Coverage for bioimageio/core/weight_converter/keras/_tensorflow.py: 20%
75 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-19 09:02 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-19 09:02 +0000
1# type: ignore # TODO: type
2import os
3import shutil
4from pathlib import Path
5from typing import no_type_check
6from zipfile import ZipFile
8try:
9 import tensorflow.saved_model
10except Exception:
11 tensorflow = None
13from bioimageio.spec._internal.io_utils import download
14from bioimageio.spec.model.v0_5 import ModelDescr
17def _zip_model_bundle(model_bundle_folder: Path):
18 zipped_model_bundle = model_bundle_folder.with_suffix(".zip")
20 with ZipFile(zipped_model_bundle, "w") as zip_obj:
21 for root, _, files in os.walk(model_bundle_folder):
22 for filename in files:
23 src = os.path.join(root, filename)
24 zip_obj.write(src, os.path.relpath(src, model_bundle_folder))
26 try:
27 shutil.rmtree(model_bundle_folder)
28 except Exception:
29 print("TensorFlow bundled model was not removed after compression")
31 return zipped_model_bundle
34# adapted from
35# https://github.com/deepimagej/pydeepimagej/blob/master/pydeepimagej/yaml/create_config.py#L236
36def _convert_tf1(
37 keras_weight_path: Path,
38 output_path: Path,
39 input_name: str,
40 output_name: str,
41 zip_weights: bool,
42):
43 try:
44 # try to build the tf model with the keras import from tensorflow
45 from bioimageio.core.weight_converter.keras._tensorflow import (
46 keras, # type: ignore
47 )
49 except Exception:
50 # if the above fails try to export with the standalone keras
51 import keras
53 @no_type_check
54 def build_tf_model():
55 keras_model = keras.models.load_model(keras_weight_path)
56 assert tensorflow is not None
57 builder = tensorflow.saved_model.builder.SavedModelBuilder(output_path)
58 signature = tensorflow.saved_model.signature_def_utils.predict_signature_def(
59 inputs={input_name: keras_model.input},
60 outputs={output_name: keras_model.output},
61 )
63 signature_def_map = {
64 tensorflow.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature
65 }
67 builder.add_meta_graph_and_variables(
68 keras.backend.get_session(),
69 [tensorflow.saved_model.tag_constants.SERVING],
70 signature_def_map=signature_def_map,
71 )
72 builder.save()
74 build_tf_model()
76 if zip_weights:
77 output_path = _zip_model_bundle(output_path)
78 print("TensorFlow model exported to", output_path)
80 return 0
83def _convert_tf2(keras_weight_path: Path, output_path: Path, zip_weights: bool):
84 try:
85 # try to build the tf model with the keras import from tensorflow
86 from bioimageio.core.weight_converter.keras._tensorflow import keras
87 except Exception:
88 # if the above fails try to export with the standalone keras
89 import keras
91 model = keras.models.load_model(keras_weight_path)
92 keras.models.save_model(model, output_path)
94 if zip_weights:
95 output_path = _zip_model_bundle(output_path)
96 print("TensorFlow model exported to", output_path)
98 return 0
101def convert_weights_to_tensorflow_saved_model_bundle(
102 model: ModelDescr, output_path: Path
103):
104 """Convert model weights from format 'keras_hdf5' to 'tensorflow_saved_model_bundle'.
106 Adapted from
107 https://github.com/deepimagej/pydeepimagej/blob/5aaf0e71f9b04df591d5ca596f0af633a7e024f5/pydeepimagej/yaml/create_config.py
109 Args:
110 model: The bioimageio model description
111 output_path: where to save the tensorflow weights. This path must not exist yet.
112 """
113 assert tensorflow is not None
114 tf_major_ver = int(tensorflow.__version__.split(".")[0])
116 if output_path.suffix == ".zip":
117 output_path = output_path.with_suffix("")
118 zip_weights = True
119 else:
120 zip_weights = False
122 if output_path.exists():
123 raise ValueError(f"The ouptut directory at {output_path} must not exist.")
125 if model.weights.keras_hdf5 is None:
126 raise ValueError("Missing Keras Hdf5 weights to convert from.")
128 weight_spec = model.weights.keras_hdf5
129 weight_path = download(weight_spec.source).path
131 if weight_spec.tensorflow_version:
132 model_tf_major_ver = int(weight_spec.tensorflow_version.major)
133 if model_tf_major_ver != tf_major_ver:
134 raise RuntimeError(
135 f"Tensorflow major versions of model {model_tf_major_ver} is not {tf_major_ver}"
136 )
138 if tf_major_ver == 1:
139 if len(model.inputs) != 1 or len(model.outputs) != 1:
140 raise NotImplementedError(
141 "Weight conversion for models with multiple inputs or outputs is not yet implemented."
142 )
143 return _convert_tf1(
144 weight_path,
145 output_path,
146 model.inputs[0].id,
147 model.outputs[0].id,
148 zip_weights,
149 )
150 else:
151 return _convert_tf2(weight_path, output_path, zip_weights)