Coverage for src/bioimageio/core/backends/tensorflow_backend.py: 0%
99 statements
« prev ^ index » next coverage.py v7.14.2, created at 2026-06-22 16:54 +0000
« prev ^ index » next coverage.py v7.14.2, created at 2026-06-22 16:54 +0000
1from pathlib import Path
2from typing import Any, List, Optional, Sequence, Tuple, Union
4import numpy as np
5import tensorflow as tf
6from loguru import logger
7from numpy.typing import NDArray
9from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
11from .._model_adapter import LocalModelAdapter
12from ..io import ensure_unzipped
15class TensorflowModelAdapter(LocalModelAdapter[None, Any]):
16 """Adapter for TensorFlow 1 models"""
18 weight_format = "tensorflow_saved_model_bundle"
20 def __init__(
21 self,
22 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
23 devices: Optional[Sequence[str]] = None,
24 ):
26 if model_description.weights.tensorflow_saved_model_bundle is None:
27 raise ValueError("No `tensorflow_saved_model_bundle` weights found")
29 if isinstance(model_description, v0_4.ModelDescr):
30 self._weight_src = (
31 model_description.weights.tensorflow_saved_model_bundle.source
32 )
33 else:
34 self._weight_src = model_description.weights.tensorflow_saved_model_bundle
36 self._graph = None
37 self._io_names: Optional[Tuple[List[str], List[str]]] = None
38 super().__init__(model_description=model_description, devices=devices)
40 def _parse_devices(self, devices: Optional[Sequence[str]]) -> Tuple[None]:
41 if devices is not None:
42 logger.warning(
43 f"Device management is not implemented for tensorflow yet, ignoring the devices {devices}"
44 )
45 return (None,)
47 def _init_model_on_device(self, device: Optional[str]) -> Any:
49 # TODO: check how to load tf weights without unzipping
50 weight_file = ensure_unzipped(
51 self._weight_src, Path("bioimageio_unzipped_tf_weights")
52 )
54 # TODO read from spec
55 tag = ( # pyright: ignore[reportUnknownVariableType]
56 tf.saved_model.tag_constants.SERVING
57 )
58 signature_key = ( # pyright: ignore[reportUnknownVariableType]
59 tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
60 )
62 self._graph = tf.Graph()
63 with self._graph.as_default():
64 sess = tf.Session(graph=self._graph) # pyright: ignore[reportUnknownVariableType]
65 # load the model and the signature
66 graph_def = tf.saved_model.loader.load( # pyright: ignore[reportUnknownVariableType]
67 sess, [tag], str(weight_file)
68 )
69 signature = ( # pyright: ignore[reportUnknownVariableType]
70 graph_def.signature_def
71 )
73 # get the tensors into the graph
74 in_names = [ # pyright: ignore[reportUnknownVariableType]
75 signature[signature_key].inputs[key].name for key in self._input_ids
76 ]
77 out_names = [ # pyright: ignore[reportUnknownVariableType]
78 signature[signature_key].outputs[key].name for key in self._output_ids
79 ]
80 self._io_names = (in_names, out_names)
82 return sess # pyright: ignore[reportUnknownVariableType]
84 def _forward_impl(
85 self, device: None, model: Any, input_arrays: Sequence[Optional[NDArray[Any]]]
86 ):
87 assert self._io_names is not None
88 assert self._graph is not None
90 in_names, out_names = self._io_names
91 in_tf_tensors = [self._graph.get_tensor_by_name(name) for name in in_names]
92 out_tf_tensors = [self._graph.get_tensor_by_name(name) for name in out_names]
94 # run prediction
95 res = model.run(
96 dict(zip(out_names, out_tf_tensors)),
97 dict(zip(in_tf_tensors, input_arrays)),
98 )
99 # from dict to list of tensors
100 res = [res[out] for out in out_names]
102 return res
104 def _cleanup_pre_model_deletion(self, device: Optional[str], model: Any) -> None:
105 return
107 def _cleanup_post_model_deletion(self, device: Optional[str]) -> None:
108 return
111class KerasModelAdapter(LocalModelAdapter[None, Any]):
112 def __init__(
113 self,
114 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
115 devices: Optional[Sequence[str]] = None,
116 ):
117 if model_description.weights.tensorflow_saved_model_bundle is None:
118 raise ValueError("No `tensorflow_saved_model_bundle` weights found")
120 if isinstance(model_description, v0_4.ModelDescr):
121 self._weight_src = (
122 model_description.weights.tensorflow_saved_model_bundle.source
123 )
124 else:
125 self._weight_src = model_description.weights.tensorflow_saved_model_bundle
127 super().__init__(model_description=model_description, devices=devices)
129 def _parse_devices(self, devices: Optional[Sequence[str]]) -> Tuple[None]:
130 if devices is not None:
131 logger.warning(
132 f"Device management is not implemented for tensorflow yet, ignoring the devices {devices}"
133 )
134 return (None,)
136 def _init_model_on_device(self, device: None) -> Any:
137 # TODO: check how to load tf weights without unzipping
138 weight_file = str(
139 ensure_unzipped(self._weight_src, Path("bioimageio_unzipped_tf_weights"))
140 )
142 try:
143 tfsm_layer = tf.keras.layers.TFSMLayer( # pyright: ignore[reportUnknownVariableType]
144 weight_file,
145 call_endpoint="serve",
146 )
147 except Exception as e:
148 try:
149 tfsm_layer = tf.keras.layers.TFSMLayer( # pyright: ignore[reportUnknownVariableType]
150 weight_file, call_endpoint="serving_default"
151 )
152 except Exception as ee:
153 logger.opt(exception=ee).info(
154 "keras.layers.TFSMLayer error for alternative call_endpoint='serving_default'"
155 )
156 raise e
158 return tfsm_layer # pyright: ignore[reportUnknownVariableType]
160 def _forward_impl( # pyright: ignore[reportUnknownParameterType]
161 self, device: None, model: Any, input_arrays: Sequence[Optional[NDArray[Any]]]
162 ):
163 assert tf is not None
164 tf_tensor = [
165 None if ipt is None else tf.convert_to_tensor(ipt) for ipt in input_arrays
166 ]
167 result = model(*tf_tensor)
168 assert isinstance(result, dict)
170 # TODO: Use RDF's `outputs[i].id` here
171 result = list( # pyright: ignore[reportUnknownVariableType]
172 result.values() # pyright: ignore[reportUnknownArgumentType]
173 )
175 return [ # pyright: ignore[reportUnknownVariableType]
176 (None if r is None else r if isinstance(r, np.ndarray) else r.numpy())
177 for r in result # pyright: ignore[reportUnknownVariableType]
178 ]
180 def _cleanup_pre_model_deletion(self, device: Optional[str], model: Any) -> None:
181 return
183 def _cleanup_post_model_deletion(self, device: Optional[str]) -> None:
184 return
187def create_tf_model_adapter(
188 model_description: AnyModelDescr, devices: Optional[Sequence[str]] = None
189):
190 tf_version = v0_5.Version(tf.__version__) # type: ignore[reportUnknownVariableType]
191 weights = model_description.weights.tensorflow_saved_model_bundle
192 if weights is None:
193 raise ValueError("No `tensorflow_saved_model_bundle` weights found")
195 model_tf_version = weights.tensorflow_version
196 if model_tf_version is None:
197 logger.warning(
198 "The model does not specify the tensorflow version."
199 + f"Cannot check if it is compatible with intalled tensorflow {tf_version}."
200 )
201 elif model_tf_version > tf_version:
202 logger.warning(
203 f"The model specifies a newer tensorflow version than installed: {model_tf_version} > {tf_version}."
204 )
205 elif (model_tf_version.major, model_tf_version.minor) != (
206 tf_version.major,
207 tf_version.minor,
208 ):
209 logger.warning(
210 "The tensorflow version specified by the model does not match the installed: "
211 + f"{model_tf_version} != {tf_version}."
212 )
214 if tf_version.major <= 1:
215 return TensorflowModelAdapter(model_description, devices=devices)
216 else:
217 return KerasModelAdapter(model_description, devices=devices)