Coverage for src/bioimageio/core/_model_adapter.py: 87%
118 statements
« prev ^ index » next coverage.py v7.14.2, created at 2026-06-22 16:54 +0000
« prev ^ index » next coverage.py v7.14.2, created at 2026-06-22 16:54 +0000
1import gc
2import warnings
3from abc import ABC, abstractmethod
4from queue import LifoQueue
5from typing import Any, Dict, Generic, Iterable, List, Optional, Sequence, Tuple, Union
7from exceptiongroup import ExceptionGroup
8from loguru import logger
9from numpy.typing import NDArray
10from typing_extensions import TypeVar
12from bioimageio.spec import ValidationSummary
13from bioimageio.spec.model import AnyModelDescr, v0_4
15from ._sample_serializer import SampleSerializer, SerializedSampleBlockType
16from .common import PerMember
17from .digest_spec import get_axes_infos, get_member_ids
18from .sample import Sample
19from .tensor import Tensor
22class ModelAdapter(ABC):
23 """
24 Represents model *without* any preprocessing or postprocessing.
26 ```
27 from bioimageio.core import load_description
29 model = load_description(...)
31 # option 1:
32 adapter = create_model_adapter(model)
33 adapter.forward(...)
34 adapter.unload()
36 # option 2:
37 with create_model_adapter(model) as adapter:
38 adapter.forward(...)
39 ```
40 """
42 def __init__(
43 self, model_description: AnyModelDescr, devices: Optional[Sequence[str]]
44 ):
45 super().__init__()
46 self._model_descr = model_description
47 self._input_ids = get_member_ids(model_description.inputs)
48 self._output_ids = get_member_ids(model_description.outputs)
49 self._input_axes = [
50 tuple(a.id for a in get_axes_infos(t)) for t in model_description.inputs
51 ]
52 self._output_axes = [
53 tuple(a.id for a in get_axes_infos(t)) for t in model_description.outputs
54 ]
55 if isinstance(model_description, v0_4.ModelDescr):
56 self._input_is_optional = [False] * len(model_description.inputs)
57 else:
58 self._input_is_optional = [ipt.optional for ipt in model_description.inputs]
60 self._devices = devices
61 self.load()
63 @property
64 def model_descr(self) -> AnyModelDescr:
65 return self._model_descr
67 @abstractmethod
68 def load(self) -> None:
69 self._loaded = True
71 @abstractmethod
72 def forward(
73 self, inputs: PerMember[Optional[Tensor]]
74 ) -> PerMember[Optional[Tensor]]: ...
76 @abstractmethod
77 def unload(self):
78 """Unload model from any devices, freeing their memory.
80 Note:
81 The moder adapter should be considered unusable afterwards.
82 """
83 self._loaded = False
85 def close(self):
86 """Close the model adapter, freeing any resources.
88 Note:
89 The moder adapter should be considered unusable afterwards.
90 """
91 self.unload()
94DeviceType = TypeVar("DeviceType")
95ModelType = TypeVar("ModelType")
98class LocalModelAdapter(ModelAdapter, ABC, Generic[DeviceType, ModelType]):
99 def load(self) -> None:
100 devices = self._devices
101 self._model_queue: LifoQueue[Tuple[DeviceType, ModelType]] = LifoQueue()
102 parsed_devices = self._parse_devices(devices)
103 assert parsed_devices
104 # prioritize devices by order specified by user
105 device_exceptions: Dict[str, Exception] = {}
106 self._initialized_devices: List[str] = []
107 for d in parsed_devices[::-1]:
108 try:
109 model = self._init_model_on_device(d)
110 except Exception as e:
111 device_exceptions[str(d)] = e
112 else:
113 self._model_queue.put((d, model))
114 self._initialized_devices.insert(0, str(d))
116 if self._model_queue.empty():
117 if len(device_exceptions) == 1:
118 raise next(iter(device_exceptions.values()))
119 else:
120 raise ExceptionGroup(
121 "Failed to initialize model on any of the requested devices.",
122 list(device_exceptions.values())[::-1],
123 )
125 if device_exceptions:
126 logger.warning(
127 "Failed to initialize model on some of the requested devices. Successfully initialized on {}, but got the following errors for other devices: {}",
128 self._initialized_devices,
129 device_exceptions,
130 )
132 super().load()
134 @abstractmethod
135 def _parse_devices(self, devices: Optional[Sequence[str]]) -> Sequence[DeviceType]:
136 """Parse devices
138 Note:
139 - May not return an empty sequence.
140 - The order of devices in the returned sequence determines the priority of device usage in the forward pass.
141 First devices has highgest priority, last device has lowest priority.
142 """
144 @abstractmethod
145 def _init_model_on_device(self, device: DeviceType) -> ModelType: ...
147 def forward(
148 self, inputs: PerMember[Optional[Tensor]]
149 ) -> PerMember[Optional[Tensor]]:
150 """
151 Run forward pass of model to get model predictions
153 Note: sample id and stample stat attributes are passed through
154 """
155 if not self._loaded:
156 raise RuntimeError("Model must be `.load()`ed before calling forward()")
158 unexpected = [mid for mid in inputs if mid not in self._input_ids]
159 if unexpected:
160 warnings.warn(f"Got unexpected input tensor IDs: {unexpected}")
162 input_arrays = [
163 (
164 None
165 if (a := inputs.get(in_id)) is None
166 else a.transpose(in_order).data.data
167 )
168 for in_id, in_order in zip(self._input_ids, self._input_axes)
169 ]
170 logger.debug(
171 "NN input shapes: {}",
172 [a.shape if a is not None else None for a in input_arrays],
173 )
174 device, model = self._model_queue.get()
175 try:
176 output_arrays = self._forward_impl(device, model, input_arrays)
177 finally:
178 self._model_queue.put((device, model))
180 logger.debug(
181 "NN output shapes: {}",
182 [a.shape if a is not None else None for a in output_arrays],
183 )
184 if len(output_arrays) > len(self._output_ids):
185 warnings.warn(
186 f"Model produced more outputs ({len(output_arrays)}) than specified in the model description ({len(self._output_ids)}). Extra outputs will be ignored."
187 )
188 output_arrays = output_arrays[: len(self._output_ids)]
190 output_tensors = [
191 None if a is None else Tensor(a, dims=d)
192 for a, d in zip(output_arrays, self._output_axes)
193 ]
194 return {
195 tid: out
196 for tid, out in zip(
197 self._output_ids,
198 output_tensors,
199 )
200 if out is not None
201 }
203 @abstractmethod
204 def _forward_impl(
205 self,
206 device: DeviceType,
207 model: ModelType,
208 input_arrays: Sequence[Optional[NDArray[Any]]],
209 ) -> Union[List[Optional[NDArray[Any]]], Tuple[Optional[NDArray[Any]], ...]]:
210 """framework specific forward implementation"""
212 def unload(self):
213 for _ in range(len(self._initialized_devices)):
214 device, model = self._model_queue.get()
215 try:
216 self._cleanup_pre_model_deletion(device, model)
217 except Exception as e:
218 logger.warning(
219 "Got error during pre-deletion cleanup on device {}: {}", device, e
220 )
221 finally:
222 del model
223 try:
224 self._cleanup_post_model_deletion(device)
225 except Exception as e:
226 logger.warning(
227 "Got error during post-deletion cleanup on device {}: {}", device, e
228 )
230 _ = gc.collect() # deallocate memory
231 super().unload()
233 @abstractmethod
234 def _cleanup_pre_model_deletion(self, device: DeviceType, model: ModelType) -> None:
235 """Clean up before model reference deletion"""
237 @abstractmethod
238 def _cleanup_post_model_deletion(self, device: DeviceType) -> None:
239 """Clean up after model reference deletion"""
242class RemoteModelAdapter(ModelAdapter, ABC, Generic[SerializedSampleBlockType]):
243 """Model adapter to use a remote service for model inference."""
245 def __init__(
246 self,
247 model_description: AnyModelDescr,
248 server: str,
249 sample_serializer: SampleSerializer[SerializedSampleBlockType],
250 ):
251 super().__init__(model_description, devices=None)
252 self._server = server
253 self._serializer = sample_serializer
255 @property
256 def server(self) -> str:
257 return self._server
259 def forward(
260 self, inputs: PerMember[Optional[Tensor]]
261 ) -> PerMember[Optional[Tensor]]:
262 serialized_input = self._serializer.serialize_sample(
263 Sample(
264 members={k: v for k, v in inputs.items() if v is not None},
265 stat={},
266 id=None,
267 )
268 )
269 serialized_output = self._forward_impl(serialized_input)
270 return self._serializer.deserialize_sample(serialized_output).members
272 @abstractmethod
273 def _forward_impl(
274 self, serialized_input_sample: Iterable[SerializedSampleBlockType]
275 ) -> Iterable[SerializedSampleBlockType]: ...
277 @abstractmethod
278 def test(self) -> Optional[ValidationSummary]:
279 """Run the bioimageio model test."""