Coverage for bioimageio/core/backends/tensorflow_backend.py: 58%

77 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-01 09:51 +0000

1from pathlib import Path 

2from typing import Any, Optional, Sequence, Union 

3 

4import numpy as np 

5import tensorflow as tf # pyright: ignore[reportMissingTypeStubs] 

6from loguru import logger 

7from numpy.typing import NDArray 

8 

9from bioimageio.core.io import ensure_unzipped 

10from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 

11 

12from ._model_adapter import ModelAdapter 

13 

14 

15class TensorflowModelAdapter(ModelAdapter): 

16 weight_format = "tensorflow_saved_model_bundle" 

17 

18 def __init__( 

19 self, 

20 *, 

21 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

22 devices: Optional[Sequence[str]] = None, 

23 ): 

24 super().__init__(model_description=model_description) 

25 

26 weight_file = model_description.weights.tensorflow_saved_model_bundle 

27 if model_description.weights.tensorflow_saved_model_bundle is None: 

28 raise ValueError("No `tensorflow_saved_model_bundle` weights found") 

29 

30 if devices is not None: 

31 logger.warning( 

32 f"Device management is not implemented for tensorflow yet, ignoring the devices {devices}" 

33 ) 

34 

35 # TODO: check how to load tf weights without unzipping 

36 weight_file = ensure_unzipped( 

37 model_description.weights.tensorflow_saved_model_bundle.source, 

38 Path("bioimageio_unzipped_tf_weights"), 

39 ) 

40 self._network = str(weight_file) 

41 

42 # TODO currently we relaod the model every time. it would be better to keep the graph and session 

43 # alive in between of forward passes (but then the sessions need to be properly opened / closed) 

44 def _forward_impl( # pyright: ignore[reportUnknownParameterType] 

45 self, input_arrays: Sequence[Optional[NDArray[Any]]] 

46 ): 

47 # TODO read from spec 

48 tag = ( # pyright: ignore[reportUnknownVariableType] 

49 tf.saved_model.tag_constants.SERVING # pyright: ignore[reportAttributeAccessIssue] 

50 ) 

51 signature_key = ( # pyright: ignore[reportUnknownVariableType] 

52 tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY # pyright: ignore[reportAttributeAccessIssue] 

53 ) 

54 

55 graph = tf.Graph() 

56 with graph.as_default(): 

57 with tf.Session( # pyright: ignore[reportAttributeAccessIssue] 

58 graph=graph 

59 ) as sess: # pyright: ignore[reportUnknownVariableType] 

60 # load the model and the signature 

61 graph_def = tf.saved_model.loader.load( # pyright: ignore[reportUnknownVariableType,reportAttributeAccessIssue] 

62 sess, [tag], self._network 

63 ) 

64 signature = ( # pyright: ignore[reportUnknownVariableType] 

65 graph_def.signature_def 

66 ) 

67 

68 # get the tensors into the graph 

69 in_names = [ # pyright: ignore[reportUnknownVariableType] 

70 signature[signature_key].inputs[key].name for key in self._input_ids 

71 ] 

72 out_names = [ # pyright: ignore[reportUnknownVariableType] 

73 signature[signature_key].outputs[key].name 

74 for key in self._output_ids 

75 ] 

76 in_tf_tensors = [ 

77 graph.get_tensor_by_name( 

78 name # pyright: ignore[reportUnknownArgumentType] 

79 ) 

80 for name in in_names # pyright: ignore[reportUnknownVariableType] 

81 ] 

82 out_tf_tensors = [ 

83 graph.get_tensor_by_name( 

84 name # pyright: ignore[reportUnknownArgumentType] 

85 ) 

86 for name in out_names # pyright: ignore[reportUnknownVariableType] 

87 ] 

88 

89 # run prediction 

90 res = sess.run( # pyright: ignore[reportUnknownVariableType] 

91 dict( 

92 zip( 

93 out_names, # pyright: ignore[reportUnknownArgumentType] 

94 out_tf_tensors, 

95 ) 

96 ), 

97 dict(zip(in_tf_tensors, input_arrays)), 

98 ) 

99 # from dict to list of tensors 

100 res = [ # pyright: ignore[reportUnknownVariableType] 

101 res[out] 

102 for out in out_names # pyright: ignore[reportUnknownVariableType] 

103 ] 

104 

105 return res # pyright: ignore[reportUnknownVariableType] 

106 

107 def unload(self) -> None: 

108 logger.warning( 

109 "Device management is not implemented for tensorflow 1, cannot unload model" 

110 ) 

111 

112 

113class KerasModelAdapter(ModelAdapter): 

114 def __init__( 

115 self, 

116 *, 

117 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

118 devices: Optional[Sequence[str]] = None, 

119 ): 

120 if model_description.weights.tensorflow_saved_model_bundle is None: 

121 raise ValueError("No `tensorflow_saved_model_bundle` weights found") 

122 

123 super().__init__(model_description=model_description) 

124 if devices is not None: 

125 logger.warning( 

126 f"Device management is not implemented for tensorflow yet, ignoring the devices {devices}" 

127 ) 

128 

129 # TODO: check how to load tf weights without unzipping 

130 weight_file = ensure_unzipped( 

131 model_description.weights.tensorflow_saved_model_bundle.source, 

132 Path("bioimageio_unzipped_tf_weights"), 

133 ) 

134 

135 try: 

136 self._network = tf.keras.layers.TFSMLayer( # pyright: ignore[reportAttributeAccessIssue] 

137 weight_file, 

138 call_endpoint="serve", 

139 ) 

140 except Exception as e: 

141 try: 

142 self._network = tf.keras.layers.TFSMLayer( # pyright: ignore[reportAttributeAccessIssue] 

143 weight_file, call_endpoint="serving_default" 

144 ) 

145 except Exception as ee: 

146 logger.opt(exception=ee).info( 

147 "keras.layers.TFSMLayer error for alternative call_endpoint='serving_default'" 

148 ) 

149 raise e 

150 

151 def _forward_impl( # pyright: ignore[reportUnknownParameterType] 

152 self, input_arrays: Sequence[Optional[NDArray[Any]]] 

153 ): 

154 assert tf is not None 

155 tf_tensor = [ 

156 None if ipt is None else tf.convert_to_tensor(ipt) for ipt in input_arrays 

157 ] 

158 

159 result = self._network(*tf_tensor) # pyright: ignore[reportUnknownVariableType] 

160 

161 assert isinstance(result, dict) 

162 

163 # TODO: Use RDF's `outputs[i].id` here 

164 result = list( # pyright: ignore[reportUnknownVariableType] 

165 result.values() # pyright: ignore[reportUnknownArgumentType] 

166 ) 

167 

168 return [ # pyright: ignore[reportUnknownVariableType] 

169 (None if r is None else r if isinstance(r, np.ndarray) else r.numpy()) 

170 for r in result # pyright: ignore[reportUnknownVariableType] 

171 ] 

172 

173 def unload(self) -> None: 

174 logger.warning( 

175 "Device management is not implemented for tensorflow>=2 models" 

176 + f" using `{self.__class__.__name__}`, cannot unload model" 

177 ) 

178 

179 

180def create_tf_model_adapter( 

181 model_description: AnyModelDescr, devices: Optional[Sequence[str]] 

182): 

183 tf_version = v0_5.Version(tf.__version__) 

184 weights = model_description.weights.tensorflow_saved_model_bundle 

185 if weights is None: 

186 raise ValueError("No `tensorflow_saved_model_bundle` weights found") 

187 

188 model_tf_version = weights.tensorflow_version 

189 if model_tf_version is None: 

190 logger.warning( 

191 "The model does not specify the tensorflow version." 

192 + f"Cannot check if it is compatible with intalled tensorflow {tf_version}." 

193 ) 

194 elif model_tf_version > tf_version: 

195 logger.warning( 

196 f"The model specifies a newer tensorflow version than installed: {model_tf_version} > {tf_version}." 

197 ) 

198 elif (model_tf_version.major, model_tf_version.minor) != ( 

199 tf_version.major, 

200 tf_version.minor, 

201 ): 

202 logger.warning( 

203 "The tensorflow version specified by the model does not match the installed: " 

204 + f"{model_tf_version} != {tf_version}." 

205 ) 

206 

207 if tf_version.major <= 1: 

208 return TensorflowModelAdapter( 

209 model_description=model_description, devices=devices 

210 ) 

211 else: 

212 return KerasModelAdapter(model_description=model_description, devices=devices)