Coverage for src / bioimageio / core / backends / tensorflow_backend.py: 58%

77 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-27 22:06 +0000

1from pathlib import Path 

2from typing import Any, Optional, Sequence, Union 

3 

4import numpy as np 

5import tensorflow as tf 

6from loguru import logger 

7from numpy.typing import NDArray 

8 

9from bioimageio.core.io import ensure_unzipped 

10from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 

11 

12from ._model_adapter import ModelAdapter 

13 

14 

15class TensorflowModelAdapter(ModelAdapter): 

16 weight_format = "tensorflow_saved_model_bundle" 

17 

18 def __init__( 

19 self, 

20 *, 

21 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

22 devices: Optional[Sequence[str]] = None, 

23 ): 

24 super().__init__(model_description=model_description) 

25 

26 weight_file = model_description.weights.tensorflow_saved_model_bundle 

27 if model_description.weights.tensorflow_saved_model_bundle is None: 

28 raise ValueError("No `tensorflow_saved_model_bundle` weights found") 

29 

30 if devices is not None: 

31 logger.warning( 

32 f"Device management is not implemented for tensorflow yet, ignoring the devices {devices}" 

33 ) 

34 

35 # TODO: check how to load tf weights without unzipping 

36 weight_file = ensure_unzipped( 

37 model_description.weights.tensorflow_saved_model_bundle.source, 

38 Path("bioimageio_unzipped_tf_weights"), 

39 ) 

40 self._network = str(weight_file) 

41 

42 # TODO currently we relaod the model every time. it would be better to keep the graph and session 

43 # alive in between of forward passes (but then the sessions need to be properly opened / closed) 

44 def _forward_impl( # pyright: ignore[reportUnknownParameterType] 

45 self, input_arrays: Sequence[Optional[NDArray[Any]]] 

46 ): 

47 # TODO read from spec 

48 tag = ( # pyright: ignore[reportUnknownVariableType] 

49 tf.saved_model.tag_constants.SERVING 

50 ) 

51 signature_key = ( # pyright: ignore[reportUnknownVariableType] 

52 tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY 

53 ) 

54 

55 graph = tf.Graph() 

56 with graph.as_default(): 

57 with tf.Session(graph=graph) as sess: # pyright: ignore[reportUnknownVariableType] 

58 # load the model and the signature 

59 graph_def = tf.saved_model.loader.load( # pyright: ignore[reportUnknownVariableType] 

60 sess, [tag], self._network 

61 ) 

62 signature = ( # pyright: ignore[reportUnknownVariableType] 

63 graph_def.signature_def 

64 ) 

65 

66 # get the tensors into the graph 

67 in_names = [ # pyright: ignore[reportUnknownVariableType] 

68 signature[signature_key].inputs[key].name for key in self._input_ids 

69 ] 

70 out_names = [ # pyright: ignore[reportUnknownVariableType] 

71 signature[signature_key].outputs[key].name 

72 for key in self._output_ids 

73 ] 

74 in_tf_tensors = [ 

75 graph.get_tensor_by_name( 

76 name # pyright: ignore[reportUnknownArgumentType] 

77 ) 

78 for name in in_names # pyright: ignore[reportUnknownVariableType] 

79 ] 

80 out_tf_tensors = [ 

81 graph.get_tensor_by_name( 

82 name # pyright: ignore[reportUnknownArgumentType] 

83 ) 

84 for name in out_names # pyright: ignore[reportUnknownVariableType] 

85 ] 

86 

87 # run prediction 

88 res = sess.run( # pyright: ignore[reportUnknownVariableType] 

89 dict( 

90 zip( 

91 out_names, # pyright: ignore[reportUnknownArgumentType] 

92 out_tf_tensors, 

93 ) 

94 ), 

95 dict(zip(in_tf_tensors, input_arrays)), 

96 ) 

97 # from dict to list of tensors 

98 res = [ # pyright: ignore[reportUnknownVariableType] 

99 res[out] 

100 for out in out_names # pyright: ignore[reportUnknownVariableType] 

101 ] 

102 

103 return res # pyright: ignore[reportUnknownVariableType] 

104 

105 def unload(self) -> None: 

106 logger.warning( 

107 "Device management is not implemented for tensorflow 1, cannot unload model" 

108 ) 

109 

110 

111class KerasModelAdapter(ModelAdapter): 

112 def __init__( 

113 self, 

114 *, 

115 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

116 devices: Optional[Sequence[str]] = None, 

117 ): 

118 if model_description.weights.tensorflow_saved_model_bundle is None: 

119 raise ValueError("No `tensorflow_saved_model_bundle` weights found") 

120 

121 super().__init__(model_description=model_description) 

122 if devices is not None: 

123 logger.warning( 

124 f"Device management is not implemented for tensorflow yet, ignoring the devices {devices}" 

125 ) 

126 

127 # TODO: check how to load tf weights without unzipping 

128 weight_file = ensure_unzipped( 

129 model_description.weights.tensorflow_saved_model_bundle.source, 

130 Path("bioimageio_unzipped_tf_weights"), 

131 ) 

132 

133 try: 

134 self._network = tf.keras.layers.TFSMLayer( 

135 weight_file, 

136 call_endpoint="serve", 

137 ) 

138 except Exception as e: 

139 try: 

140 self._network = tf.keras.layers.TFSMLayer( 

141 weight_file, call_endpoint="serving_default" 

142 ) 

143 except Exception as ee: 

144 logger.opt(exception=ee).info( 

145 "keras.layers.TFSMLayer error for alternative call_endpoint='serving_default'" 

146 ) 

147 raise e 

148 

149 def _forward_impl( # pyright: ignore[reportUnknownParameterType] 

150 self, input_arrays: Sequence[Optional[NDArray[Any]]] 

151 ): 

152 assert tf is not None 

153 tf_tensor = [ 

154 None if ipt is None else tf.convert_to_tensor(ipt) for ipt in input_arrays 

155 ] 

156 

157 result = self._network(*tf_tensor) # pyright: ignore[reportUnknownVariableType] 

158 

159 assert isinstance(result, dict) 

160 

161 # TODO: Use RDF's `outputs[i].id` here 

162 result = list( # pyright: ignore[reportUnknownVariableType] 

163 result.values() # pyright: ignore[reportUnknownArgumentType] 

164 ) 

165 

166 return [ # pyright: ignore[reportUnknownVariableType] 

167 (None if r is None else r if isinstance(r, np.ndarray) else r.numpy()) 

168 for r in result # pyright: ignore[reportUnknownVariableType] 

169 ] 

170 

171 def unload(self) -> None: 

172 logger.warning( 

173 "Device management is not implemented for tensorflow>=2 models" 

174 + f" using `{self.__class__.__name__}`, cannot unload model" 

175 ) 

176 

177 

178def create_tf_model_adapter( 

179 model_description: AnyModelDescr, devices: Optional[Sequence[str]] 

180): 

181 tf_version = v0_5.Version(tf.__version__) # type: ignore[reportUnknownVariableType] 

182 weights = model_description.weights.tensorflow_saved_model_bundle 

183 if weights is None: 

184 raise ValueError("No `tensorflow_saved_model_bundle` weights found") 

185 

186 model_tf_version = weights.tensorflow_version 

187 if model_tf_version is None: 

188 logger.warning( 

189 "The model does not specify the tensorflow version." 

190 + f"Cannot check if it is compatible with intalled tensorflow {tf_version}." 

191 ) 

192 elif model_tf_version > tf_version: 

193 logger.warning( 

194 f"The model specifies a newer tensorflow version than installed: {model_tf_version} > {tf_version}." 

195 ) 

196 elif (model_tf_version.major, model_tf_version.minor) != ( 

197 tf_version.major, 

198 tf_version.minor, 

199 ): 

200 logger.warning( 

201 "The tensorflow version specified by the model does not match the installed: " 

202 + f"{model_tf_version} != {tf_version}." 

203 ) 

204 

205 if tf_version.major <= 1: 

206 return TensorflowModelAdapter( 

207 model_description=model_description, devices=devices 

208 ) 

209 else: 

210 return KerasModelAdapter(model_description=model_description, devices=devices)