Coverage for bioimageio/core/weight_converters/torchscript_to_onnx.py: 0%
26 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-01 13:47 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-01 13:47 +0000
1from pathlib import Path
3import torch.jit
5from bioimageio.spec.model.v0_5 import ModelDescr, OnnxWeightsDescr
7from .. import __version__
8from ..digest_spec import get_member_id, get_test_inputs
9from ..proc_setup import get_pre_and_postprocessing
10from ._utils_onnx import get_dynamic_axes
13def convert(
14 model_descr: ModelDescr,
15 output_path: Path,
16 *,
17 verbose: bool = False,
18 opset_version: int = 15,
19) -> OnnxWeightsDescr:
20 """
21 Convert model weights from the PyTorch state_dict format to the ONNX format.
23 Args:
24 model_descr (Union[v0_4.ModelDescr, v0_5.ModelDescr]):
25 The model description object that contains the model and its weights.
26 output_path (Path):
27 The file path where the ONNX model will be saved.
28 verbose (bool, optional):
29 If True, will print out detailed information during the ONNX export process. Defaults to False.
30 opset_version (int, optional):
31 The ONNX opset version to use for the export. Defaults to 15.
32 Raises:
33 ValueError:
34 If the provided model does not have weights in the torchscript format.
36 Returns:
37 v0_5.OnnxWeightsDescr:
38 A descriptor object that contains information about the exported ONNX weights.
39 """
41 torchscript_descr = model_descr.weights.torchscript
42 if torchscript_descr is None:
43 raise ValueError(
44 "The provided model does not have weights in the torchscript format"
45 )
47 sample = get_test_inputs(model_descr)
48 procs = get_pre_and_postprocessing(
49 model_descr, dataset_for_initial_statistics=[sample]
50 )
51 procs.pre(sample)
52 inputs_numpy = [
53 sample.members[get_member_id(ipt)].data.data for ipt in model_descr.inputs
54 ]
55 inputs_torch = [torch.from_numpy(ipt) for ipt in inputs_numpy]
57 weight_reader = torchscript_descr.get_reader()
58 model = torch.jit.load(weight_reader) # type: ignore
59 model.to("cpu")
60 model = model.eval() # type: ignore
62 with torch.no_grad():
63 outputs_original_torch = model(*inputs_torch) # type: ignore
64 if isinstance(outputs_original_torch, torch.Tensor):
65 outputs_original_torch = [outputs_original_torch]
67 _ = torch.onnx.export(
68 model, # type: ignore
69 tuple(inputs_torch),
70 str(output_path),
71 input_names=[str(d.id) for d in model_descr.inputs],
72 output_names=[str(d.id) for d in model_descr.outputs],
73 dynamic_axes=get_dynamic_axes(model_descr),
74 verbose=verbose,
75 opset_version=opset_version,
76 )
78 return OnnxWeightsDescr(
79 source=output_path,
80 parent="torchscript",
81 opset_version=opset_version,
82 comment=f"Converted with bioimageio.core {__version__}.",
83 )