Coverage for src/bioimageio/core/remote_backends/gradio/server.py: 0%

85 statements  

« prev     ^ index     » next       coverage.py v7.14.2, created at 2026-06-22 16:54 +0000

1from itertools import chain 

2from typing import ( 

3 Any, 

4 Dict, 

5 Iterable, 

6 Literal, 

7 Optional, 

8 Union, 

9) 

10 

11import gradio as gr 

12from loguru import logger 

13 

14import bioimageio.core 

15from bioimageio.core import AxisId, Stat 

16from bioimageio.core.axis import PerAxis 

17from bioimageio.core.backends import create_model_adapter 

18from bioimageio.core.common import PerMember 

19from bioimageio.core.remote_backends.gradio.serializer import ( 

20 DescriptionSerializer, 

21 GradioSampleSerializer, 

22 SerializedSampleBlock, 

23) 

24from bioimageio.spec import load_model_description 

25from bioimageio.spec.common import Sha256 

26from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 

27 

28try: 

29 import spaces # pyright: ignore 

30except ImportError: 

31 logger.warning("Failed to import 'spaces' package") 

32 

33 class spaces: 

34 @staticmethod 

35 def GPU(func: Any): 

36 return func 

37 

38 

39logger.enable("bioimageio") 

40 

41app = gr.Server() 

42 

43 

44@app.api(name="predict") # pyright: ignore[reportUntypedFunctionDecorator] 

45@spaces.GPU 

46def predict( 

47 model: str, 

48 sha256: str, 

49 input_sample: Iterable[SerializedSampleBlock], 

50 blocksize: Optional[ 

51 Union[int, Literal["blockwise_as_serialized"], PerMember[PerAxis[int]]] 

52 ] = None, 

53 skip_preprocessing: bool = False, 

54 skip_postprocessing: bool = False, 

55 skip_input_padding: bool = False, 

56 skip_output_cropping: bool = False, 

57 batch_size: Optional[int] = None, 

58) -> Iterable[SerializedSampleBlock]: 

59 """Run prediction on a sample 

60 

61 Args: 

62 input_sample: Input sample as a sequence of serialized sample blocks. 

63 Use bioimageio.core.backends.gradio_backend.GradioModelAdapter.serialize_sample to create this from a Sample object. 

64 model: A model source: URL, nickname or base64 encoded model package (if len(model) > 2083). 

65 sha256: Sha256 hash of the model's bioimageio.yaml file at the model source or of the encoded model package. 

66 blocksize: 

67 - None (default): run non-blockwise, full-sample prediction. 

68 - integer: run blockwise prediction with a block size derived from the model and this blocksize parameter. 

69 - "blockwise_as_serialized": run blockwise prediction with the same blocking as the serialized input sample. 

70 (Non-blockwise pre- and postprocessing steps will be ignored.) 

71 - PerMember[PerAxis[int]]: run blockwise prediction with a fixed block shape given for each sample member. 

72 skip_preprocessing: If True, skip preprocessing steps defined in the model. 

73 skip_postprocessing: If True, skip postprocessing steps defined in the model. 

74 skip_input_padding: If True, skip input padding for non-blockwise prediction. 

75 Set this flag when predicting an (overlapping) sample block rather than a full sample. 

76 skip_output_cropping: If True, skip output cropping for non-blockwise prediction. 

77 Set this flag when predicting an (overlapping) sample block rather than a full sample. 

78 batch_size: Optional batch size only applicable to predicting input samples with batch dimension. 

79 """ 

80 

81 def setup(stat: Stat): 

82 model_adapter = _get_model_adapter(model, sha256=sha256) 

83 return bioimageio.core.create_prediction_pipeline( 

84 model_adapter.model_descr, fixed_dataset_statistics=stat 

85 ) 

86 

87 if blocksize == "blockwise_as_serialized": 

88 sample_block_iterator = iter(input_sample) 

89 deserialized_input_block = GradioSampleSerializer.deserialize_sample_block( 

90 next(sample_block_iterator) 

91 ) 

92 pp = setup(deserialized_input_block.stat) 

93 for block in chain( 

94 [deserialized_input_block], 

95 ( 

96 GradioSampleSerializer.deserialize_sample_block(b) 

97 for b in sample_block_iterator 

98 ), 

99 ): 

100 output_block = pp.predict_sample_block( 

101 block, 

102 skip_preprocessing=skip_preprocessing, 

103 skip_postprocessing=skip_postprocessing, 

104 ) 

105 yield GradioSampleSerializer.serialize_sample_block(output_block) 

106 else: 

107 deserialized_input_sample = GradioSampleSerializer.deserialize_sample( 

108 input_sample 

109 ) 

110 pp = setup(deserialized_input_sample.stat) 

111 

112 output_sample = None 

113 if isinstance(blocksize, int): 

114 try: 

115 if pp.has_non_blockwise_postprocessing and not skip_postprocessing: 

116 output_sample = pp.predict_sample_with_blocking( 

117 deserialized_input_sample, 

118 skip_preprocessing=skip_preprocessing, 

119 skip_postprocessing=skip_postprocessing, 

120 ns=blocksize, 

121 batch_size=batch_size, 

122 ) 

123 else: 

124 for output in pp.predict_sample_with_blocking_yield_intermediates( 

125 deserialized_input_sample, 

126 skip_preprocessing=skip_preprocessing, 

127 skip_postprocessing=skip_postprocessing, 

128 ns=blocksize, 

129 batch_size=batch_size, 

130 )[1]: 

131 # with purely blockwise postprocesssing or with postprocessing skipped, 

132 # predicted blocks are part of the final result, so we yield them immediately. 

133 yield GradioSampleSerializer.serialize_sample_block( 

134 output.last_block 

135 ) 

136 

137 return 

138 

139 except Exception as e: 

140 logger.warning( 

141 "Falling back to full-sample prediction for model {}: {}", 

142 pp.model_descr.id or pp.model_descr.name, 

143 e, 

144 ) 

145 if output_sample is None: 

146 output_sample = pp.predict_sample_without_blocking( 

147 deserialized_input_sample, 

148 skip_preprocessing=skip_preprocessing, 

149 skip_postprocessing=skip_postprocessing, 

150 skip_input_padding=skip_input_padding, 

151 skip_output_cropping=skip_output_cropping, 

152 ) 

153 

154 if all( 

155 axes.get(AxisId("batch"), 1) > 1 for axes in output_sample.shape.values() 

156 ): 

157 # yield batches 

158 yield from GradioSampleSerializer.serialize_sample_with_fixed_blocking( 

159 output_sample, 

160 block_shapes={ 

161 m: {AxisId("batch"): batch_size or 1} for m in output_sample.shape 

162 }, 

163 halo={}, 

164 ) 

165 else: 

166 yield from GradioSampleSerializer.serialize_sample(output_sample) 

167 

168 

169@app.api(name="load_model") # pyright: ignore[reportUntypedFunctionDecorator] 

170def load_model( 

171 model: str, 

172 sha256: str, 

173) -> dict[Literal["message"], str]: 

174 """Load a model into the server's model cache. This can be used to pre-load a model before running predictions to avoid the overhead of loading the model during the first prediction request.""" 

175 _ = _get_model_adapter(model, sha256=sha256) 

176 return {"message": "Model loaded successfully"} 

177 

178 

179@app.api(name="test_model") # pyright: ignore[reportUntypedFunctionDecorator] 

180def test_model( 

181 model: str, 

182 sha256: str, 

183) -> str: 

184 """Run the bioimageio model test and return the validation summary. Returns None if testing failed.""" 

185 model_adapter = _get_model_adapter(model, sha256=sha256) 

186 summary = bioimageio.core.test_model(model_adapter.model_descr) 

187 return summary.model_dump_json() 

188 

189 

190def _cache_key(kwargs: Dict[str, Any]) -> str: 

191 return kwargs["sha256"] 

192 

193 

194@gr.cache( # pyright: ignore[reportUntypedFunctionDecorator] 

195 key=_cache_key, 

196 max_size=bioimageio.core.settings.gradio_server_model_cache_max_size, 

197 max_memory=bioimageio.core.settings.gradio_server_model_cache_max_memory, 

198 per_session=False, 

199) 

200def _get_model_adapter( 

201 model: str, 

202 *, 

203 sha256: str, 

204): 

205 """Get a model adapter for the given model 

206 

207 Args: 

208 model: A model source: URL (len(model) <= 2083)) or model base64 encoded package bytes (len(model) > 2083). 

209 sha256: Sha256 hash of the model source at model URL or of the encoded model package bytes. 

210 """ 

211 if not model: 

212 raise ValueError("Model source cannot be empty") 

213 

214 model_descr = _get_model(model, sha256=sha256) 

215 return create_model_adapter(model_description=model_descr) 

216 

217 

218def _get_model( 

219 model: str, 

220 *, 

221 sha256: str, 

222) -> AnyModelDescr: 

223 if len(model) > 2083: 

224 ret = DescriptionSerializer.deserialize_from_string(model) 

225 if not isinstance(ret, (v0_4.ModelDescr, v0_5.ModelDescr)): 

226 raise ValueError( 

227 f"Deserialized model description is not a valid model description: got {ret.type} {ret.format_version}" 

228 ) 

229 return ret 

230 else: 

231 return load_model_description(model, sha256=Sha256(sha256) if sha256 else None) 

232 

233 

234@app.get("/") 

235def root(): 

236 return { 

237 "message": f"Running bioimageio.core {bioimageio.core.__version__} gradio server." 

238 } 

239 

240 

241def main(port: Optional[int] = None) -> str: 

242 _app, local_url, _share_url = app.launch( 

243 mcp_server=True, show_error=True, server_port=port 

244 ) 

245 return local_url 

246 

247 

248if __name__ == "__main__": 

249 _ = main()