Coverage for src / bioimageio / core / proc_setup.py: 96%
78 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 13:23 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 13:23 +0000
1from itertools import chain
2from typing import (
3 Callable,
4 Iterable,
5 List,
6 Mapping,
7 NamedTuple,
8 Optional,
9 Sequence,
10 Set,
11 Union,
12)
14from typing_extensions import assert_never
16from bioimageio.core.digest_spec import get_member_id
17from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
19from .proc_ops import (
20 AddKnownDatasetStats,
21 EnsureDtype,
22 Processing,
23 UpdateStats,
24 get_proc,
25)
26from .sample import Sample
27from .stat_calculators import StatsCalculator
28from .stat_measures import (
29 DatasetMeasure,
30 DatasetMeasureBase,
31 Measure,
32 MeasureValue,
33 SampleMeasure,
34 SampleMeasureBase,
35)
37TensorDescr = Union[
38 v0_4.InputTensorDescr,
39 v0_4.OutputTensorDescr,
40 v0_5.InputTensorDescr,
41 v0_5.OutputTensorDescr,
42]
45class PreAndPostprocessing(NamedTuple):
46 pre: List[Processing]
47 post: List[Processing]
50class _ProcessingCallables(NamedTuple):
51 pre: Callable[[Sample], None]
52 post: Callable[[Sample], None]
55class _ApplyProcs:
56 def __init__(self, procs: Sequence[Processing]):
57 super().__init__()
58 self._procs = procs
60 def __call__(self, sample: Sample) -> None:
61 for op in self._procs:
62 op(sample)
65def get_pre_and_postprocessing(
66 model: AnyModelDescr,
67 *,
68 dataset_for_initial_statistics: Iterable[Sample],
69 keep_updating_initial_dataset_stats: bool = False,
70 fixed_dataset_stats: Optional[Mapping[DatasetMeasure, MeasureValue]] = None,
71) -> _ProcessingCallables:
72 """Creates callables to apply pre- and postprocessing in-place to a sample"""
74 setup = setup_pre_and_postprocessing(
75 model=model,
76 dataset_for_initial_statistics=dataset_for_initial_statistics,
77 keep_updating_initial_dataset_stats=keep_updating_initial_dataset_stats,
78 fixed_dataset_stats=fixed_dataset_stats,
79 )
80 return _ProcessingCallables(_ApplyProcs(setup.pre), _ApplyProcs(setup.post))
83def setup_pre_and_postprocessing(
84 model: AnyModelDescr,
85 dataset_for_initial_statistics: Iterable[Sample],
86 keep_updating_initial_dataset_stats: bool = False,
87 fixed_dataset_stats: Optional[Mapping[DatasetMeasure, MeasureValue]] = None,
88) -> PreAndPostprocessing:
89 """Get pre- and postprocessing operators for a `model` description.
91 Used in `bioimageio.core.create_prediction_pipeline
92 """
94 prep = _get_described_procs(model.inputs)
95 post = _get_described_procs(model.outputs)
96 required = {m for p in chain(prep, post) for m in p.required_measures}
97 missing_dataset_stats = {
98 m
99 for m in required
100 if fixed_dataset_stats is None or m not in fixed_dataset_stats
101 }
102 if missing_dataset_stats:
103 initial_stats_calc = StatsCalculator(missing_dataset_stats)
104 for sample in dataset_for_initial_statistics:
105 initial_stats_calc.update(sample)
107 initial_stats = initial_stats_calc.finalize()
108 prep.insert(
109 0,
110 UpdateStats(
111 StatsCalculator(required, initial_stats),
112 keep_updating_initial_dataset_stats=keep_updating_initial_dataset_stats,
113 ),
114 )
116 if fixed_dataset_stats:
117 prep.insert(0, AddKnownDatasetStats(fixed_dataset_stats))
119 return PreAndPostprocessing(prep, post)
122class RequiredMeasures(NamedTuple):
123 pre: Set[Measure]
124 post: Set[Measure]
127class RequiredDatasetMeasures(NamedTuple):
128 pre: Set[DatasetMeasure]
129 post: Set[DatasetMeasure]
132class RequiredSampleMeasures(NamedTuple):
133 pre: Set[SampleMeasure]
134 post: Set[SampleMeasure]
137def get_requried_measures(model: AnyModelDescr) -> RequiredMeasures:
138 pre = _get_described_procs(model.inputs)
139 post = _get_described_procs(model.outputs)
140 return RequiredMeasures(
141 {m for proc in pre for m in proc.required_measures},
142 {m for proc in post for m in proc.required_measures},
143 )
146def get_required_dataset_measures(model: AnyModelDescr) -> RequiredDatasetMeasures:
147 req = get_requried_measures(model)
148 return RequiredDatasetMeasures(
149 {m for m in req.pre if isinstance(m, DatasetMeasureBase)},
150 {m for m in req.post if isinstance(m, DatasetMeasureBase)},
151 )
154def get_requried_sample_measures(model: AnyModelDescr) -> RequiredSampleMeasures:
155 req = get_requried_measures(model)
156 return RequiredSampleMeasures(
157 {m for m in req.pre if isinstance(m, SampleMeasureBase)},
158 {m for m in req.post if isinstance(m, SampleMeasureBase)},
159 )
162def _get_described_procs(
163 tensor_descrs: Iterable[TensorDescr],
164) -> List[Processing]:
165 procs: List[Processing] = []
166 for t_descr in tensor_descrs:
167 if isinstance(t_descr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)):
168 member_id = get_member_id(t_descr)
169 procs.append(
170 EnsureDtype(input=member_id, output=member_id, dtype=t_descr.data_type)
171 )
173 if isinstance(t_descr, (v0_4.InputTensorDescr, v0_5.InputTensorDescr)):
174 for proc_d in t_descr.preprocessing:
175 procs.append(get_proc(proc_d, t_descr))
176 elif isinstance(t_descr, (v0_4.OutputTensorDescr, v0_5.OutputTensorDescr)):
177 for proc_d in t_descr.postprocessing:
178 procs.append(get_proc(proc_d, t_descr))
179 else:
180 assert_never(t_descr)
182 if isinstance(
183 t_descr,
184 (v0_4.InputTensorDescr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)),
185 ):
186 if len(procs) == 1:
187 # remove initial ensure_dtype if there are no other proccessing steps
188 assert isinstance(procs[0], EnsureDtype)
189 procs = []
191 # ensure 0.4 models get float32 input
192 # which has been the implicit assumption for 0.4
193 member_id = get_member_id(t_descr)
194 procs.append(
195 EnsureDtype(input=member_id, output=member_id, dtype="float32")
196 )
198 return procs