Coverage for bioimageio/core/weight_converters/torchscript_to_onnx.py: 0%
27 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 09:51 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 09:51 +0000
1from pathlib import Path
3import torch.jit
5from bioimageio.spec.model.v0_5 import ModelDescr, OnnxWeightsDescr
6from bioimageio.spec.utils import download
8from .. import __version__
9from ..digest_spec import get_member_id, get_test_inputs
10from ..proc_setup import get_pre_and_postprocessing
11from ._utils_onnx import get_dynamic_axes
14def convert(
15 model_descr: ModelDescr,
16 output_path: Path,
17 *,
18 verbose: bool = False,
19 opset_version: int = 15,
20) -> OnnxWeightsDescr:
21 """
22 Convert model weights from the PyTorch state_dict format to the ONNX format.
24 Args:
25 model_descr (Union[v0_4.ModelDescr, v0_5.ModelDescr]):
26 The model description object that contains the model and its weights.
27 output_path (Path):
28 The file path where the ONNX model will be saved.
29 verbose (bool, optional):
30 If True, will print out detailed information during the ONNX export process. Defaults to False.
31 opset_version (int, optional):
32 The ONNX opset version to use for the export. Defaults to 15.
33 Raises:
34 ValueError:
35 If the provided model does not have weights in the torchscript format.
37 Returns:
38 v0_5.OnnxWeightsDescr:
39 A descriptor object that contains information about the exported ONNX weights.
40 """
42 torchscript_descr = model_descr.weights.torchscript
43 if torchscript_descr is None:
44 raise ValueError(
45 "The provided model does not have weights in the torchscript format"
46 )
48 sample = get_test_inputs(model_descr)
49 procs = get_pre_and_postprocessing(
50 model_descr, dataset_for_initial_statistics=[sample]
51 )
52 procs.pre(sample)
53 inputs_numpy = [
54 sample.members[get_member_id(ipt)].data.data for ipt in model_descr.inputs
55 ]
56 inputs_torch = [torch.from_numpy(ipt) for ipt in inputs_numpy]
58 weight_path = download(torchscript_descr).path
59 model = torch.jit.load(weight_path) # type: ignore
60 model.to("cpu")
61 model = model.eval() # type: ignore
63 with torch.no_grad():
64 outputs_original_torch = model(*inputs_torch) # type: ignore
65 if isinstance(outputs_original_torch, torch.Tensor):
66 outputs_original_torch = [outputs_original_torch]
68 _ = torch.onnx.export(
69 model, # type: ignore
70 tuple(inputs_torch),
71 str(output_path),
72 input_names=[str(d.id) for d in model_descr.inputs],
73 output_names=[str(d.id) for d in model_descr.outputs],
74 dynamic_axes=get_dynamic_axes(model_descr),
75 verbose=verbose,
76 opset_version=opset_version,
77 )
79 return OnnxWeightsDescr(
80 source=output_path,
81 parent="torchscript",
82 opset_version=opset_version,
83 comment=f"Converted with bioimageio.core {__version__}.",
84 )