Coverage for src/bioimageio/core/weight_converters/keras_to_tensorflow.py: 0%
72 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-13 11:02 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-13 11:02 +0000
1import os
2import shutil
3import sys
4from pathlib import Path
5from tempfile import TemporaryDirectory
6from typing import Any, Union, no_type_check
7from zipfile import ZipFile
9import tensorflow # pyright: ignore[reportMissingTypeStubs]
11from bioimageio.spec._internal.version_type import Version
12from bioimageio.spec.common import ZipPath
13from bioimageio.spec.model.v0_5 import (
14 InputTensorDescr,
15 ModelDescr,
16 OutputTensorDescr,
17 TensorflowSavedModelBundleWeightsDescr,
18)
20from .. import __version__
21from ..io import ensure_unzipped
23try:
24 # try to build the tf model with the keras import from tensorflow
25 from tensorflow import keras # type: ignore
26except Exception:
27 # if the above fails try to export with the standalone keras
28 import keras # pyright: ignore[reportMissingTypeStubs]
31def convert(
32 model_descr: ModelDescr, output_path: Path
33) -> TensorflowSavedModelBundleWeightsDescr:
34 """
35 Convert model weights from the 'keras_hdf5' format to the 'tensorflow_saved_model_bundle' format.
37 This method handles the conversion of Keras HDF5 model weights into a TensorFlow SavedModel bundle,
38 which is the recommended format for deploying TensorFlow models. The method supports both TensorFlow 1.x
39 and 2.x versions, with appropriate checks to ensure compatibility.
41 Adapted from:
42 https://github.com/deepimagej/pydeepimagej/blob/5aaf0e71f9b04df591d5ca596f0af633a7e024f5/pydeepimagej/yaml/create_config.py
44 Args:
45 model_descr:
46 The bioimage.io model description containing the model's metadata and weights.
47 output_path:
48 Path with .zip suffix (.zip is appended otherwise) to which a zip archive
49 with the TensorFlow SavedModel bundle will be saved.
50 Raises:
51 ValueError:
52 - If the specified `output_path` already exists.
53 - If the Keras HDF5 weights are missing in the model description.
54 RuntimeError:
55 If there is a mismatch between the TensorFlow version used by the model and the version installed.
56 NotImplementedError:
57 If the model has multiple inputs or outputs and TensorFlow 1.x is being used.
59 Returns:
60 A descriptor object containing information about the converted TensorFlow SavedModel bundle.
61 """
62 tf_major_ver = int(tensorflow.__version__.split(".")[0])
64 if output_path.suffix != ".zip":
65 output_path = output_path.with_suffix("")
67 if output_path.exists():
68 raise ValueError(f"The ouptut directory at {output_path} must not exist.")
70 if model_descr.weights.keras_hdf5 is None:
71 raise ValueError("Missing Keras Hdf5 weights to convert from.")
73 weight_spec = model_descr.weights.keras_hdf5
74 weight_reader = weight_spec.get_reader()
76 if weight_spec.tensorflow_version:
77 model_tf_major_ver = int(weight_spec.tensorflow_version.major)
78 if model_tf_major_ver != tf_major_ver:
79 raise RuntimeError(
80 f"Tensorflow major versions of model {model_tf_major_ver} is not {tf_major_ver}"
81 )
83 td_kwargs: dict[str, Any] = (
84 dict(ignore_cleanup_errors=True) if sys.version_info >= (3, 10) else {}
85 )
86 with TemporaryDirectory(**td_kwargs) as temp_dir:
87 local_weights = ensure_unzipped(
88 weight_reader, Path(temp_dir) / "bioimageio_unzipped_tf_weights"
89 )
90 if tf_major_ver == 1:
91 if len(model_descr.inputs) != 1 or len(model_descr.outputs) != 1:
92 raise NotImplementedError(
93 "Weight conversion for models with multiple inputs or outputs is not yet implemented."
94 )
96 input_name = str(
97 d.id
98 if isinstance((d := model_descr.inputs[0]), InputTensorDescr)
99 else d.name
100 )
101 output_name = str(
102 d.id
103 if isinstance((d := model_descr.outputs[0]), OutputTensorDescr)
104 else d.name
105 )
106 return _convert_tf1(
107 ensure_unzipped(local_weights, Path("bioimageio_unzipped_tf_weights")),
108 output_path,
109 input_name,
110 output_name,
111 )
112 else:
113 return _convert_tf2(local_weights, output_path)
116def _convert_tf2(
117 keras_weight_path: Union[Path, ZipPath], output_path: Path
118) -> TensorflowSavedModelBundleWeightsDescr:
119 model = keras.models.load_model(keras_weight_path) # type: ignore
120 model.export(output_path) # type: ignore
122 output_path = _zip_model_bundle(output_path)
123 print("TensorFlow model exported to", output_path)
125 return TensorflowSavedModelBundleWeightsDescr(
126 source=output_path.absolute(),
127 parent="keras_hdf5",
128 tensorflow_version=Version(tensorflow.__version__),
129 comment=f"Converted with bioimageio.core {__version__}.",
130 )
133# adapted from
134# https://github.com/deepimagej/pydeepimagej/blob/master/pydeepimagej/yaml/create_config.py#L236
135def _convert_tf1(
136 keras_weight_path: Path,
137 output_path: Path,
138 input_name: str,
139 output_name: str,
140) -> TensorflowSavedModelBundleWeightsDescr:
141 @no_type_check
142 def build_tf_model():
143 keras_model = keras.models.load_model(keras_weight_path)
144 builder = tensorflow.saved_model.builder.SavedModelBuilder(output_path)
145 signature = tensorflow.saved_model.signature_def_utils.predict_signature_def(
146 inputs={input_name: keras_model.input},
147 outputs={output_name: keras_model.output},
148 )
150 signature_def_map = {
151 tensorflow.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: (
152 signature
153 )
154 }
156 builder.add_meta_graph_and_variables(
157 keras.backend.get_session(),
158 [tensorflow.saved_model.tag_constants.SERVING],
159 signature_def_map=signature_def_map,
160 )
161 builder.save()
163 build_tf_model()
165 output_path = _zip_model_bundle(output_path)
166 print("TensorFlow model exported to", output_path)
168 return TensorflowSavedModelBundleWeightsDescr(
169 source=output_path.absolute(),
170 parent="keras_hdf5",
171 tensorflow_version=Version(tensorflow.__version__),
172 comment=f"Converted with bioimageio.core {__version__}.",
173 )
176def _zip_model_bundle(model_bundle_folder: Path):
177 zipped_model_bundle = model_bundle_folder.with_suffix(".zip")
179 with ZipFile(zipped_model_bundle, "w") as zip_obj:
180 for root, _, files in os.walk(model_bundle_folder):
181 for filename in files:
182 src = os.path.join(root, filename)
183 zip_obj.write(src, os.path.relpath(src, model_bundle_folder))
185 try:
186 shutil.rmtree(model_bundle_folder)
187 except Exception:
188 print("TensorFlow bundled model was not removed after compression")
190 return zipped_model_bundle