Coverage for bioimageio/core/model_adapters/_tensorflow_model_adapter.py: 29%
105 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-19 09:02 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-19 09:02 +0000
1import zipfile
2from typing import List, Literal, Optional, Sequence, Union
4import numpy as np
5from loguru import logger
7from bioimageio.spec.common import FileSource
8from bioimageio.spec.model import v0_4, v0_5
9from bioimageio.spec.utils import download
11from ..digest_spec import get_axes_infos
12from ..tensor import Tensor
13from ._model_adapter import ModelAdapter
15try:
16 import tensorflow as tf # pyright: ignore[reportMissingImports]
17except Exception as e:
18 tf = None
19 tf_error = str(e)
20else:
21 tf_error = None
24class TensorflowModelAdapterBase(ModelAdapter):
25 weight_format: Literal["keras_hdf5", "tensorflow_saved_model_bundle"]
27 def __init__(
28 self,
29 *,
30 devices: Optional[Sequence[str]] = None,
31 weights: Union[
32 v0_4.KerasHdf5WeightsDescr,
33 v0_4.TensorflowSavedModelBundleWeightsDescr,
34 v0_5.KerasHdf5WeightsDescr,
35 v0_5.TensorflowSavedModelBundleWeightsDescr,
36 ],
37 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
38 ):
39 if tf is None:
40 raise ImportError(f"failed to import tensorflow: {tf_error}")
42 super().__init__()
43 self.model_description = model_description
44 tf_version = v0_5.Version(
45 tf.__version__ # pyright: ignore[reportUnknownArgumentType]
46 )
47 model_tf_version = weights.tensorflow_version
48 if model_tf_version is None:
49 logger.warning(
50 "The model does not specify the tensorflow version."
51 + f"Cannot check if it is compatible with intalled tensorflow {tf_version}."
52 )
53 elif model_tf_version > tf_version:
54 logger.warning(
55 f"The model specifies a newer tensorflow version than installed: {model_tf_version} > {tf_version}."
56 )
57 elif (model_tf_version.major, model_tf_version.minor) != (
58 tf_version.major,
59 tf_version.minor,
60 ):
61 logger.warning(
62 "The tensorflow version specified by the model does not match the installed: "
63 + f"{model_tf_version} != {tf_version}."
64 )
66 self.use_keras_api = (
67 tf_version.major > 1
68 or self.weight_format == KerasModelAdapter.weight_format
69 )
71 # TODO tf device management
72 if devices is not None:
73 logger.warning(
74 f"Device management is not implemented for tensorflow yet, ignoring the devices {devices}"
75 )
77 weight_file = self.require_unzipped(weights.source)
78 self._network = self._get_network(weight_file)
79 self._internal_output_axes = [
80 tuple(a.id for a in get_axes_infos(out))
81 for out in model_description.outputs
82 ]
84 def require_unzipped(self, weight_file: FileSource):
85 loacl_weights_file = download(weight_file).path
86 if zipfile.is_zipfile(loacl_weights_file):
87 out_path = loacl_weights_file.with_suffix(".unzipped")
88 with zipfile.ZipFile(loacl_weights_file, "r") as f:
89 f.extractall(out_path)
91 return out_path
92 else:
93 return loacl_weights_file
95 def _get_network( # pyright: ignore[reportUnknownParameterType]
96 self, weight_file: FileSource
97 ):
98 weight_file = self.require_unzipped(weight_file)
99 assert tf is not None
100 if self.use_keras_api:
101 try:
102 return tf.keras.layers.TFSMLayer(
103 weight_file, call_endpoint="serve"
104 ) # pyright: ignore[reportUnknownVariableType]
105 except Exception as e:
106 try:
107 return tf.keras.layers.TFSMLayer(
108 weight_file, call_endpoint="serving_default"
109 ) # pyright: ignore[reportUnknownVariableType]
110 except Exception as ee:
111 logger.opt(exception=ee).info(
112 "keras.layers.TFSMLayer error for alternative call_endpoint='serving_default'"
113 )
114 raise e
115 else:
116 # NOTE in tf1 the model needs to be loaded inside of the session, so we cannot preload the model
117 return str(weight_file)
119 # TODO currently we relaod the model every time. it would be better to keep the graph and session
120 # alive in between of forward passes (but then the sessions need to be properly opened / closed)
121 def _forward_tf( # pyright: ignore[reportUnknownParameterType]
122 self, *input_tensors: Optional[Tensor]
123 ):
124 assert tf is not None
125 input_keys = [
126 ipt.name if isinstance(ipt, v0_4.InputTensorDescr) else ipt.id
127 for ipt in self.model_description.inputs
128 ]
129 output_keys = [
130 out.name if isinstance(out, v0_4.OutputTensorDescr) else out.id
131 for out in self.model_description.outputs
132 ]
133 # TODO read from spec
134 tag = ( # pyright: ignore[reportUnknownVariableType]
135 tf.saved_model.tag_constants.SERVING
136 )
137 signature_key = ( # pyright: ignore[reportUnknownVariableType]
138 tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
139 )
141 graph = tf.Graph() # pyright: ignore[reportUnknownVariableType]
142 with graph.as_default():
143 with tf.Session(
144 graph=graph
145 ) as sess: # pyright: ignore[reportUnknownVariableType]
146 # load the model and the signature
147 graph_def = tf.saved_model.loader.load( # pyright: ignore[reportUnknownVariableType]
148 sess, [tag], self._network
149 )
150 signature = ( # pyright: ignore[reportUnknownVariableType]
151 graph_def.signature_def
152 )
154 # get the tensors into the graph
155 in_names = [ # pyright: ignore[reportUnknownVariableType]
156 signature[signature_key].inputs[key].name for key in input_keys
157 ]
158 out_names = [ # pyright: ignore[reportUnknownVariableType]
159 signature[signature_key].outputs[key].name for key in output_keys
160 ]
161 in_tensors = [ # pyright: ignore[reportUnknownVariableType]
162 graph.get_tensor_by_name(name)
163 for name in in_names # pyright: ignore[reportUnknownVariableType]
164 ]
165 out_tensors = [ # pyright: ignore[reportUnknownVariableType]
166 graph.get_tensor_by_name(name)
167 for name in out_names # pyright: ignore[reportUnknownVariableType]
168 ]
170 # run prediction
171 res = sess.run( # pyright: ignore[reportUnknownVariableType]
172 dict(
173 zip(
174 out_names, # pyright: ignore[reportUnknownArgumentType]
175 out_tensors, # pyright: ignore[reportUnknownArgumentType]
176 )
177 ),
178 dict(
179 zip(
180 in_tensors, # pyright: ignore[reportUnknownArgumentType]
181 input_tensors,
182 )
183 ),
184 )
185 # from dict to list of tensors
186 res = [ # pyright: ignore[reportUnknownVariableType]
187 res[out]
188 for out in out_names # pyright: ignore[reportUnknownVariableType]
189 ]
191 return res # pyright: ignore[reportUnknownVariableType]
193 def _forward_keras( # pyright: ignore[reportUnknownParameterType]
194 self, *input_tensors: Optional[Tensor]
195 ):
196 assert self.use_keras_api
197 assert not isinstance(self._network, str)
198 assert tf is not None
199 tf_tensor = [ # pyright: ignore[reportUnknownVariableType]
200 None if ipt is None else tf.convert_to_tensor(ipt) for ipt in input_tensors
201 ]
203 result = self._network(*tf_tensor) # pyright: ignore[reportUnknownVariableType]
205 assert isinstance(result, dict)
207 # TODO: Use RDF's `outputs[i].id` here
208 result = list(result.values())
210 return [ # pyright: ignore[reportUnknownVariableType]
211 (None if r is None else r if isinstance(r, np.ndarray) else r.numpy())
212 for r in result # pyright: ignore[reportUnknownVariableType]
213 ]
215 def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]:
216 data = [None if ipt is None else ipt.data for ipt in input_tensors]
217 if self.use_keras_api:
218 result = self._forward_keras( # pyright: ignore[reportUnknownVariableType]
219 *data
220 )
221 else:
222 result = self._forward_tf( # pyright: ignore[reportUnknownVariableType]
223 *data
224 )
226 return [
227 None if r is None else Tensor(r, dims=axes)
228 for r, axes in zip( # pyright: ignore[reportUnknownVariableType]
229 result, # pyright: ignore[reportUnknownArgumentType]
230 self._internal_output_axes,
231 )
232 ]
234 def unload(self) -> None:
235 logger.warning(
236 "Device management is not implemented for keras yet, cannot unload model"
237 )
240class TensorflowModelAdapter(TensorflowModelAdapterBase):
241 weight_format = "tensorflow_saved_model_bundle"
243 def __init__(
244 self,
245 *,
246 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
247 devices: Optional[Sequence[str]] = None,
248 ):
249 if model_description.weights.tensorflow_saved_model_bundle is None:
250 raise ValueError("missing tensorflow_saved_model_bundle weights")
252 super().__init__(
253 devices=devices,
254 weights=model_description.weights.tensorflow_saved_model_bundle,
255 model_description=model_description,
256 )
259class KerasModelAdapter(TensorflowModelAdapterBase):
260 weight_format = "keras_hdf5"
262 def __init__(
263 self,
264 *,
265 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
266 devices: Optional[Sequence[str]] = None,
267 ):
268 if model_description.weights.keras_hdf5 is None:
269 raise ValueError("missing keras_hdf5 weights")
271 super().__init__(
272 model_description=model_description,
273 devices=devices,
274 weights=model_description.weights.keras_hdf5,
275 )