Coverage for bioimageio/core/_prediction_pipeline.py: 89%
114 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-19 09:02 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-19 09:02 +0000
1import warnings
2from types import MappingProxyType
3from typing import (
4 Any,
5 Iterable,
6 List,
7 Mapping,
8 Optional,
9 Sequence,
10 Tuple,
11 TypeVar,
12 Union,
13)
15from tqdm import tqdm
17from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
18from bioimageio.spec.model.v0_5 import WeightsFormat
20from ._op_base import BlockedOperator
21from .axis import AxisId, PerAxis
22from .common import Halo, MemberId, PerMember, SampleId
23from .digest_spec import (
24 get_block_transform,
25 get_input_halo,
26 get_member_ids,
27)
28from .model_adapters import ModelAdapter, create_model_adapter
29from .model_adapters import get_weight_formats as get_weight_formats
30from .proc_ops import Processing
31from .proc_setup import setup_pre_and_postprocessing
32from .sample import Sample, SampleBlock, SampleBlockWithOrigin
33from .stat_measures import DatasetMeasure, MeasureValue, Stat
34from .tensor import Tensor
36Predict_IO = TypeVar(
37 "Predict_IO",
38 Sample,
39 Iterable[Sample],
40)
43class PredictionPipeline:
44 """
45 Represents model computation including preprocessing and postprocessing
46 Note: Ideally use the PredictionPipeline as a context manager
47 """
49 def __init__(
50 self,
51 *,
52 name: str,
53 model_description: AnyModelDescr,
54 preprocessing: List[Processing],
55 postprocessing: List[Processing],
56 model_adapter: ModelAdapter,
57 default_ns: Union[
58 v0_5.ParameterizedSize_N,
59 Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N],
60 ] = 10,
61 default_batch_size: int = 1,
62 ) -> None:
63 super().__init__()
64 if model_description.run_mode:
65 warnings.warn(
66 f"Not yet implemented inference for run mode '{model_description.run_mode.name}'"
67 )
69 self.name = name
70 self._preprocessing = preprocessing
71 self._postprocessing = postprocessing
73 self.model_description = model_description
74 if isinstance(model_description, v0_4.ModelDescr):
75 self._default_input_halo: PerMember[PerAxis[Halo]] = {}
76 self._block_transform = None
77 else:
78 default_output_halo = {
79 t.id: {
80 a.id: Halo(a.halo, a.halo)
81 for a in t.axes
82 if isinstance(a, v0_5.WithHalo)
83 }
84 for t in model_description.outputs
85 }
86 self._default_input_halo = get_input_halo(
87 model_description, default_output_halo
88 )
89 self._block_transform = get_block_transform(model_description)
91 self._default_ns = default_ns
92 self._default_batch_size = default_batch_size
94 self._input_ids = get_member_ids(model_description.inputs)
95 self._output_ids = get_member_ids(model_description.outputs)
97 self._adapter: ModelAdapter = model_adapter
99 def __enter__(self):
100 self.load()
101 return self
103 def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore
104 self.unload()
105 return False
107 def predict_sample_block(
108 self,
109 sample_block: SampleBlockWithOrigin,
110 skip_preprocessing: bool = False,
111 skip_postprocessing: bool = False,
112 ) -> SampleBlock:
113 if isinstance(self.model_description, v0_4.ModelDescr):
114 raise NotImplementedError(
115 f"predict_sample_block not implemented for model {self.model_description.format_version}"
116 )
117 else:
118 assert self._block_transform is not None
120 if not skip_preprocessing:
121 self.apply_preprocessing(sample_block)
123 output_meta = sample_block.get_transformed_meta(self._block_transform)
124 output = output_meta.with_data(
125 {
126 tid: out
127 for tid, out in zip(
128 self._output_ids,
129 self._adapter.forward(
130 *(sample_block.members.get(t) for t in self._input_ids)
131 ),
132 )
133 if out is not None
134 },
135 stat=sample_block.stat,
136 )
137 if not skip_postprocessing:
138 self.apply_postprocessing(output)
140 return output
142 def predict_sample_without_blocking(
143 self,
144 sample: Sample,
145 skip_preprocessing: bool = False,
146 skip_postprocessing: bool = False,
147 ) -> Sample:
148 """predict a sample.
149 The sample's tensor shapes have to match the model's input tensor description.
150 If that is not the case, consider `predict_sample_with_blocking`"""
152 if not skip_preprocessing:
153 self.apply_preprocessing(sample)
155 output = Sample(
156 members={
157 out_id: out
158 for out_id, out in zip(
159 self._output_ids,
160 self._adapter.forward(
161 *(sample.members.get(in_id) for in_id in self._input_ids)
162 ),
163 )
164 if out is not None
165 },
166 stat=sample.stat,
167 id=sample.id,
168 )
169 if not skip_postprocessing:
170 self.apply_postprocessing(output)
172 return output
174 def get_output_sample_id(self, input_sample_id: SampleId):
175 warnings.warn(
176 "`PredictionPipeline.get_output_sample_id()` is deprecated and will be"
177 + " removed soon. Output sample id is equal to input sample id, hence this"
178 + " function is not needed."
179 )
180 return input_sample_id
182 def predict_sample_with_fixed_blocking(
183 self,
184 sample: Sample,
185 input_block_shape: Mapping[MemberId, Mapping[AxisId, int]],
186 *,
187 skip_preprocessing: bool = False,
188 skip_postprocessing: bool = False,
189 ) -> Sample:
190 if not skip_preprocessing:
191 self.apply_preprocessing(sample)
193 n_blocks, input_blocks = sample.split_into_blocks(
194 input_block_shape,
195 halo=self._default_input_halo,
196 pad_mode="reflect",
197 )
198 input_blocks = list(input_blocks)
199 predicted_blocks: List[SampleBlock] = []
200 for b in tqdm(
201 input_blocks,
202 desc=f"predict sample {sample.id or ''} with {self.model_description.id or self.model_description.name}",
203 unit="block",
204 unit_divisor=1,
205 total=n_blocks,
206 ):
207 predicted_blocks.append(
208 self.predict_sample_block(
209 b, skip_preprocessing=True, skip_postprocessing=True
210 )
211 )
213 predicted_sample = Sample.from_blocks(predicted_blocks)
214 if not skip_postprocessing:
215 self.apply_postprocessing(predicted_sample)
217 return predicted_sample
219 def predict_sample_with_blocking(
220 self,
221 sample: Sample,
222 skip_preprocessing: bool = False,
223 skip_postprocessing: bool = False,
224 ns: Optional[
225 Union[
226 v0_5.ParameterizedSize_N,
227 Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N],
228 ]
229 ] = None,
230 batch_size: Optional[int] = None,
231 ) -> Sample:
232 """predict a sample by splitting it into blocks according to the model and the `ns` parameter"""
234 if isinstance(self.model_description, v0_4.ModelDescr):
235 raise NotImplementedError(
236 "`predict_sample_with_blocking` not implemented for v0_4.ModelDescr"
237 + f" {self.model_description.name}."
238 + " Consider using `predict_sample_with_fixed_blocking`"
239 )
241 ns = ns or self._default_ns
242 if isinstance(ns, int):
243 ns = {
244 (ipt.id, a.id): ns
245 for ipt in self.model_description.inputs
246 for a in ipt.axes
247 if isinstance(a.size, v0_5.ParameterizedSize)
248 }
249 input_block_shape = self.model_description.get_tensor_sizes(
250 ns, batch_size or self._default_batch_size
251 ).inputs
253 return self.predict_sample_with_fixed_blocking(
254 sample,
255 input_block_shape=input_block_shape,
256 skip_preprocessing=skip_preprocessing,
257 skip_postprocessing=skip_postprocessing,
258 )
260 # def predict(
261 # self,
262 # inputs: Predict_IO,
263 # skip_preprocessing: bool = False,
264 # skip_postprocessing: bool = False,
265 # ) -> Predict_IO:
266 # """Run model prediction **including** pre/postprocessing."""
268 # if isinstance(inputs, Sample):
269 # return self.predict_sample_with_blocking(
270 # inputs,
271 # skip_preprocessing=skip_preprocessing,
272 # skip_postprocessing=skip_postprocessing,
273 # )
274 # elif isinstance(inputs, collections.abc.Iterable):
275 # return (
276 # self.predict(
277 # ipt,
278 # skip_preprocessing=skip_preprocessing,
279 # skip_postprocessing=skip_postprocessing,
280 # )
281 # for ipt in inputs
282 # )
283 # else:
284 # assert_never(inputs)
286 def apply_preprocessing(self, sample: Union[Sample, SampleBlockWithOrigin]) -> None:
287 """apply preprocessing in-place, also updates sample stats"""
288 for op in self._preprocessing:
289 op(sample)
291 def apply_postprocessing(
292 self, sample: Union[Sample, SampleBlock, SampleBlockWithOrigin]
293 ) -> None:
294 """apply postprocessing in-place, also updates samples stats"""
295 for op in self._postprocessing:
296 if isinstance(sample, (Sample, SampleBlockWithOrigin)):
297 op(sample)
298 elif not isinstance(op, BlockedOperator):
299 raise NotImplementedError(
300 "block wise update of output statistics not yet implemented"
301 )
302 else:
303 op(sample)
305 def load(self):
306 """
307 optional step: load model onto devices before calling forward if not using it as context manager
308 """
309 pass
311 def unload(self):
312 """
313 free any device memory in use
314 """
315 self._adapter.unload()
318def create_prediction_pipeline(
319 bioimageio_model: AnyModelDescr,
320 *,
321 devices: Optional[Sequence[str]] = None,
322 weight_format: Optional[WeightsFormat] = None,
323 weights_format: Optional[WeightsFormat] = None,
324 dataset_for_initial_statistics: Iterable[Union[Sample, Sequence[Tensor]]] = tuple(),
325 keep_updating_initial_dataset_statistics: bool = False,
326 fixed_dataset_statistics: Mapping[DatasetMeasure, MeasureValue] = MappingProxyType(
327 {}
328 ),
329 model_adapter: Optional[ModelAdapter] = None,
330 ns: Union[
331 v0_5.ParameterizedSize_N,
332 Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N],
333 ] = 10,
334 **deprecated_kwargs: Any,
335) -> PredictionPipeline:
336 """
337 Creates prediction pipeline which includes:
338 * computation of input statistics
339 * preprocessing
340 * model prediction
341 * computation of output statistics
342 * postprocessing
343 """
344 weights_format = weight_format or weights_format
345 del weight_format
346 if deprecated_kwargs:
347 warnings.warn(
348 f"deprecated create_prediction_pipeline kwargs: {set(deprecated_kwargs)}"
349 )
351 model_adapter = model_adapter or create_model_adapter(
352 model_description=bioimageio_model,
353 devices=devices,
354 weight_format_priority_order=weights_format and (weights_format,),
355 )
357 input_ids = get_member_ids(bioimageio_model.inputs)
359 def dataset():
360 common_stat: Stat = {}
361 for i, x in enumerate(dataset_for_initial_statistics):
362 if isinstance(x, Sample):
363 yield x
364 else:
365 yield Sample(members=dict(zip(input_ids, x)), stat=common_stat, id=i)
367 preprocessing, postprocessing = setup_pre_and_postprocessing(
368 bioimageio_model,
369 dataset(),
370 keep_updating_initial_dataset_stats=keep_updating_initial_dataset_statistics,
371 fixed_dataset_stats=fixed_dataset_statistics,
372 )
374 return PredictionPipeline(
375 name=bioimageio_model.name,
376 model_description=bioimageio_model,
377 model_adapter=model_adapter,
378 preprocessing=preprocessing,
379 postprocessing=postprocessing,
380 default_ns=ns,
381 )