Coverage for src/bioimageio/core/backends/_model_adapter.py: 77%
97 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-09-22 09:21 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-09-22 09:21 +0000
1import warnings
2from abc import ABC, abstractmethod
3from typing import (
4 Any,
5 List,
6 Optional,
7 Sequence,
8 Tuple,
9 Union,
10 final,
11)
13from exceptiongroup import ExceptionGroup
14from numpy.typing import NDArray
15from typing_extensions import assert_never
17from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
19from ..common import SupportedWeightsFormat
20from ..digest_spec import get_axes_infos, get_member_ids
21from ..sample import Sample, SampleBlock, SampleBlockWithOrigin
22from ..tensor import Tensor
24# Known weight formats in order of priority
25# First match wins
26DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER: Tuple[SupportedWeightsFormat, ...] = (
27 "pytorch_state_dict",
28 "tensorflow_saved_model_bundle",
29 "torchscript",
30 "onnx",
31 "keras_hdf5",
32)
35class ModelAdapter(ABC):
36 """
37 Represents model *without* any preprocessing or postprocessing.
39 ```
40 from bioimageio.core import load_description
42 model = load_description(...)
44 # option 1:
45 adapter = ModelAdapter.create(model)
46 adapter.forward(...)
47 adapter.unload()
49 # option 2:
50 with ModelAdapter.create(model) as adapter:
51 adapter.forward(...)
52 ```
53 """
55 def __init__(self, model_description: AnyModelDescr):
56 super().__init__()
57 self._model_descr = model_description
58 self._input_ids = get_member_ids(model_description.inputs)
59 self._output_ids = get_member_ids(model_description.outputs)
60 self._input_axes = [
61 tuple(a.id for a in get_axes_infos(t)) for t in model_description.inputs
62 ]
63 self._output_axes = [
64 tuple(a.id for a in get_axes_infos(t)) for t in model_description.outputs
65 ]
66 if isinstance(model_description, v0_4.ModelDescr):
67 self._input_is_optional = [False] * len(model_description.inputs)
68 else:
69 self._input_is_optional = [ipt.optional for ipt in model_description.inputs]
71 @final
72 @classmethod
73 def create(
74 cls,
75 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
76 *,
77 devices: Optional[Sequence[str]] = None,
78 weight_format_priority_order: Optional[Sequence[SupportedWeightsFormat]] = None,
79 ):
80 """
81 Creates model adapter based on the passed spec
82 Note: All specific adapters should happen inside this function to prevent different framework
83 initializations interfering with each other
84 """
85 if not isinstance(model_description, (v0_4.ModelDescr, v0_5.ModelDescr)):
86 raise TypeError(
87 f"expected v0_4.ModelDescr or v0_5.ModelDescr, but got {type(model_description)}"
88 )
90 weights = model_description.weights
91 errors: List[Exception] = []
92 weight_format_priority_order = (
93 DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER
94 if weight_format_priority_order is None
95 else weight_format_priority_order
96 )
97 # limit weight formats to the ones present
98 weight_format_priority_order_present: Sequence[SupportedWeightsFormat] = [
99 w for w in weight_format_priority_order if getattr(weights, w) is not None
100 ]
101 if not weight_format_priority_order_present:
102 raise ValueError(
103 f"None of the specified weight formats ({weight_format_priority_order}) is present ({weight_format_priority_order_present})"
104 )
106 for wf in weight_format_priority_order_present:
107 if wf == "pytorch_state_dict":
108 assert weights.pytorch_state_dict is not None
109 try:
110 from .pytorch_backend import PytorchModelAdapter
112 return PytorchModelAdapter(
113 model_description=model_description, devices=devices
114 )
115 except Exception as e:
116 errors.append(e)
117 elif wf == "tensorflow_saved_model_bundle":
118 assert weights.tensorflow_saved_model_bundle is not None
119 try:
120 from .tensorflow_backend import create_tf_model_adapter
122 return create_tf_model_adapter(
123 model_description=model_description, devices=devices
124 )
125 except Exception as e:
126 errors.append(e)
127 elif wf == "onnx":
128 assert weights.onnx is not None
129 try:
130 from .onnx_backend import ONNXModelAdapter
132 return ONNXModelAdapter(
133 model_description=model_description, devices=devices
134 )
135 except Exception as e:
136 errors.append(e)
137 elif wf == "torchscript":
138 assert weights.torchscript is not None
139 try:
140 from .torchscript_backend import TorchscriptModelAdapter
142 return TorchscriptModelAdapter(
143 model_description=model_description, devices=devices
144 )
145 except Exception as e:
146 errors.append(e)
147 elif wf == "keras_hdf5":
148 assert weights.keras_hdf5 is not None
149 # keras can either be installed as a separate package or used as part of tensorflow
150 # we try to first import the keras model adapter using the separate package and,
151 # if it is not available, try to load the one using tf
152 try:
153 try:
154 from .keras_backend import KerasModelAdapter
155 except Exception:
156 from .tensorflow_backend import KerasModelAdapter
158 return KerasModelAdapter(
159 model_description=model_description, devices=devices
160 )
161 except Exception as e:
162 errors.append(e)
163 else:
164 assert_never(wf)
166 assert errors
167 if len(weight_format_priority_order) == 1:
168 assert len(errors) == 1
169 raise errors[0]
171 else:
172 msg = (
173 "None of the weight format specific model adapters could be created"
174 + " in this environment."
175 )
176 raise ExceptionGroup(msg, errors)
178 @final
179 def load(self, *, devices: Optional[Sequence[str]] = None) -> None:
180 warnings.warn("Deprecated. ModelAdapter is loaded on initialization")
182 def forward(
183 self, input_sample: Union[Sample, SampleBlock, SampleBlockWithOrigin]
184 ) -> Sample:
185 """
186 Run forward pass of model to get model predictions
188 Note: sample id and stample stat attributes are passed through
189 """
190 unexpected = [mid for mid in input_sample.members if mid not in self._input_ids]
191 if unexpected:
192 warnings.warn(f"Got unexpected input tensor IDs: {unexpected}")
194 input_arrays = [
195 (
196 None
197 if (a := input_sample.members.get(in_id)) is None
198 else a.transpose(in_order).data.data
199 )
200 for in_id, in_order in zip(self._input_ids, self._input_axes)
201 ]
202 output_arrays = self._forward_impl(input_arrays)
203 assert len(output_arrays) <= len(self._output_ids)
204 output_tensors = [
205 None if a is None else Tensor(a, dims=d)
206 for a, d in zip(output_arrays, self._output_axes)
207 ]
208 return Sample(
209 members={
210 tid: out
211 for tid, out in zip(
212 self._output_ids,
213 output_tensors,
214 )
215 if out is not None
216 },
217 stat=input_sample.stat,
218 id=(
219 input_sample.id
220 if isinstance(input_sample, Sample)
221 else input_sample.sample_id
222 ),
223 )
225 @abstractmethod
226 def _forward_impl(
227 self, input_arrays: Sequence[Optional[NDArray[Any]]]
228 ) -> Union[List[Optional[NDArray[Any]]], Tuple[Optional[NDArray[Any]]]]:
229 """framework specific forward implementation"""
231 @abstractmethod
232 def unload(self):
233 """
234 Unload model from any devices, freeing their memory.
235 The moder adapter should be considered unusable afterwards.
236 """
238 def _get_input_args_numpy(self, input_sample: Sample):
239 """helper to extract tensor args as transposed numpy arrays"""
242create_model_adapter = ModelAdapter.create