Coverage for src/bioimageio/core/backends/__init__.py: 59%
68 statements
« prev ^ index » next coverage.py v7.14.2, created at 2026-06-22 16:54 +0000
« prev ^ index » next coverage.py v7.14.2, created at 2026-06-22 16:54 +0000
1from typing import (
2 List,
3 Optional,
4 Sequence,
5 Tuple,
6 Union,
7)
9from exceptiongroup import ExceptionGroup
10from typing_extensions import assert_never
12from bioimageio.spec.model import v0_4, v0_5
14from ..common import SupportedWeightsFormat
16# Known weight formats in order of priority
17# First match wins
18DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER: Tuple[SupportedWeightsFormat, ...] = (
19 "pytorch_state_dict",
20 "tensorflow_saved_model_bundle",
21 "torchscript",
22 "onnx",
23 "keras_v3",
24 "keras_hdf5",
25)
28def create_model_adapter(
29 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
30 *,
31 devices: Optional[Sequence[str]] = None,
32 weight_format_priority_order: Optional[Sequence[SupportedWeightsFormat]] = None,
33):
34 """Creates model adapter for `model_descritption`"""
35 if not isinstance(model_description, (v0_4.ModelDescr, v0_5.ModelDescr)):
36 raise TypeError(
37 f"expected v0_4.ModelDescr or v0_5.ModelDescr, but got {type(model_description)}"
38 )
40 weights = model_description.weights
41 errors: List[Exception] = []
42 weight_format_priority_order = (
43 DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER
44 if weight_format_priority_order is None
45 else weight_format_priority_order
46 )
47 # limit weight formats to the ones present
48 weight_format_priority_order_present: Sequence[SupportedWeightsFormat] = [
49 w for w in weight_format_priority_order if getattr(weights, w, None) is not None
50 ]
51 if not weight_format_priority_order_present:
52 raise ValueError(
53 f"None of the specified weight formats ({weight_format_priority_order}) is present ({weight_format_priority_order_present})"
54 )
56 for wf in weight_format_priority_order_present:
57 if wf == "pytorch_state_dict":
58 assert weights.pytorch_state_dict is not None
59 try:
60 from .pytorch_backend import PytorchModelAdapter
62 return PytorchModelAdapter(model_description, devices=devices)
63 except Exception as e:
64 errors.append(e)
65 elif wf == "tensorflow_saved_model_bundle":
66 assert weights.tensorflow_saved_model_bundle is not None
67 try:
68 from .tensorflow_backend import create_tf_model_adapter
70 return create_tf_model_adapter(model_description, devices=devices)
71 except Exception as e:
72 errors.append(e)
73 elif wf == "onnx":
74 assert weights.onnx is not None
75 try:
76 from .onnx_backend import ONNXModelAdapter
78 return ONNXModelAdapter(model_description, devices=devices)
79 except Exception as e:
80 errors.append(e)
81 elif wf == "torchscript":
82 assert weights.torchscript is not None
83 try:
84 from .torchscript_backend import TorchscriptModelAdapter
86 return TorchscriptModelAdapter(model_description, devices=devices)
87 except Exception as e:
88 errors.append(e)
89 elif wf == "keras_hdf5":
90 assert weights.keras_hdf5 is not None
91 # keras can either be installed as a separate package or used as part of tensorflow
92 # we try to first import the keras model adapter using the separate package and,
93 # if it is not available, try to load the one using tf
94 try:
95 try:
96 from .keras_backend import KerasModelAdapter
97 except Exception:
98 from .tensorflow_backend import KerasModelAdapter
100 return KerasModelAdapter(model_description, devices=devices)
101 except Exception as e:
102 errors.append(e)
103 elif wf == "keras_v3":
104 assert not isinstance(weights, v0_4.WeightsDescr), (
105 "keras_v3 weights not supported for v0.4 specs"
106 )
107 assert weights.keras_v3 is not None
108 try:
109 from .keras_backend import KerasModelAdapter
111 return KerasModelAdapter(model_description, devices=devices)
112 except Exception as e:
113 errors.append(e)
114 else:
115 assert_never(wf)
117 assert errors
118 if len(weight_format_priority_order) == 1:
119 assert len(errors) == 1
120 raise errors[0]
122 else:
123 msg = (
124 "None of the weight format specific model adapters could be created"
125 + " in this environment."
126 )
127 raise ExceptionGroup(msg, errors)