Coverage for src / bioimageio / core / proc_setup.py: 95%
83 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-22 18:38 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-22 18:38 +0000
1from itertools import chain
2from typing import (
3 Callable,
4 Iterable,
5 List,
6 Mapping,
7 NamedTuple,
8 Optional,
9 Sequence,
10 Set,
11 Union,
12)
14from typing_extensions import assert_never
16from bioimageio.core.digest_spec import get_member_id
17from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
19# CustomPostprocessingDescr was added after spec 0.5.9.1; guard for older installs
20_CustomPostprocessingDescr = getattr(v0_5, "CustomPostprocessingDescr", None)
22from .proc_ops import (
23 AddKnownDatasetStats,
24 CustomPostprocessing,
25 EnsureDtype,
26 Processing,
27 UpdateStats,
28 get_proc,
29)
30from .sample import Sample
31from .stat_calculators import StatsCalculator
32from .stat_measures import (
33 DatasetMeasure,
34 DatasetMeasureBase,
35 Measure,
36 MeasureValue,
37 SampleMeasure,
38 SampleMeasureBase,
39)
41TensorDescr = Union[
42 v0_4.InputTensorDescr,
43 v0_4.OutputTensorDescr,
44 v0_5.InputTensorDescr,
45 v0_5.OutputTensorDescr,
46]
49class PreAndPostprocessing(NamedTuple):
50 pre: List[Processing]
51 post: List[Processing]
54class _ProcessingCallables(NamedTuple):
55 pre: Callable[[Sample], None]
56 post: Callable[[Sample], None]
59class _ApplyProcs:
60 def __init__(self, procs: Sequence[Processing]):
61 super().__init__()
62 self._procs = procs
64 def __call__(self, sample: Sample) -> None:
65 for op in self._procs:
66 op(sample)
69def get_pre_and_postprocessing(
70 model: AnyModelDescr,
71 *,
72 dataset_for_initial_statistics: Iterable[Sample],
73 keep_updating_initial_dataset_stats: bool = False,
74 fixed_dataset_stats: Optional[Mapping[DatasetMeasure, MeasureValue]] = None,
75) -> _ProcessingCallables:
76 """Creates callables to apply pre- and postprocessing in-place to a sample"""
78 setup = setup_pre_and_postprocessing(
79 model=model,
80 dataset_for_initial_statistics=dataset_for_initial_statistics,
81 keep_updating_initial_dataset_stats=keep_updating_initial_dataset_stats,
82 fixed_dataset_stats=fixed_dataset_stats,
83 )
84 return _ProcessingCallables(_ApplyProcs(setup.pre), _ApplyProcs(setup.post))
87def setup_pre_and_postprocessing(
88 model: AnyModelDescr,
89 dataset_for_initial_statistics: Iterable[Sample],
90 keep_updating_initial_dataset_stats: bool = False,
91 fixed_dataset_stats: Optional[Mapping[DatasetMeasure, MeasureValue]] = None,
92) -> PreAndPostprocessing:
93 """Get pre- and postprocessing operators for a `model` description.
95 Used in `bioimageio.core.create_prediction_pipeline
96 """
98 prep = _get_described_procs(model.inputs)
99 post = _get_described_procs(model.outputs)
100 required = {m for p in chain(prep, post) for m in p.required_measures}
101 missing_dataset_stats = {
102 m
103 for m in required
104 if fixed_dataset_stats is None or m not in fixed_dataset_stats
105 }
106 if missing_dataset_stats:
107 initial_stats_calc = StatsCalculator(missing_dataset_stats)
108 for sample in dataset_for_initial_statistics:
109 initial_stats_calc.update(sample)
111 initial_stats = initial_stats_calc.finalize()
112 prep.insert(
113 0,
114 UpdateStats(
115 StatsCalculator(required, initial_stats),
116 keep_updating_initial_dataset_stats=keep_updating_initial_dataset_stats,
117 ),
118 )
120 if fixed_dataset_stats:
121 prep.insert(0, AddKnownDatasetStats(fixed_dataset_stats))
123 return PreAndPostprocessing(prep, post)
126class RequiredMeasures(NamedTuple):
127 pre: Set[Measure]
128 post: Set[Measure]
131class RequiredDatasetMeasures(NamedTuple):
132 pre: Set[DatasetMeasure]
133 post: Set[DatasetMeasure]
136class RequiredSampleMeasures(NamedTuple):
137 pre: Set[SampleMeasure]
138 post: Set[SampleMeasure]
141def get_requried_measures(model: AnyModelDescr) -> RequiredMeasures:
142 pre = _get_described_procs(model.inputs)
143 post = _get_described_procs(model.outputs)
144 return RequiredMeasures(
145 {m for proc in pre for m in proc.required_measures},
146 {m for proc in post for m in proc.required_measures},
147 )
150def get_required_dataset_measures(model: AnyModelDescr) -> RequiredDatasetMeasures:
151 req = get_requried_measures(model)
152 return RequiredDatasetMeasures(
153 {m for m in req.pre if isinstance(m, DatasetMeasureBase)},
154 {m for m in req.post if isinstance(m, DatasetMeasureBase)},
155 )
158def get_requried_sample_measures(model: AnyModelDescr) -> RequiredSampleMeasures:
159 req = get_requried_measures(model)
160 return RequiredSampleMeasures(
161 {m for m in req.pre if isinstance(m, SampleMeasureBase)},
162 {m for m in req.post if isinstance(m, SampleMeasureBase)},
163 )
166def _get_described_procs(
167 tensor_descrs: Iterable[TensorDescr],
168) -> List[Processing]:
169 tensor_descrs = list(tensor_descrs)
170 all_output_ids = [
171 get_member_id(d)
172 for d in tensor_descrs
173 if isinstance(d, (v0_4.OutputTensorDescr, v0_5.OutputTensorDescr))
174 ]
176 procs: List[Processing] = []
177 for t_descr in tensor_descrs:
178 if isinstance(t_descr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)):
179 member_id = get_member_id(t_descr)
180 procs.append(
181 EnsureDtype(input=member_id, output=member_id, dtype=t_descr.data_type)
182 )
184 if isinstance(t_descr, (v0_4.InputTensorDescr, v0_5.InputTensorDescr)):
185 for proc_d in t_descr.preprocessing:
186 procs.append(get_proc(proc_d, t_descr))
187 elif isinstance(t_descr, (v0_4.OutputTensorDescr, v0_5.OutputTensorDescr)):
188 for proc_d in t_descr.postprocessing:
189 if (
190 _CustomPostprocessingDescr is not None
191 and isinstance(proc_d, _CustomPostprocessingDescr)
192 and isinstance(t_descr, v0_5.OutputTensorDescr)
193 ):
194 procs.append(
195 CustomPostprocessing.from_proc_descr(
196 proc_d, t_descr, all_output_ids
197 )
198 )
199 else:
200 procs.append(get_proc(proc_d, t_descr))
201 else:
202 assert_never(t_descr)
204 if isinstance(
205 t_descr,
206 (v0_4.InputTensorDescr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)),
207 ):
208 if len(procs) == 1:
209 # remove initial ensure_dtype if there are no other proccessing steps
210 assert isinstance(procs[0], EnsureDtype)
211 procs = []
213 # ensure 0.4 models get float32 input
214 # which has been the implicit assumption for 0.4
215 member_id = get_member_id(t_descr)
216 procs.append(
217 EnsureDtype(input=member_id, output=member_id, dtype="float32")
218 )
220 return procs