Coverage for bioimageio/core/backends/_model_adapter.py: 76%
99 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 09:51 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 09:51 +0000
1import sys
2import warnings
3from abc import ABC, abstractmethod
4from typing import (
5 Any,
6 List,
7 Optional,
8 Sequence,
9 Tuple,
10 Union,
11 final,
12)
14from numpy.typing import NDArray
15from typing_extensions import assert_never
17from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
19from ..common import SupportedWeightsFormat
20from ..digest_spec import get_axes_infos, get_member_ids
21from ..sample import Sample, SampleBlock, SampleBlockWithOrigin
22from ..tensor import Tensor
24# Known weight formats in order of priority
25# First match wins
26DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER: Tuple[SupportedWeightsFormat, ...] = (
27 "pytorch_state_dict",
28 "tensorflow_saved_model_bundle",
29 "torchscript",
30 "onnx",
31 "keras_hdf5",
32)
35class ModelAdapter(ABC):
36 """
37 Represents model *without* any preprocessing or postprocessing.
39 ```
40 from bioimageio.core import load_description
42 model = load_description(...)
44 # option 1:
45 adapter = ModelAdapter.create(model)
46 adapter.forward(...)
47 adapter.unload()
49 # option 2:
50 with ModelAdapter.create(model) as adapter:
51 adapter.forward(...)
52 ```
53 """
55 def __init__(self, model_description: AnyModelDescr):
56 super().__init__()
57 self._model_descr = model_description
58 self._input_ids = get_member_ids(model_description.inputs)
59 self._output_ids = get_member_ids(model_description.outputs)
60 self._input_axes = [
61 tuple(a.id for a in get_axes_infos(t)) for t in model_description.inputs
62 ]
63 self._output_axes = [
64 tuple(a.id for a in get_axes_infos(t)) for t in model_description.outputs
65 ]
66 if isinstance(model_description, v0_4.ModelDescr):
67 self._input_is_optional = [False] * len(model_description.inputs)
68 else:
69 self._input_is_optional = [ipt.optional for ipt in model_description.inputs]
71 @final
72 @classmethod
73 def create(
74 cls,
75 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
76 *,
77 devices: Optional[Sequence[str]] = None,
78 weight_format_priority_order: Optional[Sequence[SupportedWeightsFormat]] = None,
79 ):
80 """
81 Creates model adapter based on the passed spec
82 Note: All specific adapters should happen inside this function to prevent different framework
83 initializations interfering with each other
84 """
85 if not isinstance(model_description, (v0_4.ModelDescr, v0_5.ModelDescr)):
86 raise TypeError(
87 f"expected v0_4.ModelDescr or v0_5.ModelDescr, but got {type(model_description)}"
88 )
90 weights = model_description.weights
91 errors: List[Exception] = []
92 weight_format_priority_order = (
93 DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER
94 if weight_format_priority_order is None
95 else weight_format_priority_order
96 )
97 # limit weight formats to the ones present
98 weight_format_priority_order_present: Sequence[SupportedWeightsFormat] = [
99 w for w in weight_format_priority_order if getattr(weights, w) is not None
100 ]
101 if not weight_format_priority_order_present:
102 raise ValueError(
103 f"None of the specified weight formats ({weight_format_priority_order}) is present ({weight_format_priority_order_present})"
104 )
106 for wf in weight_format_priority_order_present:
107 if wf == "pytorch_state_dict":
108 assert weights.pytorch_state_dict is not None
109 try:
110 from .pytorch_backend import PytorchModelAdapter
112 return PytorchModelAdapter(
113 model_description=model_description, devices=devices
114 )
115 except Exception as e:
116 errors.append(e)
117 elif wf == "tensorflow_saved_model_bundle":
118 assert weights.tensorflow_saved_model_bundle is not None
119 try:
120 from .tensorflow_backend import create_tf_model_adapter
122 return create_tf_model_adapter(
123 model_description=model_description, devices=devices
124 )
125 except Exception as e:
126 errors.append(e)
127 elif wf == "onnx":
128 assert weights.onnx is not None
129 try:
130 from .onnx_backend import ONNXModelAdapter
132 return ONNXModelAdapter(
133 model_description=model_description, devices=devices
134 )
135 except Exception as e:
136 errors.append(e)
137 elif wf == "torchscript":
138 assert weights.torchscript is not None
139 try:
140 from .torchscript_backend import TorchscriptModelAdapter
142 return TorchscriptModelAdapter(
143 model_description=model_description, devices=devices
144 )
145 except Exception as e:
146 errors.append(e)
147 elif wf == "keras_hdf5":
148 assert weights.keras_hdf5 is not None
149 # keras can either be installed as a separate package or used as part of tensorflow
150 # we try to first import the keras model adapter using the separate package and,
151 # if it is not available, try to load the one using tf
152 try:
153 try:
154 from .keras_backend import KerasModelAdapter
155 except Exception:
156 from .tensorflow_backend import KerasModelAdapter
158 return KerasModelAdapter(
159 model_description=model_description, devices=devices
160 )
161 except Exception as e:
162 errors.append(e)
163 else:
164 assert_never(wf)
166 assert errors
167 if len(weight_format_priority_order) == 1:
168 assert len(errors) == 1
169 raise errors[0]
171 else:
172 msg = (
173 "None of the weight format specific model adapters could be created"
174 + " in this environment."
175 )
176 if sys.version_info[:2] >= (3, 11):
177 raise ExceptionGroup(msg, errors)
178 else:
179 raise ValueError(msg) from Exception(errors)
181 @final
182 def load(self, *, devices: Optional[Sequence[str]] = None) -> None:
183 warnings.warn("Deprecated. ModelAdapter is loaded on initialization")
185 def forward(
186 self, input_sample: Union[Sample, SampleBlock, SampleBlockWithOrigin]
187 ) -> Sample:
188 """
189 Run forward pass of model to get model predictions
191 Note: sample id and stample stat attributes are passed through
192 """
193 unexpected = [mid for mid in input_sample.members if mid not in self._input_ids]
194 if unexpected:
195 warnings.warn(f"Got unexpected input tensor IDs: {unexpected}")
197 input_arrays = [
198 (
199 None
200 if (a := input_sample.members.get(in_id)) is None
201 else a.transpose(in_order).data.data
202 )
203 for in_id, in_order in zip(self._input_ids, self._input_axes)
204 ]
205 output_arrays = self._forward_impl(input_arrays)
206 assert len(output_arrays) <= len(self._output_ids)
207 output_tensors = [
208 None if a is None else Tensor(a, dims=d)
209 for a, d in zip(output_arrays, self._output_axes)
210 ]
211 return Sample(
212 members={
213 tid: out
214 for tid, out in zip(
215 self._output_ids,
216 output_tensors,
217 )
218 if out is not None
219 },
220 stat=input_sample.stat,
221 id=(
222 input_sample.id
223 if isinstance(input_sample, Sample)
224 else input_sample.sample_id
225 ),
226 )
228 @abstractmethod
229 def _forward_impl(
230 self, input_arrays: Sequence[Optional[NDArray[Any]]]
231 ) -> Union[List[Optional[NDArray[Any]]], Tuple[Optional[NDArray[Any]]]]:
232 """framework specific forward implementation"""
234 @abstractmethod
235 def unload(self):
236 """
237 Unload model from any devices, freeing their memory.
238 The moder adapter should be considered unusable afterwards.
239 """
241 def _get_input_args_numpy(self, input_sample: Sample):
242 """helper to extract tensor args as transposed numpy arrays"""
245create_model_adapter = ModelAdapter.create