Coverage for bioimageio/core/weight_converters/keras_to_tensorflow.py: 0%

68 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-01 09:51 +0000

1import os 

2import shutil 

3from pathlib import Path 

4from typing import Union, no_type_check 

5from zipfile import ZipFile 

6 

7import tensorflow # pyright: ignore[reportMissingTypeStubs] 

8 

9from bioimageio.spec._internal.io import download 

10from bioimageio.spec._internal.version_type import Version 

11from bioimageio.spec.common import ZipPath 

12from bioimageio.spec.model.v0_5 import ( 

13 InputTensorDescr, 

14 ModelDescr, 

15 OutputTensorDescr, 

16 TensorflowSavedModelBundleWeightsDescr, 

17) 

18 

19from .. import __version__ 

20from ..io import ensure_unzipped 

21 

22try: 

23 # try to build the tf model with the keras import from tensorflow 

24 from tensorflow import keras # type: ignore 

25except Exception: 

26 # if the above fails try to export with the standalone keras 

27 import keras # pyright: ignore[reportMissingTypeStubs] 

28 

29 

30def convert( 

31 model_descr: ModelDescr, output_path: Path 

32) -> TensorflowSavedModelBundleWeightsDescr: 

33 """ 

34 Convert model weights from the 'keras_hdf5' format to the 'tensorflow_saved_model_bundle' format. 

35 

36 This method handles the conversion of Keras HDF5 model weights into a TensorFlow SavedModel bundle, 

37 which is the recommended format for deploying TensorFlow models. The method supports both TensorFlow 1.x 

38 and 2.x versions, with appropriate checks to ensure compatibility. 

39 

40 Adapted from: 

41 https://github.com/deepimagej/pydeepimagej/blob/5aaf0e71f9b04df591d5ca596f0af633a7e024f5/pydeepimagej/yaml/create_config.py 

42 

43 Args: 

44 model_descr: 

45 The bioimage.io model description containing the model's metadata and weights. 

46 output_path: 

47 Path with .zip suffix (.zip is appended otherwise) to which a zip archive 

48 with the TensorFlow SavedModel bundle will be saved. 

49 Raises: 

50 ValueError: 

51 - If the specified `output_path` already exists. 

52 - If the Keras HDF5 weights are missing in the model description. 

53 RuntimeError: 

54 If there is a mismatch between the TensorFlow version used by the model and the version installed. 

55 NotImplementedError: 

56 If the model has multiple inputs or outputs and TensorFlow 1.x is being used. 

57 

58 Returns: 

59 A descriptor object containing information about the converted TensorFlow SavedModel bundle. 

60 """ 

61 tf_major_ver = int(tensorflow.__version__.split(".")[0]) 

62 

63 if output_path.suffix != ".zip": 

64 output_path = output_path.with_suffix("") 

65 

66 if output_path.exists(): 

67 raise ValueError(f"The ouptut directory at {output_path} must not exist.") 

68 

69 if model_descr.weights.keras_hdf5 is None: 

70 raise ValueError("Missing Keras Hdf5 weights to convert from.") 

71 

72 weight_spec = model_descr.weights.keras_hdf5 

73 weight_path = download(weight_spec.source).path 

74 

75 if weight_spec.tensorflow_version: 

76 model_tf_major_ver = int(weight_spec.tensorflow_version.major) 

77 if model_tf_major_ver != tf_major_ver: 

78 raise RuntimeError( 

79 f"Tensorflow major versions of model {model_tf_major_ver} is not {tf_major_ver}" 

80 ) 

81 

82 if tf_major_ver == 1: 

83 if len(model_descr.inputs) != 1 or len(model_descr.outputs) != 1: 

84 raise NotImplementedError( 

85 "Weight conversion for models with multiple inputs or outputs is not yet implemented." 

86 ) 

87 

88 input_name = str( 

89 d.id 

90 if isinstance((d := model_descr.inputs[0]), InputTensorDescr) 

91 else d.name 

92 ) 

93 output_name = str( 

94 d.id 

95 if isinstance((d := model_descr.outputs[0]), OutputTensorDescr) 

96 else d.name 

97 ) 

98 return _convert_tf1( 

99 ensure_unzipped(weight_path, Path("bioimageio_unzipped_tf_weights")), 

100 output_path, 

101 input_name, 

102 output_name, 

103 ) 

104 else: 

105 return _convert_tf2(weight_path, output_path) 

106 

107 

108def _convert_tf2( 

109 keras_weight_path: Union[Path, ZipPath], output_path: Path 

110) -> TensorflowSavedModelBundleWeightsDescr: 

111 model = keras.models.load_model(keras_weight_path) # type: ignore 

112 model.export(output_path) # type: ignore 

113 

114 output_path = _zip_model_bundle(output_path) 

115 print("TensorFlow model exported to", output_path) 

116 

117 return TensorflowSavedModelBundleWeightsDescr( 

118 source=output_path, 

119 parent="keras_hdf5", 

120 tensorflow_version=Version(tensorflow.__version__), 

121 comment=f"Converted with bioimageio.core {__version__}.", 

122 ) 

123 

124 

125# adapted from 

126# https://github.com/deepimagej/pydeepimagej/blob/master/pydeepimagej/yaml/create_config.py#L236 

127def _convert_tf1( 

128 keras_weight_path: Path, 

129 output_path: Path, 

130 input_name: str, 

131 output_name: str, 

132) -> TensorflowSavedModelBundleWeightsDescr: 

133 

134 @no_type_check 

135 def build_tf_model(): 

136 keras_model = keras.models.load_model(keras_weight_path) 

137 builder = tensorflow.saved_model.builder.SavedModelBuilder(output_path) 

138 signature = tensorflow.saved_model.signature_def_utils.predict_signature_def( 

139 inputs={input_name: keras_model.input}, 

140 outputs={output_name: keras_model.output}, 

141 ) 

142 

143 signature_def_map = { 

144 tensorflow.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: ( 

145 signature 

146 ) 

147 } 

148 

149 builder.add_meta_graph_and_variables( 

150 keras.backend.get_session(), 

151 [tensorflow.saved_model.tag_constants.SERVING], 

152 signature_def_map=signature_def_map, 

153 ) 

154 builder.save() 

155 

156 build_tf_model() 

157 

158 output_path = _zip_model_bundle(output_path) 

159 print("TensorFlow model exported to", output_path) 

160 

161 return TensorflowSavedModelBundleWeightsDescr( 

162 source=output_path, 

163 parent="keras_hdf5", 

164 tensorflow_version=Version(tensorflow.__version__), 

165 comment=f"Converted with bioimageio.core {__version__}.", 

166 ) 

167 

168 

169def _zip_model_bundle(model_bundle_folder: Path): 

170 zipped_model_bundle = model_bundle_folder.with_suffix(".zip") 

171 

172 with ZipFile(zipped_model_bundle, "w") as zip_obj: 

173 for root, _, files in os.walk(model_bundle_folder): 

174 for filename in files: 

175 src = os.path.join(root, filename) 

176 zip_obj.write(src, os.path.relpath(src, model_bundle_folder)) 

177 

178 try: 

179 shutil.rmtree(model_bundle_folder) 

180 except Exception: 

181 print("TensorFlow bundled model was not removed after compression") 

182 

183 return zipped_model_bundle