Coverage for bioimageio/core/weight_converter/keras/_tensorflow.py: 20%

75 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-19 09:02 +0000

1# type: ignore # TODO: type 

2import os 

3import shutil 

4from pathlib import Path 

5from typing import no_type_check 

6from zipfile import ZipFile 

7 

8try: 

9 import tensorflow.saved_model 

10except Exception: 

11 tensorflow = None 

12 

13from bioimageio.spec._internal.io_utils import download 

14from bioimageio.spec.model.v0_5 import ModelDescr 

15 

16 

17def _zip_model_bundle(model_bundle_folder: Path): 

18 zipped_model_bundle = model_bundle_folder.with_suffix(".zip") 

19 

20 with ZipFile(zipped_model_bundle, "w") as zip_obj: 

21 for root, _, files in os.walk(model_bundle_folder): 

22 for filename in files: 

23 src = os.path.join(root, filename) 

24 zip_obj.write(src, os.path.relpath(src, model_bundle_folder)) 

25 

26 try: 

27 shutil.rmtree(model_bundle_folder) 

28 except Exception: 

29 print("TensorFlow bundled model was not removed after compression") 

30 

31 return zipped_model_bundle 

32 

33 

34# adapted from 

35# https://github.com/deepimagej/pydeepimagej/blob/master/pydeepimagej/yaml/create_config.py#L236 

36def _convert_tf1( 

37 keras_weight_path: Path, 

38 output_path: Path, 

39 input_name: str, 

40 output_name: str, 

41 zip_weights: bool, 

42): 

43 try: 

44 # try to build the tf model with the keras import from tensorflow 

45 from bioimageio.core.weight_converter.keras._tensorflow import ( 

46 keras, # type: ignore 

47 ) 

48 

49 except Exception: 

50 # if the above fails try to export with the standalone keras 

51 import keras 

52 

53 @no_type_check 

54 def build_tf_model(): 

55 keras_model = keras.models.load_model(keras_weight_path) 

56 assert tensorflow is not None 

57 builder = tensorflow.saved_model.builder.SavedModelBuilder(output_path) 

58 signature = tensorflow.saved_model.signature_def_utils.predict_signature_def( 

59 inputs={input_name: keras_model.input}, 

60 outputs={output_name: keras_model.output}, 

61 ) 

62 

63 signature_def_map = { 

64 tensorflow.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature 

65 } 

66 

67 builder.add_meta_graph_and_variables( 

68 keras.backend.get_session(), 

69 [tensorflow.saved_model.tag_constants.SERVING], 

70 signature_def_map=signature_def_map, 

71 ) 

72 builder.save() 

73 

74 build_tf_model() 

75 

76 if zip_weights: 

77 output_path = _zip_model_bundle(output_path) 

78 print("TensorFlow model exported to", output_path) 

79 

80 return 0 

81 

82 

83def _convert_tf2(keras_weight_path: Path, output_path: Path, zip_weights: bool): 

84 try: 

85 # try to build the tf model with the keras import from tensorflow 

86 from bioimageio.core.weight_converter.keras._tensorflow import keras 

87 except Exception: 

88 # if the above fails try to export with the standalone keras 

89 import keras 

90 

91 model = keras.models.load_model(keras_weight_path) 

92 keras.models.save_model(model, output_path) 

93 

94 if zip_weights: 

95 output_path = _zip_model_bundle(output_path) 

96 print("TensorFlow model exported to", output_path) 

97 

98 return 0 

99 

100 

101def convert_weights_to_tensorflow_saved_model_bundle( 

102 model: ModelDescr, output_path: Path 

103): 

104 """Convert model weights from format 'keras_hdf5' to 'tensorflow_saved_model_bundle'. 

105 

106 Adapted from 

107 https://github.com/deepimagej/pydeepimagej/blob/5aaf0e71f9b04df591d5ca596f0af633a7e024f5/pydeepimagej/yaml/create_config.py 

108 

109 Args: 

110 model: The bioimageio model description 

111 output_path: where to save the tensorflow weights. This path must not exist yet. 

112 """ 

113 assert tensorflow is not None 

114 tf_major_ver = int(tensorflow.__version__.split(".")[0]) 

115 

116 if output_path.suffix == ".zip": 

117 output_path = output_path.with_suffix("") 

118 zip_weights = True 

119 else: 

120 zip_weights = False 

121 

122 if output_path.exists(): 

123 raise ValueError(f"The ouptut directory at {output_path} must not exist.") 

124 

125 if model.weights.keras_hdf5 is None: 

126 raise ValueError("Missing Keras Hdf5 weights to convert from.") 

127 

128 weight_spec = model.weights.keras_hdf5 

129 weight_path = download(weight_spec.source).path 

130 

131 if weight_spec.tensorflow_version: 

132 model_tf_major_ver = int(weight_spec.tensorflow_version.major) 

133 if model_tf_major_ver != tf_major_ver: 

134 raise RuntimeError( 

135 f"Tensorflow major versions of model {model_tf_major_ver} is not {tf_major_ver}" 

136 ) 

137 

138 if tf_major_ver == 1: 

139 if len(model.inputs) != 1 or len(model.outputs) != 1: 

140 raise NotImplementedError( 

141 "Weight conversion for models with multiple inputs or outputs is not yet implemented." 

142 ) 

143 return _convert_tf1( 

144 weight_path, 

145 output_path, 

146 model.inputs[0].id, 

147 model.outputs[0].id, 

148 zip_weights, 

149 ) 

150 else: 

151 return _convert_tf2(weight_path, output_path, zip_weights)