Coverage for src / bioimageio / core / weight_converters / _utils_torch_onnx.py: 72%

92 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-18 12:35 +0000

1"""helper to export both TorchScript or PytorchStateDict to ONNX""" 

2 

3from collections import defaultdict 

4from itertools import chain 

5from pathlib import Path 

6from typing import ( 

7 TYPE_CHECKING, 

8 DefaultDict, 

9 Dict, 

10 List, 

11 Literal, 

12 Optional, 

13 Sequence, 

14 Tuple, 

15 Union, 

16) 

17 

18import torch 

19from loguru import logger 

20from torch.export import ExportedProgram 

21from typing_extensions import assert_never 

22 

23from bioimageio.spec.model.v0_5 import ( 

24 BatchAxis, 

25 FileDescr, 

26 InputAxis, 

27 ModelDescr, 

28 OnnxWeightsDescr, 

29 ParameterizedSize, 

30 SizeReference, 

31) 

32 

33from .. import __version__ 

34from ..backends.pytorch_backend import get_devices 

35from ..digest_spec import get_member_id, get_test_input_sample 

36from ..proc_setup import get_pre_and_postprocessing 

37 

38if TYPE_CHECKING: 

39 from torch.export.dynamic_shapes import ( 

40 _DimHint as DimHint, # pyright: ignore[reportPrivateUsage] 

41 ) 

42 

43 

44def get_torch_sample_inputs(model_descr: ModelDescr) -> Tuple[torch.Tensor, ...]: 

45 sample = get_test_input_sample(model_descr) 

46 procs = get_pre_and_postprocessing( 

47 model_descr, dataset_for_initial_statistics=[sample] 

48 ) 

49 procs.pre(sample) 

50 inputs_numpy = [ 

51 sample.members[get_member_id(ipt)].data.data for ipt in model_descr.inputs 

52 ] 

53 return tuple(torch.from_numpy(ipt) for ipt in inputs_numpy) 

54 

55 

56def _get_dynamic_axes_noop(model_descr: ModelDescr): 

57 """noop for dynamo=True which uses `get_dynamic_shapes` instead""" 

58 

59 return None 

60 

61 

62def _get_dynamic_axes_impl(model_descr: ModelDescr): 

63 """dynamic axes for (old) onnx export with dynamo=False""" 

64 dynamic_axes: DefaultDict[str, Dict[int, str]] = defaultdict(dict) 

65 for d in chain(model_descr.inputs, model_descr.outputs): 

66 for i, ax in enumerate(d.axes): 

67 if not isinstance(ax.size, int): 

68 dynamic_axes[str(d.id)][i] = str(ax.id) 

69 

70 return dynamic_axes 

71 

72 

73try: 

74 from torch.export import Dim 

75 

76 STATIC_DIM = Dim.STATIC if hasattr(Dim, "STATIC") else None 

77 TensorDim = Union[Dim, "DimHint", None] 

78 

79except Exception as e: 

80 use_dynamo = False 

81 logger.info(f"Not using torch dynamo for ONNX export due to:\n{e}") 

82 

83 def _get_dynamic_shapes_noop(model_descr: ModelDescr): 

84 """noop for dynamo=False which uses `get_dynamic_axes` instead""" 

85 

86 return None 

87 

88 get_dynamic_shapes = _get_dynamic_shapes_noop 

89 get_dynamic_axes = _get_dynamic_axes_impl 

90else: 

91 use_dynamo = True 

92 logger.info("Using torch dynamo for ONNX export") 

93 

94 def _get_dynamic_shapes_impl(model_descr: ModelDescr): 

95 """Get dynamic shapes for torch dynamo export""" 

96 # dynamic shapes as list to match the source code which may have 

97 # different arg names than the tensor ids in the model description 

98 

99 dynamic_shapes: List[Dict[int, Union[int, TensorDim]]] = [] 

100 potential_ref_axes: Dict[str, Tuple[InputAxis, int]] = {} 

101 # add dynamic dims from parameterized input sizes (and fixed sizes as None) 

102 for d in model_descr.inputs: 

103 dynamic_tensor_dims: Dict[int, Union[int, TensorDim]] = {} 

104 for i, ax in enumerate(d.axes): 

105 dim_name = f"{d.id}_{ax.id}" 

106 if isinstance(ax.size, int): 

107 dim = ax.size 

108 elif isinstance(ax, BatchAxis): 

109 dim = Dim("batch", min=1) 

110 elif isinstance(ax.size, ParameterizedSize): 

111 dim = Dim(dim_name, min=ax.size.min) 

112 elif isinstance(ax.size, SizeReference): 

113 continue # handled below 

114 else: 

115 assert_never(ax.size) 

116 

117 dynamic_tensor_dims[i] = dim 

118 potential_ref_axes[dim_name] = (ax, i) 

119 

120 dynamic_shapes.append(dynamic_tensor_dims) 

121 

122 # add dynamic dims from size references 

123 for d, dynamic_tensor_dims in zip(model_descr.inputs, dynamic_shapes): 

124 for i, ax in enumerate(d.axes): 

125 if not isinstance(ax.size, SizeReference): 

126 continue # handled above 

127 

128 dim_name_ref = f"{ax.size.tensor_id}_{ax.size.axis_id}" 

129 ax_ref, i_ref = potential_ref_axes[dim_name_ref] 

130 dim_ref = dynamic_tensor_dims[i_ref] 

131 if isinstance(dim_ref, Dim): 

132 a = ax_ref.scale / ax.scale 

133 b = ax.size.offset 

134 dim = a * dim_ref + b 

135 else: 

136 dim = STATIC_DIM 

137 

138 dynamic_tensor_dims[i] = dim 

139 

140 return dynamic_shapes 

141 

142 get_dynamic_shapes = _get_dynamic_shapes_impl 

143 get_dynamic_axes = _get_dynamic_axes_noop 

144 

145 

146def export_to_onnx( 

147 model_descr: ModelDescr, 

148 model: Union[ExportedProgram, torch.nn.Module], 

149 output_path: Path, 

150 verbose: bool, 

151 opset_version: int, 

152 parent: Literal["torchscript", "pytorch_state_dict"], 

153 devices: Optional[Sequence[Union[str, torch.device]]] = None, 

154) -> OnnxWeightsDescr: 

155 inputs_torch = get_torch_sample_inputs( 

156 model_descr 

157 ) # TODO: be more thorough about the device? 

158 devices = get_devices(devices) 

159 inputs_torch = tuple(t.to(devices[0]) for t in inputs_torch) 

160 save_weights_externally = use_dynamo 

161 with torch.no_grad(): 

162 outputs_original_torch = model(*inputs_torch) 

163 if isinstance(outputs_original_torch, torch.Tensor): 

164 outputs_original_torch = [outputs_original_torch] 

165 

166 _ = torch.onnx.export( 

167 model, 

168 inputs_torch, 

169 str(output_path), 

170 dynamo=use_dynamo, 

171 external_data=save_weights_externally, 

172 input_names=[str(d.id) for d in model_descr.inputs], 

173 output_names=[str(d.id) for d in model_descr.outputs], 

174 dynamic_axes=get_dynamic_axes(model_descr), 

175 dynamic_shapes=get_dynamic_shapes(model_descr), 

176 verbose=verbose, 

177 opset_version=opset_version, 

178 ) 

179 

180 if save_weights_externally: 

181 external_data_path = output_path.with_suffix( 

182 output_path.suffix + ".data" 

183 ).absolute() 

184 if not external_data_path.exists(): 

185 raise FileNotFoundError( 

186 f"Expected external data file at {external_data_path} not found." 

187 ) 

188 external_data_descr = FileDescr(source=external_data_path) 

189 else: 

190 external_data_descr = None 

191 

192 return OnnxWeightsDescr( 

193 source=output_path.absolute(), 

194 external_data=external_data_descr, 

195 parent=parent, 

196 opset_version=opset_version, 

197 comment=f"Converted with bioimageio.core {__version__}, dynamo={use_dynamo}.", 

198 )