Coverage for src / bioimageio / core / _prediction_pipeline.py: 84%
171 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-18 12:35 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-18 12:35 +0000
1import warnings
2from types import MappingProxyType
3from typing import (
4 Any,
5 Iterable,
6 List,
7 Literal,
8 Mapping,
9 Optional,
10 Sequence,
11 Tuple,
12 TypeVar,
13 Union,
14)
16from loguru import logger
17from tqdm import tqdm
19from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
21from ._op_base import BlockwiseOperator, SamplewiseOperator
22from .axis import AxisId, PerAxis
23from .common import (
24 BlocksizeParameter,
25 Halo,
26 MemberId,
27 PerMember,
28 SampleId,
29 SupportedWeightsFormat,
30)
31from .digest_spec import (
32 get_block_transform,
33 get_input_halo,
34 get_member_ids,
35)
36from .model_adapters import ModelAdapter, create_model_adapter
37from .model_adapters import get_weight_formats as get_weight_formats
38from .proc_ops import Processing
39from .proc_setup import setup_pre_and_postprocessing
40from .sample import Sample, SampleBlock
41from .stat_measures import Measure, MeasureValue, Stat
42from .tensor import Tensor
44Predict_IO = TypeVar(
45 "Predict_IO",
46 Sample,
47 Iterable[Sample],
48)
51class PredictionPipeline:
52 """
53 Represents model computation including preprocessing and postprocessing
54 Note: Ideally use the `PredictionPipeline` in a with statement
55 (as a context manager).
56 """
58 def __init__(
59 self,
60 *,
61 name: str,
62 model_description: AnyModelDescr,
63 preprocessing: List[Processing],
64 postprocessing: List[Processing],
65 model_adapter: ModelAdapter,
66 default_ns: Optional[BlocksizeParameter] = None,
67 default_blocksize_parameter: BlocksizeParameter = 10,
68 default_batch_size: int = 1,
69 ) -> None:
70 """Consider using `create_prediction_pipeline` to create a `PredictionPipeline` with sensible defaults."""
71 super().__init__()
72 default_blocksize_parameter = default_ns or default_blocksize_parameter
73 if default_ns is not None:
74 warnings.warn(
75 "Argument `default_ns` is deprecated in favor of"
76 + " `default_blocksize_paramter` and will be removed soon."
77 )
78 del default_ns
80 if model_description.run_mode:
81 warnings.warn(
82 f"Not yet implemented inference for run mode '{model_description.run_mode.name}'"
83 )
85 self.name = name
86 # split preprocessing into samplewise and blockwise. samplewise preprocessing is all preprocessing up to including the last samplewise operator, blockwise preprocessing are the remaining blockwise operators.
87 # I.e. some samplewise preprocessing may be a blockwise op (at some point followed by a samplewise op).
88 self._samplewise_preprocessing: List[
89 Union[SamplewiseOperator, BlockwiseOperator]
90 ] = []
91 self._blockwise_preprocessing: List[BlockwiseOperator] = []
92 for op in preprocessing[::-1]:
93 if isinstance(op, BlockwiseOperator) and not self._samplewise_preprocessing:
94 self._blockwise_preprocessing.insert(0, op)
95 else:
96 self._samplewise_preprocessing.insert(0, op)
97 # split postprocessing analougly, but here we start blockwise and switch to samplewise at the first samplewise operator.
98 self._blockwise_postprocessing: List[BlockwiseOperator] = []
99 self._samplewise_postprocessing: List[
100 Union[BlockwiseOperator, SamplewiseOperator]
101 ] = []
102 for op in postprocessing:
103 if (
104 isinstance(op, BlockwiseOperator)
105 and not self._samplewise_postprocessing
106 ):
107 self._blockwise_postprocessing.append(op)
108 else:
109 self._samplewise_postprocessing.append(op)
111 self.pad_mode = (
112 {}
113 if isinstance(model_description, v0_4.ModelDescr)
114 else {
115 descr.id: descr.pad or v0_5.SymmetricPadding()
116 for descr in model_description.inputs
117 }
118 )
119 self.model_description = model_description
120 if isinstance(model_description, v0_4.ModelDescr):
121 self._default_output_halo: PerMember[PerAxis[Halo]] = {}
122 self._default_input_halo: PerMember[PerAxis[Halo]] = {}
123 self._block_transform = None
124 else:
125 self._default_output_halo = {
126 t.id: {
127 a.id: Halo(a.halo, a.halo)
128 for a in t.axes
129 if isinstance(a, v0_5.WithHalo)
130 }
131 for t in model_description.outputs
132 }
133 self._default_input_halo = get_input_halo(
134 model_description, self._default_output_halo
135 )
136 self._block_transform = get_block_transform(model_description)
138 self._default_blocksize_parameter = default_blocksize_parameter
139 self._default_batch_size = default_batch_size
141 self._input_ids = get_member_ids(model_description.inputs)
142 self._output_ids = get_member_ids(model_description.outputs)
144 self._adapter: ModelAdapter = model_adapter
146 def __enter__(self):
147 self.load()
148 return self
150 def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore
151 self.unload()
152 return False
154 @property
155 def has_blockwise_preprocessing(self) -> bool:
156 """`True` if all preprocessing operators in the pipeline are blockwise."""
157 return bool(self._blockwise_preprocessing)
159 @property
160 def has_blockwise_postprocessing(self) -> bool:
161 """`True` if all postprocessing operators in the pipeline are blockwise."""
162 return bool(self._blockwise_postprocessing)
164 def _raise_for_non_blockwise_processing(
165 self, proc_type: Literal["preprocessing", "postprocessing"]
166 ):
167 ops = (
168 self._samplewise_preprocessing
169 if proc_type == "preprocessing"
170 else self._samplewise_postprocessing
171 )
172 non_blockwise = [
173 op.__class__.__name__ for op in ops if not isinstance(op, BlockwiseOperator)
174 ]
175 if non_blockwise:
176 raise NotImplementedError(
177 f"Blockwise {proc_type} for {non_blockwise} not implemented."
178 )
180 def raise_for_non_blockwise_preprocessing(self):
181 """
182 Raises:
183 NotImplementedError: if there are any non-blockwise preprocessing operators in the pipeline
184 """
185 self._raise_for_non_blockwise_processing("preprocessing")
187 def raise_for_non_blockwise_postprocessing(self):
188 """
189 Raises:
190 NotImplementedError: if there are any non-blockwise postprocessing operators in the pipeline
191 """
192 self._raise_for_non_blockwise_processing("postprocessing")
194 def predict_sample_block(
195 self,
196 sample_block: SampleBlock,
197 skip_preprocessing: bool = False,
198 skip_postprocessing: bool = False,
199 ) -> SampleBlock:
200 if isinstance(self.model_description, v0_4.ModelDescr):
201 raise NotImplementedError(
202 f"predict_sample_block not implemented for model {self.model_description.format_version}"
203 )
204 else:
205 assert self._block_transform is not None
207 if not skip_preprocessing:
208 self.raise_for_non_blockwise_preprocessing()
210 if not skip_postprocessing:
211 self.raise_for_non_blockwise_postprocessing()
213 if not skip_preprocessing:
214 self.apply_preprocessing(sample_block)
216 output_meta = sample_block.get_transformed_meta(self._block_transform)
217 local_output = self._adapter.forward(sample_block)
219 output = output_meta.with_data(local_output.members, stat=local_output.stat)
220 if not skip_postprocessing:
221 self.apply_postprocessing(output)
223 return output
225 def predict_sample_without_blocking(
226 self,
227 sample: Sample,
228 skip_preprocessing: bool = False,
229 skip_postprocessing: bool = False,
230 skip_input_padding: bool = False,
231 skip_output_cropping: bool = False,
232 ) -> Sample:
233 """predict a whole sample
235 Args:
236 sample: input sample
237 skip_preprocessing: if `True`, skip all preprocessing steps.
238 skip_postprocessing: if `True`, skip all postprocessing steps.
239 skip_input_padding: if `True`, skip padding the input sample according to the model's (optional) output halos.
240 skip_output_cropping: if `True`, skip cropping any output halos from the model output.
241 Note:
242 The sample's tensor shapes have to match the model's input tensor description.
243 If that is not the case, consider `predict_sample_with_blocking`
244 """
246 if not skip_input_padding:
247 sample = sample.pad(pad_width=self._default_input_halo, mode=self.pad_mode)
249 if not skip_preprocessing:
250 self.apply_preprocessing(sample)
252 output = self._adapter.forward(sample)
253 if not skip_postprocessing:
254 self.apply_postprocessing(output)
256 if not skip_output_cropping:
257 output.members = {
258 m: t
259 if m not in self._default_output_halo
260 else t[
261 {
262 a: slice(h.left, None if h.right == 0 else -h.right)
263 for a, h in self._default_output_halo[m].items()
264 }
265 ]
266 for m, t in output.members.items()
267 }
269 return output
271 def get_output_sample_id(self, input_sample_id: SampleId):
272 warnings.warn(
273 "`PredictionPipeline.get_output_sample_id()` is deprecated and will be"
274 + " removed soon. Output sample id is equal to input sample id, hence this"
275 + " function is not needed."
276 )
277 return input_sample_id
279 def predict_sample_with_fixed_blocking(
280 self,
281 sample: Sample,
282 input_block_shape: Mapping[MemberId, Mapping[AxisId, int]],
283 *,
284 skip_preprocessing: bool = False,
285 skip_postprocessing: bool = False,
286 ) -> Sample:
287 """Predict `sample` with given `input_block_shape`.
289 Note:
290 `input_block_shape` is expected to be a valid input shape for the model.
291 """
292 if not skip_preprocessing:
293 for op in self._samplewise_preprocessing:
294 op(sample)
296 n_blocks, input_blocks = sample.split_into_blocks(
297 input_block_shape,
298 halo=self._default_input_halo,
299 pad_mode=self.pad_mode,
300 )
301 input_blocks = list(input_blocks)
302 predicted_blocks: List[SampleBlock] = []
303 logger.info(
304 "split sample shape {} into {} blocks of {}.",
305 {k: dict(v) for k, v in sample.shape.items()},
306 n_blocks,
307 {k: dict(v) for k, v in input_block_shape.items()},
308 )
309 for b in tqdm(
310 input_blocks,
311 desc=f"predict sample {sample.id or ''} with {self.model_description.id or self.model_description.name}",
312 unit="block",
313 unit_divisor=1,
314 total=n_blocks,
315 ):
316 if not skip_preprocessing:
317 for op in self._blockwise_preprocessing:
318 op(b)
320 predicted_blocks.append(
321 self.predict_sample_block(
322 b, skip_preprocessing=True, skip_postprocessing=True
323 )
324 )
325 if not skip_postprocessing:
326 for op in self._blockwise_postprocessing:
327 op(predicted_blocks[-1])
329 predicted_sample = Sample.from_blocks(predicted_blocks)
330 if not skip_postprocessing:
331 for op in self._samplewise_postprocessing:
332 op(predicted_sample)
334 return predicted_sample
336 def predict_sample_with_blocking(
337 self,
338 sample: Sample,
339 skip_preprocessing: bool = False,
340 skip_postprocessing: bool = False,
341 ns: Optional[
342 Union[
343 v0_5.ParameterizedSize_N,
344 Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N],
345 ]
346 ] = None,
347 batch_size: Optional[int] = None,
348 ) -> Sample:
349 """Predict a sample by splitting it into blocks according to the mode
351 The `ns` parameter allow scaling the model's default input block size.
352 """
354 if isinstance(self.model_description, v0_4.ModelDescr):
355 raise NotImplementedError(
356 "`predict_sample_with_blocking` not implemented for v0_4.ModelDescr"
357 + f" {self.model_description.name}."
358 + " Consider using `predict_sample_with_fixed_blocking`"
359 )
361 ns = ns or self._default_blocksize_parameter
362 if isinstance(ns, int):
363 ns = {
364 (ipt.id, a.id): ns
365 for ipt in self.model_description.inputs
366 for a in ipt.axes
367 if isinstance(a.size, v0_5.ParameterizedSize)
368 }
369 input_block_shape = self.model_description.get_tensor_sizes(
370 ns, batch_size or self._default_batch_size
371 ).inputs
373 return self.predict_sample_with_fixed_blocking(
374 sample,
375 input_block_shape=input_block_shape,
376 skip_preprocessing=skip_preprocessing,
377 skip_postprocessing=skip_postprocessing,
378 )
380 def apply_preprocessing(self, sample: Union[Sample, SampleBlock]) -> None:
381 """apply preprocessing in-place, also may updates sample stats"""
382 if isinstance(sample, SampleBlock):
383 self.raise_for_non_blockwise_preprocessing()
385 for op in self._samplewise_preprocessing + self._blockwise_preprocessing:
386 if isinstance(sample, SampleBlock):
387 assert isinstance(op, BlockwiseOperator)
388 op(sample)
389 else:
390 op(sample)
392 def apply_postprocessing(self, sample: Union[Sample, SampleBlock]) -> None:
393 """apply postprocessing in-place, also may updates samples stats"""
394 if isinstance(sample, SampleBlock):
395 self.raise_for_non_blockwise_postprocessing()
397 for op in self._blockwise_postprocessing + self._samplewise_postprocessing:
398 if isinstance(sample, SampleBlock):
399 assert isinstance(op, BlockwiseOperator)
400 op(sample)
401 else:
402 op(sample)
404 def load(self):
405 """
406 optional step: load model onto devices before calling forward if not using it as context manager
407 """
408 pass
410 def unload(self):
411 """
412 free any device memory in use
413 """
414 self._adapter.unload()
417def create_prediction_pipeline(
418 bioimageio_model: AnyModelDescr,
419 *,
420 devices: Optional[Sequence[str]] = None,
421 weight_format: Optional[SupportedWeightsFormat] = None,
422 weights_format: Optional[SupportedWeightsFormat] = None,
423 dataset_for_initial_statistics: Iterable[Union[Sample, Sequence[Tensor]]] = tuple(),
424 keep_updating_initial_dataset_statistics: bool = False,
425 fixed_dataset_statistics: Mapping[Measure, MeasureValue] = MappingProxyType({}),
426 model_adapter: Optional[ModelAdapter] = None,
427 ns: Optional[BlocksizeParameter] = None,
428 default_blocksize_parameter: BlocksizeParameter = 10, # TODO: default to None and find smart blocksize params per axis to reduce overlap of blocks with large halo
429 **deprecated_kwargs: Any,
430) -> PredictionPipeline:
431 """
432 Creates prediction pipeline which includes:
433 * computation of input statistics
434 * preprocessing
435 * model prediction
436 * computation of output statistics
437 * postprocessing
439 Args:
440 bioimageio_model: A bioimageio model description.
441 devices: (optional)
442 weight_format: deprecated in favor of **weights_format**
443 weights_format: (optional) Use a specific **weights_format** rather than
444 choosing one automatically.
445 A corresponding `bioimageio.core.model_adapters.ModelAdapter` will be
446 created to run inference with the **bioimageio_model**.
447 dataset_for_initial_statistics: (optional) If preprocessing steps require input
448 dataset statistics, **dataset_for_initial_statistics** allows you to
449 specifcy a dataset from which these statistics are computed.
450 keep_updating_initial_dataset_statistics: (optional) Set to `True` if you want
451 to update dataset statistics with each processed sample.
452 fixed_dataset_statistics: (optional) Precomputed dataset (and optionally sample) statistics.
453 Any included sample statistics will not be calculated on the fly and it is the callers
454 responsibility to use samples with the corresponding statistics availble in `sample.stat`.
455 model_adapter: (optional) Allows you to use a custom **model_adapter** instead
456 of creating one according to the present/selected **weights_format**.
457 ns: deprecated in favor of **default_blocksize_parameter**
458 default_blocksize_parameter: Allows to control the default block size for
459 blockwise predictions, see `BlocksizeParameter`.
461 """
462 weights_format = weight_format or weights_format
463 del weight_format
464 default_blocksize_parameter = ns or default_blocksize_parameter
465 del ns
466 if deprecated_kwargs:
467 warnings.warn(
468 f"deprecated create_prediction_pipeline kwargs: {set(deprecated_kwargs)}"
469 )
471 model_adapter = model_adapter or create_model_adapter(
472 model_description=bioimageio_model,
473 devices=devices,
474 weight_format_priority_order=weights_format and (weights_format,),
475 )
477 input_ids = get_member_ids(bioimageio_model.inputs)
479 def dataset():
480 common_stat: Stat = {}
481 for i, x in enumerate(dataset_for_initial_statistics):
482 if isinstance(x, Sample):
483 yield x
484 else:
485 yield Sample(members=dict(zip(input_ids, x)), stat=common_stat, id=i)
487 preprocessing, postprocessing = setup_pre_and_postprocessing(
488 bioimageio_model,
489 dataset(),
490 keep_updating_initial_dataset_stats=keep_updating_initial_dataset_statistics,
491 fixed_dataset_stats=fixed_dataset_statistics,
492 )
494 return PredictionPipeline(
495 name=bioimageio_model.name,
496 model_description=bioimageio_model,
497 model_adapter=model_adapter,
498 preprocessing=preprocessing,
499 postprocessing=postprocessing,
500 default_blocksize_parameter=default_blocksize_parameter,
501 )