Coverage for src/bioimageio/core/backends/tensorflow_backend.py: 0%

99 statements  

« prev     ^ index     » next       coverage.py v7.14.2, created at 2026-06-22 16:54 +0000

1from pathlib import Path 

2from typing import Any, List, Optional, Sequence, Tuple, Union 

3 

4import numpy as np 

5import tensorflow as tf 

6from loguru import logger 

7from numpy.typing import NDArray 

8 

9from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 

10 

11from .._model_adapter import LocalModelAdapter 

12from ..io import ensure_unzipped 

13 

14 

15class TensorflowModelAdapter(LocalModelAdapter[None, Any]): 

16 """Adapter for TensorFlow 1 models""" 

17 

18 weight_format = "tensorflow_saved_model_bundle" 

19 

20 def __init__( 

21 self, 

22 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

23 devices: Optional[Sequence[str]] = None, 

24 ): 

25 

26 if model_description.weights.tensorflow_saved_model_bundle is None: 

27 raise ValueError("No `tensorflow_saved_model_bundle` weights found") 

28 

29 if isinstance(model_description, v0_4.ModelDescr): 

30 self._weight_src = ( 

31 model_description.weights.tensorflow_saved_model_bundle.source 

32 ) 

33 else: 

34 self._weight_src = model_description.weights.tensorflow_saved_model_bundle 

35 

36 self._graph = None 

37 self._io_names: Optional[Tuple[List[str], List[str]]] = None 

38 super().__init__(model_description=model_description, devices=devices) 

39 

40 def _parse_devices(self, devices: Optional[Sequence[str]]) -> Tuple[None]: 

41 if devices is not None: 

42 logger.warning( 

43 f"Device management is not implemented for tensorflow yet, ignoring the devices {devices}" 

44 ) 

45 return (None,) 

46 

47 def _init_model_on_device(self, device: Optional[str]) -> Any: 

48 

49 # TODO: check how to load tf weights without unzipping 

50 weight_file = ensure_unzipped( 

51 self._weight_src, Path("bioimageio_unzipped_tf_weights") 

52 ) 

53 

54 # TODO read from spec 

55 tag = ( # pyright: ignore[reportUnknownVariableType] 

56 tf.saved_model.tag_constants.SERVING 

57 ) 

58 signature_key = ( # pyright: ignore[reportUnknownVariableType] 

59 tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY 

60 ) 

61 

62 self._graph = tf.Graph() 

63 with self._graph.as_default(): 

64 sess = tf.Session(graph=self._graph) # pyright: ignore[reportUnknownVariableType] 

65 # load the model and the signature 

66 graph_def = tf.saved_model.loader.load( # pyright: ignore[reportUnknownVariableType] 

67 sess, [tag], str(weight_file) 

68 ) 

69 signature = ( # pyright: ignore[reportUnknownVariableType] 

70 graph_def.signature_def 

71 ) 

72 

73 # get the tensors into the graph 

74 in_names = [ # pyright: ignore[reportUnknownVariableType] 

75 signature[signature_key].inputs[key].name for key in self._input_ids 

76 ] 

77 out_names = [ # pyright: ignore[reportUnknownVariableType] 

78 signature[signature_key].outputs[key].name for key in self._output_ids 

79 ] 

80 self._io_names = (in_names, out_names) 

81 

82 return sess # pyright: ignore[reportUnknownVariableType] 

83 

84 def _forward_impl( 

85 self, device: None, model: Any, input_arrays: Sequence[Optional[NDArray[Any]]] 

86 ): 

87 assert self._io_names is not None 

88 assert self._graph is not None 

89 

90 in_names, out_names = self._io_names 

91 in_tf_tensors = [self._graph.get_tensor_by_name(name) for name in in_names] 

92 out_tf_tensors = [self._graph.get_tensor_by_name(name) for name in out_names] 

93 

94 # run prediction 

95 res = model.run( 

96 dict(zip(out_names, out_tf_tensors)), 

97 dict(zip(in_tf_tensors, input_arrays)), 

98 ) 

99 # from dict to list of tensors 

100 res = [res[out] for out in out_names] 

101 

102 return res 

103 

104 def _cleanup_pre_model_deletion(self, device: Optional[str], model: Any) -> None: 

105 return 

106 

107 def _cleanup_post_model_deletion(self, device: Optional[str]) -> None: 

108 return 

109 

110 

111class KerasModelAdapter(LocalModelAdapter[None, Any]): 

112 def __init__( 

113 self, 

114 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

115 devices: Optional[Sequence[str]] = None, 

116 ): 

117 if model_description.weights.tensorflow_saved_model_bundle is None: 

118 raise ValueError("No `tensorflow_saved_model_bundle` weights found") 

119 

120 if isinstance(model_description, v0_4.ModelDescr): 

121 self._weight_src = ( 

122 model_description.weights.tensorflow_saved_model_bundle.source 

123 ) 

124 else: 

125 self._weight_src = model_description.weights.tensorflow_saved_model_bundle 

126 

127 super().__init__(model_description=model_description, devices=devices) 

128 

129 def _parse_devices(self, devices: Optional[Sequence[str]]) -> Tuple[None]: 

130 if devices is not None: 

131 logger.warning( 

132 f"Device management is not implemented for tensorflow yet, ignoring the devices {devices}" 

133 ) 

134 return (None,) 

135 

136 def _init_model_on_device(self, device: None) -> Any: 

137 # TODO: check how to load tf weights without unzipping 

138 weight_file = str( 

139 ensure_unzipped(self._weight_src, Path("bioimageio_unzipped_tf_weights")) 

140 ) 

141 

142 try: 

143 tfsm_layer = tf.keras.layers.TFSMLayer( # pyright: ignore[reportUnknownVariableType] 

144 weight_file, 

145 call_endpoint="serve", 

146 ) 

147 except Exception as e: 

148 try: 

149 tfsm_layer = tf.keras.layers.TFSMLayer( # pyright: ignore[reportUnknownVariableType] 

150 weight_file, call_endpoint="serving_default" 

151 ) 

152 except Exception as ee: 

153 logger.opt(exception=ee).info( 

154 "keras.layers.TFSMLayer error for alternative call_endpoint='serving_default'" 

155 ) 

156 raise e 

157 

158 return tfsm_layer # pyright: ignore[reportUnknownVariableType] 

159 

160 def _forward_impl( # pyright: ignore[reportUnknownParameterType] 

161 self, device: None, model: Any, input_arrays: Sequence[Optional[NDArray[Any]]] 

162 ): 

163 assert tf is not None 

164 tf_tensor = [ 

165 None if ipt is None else tf.convert_to_tensor(ipt) for ipt in input_arrays 

166 ] 

167 result = model(*tf_tensor) 

168 assert isinstance(result, dict) 

169 

170 # TODO: Use RDF's `outputs[i].id` here 

171 result = list( # pyright: ignore[reportUnknownVariableType] 

172 result.values() # pyright: ignore[reportUnknownArgumentType] 

173 ) 

174 

175 return [ # pyright: ignore[reportUnknownVariableType] 

176 (None if r is None else r if isinstance(r, np.ndarray) else r.numpy()) 

177 for r in result # pyright: ignore[reportUnknownVariableType] 

178 ] 

179 

180 def _cleanup_pre_model_deletion(self, device: Optional[str], model: Any) -> None: 

181 return 

182 

183 def _cleanup_post_model_deletion(self, device: Optional[str]) -> None: 

184 return 

185 

186 

187def create_tf_model_adapter( 

188 model_description: AnyModelDescr, devices: Optional[Sequence[str]] = None 

189): 

190 tf_version = v0_5.Version(tf.__version__) # type: ignore[reportUnknownVariableType] 

191 weights = model_description.weights.tensorflow_saved_model_bundle 

192 if weights is None: 

193 raise ValueError("No `tensorflow_saved_model_bundle` weights found") 

194 

195 model_tf_version = weights.tensorflow_version 

196 if model_tf_version is None: 

197 logger.warning( 

198 "The model does not specify the tensorflow version." 

199 + f"Cannot check if it is compatible with intalled tensorflow {tf_version}." 

200 ) 

201 elif model_tf_version > tf_version: 

202 logger.warning( 

203 f"The model specifies a newer tensorflow version than installed: {model_tf_version} > {tf_version}." 

204 ) 

205 elif (model_tf_version.major, model_tf_version.minor) != ( 

206 tf_version.major, 

207 tf_version.minor, 

208 ): 

209 logger.warning( 

210 "The tensorflow version specified by the model does not match the installed: " 

211 + f"{model_tf_version} != {tf_version}." 

212 ) 

213 

214 if tf_version.major <= 1: 

215 return TensorflowModelAdapter(model_description, devices=devices) 

216 else: 

217 return KerasModelAdapter(model_description, devices=devices)