Coverage for src / bioimageio / core / backends / _model_adapter.py: 73%

110 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-18 12:35 +0000

1import warnings 

2from abc import ABC, abstractmethod 

3from typing import ( 

4 Any, 

5 List, 

6 Optional, 

7 Sequence, 

8 Tuple, 

9 Union, 

10 final, 

11) 

12 

13from exceptiongroup import ExceptionGroup 

14from loguru import logger 

15from numpy.typing import NDArray 

16from typing_extensions import assert_never 

17 

18from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 

19 

20from ..common import SupportedWeightsFormat 

21from ..digest_spec import get_axes_infos, get_member_ids 

22from ..sample import Sample, SampleBlock 

23from ..tensor import Tensor 

24 

25# Known weight formats in order of priority 

26# First match wins 

27DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER: Tuple[SupportedWeightsFormat, ...] = ( 

28 "pytorch_state_dict", 

29 "tensorflow_saved_model_bundle", 

30 "torchscript", 

31 "onnx", 

32 "keras_v3", 

33 "keras_hdf5", 

34) 

35 

36 

37class ModelAdapter(ABC): 

38 """ 

39 Represents model *without* any preprocessing or postprocessing. 

40 

41 ``` 

42 from bioimageio.core import load_description 

43 

44 model = load_description(...) 

45 

46 # option 1: 

47 adapter = ModelAdapter.create(model) 

48 adapter.forward(...) 

49 adapter.unload() 

50 

51 # option 2: 

52 with ModelAdapter.create(model) as adapter: 

53 adapter.forward(...) 

54 ``` 

55 """ 

56 

57 def __init__(self, model_description: AnyModelDescr): 

58 super().__init__() 

59 self._model_descr = model_description 

60 self._input_ids = get_member_ids(model_description.inputs) 

61 self._output_ids = get_member_ids(model_description.outputs) 

62 self._input_axes = [ 

63 tuple(a.id for a in get_axes_infos(t)) for t in model_description.inputs 

64 ] 

65 self._output_axes = [ 

66 tuple(a.id for a in get_axes_infos(t)) for t in model_description.outputs 

67 ] 

68 if isinstance(model_description, v0_4.ModelDescr): 

69 self._input_is_optional = [False] * len(model_description.inputs) 

70 else: 

71 self._input_is_optional = [ipt.optional for ipt in model_description.inputs] 

72 

73 @final 

74 @classmethod 

75 def create( 

76 cls, 

77 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

78 *, 

79 devices: Optional[Sequence[str]] = None, 

80 weight_format_priority_order: Optional[Sequence[SupportedWeightsFormat]] = None, 

81 ): 

82 """ 

83 Creates model adapter based on the passed spec 

84 Note: All specific adapters should happen inside this function to prevent different framework 

85 initializations interfering with each other 

86 """ 

87 if not isinstance(model_description, (v0_4.ModelDescr, v0_5.ModelDescr)): 

88 raise TypeError( 

89 f"expected v0_4.ModelDescr or v0_5.ModelDescr, but got {type(model_description)}" 

90 ) 

91 

92 weights = model_description.weights 

93 errors: List[Exception] = [] 

94 weight_format_priority_order = ( 

95 DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER 

96 if weight_format_priority_order is None 

97 else weight_format_priority_order 

98 ) 

99 # limit weight formats to the ones present 

100 weight_format_priority_order_present: Sequence[SupportedWeightsFormat] = [ 

101 w 

102 for w in weight_format_priority_order 

103 if getattr(weights, w, None) is not None 

104 ] 

105 if not weight_format_priority_order_present: 

106 raise ValueError( 

107 f"None of the specified weight formats ({weight_format_priority_order}) is present ({weight_format_priority_order_present})" 

108 ) 

109 

110 for wf in weight_format_priority_order_present: 

111 if wf == "pytorch_state_dict": 

112 assert weights.pytorch_state_dict is not None 

113 try: 

114 from .pytorch_backend import PytorchModelAdapter 

115 

116 return PytorchModelAdapter( 

117 model_description=model_description, devices=devices 

118 ) 

119 except Exception as e: 

120 errors.append(e) 

121 elif wf == "tensorflow_saved_model_bundle": 

122 assert weights.tensorflow_saved_model_bundle is not None 

123 try: 

124 from .tensorflow_backend import create_tf_model_adapter 

125 

126 return create_tf_model_adapter( 

127 model_description=model_description, devices=devices 

128 ) 

129 except Exception as e: 

130 errors.append(e) 

131 elif wf == "onnx": 

132 assert weights.onnx is not None 

133 try: 

134 from .onnx_backend import ONNXModelAdapter 

135 

136 return ONNXModelAdapter( 

137 model_description=model_description, devices=devices 

138 ) 

139 except Exception as e: 

140 errors.append(e) 

141 elif wf == "torchscript": 

142 assert weights.torchscript is not None 

143 try: 

144 from .torchscript_backend import TorchscriptModelAdapter 

145 

146 return TorchscriptModelAdapter( 

147 model_description=model_description, devices=devices 

148 ) 

149 except Exception as e: 

150 errors.append(e) 

151 elif wf == "keras_hdf5": 

152 assert weights.keras_hdf5 is not None 

153 # keras can either be installed as a separate package or used as part of tensorflow 

154 # we try to first import the keras model adapter using the separate package and, 

155 # if it is not available, try to load the one using tf 

156 try: 

157 try: 

158 from .keras_backend import KerasModelAdapter 

159 except Exception: 

160 from .tensorflow_backend import KerasModelAdapter 

161 

162 return KerasModelAdapter( 

163 model_description=model_description, devices=devices 

164 ) 

165 except Exception as e: 

166 errors.append(e) 

167 elif wf == "keras_v3": 

168 assert not isinstance(weights, v0_4.WeightsDescr), ( 

169 "keras_v3 weights not supported for v0.4 specs" 

170 ) 

171 assert weights.keras_v3 is not None 

172 try: 

173 from .keras_backend import KerasModelAdapter 

174 

175 return KerasModelAdapter( 

176 model_description=model_description, devices=devices 

177 ) 

178 except Exception as e: 

179 errors.append(e) 

180 else: 

181 assert_never(wf) 

182 

183 assert errors 

184 if len(weight_format_priority_order) == 1: 

185 assert len(errors) == 1 

186 raise errors[0] 

187 

188 else: 

189 msg = ( 

190 "None of the weight format specific model adapters could be created" 

191 + " in this environment." 

192 ) 

193 raise ExceptionGroup(msg, errors) 

194 

195 @final 

196 def load(self, *, devices: Optional[Sequence[str]] = None) -> None: 

197 warnings.warn("Deprecated. ModelAdapter is loaded on initialization") 

198 

199 def forward(self, input_sample: Union[Sample, SampleBlock]) -> Sample: 

200 """ 

201 Run forward pass of model to get model predictions 

202 

203 Note: sample id and stample stat attributes are passed through 

204 """ 

205 unexpected = [mid for mid in input_sample.members if mid not in self._input_ids] 

206 if unexpected: 

207 warnings.warn(f"Got unexpected input tensor IDs: {unexpected}") 

208 

209 input_arrays = [ 

210 ( 

211 None 

212 if (a := input_sample.members.get(in_id)) is None 

213 else a.transpose(in_order).data.data 

214 ) 

215 for in_id, in_order in zip(self._input_ids, self._input_axes) 

216 ] 

217 logger.debug( 

218 "NN input shapes: {}", 

219 [a.shape if a is not None else None for a in input_arrays], 

220 ) 

221 output_arrays = self._forward_impl(input_arrays) 

222 logger.debug( 

223 "NN output shapes: {}", 

224 [a.shape if a is not None else None for a in output_arrays], 

225 ) 

226 if len(output_arrays) > len(self._output_ids): 

227 warnings.warn( 

228 f"Model produced more outputs ({len(output_arrays)}) than specified in the model description ({len(self._output_ids)}). Extra outputs will be ignored." 

229 ) 

230 output_arrays = output_arrays[: len(self._output_ids)] 

231 

232 output_tensors = [ 

233 None if a is None else Tensor(a, dims=d) 

234 for a, d in zip(output_arrays, self._output_axes) 

235 ] 

236 return Sample( 

237 members={ 

238 tid: out 

239 for tid, out in zip( 

240 self._output_ids, 

241 output_tensors, 

242 ) 

243 if out is not None 

244 }, 

245 stat=input_sample.stat, 

246 id=( 

247 input_sample.id 

248 if isinstance(input_sample, Sample) 

249 else input_sample.sample_id 

250 ), 

251 ) 

252 

253 @abstractmethod 

254 def _forward_impl( 

255 self, input_arrays: Sequence[Optional[NDArray[Any]]] 

256 ) -> Union[List[Optional[NDArray[Any]]], Tuple[Optional[NDArray[Any]]]]: 

257 """framework specific forward implementation""" 

258 

259 @abstractmethod 

260 def unload(self): 

261 """ 

262 Unload model from any devices, freeing their memory. 

263 The moder adapter should be considered unusable afterwards. 

264 """ 

265 

266 def _get_input_args_numpy(self, input_sample: Sample): 

267 """helper to extract tensor args as transposed numpy arrays""" 

268 

269 

270create_model_adapter = ModelAdapter.create