Coverage for src/bioimageio/core/weight_converters/torchscript_to_onnx.py: 94%
16 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-13 11:02 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-13 11:02 +0000
1from pathlib import Path
3import torch.jit
4from bioimageio.spec.model.v0_5 import ModelDescr, OnnxWeightsDescr
5from torch._export.converter import TS2EPConverter
7from ._utils_torch_onnx import export_to_onnx, get_torch_sample_inputs
10def convert(
11 model_descr: ModelDescr,
12 output_path: Path,
13 *,
14 verbose: bool = False,
15 opset_version: int = 18,
16) -> OnnxWeightsDescr:
17 """
18 Convert model weights from the PyTorch state_dict format to the ONNX format.
20 Args:
21 model_descr (Union[v0_4.ModelDescr, v0_5.ModelDescr]):
22 The model description object that contains the model and its weights.
23 output_path (Path):
24 The file path where the ONNX model will be saved.
25 verbose (bool, optional):
26 If True, will print out detailed information during the ONNX export process. Defaults to False.
27 opset_version (int, optional):
28 The ONNX opset version to use for the export. Defaults to 18.
29 Raises:
30 ValueError:
31 If the provided model does not have weights in the torchscript format.
33 Returns:
34 A description of the exported ONNX weights.
35 """
37 torchscript_descr = model_descr.weights.torchscript
38 if torchscript_descr is None:
39 raise ValueError(
40 "The provided model does not have weights in the torchscript format"
41 )
43 weight_reader = torchscript_descr.get_reader()
44 model = torch.jit.load(weight_reader) # pyright: ignore[reportUnknownVariableType]
45 model.to("cpu")
46 model = model.eval() # pyright: ignore[reportUnknownVariableType]
48 torch_sample_inputs = get_torch_sample_inputs(model_descr)
49 exported_program = TS2EPConverter(
50 model, # pyright: ignore[reportUnknownArgumentType]
51 torch_sample_inputs,
52 ).convert()
54 return export_to_onnx(
55 model_descr,
56 exported_program.module(),
57 output_path,
58 verbose,
59 opset_version,
60 parent="torchscript",
61 )