Coverage for src / bioimageio / core / weight_converters / torchscript_to_onnx.py: 95%
19 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-18 12:35 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-18 12:35 +0000
1from pathlib import Path
2from typing import Optional, Sequence, Union
4import torch.jit
5from torch._export.converter import TS2EPConverter
7from bioimageio.spec.model.v0_5 import ModelDescr, OnnxWeightsDescr
9from ..backends.pytorch_backend import get_devices
10from ._utils_torch_onnx import export_to_onnx, get_torch_sample_inputs
13def convert(
14 model_descr: ModelDescr,
15 output_path: Path,
16 *,
17 verbose: bool = False,
18 opset_version: int = 18,
19 devices: Optional[Sequence[Union[str, torch.device]]] = None,
20) -> OnnxWeightsDescr:
21 """
22 Convert model weights from the PyTorch state_dict format to the ONNX format.
24 Args:
25 model_descr (Union[v0_4.ModelDescr, v0_5.ModelDescr]):
26 The model description object that contains the model and its weights.
27 output_path (Path):
28 The file path where the ONNX model will be saved.
29 verbose (bool, optional):
30 If True, will print out detailed information during the ONNX export process. Defaults to False.
31 opset_version (int, optional):
32 The ONNX opset version to use for the export. Defaults to 18.
33 Raises:
34 ValueError:
35 If the provided model does not have weights in the torchscript format.
37 Returns:
38 A description of the exported ONNX weights.
39 """
41 torchscript_descr = model_descr.weights.torchscript
42 if torchscript_descr is None:
43 raise ValueError(
44 "The provided model does not have weights in the torchscript format"
45 )
47 weight_reader = torchscript_descr.get_reader()
48 model = torch.jit.load(weight_reader) # pyright: ignore[reportUnknownVariableType]
49 devices = get_devices(devices)
50 model.to(devices[0])
51 model = model.eval() # pyright: ignore[reportUnknownVariableType]
53 torch_sample_inputs = tuple(
54 t.to(device=devices[0]) for t in get_torch_sample_inputs(model_descr)
55 )
56 exported_program = TS2EPConverter(
57 model, # pyright: ignore[reportUnknownArgumentType]
58 torch_sample_inputs,
59 ).convert()
61 return export_to_onnx(
62 model_descr,
63 exported_program.module(),
64 output_path,
65 verbose,
66 opset_version,
67 parent="torchscript",
68 )