Coverage for src/bioimageio/core/remote_backends/gradio/client.py: 0%

103 statements  

« prev     ^ index     » next       coverage.py v7.14.2, created at 2026-06-22 16:54 +0000

1from types import MappingProxyType 

2from typing import Dict, Iterable, Literal, Mapping, Optional, Tuple, Union 

3 

4from gradio_client import Client 

5from loguru import logger 

6 

7from bioimageio.spec import AnyModelDescr, ValidationSummary 

8from bioimageio.spec.model import v0_4 

9 

10from ..._description_serializer import DescriptionSerializer as DescriptionSerializer 

11from ..._model_adapter import RemoteModelAdapter 

12from ..._prediction_pipeline import IntermediatePrediction, RemotePredictionPipeline 

13from ..._settings import settings 

14from ...axis import PerAxis 

15from ...common import BlocksizeParameter, PerMember 

16from ...io import JsonValue 

17from ...sample import Sample, SampleBlock 

18from ...stat_measures import Measure, MeasureValue 

19from .serializer import GradioSampleSerializer 

20 

21SerializedSampleBlock = Dict[str, JsonValue] 

22 

23 

24class GradioModelAdapter(RemoteModelAdapter[SerializedSampleBlock]): 

25 """Model adapter to use the bioimage-io-gradio-runner as a backend for model inference.""" 

26 

27 def __init__( 

28 self, model_description: AnyModelDescr, *, server: Optional[str] = None 

29 ): 

30 """Initialize the GradioModelAdapter. 

31 

32 Note: 

33 - This adapter requires an environment with the same gradio version as the one used on the bioimage-io-gradio-runner server. 

34 

35 Args: 

36 model_description: The model to run inference with. 

37 server: The URL of a running bioimage-io-gradio-server instance (default server might not be availability/compatible). 

38 """ 

39 server = server or settings.gradio_server 

40 if server is None: 

41 raise ValueError( 

42 "No gradio server specified. Please provide a server URL or set the 'BIOIMAGEIO_GRADIO_SERVER' environment variable." 

43 ) 

44 

45 self._client = Client(server, httpx_kwargs={"timeout": 60}) 

46 self._serialized_model, self._sha256 = ( 

47 DescriptionSerializer.serialize_to_string_and_hash(model_description) 

48 ) 

49 super().__init__( 

50 model_description, server=server, sample_serializer=GradioSampleSerializer() 

51 ) 

52 

53 def _forward_impl( 

54 self, serialized_input_sample: Iterable[SerializedSampleBlock] 

55 ) -> Iterable[SerializedSampleBlock]: 

56 return _call_predict_api( 

57 self._client, 

58 self._serialized_model, 

59 self._sha256, 

60 serialized_input_sample, 

61 blocksize=None, 

62 skip_preprocessing=True, 

63 skip_postprocessing=True, 

64 skip_input_padding=True, 

65 skip_output_cropping=True, 

66 batch_size=None, 

67 ) 

68 

69 def unload(self): 

70 return super().unload() 

71 

72 def load(self) -> None: 

73 for model_data in ("", self._serialized_model): 

74 try: 

75 result = self._client.submit( 

76 api_name="/load_model", model=model_data, sha256=self._sha256 

77 ).result() 

78 except Exception as e: 

79 if model_data: 

80 logger.warning( 

81 "Failed to load model on server with model_data, error was: {}", 

82 len(model_data), 

83 e, 

84 ) 

85 else: 

86 if result: 

87 break 

88 

89 def test(self) -> Optional[ValidationSummary]: 

90 for model_data in ("", self._serialized_model): 

91 try: 

92 result = self._client.submit( 

93 api_name="/test_model", model=model_data, sha256=self._sha256 

94 ).result() 

95 except Exception as e: 

96 if model_data: 

97 logger.warning( 

98 "Failed to test model on server with model_data, error was: {}", 

99 len(model_data), 

100 e, 

101 ) 

102 else: 

103 if result: 

104 return ValidationSummary.model_validate_json(result) 

105 

106 return None 

107 

108 

109class GradioPredictionPipeline(RemotePredictionPipeline): 

110 """Prediction pipeline to use the bioimage-io-gradio-runner as a fully remote prediction pipeline.""" 

111 

112 def __init__( 

113 self, 

114 model_description: AnyModelDescr, 

115 *, 

116 server: Optional[str] = None, 

117 precomputed_statistics: Mapping[Measure, MeasureValue] = MappingProxyType({}), 

118 default_blocksize_parameter: BlocksizeParameter = 10, 

119 default_batch_size: int = 1, 

120 ): 

121 """ 

122 Note: 

123 - This pipeline requires an environment with the same gradio version as the one used on the bioimage-io-gradio-runner server. 

124 

125 Args: 

126 model_description: The model to run inference with. 

127 server: The URL or Hugging Face space name of a running bioimageio gradio server instance (Note: default server might not be availabile/compatible!). 

128 """ 

129 server = server or settings.gradio_server 

130 if server is None: 

131 raise ValueError( 

132 "No gradio server specified. Please provide a server URL or set the 'BIOIMAGEIO_GRADIO_SERVER' environment variable." 

133 ) 

134 

135 super().__init__( 

136 model_description, 

137 server=server, 

138 default_blocksize_parameter=default_blocksize_parameter, 

139 default_batch_size=default_batch_size, 

140 ) 

141 self._client = Client(self.server, httpx_kwargs={"timeout": 60}) 

142 self._serialized_model, self._sha256 = ( 

143 DescriptionSerializer.serialize_to_string_and_hash(model_description) 

144 ) 

145 self._serializer = GradioSampleSerializer 

146 self._precomputed_statistics = dict(precomputed_statistics) 

147 

148 def predict_sample_block( 

149 self, 

150 sample_block: SampleBlock, 

151 skip_preprocessing: bool = False, 

152 skip_postprocessing: bool = False, 

153 ) -> SampleBlock: 

154 if isinstance(self._model_descr, v0_4.ModelDescr): 

155 raise NotImplementedError( 

156 f"predict_sample_block not implemented for model {self._model_descr.format_version}" 

157 ) 

158 else: 

159 assert self._block_transform is not None 

160 

161 sample_block.stat.update(self._precomputed_statistics) 

162 output_block = self._serializer.deserialize_sample( 

163 _call_predict_api( 

164 self._client, 

165 self._serialized_model, 

166 self._sha256, 

167 serialized_input_sample=self._serializer.serialize_sample( 

168 sample_block.as_sample() 

169 ), 

170 blocksize=None, 

171 skip_preprocessing=skip_preprocessing, 

172 skip_postprocessing=skip_postprocessing, 

173 skip_input_padding=True, 

174 skip_output_cropping=True, 

175 batch_size=self._default_batch_size, 

176 ) 

177 ) 

178 output_meta = sample_block.get_transformed_meta(self._block_transform) 

179 return output_meta.with_data(output_block.members, stat=sample_block.stat) 

180 

181 def predict_sample_without_blocking( 

182 self, 

183 sample: Sample, 

184 skip_preprocessing: bool = False, 

185 skip_postprocessing: bool = False, 

186 skip_input_padding: bool = False, 

187 skip_output_cropping: bool = False, 

188 ) -> Sample: 

189 sample.stat.update(self._precomputed_statistics) 

190 return self._serializer.deserialize_sample( 

191 _call_predict_api( 

192 self._client, 

193 self._serialized_model, 

194 self._sha256, 

195 serialized_input_sample=self._serializer.serialize_sample(sample), 

196 blocksize=None, 

197 skip_preprocessing=skip_preprocessing, 

198 skip_postprocessing=skip_postprocessing, 

199 skip_input_padding=skip_input_padding, 

200 skip_output_cropping=skip_output_cropping, 

201 batch_size=self._default_batch_size, 

202 ) 

203 ) 

204 

205 def predict_sample_with_fixed_blocking_yield_intermediates( 

206 self, 

207 sample: Sample, 

208 input_block_shape: PerMember[PerAxis[int]], 

209 *, 

210 skip_preprocessing: bool = False, 

211 skip_postprocessing: bool = False, 

212 fill_value: float = float("nan"), 

213 ) -> Tuple[int, Iterable[IntermediatePrediction]]: 

214 sample.stat.update(self._precomputed_statistics) 

215 

216 # blocking for serialization is not really important, but we might as well block 

217 # the same way we want the backend to block for blockwise prediction 

218 serialized_input_sample = self._serializer.serialize_sample_with_fixed_blocking( 

219 sample, block_shapes=input_block_shape, halo=self._default_input_halo 

220 ) 

221 

222 def _predict_blocks(): 

223 output_sample = None 

224 for serialized_output_block in _call_predict_api( 

225 self._client, 

226 self._serialized_model, 

227 self._sha256, 

228 serialized_input_sample=serialized_input_sample, 

229 blocksize=input_block_shape, 

230 skip_preprocessing=skip_preprocessing, 

231 skip_postprocessing=skip_postprocessing, 

232 skip_input_padding=False, 

233 skip_output_cropping=False, 

234 batch_size=self._default_batch_size, 

235 ): 

236 output_block = self._serializer.deserialize_sample_block( 

237 serialized_output_block 

238 ) 

239 if output_sample is None: 

240 output_sample = Sample.from_blocks( 

241 [output_block], fill_value=fill_value 

242 ) 

243 else: 

244 output_sample.set_block(output_block) 

245 

246 yield IntermediatePrediction(output_sample, output_block) 

247 

248 block_iterator = _predict_blocks() 

249 first_intermediate = next(block_iterator) 

250 

251 def _intermediate_predictions() -> Iterable[IntermediatePrediction]: 

252 yield first_intermediate 

253 yield from block_iterator 

254 

255 return ( 

256 first_intermediate.last_block.blocks_in_sample, 

257 _intermediate_predictions(), 

258 ) 

259 

260 

261def _call_predict_api( 

262 client: Client, 

263 serialized_model: str, 

264 sha256: str, 

265 serialized_input_sample: Iterable[SerializedSampleBlock], 

266 blocksize: Optional[ 

267 Union[int, Literal["blockwise_as_serialized"], PerMember[PerAxis[int]]] 

268 ], 

269 skip_preprocessing: bool, 

270 skip_postprocessing: bool, 

271 skip_input_padding: bool, 

272 skip_output_cropping: bool, 

273 batch_size: Optional[int], 

274) -> Iterable[SerializedSampleBlock]: 

275 def submit(model: str): 

276 return client.submit( 

277 api_name="/predict", 

278 model=model, 

279 sha256=sha256, 

280 input_sample=serialized_input_sample, 

281 blocksize={ 

282 str(k): {str(kk): vv for kk, vv in v.items()} 

283 for k, v in blocksize.items() 

284 } 

285 if not (blocksize is None or isinstance(blocksize, (int, str))) 

286 else blocksize, 

287 skip_preprocessing=skip_preprocessing, 

288 skip_postprocessing=skip_postprocessing, 

289 skip_input_padding=skip_input_padding, 

290 skip_output_cropping=skip_output_cropping, 

291 batch_size=batch_size, 

292 ) 

293 

294 try_with_model_upload = True 

295 try: 

296 job = submit("") 

297 for block in job: # pyright: ignore[reportUnknownVariableType] 

298 yield block # pyright: ignore[reportReturnType] 

299 # we got one response, so the model cache was hit... 

300 try_with_model_upload = False 

301 except Exception as e: 

302 # A raised exception on the server seems to simply return an empty response sequence, 

303 # so this except is likely not triggered at all. 

304 # Below we retry on empty return value, too. 

305 if try_with_model_upload: 

306 logger.warning( 

307 "Failed to submit job without model upload, trying with model upload, error was: {}", 

308 e, 

309 ) 

310 else: 

311 raise e 

312 

313 if try_with_model_upload: 

314 job = submit(serialized_model) 

315 for block in job: # pyright: ignore[reportUnknownVariableType] 

316 yield block # pyright: ignore[reportReturnType]