Coverage for bioimageio/core/model_adapters/_model_adapter.py: 58%
66 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-19 09:02 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-19 09:02 +0000
1import warnings
2from abc import ABC, abstractmethod
3from typing import List, Optional, Sequence, Tuple, Union, final
5from bioimageio.spec.model import v0_4, v0_5
7from ..tensor import Tensor
9WeightsFormat = Union[v0_4.WeightsFormat, v0_5.WeightsFormat]
11# Known weight formats in order of priority
12# First match wins
13DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER: Tuple[WeightsFormat, ...] = (
14 "pytorch_state_dict",
15 "tensorflow_saved_model_bundle",
16 "torchscript",
17 "onnx",
18 "keras_hdf5",
19)
22class ModelAdapter(ABC):
23 """
24 Represents model *without* any preprocessing or postprocessing.
26 ```
27 from bioimageio.core import load_description
29 model = load_description(...)
31 # option 1:
32 adapter = ModelAdapter.create(model)
33 adapter.forward(...)
34 adapter.unload()
36 # option 2:
37 with ModelAdapter.create(model) as adapter:
38 adapter.forward(...)
39 ```
40 """
42 @final
43 @classmethod
44 def create(
45 cls,
46 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
47 *,
48 devices: Optional[Sequence[str]] = None,
49 weight_format_priority_order: Optional[Sequence[WeightsFormat]] = None,
50 ):
51 """
52 Creates model adapter based on the passed spec
53 Note: All specific adapters should happen inside this function to prevent different framework
54 initializations interfering with each other
55 """
56 if not isinstance(model_description, (v0_4.ModelDescr, v0_5.ModelDescr)):
57 raise TypeError(
58 f"expected v0_4.ModelDescr or v0_5.ModelDescr, but got {type(model_description)}"
59 )
61 weights = model_description.weights
62 errors: List[Tuple[WeightsFormat, Exception]] = []
63 weight_format_priority_order = (
64 DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER
65 if weight_format_priority_order is None
66 else weight_format_priority_order
67 )
68 # limit weight formats to the ones present
69 weight_format_priority_order = [
70 w for w in weight_format_priority_order if getattr(weights, w) is not None
71 ]
73 for wf in weight_format_priority_order:
74 if wf == "pytorch_state_dict" and weights.pytorch_state_dict is not None:
75 try:
76 from ._pytorch_model_adapter import PytorchModelAdapter
78 return PytorchModelAdapter(
79 outputs=model_description.outputs,
80 weights=weights.pytorch_state_dict,
81 devices=devices,
82 )
83 except Exception as e:
84 errors.append((wf, e))
85 elif (
86 wf == "tensorflow_saved_model_bundle"
87 and weights.tensorflow_saved_model_bundle is not None
88 ):
89 try:
90 from ._tensorflow_model_adapter import TensorflowModelAdapter
92 return TensorflowModelAdapter(
93 model_description=model_description, devices=devices
94 )
95 except Exception as e:
96 errors.append((wf, e))
97 elif wf == "onnx" and weights.onnx is not None:
98 try:
99 from ._onnx_model_adapter import ONNXModelAdapter
101 return ONNXModelAdapter(
102 model_description=model_description, devices=devices
103 )
104 except Exception as e:
105 errors.append((wf, e))
106 elif wf == "torchscript" and weights.torchscript is not None:
107 try:
108 from ._torchscript_model_adapter import TorchscriptModelAdapter
110 return TorchscriptModelAdapter(
111 model_description=model_description, devices=devices
112 )
113 except Exception as e:
114 errors.append((wf, e))
115 elif wf == "keras_hdf5" and weights.keras_hdf5 is not None:
116 # keras can either be installed as a separate package or used as part of tensorflow
117 # we try to first import the keras model adapter using the separate package and,
118 # if it is not available, try to load the one using tf
119 try:
120 from ._keras_model_adapter import (
121 KerasModelAdapter,
122 keras, # type: ignore
123 )
125 if keras is None:
126 from ._tensorflow_model_adapter import KerasModelAdapter
128 return KerasModelAdapter(
129 model_description=model_description, devices=devices
130 )
131 except Exception as e:
132 errors.append((wf, e))
134 assert errors
135 if len(weight_format_priority_order) == 1:
136 assert len(errors) == 1
137 raise ValueError(
138 f"The '{weight_format_priority_order[0]}' model adapter could not be created"
139 + f" in this environment:\n{errors[0][1].__class__.__name__}({errors[0][1]}).\n\n"
140 )
142 else:
143 error_list = "\n - ".join(
144 f"{wf}: {e.__class__.__name__}({e})" for wf, e in errors
145 )
146 raise ValueError(
147 "None of the weight format specific model adapters could be created"
148 + f" in this environment. Errors are:\n\n{error_list}.\n\n"
149 )
151 @final
152 def load(self, *, devices: Optional[Sequence[str]] = None) -> None:
153 warnings.warn("Deprecated. ModelAdapter is loaded on initialization")
155 @abstractmethod
156 def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]:
157 """
158 Run forward pass of model to get model predictions
159 """
160 # TODO: handle tensor.transpose in here and make _forward_impl the abstract impl
162 @abstractmethod
163 def unload(self):
164 """
165 Unload model from any devices, freeing their memory.
166 The moder adapter should be considered unusable afterwards.
167 """
170def get_weight_formats() -> List[str]:
171 """
172 Return list of supported weight types
173 """
174 return list(DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER)
177create_model_adapter = ModelAdapter.create