Coverage for src / bioimageio / core / backends / tensorflow_backend.py: 58%
77 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 13:23 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 13:23 +0000
1from pathlib import Path
2from typing import Any, Optional, Sequence, Union
4import numpy as np
5import tensorflow as tf
6from loguru import logger
7from numpy.typing import NDArray
9from bioimageio.core.io import ensure_unzipped
10from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
12from ._model_adapter import ModelAdapter
15class TensorflowModelAdapter(ModelAdapter):
16 weight_format = "tensorflow_saved_model_bundle"
18 def __init__(
19 self,
20 *,
21 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
22 devices: Optional[Sequence[str]] = None,
23 ):
24 super().__init__(model_description=model_description)
26 weight_file = model_description.weights.tensorflow_saved_model_bundle
27 if model_description.weights.tensorflow_saved_model_bundle is None:
28 raise ValueError("No `tensorflow_saved_model_bundle` weights found")
30 if devices is not None:
31 logger.warning(
32 f"Device management is not implemented for tensorflow yet, ignoring the devices {devices}"
33 )
35 # TODO: check how to load tf weights without unzipping
36 weight_file = ensure_unzipped(
37 model_description.weights.tensorflow_saved_model_bundle.source,
38 Path("bioimageio_unzipped_tf_weights"),
39 )
40 self._network = str(weight_file)
42 # TODO currently we relaod the model every time. it would be better to keep the graph and session
43 # alive in between of forward passes (but then the sessions need to be properly opened / closed)
44 def _forward_impl( # pyright: ignore[reportUnknownParameterType]
45 self, input_arrays: Sequence[Optional[NDArray[Any]]]
46 ):
47 # TODO read from spec
48 tag = ( # pyright: ignore[reportUnknownVariableType]
49 tf.saved_model.tag_constants.SERVING
50 )
51 signature_key = ( # pyright: ignore[reportUnknownVariableType]
52 tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
53 )
55 graph = tf.Graph()
56 with graph.as_default():
57 with tf.Session(graph=graph) as sess: # pyright: ignore[reportUnknownVariableType]
58 # load the model and the signature
59 graph_def = tf.saved_model.loader.load( # pyright: ignore[reportUnknownVariableType]
60 sess, [tag], self._network
61 )
62 signature = ( # pyright: ignore[reportUnknownVariableType]
63 graph_def.signature_def
64 )
66 # get the tensors into the graph
67 in_names = [ # pyright: ignore[reportUnknownVariableType]
68 signature[signature_key].inputs[key].name for key in self._input_ids
69 ]
70 out_names = [ # pyright: ignore[reportUnknownVariableType]
71 signature[signature_key].outputs[key].name
72 for key in self._output_ids
73 ]
74 in_tf_tensors = [
75 graph.get_tensor_by_name(
76 name # pyright: ignore[reportUnknownArgumentType]
77 )
78 for name in in_names # pyright: ignore[reportUnknownVariableType]
79 ]
80 out_tf_tensors = [
81 graph.get_tensor_by_name(
82 name # pyright: ignore[reportUnknownArgumentType]
83 )
84 for name in out_names # pyright: ignore[reportUnknownVariableType]
85 ]
87 # run prediction
88 res = sess.run( # pyright: ignore[reportUnknownVariableType]
89 dict(
90 zip(
91 out_names, # pyright: ignore[reportUnknownArgumentType]
92 out_tf_tensors,
93 )
94 ),
95 dict(zip(in_tf_tensors, input_arrays)),
96 )
97 # from dict to list of tensors
98 res = [ # pyright: ignore[reportUnknownVariableType]
99 res[out]
100 for out in out_names # pyright: ignore[reportUnknownVariableType]
101 ]
103 return res # pyright: ignore[reportUnknownVariableType]
105 def unload(self) -> None:
106 logger.warning(
107 "Device management is not implemented for tensorflow 1, cannot unload model"
108 )
111class KerasModelAdapter(ModelAdapter):
112 def __init__(
113 self,
114 *,
115 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
116 devices: Optional[Sequence[str]] = None,
117 ):
118 if model_description.weights.tensorflow_saved_model_bundle is None:
119 raise ValueError("No `tensorflow_saved_model_bundle` weights found")
121 super().__init__(model_description=model_description)
122 if devices is not None:
123 logger.warning(
124 f"Device management is not implemented for tensorflow yet, ignoring the devices {devices}"
125 )
127 # TODO: check how to load tf weights without unzipping
128 weight_file = ensure_unzipped(
129 model_description.weights.tensorflow_saved_model_bundle.source,
130 Path("bioimageio_unzipped_tf_weights"),
131 )
133 try:
134 self._network = tf.keras.layers.TFSMLayer(
135 weight_file,
136 call_endpoint="serve",
137 )
138 except Exception as e:
139 try:
140 self._network = tf.keras.layers.TFSMLayer(
141 weight_file, call_endpoint="serving_default"
142 )
143 except Exception as ee:
144 logger.opt(exception=ee).info(
145 "keras.layers.TFSMLayer error for alternative call_endpoint='serving_default'"
146 )
147 raise e
149 def _forward_impl( # pyright: ignore[reportUnknownParameterType]
150 self, input_arrays: Sequence[Optional[NDArray[Any]]]
151 ):
152 assert tf is not None
153 tf_tensor = [
154 None if ipt is None else tf.convert_to_tensor(ipt) for ipt in input_arrays
155 ]
157 result = self._network(*tf_tensor) # pyright: ignore[reportUnknownVariableType]
159 assert isinstance(result, dict)
161 # TODO: Use RDF's `outputs[i].id` here
162 result = list( # pyright: ignore[reportUnknownVariableType]
163 result.values() # pyright: ignore[reportUnknownArgumentType]
164 )
166 return [ # pyright: ignore[reportUnknownVariableType]
167 (None if r is None else r if isinstance(r, np.ndarray) else r.numpy())
168 for r in result # pyright: ignore[reportUnknownVariableType]
169 ]
171 def unload(self) -> None:
172 logger.warning(
173 "Device management is not implemented for tensorflow>=2 models"
174 + f" using `{self.__class__.__name__}`, cannot unload model"
175 )
178def create_tf_model_adapter(
179 model_description: AnyModelDescr, devices: Optional[Sequence[str]]
180):
181 tf_version = v0_5.Version(tf.__version__) # type: ignore[reportUnknownVariableType]
182 weights = model_description.weights.tensorflow_saved_model_bundle
183 if weights is None:
184 raise ValueError("No `tensorflow_saved_model_bundle` weights found")
186 model_tf_version = weights.tensorflow_version
187 if model_tf_version is None:
188 logger.warning(
189 "The model does not specify the tensorflow version."
190 + f"Cannot check if it is compatible with intalled tensorflow {tf_version}."
191 )
192 elif model_tf_version > tf_version:
193 logger.warning(
194 f"The model specifies a newer tensorflow version than installed: {model_tf_version} > {tf_version}."
195 )
196 elif (model_tf_version.major, model_tf_version.minor) != (
197 tf_version.major,
198 tf_version.minor,
199 ):
200 logger.warning(
201 "The tensorflow version specified by the model does not match the installed: "
202 + f"{model_tf_version} != {tf_version}."
203 )
205 if tf_version.major <= 1:
206 return TensorflowModelAdapter(
207 model_description=model_description, devices=devices
208 )
209 else:
210 return KerasModelAdapter(model_description=model_description, devices=devices)