bioimageio.core.model_adapters

DEPRECATED

 1"""DEPRECATED"""
 2
 3from typing import List
 4
 5from .backends._model_adapter import (
 6    DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER,
 7    ModelAdapter,
 8    create_model_adapter,
 9)
10
11__all__ = [
12    "ModelAdapter",
13    "create_model_adapter",
14    "get_weight_formats",
15]
16
17
18def get_weight_formats() -> List[str]:
19    """
20    Return list of supported weight types
21    """
22    return list(DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER)
class ModelAdapter(abc.ABC):
 36class ModelAdapter(ABC):
 37    """
 38    Represents model *without* any preprocessing or postprocessing.
 39
 40    ```
 41    from bioimageio.core import load_description
 42
 43    model = load_description(...)
 44
 45    # option 1:
 46    adapter = ModelAdapter.create(model)
 47    adapter.forward(...)
 48    adapter.unload()
 49
 50    # option 2:
 51    with ModelAdapter.create(model) as adapter:
 52        adapter.forward(...)
 53    ```
 54    """
 55
 56    def __init__(self, model_description: AnyModelDescr):
 57        super().__init__()
 58        self._model_descr = model_description
 59        self._input_ids = get_member_ids(model_description.inputs)
 60        self._output_ids = get_member_ids(model_description.outputs)
 61        self._input_axes = [
 62            tuple(a.id for a in get_axes_infos(t)) for t in model_description.inputs
 63        ]
 64        self._output_axes = [
 65            tuple(a.id for a in get_axes_infos(t)) for t in model_description.outputs
 66        ]
 67        if isinstance(model_description, v0_4.ModelDescr):
 68            self._input_is_optional = [False] * len(model_description.inputs)
 69        else:
 70            self._input_is_optional = [ipt.optional for ipt in model_description.inputs]
 71
 72    @final
 73    @classmethod
 74    def create(
 75        cls,
 76        model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
 77        *,
 78        devices: Optional[Sequence[str]] = None,
 79        weight_format_priority_order: Optional[Sequence[SupportedWeightsFormat]] = None,
 80    ):
 81        """
 82        Creates model adapter based on the passed spec
 83        Note: All specific adapters should happen inside this function to prevent different framework
 84        initializations interfering with each other
 85        """
 86        if not isinstance(model_description, (v0_4.ModelDescr, v0_5.ModelDescr)):
 87            raise TypeError(
 88                f"expected v0_4.ModelDescr or v0_5.ModelDescr, but got {type(model_description)}"
 89            )
 90
 91        weights = model_description.weights
 92        errors: List[Exception] = []
 93        weight_format_priority_order = (
 94            DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER
 95            if weight_format_priority_order is None
 96            else weight_format_priority_order
 97        )
 98        # limit weight formats to the ones present
 99        weight_format_priority_order_present: Sequence[SupportedWeightsFormat] = [
100            w for w in weight_format_priority_order if getattr(weights, w) is not None
101        ]
102        if not weight_format_priority_order_present:
103            raise ValueError(
104                f"None of the specified weight formats ({weight_format_priority_order}) is present ({weight_format_priority_order_present})"
105            )
106
107        for wf in weight_format_priority_order_present:
108            if wf == "pytorch_state_dict":
109                assert weights.pytorch_state_dict is not None
110                try:
111                    from .pytorch_backend import PytorchModelAdapter
112
113                    return PytorchModelAdapter(
114                        model_description=model_description, devices=devices
115                    )
116                except Exception as e:
117                    errors.append(e)
118            elif wf == "tensorflow_saved_model_bundle":
119                assert weights.tensorflow_saved_model_bundle is not None
120                try:
121                    from .tensorflow_backend import create_tf_model_adapter
122
123                    return create_tf_model_adapter(
124                        model_description=model_description, devices=devices
125                    )
126                except Exception as e:
127                    errors.append(e)
128            elif wf == "onnx":
129                assert weights.onnx is not None
130                try:
131                    from .onnx_backend import ONNXModelAdapter
132
133                    return ONNXModelAdapter(
134                        model_description=model_description, devices=devices
135                    )
136                except Exception as e:
137                    errors.append(e)
138            elif wf == "torchscript":
139                assert weights.torchscript is not None
140                try:
141                    from .torchscript_backend import TorchscriptModelAdapter
142
143                    return TorchscriptModelAdapter(
144                        model_description=model_description, devices=devices
145                    )
146                except Exception as e:
147                    errors.append(e)
148            elif wf == "keras_hdf5":
149                assert weights.keras_hdf5 is not None
150                # keras can either be installed as a separate package or used as part of tensorflow
151                # we try to first import the keras model adapter using the separate package and,
152                # if it is not available, try to load the one using tf
153                try:
154                    try:
155                        from .keras_backend import KerasModelAdapter
156                    except Exception:
157                        from .tensorflow_backend import KerasModelAdapter
158
159                    return KerasModelAdapter(
160                        model_description=model_description, devices=devices
161                    )
162                except Exception as e:
163                    errors.append(e)
164            else:
165                assert_never(wf)
166
167        assert errors
168        if len(weight_format_priority_order) == 1:
169            assert len(errors) == 1
170            raise errors[0]
171
172        else:
173            msg = (
174                "None of the weight format specific model adapters could be created"
175                + " in this environment."
176            )
177            raise ExceptionGroup(msg, errors)
178
179    @final
180    def load(self, *, devices: Optional[Sequence[str]] = None) -> None:
181        warnings.warn("Deprecated. ModelAdapter is loaded on initialization")
182
183    def forward(
184        self, input_sample: Union[Sample, SampleBlock, SampleBlockWithOrigin]
185    ) -> Sample:
186        """
187        Run forward pass of model to get model predictions
188
189        Note: sample id and stample stat attributes are passed through
190        """
191        unexpected = [mid for mid in input_sample.members if mid not in self._input_ids]
192        if unexpected:
193            warnings.warn(f"Got unexpected input tensor IDs: {unexpected}")
194
195        input_arrays = [
196            (
197                None
198                if (a := input_sample.members.get(in_id)) is None
199                else a.transpose(in_order).data.data
200            )
201            for in_id, in_order in zip(self._input_ids, self._input_axes)
202        ]
203        output_arrays = self._forward_impl(input_arrays)
204        assert len(output_arrays) <= len(self._output_ids)
205        output_tensors = [
206            None if a is None else Tensor(a, dims=d)
207            for a, d in zip(output_arrays, self._output_axes)
208        ]
209        return Sample(
210            members={
211                tid: out
212                for tid, out in zip(
213                    self._output_ids,
214                    output_tensors,
215                )
216                if out is not None
217            },
218            stat=input_sample.stat,
219            id=(
220                input_sample.id
221                if isinstance(input_sample, Sample)
222                else input_sample.sample_id
223            ),
224        )
225
226    @abstractmethod
227    def _forward_impl(
228        self, input_arrays: Sequence[Optional[NDArray[Any]]]
229    ) -> Union[List[Optional[NDArray[Any]]], Tuple[Optional[NDArray[Any]]]]:
230        """framework specific forward implementation"""
231
232    @abstractmethod
233    def unload(self):
234        """
235        Unload model from any devices, freeing their memory.
236        The moder adapter should be considered unusable afterwards.
237        """
238
239    def _get_input_args_numpy(self, input_sample: Sample):
240        """helper to extract tensor args as transposed numpy arrays"""

Represents model without any preprocessing or postprocessing.

from bioimageio.core import load_description

model = load_description(...)

# option 1:
adapter = ModelAdapter.create(model)
adapter.forward(...)
adapter.unload()

# option 2:
with ModelAdapter.create(model) as adapter:
    adapter.forward(...)
@final
@classmethod
def create( cls, model_description: Union[bioimageio.spec.model.v0_4.ModelDescr, bioimageio.spec.ModelDescr], *, devices: Optional[Sequence[str]] = None, weight_format_priority_order: Optional[Sequence[Literal['keras_hdf5', 'onnx', 'pytorch_state_dict', 'tensorflow_saved_model_bundle', 'torchscript']]] = None):
 72    @final
 73    @classmethod
 74    def create(
 75        cls,
 76        model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
 77        *,
 78        devices: Optional[Sequence[str]] = None,
 79        weight_format_priority_order: Optional[Sequence[SupportedWeightsFormat]] = None,
 80    ):
 81        """
 82        Creates model adapter based on the passed spec
 83        Note: All specific adapters should happen inside this function to prevent different framework
 84        initializations interfering with each other
 85        """
 86        if not isinstance(model_description, (v0_4.ModelDescr, v0_5.ModelDescr)):
 87            raise TypeError(
 88                f"expected v0_4.ModelDescr or v0_5.ModelDescr, but got {type(model_description)}"
 89            )
 90
 91        weights = model_description.weights
 92        errors: List[Exception] = []
 93        weight_format_priority_order = (
 94            DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER
 95            if weight_format_priority_order is None
 96            else weight_format_priority_order
 97        )
 98        # limit weight formats to the ones present
 99        weight_format_priority_order_present: Sequence[SupportedWeightsFormat] = [
100            w for w in weight_format_priority_order if getattr(weights, w) is not None
101        ]
102        if not weight_format_priority_order_present:
103            raise ValueError(
104                f"None of the specified weight formats ({weight_format_priority_order}) is present ({weight_format_priority_order_present})"
105            )
106
107        for wf in weight_format_priority_order_present:
108            if wf == "pytorch_state_dict":
109                assert weights.pytorch_state_dict is not None
110                try:
111                    from .pytorch_backend import PytorchModelAdapter
112
113                    return PytorchModelAdapter(
114                        model_description=model_description, devices=devices
115                    )
116                except Exception as e:
117                    errors.append(e)
118            elif wf == "tensorflow_saved_model_bundle":
119                assert weights.tensorflow_saved_model_bundle is not None
120                try:
121                    from .tensorflow_backend import create_tf_model_adapter
122
123                    return create_tf_model_adapter(
124                        model_description=model_description, devices=devices
125                    )
126                except Exception as e:
127                    errors.append(e)
128            elif wf == "onnx":
129                assert weights.onnx is not None
130                try:
131                    from .onnx_backend import ONNXModelAdapter
132
133                    return ONNXModelAdapter(
134                        model_description=model_description, devices=devices
135                    )
136                except Exception as e:
137                    errors.append(e)
138            elif wf == "torchscript":
139                assert weights.torchscript is not None
140                try:
141                    from .torchscript_backend import TorchscriptModelAdapter
142
143                    return TorchscriptModelAdapter(
144                        model_description=model_description, devices=devices
145                    )
146                except Exception as e:
147                    errors.append(e)
148            elif wf == "keras_hdf5":
149                assert weights.keras_hdf5 is not None
150                # keras can either be installed as a separate package or used as part of tensorflow
151                # we try to first import the keras model adapter using the separate package and,
152                # if it is not available, try to load the one using tf
153                try:
154                    try:
155                        from .keras_backend import KerasModelAdapter
156                    except Exception:
157                        from .tensorflow_backend import KerasModelAdapter
158
159                    return KerasModelAdapter(
160                        model_description=model_description, devices=devices
161                    )
162                except Exception as e:
163                    errors.append(e)
164            else:
165                assert_never(wf)
166
167        assert errors
168        if len(weight_format_priority_order) == 1:
169            assert len(errors) == 1
170            raise errors[0]
171
172        else:
173            msg = (
174                "None of the weight format specific model adapters could be created"
175                + " in this environment."
176            )
177            raise ExceptionGroup(msg, errors)

Creates model adapter based on the passed spec Note: All specific adapters should happen inside this function to prevent different framework initializations interfering with each other

@final
def load(self, *, devices: Optional[Sequence[str]] = None) -> None:
179    @final
180    def load(self, *, devices: Optional[Sequence[str]] = None) -> None:
181        warnings.warn("Deprecated. ModelAdapter is loaded on initialization")
183    def forward(
184        self, input_sample: Union[Sample, SampleBlock, SampleBlockWithOrigin]
185    ) -> Sample:
186        """
187        Run forward pass of model to get model predictions
188
189        Note: sample id and stample stat attributes are passed through
190        """
191        unexpected = [mid for mid in input_sample.members if mid not in self._input_ids]
192        if unexpected:
193            warnings.warn(f"Got unexpected input tensor IDs: {unexpected}")
194
195        input_arrays = [
196            (
197                None
198                if (a := input_sample.members.get(in_id)) is None
199                else a.transpose(in_order).data.data
200            )
201            for in_id, in_order in zip(self._input_ids, self._input_axes)
202        ]
203        output_arrays = self._forward_impl(input_arrays)
204        assert len(output_arrays) <= len(self._output_ids)
205        output_tensors = [
206            None if a is None else Tensor(a, dims=d)
207            for a, d in zip(output_arrays, self._output_axes)
208        ]
209        return Sample(
210            members={
211                tid: out
212                for tid, out in zip(
213                    self._output_ids,
214                    output_tensors,
215                )
216                if out is not None
217            },
218            stat=input_sample.stat,
219            id=(
220                input_sample.id
221                if isinstance(input_sample, Sample)
222                else input_sample.sample_id
223            ),
224        )

Run forward pass of model to get model predictions

Note: sample id and stample stat attributes are passed through

@abstractmethod
def unload(self):
232    @abstractmethod
233    def unload(self):
234        """
235        Unload model from any devices, freeing their memory.
236        The moder adapter should be considered unusable afterwards.
237        """

Unload model from any devices, freeing their memory. The moder adapter should be considered unusable afterwards.

@final
@classmethod
def create_model_adapter( model_description: Union[bioimageio.spec.model.v0_4.ModelDescr, bioimageio.spec.ModelDescr], *, devices: Optional[Sequence[str]] = None, weight_format_priority_order: Optional[Sequence[Literal['keras_hdf5', 'onnx', 'pytorch_state_dict', 'tensorflow_saved_model_bundle', 'torchscript']]] = None):
 72    @final
 73    @classmethod
 74    def create(
 75        cls,
 76        model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
 77        *,
 78        devices: Optional[Sequence[str]] = None,
 79        weight_format_priority_order: Optional[Sequence[SupportedWeightsFormat]] = None,
 80    ):
 81        """
 82        Creates model adapter based on the passed spec
 83        Note: All specific adapters should happen inside this function to prevent different framework
 84        initializations interfering with each other
 85        """
 86        if not isinstance(model_description, (v0_4.ModelDescr, v0_5.ModelDescr)):
 87            raise TypeError(
 88                f"expected v0_4.ModelDescr or v0_5.ModelDescr, but got {type(model_description)}"
 89            )
 90
 91        weights = model_description.weights
 92        errors: List[Exception] = []
 93        weight_format_priority_order = (
 94            DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER
 95            if weight_format_priority_order is None
 96            else weight_format_priority_order
 97        )
 98        # limit weight formats to the ones present
 99        weight_format_priority_order_present: Sequence[SupportedWeightsFormat] = [
100            w for w in weight_format_priority_order if getattr(weights, w) is not None
101        ]
102        if not weight_format_priority_order_present:
103            raise ValueError(
104                f"None of the specified weight formats ({weight_format_priority_order}) is present ({weight_format_priority_order_present})"
105            )
106
107        for wf in weight_format_priority_order_present:
108            if wf == "pytorch_state_dict":
109                assert weights.pytorch_state_dict is not None
110                try:
111                    from .pytorch_backend import PytorchModelAdapter
112
113                    return PytorchModelAdapter(
114                        model_description=model_description, devices=devices
115                    )
116                except Exception as e:
117                    errors.append(e)
118            elif wf == "tensorflow_saved_model_bundle":
119                assert weights.tensorflow_saved_model_bundle is not None
120                try:
121                    from .tensorflow_backend import create_tf_model_adapter
122
123                    return create_tf_model_adapter(
124                        model_description=model_description, devices=devices
125                    )
126                except Exception as e:
127                    errors.append(e)
128            elif wf == "onnx":
129                assert weights.onnx is not None
130                try:
131                    from .onnx_backend import ONNXModelAdapter
132
133                    return ONNXModelAdapter(
134                        model_description=model_description, devices=devices
135                    )
136                except Exception as e:
137                    errors.append(e)
138            elif wf == "torchscript":
139                assert weights.torchscript is not None
140                try:
141                    from .torchscript_backend import TorchscriptModelAdapter
142
143                    return TorchscriptModelAdapter(
144                        model_description=model_description, devices=devices
145                    )
146                except Exception as e:
147                    errors.append(e)
148            elif wf == "keras_hdf5":
149                assert weights.keras_hdf5 is not None
150                # keras can either be installed as a separate package or used as part of tensorflow
151                # we try to first import the keras model adapter using the separate package and,
152                # if it is not available, try to load the one using tf
153                try:
154                    try:
155                        from .keras_backend import KerasModelAdapter
156                    except Exception:
157                        from .tensorflow_backend import KerasModelAdapter
158
159                    return KerasModelAdapter(
160                        model_description=model_description, devices=devices
161                    )
162                except Exception as e:
163                    errors.append(e)
164            else:
165                assert_never(wf)
166
167        assert errors
168        if len(weight_format_priority_order) == 1:
169            assert len(errors) == 1
170            raise errors[0]
171
172        else:
173            msg = (
174                "None of the weight format specific model adapters could be created"
175                + " in this environment."
176            )
177            raise ExceptionGroup(msg, errors)

Creates model adapter based on the passed spec Note: All specific adapters should happen inside this function to prevent different framework initializations interfering with each other

def get_weight_formats() -> List[str]:
19def get_weight_formats() -> List[str]:
20    """
21    Return list of supported weight types
22    """
23    return list(DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER)

Return list of supported weight types