Coverage for src / bioimageio / core / prediction.py: 61%
57 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-18 12:35 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-18 12:35 +0000
1import collections.abc
2from pathlib import Path
3from typing import (
4 Hashable,
5 Iterable,
6 Iterator,
7 Mapping,
8 Optional,
9 Tuple,
10 Union,
11)
13from loguru import logger
14from tqdm import tqdm
16from bioimageio.spec import load_description
17from bioimageio.spec.common import PermissiveFileSource
18from bioimageio.spec.model import v0_4, v0_5
20from ._prediction_pipeline import PredictionPipeline, create_prediction_pipeline
21from .axis import AxisId
22from .common import BlocksizeParameter, MemberId, PerMember
23from .digest_spec import TensorSource, create_sample_for_model, get_member_id
24from .io import save_sample
25from .sample import Sample
28def predict(
29 *,
30 model: Union[
31 PermissiveFileSource, v0_4.ModelDescr, v0_5.ModelDescr, PredictionPipeline
32 ],
33 inputs: Union[Sample, PerMember[TensorSource], TensorSource],
34 sample_id: Hashable = "sample",
35 blocksize_parameter: Optional[BlocksizeParameter] = None,
36 input_block_shape: Optional[Mapping[MemberId, Mapping[AxisId, int]]] = None,
37 skip_preprocessing: bool = False,
38 skip_postprocessing: bool = False,
39 save_output_path: Optional[Union[Path, str]] = None,
40) -> Sample:
41 """Run prediction for a single set of input(s) with a bioimage.io model
43 Args:
44 model: Model to predict with.
45 May be given as RDF source, model description or prediction pipeline.
46 inputs: the input sample or the named input(s) for this model as a dictionary
47 sample_id: the sample id.
48 The **sample_id** is used to format **save_output_path**
49 and to distinguish sample specific log messages.
50 blocksize_parameter: (optional) Tile the input into blocks parametrized by
51 **blocksize_parameter** according to any parametrized axis sizes defined
52 by the **model**.
53 See `bioimageio.spec.model.v0_5.ParameterizedSize` for details.
54 Note: For a predetermined, fixed block shape use **input_block_shape**.
55 input_block_shape: (optional) Tile the input sample tensors into blocks.
56 Note: Use **blocksize_parameter** for a parameterized block shape to
57 run prediction independent of the exact block shape.
58 skip_preprocessing: Flag to skip the model's preprocessing.
59 skip_postprocessing: Flag to skip the model's postprocessing.
60 save_output_path: A path with to save the output to. M
61 Must contain:
62 - `{output_id}` (or `{member_id}`) if the model has multiple output tensors
63 May contain:
64 - `{sample_id}` to avoid overwriting recurrent calls
65 """
66 if isinstance(model, PredictionPipeline):
67 pp = model
68 model = pp.model_description
69 else:
70 if not isinstance(model, (v0_4.ModelDescr, v0_5.ModelDescr)):
71 loaded = load_description(model)
72 if not isinstance(loaded, (v0_4.ModelDescr, v0_5.ModelDescr)):
73 raise ValueError(f"expected model description, but got {loaded}")
74 model = loaded
76 pp = create_prediction_pipeline(
77 model,
78 fixed_dataset_statistics=inputs.stat if isinstance(inputs, Sample) else {},
79 )
81 if save_output_path is not None:
82 if (
83 "{output_id}" not in str(save_output_path)
84 and "{member_id}" not in str(save_output_path)
85 and len(model.outputs) > 1
86 ):
87 raise ValueError(
88 f"Missing `{{output_id}}` in save_output_path={save_output_path} to "
89 + "distinguish model outputs "
90 + str([get_member_id(d) for d in model.outputs])
91 )
93 if isinstance(inputs, Sample):
94 sample = inputs
95 else:
96 sample = create_sample_for_model(
97 pp.model_description, inputs=inputs, sample_id=sample_id
98 )
100 if input_block_shape is not None:
101 if blocksize_parameter is not None:
102 logger.warning(
103 "ignoring blocksize_parameter={} in favor of input_block_shape={}",
104 blocksize_parameter,
105 input_block_shape,
106 )
108 output = pp.predict_sample_with_fixed_blocking(
109 sample,
110 input_block_shape=input_block_shape,
111 skip_preprocessing=skip_preprocessing,
112 skip_postprocessing=skip_postprocessing,
113 )
114 elif blocksize_parameter is not None:
115 output = pp.predict_sample_with_blocking(
116 sample,
117 skip_preprocessing=skip_preprocessing,
118 skip_postprocessing=skip_postprocessing,
119 ns=blocksize_parameter,
120 )
121 else:
122 output = pp.predict_sample_without_blocking(
123 sample,
124 skip_preprocessing=skip_preprocessing,
125 skip_postprocessing=skip_postprocessing,
126 )
127 if save_output_path:
128 save_sample(save_output_path, output)
130 return output
133def predict_many(
134 *,
135 model: Union[
136 PermissiveFileSource, v0_4.ModelDescr, v0_5.ModelDescr, PredictionPipeline
137 ],
138 inputs: Union[Iterable[PerMember[TensorSource]], Iterable[TensorSource]],
139 sample_id: str = "sample{i:03}",
140 blocksize_parameter: Optional[
141 Union[
142 v0_5.ParameterizedSize_N,
143 Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N],
144 ]
145 ] = None,
146 skip_preprocessing: bool = False,
147 skip_postprocessing: bool = False,
148 save_output_path: Optional[Union[Path, str]] = None,
149) -> Iterator[Sample]:
150 """Run prediction for a multiple sets of inputs with a bioimage.io model
152 Args:
153 model: Model to predict with.
154 May be given as RDF source, model description or prediction pipeline.
155 inputs: An iterable of the named input(s) for this model as a dictionary.
156 sample_id: The sample id.
157 note: `{i}` will be formatted as the i-th sample.
158 If `{i}` (or `{i:`) is not present and `inputs` is not an iterable `{i:03}`
159 is appended.
160 blocksize_parameter: (optional) Tile the input into blocks parametrized by
161 blocksize according to any parametrized axis sizes defined in the model RDF.
162 skip_preprocessing: Flag to skip the model's preprocessing.
163 skip_postprocessing: Flag to skip the model's postprocessing.
164 save_output_path: A path to save the output to.
165 Must contain:
166 - `{sample_id}` to differentiate predicted samples
167 - `{output_id}` (or `{member_id}`) if the model has multiple outputs
168 """
169 if save_output_path is not None and "{sample_id}" not in str(save_output_path):
170 raise ValueError(
171 f"Missing `{{sample_id}}` in save_output_path={save_output_path}"
172 + " to differentiate predicted samples."
173 )
175 if isinstance(model, PredictionPipeline):
176 pp = model
177 else:
178 if not isinstance(model, (v0_4.ModelDescr, v0_5.ModelDescr)):
179 loaded = load_description(model)
180 if not isinstance(loaded, (v0_4.ModelDescr, v0_5.ModelDescr)):
181 raise ValueError(f"expected model description, but got {loaded}")
182 model = loaded
184 pp = create_prediction_pipeline(model)
186 if not isinstance(inputs, collections.abc.Mapping):
187 if "{i}" not in sample_id and "{i:" not in sample_id:
188 sample_id += "{i:03}"
190 total = len(inputs) if isinstance(inputs, collections.abc.Sized) else None
192 for i, ipts in tqdm(enumerate(inputs), total=total):
193 yield predict(
194 model=pp,
195 inputs=ipts,
196 sample_id=sample_id.format(i=i),
197 blocksize_parameter=blocksize_parameter,
198 skip_preprocessing=skip_preprocessing,
199 skip_postprocessing=skip_postprocessing,
200 save_output_path=save_output_path,
201 )