Coverage for src / bioimageio / core / weight_converters / pytorch_to_onnx.py: 91%
11 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-18 12:35 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-18 12:35 +0000
1from pathlib import Path
2from typing import Optional, Sequence
4from bioimageio.spec.model.v0_5 import ModelDescr, OnnxWeightsDescr
6from ..backends.pytorch_backend import load_torch_model
7from ._utils_torch_onnx import export_to_onnx
10def convert(
11 model_descr: ModelDescr,
12 output_path: Path,
13 *,
14 verbose: bool = False,
15 opset_version: int = 18,
16 devices: Optional[Sequence[str]] = None,
17) -> OnnxWeightsDescr:
18 """
19 Convert model weights from the Torchscript state_dict format to the ONNX format.
21 Args:
22 model_descr:
23 The model description object that contains the model and its weights.
24 output_path:
25 The file path where the ONNX model will be saved.
26 verbose:
27 If True, will print out detailed information during the ONNX export process. Defaults to False.
28 opset_version:
29 The ONNX opset version to use for the export. Defaults to 18.
31 Raises:
32 ValueError:
33 If the provided model does not have weights in the PyTorch state_dict format.
35 Returns:
36 A description of the exported ONNX weights.
37 """
39 state_dict_weights_descr = model_descr.weights.pytorch_state_dict
40 if state_dict_weights_descr is None:
41 raise ValueError(
42 "The provided model does not have weights in the pytorch state dict format"
43 )
45 model = load_torch_model(
46 state_dict_weights_descr, load_state=True, devices=devices
47 ).eval()
49 return export_to_onnx(
50 model_descr,
51 model,
52 output_path,
53 verbose,
54 opset_version,
55 parent="pytorch_state_dict",
56 )