Coverage for bioimageio/core/prediction.py: 59%
63 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-19 09:02 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-19 09:02 +0000
1import collections.abc
2from pathlib import Path
3from typing import (
4 Any,
5 Hashable,
6 Iterable,
7 Iterator,
8 Mapping,
9 Optional,
10 Tuple,
11 Union,
12)
14import xarray as xr
15from loguru import logger
16from numpy.typing import NDArray
17from tqdm import tqdm
19from bioimageio.spec import load_description
20from bioimageio.spec.common import PermissiveFileSource
21from bioimageio.spec.model import v0_4, v0_5
23from ._prediction_pipeline import PredictionPipeline, create_prediction_pipeline
24from .axis import AxisId
25from .common import MemberId, PerMember
26from .digest_spec import create_sample_for_model
27from .io import save_sample
28from .sample import Sample
29from .tensor import Tensor
32def predict(
33 *,
34 model: Union[
35 PermissiveFileSource, v0_4.ModelDescr, v0_5.ModelDescr, PredictionPipeline
36 ],
37 inputs: Union[Sample, PerMember[Union[Tensor, xr.DataArray, NDArray[Any], Path]]],
38 sample_id: Hashable = "sample",
39 blocksize_parameter: Optional[
40 Union[
41 v0_5.ParameterizedSize_N,
42 Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N],
43 ]
44 ] = None,
45 input_block_shape: Optional[Mapping[MemberId, Mapping[AxisId, int]]] = None,
46 skip_preprocessing: bool = False,
47 skip_postprocessing: bool = False,
48 save_output_path: Optional[Union[Path, str]] = None,
49) -> Sample:
50 """Run prediction for a single set of input(s) with a bioimage.io model
52 Args:
53 model: model to predict with.
54 May be given as RDF source, model description or prediction pipeline.
55 inputs: the input sample or the named input(s) for this model as a dictionary
56 sample_id: the sample id.
57 blocksize_parameter: (optional) tile the input into blocks parametrized by
58 blocksize according to any parametrized axis sizes defined in the model RDF.
59 Note: For a predetermined, fixed block shape use `input_block_shape`
60 input_block_shape: (optional) tile the input sample tensors into blocks.
61 Note: For a parameterized block shape, not dealing with the exact block shape,
62 use `blocksize_parameter`.
63 skip_preprocessing: flag to skip the model's preprocessing
64 skip_postprocessing: flag to skip the model's postprocessing
65 save_output_path: A path with `{member_id}` `{sample_id}` in it
66 to save the output to.
67 """
68 if save_output_path is not None:
69 if "{member_id}" not in str(save_output_path):
70 raise ValueError(
71 f"Missing `{ member_id} ` in save_output_path={save_output_path}"
72 )
74 if isinstance(model, PredictionPipeline):
75 pp = model
76 else:
77 if not isinstance(model, (v0_4.ModelDescr, v0_5.ModelDescr)):
78 loaded = load_description(model)
79 if not isinstance(loaded, (v0_4.ModelDescr, v0_5.ModelDescr)):
80 raise ValueError(f"expected model description, but got {loaded}")
81 model = loaded
83 pp = create_prediction_pipeline(model)
85 if isinstance(inputs, Sample):
86 sample = inputs
87 else:
88 sample = create_sample_for_model(
89 pp.model_description, inputs=inputs, sample_id=sample_id
90 )
92 if input_block_shape is not None:
93 if blocksize_parameter is not None:
94 logger.warning(
95 "ignoring blocksize_parameter={} in favor of input_block_shape={}",
96 blocksize_parameter,
97 input_block_shape,
98 )
100 output = pp.predict_sample_with_fixed_blocking(
101 sample,
102 input_block_shape=input_block_shape,
103 skip_preprocessing=skip_preprocessing,
104 skip_postprocessing=skip_postprocessing,
105 )
106 elif blocksize_parameter is not None:
107 output = pp.predict_sample_with_blocking(
108 sample,
109 skip_preprocessing=skip_preprocessing,
110 skip_postprocessing=skip_postprocessing,
111 ns=blocksize_parameter,
112 )
113 else:
114 output = pp.predict_sample_without_blocking(
115 sample,
116 skip_preprocessing=skip_preprocessing,
117 skip_postprocessing=skip_postprocessing,
118 )
119 if save_output_path:
120 save_sample(save_output_path, output)
122 return output
125def predict_many(
126 *,
127 model: Union[
128 PermissiveFileSource, v0_4.ModelDescr, v0_5.ModelDescr, PredictionPipeline
129 ],
130 inputs: Iterable[PerMember[Union[Tensor, xr.DataArray, NDArray[Any], Path]]],
131 sample_id: str = "sample{i:03}",
132 blocksize_parameter: Optional[
133 Union[
134 v0_5.ParameterizedSize_N,
135 Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N],
136 ]
137 ] = None,
138 skip_preprocessing: bool = False,
139 skip_postprocessing: bool = False,
140 save_output_path: Optional[Union[Path, str]] = None,
141) -> Iterator[Sample]:
142 """Run prediction for a multiple sets of inputs with a bioimage.io model
144 Args:
145 model: model to predict with.
146 May be given as RDF source, model description or prediction pipeline.
147 inputs: An iterable of the named input(s) for this model as a dictionary.
148 sample_id: the sample id.
149 note: `{i}` will be formatted as the i-th sample.
150 If `{i}` (or `{i:`) is not present and `inputs` is an iterable `{i:03}` is appended.
151 blocksize_parameter: (optional) tile the input into blocks parametrized by
152 blocksize according to any parametrized axis sizes defined in the model RDF
153 skip_preprocessing: flag to skip the model's preprocessing
154 skip_postprocessing: flag to skip the model's postprocessing
155 save_output_path: A path with `{member_id}` `{sample_id}` in it
156 to save the output to.
157 """
158 if save_output_path is not None:
159 if "{member_id}" not in str(save_output_path):
160 raise ValueError(
161 f"Missing `{ member_id} ` in save_output_path={save_output_path}"
162 )
164 if not isinstance(inputs, collections.abc.Mapping) and "{sample_id}" not in str(
165 save_output_path
166 ):
167 raise ValueError(
168 f"Missing `{ sample_id} ` in save_output_path={save_output_path}"
169 )
171 if isinstance(model, PredictionPipeline):
172 pp = model
173 else:
174 if not isinstance(model, (v0_4.ModelDescr, v0_5.ModelDescr)):
175 loaded = load_description(model)
176 if not isinstance(loaded, (v0_4.ModelDescr, v0_5.ModelDescr)):
177 raise ValueError(f"expected model description, but got {loaded}")
178 model = loaded
180 pp = create_prediction_pipeline(model)
182 if not isinstance(inputs, collections.abc.Mapping):
183 sample_id = str(sample_id)
184 if "{i}" not in sample_id and "{i:" not in sample_id:
185 sample_id += "{i:03}"
187 total = len(inputs) if isinstance(inputs, collections.abc.Sized) else None
189 for i, ipts in tqdm(enumerate(inputs), total=total):
190 yield predict(
191 model=pp,
192 inputs=ipts,
193 sample_id=sample_id.format(i=i),
194 blocksize_parameter=blocksize_parameter,
195 skip_preprocessing=skip_preprocessing,
196 skip_postprocessing=skip_postprocessing,
197 save_output_path=save_output_path,
198 )