Coverage for src / bioimageio / core / backends / _model_adapter.py: 73%
110 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-18 12:35 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-18 12:35 +0000
1import warnings
2from abc import ABC, abstractmethod
3from typing import (
4 Any,
5 List,
6 Optional,
7 Sequence,
8 Tuple,
9 Union,
10 final,
11)
13from exceptiongroup import ExceptionGroup
14from loguru import logger
15from numpy.typing import NDArray
16from typing_extensions import assert_never
18from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
20from ..common import SupportedWeightsFormat
21from ..digest_spec import get_axes_infos, get_member_ids
22from ..sample import Sample, SampleBlock
23from ..tensor import Tensor
25# Known weight formats in order of priority
26# First match wins
27DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER: Tuple[SupportedWeightsFormat, ...] = (
28 "pytorch_state_dict",
29 "tensorflow_saved_model_bundle",
30 "torchscript",
31 "onnx",
32 "keras_v3",
33 "keras_hdf5",
34)
37class ModelAdapter(ABC):
38 """
39 Represents model *without* any preprocessing or postprocessing.
41 ```
42 from bioimageio.core import load_description
44 model = load_description(...)
46 # option 1:
47 adapter = ModelAdapter.create(model)
48 adapter.forward(...)
49 adapter.unload()
51 # option 2:
52 with ModelAdapter.create(model) as adapter:
53 adapter.forward(...)
54 ```
55 """
57 def __init__(self, model_description: AnyModelDescr):
58 super().__init__()
59 self._model_descr = model_description
60 self._input_ids = get_member_ids(model_description.inputs)
61 self._output_ids = get_member_ids(model_description.outputs)
62 self._input_axes = [
63 tuple(a.id for a in get_axes_infos(t)) for t in model_description.inputs
64 ]
65 self._output_axes = [
66 tuple(a.id for a in get_axes_infos(t)) for t in model_description.outputs
67 ]
68 if isinstance(model_description, v0_4.ModelDescr):
69 self._input_is_optional = [False] * len(model_description.inputs)
70 else:
71 self._input_is_optional = [ipt.optional for ipt in model_description.inputs]
73 @final
74 @classmethod
75 def create(
76 cls,
77 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
78 *,
79 devices: Optional[Sequence[str]] = None,
80 weight_format_priority_order: Optional[Sequence[SupportedWeightsFormat]] = None,
81 ):
82 """
83 Creates model adapter based on the passed spec
84 Note: All specific adapters should happen inside this function to prevent different framework
85 initializations interfering with each other
86 """
87 if not isinstance(model_description, (v0_4.ModelDescr, v0_5.ModelDescr)):
88 raise TypeError(
89 f"expected v0_4.ModelDescr or v0_5.ModelDescr, but got {type(model_description)}"
90 )
92 weights = model_description.weights
93 errors: List[Exception] = []
94 weight_format_priority_order = (
95 DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER
96 if weight_format_priority_order is None
97 else weight_format_priority_order
98 )
99 # limit weight formats to the ones present
100 weight_format_priority_order_present: Sequence[SupportedWeightsFormat] = [
101 w
102 for w in weight_format_priority_order
103 if getattr(weights, w, None) is not None
104 ]
105 if not weight_format_priority_order_present:
106 raise ValueError(
107 f"None of the specified weight formats ({weight_format_priority_order}) is present ({weight_format_priority_order_present})"
108 )
110 for wf in weight_format_priority_order_present:
111 if wf == "pytorch_state_dict":
112 assert weights.pytorch_state_dict is not None
113 try:
114 from .pytorch_backend import PytorchModelAdapter
116 return PytorchModelAdapter(
117 model_description=model_description, devices=devices
118 )
119 except Exception as e:
120 errors.append(e)
121 elif wf == "tensorflow_saved_model_bundle":
122 assert weights.tensorflow_saved_model_bundle is not None
123 try:
124 from .tensorflow_backend import create_tf_model_adapter
126 return create_tf_model_adapter(
127 model_description=model_description, devices=devices
128 )
129 except Exception as e:
130 errors.append(e)
131 elif wf == "onnx":
132 assert weights.onnx is not None
133 try:
134 from .onnx_backend import ONNXModelAdapter
136 return ONNXModelAdapter(
137 model_description=model_description, devices=devices
138 )
139 except Exception as e:
140 errors.append(e)
141 elif wf == "torchscript":
142 assert weights.torchscript is not None
143 try:
144 from .torchscript_backend import TorchscriptModelAdapter
146 return TorchscriptModelAdapter(
147 model_description=model_description, devices=devices
148 )
149 except Exception as e:
150 errors.append(e)
151 elif wf == "keras_hdf5":
152 assert weights.keras_hdf5 is not None
153 # keras can either be installed as a separate package or used as part of tensorflow
154 # we try to first import the keras model adapter using the separate package and,
155 # if it is not available, try to load the one using tf
156 try:
157 try:
158 from .keras_backend import KerasModelAdapter
159 except Exception:
160 from .tensorflow_backend import KerasModelAdapter
162 return KerasModelAdapter(
163 model_description=model_description, devices=devices
164 )
165 except Exception as e:
166 errors.append(e)
167 elif wf == "keras_v3":
168 assert not isinstance(weights, v0_4.WeightsDescr), (
169 "keras_v3 weights not supported for v0.4 specs"
170 )
171 assert weights.keras_v3 is not None
172 try:
173 from .keras_backend import KerasModelAdapter
175 return KerasModelAdapter(
176 model_description=model_description, devices=devices
177 )
178 except Exception as e:
179 errors.append(e)
180 else:
181 assert_never(wf)
183 assert errors
184 if len(weight_format_priority_order) == 1:
185 assert len(errors) == 1
186 raise errors[0]
188 else:
189 msg = (
190 "None of the weight format specific model adapters could be created"
191 + " in this environment."
192 )
193 raise ExceptionGroup(msg, errors)
195 @final
196 def load(self, *, devices: Optional[Sequence[str]] = None) -> None:
197 warnings.warn("Deprecated. ModelAdapter is loaded on initialization")
199 def forward(self, input_sample: Union[Sample, SampleBlock]) -> Sample:
200 """
201 Run forward pass of model to get model predictions
203 Note: sample id and stample stat attributes are passed through
204 """
205 unexpected = [mid for mid in input_sample.members if mid not in self._input_ids]
206 if unexpected:
207 warnings.warn(f"Got unexpected input tensor IDs: {unexpected}")
209 input_arrays = [
210 (
211 None
212 if (a := input_sample.members.get(in_id)) is None
213 else a.transpose(in_order).data.data
214 )
215 for in_id, in_order in zip(self._input_ids, self._input_axes)
216 ]
217 logger.debug(
218 "NN input shapes: {}",
219 [a.shape if a is not None else None for a in input_arrays],
220 )
221 output_arrays = self._forward_impl(input_arrays)
222 logger.debug(
223 "NN output shapes: {}",
224 [a.shape if a is not None else None for a in output_arrays],
225 )
226 if len(output_arrays) > len(self._output_ids):
227 warnings.warn(
228 f"Model produced more outputs ({len(output_arrays)}) than specified in the model description ({len(self._output_ids)}). Extra outputs will be ignored."
229 )
230 output_arrays = output_arrays[: len(self._output_ids)]
232 output_tensors = [
233 None if a is None else Tensor(a, dims=d)
234 for a, d in zip(output_arrays, self._output_axes)
235 ]
236 return Sample(
237 members={
238 tid: out
239 for tid, out in zip(
240 self._output_ids,
241 output_tensors,
242 )
243 if out is not None
244 },
245 stat=input_sample.stat,
246 id=(
247 input_sample.id
248 if isinstance(input_sample, Sample)
249 else input_sample.sample_id
250 ),
251 )
253 @abstractmethod
254 def _forward_impl(
255 self, input_arrays: Sequence[Optional[NDArray[Any]]]
256 ) -> Union[List[Optional[NDArray[Any]]], Tuple[Optional[NDArray[Any]]]]:
257 """framework specific forward implementation"""
259 @abstractmethod
260 def unload(self):
261 """
262 Unload model from any devices, freeing their memory.
263 The moder adapter should be considered unusable afterwards.
264 """
266 def _get_input_args_numpy(self, input_sample: Sample):
267 """helper to extract tensor args as transposed numpy arrays"""
270create_model_adapter = ModelAdapter.create