bioimageio.core.model_adapters

DEPRECATED

 1"""DEPRECATED"""
 2
 3from typing import List
 4
 5from .backends._model_adapter import (
 6    DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER,
 7    ModelAdapter,
 8    create_model_adapter,
 9)
10
11__all__ = [
12    "ModelAdapter",
13    "create_model_adapter",
14    "get_weight_formats",
15]
16
17
18def get_weight_formats() -> List[str]:
19    """
20    Return list of supported weight types
21    """
22    return list(DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER)
class ModelAdapter(abc.ABC):
 36class ModelAdapter(ABC):
 37    """
 38    Represents model *without* any preprocessing or postprocessing.
 39
 40    ```
 41    from bioimageio.core import load_description
 42
 43    model = load_description(...)
 44
 45    # option 1:
 46    adapter = ModelAdapter.create(model)
 47    adapter.forward(...)
 48    adapter.unload()
 49
 50    # option 2:
 51    with ModelAdapter.create(model) as adapter:
 52        adapter.forward(...)
 53    ```
 54    """
 55
 56    def __init__(self, model_description: AnyModelDescr):
 57        super().__init__()
 58        self._model_descr = model_description
 59        self._input_ids = get_member_ids(model_description.inputs)
 60        self._output_ids = get_member_ids(model_description.outputs)
 61        self._input_axes = [
 62            tuple(a.id for a in get_axes_infos(t)) for t in model_description.inputs
 63        ]
 64        self._output_axes = [
 65            tuple(a.id for a in get_axes_infos(t)) for t in model_description.outputs
 66        ]
 67        if isinstance(model_description, v0_4.ModelDescr):
 68            self._input_is_optional = [False] * len(model_description.inputs)
 69        else:
 70            self._input_is_optional = [ipt.optional for ipt in model_description.inputs]
 71
 72    @final
 73    @classmethod
 74    def create(
 75        cls,
 76        model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
 77        *,
 78        devices: Optional[Sequence[str]] = None,
 79        weight_format_priority_order: Optional[Sequence[SupportedWeightsFormat]] = None,
 80    ):
 81        """
 82        Creates model adapter based on the passed spec
 83        Note: All specific adapters should happen inside this function to prevent different framework
 84        initializations interfering with each other
 85        """
 86        if not isinstance(model_description, (v0_4.ModelDescr, v0_5.ModelDescr)):
 87            raise TypeError(
 88                f"expected v0_4.ModelDescr or v0_5.ModelDescr, but got {type(model_description)}"
 89            )
 90
 91        weights = model_description.weights
 92        errors: List[Exception] = []
 93        weight_format_priority_order = (
 94            DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER
 95            if weight_format_priority_order is None
 96            else weight_format_priority_order
 97        )
 98        # limit weight formats to the ones present
 99        weight_format_priority_order_present: Sequence[SupportedWeightsFormat] = [
100            w for w in weight_format_priority_order if getattr(weights, w) is not None
101        ]
102        if not weight_format_priority_order_present:
103            raise ValueError(
104                f"None of the specified weight formats ({weight_format_priority_order}) is present ({weight_format_priority_order_present})"
105            )
106
107        for wf in weight_format_priority_order_present:
108            if wf == "pytorch_state_dict":
109                assert weights.pytorch_state_dict is not None
110                try:
111                    from .pytorch_backend import PytorchModelAdapter
112
113                    return PytorchModelAdapter(
114                        model_description=model_description, devices=devices
115                    )
116                except Exception as e:
117                    errors.append(e)
118            elif wf == "tensorflow_saved_model_bundle":
119                assert weights.tensorflow_saved_model_bundle is not None
120                try:
121                    from .tensorflow_backend import create_tf_model_adapter
122
123                    return create_tf_model_adapter(
124                        model_description=model_description, devices=devices
125                    )
126                except Exception as e:
127                    errors.append(e)
128            elif wf == "onnx":
129                assert weights.onnx is not None
130                try:
131                    from .onnx_backend import ONNXModelAdapter
132
133                    return ONNXModelAdapter(
134                        model_description=model_description, devices=devices
135                    )
136                except Exception as e:
137                    errors.append(e)
138            elif wf == "torchscript":
139                assert weights.torchscript is not None
140                try:
141                    from .torchscript_backend import TorchscriptModelAdapter
142
143                    return TorchscriptModelAdapter(
144                        model_description=model_description, devices=devices
145                    )
146                except Exception as e:
147                    errors.append(e)
148            elif wf == "keras_hdf5":
149                assert weights.keras_hdf5 is not None
150                # keras can either be installed as a separate package or used as part of tensorflow
151                # we try to first import the keras model adapter using the separate package and,
152                # if it is not available, try to load the one using tf
153                try:
154                    try:
155                        from .keras_backend import KerasModelAdapter
156                    except Exception:
157                        from .tensorflow_backend import KerasModelAdapter
158
159                    return KerasModelAdapter(
160                        model_description=model_description, devices=devices
161                    )
162                except Exception as e:
163                    errors.append(e)
164            else:
165                assert_never(wf)
166
167        assert errors
168        if len(weight_format_priority_order) == 1:
169            assert len(errors) == 1
170            raise errors[0]
171
172        else:
173            msg = (
174                "None of the weight format specific model adapters could be created"
175                + " in this environment."
176            )
177            if sys.version_info[:2] >= (3, 11):
178                raise ExceptionGroup(msg, errors)
179            else:
180                raise ValueError(msg) from Exception(errors)
181
182    @final
183    def load(self, *, devices: Optional[Sequence[str]] = None) -> None:
184        warnings.warn("Deprecated. ModelAdapter is loaded on initialization")
185
186    def forward(
187        self, input_sample: Union[Sample, SampleBlock, SampleBlockWithOrigin]
188    ) -> Sample:
189        """
190        Run forward pass of model to get model predictions
191
192        Note: sample id and stample stat attributes are passed through
193        """
194        unexpected = [mid for mid in input_sample.members if mid not in self._input_ids]
195        if unexpected:
196            warnings.warn(f"Got unexpected input tensor IDs: {unexpected}")
197
198        input_arrays = [
199            (
200                None
201                if (a := input_sample.members.get(in_id)) is None
202                else a.transpose(in_order).data.data
203            )
204            for in_id, in_order in zip(self._input_ids, self._input_axes)
205        ]
206        output_arrays = self._forward_impl(input_arrays)
207        assert len(output_arrays) <= len(self._output_ids)
208        output_tensors = [
209            None if a is None else Tensor(a, dims=d)
210            for a, d in zip(output_arrays, self._output_axes)
211        ]
212        return Sample(
213            members={
214                tid: out
215                for tid, out in zip(
216                    self._output_ids,
217                    output_tensors,
218                )
219                if out is not None
220            },
221            stat=input_sample.stat,
222            id=(
223                input_sample.id
224                if isinstance(input_sample, Sample)
225                else input_sample.sample_id
226            ),
227        )
228
229    @abstractmethod
230    def _forward_impl(
231        self, input_arrays: Sequence[Optional[NDArray[Any]]]
232    ) -> Union[List[Optional[NDArray[Any]]], Tuple[Optional[NDArray[Any]]]]:
233        """framework specific forward implementation"""
234
235    @abstractmethod
236    def unload(self):
237        """
238        Unload model from any devices, freeing their memory.
239        The moder adapter should be considered unusable afterwards.
240        """
241
242    def _get_input_args_numpy(self, input_sample: Sample):
243        """helper to extract tensor args as transposed numpy arrays"""

Represents model without any preprocessing or postprocessing.

from bioimageio.core import load_description

model = load_description(...)

# option 1:
adapter = ModelAdapter.create(model)
adapter.forward(...)
adapter.unload()

# option 2:
with ModelAdapter.create(model) as adapter:
    adapter.forward(...)
@final
@classmethod
def create( cls, model_description: Union[bioimageio.spec.model.v0_4.ModelDescr, bioimageio.spec.ModelDescr], *, devices: Optional[Sequence[str]] = None, weight_format_priority_order: Optional[Sequence[Literal['keras_hdf5', 'onnx', 'pytorch_state_dict', 'tensorflow_saved_model_bundle', 'torchscript']]] = None):
 72    @final
 73    @classmethod
 74    def create(
 75        cls,
 76        model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
 77        *,
 78        devices: Optional[Sequence[str]] = None,
 79        weight_format_priority_order: Optional[Sequence[SupportedWeightsFormat]] = None,
 80    ):
 81        """
 82        Creates model adapter based on the passed spec
 83        Note: All specific adapters should happen inside this function to prevent different framework
 84        initializations interfering with each other
 85        """
 86        if not isinstance(model_description, (v0_4.ModelDescr, v0_5.ModelDescr)):
 87            raise TypeError(
 88                f"expected v0_4.ModelDescr or v0_5.ModelDescr, but got {type(model_description)}"
 89            )
 90
 91        weights = model_description.weights
 92        errors: List[Exception] = []
 93        weight_format_priority_order = (
 94            DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER
 95            if weight_format_priority_order is None
 96            else weight_format_priority_order
 97        )
 98        # limit weight formats to the ones present
 99        weight_format_priority_order_present: Sequence[SupportedWeightsFormat] = [
100            w for w in weight_format_priority_order if getattr(weights, w) is not None
101        ]
102        if not weight_format_priority_order_present:
103            raise ValueError(
104                f"None of the specified weight formats ({weight_format_priority_order}) is present ({weight_format_priority_order_present})"
105            )
106
107        for wf in weight_format_priority_order_present:
108            if wf == "pytorch_state_dict":
109                assert weights.pytorch_state_dict is not None
110                try:
111                    from .pytorch_backend import PytorchModelAdapter
112
113                    return PytorchModelAdapter(
114                        model_description=model_description, devices=devices
115                    )
116                except Exception as e:
117                    errors.append(e)
118            elif wf == "tensorflow_saved_model_bundle":
119                assert weights.tensorflow_saved_model_bundle is not None
120                try:
121                    from .tensorflow_backend import create_tf_model_adapter
122
123                    return create_tf_model_adapter(
124                        model_description=model_description, devices=devices
125                    )
126                except Exception as e:
127                    errors.append(e)
128            elif wf == "onnx":
129                assert weights.onnx is not None
130                try:
131                    from .onnx_backend import ONNXModelAdapter
132
133                    return ONNXModelAdapter(
134                        model_description=model_description, devices=devices
135                    )
136                except Exception as e:
137                    errors.append(e)
138            elif wf == "torchscript":
139                assert weights.torchscript is not None
140                try:
141                    from .torchscript_backend import TorchscriptModelAdapter
142
143                    return TorchscriptModelAdapter(
144                        model_description=model_description, devices=devices
145                    )
146                except Exception as e:
147                    errors.append(e)
148            elif wf == "keras_hdf5":
149                assert weights.keras_hdf5 is not None
150                # keras can either be installed as a separate package or used as part of tensorflow
151                # we try to first import the keras model adapter using the separate package and,
152                # if it is not available, try to load the one using tf
153                try:
154                    try:
155                        from .keras_backend import KerasModelAdapter
156                    except Exception:
157                        from .tensorflow_backend import KerasModelAdapter
158
159                    return KerasModelAdapter(
160                        model_description=model_description, devices=devices
161                    )
162                except Exception as e:
163                    errors.append(e)
164            else:
165                assert_never(wf)
166
167        assert errors
168        if len(weight_format_priority_order) == 1:
169            assert len(errors) == 1
170            raise errors[0]
171
172        else:
173            msg = (
174                "None of the weight format specific model adapters could be created"
175                + " in this environment."
176            )
177            if sys.version_info[:2] >= (3, 11):
178                raise ExceptionGroup(msg, errors)
179            else:
180                raise ValueError(msg) from Exception(errors)

Creates model adapter based on the passed spec Note: All specific adapters should happen inside this function to prevent different framework initializations interfering with each other

@final
def load(self, *, devices: Optional[Sequence[str]] = None) -> None:
182    @final
183    def load(self, *, devices: Optional[Sequence[str]] = None) -> None:
184        warnings.warn("Deprecated. ModelAdapter is loaded on initialization")
186    def forward(
187        self, input_sample: Union[Sample, SampleBlock, SampleBlockWithOrigin]
188    ) -> Sample:
189        """
190        Run forward pass of model to get model predictions
191
192        Note: sample id and stample stat attributes are passed through
193        """
194        unexpected = [mid for mid in input_sample.members if mid not in self._input_ids]
195        if unexpected:
196            warnings.warn(f"Got unexpected input tensor IDs: {unexpected}")
197
198        input_arrays = [
199            (
200                None
201                if (a := input_sample.members.get(in_id)) is None
202                else a.transpose(in_order).data.data
203            )
204            for in_id, in_order in zip(self._input_ids, self._input_axes)
205        ]
206        output_arrays = self._forward_impl(input_arrays)
207        assert len(output_arrays) <= len(self._output_ids)
208        output_tensors = [
209            None if a is None else Tensor(a, dims=d)
210            for a, d in zip(output_arrays, self._output_axes)
211        ]
212        return Sample(
213            members={
214                tid: out
215                for tid, out in zip(
216                    self._output_ids,
217                    output_tensors,
218                )
219                if out is not None
220            },
221            stat=input_sample.stat,
222            id=(
223                input_sample.id
224                if isinstance(input_sample, Sample)
225                else input_sample.sample_id
226            ),
227        )

Run forward pass of model to get model predictions

Note: sample id and stample stat attributes are passed through

@abstractmethod
def unload(self):
235    @abstractmethod
236    def unload(self):
237        """
238        Unload model from any devices, freeing their memory.
239        The moder adapter should be considered unusable afterwards.
240        """

Unload model from any devices, freeing their memory. The moder adapter should be considered unusable afterwards.

@final
@classmethod
def create_model_adapter( model_description: Union[bioimageio.spec.model.v0_4.ModelDescr, bioimageio.spec.ModelDescr], *, devices: Optional[Sequence[str]] = None, weight_format_priority_order: Optional[Sequence[Literal['keras_hdf5', 'onnx', 'pytorch_state_dict', 'tensorflow_saved_model_bundle', 'torchscript']]] = None):
 72    @final
 73    @classmethod
 74    def create(
 75        cls,
 76        model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
 77        *,
 78        devices: Optional[Sequence[str]] = None,
 79        weight_format_priority_order: Optional[Sequence[SupportedWeightsFormat]] = None,
 80    ):
 81        """
 82        Creates model adapter based on the passed spec
 83        Note: All specific adapters should happen inside this function to prevent different framework
 84        initializations interfering with each other
 85        """
 86        if not isinstance(model_description, (v0_4.ModelDescr, v0_5.ModelDescr)):
 87            raise TypeError(
 88                f"expected v0_4.ModelDescr or v0_5.ModelDescr, but got {type(model_description)}"
 89            )
 90
 91        weights = model_description.weights
 92        errors: List[Exception] = []
 93        weight_format_priority_order = (
 94            DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER
 95            if weight_format_priority_order is None
 96            else weight_format_priority_order
 97        )
 98        # limit weight formats to the ones present
 99        weight_format_priority_order_present: Sequence[SupportedWeightsFormat] = [
100            w for w in weight_format_priority_order if getattr(weights, w) is not None
101        ]
102        if not weight_format_priority_order_present:
103            raise ValueError(
104                f"None of the specified weight formats ({weight_format_priority_order}) is present ({weight_format_priority_order_present})"
105            )
106
107        for wf in weight_format_priority_order_present:
108            if wf == "pytorch_state_dict":
109                assert weights.pytorch_state_dict is not None
110                try:
111                    from .pytorch_backend import PytorchModelAdapter
112
113                    return PytorchModelAdapter(
114                        model_description=model_description, devices=devices
115                    )
116                except Exception as e:
117                    errors.append(e)
118            elif wf == "tensorflow_saved_model_bundle":
119                assert weights.tensorflow_saved_model_bundle is not None
120                try:
121                    from .tensorflow_backend import create_tf_model_adapter
122
123                    return create_tf_model_adapter(
124                        model_description=model_description, devices=devices
125                    )
126                except Exception as e:
127                    errors.append(e)
128            elif wf == "onnx":
129                assert weights.onnx is not None
130                try:
131                    from .onnx_backend import ONNXModelAdapter
132
133                    return ONNXModelAdapter(
134                        model_description=model_description, devices=devices
135                    )
136                except Exception as e:
137                    errors.append(e)
138            elif wf == "torchscript":
139                assert weights.torchscript is not None
140                try:
141                    from .torchscript_backend import TorchscriptModelAdapter
142
143                    return TorchscriptModelAdapter(
144                        model_description=model_description, devices=devices
145                    )
146                except Exception as e:
147                    errors.append(e)
148            elif wf == "keras_hdf5":
149                assert weights.keras_hdf5 is not None
150                # keras can either be installed as a separate package or used as part of tensorflow
151                # we try to first import the keras model adapter using the separate package and,
152                # if it is not available, try to load the one using tf
153                try:
154                    try:
155                        from .keras_backend import KerasModelAdapter
156                    except Exception:
157                        from .tensorflow_backend import KerasModelAdapter
158
159                    return KerasModelAdapter(
160                        model_description=model_description, devices=devices
161                    )
162                except Exception as e:
163                    errors.append(e)
164            else:
165                assert_never(wf)
166
167        assert errors
168        if len(weight_format_priority_order) == 1:
169            assert len(errors) == 1
170            raise errors[0]
171
172        else:
173            msg = (
174                "None of the weight format specific model adapters could be created"
175                + " in this environment."
176            )
177            if sys.version_info[:2] >= (3, 11):
178                raise ExceptionGroup(msg, errors)
179            else:
180                raise ValueError(msg) from Exception(errors)

Creates model adapter based on the passed spec Note: All specific adapters should happen inside this function to prevent different framework initializations interfering with each other

def get_weight_formats() -> List[str]:
19def get_weight_formats() -> List[str]:
20    """
21    Return list of supported weight types
22    """
23    return list(DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER)

Return list of supported weight types