Coverage for bioimageio/core/weight_converters/pytorch_to_onnx.py: 96%
24 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 09:51 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 09:51 +0000
1from pathlib import Path
3import torch
5from bioimageio.spec.model.v0_5 import ModelDescr, OnnxWeightsDescr
7from .. import __version__
8from ..backends.pytorch_backend import load_torch_model
9from ..digest_spec import get_member_id, get_test_inputs
10from ..proc_setup import get_pre_and_postprocessing
11from ._utils_onnx import get_dynamic_axes
14def convert(
15 model_descr: ModelDescr,
16 output_path: Path,
17 *,
18 verbose: bool = False,
19 opset_version: int = 15,
20) -> OnnxWeightsDescr:
21 """
22 Convert model weights from the Torchscript state_dict format to the ONNX format.
24 Args:
25 model_descr:
26 The model description object that contains the model and its weights.
27 output_path:
28 The file path where the ONNX model will be saved.
29 verbose:
30 If True, will print out detailed information during the ONNX export process. Defaults to False.
31 opset_version:
32 The ONNX opset version to use for the export. Defaults to 15.
34 Raises:
35 ValueError:
36 If the provided model does not have weights in the PyTorch state_dict format.
38 Returns:
39 A descriptor object that contains information about the exported ONNX weights.
40 """
42 state_dict_weights_descr = model_descr.weights.pytorch_state_dict
43 if state_dict_weights_descr is None:
44 raise ValueError(
45 "The provided model does not have weights in the pytorch state dict format"
46 )
48 sample = get_test_inputs(model_descr)
49 procs = get_pre_and_postprocessing(
50 model_descr, dataset_for_initial_statistics=[sample]
51 )
52 procs.pre(sample)
53 inputs_numpy = [
54 sample.members[get_member_id(ipt)].data.data for ipt in model_descr.inputs
55 ]
56 inputs_torch = [torch.from_numpy(ipt) for ipt in inputs_numpy]
57 model = load_torch_model(state_dict_weights_descr, load_state=True)
58 with torch.no_grad():
59 outputs_original_torch = model(*inputs_torch)
60 if isinstance(outputs_original_torch, torch.Tensor):
61 outputs_original_torch = [outputs_original_torch]
63 _ = torch.onnx.export(
64 model,
65 tuple(inputs_torch),
66 str(output_path),
67 input_names=[str(d.id) for d in model_descr.inputs],
68 output_names=[str(d.id) for d in model_descr.outputs],
69 dynamic_axes=get_dynamic_axes(model_descr),
70 verbose=verbose,
71 opset_version=opset_version,
72 )
74 return OnnxWeightsDescr(
75 source=output_path,
76 parent="pytorch_state_dict",
77 opset_version=opset_version,
78 comment=f"Converted with bioimageio.core {__version__}.",
79 )