Coverage for bioimageio/core/weight_converters/keras_to_tensorflow.py: 0%
68 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 09:51 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 09:51 +0000
1import os
2import shutil
3from pathlib import Path
4from typing import Union, no_type_check
5from zipfile import ZipFile
7import tensorflow # pyright: ignore[reportMissingTypeStubs]
9from bioimageio.spec._internal.io import download
10from bioimageio.spec._internal.version_type import Version
11from bioimageio.spec.common import ZipPath
12from bioimageio.spec.model.v0_5 import (
13 InputTensorDescr,
14 ModelDescr,
15 OutputTensorDescr,
16 TensorflowSavedModelBundleWeightsDescr,
17)
19from .. import __version__
20from ..io import ensure_unzipped
22try:
23 # try to build the tf model with the keras import from tensorflow
24 from tensorflow import keras # type: ignore
25except Exception:
26 # if the above fails try to export with the standalone keras
27 import keras # pyright: ignore[reportMissingTypeStubs]
30def convert(
31 model_descr: ModelDescr, output_path: Path
32) -> TensorflowSavedModelBundleWeightsDescr:
33 """
34 Convert model weights from the 'keras_hdf5' format to the 'tensorflow_saved_model_bundle' format.
36 This method handles the conversion of Keras HDF5 model weights into a TensorFlow SavedModel bundle,
37 which is the recommended format for deploying TensorFlow models. The method supports both TensorFlow 1.x
38 and 2.x versions, with appropriate checks to ensure compatibility.
40 Adapted from:
41 https://github.com/deepimagej/pydeepimagej/blob/5aaf0e71f9b04df591d5ca596f0af633a7e024f5/pydeepimagej/yaml/create_config.py
43 Args:
44 model_descr:
45 The bioimage.io model description containing the model's metadata and weights.
46 output_path:
47 Path with .zip suffix (.zip is appended otherwise) to which a zip archive
48 with the TensorFlow SavedModel bundle will be saved.
49 Raises:
50 ValueError:
51 - If the specified `output_path` already exists.
52 - If the Keras HDF5 weights are missing in the model description.
53 RuntimeError:
54 If there is a mismatch between the TensorFlow version used by the model and the version installed.
55 NotImplementedError:
56 If the model has multiple inputs or outputs and TensorFlow 1.x is being used.
58 Returns:
59 A descriptor object containing information about the converted TensorFlow SavedModel bundle.
60 """
61 tf_major_ver = int(tensorflow.__version__.split(".")[0])
63 if output_path.suffix != ".zip":
64 output_path = output_path.with_suffix("")
66 if output_path.exists():
67 raise ValueError(f"The ouptut directory at {output_path} must not exist.")
69 if model_descr.weights.keras_hdf5 is None:
70 raise ValueError("Missing Keras Hdf5 weights to convert from.")
72 weight_spec = model_descr.weights.keras_hdf5
73 weight_path = download(weight_spec.source).path
75 if weight_spec.tensorflow_version:
76 model_tf_major_ver = int(weight_spec.tensorflow_version.major)
77 if model_tf_major_ver != tf_major_ver:
78 raise RuntimeError(
79 f"Tensorflow major versions of model {model_tf_major_ver} is not {tf_major_ver}"
80 )
82 if tf_major_ver == 1:
83 if len(model_descr.inputs) != 1 or len(model_descr.outputs) != 1:
84 raise NotImplementedError(
85 "Weight conversion for models with multiple inputs or outputs is not yet implemented."
86 )
88 input_name = str(
89 d.id
90 if isinstance((d := model_descr.inputs[0]), InputTensorDescr)
91 else d.name
92 )
93 output_name = str(
94 d.id
95 if isinstance((d := model_descr.outputs[0]), OutputTensorDescr)
96 else d.name
97 )
98 return _convert_tf1(
99 ensure_unzipped(weight_path, Path("bioimageio_unzipped_tf_weights")),
100 output_path,
101 input_name,
102 output_name,
103 )
104 else:
105 return _convert_tf2(weight_path, output_path)
108def _convert_tf2(
109 keras_weight_path: Union[Path, ZipPath], output_path: Path
110) -> TensorflowSavedModelBundleWeightsDescr:
111 model = keras.models.load_model(keras_weight_path) # type: ignore
112 model.export(output_path) # type: ignore
114 output_path = _zip_model_bundle(output_path)
115 print("TensorFlow model exported to", output_path)
117 return TensorflowSavedModelBundleWeightsDescr(
118 source=output_path,
119 parent="keras_hdf5",
120 tensorflow_version=Version(tensorflow.__version__),
121 comment=f"Converted with bioimageio.core {__version__}.",
122 )
125# adapted from
126# https://github.com/deepimagej/pydeepimagej/blob/master/pydeepimagej/yaml/create_config.py#L236
127def _convert_tf1(
128 keras_weight_path: Path,
129 output_path: Path,
130 input_name: str,
131 output_name: str,
132) -> TensorflowSavedModelBundleWeightsDescr:
134 @no_type_check
135 def build_tf_model():
136 keras_model = keras.models.load_model(keras_weight_path)
137 builder = tensorflow.saved_model.builder.SavedModelBuilder(output_path)
138 signature = tensorflow.saved_model.signature_def_utils.predict_signature_def(
139 inputs={input_name: keras_model.input},
140 outputs={output_name: keras_model.output},
141 )
143 signature_def_map = {
144 tensorflow.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: (
145 signature
146 )
147 }
149 builder.add_meta_graph_and_variables(
150 keras.backend.get_session(),
151 [tensorflow.saved_model.tag_constants.SERVING],
152 signature_def_map=signature_def_map,
153 )
154 builder.save()
156 build_tf_model()
158 output_path = _zip_model_bundle(output_path)
159 print("TensorFlow model exported to", output_path)
161 return TensorflowSavedModelBundleWeightsDescr(
162 source=output_path,
163 parent="keras_hdf5",
164 tensorflow_version=Version(tensorflow.__version__),
165 comment=f"Converted with bioimageio.core {__version__}.",
166 )
169def _zip_model_bundle(model_bundle_folder: Path):
170 zipped_model_bundle = model_bundle_folder.with_suffix(".zip")
172 with ZipFile(zipped_model_bundle, "w") as zip_obj:
173 for root, _, files in os.walk(model_bundle_folder):
174 for filename in files:
175 src = os.path.join(root, filename)
176 zip_obj.write(src, os.path.relpath(src, model_bundle_folder))
178 try:
179 shutil.rmtree(model_bundle_folder)
180 except Exception:
181 print("TensorFlow bundled model was not removed after compression")
183 return zipped_model_bundle