Coverage for bioimageio/core/model_adapters/_tensorflow_model_adapter.py: 29%

105 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-19 09:02 +0000

1import zipfile 

2from typing import List, Literal, Optional, Sequence, Union 

3 

4import numpy as np 

5from loguru import logger 

6 

7from bioimageio.spec.common import FileSource 

8from bioimageio.spec.model import v0_4, v0_5 

9from bioimageio.spec.utils import download 

10 

11from ..digest_spec import get_axes_infos 

12from ..tensor import Tensor 

13from ._model_adapter import ModelAdapter 

14 

15try: 

16 import tensorflow as tf # pyright: ignore[reportMissingImports] 

17except Exception as e: 

18 tf = None 

19 tf_error = str(e) 

20else: 

21 tf_error = None 

22 

23 

24class TensorflowModelAdapterBase(ModelAdapter): 

25 weight_format: Literal["keras_hdf5", "tensorflow_saved_model_bundle"] 

26 

27 def __init__( 

28 self, 

29 *, 

30 devices: Optional[Sequence[str]] = None, 

31 weights: Union[ 

32 v0_4.KerasHdf5WeightsDescr, 

33 v0_4.TensorflowSavedModelBundleWeightsDescr, 

34 v0_5.KerasHdf5WeightsDescr, 

35 v0_5.TensorflowSavedModelBundleWeightsDescr, 

36 ], 

37 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

38 ): 

39 if tf is None: 

40 raise ImportError(f"failed to import tensorflow: {tf_error}") 

41 

42 super().__init__() 

43 self.model_description = model_description 

44 tf_version = v0_5.Version( 

45 tf.__version__ # pyright: ignore[reportUnknownArgumentType] 

46 ) 

47 model_tf_version = weights.tensorflow_version 

48 if model_tf_version is None: 

49 logger.warning( 

50 "The model does not specify the tensorflow version." 

51 + f"Cannot check if it is compatible with intalled tensorflow {tf_version}." 

52 ) 

53 elif model_tf_version > tf_version: 

54 logger.warning( 

55 f"The model specifies a newer tensorflow version than installed: {model_tf_version} > {tf_version}." 

56 ) 

57 elif (model_tf_version.major, model_tf_version.minor) != ( 

58 tf_version.major, 

59 tf_version.minor, 

60 ): 

61 logger.warning( 

62 "The tensorflow version specified by the model does not match the installed: " 

63 + f"{model_tf_version} != {tf_version}." 

64 ) 

65 

66 self.use_keras_api = ( 

67 tf_version.major > 1 

68 or self.weight_format == KerasModelAdapter.weight_format 

69 ) 

70 

71 # TODO tf device management 

72 if devices is not None: 

73 logger.warning( 

74 f"Device management is not implemented for tensorflow yet, ignoring the devices {devices}" 

75 ) 

76 

77 weight_file = self.require_unzipped(weights.source) 

78 self._network = self._get_network(weight_file) 

79 self._internal_output_axes = [ 

80 tuple(a.id for a in get_axes_infos(out)) 

81 for out in model_description.outputs 

82 ] 

83 

84 def require_unzipped(self, weight_file: FileSource): 

85 loacl_weights_file = download(weight_file).path 

86 if zipfile.is_zipfile(loacl_weights_file): 

87 out_path = loacl_weights_file.with_suffix(".unzipped") 

88 with zipfile.ZipFile(loacl_weights_file, "r") as f: 

89 f.extractall(out_path) 

90 

91 return out_path 

92 else: 

93 return loacl_weights_file 

94 

95 def _get_network( # pyright: ignore[reportUnknownParameterType] 

96 self, weight_file: FileSource 

97 ): 

98 weight_file = self.require_unzipped(weight_file) 

99 assert tf is not None 

100 if self.use_keras_api: 

101 try: 

102 return tf.keras.layers.TFSMLayer( 

103 weight_file, call_endpoint="serve" 

104 ) # pyright: ignore[reportUnknownVariableType] 

105 except Exception as e: 

106 try: 

107 return tf.keras.layers.TFSMLayer( 

108 weight_file, call_endpoint="serving_default" 

109 ) # pyright: ignore[reportUnknownVariableType] 

110 except Exception as ee: 

111 logger.opt(exception=ee).info( 

112 "keras.layers.TFSMLayer error for alternative call_endpoint='serving_default'" 

113 ) 

114 raise e 

115 else: 

116 # NOTE in tf1 the model needs to be loaded inside of the session, so we cannot preload the model 

117 return str(weight_file) 

118 

119 # TODO currently we relaod the model every time. it would be better to keep the graph and session 

120 # alive in between of forward passes (but then the sessions need to be properly opened / closed) 

121 def _forward_tf( # pyright: ignore[reportUnknownParameterType] 

122 self, *input_tensors: Optional[Tensor] 

123 ): 

124 assert tf is not None 

125 input_keys = [ 

126 ipt.name if isinstance(ipt, v0_4.InputTensorDescr) else ipt.id 

127 for ipt in self.model_description.inputs 

128 ] 

129 output_keys = [ 

130 out.name if isinstance(out, v0_4.OutputTensorDescr) else out.id 

131 for out in self.model_description.outputs 

132 ] 

133 # TODO read from spec 

134 tag = ( # pyright: ignore[reportUnknownVariableType] 

135 tf.saved_model.tag_constants.SERVING 

136 ) 

137 signature_key = ( # pyright: ignore[reportUnknownVariableType] 

138 tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY 

139 ) 

140 

141 graph = tf.Graph() # pyright: ignore[reportUnknownVariableType] 

142 with graph.as_default(): 

143 with tf.Session( 

144 graph=graph 

145 ) as sess: # pyright: ignore[reportUnknownVariableType] 

146 # load the model and the signature 

147 graph_def = tf.saved_model.loader.load( # pyright: ignore[reportUnknownVariableType] 

148 sess, [tag], self._network 

149 ) 

150 signature = ( # pyright: ignore[reportUnknownVariableType] 

151 graph_def.signature_def 

152 ) 

153 

154 # get the tensors into the graph 

155 in_names = [ # pyright: ignore[reportUnknownVariableType] 

156 signature[signature_key].inputs[key].name for key in input_keys 

157 ] 

158 out_names = [ # pyright: ignore[reportUnknownVariableType] 

159 signature[signature_key].outputs[key].name for key in output_keys 

160 ] 

161 in_tensors = [ # pyright: ignore[reportUnknownVariableType] 

162 graph.get_tensor_by_name(name) 

163 for name in in_names # pyright: ignore[reportUnknownVariableType] 

164 ] 

165 out_tensors = [ # pyright: ignore[reportUnknownVariableType] 

166 graph.get_tensor_by_name(name) 

167 for name in out_names # pyright: ignore[reportUnknownVariableType] 

168 ] 

169 

170 # run prediction 

171 res = sess.run( # pyright: ignore[reportUnknownVariableType] 

172 dict( 

173 zip( 

174 out_names, # pyright: ignore[reportUnknownArgumentType] 

175 out_tensors, # pyright: ignore[reportUnknownArgumentType] 

176 ) 

177 ), 

178 dict( 

179 zip( 

180 in_tensors, # pyright: ignore[reportUnknownArgumentType] 

181 input_tensors, 

182 ) 

183 ), 

184 ) 

185 # from dict to list of tensors 

186 res = [ # pyright: ignore[reportUnknownVariableType] 

187 res[out] 

188 for out in out_names # pyright: ignore[reportUnknownVariableType] 

189 ] 

190 

191 return res # pyright: ignore[reportUnknownVariableType] 

192 

193 def _forward_keras( # pyright: ignore[reportUnknownParameterType] 

194 self, *input_tensors: Optional[Tensor] 

195 ): 

196 assert self.use_keras_api 

197 assert not isinstance(self._network, str) 

198 assert tf is not None 

199 tf_tensor = [ # pyright: ignore[reportUnknownVariableType] 

200 None if ipt is None else tf.convert_to_tensor(ipt) for ipt in input_tensors 

201 ] 

202 

203 result = self._network(*tf_tensor) # pyright: ignore[reportUnknownVariableType] 

204 

205 assert isinstance(result, dict) 

206 

207 # TODO: Use RDF's `outputs[i].id` here 

208 result = list(result.values()) 

209 

210 return [ # pyright: ignore[reportUnknownVariableType] 

211 (None if r is None else r if isinstance(r, np.ndarray) else r.numpy()) 

212 for r in result # pyright: ignore[reportUnknownVariableType] 

213 ] 

214 

215 def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]: 

216 data = [None if ipt is None else ipt.data for ipt in input_tensors] 

217 if self.use_keras_api: 

218 result = self._forward_keras( # pyright: ignore[reportUnknownVariableType] 

219 *data 

220 ) 

221 else: 

222 result = self._forward_tf( # pyright: ignore[reportUnknownVariableType] 

223 *data 

224 ) 

225 

226 return [ 

227 None if r is None else Tensor(r, dims=axes) 

228 for r, axes in zip( # pyright: ignore[reportUnknownVariableType] 

229 result, # pyright: ignore[reportUnknownArgumentType] 

230 self._internal_output_axes, 

231 ) 

232 ] 

233 

234 def unload(self) -> None: 

235 logger.warning( 

236 "Device management is not implemented for keras yet, cannot unload model" 

237 ) 

238 

239 

240class TensorflowModelAdapter(TensorflowModelAdapterBase): 

241 weight_format = "tensorflow_saved_model_bundle" 

242 

243 def __init__( 

244 self, 

245 *, 

246 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

247 devices: Optional[Sequence[str]] = None, 

248 ): 

249 if model_description.weights.tensorflow_saved_model_bundle is None: 

250 raise ValueError("missing tensorflow_saved_model_bundle weights") 

251 

252 super().__init__( 

253 devices=devices, 

254 weights=model_description.weights.tensorflow_saved_model_bundle, 

255 model_description=model_description, 

256 ) 

257 

258 

259class KerasModelAdapter(TensorflowModelAdapterBase): 

260 weight_format = "keras_hdf5" 

261 

262 def __init__( 

263 self, 

264 *, 

265 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

266 devices: Optional[Sequence[str]] = None, 

267 ): 

268 if model_description.weights.keras_hdf5 is None: 

269 raise ValueError("missing keras_hdf5 weights") 

270 

271 super().__init__( 

272 model_description=model_description, 

273 devices=devices, 

274 weights=model_description.weights.keras_hdf5, 

275 )