Coverage for src / bioimageio / core / _prediction_pipeline.py: 82%
146 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 13:23 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 13:23 +0000
1import warnings
2from types import MappingProxyType
3from typing import (
4 Any,
5 Iterable,
6 List,
7 Literal,
8 Mapping,
9 Optional,
10 Sequence,
11 Tuple,
12 TypeVar,
13 Union,
14)
16from loguru import logger
17from tqdm import tqdm
19from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
21from ._op_base import BlockwiseOperator
22from .axis import AxisId, PerAxis
23from .common import (
24 BlocksizeParameter,
25 Halo,
26 MemberId,
27 PerMember,
28 SampleId,
29 SupportedWeightsFormat,
30)
31from .digest_spec import (
32 get_block_transform,
33 get_input_halo,
34 get_member_ids,
35)
36from .model_adapters import ModelAdapter, create_model_adapter
37from .model_adapters import get_weight_formats as get_weight_formats
38from .proc_setup import Processing, setup_pre_and_postprocessing
39from .sample import Sample, SampleBlock
40from .stat_measures import DatasetMeasure, MeasureValue, Stat
41from .tensor import Tensor
43Predict_IO = TypeVar(
44 "Predict_IO",
45 Sample,
46 Iterable[Sample],
47)
50class PredictionPipeline:
51 """
52 Represents model computation including preprocessing and postprocessing
53 Note: Ideally use the `PredictionPipeline` in a with statement
54 (as a context manager).
55 """
57 def __init__(
58 self,
59 *,
60 name: str,
61 model_description: AnyModelDescr,
62 preprocessing: List[Processing],
63 postprocessing: List[Processing],
64 model_adapter: ModelAdapter,
65 default_ns: Optional[BlocksizeParameter] = None,
66 default_blocksize_parameter: BlocksizeParameter = 10,
67 default_batch_size: int = 1,
68 ) -> None:
69 """Consider using `create_prediction_pipeline` to create a `PredictionPipeline` with sensible defaults."""
70 super().__init__()
71 default_blocksize_parameter = default_ns or default_blocksize_parameter
72 if default_ns is not None:
73 warnings.warn(
74 "Argument `default_ns` is deprecated in favor of"
75 + " `default_blocksize_paramter` and will be removed soon."
76 )
77 del default_ns
79 if model_description.run_mode:
80 warnings.warn(
81 f"Not yet implemented inference for run mode '{model_description.run_mode.name}'"
82 )
84 self.name = name
85 self._preprocessing = preprocessing
86 self._postprocessing = postprocessing
88 self.model_description = model_description
89 if isinstance(model_description, v0_4.ModelDescr):
90 self._default_input_halo: PerMember[PerAxis[Halo]] = {}
91 self._block_transform = None
92 else:
93 default_output_halo = {
94 t.id: {
95 a.id: Halo(a.halo, a.halo)
96 for a in t.axes
97 if isinstance(a, v0_5.WithHalo)
98 }
99 for t in model_description.outputs
100 }
101 self._default_input_halo = get_input_halo(
102 model_description, default_output_halo
103 )
104 self._block_transform = get_block_transform(model_description)
106 self._default_blocksize_parameter = default_blocksize_parameter
107 self._default_batch_size = default_batch_size
109 self._input_ids = get_member_ids(model_description.inputs)
110 self._output_ids = get_member_ids(model_description.outputs)
112 self._adapter: ModelAdapter = model_adapter
114 def __enter__(self):
115 self.load()
116 return self
118 def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore
119 self.unload()
120 return False
122 @property
123 def has_blockwise_preprocessing(self) -> bool:
124 """`True` if all preprocessing operators in the pipeline are blockwise."""
125 return all(isinstance(op, BlockwiseOperator) for op in self._preprocessing)
127 @property
128 def has_blockwise_postprocessing(self) -> bool:
129 """`True` if all postprocessing operators in the pipeline are blockwise."""
130 return all(isinstance(op, BlockwiseOperator) for op in self._postprocessing)
132 def _raise_for_non_blockwise_processing(
133 self, proc_type: Literal["preprocessing", "postprocessing"]
134 ):
135 ops = (
136 self._preprocessing
137 if proc_type == "preprocessing"
138 else self._postprocessing
139 )
140 non_blockwise = [
141 op.__class__.__name__ for op in ops if not isinstance(op, BlockwiseOperator)
142 ]
143 if non_blockwise:
144 raise NotImplementedError(
145 f"Blockwise {proc_type} for non-blockwise operators {non_blockwise} not implemented."
146 )
148 def raise_for_non_blockwise_preprocessing(self):
149 """
150 Raises:
151 NotImplementedError: if there are any non-blockwise preprocessing operators in the pipeline
152 """
153 self._raise_for_non_blockwise_processing("preprocessing")
155 def raise_for_non_blockwise_postprocessing(self):
156 """
157 Raises:
158 NotImplementedError: if there are any non-blockwise postprocessing operators in the pipeline
159 """
160 self._raise_for_non_blockwise_processing("postprocessing")
162 def predict_sample_block(
163 self,
164 sample_block: SampleBlock,
165 skip_preprocessing: bool = False,
166 skip_postprocessing: bool = False,
167 ) -> SampleBlock:
168 if isinstance(self.model_description, v0_4.ModelDescr):
169 raise NotImplementedError(
170 f"predict_sample_block not implemented for model {self.model_description.format_version}"
171 )
172 else:
173 assert self._block_transform is not None
175 if not skip_preprocessing:
176 self.raise_for_non_blockwise_preprocessing()
178 if not skip_postprocessing:
179 self.raise_for_non_blockwise_postprocessing()
181 if not skip_preprocessing:
182 self.apply_preprocessing(sample_block)
184 output_meta = sample_block.get_transformed_meta(self._block_transform)
185 local_output = self._adapter.forward(sample_block)
187 output = output_meta.with_data(local_output.members, stat=local_output.stat)
188 if not skip_postprocessing:
189 self.apply_postprocessing(output)
191 return output
193 def predict_sample_without_blocking(
194 self,
195 sample: Sample,
196 skip_preprocessing: bool = False,
197 skip_postprocessing: bool = False,
198 ) -> Sample:
199 """predict a whole sample
201 Note:
202 The sample's tensor shapes have to match the model's input tensor description.
203 If that is not the case, consider `predict_sample_with_blocking`
204 """
206 if not skip_preprocessing:
207 self.apply_preprocessing(sample)
209 output = self._adapter.forward(sample)
210 if not skip_postprocessing:
211 self.apply_postprocessing(output)
213 return output
215 def get_output_sample_id(self, input_sample_id: SampleId):
216 warnings.warn(
217 "`PredictionPipeline.get_output_sample_id()` is deprecated and will be"
218 + " removed soon. Output sample id is equal to input sample id, hence this"
219 + " function is not needed."
220 )
221 return input_sample_id
223 def predict_sample_with_fixed_blocking(
224 self,
225 sample: Sample,
226 input_block_shape: Mapping[MemberId, Mapping[AxisId, int]],
227 *,
228 skip_preprocessing: bool = False,
229 skip_postprocessing: bool = False,
230 ) -> Sample:
231 """Predict `sample` with given `input_block_shape`.
233 Note:
234 `input_block_shape` is expected to be a valid input shape for the model.
235 """
236 if not skip_preprocessing:
237 self.apply_preprocessing(sample)
239 n_blocks, input_blocks = sample.split_into_blocks(
240 input_block_shape,
241 halo=self._default_input_halo,
242 pad_mode="reflect",
243 )
244 input_blocks = list(input_blocks)
245 predicted_blocks: List[SampleBlock] = []
246 logger.info(
247 "split sample shape {} into {} blocks of {}.",
248 {k: dict(v) for k, v in sample.shape.items()},
249 n_blocks,
250 {k: dict(v) for k, v in input_block_shape.items()},
251 )
252 for b in tqdm(
253 input_blocks,
254 desc=f"predict sample {sample.id or ''} with {self.model_description.id or self.model_description.name}",
255 unit="block",
256 unit_divisor=1,
257 total=n_blocks,
258 ):
259 predicted_blocks.append(
260 self.predict_sample_block(
261 b, skip_preprocessing=True, skip_postprocessing=True
262 )
263 )
265 predicted_sample = Sample.from_blocks(predicted_blocks)
266 if not skip_postprocessing:
267 self.apply_postprocessing(predicted_sample)
269 return predicted_sample
271 def predict_sample_with_blocking(
272 self,
273 sample: Sample,
274 skip_preprocessing: bool = False,
275 skip_postprocessing: bool = False,
276 ns: Optional[
277 Union[
278 v0_5.ParameterizedSize_N,
279 Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N],
280 ]
281 ] = None,
282 batch_size: Optional[int] = None,
283 ) -> Sample:
284 """Predict a sample by splitting it into blocks according to the mode
286 The `ns` parameter allow scaling the model's default input block size.
287 """
289 if isinstance(self.model_description, v0_4.ModelDescr):
290 raise NotImplementedError(
291 "`predict_sample_with_blocking` not implemented for v0_4.ModelDescr"
292 + f" {self.model_description.name}."
293 + " Consider using `predict_sample_with_fixed_blocking`"
294 )
296 ns = ns or self._default_blocksize_parameter
297 if isinstance(ns, int):
298 ns = {
299 (ipt.id, a.id): ns
300 for ipt in self.model_description.inputs
301 for a in ipt.axes
302 if isinstance(a.size, v0_5.ParameterizedSize)
303 }
304 input_block_shape = self.model_description.get_tensor_sizes(
305 ns, batch_size or self._default_batch_size
306 ).inputs
308 return self.predict_sample_with_fixed_blocking(
309 sample,
310 input_block_shape=input_block_shape,
311 skip_preprocessing=skip_preprocessing,
312 skip_postprocessing=skip_postprocessing,
313 )
315 def apply_preprocessing(self, sample: Union[Sample, SampleBlock]) -> None:
316 """apply preprocessing in-place, also may updates sample stats"""
317 if isinstance(sample, SampleBlock):
318 self.raise_for_non_blockwise_preprocessing()
320 for op in self._preprocessing:
321 if isinstance(sample, SampleBlock):
322 assert isinstance(op, BlockwiseOperator)
323 op(sample)
324 else:
325 op(sample)
327 def apply_postprocessing(self, sample: Union[Sample, SampleBlock]) -> None:
328 """apply postprocessing in-place, also may updates samples stats"""
329 if isinstance(sample, SampleBlock):
330 self.raise_for_non_blockwise_postprocessing()
332 for op in self._postprocessing:
333 if isinstance(sample, SampleBlock):
334 assert isinstance(op, BlockwiseOperator)
335 op(sample)
336 else:
337 op(sample)
339 def load(self):
340 """
341 optional step: load model onto devices before calling forward if not using it as context manager
342 """
343 pass
345 def unload(self):
346 """
347 free any device memory in use
348 """
349 self._adapter.unload()
352def create_prediction_pipeline(
353 bioimageio_model: AnyModelDescr,
354 *,
355 devices: Optional[Sequence[str]] = None,
356 weight_format: Optional[SupportedWeightsFormat] = None,
357 weights_format: Optional[SupportedWeightsFormat] = None,
358 dataset_for_initial_statistics: Iterable[Union[Sample, Sequence[Tensor]]] = tuple(),
359 keep_updating_initial_dataset_statistics: bool = False,
360 fixed_dataset_statistics: Mapping[DatasetMeasure, MeasureValue] = MappingProxyType(
361 {}
362 ),
363 model_adapter: Optional[ModelAdapter] = None,
364 ns: Optional[BlocksizeParameter] = None,
365 default_blocksize_parameter: BlocksizeParameter = 10, # TODO: default to None and find smart blocksize params per axis to reduce overlap of blocks with large halo
366 **deprecated_kwargs: Any,
367) -> PredictionPipeline:
368 """
369 Creates prediction pipeline which includes:
370 * computation of input statistics
371 * preprocessing
372 * model prediction
373 * computation of output statistics
374 * postprocessing
376 Args:
377 bioimageio_model: A bioimageio model description.
378 devices: (optional)
379 weight_format: deprecated in favor of **weights_format**
380 weights_format: (optional) Use a specific **weights_format** rather than
381 choosing one automatically.
382 A corresponding `bioimageio.core.model_adapters.ModelAdapter` will be
383 created to run inference with the **bioimageio_model**.
384 dataset_for_initial_statistics: (optional) If preprocessing steps require input
385 dataset statistics, **dataset_for_initial_statistics** allows you to
386 specifcy a dataset from which these statistics are computed.
387 keep_updating_initial_dataset_statistics: (optional) Set to `True` if you want
388 to update dataset statistics with each processed sample.
389 fixed_dataset_statistics: (optional) Allows you to specify a mapping of
390 `DatasetMeasure`s to precomputed `MeasureValue`s.
391 model_adapter: (optional) Allows you to use a custom **model_adapter** instead
392 of creating one according to the present/selected **weights_format**.
393 ns: deprecated in favor of **default_blocksize_parameter**
394 default_blocksize_parameter: Allows to control the default block size for
395 blockwise predictions, see `BlocksizeParameter`.
397 """
398 weights_format = weight_format or weights_format
399 del weight_format
400 default_blocksize_parameter = ns or default_blocksize_parameter
401 del ns
402 if deprecated_kwargs:
403 warnings.warn(
404 f"deprecated create_prediction_pipeline kwargs: {set(deprecated_kwargs)}"
405 )
407 model_adapter = model_adapter or create_model_adapter(
408 model_description=bioimageio_model,
409 devices=devices,
410 weight_format_priority_order=weights_format and (weights_format,),
411 )
413 input_ids = get_member_ids(bioimageio_model.inputs)
415 def dataset():
416 common_stat: Stat = {}
417 for i, x in enumerate(dataset_for_initial_statistics):
418 if isinstance(x, Sample):
419 yield x
420 else:
421 yield Sample(members=dict(zip(input_ids, x)), stat=common_stat, id=i)
423 preprocessing, postprocessing = setup_pre_and_postprocessing(
424 bioimageio_model,
425 dataset(),
426 keep_updating_initial_dataset_stats=keep_updating_initial_dataset_statistics,
427 fixed_dataset_stats=fixed_dataset_statistics,
428 )
430 return PredictionPipeline(
431 name=bioimageio_model.name,
432 model_description=bioimageio_model,
433 model_adapter=model_adapter,
434 preprocessing=preprocessing,
435 postprocessing=postprocessing,
436 default_blocksize_parameter=default_blocksize_parameter,
437 )