Coverage for src/bioimageio/core/weight_converters/_utils_torch_onnx.py: 70%

88 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-13 11:02 +0000

1"""helper to export both TorchScript or PytorchStateDict to ONNX""" 

2 

3from collections import defaultdict 

4from itertools import chain 

5from pathlib import Path 

6from typing import TYPE_CHECKING, DefaultDict, Dict, List, Literal, Tuple, Union 

7 

8import torch 

9from bioimageio.spec.model.v0_5 import ( 

10 BatchAxis, 

11 FileDescr, 

12 InputAxis, 

13 ModelDescr, 

14 OnnxWeightsDescr, 

15 ParameterizedSize, 

16 SizeReference, 

17) 

18from loguru import logger 

19from typing_extensions import assert_never 

20 

21from .. import __version__ 

22from ..digest_spec import get_member_id, get_test_input_sample 

23from ..proc_setup import get_pre_and_postprocessing 

24 

25if TYPE_CHECKING: 

26 from torch.export.dynamic_shapes import ( 

27 _DimHint as DimHint, # pyright: ignore[reportPrivateUsage] 

28 ) 

29 

30 

31def get_torch_sample_inputs(model_descr: ModelDescr) -> Tuple[torch.Tensor, ...]: 

32 sample = get_test_input_sample(model_descr) 

33 procs = get_pre_and_postprocessing( 

34 model_descr, dataset_for_initial_statistics=[sample] 

35 ) 

36 procs.pre(sample) 

37 inputs_numpy = [ 

38 sample.members[get_member_id(ipt)].data.data for ipt in model_descr.inputs 

39 ] 

40 return tuple(torch.from_numpy(ipt) for ipt in inputs_numpy) 

41 

42 

43def _get_dynamic_axes_noop(model_descr: ModelDescr): 

44 """noop for dynamo=True which uses `get_dynamic_shapes` instead""" 

45 

46 return None 

47 

48 

49def _get_dynamic_axes_impl(model_descr: ModelDescr): 

50 """dynamic axes for (old) onnx export with dynamo=False""" 

51 dynamic_axes: DefaultDict[str, Dict[int, str]] = defaultdict(dict) 

52 for d in chain(model_descr.inputs, model_descr.outputs): 

53 for i, ax in enumerate(d.axes): 

54 if not isinstance(ax.size, int): 

55 dynamic_axes[str(d.id)][i] = str(ax.id) 

56 

57 return dynamic_axes 

58 

59 

60try: 

61 from torch.export import Dim 

62 

63 STATIC_DIM = Dim.STATIC if hasattr(Dim, "STATIC") else None 

64 TensorDim = Union[Dim, "DimHint", None] 

65 

66except Exception as e: 

67 use_dynamo = False 

68 logger.info(f"Not using torch dynamo for ONNX export due to:\n{e}") 

69 

70 def _get_dynamic_shapes_noop(model_descr: ModelDescr): 

71 """noop for dynamo=False which uses `get_dynamic_axes` instead""" 

72 

73 return None 

74 

75 get_dynamic_shapes = _get_dynamic_shapes_noop 

76 get_dynamic_axes = _get_dynamic_axes_impl 

77else: 

78 use_dynamo = True 

79 logger.info("Using torch dynamo for ONNX export") 

80 

81 def _get_dynamic_shapes_impl(model_descr: ModelDescr): 

82 """Get dynamic shapes for torch dynamo export""" 

83 # dynamic shapes as list to match the source code which may have 

84 # different arg names than the tensor ids in the model description 

85 

86 dynamic_shapes: List[Dict[int, Union[int, TensorDim]]] = [] 

87 potential_ref_axes: Dict[str, Tuple[InputAxis, int]] = {} 

88 # add dynamic dims from parameterized input sizes (and fixed sizes as None) 

89 for d in model_descr.inputs: 

90 dynamic_tensor_dims: Dict[int, Union[int, TensorDim]] = {} 

91 for i, ax in enumerate(d.axes): 

92 dim_name = f"{d.id}_{ax.id}" 

93 if isinstance(ax.size, int): 

94 dim = ax.size 

95 elif isinstance(ax, BatchAxis): 

96 dim = Dim("batch", min=1) 

97 elif isinstance(ax.size, ParameterizedSize): 

98 dim = Dim(dim_name, min=ax.size.min) 

99 elif isinstance(ax.size, SizeReference): 

100 continue # handled below 

101 else: 

102 assert_never(ax.size) 

103 

104 dynamic_tensor_dims[i] = dim 

105 potential_ref_axes[dim_name] = (ax, i) 

106 

107 dynamic_shapes.append(dynamic_tensor_dims) 

108 

109 # add dynamic dims from size references 

110 for d, dynamic_tensor_dims in zip(model_descr.inputs, dynamic_shapes): 

111 for i, ax in enumerate(d.axes): 

112 if not isinstance(ax.size, SizeReference): 

113 continue # handled above 

114 

115 dim_name_ref = f"{ax.size.tensor_id}_{ax.size.axis_id}" 

116 ax_ref, i_ref = potential_ref_axes[dim_name_ref] 

117 dim_ref = dynamic_tensor_dims[i_ref] 

118 if isinstance(dim_ref, Dim): 

119 a = ax_ref.scale / ax.scale 

120 b = ax.size.offset 

121 dim = a * dim_ref + b 

122 else: 

123 dim = STATIC_DIM 

124 

125 dynamic_tensor_dims[i] = dim 

126 

127 return dynamic_shapes 

128 

129 get_dynamic_shapes = _get_dynamic_shapes_impl 

130 get_dynamic_axes = _get_dynamic_axes_noop 

131 

132 

133def export_to_onnx( 

134 model_descr: ModelDescr, 

135 model: torch.nn.Module, 

136 output_path: Path, 

137 verbose: bool, 

138 opset_version: int, 

139 parent: Literal["torchscript", "pytorch_state_dict"], 

140) -> OnnxWeightsDescr: 

141 inputs_torch = get_torch_sample_inputs(model_descr) 

142 

143 save_weights_externally = use_dynamo 

144 with torch.no_grad(): 

145 outputs_original_torch = model(*inputs_torch) 

146 if isinstance(outputs_original_torch, torch.Tensor): 

147 outputs_original_torch = [outputs_original_torch] 

148 

149 _ = torch.onnx.export( 

150 model, 

151 inputs_torch, 

152 str(output_path), 

153 dynamo=use_dynamo, 

154 external_data=save_weights_externally, 

155 input_names=[str(d.id) for d in model_descr.inputs], 

156 output_names=[str(d.id) for d in model_descr.outputs], 

157 dynamic_axes=get_dynamic_axes(model_descr), 

158 dynamic_shapes=get_dynamic_shapes(model_descr), 

159 verbose=verbose, 

160 opset_version=opset_version, 

161 ) 

162 

163 if save_weights_externally: 

164 external_data_path = output_path.with_suffix( 

165 output_path.suffix + ".data" 

166 ).absolute() 

167 if not external_data_path.exists(): 

168 raise FileNotFoundError( 

169 f"Expected external data file at {external_data_path} not found." 

170 ) 

171 external_data_descr = FileDescr(source=external_data_path) 

172 else: 

173 external_data_descr = None 

174 

175 return OnnxWeightsDescr( 

176 source=output_path.absolute(), 

177 external_data=external_data_descr, 

178 parent=parent, 

179 opset_version=opset_version, 

180 comment=f"Converted with bioimageio.core {__version__}, dynamo={use_dynamo}.", 

181 )