bioimageio.core.model_adapters

1from ._model_adapter import ModelAdapter, create_model_adapter, get_weight_formats
2
3__all__ = [
4    "ModelAdapter",
5    "create_model_adapter",
6    "get_weight_formats",
7]
class ModelAdapter(abc.ABC):
 23class ModelAdapter(ABC):
 24    """
 25    Represents model *without* any preprocessing or postprocessing.
 26
 27    ```
 28    from bioimageio.core import load_description
 29
 30    model = load_description(...)
 31
 32    # option 1:
 33    adapter = ModelAdapter.create(model)
 34    adapter.forward(...)
 35    adapter.unload()
 36
 37    # option 2:
 38    with ModelAdapter.create(model) as adapter:
 39        adapter.forward(...)
 40    ```
 41    """
 42
 43    @final
 44    @classmethod
 45    def create(
 46        cls,
 47        model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
 48        *,
 49        devices: Optional[Sequence[str]] = None,
 50        weight_format_priority_order: Optional[Sequence[WeightsFormat]] = None,
 51    ):
 52        """
 53        Creates model adapter based on the passed spec
 54        Note: All specific adapters should happen inside this function to prevent different framework
 55        initializations interfering with each other
 56        """
 57        if not isinstance(model_description, (v0_4.ModelDescr, v0_5.ModelDescr)):
 58            raise TypeError(
 59                f"expected v0_4.ModelDescr or v0_5.ModelDescr, but got {type(model_description)}"
 60            )
 61
 62        weights = model_description.weights
 63        errors: List[Tuple[WeightsFormat, Exception]] = []
 64        weight_format_priority_order = (
 65            DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER
 66            if weight_format_priority_order is None
 67            else weight_format_priority_order
 68        )
 69        # limit weight formats to the ones present
 70        weight_format_priority_order = [
 71            w for w in weight_format_priority_order if getattr(weights, w) is not None
 72        ]
 73
 74        for wf in weight_format_priority_order:
 75            if wf == "pytorch_state_dict" and weights.pytorch_state_dict is not None:
 76                try:
 77                    from ._pytorch_model_adapter import PytorchModelAdapter
 78
 79                    return PytorchModelAdapter(
 80                        outputs=model_description.outputs,
 81                        weights=weights.pytorch_state_dict,
 82                        devices=devices,
 83                    )
 84                except Exception as e:
 85                    errors.append((wf, e))
 86            elif (
 87                wf == "tensorflow_saved_model_bundle"
 88                and weights.tensorflow_saved_model_bundle is not None
 89            ):
 90                try:
 91                    from ._tensorflow_model_adapter import TensorflowModelAdapter
 92
 93                    return TensorflowModelAdapter(
 94                        model_description=model_description, devices=devices
 95                    )
 96                except Exception as e:
 97                    errors.append((wf, e))
 98            elif wf == "onnx" and weights.onnx is not None:
 99                try:
100                    from ._onnx_model_adapter import ONNXModelAdapter
101
102                    return ONNXModelAdapter(
103                        model_description=model_description, devices=devices
104                    )
105                except Exception as e:
106                    errors.append((wf, e))
107            elif wf == "torchscript" and weights.torchscript is not None:
108                try:
109                    from ._torchscript_model_adapter import TorchscriptModelAdapter
110
111                    return TorchscriptModelAdapter(
112                        model_description=model_description, devices=devices
113                    )
114                except Exception as e:
115                    errors.append((wf, e))
116            elif wf == "keras_hdf5" and weights.keras_hdf5 is not None:
117                # keras can either be installed as a separate package or used as part of tensorflow
118                # we try to first import the keras model adapter using the separate package and,
119                # if it is not available, try to load the one using tf
120                try:
121                    from ._keras_model_adapter import (
122                        KerasModelAdapter,
123                        keras,  # type: ignore
124                    )
125
126                    if keras is None:
127                        from ._tensorflow_model_adapter import KerasModelAdapter
128
129                    return KerasModelAdapter(
130                        model_description=model_description, devices=devices
131                    )
132                except Exception as e:
133                    errors.append((wf, e))
134
135        assert errors
136        if len(weight_format_priority_order) == 1:
137            assert len(errors) == 1
138            raise ValueError(
139                f"The '{weight_format_priority_order[0]}' model adapter could not be created"
140                + f" in this environment:\n{errors[0][1].__class__.__name__}({errors[0][1]}).\n\n"
141            )
142
143        else:
144            error_list = "\n - ".join(
145                f"{wf}: {e.__class__.__name__}({e})" for wf, e in errors
146            )
147            raise ValueError(
148                "None of the weight format specific model adapters could be created"
149                + f" in this environment. Errors are:\n\n{error_list}.\n\n"
150            )
151
152    @final
153    def load(self, *, devices: Optional[Sequence[str]] = None) -> None:
154        warnings.warn("Deprecated. ModelAdapter is loaded on initialization")
155
156    @abstractmethod
157    def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]:
158        """
159        Run forward pass of model to get model predictions
160        """
161        # TODO: handle tensor.transpose in here and make _forward_impl the abstract impl
162
163    @abstractmethod
164    def unload(self):
165        """
166        Unload model from any devices, freeing their memory.
167        The moder adapter should be considered unusable afterwards.
168        """

Represents model without any preprocessing or postprocessing.

from bioimageio.core import load_description

model = load_description(...)

# option 1:
adapter = ModelAdapter.create(model)
adapter.forward(...)
adapter.unload()

# option 2:
with ModelAdapter.create(model) as adapter:
    adapter.forward(...)
@final
@classmethod
def create( cls, model_description: Union[bioimageio.spec.model.v0_4.ModelDescr, bioimageio.spec.ModelDescr], *, devices: Optional[Sequence[str]] = None, weight_format_priority_order: Optional[Sequence[Literal['keras_hdf5', 'onnx', 'pytorch_state_dict', 'tensorflow_js', 'tensorflow_saved_model_bundle', 'torchscript']]] = None):
 43    @final
 44    @classmethod
 45    def create(
 46        cls,
 47        model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
 48        *,
 49        devices: Optional[Sequence[str]] = None,
 50        weight_format_priority_order: Optional[Sequence[WeightsFormat]] = None,
 51    ):
 52        """
 53        Creates model adapter based on the passed spec
 54        Note: All specific adapters should happen inside this function to prevent different framework
 55        initializations interfering with each other
 56        """
 57        if not isinstance(model_description, (v0_4.ModelDescr, v0_5.ModelDescr)):
 58            raise TypeError(
 59                f"expected v0_4.ModelDescr or v0_5.ModelDescr, but got {type(model_description)}"
 60            )
 61
 62        weights = model_description.weights
 63        errors: List[Tuple[WeightsFormat, Exception]] = []
 64        weight_format_priority_order = (
 65            DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER
 66            if weight_format_priority_order is None
 67            else weight_format_priority_order
 68        )
 69        # limit weight formats to the ones present
 70        weight_format_priority_order = [
 71            w for w in weight_format_priority_order if getattr(weights, w) is not None
 72        ]
 73
 74        for wf in weight_format_priority_order:
 75            if wf == "pytorch_state_dict" and weights.pytorch_state_dict is not None:
 76                try:
 77                    from ._pytorch_model_adapter import PytorchModelAdapter
 78
 79                    return PytorchModelAdapter(
 80                        outputs=model_description.outputs,
 81                        weights=weights.pytorch_state_dict,
 82                        devices=devices,
 83                    )
 84                except Exception as e:
 85                    errors.append((wf, e))
 86            elif (
 87                wf == "tensorflow_saved_model_bundle"
 88                and weights.tensorflow_saved_model_bundle is not None
 89            ):
 90                try:
 91                    from ._tensorflow_model_adapter import TensorflowModelAdapter
 92
 93                    return TensorflowModelAdapter(
 94                        model_description=model_description, devices=devices
 95                    )
 96                except Exception as e:
 97                    errors.append((wf, e))
 98            elif wf == "onnx" and weights.onnx is not None:
 99                try:
100                    from ._onnx_model_adapter import ONNXModelAdapter
101
102                    return ONNXModelAdapter(
103                        model_description=model_description, devices=devices
104                    )
105                except Exception as e:
106                    errors.append((wf, e))
107            elif wf == "torchscript" and weights.torchscript is not None:
108                try:
109                    from ._torchscript_model_adapter import TorchscriptModelAdapter
110
111                    return TorchscriptModelAdapter(
112                        model_description=model_description, devices=devices
113                    )
114                except Exception as e:
115                    errors.append((wf, e))
116            elif wf == "keras_hdf5" and weights.keras_hdf5 is not None:
117                # keras can either be installed as a separate package or used as part of tensorflow
118                # we try to first import the keras model adapter using the separate package and,
119                # if it is not available, try to load the one using tf
120                try:
121                    from ._keras_model_adapter import (
122                        KerasModelAdapter,
123                        keras,  # type: ignore
124                    )
125
126                    if keras is None:
127                        from ._tensorflow_model_adapter import KerasModelAdapter
128
129                    return KerasModelAdapter(
130                        model_description=model_description, devices=devices
131                    )
132                except Exception as e:
133                    errors.append((wf, e))
134
135        assert errors
136        if len(weight_format_priority_order) == 1:
137            assert len(errors) == 1
138            raise ValueError(
139                f"The '{weight_format_priority_order[0]}' model adapter could not be created"
140                + f" in this environment:\n{errors[0][1].__class__.__name__}({errors[0][1]}).\n\n"
141            )
142
143        else:
144            error_list = "\n - ".join(
145                f"{wf}: {e.__class__.__name__}({e})" for wf, e in errors
146            )
147            raise ValueError(
148                "None of the weight format specific model adapters could be created"
149                + f" in this environment. Errors are:\n\n{error_list}.\n\n"
150            )

Creates model adapter based on the passed spec Note: All specific adapters should happen inside this function to prevent different framework initializations interfering with each other

@final
def load(self, *, devices: Optional[Sequence[str]] = None) -> None:
152    @final
153    def load(self, *, devices: Optional[Sequence[str]] = None) -> None:
154        warnings.warn("Deprecated. ModelAdapter is loaded on initialization")
@abstractmethod
def forward( self, *input_tensors: Optional[bioimageio.core.Tensor]) -> List[Optional[bioimageio.core.Tensor]]:
156    @abstractmethod
157    def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]:
158        """
159        Run forward pass of model to get model predictions
160        """
161        # TODO: handle tensor.transpose in here and make _forward_impl the abstract impl

Run forward pass of model to get model predictions

@abstractmethod
def unload(self):
163    @abstractmethod
164    def unload(self):
165        """
166        Unload model from any devices, freeing their memory.
167        The moder adapter should be considered unusable afterwards.
168        """

Unload model from any devices, freeing their memory. The moder adapter should be considered unusable afterwards.

@final
@classmethod
def create_model_adapter( model_description: Union[bioimageio.spec.model.v0_4.ModelDescr, bioimageio.spec.ModelDescr], *, devices: Optional[Sequence[str]] = None, weight_format_priority_order: Optional[Sequence[Literal['keras_hdf5', 'onnx', 'pytorch_state_dict', 'tensorflow_js', 'tensorflow_saved_model_bundle', 'torchscript']]] = None):
 43    @final
 44    @classmethod
 45    def create(
 46        cls,
 47        model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
 48        *,
 49        devices: Optional[Sequence[str]] = None,
 50        weight_format_priority_order: Optional[Sequence[WeightsFormat]] = None,
 51    ):
 52        """
 53        Creates model adapter based on the passed spec
 54        Note: All specific adapters should happen inside this function to prevent different framework
 55        initializations interfering with each other
 56        """
 57        if not isinstance(model_description, (v0_4.ModelDescr, v0_5.ModelDescr)):
 58            raise TypeError(
 59                f"expected v0_4.ModelDescr or v0_5.ModelDescr, but got {type(model_description)}"
 60            )
 61
 62        weights = model_description.weights
 63        errors: List[Tuple[WeightsFormat, Exception]] = []
 64        weight_format_priority_order = (
 65            DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER
 66            if weight_format_priority_order is None
 67            else weight_format_priority_order
 68        )
 69        # limit weight formats to the ones present
 70        weight_format_priority_order = [
 71            w for w in weight_format_priority_order if getattr(weights, w) is not None
 72        ]
 73
 74        for wf in weight_format_priority_order:
 75            if wf == "pytorch_state_dict" and weights.pytorch_state_dict is not None:
 76                try:
 77                    from ._pytorch_model_adapter import PytorchModelAdapter
 78
 79                    return PytorchModelAdapter(
 80                        outputs=model_description.outputs,
 81                        weights=weights.pytorch_state_dict,
 82                        devices=devices,
 83                    )
 84                except Exception as e:
 85                    errors.append((wf, e))
 86            elif (
 87                wf == "tensorflow_saved_model_bundle"
 88                and weights.tensorflow_saved_model_bundle is not None
 89            ):
 90                try:
 91                    from ._tensorflow_model_adapter import TensorflowModelAdapter
 92
 93                    return TensorflowModelAdapter(
 94                        model_description=model_description, devices=devices
 95                    )
 96                except Exception as e:
 97                    errors.append((wf, e))
 98            elif wf == "onnx" and weights.onnx is not None:
 99                try:
100                    from ._onnx_model_adapter import ONNXModelAdapter
101
102                    return ONNXModelAdapter(
103                        model_description=model_description, devices=devices
104                    )
105                except Exception as e:
106                    errors.append((wf, e))
107            elif wf == "torchscript" and weights.torchscript is not None:
108                try:
109                    from ._torchscript_model_adapter import TorchscriptModelAdapter
110
111                    return TorchscriptModelAdapter(
112                        model_description=model_description, devices=devices
113                    )
114                except Exception as e:
115                    errors.append((wf, e))
116            elif wf == "keras_hdf5" and weights.keras_hdf5 is not None:
117                # keras can either be installed as a separate package or used as part of tensorflow
118                # we try to first import the keras model adapter using the separate package and,
119                # if it is not available, try to load the one using tf
120                try:
121                    from ._keras_model_adapter import (
122                        KerasModelAdapter,
123                        keras,  # type: ignore
124                    )
125
126                    if keras is None:
127                        from ._tensorflow_model_adapter import KerasModelAdapter
128
129                    return KerasModelAdapter(
130                        model_description=model_description, devices=devices
131                    )
132                except Exception as e:
133                    errors.append((wf, e))
134
135        assert errors
136        if len(weight_format_priority_order) == 1:
137            assert len(errors) == 1
138            raise ValueError(
139                f"The '{weight_format_priority_order[0]}' model adapter could not be created"
140                + f" in this environment:\n{errors[0][1].__class__.__name__}({errors[0][1]}).\n\n"
141            )
142
143        else:
144            error_list = "\n - ".join(
145                f"{wf}: {e.__class__.__name__}({e})" for wf, e in errors
146            )
147            raise ValueError(
148                "None of the weight format specific model adapters could be created"
149                + f" in this environment. Errors are:\n\n{error_list}.\n\n"
150            )

Creates model adapter based on the passed spec Note: All specific adapters should happen inside this function to prevent different framework initializations interfering with each other

def get_weight_formats() -> List[str]:
171def get_weight_formats() -> List[str]:
172    """
173    Return list of supported weight types
174    """
175    return list(DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER)

Return list of supported weight types