Coverage for bioimageio/core/prediction.py: 61%
57 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 09:51 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 09:51 +0000
1import collections.abc
2from pathlib import Path
3from typing import (
4 Hashable,
5 Iterable,
6 Iterator,
7 Mapping,
8 Optional,
9 Tuple,
10 Union,
11)
13from loguru import logger
14from tqdm import tqdm
16from bioimageio.spec import load_description
17from bioimageio.spec.common import PermissiveFileSource
18from bioimageio.spec.model import v0_4, v0_5
20from ._prediction_pipeline import PredictionPipeline, create_prediction_pipeline
21from .axis import AxisId
22from .common import BlocksizeParameter, MemberId, PerMember
23from .digest_spec import TensorSource, create_sample_for_model, get_member_id
24from .io import save_sample
25from .sample import Sample
28def predict(
29 *,
30 model: Union[
31 PermissiveFileSource, v0_4.ModelDescr, v0_5.ModelDescr, PredictionPipeline
32 ],
33 inputs: Union[Sample, PerMember[TensorSource], TensorSource],
34 sample_id: Hashable = "sample",
35 blocksize_parameter: Optional[BlocksizeParameter] = None,
36 input_block_shape: Optional[Mapping[MemberId, Mapping[AxisId, int]]] = None,
37 skip_preprocessing: bool = False,
38 skip_postprocessing: bool = False,
39 save_output_path: Optional[Union[Path, str]] = None,
40) -> Sample:
41 """Run prediction for a single set of input(s) with a bioimage.io model
43 Args:
44 model: Model to predict with.
45 May be given as RDF source, model description or prediction pipeline.
46 inputs: the input sample or the named input(s) for this model as a dictionary
47 sample_id: the sample id.
48 The **sample_id** is used to format **save_output_path**
49 and to distinguish sample specific log messages.
50 blocksize_parameter: (optional) Tile the input into blocks parametrized by
51 **blocksize_parameter** according to any parametrized axis sizes defined
52 by the **model**.
53 See `bioimageio.spec.model.v0_5.ParameterizedSize` for details.
54 Note: For a predetermined, fixed block shape use **input_block_shape**.
55 input_block_shape: (optional) Tile the input sample tensors into blocks.
56 Note: Use **blocksize_parameter** for a parameterized block shape to
57 run prediction independent of the exact block shape.
58 skip_preprocessing: Flag to skip the model's preprocessing.
59 skip_postprocessing: Flag to skip the model's postprocessing.
60 save_output_path: A path with to save the output to. M
61 Must contain:
62 - `{output_id}` (or `{member_id}`) if the model has multiple output tensors
63 May contain:
64 - `{sample_id}` to avoid overwriting recurrent calls
65 """
66 if isinstance(model, PredictionPipeline):
67 pp = model
68 model = pp.model_description
69 else:
70 if not isinstance(model, (v0_4.ModelDescr, v0_5.ModelDescr)):
71 loaded = load_description(model)
72 if not isinstance(loaded, (v0_4.ModelDescr, v0_5.ModelDescr)):
73 raise ValueError(f"expected model description, but got {loaded}")
74 model = loaded
76 pp = create_prediction_pipeline(model)
78 if save_output_path is not None:
79 if (
80 "{output_id}" not in str(save_output_path)
81 and "{member_id}" not in str(save_output_path)
82 and len(model.outputs) > 1
83 ):
84 raise ValueError(
85 f"Missing `{{output_id}}` in save_output_path={save_output_path} to "
86 + "distinguish model outputs "
87 + str([get_member_id(d) for d in model.outputs])
88 )
90 if isinstance(inputs, Sample):
91 sample = inputs
92 else:
93 sample = create_sample_for_model(
94 pp.model_description, inputs=inputs, sample_id=sample_id
95 )
97 if input_block_shape is not None:
98 if blocksize_parameter is not None:
99 logger.warning(
100 "ignoring blocksize_parameter={} in favor of input_block_shape={}",
101 blocksize_parameter,
102 input_block_shape,
103 )
105 output = pp.predict_sample_with_fixed_blocking(
106 sample,
107 input_block_shape=input_block_shape,
108 skip_preprocessing=skip_preprocessing,
109 skip_postprocessing=skip_postprocessing,
110 )
111 elif blocksize_parameter is not None:
112 output = pp.predict_sample_with_blocking(
113 sample,
114 skip_preprocessing=skip_preprocessing,
115 skip_postprocessing=skip_postprocessing,
116 ns=blocksize_parameter,
117 )
118 else:
119 output = pp.predict_sample_without_blocking(
120 sample,
121 skip_preprocessing=skip_preprocessing,
122 skip_postprocessing=skip_postprocessing,
123 )
124 if save_output_path:
125 save_sample(save_output_path, output)
127 return output
130def predict_many(
131 *,
132 model: Union[
133 PermissiveFileSource, v0_4.ModelDescr, v0_5.ModelDescr, PredictionPipeline
134 ],
135 inputs: Union[Iterable[PerMember[TensorSource]], Iterable[TensorSource]],
136 sample_id: str = "sample{i:03}",
137 blocksize_parameter: Optional[
138 Union[
139 v0_5.ParameterizedSize_N,
140 Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N],
141 ]
142 ] = None,
143 skip_preprocessing: bool = False,
144 skip_postprocessing: bool = False,
145 save_output_path: Optional[Union[Path, str]] = None,
146) -> Iterator[Sample]:
147 """Run prediction for a multiple sets of inputs with a bioimage.io model
149 Args:
150 model: Model to predict with.
151 May be given as RDF source, model description or prediction pipeline.
152 inputs: An iterable of the named input(s) for this model as a dictionary.
153 sample_id: The sample id.
154 note: `{i}` will be formatted as the i-th sample.
155 If `{i}` (or `{i:`) is not present and `inputs` is not an iterable `{i:03}`
156 is appended.
157 blocksize_parameter: (optional) Tile the input into blocks parametrized by
158 blocksize according to any parametrized axis sizes defined in the model RDF.
159 skip_preprocessing: Flag to skip the model's preprocessing.
160 skip_postprocessing: Flag to skip the model's postprocessing.
161 save_output_path: A path to save the output to.
162 Must contain:
163 - `{sample_id}` to differentiate predicted samples
164 - `{output_id}` (or `{member_id}`) if the model has multiple outputs
165 """
166 if save_output_path is not None and "{sample_id}" not in str(save_output_path):
167 raise ValueError(
168 f"Missing `{{sample_id}}` in save_output_path={save_output_path}"
169 + " to differentiate predicted samples."
170 )
172 if isinstance(model, PredictionPipeline):
173 pp = model
174 else:
175 if not isinstance(model, (v0_4.ModelDescr, v0_5.ModelDescr)):
176 loaded = load_description(model)
177 if not isinstance(loaded, (v0_4.ModelDescr, v0_5.ModelDescr)):
178 raise ValueError(f"expected model description, but got {loaded}")
179 model = loaded
181 pp = create_prediction_pipeline(model)
183 if not isinstance(inputs, collections.abc.Mapping):
184 if "{i}" not in sample_id and "{i:" not in sample_id:
185 sample_id += "{i:03}"
187 total = len(inputs) if isinstance(inputs, collections.abc.Sized) else None
189 for i, ipts in tqdm(enumerate(inputs), total=total):
190 yield predict(
191 model=pp,
192 inputs=ipts,
193 sample_id=sample_id.format(i=i),
194 blocksize_parameter=blocksize_parameter,
195 skip_preprocessing=skip_preprocessing,
196 skip_postprocessing=skip_postprocessing,
197 save_output_path=save_output_path,
198 )