Coverage for bioimageio/core/weight_converters/keras_to_tensorflow.py: 0%
70 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-16 15:20 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-16 15:20 +0000
1import os
2import shutil
3from pathlib import Path
4from tempfile import TemporaryDirectory
5from typing import Union, no_type_check
6from zipfile import ZipFile
8import tensorflow # pyright: ignore[reportMissingTypeStubs]
10from bioimageio.spec._internal.version_type import Version
11from bioimageio.spec.common import ZipPath
12from bioimageio.spec.model.v0_5 import (
13 InputTensorDescr,
14 ModelDescr,
15 OutputTensorDescr,
16 TensorflowSavedModelBundleWeightsDescr,
17)
19from .. import __version__
20from ..io import ensure_unzipped
22try:
23 # try to build the tf model with the keras import from tensorflow
24 from tensorflow import keras # type: ignore
25except Exception:
26 # if the above fails try to export with the standalone keras
27 import keras # pyright: ignore[reportMissingTypeStubs]
30def convert(
31 model_descr: ModelDescr, output_path: Path
32) -> TensorflowSavedModelBundleWeightsDescr:
33 """
34 Convert model weights from the 'keras_hdf5' format to the 'tensorflow_saved_model_bundle' format.
36 This method handles the conversion of Keras HDF5 model weights into a TensorFlow SavedModel bundle,
37 which is the recommended format for deploying TensorFlow models. The method supports both TensorFlow 1.x
38 and 2.x versions, with appropriate checks to ensure compatibility.
40 Adapted from:
41 https://github.com/deepimagej/pydeepimagej/blob/5aaf0e71f9b04df591d5ca596f0af633a7e024f5/pydeepimagej/yaml/create_config.py
43 Args:
44 model_descr:
45 The bioimage.io model description containing the model's metadata and weights.
46 output_path:
47 Path with .zip suffix (.zip is appended otherwise) to which a zip archive
48 with the TensorFlow SavedModel bundle will be saved.
49 Raises:
50 ValueError:
51 - If the specified `output_path` already exists.
52 - If the Keras HDF5 weights are missing in the model description.
53 RuntimeError:
54 If there is a mismatch between the TensorFlow version used by the model and the version installed.
55 NotImplementedError:
56 If the model has multiple inputs or outputs and TensorFlow 1.x is being used.
58 Returns:
59 A descriptor object containing information about the converted TensorFlow SavedModel bundle.
60 """
61 tf_major_ver = int(tensorflow.__version__.split(".")[0])
63 if output_path.suffix != ".zip":
64 output_path = output_path.with_suffix("")
66 if output_path.exists():
67 raise ValueError(f"The ouptut directory at {output_path} must not exist.")
69 if model_descr.weights.keras_hdf5 is None:
70 raise ValueError("Missing Keras Hdf5 weights to convert from.")
72 weight_spec = model_descr.weights.keras_hdf5
73 weight_reader = weight_spec.get_reader()
75 if weight_spec.tensorflow_version:
76 model_tf_major_ver = int(weight_spec.tensorflow_version.major)
77 if model_tf_major_ver != tf_major_ver:
78 raise RuntimeError(
79 f"Tensorflow major versions of model {model_tf_major_ver} is not {tf_major_ver}"
80 )
82 with TemporaryDirectory(ignore_cleanup_errors=True) as temp_dir:
83 local_weights = ensure_unzipped(
84 weight_reader, Path(temp_dir) / "bioimageio_unzipped_tf_weights"
85 )
86 if tf_major_ver == 1:
87 if len(model_descr.inputs) != 1 or len(model_descr.outputs) != 1:
88 raise NotImplementedError(
89 "Weight conversion for models with multiple inputs or outputs is not yet implemented."
90 )
92 input_name = str(
93 d.id
94 if isinstance((d := model_descr.inputs[0]), InputTensorDescr)
95 else d.name
96 )
97 output_name = str(
98 d.id
99 if isinstance((d := model_descr.outputs[0]), OutputTensorDescr)
100 else d.name
101 )
102 return _convert_tf1(
103 ensure_unzipped(local_weights, Path("bioimageio_unzipped_tf_weights")),
104 output_path,
105 input_name,
106 output_name,
107 )
108 else:
109 return _convert_tf2(local_weights, output_path)
112def _convert_tf2(
113 keras_weight_path: Union[Path, ZipPath], output_path: Path
114) -> TensorflowSavedModelBundleWeightsDescr:
115 model = keras.models.load_model(keras_weight_path) # type: ignore
116 model.export(output_path) # type: ignore
118 output_path = _zip_model_bundle(output_path)
119 print("TensorFlow model exported to", output_path)
121 return TensorflowSavedModelBundleWeightsDescr(
122 source=output_path,
123 parent="keras_hdf5",
124 tensorflow_version=Version(tensorflow.__version__),
125 comment=f"Converted with bioimageio.core {__version__}.",
126 )
129# adapted from
130# https://github.com/deepimagej/pydeepimagej/blob/master/pydeepimagej/yaml/create_config.py#L236
131def _convert_tf1(
132 keras_weight_path: Path,
133 output_path: Path,
134 input_name: str,
135 output_name: str,
136) -> TensorflowSavedModelBundleWeightsDescr:
138 @no_type_check
139 def build_tf_model():
140 keras_model = keras.models.load_model(keras_weight_path)
141 builder = tensorflow.saved_model.builder.SavedModelBuilder(output_path)
142 signature = tensorflow.saved_model.signature_def_utils.predict_signature_def(
143 inputs={input_name: keras_model.input},
144 outputs={output_name: keras_model.output},
145 )
147 signature_def_map = {
148 tensorflow.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: (
149 signature
150 )
151 }
153 builder.add_meta_graph_and_variables(
154 keras.backend.get_session(),
155 [tensorflow.saved_model.tag_constants.SERVING],
156 signature_def_map=signature_def_map,
157 )
158 builder.save()
160 build_tf_model()
162 output_path = _zip_model_bundle(output_path)
163 print("TensorFlow model exported to", output_path)
165 return TensorflowSavedModelBundleWeightsDescr(
166 source=output_path,
167 parent="keras_hdf5",
168 tensorflow_version=Version(tensorflow.__version__),
169 comment=f"Converted with bioimageio.core {__version__}.",
170 )
173def _zip_model_bundle(model_bundle_folder: Path):
174 zipped_model_bundle = model_bundle_folder.with_suffix(".zip")
176 with ZipFile(zipped_model_bundle, "w") as zip_obj:
177 for root, _, files in os.walk(model_bundle_folder):
178 for filename in files:
179 src = os.path.join(root, filename)
180 zip_obj.write(src, os.path.relpath(src, model_bundle_folder))
182 try:
183 shutil.rmtree(model_bundle_folder)
184 except Exception:
185 print("TensorFlow bundled model was not removed after compression")
187 return zipped_model_bundle