Coverage for src / bioimageio / core / backends / _model_adapter.py: 71%

105 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-27 22:06 +0000

1import warnings 

2from abc import ABC, abstractmethod 

3from typing import ( 

4 Any, 

5 List, 

6 Optional, 

7 Sequence, 

8 Tuple, 

9 Union, 

10 final, 

11) 

12 

13from exceptiongroup import ExceptionGroup 

14from numpy.typing import NDArray 

15from typing_extensions import assert_never 

16 

17from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 

18 

19from ..common import SupportedWeightsFormat 

20from ..digest_spec import get_axes_infos, get_member_ids 

21from ..sample import Sample, SampleBlock 

22from ..tensor import Tensor 

23 

24# Known weight formats in order of priority 

25# First match wins 

26DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER: Tuple[SupportedWeightsFormat, ...] = ( 

27 "pytorch_state_dict", 

28 "tensorflow_saved_model_bundle", 

29 "torchscript", 

30 "onnx", 

31 "keras_v3", 

32 "keras_hdf5", 

33) 

34 

35 

36class ModelAdapter(ABC): 

37 """ 

38 Represents model *without* any preprocessing or postprocessing. 

39 

40 ``` 

41 from bioimageio.core import load_description 

42 

43 model = load_description(...) 

44 

45 # option 1: 

46 adapter = ModelAdapter.create(model) 

47 adapter.forward(...) 

48 adapter.unload() 

49 

50 # option 2: 

51 with ModelAdapter.create(model) as adapter: 

52 adapter.forward(...) 

53 ``` 

54 """ 

55 

56 def __init__(self, model_description: AnyModelDescr): 

57 super().__init__() 

58 self._model_descr = model_description 

59 self._input_ids = get_member_ids(model_description.inputs) 

60 self._output_ids = get_member_ids(model_description.outputs) 

61 self._input_axes = [ 

62 tuple(a.id for a in get_axes_infos(t)) for t in model_description.inputs 

63 ] 

64 self._output_axes = [ 

65 tuple(a.id for a in get_axes_infos(t)) for t in model_description.outputs 

66 ] 

67 if isinstance(model_description, v0_4.ModelDescr): 

68 self._input_is_optional = [False] * len(model_description.inputs) 

69 else: 

70 self._input_is_optional = [ipt.optional for ipt in model_description.inputs] 

71 

72 @final 

73 @classmethod 

74 def create( 

75 cls, 

76 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

77 *, 

78 devices: Optional[Sequence[str]] = None, 

79 weight_format_priority_order: Optional[Sequence[SupportedWeightsFormat]] = None, 

80 ): 

81 """ 

82 Creates model adapter based on the passed spec 

83 Note: All specific adapters should happen inside this function to prevent different framework 

84 initializations interfering with each other 

85 """ 

86 if not isinstance(model_description, (v0_4.ModelDescr, v0_5.ModelDescr)): 

87 raise TypeError( 

88 f"expected v0_4.ModelDescr or v0_5.ModelDescr, but got {type(model_description)}" 

89 ) 

90 

91 weights = model_description.weights 

92 errors: List[Exception] = [] 

93 weight_format_priority_order = ( 

94 DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER 

95 if weight_format_priority_order is None 

96 else weight_format_priority_order 

97 ) 

98 # limit weight formats to the ones present 

99 weight_format_priority_order_present: Sequence[SupportedWeightsFormat] = [ 

100 w 

101 for w in weight_format_priority_order 

102 if getattr(weights, w, None) is not None 

103 ] 

104 if not weight_format_priority_order_present: 

105 raise ValueError( 

106 f"None of the specified weight formats ({weight_format_priority_order}) is present ({weight_format_priority_order_present})" 

107 ) 

108 

109 for wf in weight_format_priority_order_present: 

110 if wf == "pytorch_state_dict": 

111 assert weights.pytorch_state_dict is not None 

112 try: 

113 from .pytorch_backend import PytorchModelAdapter 

114 

115 return PytorchModelAdapter( 

116 model_description=model_description, devices=devices 

117 ) 

118 except Exception as e: 

119 errors.append(e) 

120 elif wf == "tensorflow_saved_model_bundle": 

121 assert weights.tensorflow_saved_model_bundle is not None 

122 try: 

123 from .tensorflow_backend import create_tf_model_adapter 

124 

125 return create_tf_model_adapter( 

126 model_description=model_description, devices=devices 

127 ) 

128 except Exception as e: 

129 errors.append(e) 

130 elif wf == "onnx": 

131 assert weights.onnx is not None 

132 try: 

133 from .onnx_backend import ONNXModelAdapter 

134 

135 return ONNXModelAdapter( 

136 model_description=model_description, devices=devices 

137 ) 

138 except Exception as e: 

139 errors.append(e) 

140 elif wf == "torchscript": 

141 assert weights.torchscript is not None 

142 try: 

143 from .torchscript_backend import TorchscriptModelAdapter 

144 

145 return TorchscriptModelAdapter( 

146 model_description=model_description, devices=devices 

147 ) 

148 except Exception as e: 

149 errors.append(e) 

150 elif wf == "keras_hdf5": 

151 assert weights.keras_hdf5 is not None 

152 # keras can either be installed as a separate package or used as part of tensorflow 

153 # we try to first import the keras model adapter using the separate package and, 

154 # if it is not available, try to load the one using tf 

155 try: 

156 try: 

157 from .keras_backend import KerasModelAdapter 

158 except Exception: 

159 from .tensorflow_backend import KerasModelAdapter 

160 

161 return KerasModelAdapter( 

162 model_description=model_description, devices=devices 

163 ) 

164 except Exception as e: 

165 errors.append(e) 

166 elif wf == "keras_v3": 

167 assert not isinstance(weights, v0_4.WeightsDescr), ( 

168 "keras_v3 weights not supported for v0.4 specs" 

169 ) 

170 assert weights.keras_v3 is not None 

171 try: 

172 from .keras_backend import KerasModelAdapter 

173 

174 return KerasModelAdapter( 

175 model_description=model_description, devices=devices 

176 ) 

177 except Exception as e: 

178 errors.append(e) 

179 else: 

180 assert_never(wf) 

181 

182 assert errors 

183 if len(weight_format_priority_order) == 1: 

184 assert len(errors) == 1 

185 raise errors[0] 

186 

187 else: 

188 msg = ( 

189 "None of the weight format specific model adapters could be created" 

190 + " in this environment." 

191 ) 

192 raise ExceptionGroup(msg, errors) 

193 

194 @final 

195 def load(self, *, devices: Optional[Sequence[str]] = None) -> None: 

196 warnings.warn("Deprecated. ModelAdapter is loaded on initialization") 

197 

198 def forward(self, input_sample: Union[Sample, SampleBlock]) -> Sample: 

199 """ 

200 Run forward pass of model to get model predictions 

201 

202 Note: sample id and stample stat attributes are passed through 

203 """ 

204 unexpected = [mid for mid in input_sample.members if mid not in self._input_ids] 

205 if unexpected: 

206 warnings.warn(f"Got unexpected input tensor IDs: {unexpected}") 

207 

208 input_arrays = [ 

209 ( 

210 None 

211 if (a := input_sample.members.get(in_id)) is None 

212 else a.transpose(in_order).data.data 

213 ) 

214 for in_id, in_order in zip(self._input_ids, self._input_axes) 

215 ] 

216 output_arrays = self._forward_impl(input_arrays) 

217 assert len(output_arrays) <= len(self._output_ids) 

218 output_tensors = [ 

219 None if a is None else Tensor(a, dims=d) 

220 for a, d in zip(output_arrays, self._output_axes) 

221 ] 

222 return Sample( 

223 members={ 

224 tid: out 

225 for tid, out in zip( 

226 self._output_ids, 

227 output_tensors, 

228 ) 

229 if out is not None 

230 }, 

231 stat=input_sample.stat, 

232 id=( 

233 input_sample.id 

234 if isinstance(input_sample, Sample) 

235 else input_sample.sample_id 

236 ), 

237 ) 

238 

239 @abstractmethod 

240 def _forward_impl( 

241 self, input_arrays: Sequence[Optional[NDArray[Any]]] 

242 ) -> Union[List[Optional[NDArray[Any]]], Tuple[Optional[NDArray[Any]]]]: 

243 """framework specific forward implementation""" 

244 

245 @abstractmethod 

246 def unload(self): 

247 """ 

248 Unload model from any devices, freeing their memory. 

249 The moder adapter should be considered unusable afterwards. 

250 """ 

251 

252 def _get_input_args_numpy(self, input_sample: Sample): 

253 """helper to extract tensor args as transposed numpy arrays""" 

254 

255 

256create_model_adapter = ModelAdapter.create