Coverage for src/bioimageio/core/remote_backends/gradio/client.py: 0%
103 statements
« prev ^ index » next coverage.py v7.14.2, created at 2026-06-22 16:54 +0000
« prev ^ index » next coverage.py v7.14.2, created at 2026-06-22 16:54 +0000
1from types import MappingProxyType
2from typing import Dict, Iterable, Literal, Mapping, Optional, Tuple, Union
4from gradio_client import Client
5from loguru import logger
7from bioimageio.spec import AnyModelDescr, ValidationSummary
8from bioimageio.spec.model import v0_4
10from ..._description_serializer import DescriptionSerializer as DescriptionSerializer
11from ..._model_adapter import RemoteModelAdapter
12from ..._prediction_pipeline import IntermediatePrediction, RemotePredictionPipeline
13from ..._settings import settings
14from ...axis import PerAxis
15from ...common import BlocksizeParameter, PerMember
16from ...io import JsonValue
17from ...sample import Sample, SampleBlock
18from ...stat_measures import Measure, MeasureValue
19from .serializer import GradioSampleSerializer
21SerializedSampleBlock = Dict[str, JsonValue]
24class GradioModelAdapter(RemoteModelAdapter[SerializedSampleBlock]):
25 """Model adapter to use the bioimage-io-gradio-runner as a backend for model inference."""
27 def __init__(
28 self, model_description: AnyModelDescr, *, server: Optional[str] = None
29 ):
30 """Initialize the GradioModelAdapter.
32 Note:
33 - This adapter requires an environment with the same gradio version as the one used on the bioimage-io-gradio-runner server.
35 Args:
36 model_description: The model to run inference with.
37 server: The URL of a running bioimage-io-gradio-server instance (default server might not be availability/compatible).
38 """
39 server = server or settings.gradio_server
40 if server is None:
41 raise ValueError(
42 "No gradio server specified. Please provide a server URL or set the 'BIOIMAGEIO_GRADIO_SERVER' environment variable."
43 )
45 self._client = Client(server, httpx_kwargs={"timeout": 60})
46 self._serialized_model, self._sha256 = (
47 DescriptionSerializer.serialize_to_string_and_hash(model_description)
48 )
49 super().__init__(
50 model_description, server=server, sample_serializer=GradioSampleSerializer()
51 )
53 def _forward_impl(
54 self, serialized_input_sample: Iterable[SerializedSampleBlock]
55 ) -> Iterable[SerializedSampleBlock]:
56 return _call_predict_api(
57 self._client,
58 self._serialized_model,
59 self._sha256,
60 serialized_input_sample,
61 blocksize=None,
62 skip_preprocessing=True,
63 skip_postprocessing=True,
64 skip_input_padding=True,
65 skip_output_cropping=True,
66 batch_size=None,
67 )
69 def unload(self):
70 return super().unload()
72 def load(self) -> None:
73 for model_data in ("", self._serialized_model):
74 try:
75 result = self._client.submit(
76 api_name="/load_model", model=model_data, sha256=self._sha256
77 ).result()
78 except Exception as e:
79 if model_data:
80 logger.warning(
81 "Failed to load model on server with model_data, error was: {}",
82 len(model_data),
83 e,
84 )
85 else:
86 if result:
87 break
89 def test(self) -> Optional[ValidationSummary]:
90 for model_data in ("", self._serialized_model):
91 try:
92 result = self._client.submit(
93 api_name="/test_model", model=model_data, sha256=self._sha256
94 ).result()
95 except Exception as e:
96 if model_data:
97 logger.warning(
98 "Failed to test model on server with model_data, error was: {}",
99 len(model_data),
100 e,
101 )
102 else:
103 if result:
104 return ValidationSummary.model_validate_json(result)
106 return None
109class GradioPredictionPipeline(RemotePredictionPipeline):
110 """Prediction pipeline to use the bioimage-io-gradio-runner as a fully remote prediction pipeline."""
112 def __init__(
113 self,
114 model_description: AnyModelDescr,
115 *,
116 server: Optional[str] = None,
117 precomputed_statistics: Mapping[Measure, MeasureValue] = MappingProxyType({}),
118 default_blocksize_parameter: BlocksizeParameter = 10,
119 default_batch_size: int = 1,
120 ):
121 """
122 Note:
123 - This pipeline requires an environment with the same gradio version as the one used on the bioimage-io-gradio-runner server.
125 Args:
126 model_description: The model to run inference with.
127 server: The URL or Hugging Face space name of a running bioimageio gradio server instance (Note: default server might not be availabile/compatible!).
128 """
129 server = server or settings.gradio_server
130 if server is None:
131 raise ValueError(
132 "No gradio server specified. Please provide a server URL or set the 'BIOIMAGEIO_GRADIO_SERVER' environment variable."
133 )
135 super().__init__(
136 model_description,
137 server=server,
138 default_blocksize_parameter=default_blocksize_parameter,
139 default_batch_size=default_batch_size,
140 )
141 self._client = Client(self.server, httpx_kwargs={"timeout": 60})
142 self._serialized_model, self._sha256 = (
143 DescriptionSerializer.serialize_to_string_and_hash(model_description)
144 )
145 self._serializer = GradioSampleSerializer
146 self._precomputed_statistics = dict(precomputed_statistics)
148 def predict_sample_block(
149 self,
150 sample_block: SampleBlock,
151 skip_preprocessing: bool = False,
152 skip_postprocessing: bool = False,
153 ) -> SampleBlock:
154 if isinstance(self._model_descr, v0_4.ModelDescr):
155 raise NotImplementedError(
156 f"predict_sample_block not implemented for model {self._model_descr.format_version}"
157 )
158 else:
159 assert self._block_transform is not None
161 sample_block.stat.update(self._precomputed_statistics)
162 output_block = self._serializer.deserialize_sample(
163 _call_predict_api(
164 self._client,
165 self._serialized_model,
166 self._sha256,
167 serialized_input_sample=self._serializer.serialize_sample(
168 sample_block.as_sample()
169 ),
170 blocksize=None,
171 skip_preprocessing=skip_preprocessing,
172 skip_postprocessing=skip_postprocessing,
173 skip_input_padding=True,
174 skip_output_cropping=True,
175 batch_size=self._default_batch_size,
176 )
177 )
178 output_meta = sample_block.get_transformed_meta(self._block_transform)
179 return output_meta.with_data(output_block.members, stat=sample_block.stat)
181 def predict_sample_without_blocking(
182 self,
183 sample: Sample,
184 skip_preprocessing: bool = False,
185 skip_postprocessing: bool = False,
186 skip_input_padding: bool = False,
187 skip_output_cropping: bool = False,
188 ) -> Sample:
189 sample.stat.update(self._precomputed_statistics)
190 return self._serializer.deserialize_sample(
191 _call_predict_api(
192 self._client,
193 self._serialized_model,
194 self._sha256,
195 serialized_input_sample=self._serializer.serialize_sample(sample),
196 blocksize=None,
197 skip_preprocessing=skip_preprocessing,
198 skip_postprocessing=skip_postprocessing,
199 skip_input_padding=skip_input_padding,
200 skip_output_cropping=skip_output_cropping,
201 batch_size=self._default_batch_size,
202 )
203 )
205 def predict_sample_with_fixed_blocking_yield_intermediates(
206 self,
207 sample: Sample,
208 input_block_shape: PerMember[PerAxis[int]],
209 *,
210 skip_preprocessing: bool = False,
211 skip_postprocessing: bool = False,
212 fill_value: float = float("nan"),
213 ) -> Tuple[int, Iterable[IntermediatePrediction]]:
214 sample.stat.update(self._precomputed_statistics)
216 # blocking for serialization is not really important, but we might as well block
217 # the same way we want the backend to block for blockwise prediction
218 serialized_input_sample = self._serializer.serialize_sample_with_fixed_blocking(
219 sample, block_shapes=input_block_shape, halo=self._default_input_halo
220 )
222 def _predict_blocks():
223 output_sample = None
224 for serialized_output_block in _call_predict_api(
225 self._client,
226 self._serialized_model,
227 self._sha256,
228 serialized_input_sample=serialized_input_sample,
229 blocksize=input_block_shape,
230 skip_preprocessing=skip_preprocessing,
231 skip_postprocessing=skip_postprocessing,
232 skip_input_padding=False,
233 skip_output_cropping=False,
234 batch_size=self._default_batch_size,
235 ):
236 output_block = self._serializer.deserialize_sample_block(
237 serialized_output_block
238 )
239 if output_sample is None:
240 output_sample = Sample.from_blocks(
241 [output_block], fill_value=fill_value
242 )
243 else:
244 output_sample.set_block(output_block)
246 yield IntermediatePrediction(output_sample, output_block)
248 block_iterator = _predict_blocks()
249 first_intermediate = next(block_iterator)
251 def _intermediate_predictions() -> Iterable[IntermediatePrediction]:
252 yield first_intermediate
253 yield from block_iterator
255 return (
256 first_intermediate.last_block.blocks_in_sample,
257 _intermediate_predictions(),
258 )
261def _call_predict_api(
262 client: Client,
263 serialized_model: str,
264 sha256: str,
265 serialized_input_sample: Iterable[SerializedSampleBlock],
266 blocksize: Optional[
267 Union[int, Literal["blockwise_as_serialized"], PerMember[PerAxis[int]]]
268 ],
269 skip_preprocessing: bool,
270 skip_postprocessing: bool,
271 skip_input_padding: bool,
272 skip_output_cropping: bool,
273 batch_size: Optional[int],
274) -> Iterable[SerializedSampleBlock]:
275 def submit(model: str):
276 return client.submit(
277 api_name="/predict",
278 model=model,
279 sha256=sha256,
280 input_sample=serialized_input_sample,
281 blocksize={
282 str(k): {str(kk): vv for kk, vv in v.items()}
283 for k, v in blocksize.items()
284 }
285 if not (blocksize is None or isinstance(blocksize, (int, str)))
286 else blocksize,
287 skip_preprocessing=skip_preprocessing,
288 skip_postprocessing=skip_postprocessing,
289 skip_input_padding=skip_input_padding,
290 skip_output_cropping=skip_output_cropping,
291 batch_size=batch_size,
292 )
294 try_with_model_upload = True
295 try:
296 job = submit("")
297 for block in job: # pyright: ignore[reportUnknownVariableType]
298 yield block # pyright: ignore[reportReturnType]
299 # we got one response, so the model cache was hit...
300 try_with_model_upload = False
301 except Exception as e:
302 # A raised exception on the server seems to simply return an empty response sequence,
303 # so this except is likely not triggered at all.
304 # Below we retry on empty return value, too.
305 if try_with_model_upload:
306 logger.warning(
307 "Failed to submit job without model upload, trying with model upload, error was: {}",
308 e,
309 )
310 else:
311 raise e
313 if try_with_model_upload:
314 job = submit(serialized_model)
315 for block in job: # pyright: ignore[reportUnknownVariableType]
316 yield block # pyright: ignore[reportReturnType]