Coverage for src/bioimageio/core/remote_backends/gradio/server.py: 0%
85 statements
« prev ^ index » next coverage.py v7.14.2, created at 2026-06-22 16:54 +0000
« prev ^ index » next coverage.py v7.14.2, created at 2026-06-22 16:54 +0000
1from itertools import chain
2from typing import (
3 Any,
4 Dict,
5 Iterable,
6 Literal,
7 Optional,
8 Union,
9)
11import gradio as gr
12from loguru import logger
14import bioimageio.core
15from bioimageio.core import AxisId, Stat
16from bioimageio.core.axis import PerAxis
17from bioimageio.core.backends import create_model_adapter
18from bioimageio.core.common import PerMember
19from bioimageio.core.remote_backends.gradio.serializer import (
20 DescriptionSerializer,
21 GradioSampleSerializer,
22 SerializedSampleBlock,
23)
24from bioimageio.spec import load_model_description
25from bioimageio.spec.common import Sha256
26from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
28try:
29 import spaces # pyright: ignore
30except ImportError:
31 logger.warning("Failed to import 'spaces' package")
33 class spaces:
34 @staticmethod
35 def GPU(func: Any):
36 return func
39logger.enable("bioimageio")
41app = gr.Server()
44@app.api(name="predict") # pyright: ignore[reportUntypedFunctionDecorator]
45@spaces.GPU
46def predict(
47 model: str,
48 sha256: str,
49 input_sample: Iterable[SerializedSampleBlock],
50 blocksize: Optional[
51 Union[int, Literal["blockwise_as_serialized"], PerMember[PerAxis[int]]]
52 ] = None,
53 skip_preprocessing: bool = False,
54 skip_postprocessing: bool = False,
55 skip_input_padding: bool = False,
56 skip_output_cropping: bool = False,
57 batch_size: Optional[int] = None,
58) -> Iterable[SerializedSampleBlock]:
59 """Run prediction on a sample
61 Args:
62 input_sample: Input sample as a sequence of serialized sample blocks.
63 Use bioimageio.core.backends.gradio_backend.GradioModelAdapter.serialize_sample to create this from a Sample object.
64 model: A model source: URL, nickname or base64 encoded model package (if len(model) > 2083).
65 sha256: Sha256 hash of the model's bioimageio.yaml file at the model source or of the encoded model package.
66 blocksize:
67 - None (default): run non-blockwise, full-sample prediction.
68 - integer: run blockwise prediction with a block size derived from the model and this blocksize parameter.
69 - "blockwise_as_serialized": run blockwise prediction with the same blocking as the serialized input sample.
70 (Non-blockwise pre- and postprocessing steps will be ignored.)
71 - PerMember[PerAxis[int]]: run blockwise prediction with a fixed block shape given for each sample member.
72 skip_preprocessing: If True, skip preprocessing steps defined in the model.
73 skip_postprocessing: If True, skip postprocessing steps defined in the model.
74 skip_input_padding: If True, skip input padding for non-blockwise prediction.
75 Set this flag when predicting an (overlapping) sample block rather than a full sample.
76 skip_output_cropping: If True, skip output cropping for non-blockwise prediction.
77 Set this flag when predicting an (overlapping) sample block rather than a full sample.
78 batch_size: Optional batch size only applicable to predicting input samples with batch dimension.
79 """
81 def setup(stat: Stat):
82 model_adapter = _get_model_adapter(model, sha256=sha256)
83 return bioimageio.core.create_prediction_pipeline(
84 model_adapter.model_descr, fixed_dataset_statistics=stat
85 )
87 if blocksize == "blockwise_as_serialized":
88 sample_block_iterator = iter(input_sample)
89 deserialized_input_block = GradioSampleSerializer.deserialize_sample_block(
90 next(sample_block_iterator)
91 )
92 pp = setup(deserialized_input_block.stat)
93 for block in chain(
94 [deserialized_input_block],
95 (
96 GradioSampleSerializer.deserialize_sample_block(b)
97 for b in sample_block_iterator
98 ),
99 ):
100 output_block = pp.predict_sample_block(
101 block,
102 skip_preprocessing=skip_preprocessing,
103 skip_postprocessing=skip_postprocessing,
104 )
105 yield GradioSampleSerializer.serialize_sample_block(output_block)
106 else:
107 deserialized_input_sample = GradioSampleSerializer.deserialize_sample(
108 input_sample
109 )
110 pp = setup(deserialized_input_sample.stat)
112 output_sample = None
113 if isinstance(blocksize, int):
114 try:
115 if pp.has_non_blockwise_postprocessing and not skip_postprocessing:
116 output_sample = pp.predict_sample_with_blocking(
117 deserialized_input_sample,
118 skip_preprocessing=skip_preprocessing,
119 skip_postprocessing=skip_postprocessing,
120 ns=blocksize,
121 batch_size=batch_size,
122 )
123 else:
124 for output in pp.predict_sample_with_blocking_yield_intermediates(
125 deserialized_input_sample,
126 skip_preprocessing=skip_preprocessing,
127 skip_postprocessing=skip_postprocessing,
128 ns=blocksize,
129 batch_size=batch_size,
130 )[1]:
131 # with purely blockwise postprocesssing or with postprocessing skipped,
132 # predicted blocks are part of the final result, so we yield them immediately.
133 yield GradioSampleSerializer.serialize_sample_block(
134 output.last_block
135 )
137 return
139 except Exception as e:
140 logger.warning(
141 "Falling back to full-sample prediction for model {}: {}",
142 pp.model_descr.id or pp.model_descr.name,
143 e,
144 )
145 if output_sample is None:
146 output_sample = pp.predict_sample_without_blocking(
147 deserialized_input_sample,
148 skip_preprocessing=skip_preprocessing,
149 skip_postprocessing=skip_postprocessing,
150 skip_input_padding=skip_input_padding,
151 skip_output_cropping=skip_output_cropping,
152 )
154 if all(
155 axes.get(AxisId("batch"), 1) > 1 for axes in output_sample.shape.values()
156 ):
157 # yield batches
158 yield from GradioSampleSerializer.serialize_sample_with_fixed_blocking(
159 output_sample,
160 block_shapes={
161 m: {AxisId("batch"): batch_size or 1} for m in output_sample.shape
162 },
163 halo={},
164 )
165 else:
166 yield from GradioSampleSerializer.serialize_sample(output_sample)
169@app.api(name="load_model") # pyright: ignore[reportUntypedFunctionDecorator]
170def load_model(
171 model: str,
172 sha256: str,
173) -> dict[Literal["message"], str]:
174 """Load a model into the server's model cache. This can be used to pre-load a model before running predictions to avoid the overhead of loading the model during the first prediction request."""
175 _ = _get_model_adapter(model, sha256=sha256)
176 return {"message": "Model loaded successfully"}
179@app.api(name="test_model") # pyright: ignore[reportUntypedFunctionDecorator]
180def test_model(
181 model: str,
182 sha256: str,
183) -> str:
184 """Run the bioimageio model test and return the validation summary. Returns None if testing failed."""
185 model_adapter = _get_model_adapter(model, sha256=sha256)
186 summary = bioimageio.core.test_model(model_adapter.model_descr)
187 return summary.model_dump_json()
190def _cache_key(kwargs: Dict[str, Any]) -> str:
191 return kwargs["sha256"]
194@gr.cache( # pyright: ignore[reportUntypedFunctionDecorator]
195 key=_cache_key,
196 max_size=bioimageio.core.settings.gradio_server_model_cache_max_size,
197 max_memory=bioimageio.core.settings.gradio_server_model_cache_max_memory,
198 per_session=False,
199)
200def _get_model_adapter(
201 model: str,
202 *,
203 sha256: str,
204):
205 """Get a model adapter for the given model
207 Args:
208 model: A model source: URL (len(model) <= 2083)) or model base64 encoded package bytes (len(model) > 2083).
209 sha256: Sha256 hash of the model source at model URL or of the encoded model package bytes.
210 """
211 if not model:
212 raise ValueError("Model source cannot be empty")
214 model_descr = _get_model(model, sha256=sha256)
215 return create_model_adapter(model_description=model_descr)
218def _get_model(
219 model: str,
220 *,
221 sha256: str,
222) -> AnyModelDescr:
223 if len(model) > 2083:
224 ret = DescriptionSerializer.deserialize_from_string(model)
225 if not isinstance(ret, (v0_4.ModelDescr, v0_5.ModelDescr)):
226 raise ValueError(
227 f"Deserialized model description is not a valid model description: got {ret.type} {ret.format_version}"
228 )
229 return ret
230 else:
231 return load_model_description(model, sha256=Sha256(sha256) if sha256 else None)
234@app.get("/")
235def root():
236 return {
237 "message": f"Running bioimageio.core {bioimageio.core.__version__} gradio server."
238 }
241def main(port: Optional[int] = None) -> str:
242 _app, local_url, _share_url = app.launch(
243 mcp_server=True, show_error=True, server_port=port
244 )
245 return local_url
248if __name__ == "__main__":
249 _ = main()