bioimageio.core.model_adapters
DEPRECATED
1"""DEPRECATED""" 2 3from typing import List 4 5from .backends._model_adapter import ( 6 DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER, 7 ModelAdapter, 8 create_model_adapter, 9) 10 11__all__ = [ 12 "ModelAdapter", 13 "create_model_adapter", 14 "get_weight_formats", 15] 16 17 18def get_weight_formats() -> List[str]: 19 """ 20 Return list of supported weight types 21 """ 22 return list(DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER)
class
ModelAdapter(abc.ABC):
36class ModelAdapter(ABC): 37 """ 38 Represents model *without* any preprocessing or postprocessing. 39 40 ``` 41 from bioimageio.core import load_description 42 43 model = load_description(...) 44 45 # option 1: 46 adapter = ModelAdapter.create(model) 47 adapter.forward(...) 48 adapter.unload() 49 50 # option 2: 51 with ModelAdapter.create(model) as adapter: 52 adapter.forward(...) 53 ``` 54 """ 55 56 def __init__(self, model_description: AnyModelDescr): 57 super().__init__() 58 self._model_descr = model_description 59 self._input_ids = get_member_ids(model_description.inputs) 60 self._output_ids = get_member_ids(model_description.outputs) 61 self._input_axes = [ 62 tuple(a.id for a in get_axes_infos(t)) for t in model_description.inputs 63 ] 64 self._output_axes = [ 65 tuple(a.id for a in get_axes_infos(t)) for t in model_description.outputs 66 ] 67 if isinstance(model_description, v0_4.ModelDescr): 68 self._input_is_optional = [False] * len(model_description.inputs) 69 else: 70 self._input_is_optional = [ipt.optional for ipt in model_description.inputs] 71 72 @final 73 @classmethod 74 def create( 75 cls, 76 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], 77 *, 78 devices: Optional[Sequence[str]] = None, 79 weight_format_priority_order: Optional[Sequence[SupportedWeightsFormat]] = None, 80 ): 81 """ 82 Creates model adapter based on the passed spec 83 Note: All specific adapters should happen inside this function to prevent different framework 84 initializations interfering with each other 85 """ 86 if not isinstance(model_description, (v0_4.ModelDescr, v0_5.ModelDescr)): 87 raise TypeError( 88 f"expected v0_4.ModelDescr or v0_5.ModelDescr, but got {type(model_description)}" 89 ) 90 91 weights = model_description.weights 92 errors: List[Exception] = [] 93 weight_format_priority_order = ( 94 DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER 95 if weight_format_priority_order is None 96 else weight_format_priority_order 97 ) 98 # limit weight formats to the ones present 99 weight_format_priority_order_present: Sequence[SupportedWeightsFormat] = [ 100 w for w in weight_format_priority_order if getattr(weights, w) is not None 101 ] 102 if not weight_format_priority_order_present: 103 raise ValueError( 104 f"None of the specified weight formats ({weight_format_priority_order}) is present ({weight_format_priority_order_present})" 105 ) 106 107 for wf in weight_format_priority_order_present: 108 if wf == "pytorch_state_dict": 109 assert weights.pytorch_state_dict is not None 110 try: 111 from .pytorch_backend import PytorchModelAdapter 112 113 return PytorchModelAdapter( 114 model_description=model_description, devices=devices 115 ) 116 except Exception as e: 117 errors.append(e) 118 elif wf == "tensorflow_saved_model_bundle": 119 assert weights.tensorflow_saved_model_bundle is not None 120 try: 121 from .tensorflow_backend import create_tf_model_adapter 122 123 return create_tf_model_adapter( 124 model_description=model_description, devices=devices 125 ) 126 except Exception as e: 127 errors.append(e) 128 elif wf == "onnx": 129 assert weights.onnx is not None 130 try: 131 from .onnx_backend import ONNXModelAdapter 132 133 return ONNXModelAdapter( 134 model_description=model_description, devices=devices 135 ) 136 except Exception as e: 137 errors.append(e) 138 elif wf == "torchscript": 139 assert weights.torchscript is not None 140 try: 141 from .torchscript_backend import TorchscriptModelAdapter 142 143 return TorchscriptModelAdapter( 144 model_description=model_description, devices=devices 145 ) 146 except Exception as e: 147 errors.append(e) 148 elif wf == "keras_hdf5": 149 assert weights.keras_hdf5 is not None 150 # keras can either be installed as a separate package or used as part of tensorflow 151 # we try to first import the keras model adapter using the separate package and, 152 # if it is not available, try to load the one using tf 153 try: 154 try: 155 from .keras_backend import KerasModelAdapter 156 except Exception: 157 from .tensorflow_backend import KerasModelAdapter 158 159 return KerasModelAdapter( 160 model_description=model_description, devices=devices 161 ) 162 except Exception as e: 163 errors.append(e) 164 else: 165 assert_never(wf) 166 167 assert errors 168 if len(weight_format_priority_order) == 1: 169 assert len(errors) == 1 170 raise errors[0] 171 172 else: 173 msg = ( 174 "None of the weight format specific model adapters could be created" 175 + " in this environment." 176 ) 177 raise ExceptionGroup(msg, errors) 178 179 @final 180 def load(self, *, devices: Optional[Sequence[str]] = None) -> None: 181 warnings.warn("Deprecated. ModelAdapter is loaded on initialization") 182 183 def forward( 184 self, input_sample: Union[Sample, SampleBlock, SampleBlockWithOrigin] 185 ) -> Sample: 186 """ 187 Run forward pass of model to get model predictions 188 189 Note: sample id and stample stat attributes are passed through 190 """ 191 unexpected = [mid for mid in input_sample.members if mid not in self._input_ids] 192 if unexpected: 193 warnings.warn(f"Got unexpected input tensor IDs: {unexpected}") 194 195 input_arrays = [ 196 ( 197 None 198 if (a := input_sample.members.get(in_id)) is None 199 else a.transpose(in_order).data.data 200 ) 201 for in_id, in_order in zip(self._input_ids, self._input_axes) 202 ] 203 output_arrays = self._forward_impl(input_arrays) 204 assert len(output_arrays) <= len(self._output_ids) 205 output_tensors = [ 206 None if a is None else Tensor(a, dims=d) 207 for a, d in zip(output_arrays, self._output_axes) 208 ] 209 return Sample( 210 members={ 211 tid: out 212 for tid, out in zip( 213 self._output_ids, 214 output_tensors, 215 ) 216 if out is not None 217 }, 218 stat=input_sample.stat, 219 id=( 220 input_sample.id 221 if isinstance(input_sample, Sample) 222 else input_sample.sample_id 223 ), 224 ) 225 226 @abstractmethod 227 def _forward_impl( 228 self, input_arrays: Sequence[Optional[NDArray[Any]]] 229 ) -> Union[List[Optional[NDArray[Any]]], Tuple[Optional[NDArray[Any]]]]: 230 """framework specific forward implementation""" 231 232 @abstractmethod 233 def unload(self): 234 """ 235 Unload model from any devices, freeing their memory. 236 The moder adapter should be considered unusable afterwards. 237 """ 238 239 def _get_input_args_numpy(self, input_sample: Sample): 240 """helper to extract tensor args as transposed numpy arrays"""
Represents model without any preprocessing or postprocessing.
from bioimageio.core import load_description
model = load_description(...)
# option 1:
adapter = ModelAdapter.create(model)
adapter.forward(...)
adapter.unload()
# option 2:
with ModelAdapter.create(model) as adapter:
adapter.forward(...)
@final
@classmethod
def
create( cls, model_description: Union[bioimageio.spec.model.v0_4.ModelDescr, bioimageio.spec.ModelDescr], *, devices: Optional[Sequence[str]] = None, weight_format_priority_order: Optional[Sequence[Literal['keras_hdf5', 'onnx', 'pytorch_state_dict', 'tensorflow_saved_model_bundle', 'torchscript']]] = None):
72 @final 73 @classmethod 74 def create( 75 cls, 76 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], 77 *, 78 devices: Optional[Sequence[str]] = None, 79 weight_format_priority_order: Optional[Sequence[SupportedWeightsFormat]] = None, 80 ): 81 """ 82 Creates model adapter based on the passed spec 83 Note: All specific adapters should happen inside this function to prevent different framework 84 initializations interfering with each other 85 """ 86 if not isinstance(model_description, (v0_4.ModelDescr, v0_5.ModelDescr)): 87 raise TypeError( 88 f"expected v0_4.ModelDescr or v0_5.ModelDescr, but got {type(model_description)}" 89 ) 90 91 weights = model_description.weights 92 errors: List[Exception] = [] 93 weight_format_priority_order = ( 94 DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER 95 if weight_format_priority_order is None 96 else weight_format_priority_order 97 ) 98 # limit weight formats to the ones present 99 weight_format_priority_order_present: Sequence[SupportedWeightsFormat] = [ 100 w for w in weight_format_priority_order if getattr(weights, w) is not None 101 ] 102 if not weight_format_priority_order_present: 103 raise ValueError( 104 f"None of the specified weight formats ({weight_format_priority_order}) is present ({weight_format_priority_order_present})" 105 ) 106 107 for wf in weight_format_priority_order_present: 108 if wf == "pytorch_state_dict": 109 assert weights.pytorch_state_dict is not None 110 try: 111 from .pytorch_backend import PytorchModelAdapter 112 113 return PytorchModelAdapter( 114 model_description=model_description, devices=devices 115 ) 116 except Exception as e: 117 errors.append(e) 118 elif wf == "tensorflow_saved_model_bundle": 119 assert weights.tensorflow_saved_model_bundle is not None 120 try: 121 from .tensorflow_backend import create_tf_model_adapter 122 123 return create_tf_model_adapter( 124 model_description=model_description, devices=devices 125 ) 126 except Exception as e: 127 errors.append(e) 128 elif wf == "onnx": 129 assert weights.onnx is not None 130 try: 131 from .onnx_backend import ONNXModelAdapter 132 133 return ONNXModelAdapter( 134 model_description=model_description, devices=devices 135 ) 136 except Exception as e: 137 errors.append(e) 138 elif wf == "torchscript": 139 assert weights.torchscript is not None 140 try: 141 from .torchscript_backend import TorchscriptModelAdapter 142 143 return TorchscriptModelAdapter( 144 model_description=model_description, devices=devices 145 ) 146 except Exception as e: 147 errors.append(e) 148 elif wf == "keras_hdf5": 149 assert weights.keras_hdf5 is not None 150 # keras can either be installed as a separate package or used as part of tensorflow 151 # we try to first import the keras model adapter using the separate package and, 152 # if it is not available, try to load the one using tf 153 try: 154 try: 155 from .keras_backend import KerasModelAdapter 156 except Exception: 157 from .tensorflow_backend import KerasModelAdapter 158 159 return KerasModelAdapter( 160 model_description=model_description, devices=devices 161 ) 162 except Exception as e: 163 errors.append(e) 164 else: 165 assert_never(wf) 166 167 assert errors 168 if len(weight_format_priority_order) == 1: 169 assert len(errors) == 1 170 raise errors[0] 171 172 else: 173 msg = ( 174 "None of the weight format specific model adapters could be created" 175 + " in this environment." 176 ) 177 raise ExceptionGroup(msg, errors)
Creates model adapter based on the passed spec Note: All specific adapters should happen inside this function to prevent different framework initializations interfering with each other
def
forward( self, input_sample: Union[bioimageio.core.Sample, bioimageio.core.sample.SampleBlock, bioimageio.core.sample.SampleBlockWithOrigin]) -> bioimageio.core.Sample:
183 def forward( 184 self, input_sample: Union[Sample, SampleBlock, SampleBlockWithOrigin] 185 ) -> Sample: 186 """ 187 Run forward pass of model to get model predictions 188 189 Note: sample id and stample stat attributes are passed through 190 """ 191 unexpected = [mid for mid in input_sample.members if mid not in self._input_ids] 192 if unexpected: 193 warnings.warn(f"Got unexpected input tensor IDs: {unexpected}") 194 195 input_arrays = [ 196 ( 197 None 198 if (a := input_sample.members.get(in_id)) is None 199 else a.transpose(in_order).data.data 200 ) 201 for in_id, in_order in zip(self._input_ids, self._input_axes) 202 ] 203 output_arrays = self._forward_impl(input_arrays) 204 assert len(output_arrays) <= len(self._output_ids) 205 output_tensors = [ 206 None if a is None else Tensor(a, dims=d) 207 for a, d in zip(output_arrays, self._output_axes) 208 ] 209 return Sample( 210 members={ 211 tid: out 212 for tid, out in zip( 213 self._output_ids, 214 output_tensors, 215 ) 216 if out is not None 217 }, 218 stat=input_sample.stat, 219 id=( 220 input_sample.id 221 if isinstance(input_sample, Sample) 222 else input_sample.sample_id 223 ), 224 )
Run forward pass of model to get model predictions
Note: sample id and stample stat attributes are passed through
@abstractmethod
def
unload(self):
232 @abstractmethod 233 def unload(self): 234 """ 235 Unload model from any devices, freeing their memory. 236 The moder adapter should be considered unusable afterwards. 237 """
Unload model from any devices, freeing their memory. The moder adapter should be considered unusable afterwards.
@final
@classmethod
def
create_model_adapter( model_description: Union[bioimageio.spec.model.v0_4.ModelDescr, bioimageio.spec.ModelDescr], *, devices: Optional[Sequence[str]] = None, weight_format_priority_order: Optional[Sequence[Literal['keras_hdf5', 'onnx', 'pytorch_state_dict', 'tensorflow_saved_model_bundle', 'torchscript']]] = None):
72 @final 73 @classmethod 74 def create( 75 cls, 76 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], 77 *, 78 devices: Optional[Sequence[str]] = None, 79 weight_format_priority_order: Optional[Sequence[SupportedWeightsFormat]] = None, 80 ): 81 """ 82 Creates model adapter based on the passed spec 83 Note: All specific adapters should happen inside this function to prevent different framework 84 initializations interfering with each other 85 """ 86 if not isinstance(model_description, (v0_4.ModelDescr, v0_5.ModelDescr)): 87 raise TypeError( 88 f"expected v0_4.ModelDescr or v0_5.ModelDescr, but got {type(model_description)}" 89 ) 90 91 weights = model_description.weights 92 errors: List[Exception] = [] 93 weight_format_priority_order = ( 94 DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER 95 if weight_format_priority_order is None 96 else weight_format_priority_order 97 ) 98 # limit weight formats to the ones present 99 weight_format_priority_order_present: Sequence[SupportedWeightsFormat] = [ 100 w for w in weight_format_priority_order if getattr(weights, w) is not None 101 ] 102 if not weight_format_priority_order_present: 103 raise ValueError( 104 f"None of the specified weight formats ({weight_format_priority_order}) is present ({weight_format_priority_order_present})" 105 ) 106 107 for wf in weight_format_priority_order_present: 108 if wf == "pytorch_state_dict": 109 assert weights.pytorch_state_dict is not None 110 try: 111 from .pytorch_backend import PytorchModelAdapter 112 113 return PytorchModelAdapter( 114 model_description=model_description, devices=devices 115 ) 116 except Exception as e: 117 errors.append(e) 118 elif wf == "tensorflow_saved_model_bundle": 119 assert weights.tensorflow_saved_model_bundle is not None 120 try: 121 from .tensorflow_backend import create_tf_model_adapter 122 123 return create_tf_model_adapter( 124 model_description=model_description, devices=devices 125 ) 126 except Exception as e: 127 errors.append(e) 128 elif wf == "onnx": 129 assert weights.onnx is not None 130 try: 131 from .onnx_backend import ONNXModelAdapter 132 133 return ONNXModelAdapter( 134 model_description=model_description, devices=devices 135 ) 136 except Exception as e: 137 errors.append(e) 138 elif wf == "torchscript": 139 assert weights.torchscript is not None 140 try: 141 from .torchscript_backend import TorchscriptModelAdapter 142 143 return TorchscriptModelAdapter( 144 model_description=model_description, devices=devices 145 ) 146 except Exception as e: 147 errors.append(e) 148 elif wf == "keras_hdf5": 149 assert weights.keras_hdf5 is not None 150 # keras can either be installed as a separate package or used as part of tensorflow 151 # we try to first import the keras model adapter using the separate package and, 152 # if it is not available, try to load the one using tf 153 try: 154 try: 155 from .keras_backend import KerasModelAdapter 156 except Exception: 157 from .tensorflow_backend import KerasModelAdapter 158 159 return KerasModelAdapter( 160 model_description=model_description, devices=devices 161 ) 162 except Exception as e: 163 errors.append(e) 164 else: 165 assert_never(wf) 166 167 assert errors 168 if len(weight_format_priority_order) == 1: 169 assert len(errors) == 1 170 raise errors[0] 171 172 else: 173 msg = ( 174 "None of the weight format specific model adapters could be created" 175 + " in this environment." 176 ) 177 raise ExceptionGroup(msg, errors)
Creates model adapter based on the passed spec Note: All specific adapters should happen inside this function to prevent different framework initializations interfering with each other
def
get_weight_formats() -> List[str]:
19def get_weight_formats() -> List[str]: 20 """ 21 Return list of supported weight types 22 """ 23 return list(DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER)
Return list of supported weight types