Coverage for bioimageio/core/_prediction_pipeline.py: 89%
122 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 09:51 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 09:51 +0000
1import warnings
2from types import MappingProxyType
3from typing import (
4 Any,
5 Iterable,
6 List,
7 Mapping,
8 Optional,
9 Sequence,
10 Tuple,
11 TypeVar,
12 Union,
13)
15from loguru import logger
16from tqdm import tqdm
18from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
20from ._op_base import BlockedOperator
21from .axis import AxisId, PerAxis
22from .common import (
23 BlocksizeParameter,
24 Halo,
25 MemberId,
26 PerMember,
27 SampleId,
28 SupportedWeightsFormat,
29)
30from .digest_spec import (
31 get_block_transform,
32 get_input_halo,
33 get_member_ids,
34)
35from .model_adapters import ModelAdapter, create_model_adapter
36from .model_adapters import get_weight_formats as get_weight_formats
37from .proc_ops import Processing
38from .proc_setup import setup_pre_and_postprocessing
39from .sample import Sample, SampleBlock, SampleBlockWithOrigin
40from .stat_measures import DatasetMeasure, MeasureValue, Stat
41from .tensor import Tensor
43Predict_IO = TypeVar(
44 "Predict_IO",
45 Sample,
46 Iterable[Sample],
47)
50class PredictionPipeline:
51 """
52 Represents model computation including preprocessing and postprocessing
53 Note: Ideally use the `PredictionPipeline` in a with statement
54 (as a context manager).
55 """
57 def __init__(
58 self,
59 *,
60 name: str,
61 model_description: AnyModelDescr,
62 preprocessing: List[Processing],
63 postprocessing: List[Processing],
64 model_adapter: ModelAdapter,
65 default_ns: Optional[BlocksizeParameter] = None,
66 default_blocksize_parameter: BlocksizeParameter = 10,
67 default_batch_size: int = 1,
68 ) -> None:
69 """Use `create_prediction_pipeline` to create a `PredictionPipeline`"""
70 super().__init__()
71 default_blocksize_parameter = default_ns or default_blocksize_parameter
72 if default_ns is not None:
73 warnings.warn(
74 "Argument `default_ns` is deprecated in favor of"
75 + " `default_blocksize_paramter` and will be removed soon."
76 )
77 del default_ns
79 if model_description.run_mode:
80 warnings.warn(
81 f"Not yet implemented inference for run mode '{model_description.run_mode.name}'"
82 )
84 self.name = name
85 self._preprocessing = preprocessing
86 self._postprocessing = postprocessing
88 self.model_description = model_description
89 if isinstance(model_description, v0_4.ModelDescr):
90 self._default_input_halo: PerMember[PerAxis[Halo]] = {}
91 self._block_transform = None
92 else:
93 default_output_halo = {
94 t.id: {
95 a.id: Halo(a.halo, a.halo)
96 for a in t.axes
97 if isinstance(a, v0_5.WithHalo)
98 }
99 for t in model_description.outputs
100 }
101 self._default_input_halo = get_input_halo(
102 model_description, default_output_halo
103 )
104 self._block_transform = get_block_transform(model_description)
106 self._default_blocksize_parameter = default_blocksize_parameter
107 self._default_batch_size = default_batch_size
109 self._input_ids = get_member_ids(model_description.inputs)
110 self._output_ids = get_member_ids(model_description.outputs)
112 self._adapter: ModelAdapter = model_adapter
114 def __enter__(self):
115 self.load()
116 return self
118 def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore
119 self.unload()
120 return False
122 def predict_sample_block(
123 self,
124 sample_block: SampleBlockWithOrigin,
125 skip_preprocessing: bool = False,
126 skip_postprocessing: bool = False,
127 ) -> SampleBlock:
128 if isinstance(self.model_description, v0_4.ModelDescr):
129 raise NotImplementedError(
130 f"predict_sample_block not implemented for model {self.model_description.format_version}"
131 )
132 else:
133 assert self._block_transform is not None
135 if not skip_preprocessing:
136 self.apply_preprocessing(sample_block)
138 output_meta = sample_block.get_transformed_meta(self._block_transform)
139 local_output = self._adapter.forward(sample_block)
141 output = output_meta.with_data(local_output.members, stat=local_output.stat)
142 if not skip_postprocessing:
143 self.apply_postprocessing(output)
145 return output
147 def predict_sample_without_blocking(
148 self,
149 sample: Sample,
150 skip_preprocessing: bool = False,
151 skip_postprocessing: bool = False,
152 ) -> Sample:
153 """predict a sample.
154 The sample's tensor shapes have to match the model's input tensor description.
155 If that is not the case, consider `predict_sample_with_blocking`"""
157 if not skip_preprocessing:
158 self.apply_preprocessing(sample)
160 output = self._adapter.forward(sample)
161 if not skip_postprocessing:
162 self.apply_postprocessing(output)
164 return output
166 def get_output_sample_id(self, input_sample_id: SampleId):
167 warnings.warn(
168 "`PredictionPipeline.get_output_sample_id()` is deprecated and will be"
169 + " removed soon. Output sample id is equal to input sample id, hence this"
170 + " function is not needed."
171 )
172 return input_sample_id
174 def predict_sample_with_fixed_blocking(
175 self,
176 sample: Sample,
177 input_block_shape: Mapping[MemberId, Mapping[AxisId, int]],
178 *,
179 skip_preprocessing: bool = False,
180 skip_postprocessing: bool = False,
181 ) -> Sample:
182 if not skip_preprocessing:
183 self.apply_preprocessing(sample)
185 n_blocks, input_blocks = sample.split_into_blocks(
186 input_block_shape,
187 halo=self._default_input_halo,
188 pad_mode="reflect",
189 )
190 input_blocks = list(input_blocks)
191 predicted_blocks: List[SampleBlock] = []
192 logger.info(
193 "split sample shape {} into {} blocks of {}.",
194 {k: dict(v) for k, v in sample.shape.items()},
195 n_blocks,
196 {k: dict(v) for k, v in input_block_shape.items()},
197 )
198 for b in tqdm(
199 input_blocks,
200 desc=f"predict {sample.id or ''} with {self.model_description.id or self.model_description.name}",
201 unit="block",
202 unit_divisor=1,
203 total=n_blocks,
204 ):
205 predicted_blocks.append(
206 self.predict_sample_block(
207 b, skip_preprocessing=True, skip_postprocessing=True
208 )
209 )
211 predicted_sample = Sample.from_blocks(predicted_blocks)
212 if not skip_postprocessing:
213 self.apply_postprocessing(predicted_sample)
215 return predicted_sample
217 def predict_sample_with_blocking(
218 self,
219 sample: Sample,
220 skip_preprocessing: bool = False,
221 skip_postprocessing: bool = False,
222 ns: Optional[
223 Union[
224 v0_5.ParameterizedSize_N,
225 Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N],
226 ]
227 ] = None,
228 batch_size: Optional[int] = None,
229 ) -> Sample:
230 """predict a sample by splitting it into blocks according to the model and the `ns` parameter"""
232 if isinstance(self.model_description, v0_4.ModelDescr):
233 raise NotImplementedError(
234 "`predict_sample_with_blocking` not implemented for v0_4.ModelDescr"
235 + f" {self.model_description.name}."
236 + " Consider using `predict_sample_with_fixed_blocking`"
237 )
239 ns = ns or self._default_blocksize_parameter
240 if isinstance(ns, int):
241 ns = {
242 (ipt.id, a.id): ns
243 for ipt in self.model_description.inputs
244 for a in ipt.axes
245 if isinstance(a.size, v0_5.ParameterizedSize)
246 }
247 input_block_shape = self.model_description.get_tensor_sizes(
248 ns, batch_size or self._default_batch_size
249 ).inputs
251 return self.predict_sample_with_fixed_blocking(
252 sample,
253 input_block_shape=input_block_shape,
254 skip_preprocessing=skip_preprocessing,
255 skip_postprocessing=skip_postprocessing,
256 )
258 # def predict(
259 # self,
260 # inputs: Predict_IO,
261 # skip_preprocessing: bool = False,
262 # skip_postprocessing: bool = False,
263 # ) -> Predict_IO:
264 # """Run model prediction **including** pre/postprocessing."""
266 # if isinstance(inputs, Sample):
267 # return self.predict_sample_with_blocking(
268 # inputs,
269 # skip_preprocessing=skip_preprocessing,
270 # skip_postprocessing=skip_postprocessing,
271 # )
272 # elif isinstance(inputs, collections.abc.Iterable):
273 # return (
274 # self.predict(
275 # ipt,
276 # skip_preprocessing=skip_preprocessing,
277 # skip_postprocessing=skip_postprocessing,
278 # )
279 # for ipt in inputs
280 # )
281 # else:
282 # assert_never(inputs)
284 def apply_preprocessing(self, sample: Union[Sample, SampleBlockWithOrigin]) -> None:
285 """apply preprocessing in-place, also updates sample stats"""
286 for op in self._preprocessing:
287 op(sample)
289 def apply_postprocessing(
290 self, sample: Union[Sample, SampleBlock, SampleBlockWithOrigin]
291 ) -> None:
292 """apply postprocessing in-place, also updates samples stats"""
293 for op in self._postprocessing:
294 if isinstance(sample, (Sample, SampleBlockWithOrigin)):
295 op(sample)
296 elif not isinstance(op, BlockedOperator):
297 raise NotImplementedError(
298 "block wise update of output statistics not yet implemented"
299 )
300 else:
301 op(sample)
303 def load(self):
304 """
305 optional step: load model onto devices before calling forward if not using it as context manager
306 """
307 pass
309 def unload(self):
310 """
311 free any device memory in use
312 """
313 self._adapter.unload()
316def create_prediction_pipeline(
317 bioimageio_model: AnyModelDescr,
318 *,
319 devices: Optional[Sequence[str]] = None,
320 weight_format: Optional[SupportedWeightsFormat] = None,
321 weights_format: Optional[SupportedWeightsFormat] = None,
322 dataset_for_initial_statistics: Iterable[Union[Sample, Sequence[Tensor]]] = tuple(),
323 keep_updating_initial_dataset_statistics: bool = False,
324 fixed_dataset_statistics: Mapping[DatasetMeasure, MeasureValue] = MappingProxyType(
325 {}
326 ),
327 model_adapter: Optional[ModelAdapter] = None,
328 ns: Optional[BlocksizeParameter] = None,
329 default_blocksize_parameter: BlocksizeParameter = 10,
330 **deprecated_kwargs: Any,
331) -> PredictionPipeline:
332 """
333 Creates prediction pipeline which includes:
334 * computation of input statistics
335 * preprocessing
336 * model prediction
337 * computation of output statistics
338 * postprocessing
340 Args:
341 bioimageio_model: A bioimageio model description.
342 devices: (optional)
343 weight_format: deprecated in favor of **weights_format**
344 weights_format: (optional) Use a specific **weights_format** rather than
345 choosing one automatically.
346 A corresponding `bioimageio.core.model_adapters.ModelAdapter` will be
347 created to run inference with the **bioimageio_model**.
348 dataset_for_initial_statistics: (optional) If preprocessing steps require input
349 dataset statistics, **dataset_for_initial_statistics** allows you to
350 specifcy a dataset from which these statistics are computed.
351 keep_updating_initial_dataset_statistics: (optional) Set to `True` if you want
352 to update dataset statistics with each processed sample.
353 fixed_dataset_statistics: (optional) Allows you to specify a mapping of
354 `DatasetMeasure`s to precomputed `MeasureValue`s.
355 model_adapter: (optional) Allows you to use a custom **model_adapter** instead
356 of creating one according to the present/selected **weights_format**.
357 ns: deprecated in favor of **default_blocksize_parameter**
358 default_blocksize_parameter: Allows to control the default block size for
359 blockwise predictions, see `BlocksizeParameter`.
361 """
362 weights_format = weight_format or weights_format
363 del weight_format
364 default_blocksize_parameter = ns or default_blocksize_parameter
365 del ns
366 if deprecated_kwargs:
367 warnings.warn(
368 f"deprecated create_prediction_pipeline kwargs: {set(deprecated_kwargs)}"
369 )
371 model_adapter = model_adapter or create_model_adapter(
372 model_description=bioimageio_model,
373 devices=devices,
374 weight_format_priority_order=weights_format and (weights_format,),
375 )
377 input_ids = get_member_ids(bioimageio_model.inputs)
379 def dataset():
380 common_stat: Stat = {}
381 for i, x in enumerate(dataset_for_initial_statistics):
382 if isinstance(x, Sample):
383 yield x
384 else:
385 yield Sample(members=dict(zip(input_ids, x)), stat=common_stat, id=i)
387 preprocessing, postprocessing = setup_pre_and_postprocessing(
388 bioimageio_model,
389 dataset(),
390 keep_updating_initial_dataset_stats=keep_updating_initial_dataset_statistics,
391 fixed_dataset_stats=fixed_dataset_statistics,
392 )
394 return PredictionPipeline(
395 name=bioimageio_model.name,
396 model_description=bioimageio_model,
397 model_adapter=model_adapter,
398 preprocessing=preprocessing,
399 postprocessing=postprocessing,
400 default_blocksize_parameter=default_blocksize_parameter,
401 )