Coverage for src / bioimageio / core / weight_converters / _utils_torch_onnx.py: 72%
92 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-18 12:35 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-18 12:35 +0000
1"""helper to export both TorchScript or PytorchStateDict to ONNX"""
3from collections import defaultdict
4from itertools import chain
5from pathlib import Path
6from typing import (
7 TYPE_CHECKING,
8 DefaultDict,
9 Dict,
10 List,
11 Literal,
12 Optional,
13 Sequence,
14 Tuple,
15 Union,
16)
18import torch
19from loguru import logger
20from torch.export import ExportedProgram
21from typing_extensions import assert_never
23from bioimageio.spec.model.v0_5 import (
24 BatchAxis,
25 FileDescr,
26 InputAxis,
27 ModelDescr,
28 OnnxWeightsDescr,
29 ParameterizedSize,
30 SizeReference,
31)
33from .. import __version__
34from ..backends.pytorch_backend import get_devices
35from ..digest_spec import get_member_id, get_test_input_sample
36from ..proc_setup import get_pre_and_postprocessing
38if TYPE_CHECKING:
39 from torch.export.dynamic_shapes import (
40 _DimHint as DimHint, # pyright: ignore[reportPrivateUsage]
41 )
44def get_torch_sample_inputs(model_descr: ModelDescr) -> Tuple[torch.Tensor, ...]:
45 sample = get_test_input_sample(model_descr)
46 procs = get_pre_and_postprocessing(
47 model_descr, dataset_for_initial_statistics=[sample]
48 )
49 procs.pre(sample)
50 inputs_numpy = [
51 sample.members[get_member_id(ipt)].data.data for ipt in model_descr.inputs
52 ]
53 return tuple(torch.from_numpy(ipt) for ipt in inputs_numpy)
56def _get_dynamic_axes_noop(model_descr: ModelDescr):
57 """noop for dynamo=True which uses `get_dynamic_shapes` instead"""
59 return None
62def _get_dynamic_axes_impl(model_descr: ModelDescr):
63 """dynamic axes for (old) onnx export with dynamo=False"""
64 dynamic_axes: DefaultDict[str, Dict[int, str]] = defaultdict(dict)
65 for d in chain(model_descr.inputs, model_descr.outputs):
66 for i, ax in enumerate(d.axes):
67 if not isinstance(ax.size, int):
68 dynamic_axes[str(d.id)][i] = str(ax.id)
70 return dynamic_axes
73try:
74 from torch.export import Dim
76 STATIC_DIM = Dim.STATIC if hasattr(Dim, "STATIC") else None
77 TensorDim = Union[Dim, "DimHint", None]
79except Exception as e:
80 use_dynamo = False
81 logger.info(f"Not using torch dynamo for ONNX export due to:\n{e}")
83 def _get_dynamic_shapes_noop(model_descr: ModelDescr):
84 """noop for dynamo=False which uses `get_dynamic_axes` instead"""
86 return None
88 get_dynamic_shapes = _get_dynamic_shapes_noop
89 get_dynamic_axes = _get_dynamic_axes_impl
90else:
91 use_dynamo = True
92 logger.info("Using torch dynamo for ONNX export")
94 def _get_dynamic_shapes_impl(model_descr: ModelDescr):
95 """Get dynamic shapes for torch dynamo export"""
96 # dynamic shapes as list to match the source code which may have
97 # different arg names than the tensor ids in the model description
99 dynamic_shapes: List[Dict[int, Union[int, TensorDim]]] = []
100 potential_ref_axes: Dict[str, Tuple[InputAxis, int]] = {}
101 # add dynamic dims from parameterized input sizes (and fixed sizes as None)
102 for d in model_descr.inputs:
103 dynamic_tensor_dims: Dict[int, Union[int, TensorDim]] = {}
104 for i, ax in enumerate(d.axes):
105 dim_name = f"{d.id}_{ax.id}"
106 if isinstance(ax.size, int):
107 dim = ax.size
108 elif isinstance(ax, BatchAxis):
109 dim = Dim("batch", min=1)
110 elif isinstance(ax.size, ParameterizedSize):
111 dim = Dim(dim_name, min=ax.size.min)
112 elif isinstance(ax.size, SizeReference):
113 continue # handled below
114 else:
115 assert_never(ax.size)
117 dynamic_tensor_dims[i] = dim
118 potential_ref_axes[dim_name] = (ax, i)
120 dynamic_shapes.append(dynamic_tensor_dims)
122 # add dynamic dims from size references
123 for d, dynamic_tensor_dims in zip(model_descr.inputs, dynamic_shapes):
124 for i, ax in enumerate(d.axes):
125 if not isinstance(ax.size, SizeReference):
126 continue # handled above
128 dim_name_ref = f"{ax.size.tensor_id}_{ax.size.axis_id}"
129 ax_ref, i_ref = potential_ref_axes[dim_name_ref]
130 dim_ref = dynamic_tensor_dims[i_ref]
131 if isinstance(dim_ref, Dim):
132 a = ax_ref.scale / ax.scale
133 b = ax.size.offset
134 dim = a * dim_ref + b
135 else:
136 dim = STATIC_DIM
138 dynamic_tensor_dims[i] = dim
140 return dynamic_shapes
142 get_dynamic_shapes = _get_dynamic_shapes_impl
143 get_dynamic_axes = _get_dynamic_axes_noop
146def export_to_onnx(
147 model_descr: ModelDescr,
148 model: Union[ExportedProgram, torch.nn.Module],
149 output_path: Path,
150 verbose: bool,
151 opset_version: int,
152 parent: Literal["torchscript", "pytorch_state_dict"],
153 devices: Optional[Sequence[Union[str, torch.device]]] = None,
154) -> OnnxWeightsDescr:
155 inputs_torch = get_torch_sample_inputs(
156 model_descr
157 ) # TODO: be more thorough about the device?
158 devices = get_devices(devices)
159 inputs_torch = tuple(t.to(devices[0]) for t in inputs_torch)
160 save_weights_externally = use_dynamo
161 with torch.no_grad():
162 outputs_original_torch = model(*inputs_torch)
163 if isinstance(outputs_original_torch, torch.Tensor):
164 outputs_original_torch = [outputs_original_torch]
166 _ = torch.onnx.export(
167 model,
168 inputs_torch,
169 str(output_path),
170 dynamo=use_dynamo,
171 external_data=save_weights_externally,
172 input_names=[str(d.id) for d in model_descr.inputs],
173 output_names=[str(d.id) for d in model_descr.outputs],
174 dynamic_axes=get_dynamic_axes(model_descr),
175 dynamic_shapes=get_dynamic_shapes(model_descr),
176 verbose=verbose,
177 opset_version=opset_version,
178 )
180 if save_weights_externally:
181 external_data_path = output_path.with_suffix(
182 output_path.suffix + ".data"
183 ).absolute()
184 if not external_data_path.exists():
185 raise FileNotFoundError(
186 f"Expected external data file at {external_data_path} not found."
187 )
188 external_data_descr = FileDescr(source=external_data_path)
189 else:
190 external_data_descr = None
192 return OnnxWeightsDescr(
193 source=output_path.absolute(),
194 external_data=external_data_descr,
195 parent=parent,
196 opset_version=opset_version,
197 comment=f"Converted with bioimageio.core {__version__}, dynamo={use_dynamo}.",
198 )