Coverage for src / bioimageio / core / backends / _model_adapter.py: 71%
105 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 13:23 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 13:23 +0000
1import warnings
2from abc import ABC, abstractmethod
3from typing import (
4 Any,
5 List,
6 Optional,
7 Sequence,
8 Tuple,
9 Union,
10 final,
11)
13from exceptiongroup import ExceptionGroup
14from numpy.typing import NDArray
15from typing_extensions import assert_never
17from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
19from ..common import SupportedWeightsFormat
20from ..digest_spec import get_axes_infos, get_member_ids
21from ..sample import Sample, SampleBlock
22from ..tensor import Tensor
24# Known weight formats in order of priority
25# First match wins
26DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER: Tuple[SupportedWeightsFormat, ...] = (
27 "pytorch_state_dict",
28 "tensorflow_saved_model_bundle",
29 "torchscript",
30 "onnx",
31 "keras_v3",
32 "keras_hdf5",
33)
36class ModelAdapter(ABC):
37 """
38 Represents model *without* any preprocessing or postprocessing.
40 ```
41 from bioimageio.core import load_description
43 model = load_description(...)
45 # option 1:
46 adapter = ModelAdapter.create(model)
47 adapter.forward(...)
48 adapter.unload()
50 # option 2:
51 with ModelAdapter.create(model) as adapter:
52 adapter.forward(...)
53 ```
54 """
56 def __init__(self, model_description: AnyModelDescr):
57 super().__init__()
58 self._model_descr = model_description
59 self._input_ids = get_member_ids(model_description.inputs)
60 self._output_ids = get_member_ids(model_description.outputs)
61 self._input_axes = [
62 tuple(a.id for a in get_axes_infos(t)) for t in model_description.inputs
63 ]
64 self._output_axes = [
65 tuple(a.id for a in get_axes_infos(t)) for t in model_description.outputs
66 ]
67 if isinstance(model_description, v0_4.ModelDescr):
68 self._input_is_optional = [False] * len(model_description.inputs)
69 else:
70 self._input_is_optional = [ipt.optional for ipt in model_description.inputs]
72 @final
73 @classmethod
74 def create(
75 cls,
76 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
77 *,
78 devices: Optional[Sequence[str]] = None,
79 weight_format_priority_order: Optional[Sequence[SupportedWeightsFormat]] = None,
80 ):
81 """
82 Creates model adapter based on the passed spec
83 Note: All specific adapters should happen inside this function to prevent different framework
84 initializations interfering with each other
85 """
86 if not isinstance(model_description, (v0_4.ModelDescr, v0_5.ModelDescr)):
87 raise TypeError(
88 f"expected v0_4.ModelDescr or v0_5.ModelDescr, but got {type(model_description)}"
89 )
91 weights = model_description.weights
92 errors: List[Exception] = []
93 weight_format_priority_order = (
94 DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER
95 if weight_format_priority_order is None
96 else weight_format_priority_order
97 )
98 # limit weight formats to the ones present
99 weight_format_priority_order_present: Sequence[SupportedWeightsFormat] = [
100 w
101 for w in weight_format_priority_order
102 if getattr(weights, w, None) is not None
103 ]
104 if not weight_format_priority_order_present:
105 raise ValueError(
106 f"None of the specified weight formats ({weight_format_priority_order}) is present ({weight_format_priority_order_present})"
107 )
109 for wf in weight_format_priority_order_present:
110 if wf == "pytorch_state_dict":
111 assert weights.pytorch_state_dict is not None
112 try:
113 from .pytorch_backend import PytorchModelAdapter
115 return PytorchModelAdapter(
116 model_description=model_description, devices=devices
117 )
118 except Exception as e:
119 errors.append(e)
120 elif wf == "tensorflow_saved_model_bundle":
121 assert weights.tensorflow_saved_model_bundle is not None
122 try:
123 from .tensorflow_backend import create_tf_model_adapter
125 return create_tf_model_adapter(
126 model_description=model_description, devices=devices
127 )
128 except Exception as e:
129 errors.append(e)
130 elif wf == "onnx":
131 assert weights.onnx is not None
132 try:
133 from .onnx_backend import ONNXModelAdapter
135 return ONNXModelAdapter(
136 model_description=model_description, devices=devices
137 )
138 except Exception as e:
139 errors.append(e)
140 elif wf == "torchscript":
141 assert weights.torchscript is not None
142 try:
143 from .torchscript_backend import TorchscriptModelAdapter
145 return TorchscriptModelAdapter(
146 model_description=model_description, devices=devices
147 )
148 except Exception as e:
149 errors.append(e)
150 elif wf == "keras_hdf5":
151 assert weights.keras_hdf5 is not None
152 # keras can either be installed as a separate package or used as part of tensorflow
153 # we try to first import the keras model adapter using the separate package and,
154 # if it is not available, try to load the one using tf
155 try:
156 try:
157 from .keras_backend import KerasModelAdapter
158 except Exception:
159 from .tensorflow_backend import KerasModelAdapter
161 return KerasModelAdapter(
162 model_description=model_description, devices=devices
163 )
164 except Exception as e:
165 errors.append(e)
166 elif wf == "keras_v3":
167 assert not isinstance(weights, v0_4.WeightsDescr), (
168 "keras_v3 weights not supported for v0.4 specs"
169 )
170 assert weights.keras_v3 is not None
171 try:
172 from .keras_backend import KerasModelAdapter
174 return KerasModelAdapter(
175 model_description=model_description, devices=devices
176 )
177 except Exception as e:
178 errors.append(e)
179 else:
180 assert_never(wf)
182 assert errors
183 if len(weight_format_priority_order) == 1:
184 assert len(errors) == 1
185 raise errors[0]
187 else:
188 msg = (
189 "None of the weight format specific model adapters could be created"
190 + " in this environment."
191 )
192 raise ExceptionGroup(msg, errors)
194 @final
195 def load(self, *, devices: Optional[Sequence[str]] = None) -> None:
196 warnings.warn("Deprecated. ModelAdapter is loaded on initialization")
198 def forward(self, input_sample: Union[Sample, SampleBlock]) -> Sample:
199 """
200 Run forward pass of model to get model predictions
202 Note: sample id and stample stat attributes are passed through
203 """
204 unexpected = [mid for mid in input_sample.members if mid not in self._input_ids]
205 if unexpected:
206 warnings.warn(f"Got unexpected input tensor IDs: {unexpected}")
208 input_arrays = [
209 (
210 None
211 if (a := input_sample.members.get(in_id)) is None
212 else a.transpose(in_order).data.data
213 )
214 for in_id, in_order in zip(self._input_ids, self._input_axes)
215 ]
216 output_arrays = self._forward_impl(input_arrays)
217 assert len(output_arrays) <= len(self._output_ids)
218 output_tensors = [
219 None if a is None else Tensor(a, dims=d)
220 for a, d in zip(output_arrays, self._output_axes)
221 ]
222 return Sample(
223 members={
224 tid: out
225 for tid, out in zip(
226 self._output_ids,
227 output_tensors,
228 )
229 if out is not None
230 },
231 stat=input_sample.stat,
232 id=(
233 input_sample.id
234 if isinstance(input_sample, Sample)
235 else input_sample.sample_id
236 ),
237 )
239 @abstractmethod
240 def _forward_impl(
241 self, input_arrays: Sequence[Optional[NDArray[Any]]]
242 ) -> Union[List[Optional[NDArray[Any]]], Tuple[Optional[NDArray[Any]]]]:
243 """framework specific forward implementation"""
245 @abstractmethod
246 def unload(self):
247 """
248 Unload model from any devices, freeing their memory.
249 The moder adapter should be considered unusable afterwards.
250 """
252 def _get_input_args_numpy(self, input_sample: Sample):
253 """helper to extract tensor args as transposed numpy arrays"""
256create_model_adapter = ModelAdapter.create