Coverage for src/bioimageio/core/weight_converters/_utils_torch_onnx.py: 70%
88 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-13 11:02 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-13 11:02 +0000
1"""helper to export both TorchScript or PytorchStateDict to ONNX"""
3from collections import defaultdict
4from itertools import chain
5from pathlib import Path
6from typing import TYPE_CHECKING, DefaultDict, Dict, List, Literal, Tuple, Union
8import torch
9from bioimageio.spec.model.v0_5 import (
10 BatchAxis,
11 FileDescr,
12 InputAxis,
13 ModelDescr,
14 OnnxWeightsDescr,
15 ParameterizedSize,
16 SizeReference,
17)
18from loguru import logger
19from typing_extensions import assert_never
21from .. import __version__
22from ..digest_spec import get_member_id, get_test_input_sample
23from ..proc_setup import get_pre_and_postprocessing
25if TYPE_CHECKING:
26 from torch.export.dynamic_shapes import (
27 _DimHint as DimHint, # pyright: ignore[reportPrivateUsage]
28 )
31def get_torch_sample_inputs(model_descr: ModelDescr) -> Tuple[torch.Tensor, ...]:
32 sample = get_test_input_sample(model_descr)
33 procs = get_pre_and_postprocessing(
34 model_descr, dataset_for_initial_statistics=[sample]
35 )
36 procs.pre(sample)
37 inputs_numpy = [
38 sample.members[get_member_id(ipt)].data.data for ipt in model_descr.inputs
39 ]
40 return tuple(torch.from_numpy(ipt) for ipt in inputs_numpy)
43def _get_dynamic_axes_noop(model_descr: ModelDescr):
44 """noop for dynamo=True which uses `get_dynamic_shapes` instead"""
46 return None
49def _get_dynamic_axes_impl(model_descr: ModelDescr):
50 """dynamic axes for (old) onnx export with dynamo=False"""
51 dynamic_axes: DefaultDict[str, Dict[int, str]] = defaultdict(dict)
52 for d in chain(model_descr.inputs, model_descr.outputs):
53 for i, ax in enumerate(d.axes):
54 if not isinstance(ax.size, int):
55 dynamic_axes[str(d.id)][i] = str(ax.id)
57 return dynamic_axes
60try:
61 from torch.export import Dim
63 STATIC_DIM = Dim.STATIC if hasattr(Dim, "STATIC") else None
64 TensorDim = Union[Dim, "DimHint", None]
66except Exception as e:
67 use_dynamo = False
68 logger.info(f"Not using torch dynamo for ONNX export due to:\n{e}")
70 def _get_dynamic_shapes_noop(model_descr: ModelDescr):
71 """noop for dynamo=False which uses `get_dynamic_axes` instead"""
73 return None
75 get_dynamic_shapes = _get_dynamic_shapes_noop
76 get_dynamic_axes = _get_dynamic_axes_impl
77else:
78 use_dynamo = True
79 logger.info("Using torch dynamo for ONNX export")
81 def _get_dynamic_shapes_impl(model_descr: ModelDescr):
82 """Get dynamic shapes for torch dynamo export"""
83 # dynamic shapes as list to match the source code which may have
84 # different arg names than the tensor ids in the model description
86 dynamic_shapes: List[Dict[int, Union[int, TensorDim]]] = []
87 potential_ref_axes: Dict[str, Tuple[InputAxis, int]] = {}
88 # add dynamic dims from parameterized input sizes (and fixed sizes as None)
89 for d in model_descr.inputs:
90 dynamic_tensor_dims: Dict[int, Union[int, TensorDim]] = {}
91 for i, ax in enumerate(d.axes):
92 dim_name = f"{d.id}_{ax.id}"
93 if isinstance(ax.size, int):
94 dim = ax.size
95 elif isinstance(ax, BatchAxis):
96 dim = Dim("batch", min=1)
97 elif isinstance(ax.size, ParameterizedSize):
98 dim = Dim(dim_name, min=ax.size.min)
99 elif isinstance(ax.size, SizeReference):
100 continue # handled below
101 else:
102 assert_never(ax.size)
104 dynamic_tensor_dims[i] = dim
105 potential_ref_axes[dim_name] = (ax, i)
107 dynamic_shapes.append(dynamic_tensor_dims)
109 # add dynamic dims from size references
110 for d, dynamic_tensor_dims in zip(model_descr.inputs, dynamic_shapes):
111 for i, ax in enumerate(d.axes):
112 if not isinstance(ax.size, SizeReference):
113 continue # handled above
115 dim_name_ref = f"{ax.size.tensor_id}_{ax.size.axis_id}"
116 ax_ref, i_ref = potential_ref_axes[dim_name_ref]
117 dim_ref = dynamic_tensor_dims[i_ref]
118 if isinstance(dim_ref, Dim):
119 a = ax_ref.scale / ax.scale
120 b = ax.size.offset
121 dim = a * dim_ref + b
122 else:
123 dim = STATIC_DIM
125 dynamic_tensor_dims[i] = dim
127 return dynamic_shapes
129 get_dynamic_shapes = _get_dynamic_shapes_impl
130 get_dynamic_axes = _get_dynamic_axes_noop
133def export_to_onnx(
134 model_descr: ModelDescr,
135 model: torch.nn.Module,
136 output_path: Path,
137 verbose: bool,
138 opset_version: int,
139 parent: Literal["torchscript", "pytorch_state_dict"],
140) -> OnnxWeightsDescr:
141 inputs_torch = get_torch_sample_inputs(model_descr)
143 save_weights_externally = use_dynamo
144 with torch.no_grad():
145 outputs_original_torch = model(*inputs_torch)
146 if isinstance(outputs_original_torch, torch.Tensor):
147 outputs_original_torch = [outputs_original_torch]
149 _ = torch.onnx.export(
150 model,
151 inputs_torch,
152 str(output_path),
153 dynamo=use_dynamo,
154 external_data=save_weights_externally,
155 input_names=[str(d.id) for d in model_descr.inputs],
156 output_names=[str(d.id) for d in model_descr.outputs],
157 dynamic_axes=get_dynamic_axes(model_descr),
158 dynamic_shapes=get_dynamic_shapes(model_descr),
159 verbose=verbose,
160 opset_version=opset_version,
161 )
163 if save_weights_externally:
164 external_data_path = output_path.with_suffix(
165 output_path.suffix + ".data"
166 ).absolute()
167 if not external_data_path.exists():
168 raise FileNotFoundError(
169 f"Expected external data file at {external_data_path} not found."
170 )
171 external_data_descr = FileDescr(source=external_data_path)
172 else:
173 external_data_descr = None
175 return OnnxWeightsDescr(
176 source=output_path.absolute(),
177 external_data=external_data_descr,
178 parent=parent,
179 opset_version=opset_version,
180 comment=f"Converted with bioimageio.core {__version__}, dynamo={use_dynamo}.",
181 )