Coverage for src/bioimageio/core/_prediction_pipeline.py: 83%
226 statements
« prev ^ index » next coverage.py v7.14.2, created at 2026-06-22 16:54 +0000
« prev ^ index » next coverage.py v7.14.2, created at 2026-06-22 16:54 +0000
1import warnings
2from abc import ABC, abstractmethod
3from types import MappingProxyType
4from typing import (
5 Any,
6 Iterable,
7 List,
8 Literal,
9 Mapping,
10 NamedTuple,
11 Optional,
12 Sequence,
13 Tuple,
14 TypeVar,
15 Union,
16)
18from loguru import logger
19from tqdm import tqdm
20from typing_extensions import assert_never
22from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
24from ._model_adapter import ModelAdapter
25from ._op_base import BlockwiseOperator, SamplewiseOperator
26from .axis import AxisId, PerAxis
27from .backends import create_model_adapter
28from .common import (
29 BlocksizeParameter,
30 Halo,
31 MemberId,
32 PerMember,
33 SampleId,
34 SupportedWeightsFormat,
35)
36from .digest_spec import (
37 get_block_transform,
38 get_input_halo,
39 get_member_ids,
40)
41from .proc_ops import Processing
42from .proc_setup import setup_pre_and_postprocessing
43from .sample import Sample, SampleBlock
44from .stat_measures import Measure, MeasureValue, Stat
45from .tensor import Tensor
47Predict_IO = TypeVar(
48 "Predict_IO",
49 Sample,
50 Iterable[Sample],
51)
54class IntermediatePrediction(NamedTuple):
55 """Represents an intermediate prediction of a sample with blocking, including the predicted sample so far and the last predicted block.
57 The final `IntermediatePrediction` in a sequence holds the complete predicted (and postprocessed if applicable) sample."""
59 sample: Sample
60 last_block: SampleBlock
63class _PredictionPipelineBase(ABC):
64 def __init__(
65 self,
66 model_descr: AnyModelDescr,
67 *,
68 default_blocksize_parameter: BlocksizeParameter,
69 default_batch_size: int,
70 ) -> None:
71 super().__init__()
72 self._model_descr = model_descr
73 self._default_blocksize_parameter = default_blocksize_parameter
74 self._default_batch_size = default_batch_size
76 if isinstance(model_descr, v0_4.ModelDescr):
77 self._default_output_halo: PerMember[PerAxis[Halo]] = {}
78 self._default_input_halo: PerMember[PerAxis[Halo]] = {}
79 self._block_transform = None
80 else:
81 self._default_output_halo = {
82 t.id: {
83 a.id: Halo(a.halo, a.halo)
84 for a in t.axes
85 if isinstance(a, v0_5.WithHalo)
86 }
87 for t in model_descr.outputs
88 }
89 self._default_input_halo = get_input_halo(
90 model_descr, self._default_output_halo
91 )
92 self._block_transform = get_block_transform(model_descr)
94 self.pad_mode = (
95 {}
96 if isinstance(model_descr, v0_4.ModelDescr)
97 else {
98 descr.id: descr.pad or v0_5.SymmetricPadding()
99 for descr in model_descr.inputs
100 }
101 )
103 @property
104 def model_descr(self) -> AnyModelDescr:
105 return self._model_descr
107 @property
108 def model_description(self) -> AnyModelDescr:
109 return self._model_descr
111 @abstractmethod
112 def predict_sample_without_blocking(
113 self,
114 sample: Sample,
115 skip_preprocessing: bool = False,
116 skip_postprocessing: bool = False,
117 skip_input_padding: bool = False,
118 skip_output_cropping: bool = False,
119 ) -> Sample:
120 """Predict a whole sample at once.
122 Note:
123 The sample's tensor shapes have to match the model's input tensor description.
124 If that is not the case, consider `predict_sample_with_blocking`
126 Args:
127 sample: input sample
128 skip_preprocessing: if `True`, skip all preprocessing steps.
129 skip_postprocessing: if `True`, skip all postprocessing steps.
130 skip_input_padding: if `True`, skip padding the input sample according to the model's (optional) output halos.
131 skip_output_cropping: if `True`, skip cropping any output halos from the model output.
132 """
134 def predict_sample_with_blocking(
135 self,
136 sample: Sample,
137 skip_preprocessing: bool = False,
138 skip_postprocessing: bool = False,
139 ns: Optional[
140 Union[
141 v0_5.ParameterizedSize_N,
142 Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N],
143 ]
144 ] = None,
145 batch_size: Optional[int] = None,
146 ) -> Sample:
147 """Predict a sample by predicting sample blocks.
149 Note: For fixed/known blocksizes use `predict_sample_with_fixed_blocking`.
151 Args:
152 sample: The sample to predict on.
153 skip_preprocessing: If `True`, skip all preprocessing steps.
154 skip_postprocessing: If `True`, skip all postprocessing steps.
155 ns: Block size parameter(s) allows scaling the model's default input block size.
156 Blocksize parameters are only applied to parameterized input axes, all other axis sizes are fixed/derived or (for output axes) data dependent.
157 Unapplicable blocksize parameters are ignored.
158 batch_size: Batch size to use for prediction.
159 """
160 output = None
161 for output in self.predict_sample_with_blocking_yield_intermediates(
162 sample,
163 skip_preprocessing=skip_preprocessing,
164 skip_postprocessing=skip_postprocessing,
165 ns=ns,
166 batch_size=batch_size,
167 )[1]:
168 pass
170 assert output is not None, (
171 "No blocks were predicted, cannot return final sample."
172 )
173 return output.sample
175 def predict_sample_with_fixed_blocking(
176 self,
177 sample: Sample,
178 input_block_shape: PerMember[PerAxis[int]],
179 skip_preprocessing: bool = False,
180 skip_postprocessing: bool = False,
181 ) -> Sample:
182 """Predict `sample` with given `input_block_shape`.
184 Note:
185 - `input_block_shape` is expected to be a valid input shape for the model.
186 - Use `predict_sample_with_blocking` if you want to control block sizes via generic block size parameters rather than fixed block shapes.
188 Args:
189 sample: The sample to predict on.
190 input_block_shape: Mapping of input member id to mapping of axis id to block size for that axis.
191 skip_preprocessing: If `True`, skip all preprocessing steps.
192 skip_postprocessing: If `True`, skip all postprocessing steps.
193 """
194 intermediate = None
195 for intermediate in self.predict_sample_with_fixed_blocking_yield_intermediates(
196 sample,
197 input_block_shape=input_block_shape,
198 skip_preprocessing=skip_preprocessing,
199 skip_postprocessing=skip_postprocessing,
200 )[1]:
201 pass
203 assert intermediate is not None, (
204 "No blocks were predicted, cannot return final sample."
205 )
206 return intermediate.sample
208 def predict_sample_with_blocking_yield_intermediates(
209 self,
210 sample: Sample,
211 skip_preprocessing: bool = False,
212 skip_postprocessing: bool = False,
213 ns: Optional[
214 Union[
215 v0_5.ParameterizedSize_N,
216 Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N],
217 ]
218 ] = None,
219 batch_size: Optional[int] = None,
220 ) -> Tuple[int, Iterable[IntermediatePrediction]]:
221 """Predict `sample` by predicting sample blocks and yield intermediate predictions if no samplewise postprocessing is included.
223 Returns:
224 Tuple of number of blocks and an iterator of predicted intermediate samples with the last predicted block,
225 All samples, but the last one, are intermediate samples with more and more blocks predicted.
226 In case samplewise postprocessing needs to be applied, no intermediate results are yielded, but only the final sample after all blocks are predicted and postprocessed.
227 """
228 if isinstance(self._model_descr, v0_4.ModelDescr):
229 raise NotImplementedError(
230 "`predict_sample_with_blocking` not implemented for v0_4.ModelDescr"
231 + f" {self._model_descr.name}."
232 + " Consider using `predict_sample_with_fixed_blocking`"
233 )
235 ns = ns or self._default_blocksize_parameter
236 if isinstance(ns, int):
237 ns = {
238 (ipt.id, a.id): ns
239 for ipt in self._model_descr.inputs
240 for a in ipt.axes
241 if isinstance(a.size, v0_5.ParameterizedSize)
242 }
243 input_block_shape = self._model_descr.get_tensor_sizes(
244 ns, batch_size or self._default_batch_size
245 ).inputs
247 return self.predict_sample_with_fixed_blocking_yield_intermediates(
248 sample,
249 input_block_shape=input_block_shape,
250 skip_preprocessing=skip_preprocessing,
251 skip_postprocessing=skip_postprocessing,
252 )
254 @abstractmethod
255 def predict_sample_with_fixed_blocking_yield_intermediates(
256 self,
257 sample: Sample,
258 input_block_shape: PerMember[PerAxis[int]],
259 *,
260 skip_preprocessing: bool = False,
261 skip_postprocessing: bool = False,
262 fill_value: float = float("nan"),
263 ) -> Tuple[int, Iterable[IntermediatePrediction]]: ...
265 @abstractmethod
266 def predict_sample_block(
267 self,
268 sample_block: SampleBlock,
269 skip_preprocessing: bool = False,
270 skip_postprocessing: bool = False,
271 ) -> SampleBlock:
272 """Predict a single sample block.
274 Note that this does not apply samplewise preprocessing or postprocessing steps, but only blockwise ones.
276 Args:
277 sample_block: The sample block to predict on.
278 skip_preprocessing: If `True`, skip blockwise preprocessing steps.
279 skip_postprocessing: If `True`, skip blockwise postprocessing steps.
280 """
283class PredictionPipeline(_PredictionPipelineBase):
284 """
285 Represents model computation including preprocessing and postprocessing
286 Note: Ideally use the `PredictionPipeline` in a with statement
287 (as a context manager).
288 """
290 def __init__(
291 self,
292 *,
293 name: str,
294 model_description: AnyModelDescr,
295 preprocessing: List[Processing],
296 postprocessing: List[Processing],
297 model_adapter: ModelAdapter,
298 default_blocksize_parameter: BlocksizeParameter = 10,
299 default_batch_size: int = 1,
300 ) -> None:
301 """Consider using `create_prediction_pipeline` to create a `PredictionPipeline` with sensible defaults."""
302 super().__init__(
303 model_descr=model_description,
304 default_blocksize_parameter=default_blocksize_parameter,
305 default_batch_size=default_batch_size,
306 )
308 if model_description.run_mode:
309 warnings.warn(
310 f"Not yet implemented inference for run mode '{model_description.run_mode.name}'"
311 )
313 self.name = name
314 # split preprocessing into samplewise and blockwise. samplewise preprocessing is all preprocessing up to including the last samplewise operator, blockwise preprocessing are the remaining blockwise operators.
315 # I.e. some samplewise preprocessing may be a blockwise op (at some point followed by a samplewise op).
316 self._samplewise_preprocessing: List[
317 Union[SamplewiseOperator, BlockwiseOperator]
318 ] = []
319 self._blockwise_preprocessing: List[BlockwiseOperator] = []
320 for op in preprocessing[::-1]:
321 if isinstance(op, BlockwiseOperator) and not self._samplewise_preprocessing:
322 self._blockwise_preprocessing.insert(0, op)
323 else:
324 self._samplewise_preprocessing.insert(0, op)
325 # split postprocessing analougly, but here we start blockwise and switch to samplewise at the first samplewise operator.
326 self._blockwise_postprocessing: List[BlockwiseOperator] = []
327 self._samplewise_postprocessing: List[
328 Union[BlockwiseOperator, SamplewiseOperator]
329 ] = []
330 for op in postprocessing:
331 if (
332 isinstance(op, BlockwiseOperator)
333 and not self._samplewise_postprocessing
334 ):
335 self._blockwise_postprocessing.append(op)
336 else:
337 self._samplewise_postprocessing.append(op)
339 self._input_ids = get_member_ids(model_description.inputs)
340 self._output_ids = get_member_ids(model_description.outputs)
342 self._adapter = model_adapter
344 def __enter__(self):
345 self.load()
346 return self
348 def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore
349 self.unload()
350 return False
352 @property
353 def has_non_blockwise_preprocessing(self) -> bool:
354 """`True` if any preprocessing operators in the pipeline are not applicable blockwise."""
355 return bool(self._samplewise_preprocessing)
357 @property
358 def has_non_blockwise_postprocessing(self) -> bool:
359 """`True` if any postprocessing operators in the pipeline are not applicable blockwise."""
360 return bool(self._samplewise_postprocessing)
362 def _raise_for_non_blockwise_processing(
363 self, proc_type: Literal["preprocessing", "postprocessing"]
364 ):
365 ops = (
366 self._samplewise_preprocessing
367 if proc_type == "preprocessing"
368 else self._samplewise_postprocessing
369 )
370 non_blockwise = [
371 op.__class__.__name__ for op in ops if not isinstance(op, BlockwiseOperator)
372 ]
373 if non_blockwise:
374 raise NotImplementedError(
375 f"Blockwise {proc_type} for {non_blockwise} not implemented."
376 )
378 def raise_for_non_blockwise_preprocessing(self):
379 """
380 Raises:
381 NotImplementedError: if there are any non-blockwise preprocessing operators in the pipeline
382 """
383 self._raise_for_non_blockwise_processing("preprocessing")
385 def raise_for_non_blockwise_postprocessing(self):
386 """
387 Raises:
388 NotImplementedError: if there are any non-blockwise postprocessing operators in the pipeline
389 """
390 self._raise_for_non_blockwise_processing("postprocessing")
392 def predict_sample_block(
393 self,
394 sample_block: SampleBlock,
395 skip_preprocessing: bool = False,
396 skip_postprocessing: bool = False,
397 ) -> SampleBlock:
398 if isinstance(self._model_descr, v0_4.ModelDescr):
399 raise NotImplementedError(
400 f"predict_sample_block not implemented for model {self._model_descr.format_version}"
401 )
402 else:
403 assert self._block_transform is not None
405 if not skip_preprocessing:
406 self._apply_blockwise_preprocessing(sample_block)
408 output_meta = sample_block.get_transformed_meta(self._block_transform)
409 local_output = self._adapter.forward(sample_block.members)
411 output = output_meta.with_data(
412 {k: v for k, v in local_output.items() if v is not None},
413 stat=sample_block.stat,
414 )
415 if not skip_postprocessing:
416 self._apply_blockwise_postprocessing(output)
418 return output
420 def predict_sample_without_blocking(
421 self,
422 sample: Sample,
423 skip_preprocessing: bool = False,
424 skip_postprocessing: bool = False,
425 skip_input_padding: bool = False,
426 skip_output_cropping: bool = False,
427 ) -> Sample:
428 if not skip_input_padding:
429 sample = sample.pad(pad_width=self._default_input_halo, mode=self.pad_mode)
431 if not skip_preprocessing:
432 self.apply_preprocessing(sample)
434 output = Sample(
435 members={
436 k: v
437 for k, v in self._adapter.forward(sample.members).items()
438 if v is not None
439 },
440 stat=sample.stat,
441 id=sample.id,
442 )
443 if not skip_postprocessing:
444 self.apply_postprocessing(output)
446 if not skip_output_cropping:
447 output.members = {
448 m: t
449 if m not in self._default_output_halo
450 else t[
451 {
452 a: slice(h.left, None if h.right == 0 else -h.right)
453 for a, h in self._default_output_halo[m].items()
454 }
455 ]
456 for m, t in output.members.items()
457 }
459 return output
461 def get_output_sample_id(self, input_sample_id: SampleId):
462 warnings.warn(
463 "`PredictionPipeline.get_output_sample_id()` is deprecated and will be"
464 + " removed soon. Output sample id is equal to input sample id, hence this"
465 + " function is not needed."
466 )
467 return input_sample_id
469 def predict_sample_with_blocking(
470 self,
471 sample: Sample,
472 skip_preprocessing: bool = False,
473 skip_postprocessing: bool = False,
474 ns: Optional[
475 Union[
476 v0_5.ParameterizedSize_N,
477 Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N],
478 ]
479 ] = None,
480 batch_size: Optional[int] = None,
481 ) -> Sample:
482 output = None
483 for output in self.predict_sample_with_blocking_yield_intermediates(
484 sample,
485 skip_preprocessing=skip_preprocessing,
486 skip_postprocessing=skip_postprocessing,
487 ns=ns,
488 batch_size=batch_size,
489 )[1]:
490 pass
492 assert output is not None, (
493 "No blocks were predicted, cannot return final sample."
494 )
495 return output.sample
497 def predict_sample_with_fixed_blocking_yield_intermediates(
498 self,
499 sample: Sample,
500 input_block_shape: Mapping[MemberId, Mapping[AxisId, int]],
501 *,
502 skip_preprocessing: bool = False,
503 skip_postprocessing: bool = False,
504 fill_value: float = float("nan"),
505 ) -> Tuple[int, Iterable[IntermediatePrediction]]:
506 """Predict `sample` with given `input_block_shape` and yield the full sample with intermediate results.
508 Note:
509 - `input_block_shape` is expected to be a valid input shape for the model.
510 - Use `predict_sample_with_blocking` if you want to control block sizes via generic block size parameters
511 rather than fixed block shapes.
512 - Postprocessing may only be complete for the final sample (if samplewise postprocessing steps are included
513 in the pipeline), intermediate samples may have some (blockwise applicable) postprocessing steps applied.
515 Args:
516 sample: The sample to predict on.
517 input_block_shape: Mapping of input member id to mapping of axis id to block size for that axis.
518 skip_preprocessing: If `True`, skip all preprocessing steps.
519 skip_postprocessing: If `True`, skip all postprocessing steps.
521 Returns:
522 Tuple of number of blocks and an iterable of predicted intermediate samples with the last predicted block,
523 All samples, but the last one, are intermediate samples with more and more blocks predicted.
524 """
526 if not skip_preprocessing:
527 self._apply_samplewise_preprocessing(sample)
529 n_blocks, input_blocks = sample.split_into_blocks(
530 input_block_shape,
531 halo=self._default_input_halo,
532 pad_mode=self.pad_mode,
533 )
534 logger.info(
535 "split sample shape {} into {} blocks of {}.",
536 {k: dict(v) for k, v in sample.shape.items()},
537 n_blocks,
538 {k: dict(v) for k, v in input_block_shape.items()},
539 )
541 def _predict_blocks():
542 predicted_sample = None
543 for i, b in enumerate(
544 tqdm(
545 input_blocks,
546 desc=f"predict sample {sample.id or ''} with {self._model_descr.id or self._model_descr.name}",
547 unit="block",
548 unit_divisor=1,
549 total=n_blocks,
550 )
551 ):
552 if not skip_preprocessing:
553 self._apply_blockwise_preprocessing(b)
555 predicted_block = self.predict_sample_block(
556 b, skip_preprocessing=True, skip_postprocessing=True
557 )
559 if not skip_postprocessing:
560 self._apply_blockwise_postprocessing(predicted_block)
562 if predicted_sample is None:
563 predicted_sample = Sample.from_blocks(
564 [predicted_block], fill_value=fill_value
565 )
566 else:
567 predicted_sample.set_block(predicted_block)
569 if not skip_postprocessing and i == n_blocks - 1:
570 self._apply_samplewise_postprocessing(predicted_sample)
572 yield IntermediatePrediction(predicted_sample, predicted_block)
574 return n_blocks, _predict_blocks()
576 def _apply_samplewise_preprocessing(self, sample: Sample, /) -> None:
577 """Apply preprocessing operators up to and including the last samplewise operator in-place.
579 Note: This skips all blockwise preprocessing steps after the last samplewise operator.
580 """
581 if isinstance(sample, SampleBlock):
582 self.raise_for_non_blockwise_preprocessing()
584 for op in self._samplewise_preprocessing:
585 op(sample)
587 def _apply_blockwise_preprocessing(
588 self, sample_block: Union[Sample, SampleBlock], /
589 ) -> None:
590 """Apply blockwise preprocessing operators in-place.
592 Note: This skips all preprocessing operators up to and including the last samplewise one.
593 """
594 for op in self._blockwise_preprocessing:
595 op(sample_block)
597 def apply_preprocessing(self, sample: Union[Sample, SampleBlock]) -> None:
598 """Apply preprocessing in-place, also may updates sample stats"""
600 if isinstance(sample, Sample):
601 self._apply_samplewise_preprocessing(sample)
602 else:
603 self.raise_for_non_blockwise_preprocessing()
605 self._apply_blockwise_preprocessing(sample)
607 def _apply_blockwise_postprocessing(
608 self, sample_block: Union[Sample, SampleBlock], /
609 ) -> None:
610 """Apply in-place blockwise postprocessing operators
612 Note: This does not apply all postprocessing operators from the first samplewise one onwards.
613 """
614 for op in self._blockwise_postprocessing:
615 op(sample_block)
617 def _apply_samplewise_postprocessing(self, sample: Sample, /) -> None:
618 """Apply in-place postprocessing operators starting from and including the first samplewise operator.
620 Note: This skips all blockwise postprocessing steps before the first samplewise one.
621 """
622 if isinstance(sample, SampleBlock):
623 self.raise_for_non_blockwise_postprocessing()
625 for op in self._samplewise_postprocessing:
626 op(sample)
628 def apply_postprocessing(self, sample: Union[Sample, SampleBlock]) -> None:
629 """apply postprocessing in-place, also may updates samples stats"""
630 self._apply_blockwise_postprocessing(sample)
631 if isinstance(sample, Sample):
632 self._apply_samplewise_postprocessing(sample)
633 else:
634 self.raise_for_non_blockwise_postprocessing()
636 def load(self):
637 """Prepare prediction pipeline for use.
639 Reusable model adapters may be loaded and unloaded multiple times, but currently not all model adapters
640 cleanly unload and reload.
642 Note:
643 For some model adapters loading is currently part of the constructor making them unusable after unloading.
644 """
645 self._adapter.load()
647 def unload(self):
648 """Free any device memory in use.
650 Note:
651 Currently prediction pipeline becomes unusable after unloading."""
652 self._adapter.unload()
654 def close(self):
655 """Permanently close the prediction pipeline and free any device memory in use.
656 This makes the prediction pipeline unusable afterwards."""
657 self.unload()
660class RemotePredictionPipeline(_PredictionPipelineBase):
661 """Abstract base class for fully remote prediction pipelines.
663 Note: A ("local") `PredictionPipeline` may also use a `RemoteModelAdapter` for remote model inference, but it may
664 still apply local preprocessing and postprocessing steps.
665 In contrast, a `RemotePredictionPipeline` is designed for the case where all steps including preprocessing and
666 postprocessing are performed remotely.
667 """
669 def __init__(
670 self,
671 model_descr: AnyModelDescr,
672 *,
673 server: str,
674 default_blocksize_parameter: BlocksizeParameter,
675 default_batch_size: int,
676 ) -> None:
677 super().__init__(
678 model_descr,
679 default_blocksize_parameter=default_blocksize_parameter,
680 default_batch_size=default_batch_size,
681 )
682 self._server = server
684 @property
685 def server(self) -> str:
686 return self._server
689def create_prediction_pipeline(
690 bioimageio_model: AnyModelDescr,
691 *,
692 devices: Optional[Sequence[str]] = None,
693 weight_format: Optional[SupportedWeightsFormat] = None,
694 weights_format: Optional[SupportedWeightsFormat] = None,
695 dataset_for_initial_statistics: Iterable[Union[Sample, Sequence[Tensor]]] = tuple(),
696 keep_updating_initial_dataset_statistics: bool = False,
697 fixed_dataset_statistics: Mapping[Measure, MeasureValue] = MappingProxyType({}),
698 model_adapter: Optional[ModelAdapter] = None,
699 ns: Optional[BlocksizeParameter] = None,
700 default_blocksize_parameter: BlocksizeParameter = 10, # TODO: default to None and find smart blocksize params per axis to reduce overlap of blocks with large halo
701 **deprecated_kwargs: Any,
702) -> PredictionPipeline:
703 """
704 Creates prediction pipeline which includes:
705 * computation of input statistics
706 * preprocessing
707 * model prediction
708 * computation of output statistics
709 * postprocessing
711 Args:
712 bioimageio_model: A bioimageio model description.
713 devices: (optional)
714 weight_format: deprecated in favor of **weights_format**
715 weights_format: (optional) Use a specific **weights_format** rather than
716 choosing one automatically.
717 A corresponding `bioimageio.core.model_adapters.ModelAdapter` will be
718 created to run inference with the **bioimageio_model**.
719 dataset_for_initial_statistics: (optional) If preprocessing steps require input
720 dataset statistics, **dataset_for_initial_statistics** allows you to
721 specifcy a dataset from which these statistics are computed.
722 keep_updating_initial_dataset_statistics: (optional) Set to `True` if you want
723 to update dataset statistics with each processed sample.
724 fixed_dataset_statistics: (optional) Precomputed dataset (and optionally sample) statistics.
725 Any included sample statistics will not be calculated on the fly and it is the callers
726 responsibility to use samples with the corresponding statistics availble in `sample.stat`.
727 model_adapter: (optional) Allows you to use a custom **model_adapter** instead
728 of creating one according to the present/selected **weights_format**.
729 ns: deprecated in favor of **default_blocksize_parameter**
730 default_blocksize_parameter: Allows to control the default block size for
731 blockwise predictions, see `BlocksizeParameter`.
733 """
734 weights_format = weight_format or weights_format
735 del weight_format
736 default_blocksize_parameter = ns or default_blocksize_parameter
737 del ns
738 if deprecated_kwargs:
739 warnings.warn(
740 f"deprecated create_prediction_pipeline kwargs: {set(deprecated_kwargs)}"
741 )
743 model_adapter = model_adapter or create_model_adapter(
744 model_description=bioimageio_model,
745 devices=devices,
746 weight_format_priority_order=weights_format and (weights_format,),
747 )
749 input_ids = get_member_ids(bioimageio_model.inputs)
751 def dataset():
752 common_stat: Stat = {}
753 for i, x in enumerate(dataset_for_initial_statistics):
754 if isinstance(x, Sample):
755 yield x
756 else:
757 yield Sample(members=dict(zip(input_ids, x)), stat=common_stat, id=i)
759 preprocessing, postprocessing = setup_pre_and_postprocessing(
760 bioimageio_model,
761 dataset(),
762 keep_updating_initial_dataset_stats=keep_updating_initial_dataset_statistics,
763 fixed_dataset_stats=fixed_dataset_statistics,
764 )
766 return PredictionPipeline(
767 name=bioimageio_model.name,
768 model_description=bioimageio_model,
769 model_adapter=model_adapter,
770 preprocessing=preprocessing,
771 postprocessing=postprocessing,
772 default_blocksize_parameter=default_blocksize_parameter,
773 )
776def create_remote_prediction_pipeline(
777 model_description: AnyModelDescr,
778 *,
779 server: Optional[str] = None,
780 server_type: Optional[Literal["gradio"]] = "gradio",
781 precomputed_statistics: Mapping[Measure, MeasureValue] = MappingProxyType({}),
782 default_blocksize_parameter: BlocksizeParameter = 10, # TODO: default to None and find smart blocksize params per axis to reduce overlap of blocks with large halo
783 default_batch_size: int = 1,
784) -> RemotePredictionPipeline:
785 """Create a `RemotePredictionPipeline` for the given `model_description`.
787 Args:
788 model_description: The model to run inference with.
789 server: The URL or Hugging Face space name of a running bioimageio server instance
790 server_type: The type of the remote server to connect to. Currently only "gradio" is supported.
791 precomputed_statistics: Precomputed dataset (and optionally sample) statistics.
792 Any included sample statistics will not be calculated on the fly and it is the callers
793 responsibility to use samples with the corresponding statistics availble in `sample.stat`.
794 default_blocksize_parameter: Allows to control the default block size with a single parameter for blockwise predictions. (not all models support this)
795 default_batch_size: Default batch size to use
796 """
798 if server_type is None:
799 server_type = "gradio"
801 try:
802 if server_type == "gradio":
803 from .remote_backends.gradio.client import (
804 GradioPredictionPipeline as RemotePredictionPipelineImpl,
805 )
806 else:
807 assert_never(server_type)
808 except ImportError as e:
809 raise ImportError(
810 f"Failed to import {server_type.capitalize()}PredictionPipeline. Make sure to install the '{server_type}-client' extra,"
811 + f" e.g. with `pip install bioimageio.core[{server_type}-client]`."
812 ) from e
814 return RemotePredictionPipelineImpl(
815 model_description,
816 server=server,
817 precomputed_statistics=precomputed_statistics,
818 default_blocksize_parameter=default_blocksize_parameter,
819 default_batch_size=default_batch_size,
820 )