Coverage for src/bioimageio/core/_model_adapter.py: 87%

118 statements  

« prev     ^ index     » next       coverage.py v7.14.2, created at 2026-06-22 16:54 +0000

1import gc 

2import warnings 

3from abc import ABC, abstractmethod 

4from queue import LifoQueue 

5from typing import Any, Dict, Generic, Iterable, List, Optional, Sequence, Tuple, Union 

6 

7from exceptiongroup import ExceptionGroup 

8from loguru import logger 

9from numpy.typing import NDArray 

10from typing_extensions import TypeVar 

11 

12from bioimageio.spec import ValidationSummary 

13from bioimageio.spec.model import AnyModelDescr, v0_4 

14 

15from ._sample_serializer import SampleSerializer, SerializedSampleBlockType 

16from .common import PerMember 

17from .digest_spec import get_axes_infos, get_member_ids 

18from .sample import Sample 

19from .tensor import Tensor 

20 

21 

22class ModelAdapter(ABC): 

23 """ 

24 Represents model *without* any preprocessing or postprocessing. 

25 

26 ``` 

27 from bioimageio.core import load_description 

28 

29 model = load_description(...) 

30 

31 # option 1: 

32 adapter = create_model_adapter(model) 

33 adapter.forward(...) 

34 adapter.unload() 

35 

36 # option 2: 

37 with create_model_adapter(model) as adapter: 

38 adapter.forward(...) 

39 ``` 

40 """ 

41 

42 def __init__( 

43 self, model_description: AnyModelDescr, devices: Optional[Sequence[str]] 

44 ): 

45 super().__init__() 

46 self._model_descr = model_description 

47 self._input_ids = get_member_ids(model_description.inputs) 

48 self._output_ids = get_member_ids(model_description.outputs) 

49 self._input_axes = [ 

50 tuple(a.id for a in get_axes_infos(t)) for t in model_description.inputs 

51 ] 

52 self._output_axes = [ 

53 tuple(a.id for a in get_axes_infos(t)) for t in model_description.outputs 

54 ] 

55 if isinstance(model_description, v0_4.ModelDescr): 

56 self._input_is_optional = [False] * len(model_description.inputs) 

57 else: 

58 self._input_is_optional = [ipt.optional for ipt in model_description.inputs] 

59 

60 self._devices = devices 

61 self.load() 

62 

63 @property 

64 def model_descr(self) -> AnyModelDescr: 

65 return self._model_descr 

66 

67 @abstractmethod 

68 def load(self) -> None: 

69 self._loaded = True 

70 

71 @abstractmethod 

72 def forward( 

73 self, inputs: PerMember[Optional[Tensor]] 

74 ) -> PerMember[Optional[Tensor]]: ... 

75 

76 @abstractmethod 

77 def unload(self): 

78 """Unload model from any devices, freeing their memory. 

79 

80 Note: 

81 The moder adapter should be considered unusable afterwards. 

82 """ 

83 self._loaded = False 

84 

85 def close(self): 

86 """Close the model adapter, freeing any resources. 

87 

88 Note: 

89 The moder adapter should be considered unusable afterwards. 

90 """ 

91 self.unload() 

92 

93 

94DeviceType = TypeVar("DeviceType") 

95ModelType = TypeVar("ModelType") 

96 

97 

98class LocalModelAdapter(ModelAdapter, ABC, Generic[DeviceType, ModelType]): 

99 def load(self) -> None: 

100 devices = self._devices 

101 self._model_queue: LifoQueue[Tuple[DeviceType, ModelType]] = LifoQueue() 

102 parsed_devices = self._parse_devices(devices) 

103 assert parsed_devices 

104 # prioritize devices by order specified by user 

105 device_exceptions: Dict[str, Exception] = {} 

106 self._initialized_devices: List[str] = [] 

107 for d in parsed_devices[::-1]: 

108 try: 

109 model = self._init_model_on_device(d) 

110 except Exception as e: 

111 device_exceptions[str(d)] = e 

112 else: 

113 self._model_queue.put((d, model)) 

114 self._initialized_devices.insert(0, str(d)) 

115 

116 if self._model_queue.empty(): 

117 if len(device_exceptions) == 1: 

118 raise next(iter(device_exceptions.values())) 

119 else: 

120 raise ExceptionGroup( 

121 "Failed to initialize model on any of the requested devices.", 

122 list(device_exceptions.values())[::-1], 

123 ) 

124 

125 if device_exceptions: 

126 logger.warning( 

127 "Failed to initialize model on some of the requested devices. Successfully initialized on {}, but got the following errors for other devices: {}", 

128 self._initialized_devices, 

129 device_exceptions, 

130 ) 

131 

132 super().load() 

133 

134 @abstractmethod 

135 def _parse_devices(self, devices: Optional[Sequence[str]]) -> Sequence[DeviceType]: 

136 """Parse devices 

137 

138 Note: 

139 - May not return an empty sequence. 

140 - The order of devices in the returned sequence determines the priority of device usage in the forward pass. 

141 First devices has highgest priority, last device has lowest priority. 

142 """ 

143 

144 @abstractmethod 

145 def _init_model_on_device(self, device: DeviceType) -> ModelType: ... 

146 

147 def forward( 

148 self, inputs: PerMember[Optional[Tensor]] 

149 ) -> PerMember[Optional[Tensor]]: 

150 """ 

151 Run forward pass of model to get model predictions 

152 

153 Note: sample id and stample stat attributes are passed through 

154 """ 

155 if not self._loaded: 

156 raise RuntimeError("Model must be `.load()`ed before calling forward()") 

157 

158 unexpected = [mid for mid in inputs if mid not in self._input_ids] 

159 if unexpected: 

160 warnings.warn(f"Got unexpected input tensor IDs: {unexpected}") 

161 

162 input_arrays = [ 

163 ( 

164 None 

165 if (a := inputs.get(in_id)) is None 

166 else a.transpose(in_order).data.data 

167 ) 

168 for in_id, in_order in zip(self._input_ids, self._input_axes) 

169 ] 

170 logger.debug( 

171 "NN input shapes: {}", 

172 [a.shape if a is not None else None for a in input_arrays], 

173 ) 

174 device, model = self._model_queue.get() 

175 try: 

176 output_arrays = self._forward_impl(device, model, input_arrays) 

177 finally: 

178 self._model_queue.put((device, model)) 

179 

180 logger.debug( 

181 "NN output shapes: {}", 

182 [a.shape if a is not None else None for a in output_arrays], 

183 ) 

184 if len(output_arrays) > len(self._output_ids): 

185 warnings.warn( 

186 f"Model produced more outputs ({len(output_arrays)}) than specified in the model description ({len(self._output_ids)}). Extra outputs will be ignored." 

187 ) 

188 output_arrays = output_arrays[: len(self._output_ids)] 

189 

190 output_tensors = [ 

191 None if a is None else Tensor(a, dims=d) 

192 for a, d in zip(output_arrays, self._output_axes) 

193 ] 

194 return { 

195 tid: out 

196 for tid, out in zip( 

197 self._output_ids, 

198 output_tensors, 

199 ) 

200 if out is not None 

201 } 

202 

203 @abstractmethod 

204 def _forward_impl( 

205 self, 

206 device: DeviceType, 

207 model: ModelType, 

208 input_arrays: Sequence[Optional[NDArray[Any]]], 

209 ) -> Union[List[Optional[NDArray[Any]]], Tuple[Optional[NDArray[Any]], ...]]: 

210 """framework specific forward implementation""" 

211 

212 def unload(self): 

213 for _ in range(len(self._initialized_devices)): 

214 device, model = self._model_queue.get() 

215 try: 

216 self._cleanup_pre_model_deletion(device, model) 

217 except Exception as e: 

218 logger.warning( 

219 "Got error during pre-deletion cleanup on device {}: {}", device, e 

220 ) 

221 finally: 

222 del model 

223 try: 

224 self._cleanup_post_model_deletion(device) 

225 except Exception as e: 

226 logger.warning( 

227 "Got error during post-deletion cleanup on device {}: {}", device, e 

228 ) 

229 

230 _ = gc.collect() # deallocate memory 

231 super().unload() 

232 

233 @abstractmethod 

234 def _cleanup_pre_model_deletion(self, device: DeviceType, model: ModelType) -> None: 

235 """Clean up before model reference deletion""" 

236 

237 @abstractmethod 

238 def _cleanup_post_model_deletion(self, device: DeviceType) -> None: 

239 """Clean up after model reference deletion""" 

240 

241 

242class RemoteModelAdapter(ModelAdapter, ABC, Generic[SerializedSampleBlockType]): 

243 """Model adapter to use a remote service for model inference.""" 

244 

245 def __init__( 

246 self, 

247 model_description: AnyModelDescr, 

248 server: str, 

249 sample_serializer: SampleSerializer[SerializedSampleBlockType], 

250 ): 

251 super().__init__(model_description, devices=None) 

252 self._server = server 

253 self._serializer = sample_serializer 

254 

255 @property 

256 def server(self) -> str: 

257 return self._server 

258 

259 def forward( 

260 self, inputs: PerMember[Optional[Tensor]] 

261 ) -> PerMember[Optional[Tensor]]: 

262 serialized_input = self._serializer.serialize_sample( 

263 Sample( 

264 members={k: v for k, v in inputs.items() if v is not None}, 

265 stat={}, 

266 id=None, 

267 ) 

268 ) 

269 serialized_output = self._forward_impl(serialized_input) 

270 return self._serializer.deserialize_sample(serialized_output).members 

271 

272 @abstractmethod 

273 def _forward_impl( 

274 self, serialized_input_sample: Iterable[SerializedSampleBlockType] 

275 ) -> Iterable[SerializedSampleBlockType]: ... 

276 

277 @abstractmethod 

278 def test(self) -> Optional[ValidationSummary]: 

279 """Run the bioimageio model test."""