Coverage for src/bioimageio/core/weight_converters/torchscript_to_onnx.py: 0%
25 statements
« prev ^ index » next coverage.py v7.14.2, created at 2026-06-22 16:54 +0000
« prev ^ index » next coverage.py v7.14.2, created at 2026-06-22 16:54 +0000
1from pathlib import Path
2from typing import Optional, Sequence, Union
4import torch.jit
5from loguru import logger
6from torch._export.converter import TS2EPConverter
8from bioimageio.spec.model.v0_5 import ModelDescr, OnnxWeightsDescr
10from ..backends.pytorch_backend import get_devices
11from ._utils_torch_onnx import export_to_onnx, get_torch_sample_inputs
14def convert(
15 model_descr: ModelDescr,
16 output_path: Path,
17 *,
18 verbose: bool = False,
19 opset_version: int = 18,
20 devices: Optional[Sequence[Union[str, torch.device]]] = None,
21) -> OnnxWeightsDescr:
22 """
23 Convert model weights from the PyTorch state_dict format to the ONNX format.
25 Args:
26 model_descr (Union[v0_4.ModelDescr, v0_5.ModelDescr]):
27 The model description object that contains the model and its weights.
28 output_path (Path):
29 The file path where the ONNX model will be saved.
30 verbose (bool, optional):
31 If True, will print out detailed information during the ONNX export process. Defaults to False.
32 opset_version (int, optional):
33 The ONNX opset version to use for the export. Defaults to 18.
34 Raises:
35 ValueError:
36 If the provided model does not have weights in the torchscript format.
38 Returns:
39 A description of the exported ONNX weights.
40 """
42 torchscript_descr = model_descr.weights.torchscript
43 if torchscript_descr is None:
44 raise ValueError(
45 "The provided model does not have weights in the torchscript format"
46 )
48 weight_reader = torchscript_descr.get_reader()
49 model = torch.jit.load(weight_reader) # pyright: ignore[reportUnknownVariableType]
50 devices = get_devices(devices)
51 model.to(devices[0])
52 model = model.eval() # pyright: ignore[reportUnknownVariableType]
54 torch_sample_inputs = tuple(
55 t.to(device=devices[0]) for t in get_torch_sample_inputs(model_descr)
56 )
57 exported_program = TS2EPConverter(
58 model, # pyright: ignore[reportUnknownArgumentType]
59 torch_sample_inputs,
60 ).convert()
61 exported_module = exported_program.module()
63 try:
64 exported_module = exported_module.eval()
65 except Exception as e:
66 logger.warning("Failed to set TS2EPConverter program to evaluation mode: {}", e)
68 return export_to_onnx(
69 model_descr,
70 exported_module,
71 output_path,
72 verbose,
73 opset_version,
74 parent="torchscript",
75 )