Coverage for bioimageio/core/backends/_model_adapter.py: 76%

99 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-01 09:51 +0000

1import sys 

2import warnings 

3from abc import ABC, abstractmethod 

4from typing import ( 

5 Any, 

6 List, 

7 Optional, 

8 Sequence, 

9 Tuple, 

10 Union, 

11 final, 

12) 

13 

14from numpy.typing import NDArray 

15from typing_extensions import assert_never 

16 

17from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 

18 

19from ..common import SupportedWeightsFormat 

20from ..digest_spec import get_axes_infos, get_member_ids 

21from ..sample import Sample, SampleBlock, SampleBlockWithOrigin 

22from ..tensor import Tensor 

23 

24# Known weight formats in order of priority 

25# First match wins 

26DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER: Tuple[SupportedWeightsFormat, ...] = ( 

27 "pytorch_state_dict", 

28 "tensorflow_saved_model_bundle", 

29 "torchscript", 

30 "onnx", 

31 "keras_hdf5", 

32) 

33 

34 

35class ModelAdapter(ABC): 

36 """ 

37 Represents model *without* any preprocessing or postprocessing. 

38 

39 ``` 

40 from bioimageio.core import load_description 

41 

42 model = load_description(...) 

43 

44 # option 1: 

45 adapter = ModelAdapter.create(model) 

46 adapter.forward(...) 

47 adapter.unload() 

48 

49 # option 2: 

50 with ModelAdapter.create(model) as adapter: 

51 adapter.forward(...) 

52 ``` 

53 """ 

54 

55 def __init__(self, model_description: AnyModelDescr): 

56 super().__init__() 

57 self._model_descr = model_description 

58 self._input_ids = get_member_ids(model_description.inputs) 

59 self._output_ids = get_member_ids(model_description.outputs) 

60 self._input_axes = [ 

61 tuple(a.id for a in get_axes_infos(t)) for t in model_description.inputs 

62 ] 

63 self._output_axes = [ 

64 tuple(a.id for a in get_axes_infos(t)) for t in model_description.outputs 

65 ] 

66 if isinstance(model_description, v0_4.ModelDescr): 

67 self._input_is_optional = [False] * len(model_description.inputs) 

68 else: 

69 self._input_is_optional = [ipt.optional for ipt in model_description.inputs] 

70 

71 @final 

72 @classmethod 

73 def create( 

74 cls, 

75 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

76 *, 

77 devices: Optional[Sequence[str]] = None, 

78 weight_format_priority_order: Optional[Sequence[SupportedWeightsFormat]] = None, 

79 ): 

80 """ 

81 Creates model adapter based on the passed spec 

82 Note: All specific adapters should happen inside this function to prevent different framework 

83 initializations interfering with each other 

84 """ 

85 if not isinstance(model_description, (v0_4.ModelDescr, v0_5.ModelDescr)): 

86 raise TypeError( 

87 f"expected v0_4.ModelDescr or v0_5.ModelDescr, but got {type(model_description)}" 

88 ) 

89 

90 weights = model_description.weights 

91 errors: List[Exception] = [] 

92 weight_format_priority_order = ( 

93 DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER 

94 if weight_format_priority_order is None 

95 else weight_format_priority_order 

96 ) 

97 # limit weight formats to the ones present 

98 weight_format_priority_order_present: Sequence[SupportedWeightsFormat] = [ 

99 w for w in weight_format_priority_order if getattr(weights, w) is not None 

100 ] 

101 if not weight_format_priority_order_present: 

102 raise ValueError( 

103 f"None of the specified weight formats ({weight_format_priority_order}) is present ({weight_format_priority_order_present})" 

104 ) 

105 

106 for wf in weight_format_priority_order_present: 

107 if wf == "pytorch_state_dict": 

108 assert weights.pytorch_state_dict is not None 

109 try: 

110 from .pytorch_backend import PytorchModelAdapter 

111 

112 return PytorchModelAdapter( 

113 model_description=model_description, devices=devices 

114 ) 

115 except Exception as e: 

116 errors.append(e) 

117 elif wf == "tensorflow_saved_model_bundle": 

118 assert weights.tensorflow_saved_model_bundle is not None 

119 try: 

120 from .tensorflow_backend import create_tf_model_adapter 

121 

122 return create_tf_model_adapter( 

123 model_description=model_description, devices=devices 

124 ) 

125 except Exception as e: 

126 errors.append(e) 

127 elif wf == "onnx": 

128 assert weights.onnx is not None 

129 try: 

130 from .onnx_backend import ONNXModelAdapter 

131 

132 return ONNXModelAdapter( 

133 model_description=model_description, devices=devices 

134 ) 

135 except Exception as e: 

136 errors.append(e) 

137 elif wf == "torchscript": 

138 assert weights.torchscript is not None 

139 try: 

140 from .torchscript_backend import TorchscriptModelAdapter 

141 

142 return TorchscriptModelAdapter( 

143 model_description=model_description, devices=devices 

144 ) 

145 except Exception as e: 

146 errors.append(e) 

147 elif wf == "keras_hdf5": 

148 assert weights.keras_hdf5 is not None 

149 # keras can either be installed as a separate package or used as part of tensorflow 

150 # we try to first import the keras model adapter using the separate package and, 

151 # if it is not available, try to load the one using tf 

152 try: 

153 try: 

154 from .keras_backend import KerasModelAdapter 

155 except Exception: 

156 from .tensorflow_backend import KerasModelAdapter 

157 

158 return KerasModelAdapter( 

159 model_description=model_description, devices=devices 

160 ) 

161 except Exception as e: 

162 errors.append(e) 

163 else: 

164 assert_never(wf) 

165 

166 assert errors 

167 if len(weight_format_priority_order) == 1: 

168 assert len(errors) == 1 

169 raise errors[0] 

170 

171 else: 

172 msg = ( 

173 "None of the weight format specific model adapters could be created" 

174 + " in this environment." 

175 ) 

176 if sys.version_info[:2] >= (3, 11): 

177 raise ExceptionGroup(msg, errors) 

178 else: 

179 raise ValueError(msg) from Exception(errors) 

180 

181 @final 

182 def load(self, *, devices: Optional[Sequence[str]] = None) -> None: 

183 warnings.warn("Deprecated. ModelAdapter is loaded on initialization") 

184 

185 def forward( 

186 self, input_sample: Union[Sample, SampleBlock, SampleBlockWithOrigin] 

187 ) -> Sample: 

188 """ 

189 Run forward pass of model to get model predictions 

190 

191 Note: sample id and stample stat attributes are passed through 

192 """ 

193 unexpected = [mid for mid in input_sample.members if mid not in self._input_ids] 

194 if unexpected: 

195 warnings.warn(f"Got unexpected input tensor IDs: {unexpected}") 

196 

197 input_arrays = [ 

198 ( 

199 None 

200 if (a := input_sample.members.get(in_id)) is None 

201 else a.transpose(in_order).data.data 

202 ) 

203 for in_id, in_order in zip(self._input_ids, self._input_axes) 

204 ] 

205 output_arrays = self._forward_impl(input_arrays) 

206 assert len(output_arrays) <= len(self._output_ids) 

207 output_tensors = [ 

208 None if a is None else Tensor(a, dims=d) 

209 for a, d in zip(output_arrays, self._output_axes) 

210 ] 

211 return Sample( 

212 members={ 

213 tid: out 

214 for tid, out in zip( 

215 self._output_ids, 

216 output_tensors, 

217 ) 

218 if out is not None 

219 }, 

220 stat=input_sample.stat, 

221 id=( 

222 input_sample.id 

223 if isinstance(input_sample, Sample) 

224 else input_sample.sample_id 

225 ), 

226 ) 

227 

228 @abstractmethod 

229 def _forward_impl( 

230 self, input_arrays: Sequence[Optional[NDArray[Any]]] 

231 ) -> Union[List[Optional[NDArray[Any]]], Tuple[Optional[NDArray[Any]]]]: 

232 """framework specific forward implementation""" 

233 

234 @abstractmethod 

235 def unload(self): 

236 """ 

237 Unload model from any devices, freeing their memory. 

238 The moder adapter should be considered unusable afterwards. 

239 """ 

240 

241 def _get_input_args_numpy(self, input_sample: Sample): 

242 """helper to extract tensor args as transposed numpy arrays""" 

243 

244 

245create_model_adapter = ModelAdapter.create