Coverage for bioimageio/core/backends/tensorflow_backend.py: 58%
77 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 09:51 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 09:51 +0000
1from pathlib import Path
2from typing import Any, Optional, Sequence, Union
4import numpy as np
5import tensorflow as tf # pyright: ignore[reportMissingTypeStubs]
6from loguru import logger
7from numpy.typing import NDArray
9from bioimageio.core.io import ensure_unzipped
10from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
12from ._model_adapter import ModelAdapter
15class TensorflowModelAdapter(ModelAdapter):
16 weight_format = "tensorflow_saved_model_bundle"
18 def __init__(
19 self,
20 *,
21 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
22 devices: Optional[Sequence[str]] = None,
23 ):
24 super().__init__(model_description=model_description)
26 weight_file = model_description.weights.tensorflow_saved_model_bundle
27 if model_description.weights.tensorflow_saved_model_bundle is None:
28 raise ValueError("No `tensorflow_saved_model_bundle` weights found")
30 if devices is not None:
31 logger.warning(
32 f"Device management is not implemented for tensorflow yet, ignoring the devices {devices}"
33 )
35 # TODO: check how to load tf weights without unzipping
36 weight_file = ensure_unzipped(
37 model_description.weights.tensorflow_saved_model_bundle.source,
38 Path("bioimageio_unzipped_tf_weights"),
39 )
40 self._network = str(weight_file)
42 # TODO currently we relaod the model every time. it would be better to keep the graph and session
43 # alive in between of forward passes (but then the sessions need to be properly opened / closed)
44 def _forward_impl( # pyright: ignore[reportUnknownParameterType]
45 self, input_arrays: Sequence[Optional[NDArray[Any]]]
46 ):
47 # TODO read from spec
48 tag = ( # pyright: ignore[reportUnknownVariableType]
49 tf.saved_model.tag_constants.SERVING # pyright: ignore[reportAttributeAccessIssue]
50 )
51 signature_key = ( # pyright: ignore[reportUnknownVariableType]
52 tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY # pyright: ignore[reportAttributeAccessIssue]
53 )
55 graph = tf.Graph()
56 with graph.as_default():
57 with tf.Session( # pyright: ignore[reportAttributeAccessIssue]
58 graph=graph
59 ) as sess: # pyright: ignore[reportUnknownVariableType]
60 # load the model and the signature
61 graph_def = tf.saved_model.loader.load( # pyright: ignore[reportUnknownVariableType,reportAttributeAccessIssue]
62 sess, [tag], self._network
63 )
64 signature = ( # pyright: ignore[reportUnknownVariableType]
65 graph_def.signature_def
66 )
68 # get the tensors into the graph
69 in_names = [ # pyright: ignore[reportUnknownVariableType]
70 signature[signature_key].inputs[key].name for key in self._input_ids
71 ]
72 out_names = [ # pyright: ignore[reportUnknownVariableType]
73 signature[signature_key].outputs[key].name
74 for key in self._output_ids
75 ]
76 in_tf_tensors = [
77 graph.get_tensor_by_name(
78 name # pyright: ignore[reportUnknownArgumentType]
79 )
80 for name in in_names # pyright: ignore[reportUnknownVariableType]
81 ]
82 out_tf_tensors = [
83 graph.get_tensor_by_name(
84 name # pyright: ignore[reportUnknownArgumentType]
85 )
86 for name in out_names # pyright: ignore[reportUnknownVariableType]
87 ]
89 # run prediction
90 res = sess.run( # pyright: ignore[reportUnknownVariableType]
91 dict(
92 zip(
93 out_names, # pyright: ignore[reportUnknownArgumentType]
94 out_tf_tensors,
95 )
96 ),
97 dict(zip(in_tf_tensors, input_arrays)),
98 )
99 # from dict to list of tensors
100 res = [ # pyright: ignore[reportUnknownVariableType]
101 res[out]
102 for out in out_names # pyright: ignore[reportUnknownVariableType]
103 ]
105 return res # pyright: ignore[reportUnknownVariableType]
107 def unload(self) -> None:
108 logger.warning(
109 "Device management is not implemented for tensorflow 1, cannot unload model"
110 )
113class KerasModelAdapter(ModelAdapter):
114 def __init__(
115 self,
116 *,
117 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
118 devices: Optional[Sequence[str]] = None,
119 ):
120 if model_description.weights.tensorflow_saved_model_bundle is None:
121 raise ValueError("No `tensorflow_saved_model_bundle` weights found")
123 super().__init__(model_description=model_description)
124 if devices is not None:
125 logger.warning(
126 f"Device management is not implemented for tensorflow yet, ignoring the devices {devices}"
127 )
129 # TODO: check how to load tf weights without unzipping
130 weight_file = ensure_unzipped(
131 model_description.weights.tensorflow_saved_model_bundle.source,
132 Path("bioimageio_unzipped_tf_weights"),
133 )
135 try:
136 self._network = tf.keras.layers.TFSMLayer( # pyright: ignore[reportAttributeAccessIssue]
137 weight_file,
138 call_endpoint="serve",
139 )
140 except Exception as e:
141 try:
142 self._network = tf.keras.layers.TFSMLayer( # pyright: ignore[reportAttributeAccessIssue]
143 weight_file, call_endpoint="serving_default"
144 )
145 except Exception as ee:
146 logger.opt(exception=ee).info(
147 "keras.layers.TFSMLayer error for alternative call_endpoint='serving_default'"
148 )
149 raise e
151 def _forward_impl( # pyright: ignore[reportUnknownParameterType]
152 self, input_arrays: Sequence[Optional[NDArray[Any]]]
153 ):
154 assert tf is not None
155 tf_tensor = [
156 None if ipt is None else tf.convert_to_tensor(ipt) for ipt in input_arrays
157 ]
159 result = self._network(*tf_tensor) # pyright: ignore[reportUnknownVariableType]
161 assert isinstance(result, dict)
163 # TODO: Use RDF's `outputs[i].id` here
164 result = list( # pyright: ignore[reportUnknownVariableType]
165 result.values() # pyright: ignore[reportUnknownArgumentType]
166 )
168 return [ # pyright: ignore[reportUnknownVariableType]
169 (None if r is None else r if isinstance(r, np.ndarray) else r.numpy())
170 for r in result # pyright: ignore[reportUnknownVariableType]
171 ]
173 def unload(self) -> None:
174 logger.warning(
175 "Device management is not implemented for tensorflow>=2 models"
176 + f" using `{self.__class__.__name__}`, cannot unload model"
177 )
180def create_tf_model_adapter(
181 model_description: AnyModelDescr, devices: Optional[Sequence[str]]
182):
183 tf_version = v0_5.Version(tf.__version__)
184 weights = model_description.weights.tensorflow_saved_model_bundle
185 if weights is None:
186 raise ValueError("No `tensorflow_saved_model_bundle` weights found")
188 model_tf_version = weights.tensorflow_version
189 if model_tf_version is None:
190 logger.warning(
191 "The model does not specify the tensorflow version."
192 + f"Cannot check if it is compatible with intalled tensorflow {tf_version}."
193 )
194 elif model_tf_version > tf_version:
195 logger.warning(
196 f"The model specifies a newer tensorflow version than installed: {model_tf_version} > {tf_version}."
197 )
198 elif (model_tf_version.major, model_tf_version.minor) != (
199 tf_version.major,
200 tf_version.minor,
201 ):
202 logger.warning(
203 "The tensorflow version specified by the model does not match the installed: "
204 + f"{model_tf_version} != {tf_version}."
205 )
207 if tf_version.major <= 1:
208 return TensorflowModelAdapter(
209 model_description=model_description, devices=devices
210 )
211 else:
212 return KerasModelAdapter(model_description=model_description, devices=devices)