Coverage for src/bioimageio/core/prediction.py: 63%
59 statements
« prev ^ index » next coverage.py v7.14.2, created at 2026-06-22 16:54 +0000
« prev ^ index » next coverage.py v7.14.2, created at 2026-06-22 16:54 +0000
1import collections.abc
2from pathlib import Path
3from typing import (
4 Hashable,
5 Iterable,
6 Iterator,
7 Mapping,
8 Optional,
9 Tuple,
10 Union,
11)
13from loguru import logger
14from tqdm import tqdm
16from bioimageio.spec import load_description
17from bioimageio.spec.common import PermissiveFileSource
18from bioimageio.spec.model import v0_4, v0_5
20from ._prediction_pipeline import PredictionPipeline, create_prediction_pipeline
21from .axis import AxisId
22from .common import BlocksizeParameter, MemberId, PerMember
23from .digest_spec import TensorSource, create_sample_for_model, get_member_id
24from .io import save_sample
25from .sample import Sample
28def predict(
29 *,
30 model: Union[
31 PermissiveFileSource, v0_4.ModelDescr, v0_5.ModelDescr, PredictionPipeline
32 ],
33 inputs: Union[Sample, PerMember[TensorSource], TensorSource],
34 sample_id: Hashable = "sample",
35 blocksize_parameter: Optional[BlocksizeParameter] = None,
36 input_block_shape: Optional[Mapping[MemberId, Mapping[AxisId, int]]] = None,
37 skip_preprocessing: bool = False,
38 skip_postprocessing: bool = False,
39 save_output_path: Optional[Union[Path, str]] = None,
40) -> Sample:
41 """Run prediction for a single set of input(s) with a bioimage.io model
43 Args:
44 model: Model to predict with.
45 May be given as RDF source, model description or prediction pipeline.
46 inputs: the input sample or the named input(s) for this model as a dictionary
47 sample_id: the sample id.
48 The **sample_id** is used to format **save_output_path**
49 and to distinguish sample specific log messages.
50 blocksize_parameter: (optional) Tile the input into blocks parametrized by
51 **blocksize_parameter** according to any parametrized axis sizes defined
52 by the **model**.
53 See `bioimageio.spec.model.v0_5.ParameterizedSize` for details.
54 Note: For a predetermined, fixed block shape use **input_block_shape**.
55 input_block_shape: (optional) Tile the input sample tensors into blocks.
56 Note: Use **blocksize_parameter** for a parameterized block shape to
57 run prediction independent of the exact block shape.
58 skip_preprocessing: Flag to skip the model's preprocessing.
59 skip_postprocessing: Flag to skip the model's postprocessing.
60 save_output_path: A path with to save the output to. M
61 Must contain:
62 - `{output_id}` (or `{member_id}`) if the model has multiple output tensors
63 May contain:
64 - `{sample_id}` to avoid overwriting recurrent calls
65 """
66 if isinstance(model, PredictionPipeline):
67 pp = model
68 model = pp.model_descr
69 else:
70 if not isinstance(model, (v0_4.ModelDescr, v0_5.ModelDescr)):
71 loaded = load_description(model)
72 if not isinstance(loaded, (v0_4.ModelDescr, v0_5.ModelDescr)):
73 raise ValueError(f"expected model description, but got {loaded}")
74 model = loaded
76 pp = create_prediction_pipeline(
77 model,
78 fixed_dataset_statistics=inputs.stat if isinstance(inputs, Sample) else {},
79 )
81 with pp:
82 model = pp.model_descr
83 if save_output_path is not None:
84 if (
85 "{output_id}" not in str(save_output_path)
86 and "{member_id}" not in str(save_output_path)
87 and len(model.outputs) > 1
88 ):
89 raise ValueError(
90 f"Missing `{{output_id}}` in save_output_path={save_output_path} to "
91 + "distinguish model outputs "
92 + str([get_member_id(d) for d in model.outputs])
93 )
95 if isinstance(inputs, Sample):
96 sample = inputs
97 else:
98 sample = create_sample_for_model(
99 pp.model_descr, inputs=inputs, sample_id=sample_id
100 )
102 if input_block_shape is not None:
103 if blocksize_parameter is not None:
104 logger.warning(
105 "ignoring blocksize_parameter={} in favor of input_block_shape={}",
106 blocksize_parameter,
107 input_block_shape,
108 )
110 output = pp.predict_sample_with_fixed_blocking(
111 sample,
112 input_block_shape=input_block_shape,
113 skip_preprocessing=skip_preprocessing,
114 skip_postprocessing=skip_postprocessing,
115 )
116 elif blocksize_parameter is not None:
117 output = pp.predict_sample_with_blocking(
118 sample,
119 skip_preprocessing=skip_preprocessing,
120 skip_postprocessing=skip_postprocessing,
121 ns=blocksize_parameter,
122 )
123 else:
124 output = pp.predict_sample_without_blocking(
125 sample,
126 skip_preprocessing=skip_preprocessing,
127 skip_postprocessing=skip_postprocessing,
128 )
129 if save_output_path:
130 save_sample(save_output_path, output)
132 return output
135def predict_many(
136 *,
137 model: Union[
138 PermissiveFileSource, v0_4.ModelDescr, v0_5.ModelDescr, PredictionPipeline
139 ],
140 inputs: Union[Iterable[PerMember[TensorSource]], Iterable[TensorSource]],
141 sample_id: str = "sample{i:03}",
142 blocksize_parameter: Optional[
143 Union[
144 v0_5.ParameterizedSize_N,
145 Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N],
146 ]
147 ] = None,
148 skip_preprocessing: bool = False,
149 skip_postprocessing: bool = False,
150 save_output_path: Optional[Union[Path, str]] = None,
151) -> Iterator[Sample]:
152 """Run prediction for a multiple sets of inputs with a bioimage.io model
154 Args:
155 model: Model to predict with.
156 May be given as RDF source, model description or prediction pipeline.
157 inputs: An iterable of the named input(s) for this model as a dictionary.
158 sample_id: The sample id.
159 note: `{i}` will be formatted as the i-th sample.
160 If `{i}` (or `{i:`) is not present and `inputs` is not an iterable `{i:03}`
161 is appended.
162 blocksize_parameter: (optional) Tile the input into blocks parametrized by
163 blocksize according to any parametrized axis sizes defined in the model RDF.
164 skip_preprocessing: Flag to skip the model's preprocessing.
165 skip_postprocessing: Flag to skip the model's postprocessing.
166 save_output_path: A path to save the output to.
167 Must contain:
168 - `{sample_id}` to differentiate predicted samples
169 - `{output_id}` (or `{member_id}`) if the model has multiple outputs
170 """
171 if save_output_path is not None and "{sample_id}" not in str(save_output_path):
172 raise ValueError(
173 f"Missing `{{sample_id}}` in save_output_path={save_output_path}"
174 + " to differentiate predicted samples."
175 )
177 if isinstance(model, PredictionPipeline):
178 pp = model
179 else:
180 if not isinstance(model, (v0_4.ModelDescr, v0_5.ModelDescr)):
181 loaded = load_description(model)
182 if not isinstance(loaded, (v0_4.ModelDescr, v0_5.ModelDescr)):
183 raise ValueError(f"expected model description, but got {loaded}")
184 model = loaded
186 pp = create_prediction_pipeline(model)
188 if not isinstance(inputs, collections.abc.Mapping):
189 if "{i}" not in sample_id and "{i:" not in sample_id:
190 sample_id += "{i:03}"
192 total = len(inputs) if isinstance(inputs, collections.abc.Sized) else None
194 for i, ipts in tqdm(enumerate(inputs), total=total):
195 yield predict(
196 model=pp,
197 inputs=ipts,
198 sample_id=sample_id.format(i=i),
199 blocksize_parameter=blocksize_parameter,
200 skip_preprocessing=skip_preprocessing,
201 skip_postprocessing=skip_postprocessing,
202 save_output_path=save_output_path,
203 )