Coverage for bioimageio/core/proc_setup.py: 92%
91 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 09:51 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 09:51 +0000
1from typing import (
2 Callable,
3 Iterable,
4 List,
5 Mapping,
6 NamedTuple,
7 Optional,
8 Sequence,
9 Set,
10 Union,
11)
13from typing_extensions import assert_never
15from bioimageio.core.digest_spec import get_member_id
16from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
18from .proc_ops import (
19 AddKnownDatasetStats,
20 EnsureDtype,
21 Processing,
22 UpdateStats,
23 get_proc,
24)
25from .sample import Sample
26from .stat_calculators import StatsCalculator
27from .stat_measures import (
28 DatasetMeasure,
29 DatasetMeasureBase,
30 Measure,
31 MeasureValue,
32 SampleMeasure,
33 SampleMeasureBase,
34)
36TensorDescr = Union[
37 v0_4.InputTensorDescr,
38 v0_4.OutputTensorDescr,
39 v0_5.InputTensorDescr,
40 v0_5.OutputTensorDescr,
41]
44class PreAndPostprocessing(NamedTuple):
45 pre: List[Processing]
46 post: List[Processing]
49class _ProcessingCallables(NamedTuple):
50 pre: Callable[[Sample], None]
51 post: Callable[[Sample], None]
54class _SetupProcessing(NamedTuple):
55 pre: List[Processing]
56 post: List[Processing]
57 pre_measures: Set[Measure]
58 post_measures: Set[Measure]
61class _ApplyProcs:
62 def __init__(self, procs: Sequence[Processing]):
63 super().__init__()
64 self._procs = procs
66 def __call__(self, sample: Sample) -> None:
67 for op in self._procs:
68 op(sample)
71def get_pre_and_postprocessing(
72 model: AnyModelDescr,
73 *,
74 dataset_for_initial_statistics: Iterable[Sample],
75 keep_updating_initial_dataset_stats: bool = False,
76 fixed_dataset_stats: Optional[Mapping[DatasetMeasure, MeasureValue]] = None,
77) -> _ProcessingCallables:
78 """Creates callables to apply pre- and postprocessing in-place to a sample"""
80 setup = setup_pre_and_postprocessing(
81 model=model,
82 dataset_for_initial_statistics=dataset_for_initial_statistics,
83 keep_updating_initial_dataset_stats=keep_updating_initial_dataset_stats,
84 fixed_dataset_stats=fixed_dataset_stats,
85 )
86 return _ProcessingCallables(_ApplyProcs(setup.pre), _ApplyProcs(setup.post))
89def setup_pre_and_postprocessing(
90 model: AnyModelDescr,
91 dataset_for_initial_statistics: Iterable[Sample],
92 keep_updating_initial_dataset_stats: bool = False,
93 fixed_dataset_stats: Optional[Mapping[DatasetMeasure, MeasureValue]] = None,
94) -> PreAndPostprocessing:
95 """
96 Get pre- and postprocessing operators for a `model` description.
97 Used in `bioimageio.core.create_prediction_pipeline"""
98 prep, post, prep_meas, post_meas = _prepare_setup_pre_and_postprocessing(model)
100 missing_dataset_stats = {
101 m
102 for m in prep_meas | post_meas
103 if fixed_dataset_stats is None or m not in fixed_dataset_stats
104 }
105 if missing_dataset_stats:
106 initial_stats_calc = StatsCalculator(missing_dataset_stats)
107 for sample in dataset_for_initial_statistics:
108 initial_stats_calc.update(sample)
110 initial_stats = initial_stats_calc.finalize()
111 else:
112 initial_stats = {}
114 prep.insert(
115 0,
116 UpdateStats(
117 StatsCalculator(prep_meas, initial_stats),
118 keep_updating_initial_dataset_stats=keep_updating_initial_dataset_stats,
119 ),
120 )
121 if post_meas:
122 post.insert(
123 0,
124 UpdateStats(
125 StatsCalculator(post_meas, initial_stats),
126 keep_updating_initial_dataset_stats=keep_updating_initial_dataset_stats,
127 ),
128 )
130 if fixed_dataset_stats:
131 prep.insert(0, AddKnownDatasetStats(fixed_dataset_stats))
132 post.insert(0, AddKnownDatasetStats(fixed_dataset_stats))
134 return PreAndPostprocessing(prep, post)
137class RequiredMeasures(NamedTuple):
138 pre: Set[Measure]
139 post: Set[Measure]
142class RequiredDatasetMeasures(NamedTuple):
143 pre: Set[DatasetMeasure]
144 post: Set[DatasetMeasure]
147class RequiredSampleMeasures(NamedTuple):
148 pre: Set[SampleMeasure]
149 post: Set[SampleMeasure]
152def get_requried_measures(model: AnyModelDescr) -> RequiredMeasures:
153 s = _prepare_setup_pre_and_postprocessing(model)
154 return RequiredMeasures(s.pre_measures, s.post_measures)
157def get_required_dataset_measures(model: AnyModelDescr) -> RequiredDatasetMeasures:
158 s = _prepare_setup_pre_and_postprocessing(model)
159 return RequiredDatasetMeasures(
160 {m for m in s.pre_measures if isinstance(m, DatasetMeasureBase)},
161 {m for m in s.post_measures if isinstance(m, DatasetMeasureBase)},
162 )
165def get_requried_sample_measures(model: AnyModelDescr) -> RequiredSampleMeasures:
166 s = _prepare_setup_pre_and_postprocessing(model)
167 return RequiredSampleMeasures(
168 {m for m in s.pre_measures if isinstance(m, SampleMeasureBase)},
169 {m for m in s.post_measures if isinstance(m, SampleMeasureBase)},
170 )
173def _prepare_procs(
174 tensor_descrs: Union[
175 Sequence[v0_4.InputTensorDescr],
176 Sequence[v0_5.InputTensorDescr],
177 Sequence[v0_4.OutputTensorDescr],
178 Sequence[v0_5.OutputTensorDescr],
179 ],
180) -> List[Processing]:
181 procs: List[Processing] = []
182 for t_descr in tensor_descrs:
183 if isinstance(t_descr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)):
184 member_id = get_member_id(t_descr)
185 procs.append(
186 EnsureDtype(input=member_id, output=member_id, dtype=t_descr.data_type)
187 )
189 if isinstance(t_descr, (v0_4.InputTensorDescr, v0_5.InputTensorDescr)):
190 for proc_d in t_descr.preprocessing:
191 procs.append(get_proc(proc_d, t_descr))
192 elif isinstance(t_descr, (v0_4.OutputTensorDescr, v0_5.OutputTensorDescr)):
193 for proc_d in t_descr.postprocessing:
194 procs.append(get_proc(proc_d, t_descr))
195 else:
196 assert_never(t_descr)
198 if isinstance(
199 t_descr,
200 (v0_4.InputTensorDescr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)),
201 ):
202 if len(procs) == 1:
203 # remove initial ensure_dtype if there are no other proccessing steps
204 assert isinstance(procs[0], EnsureDtype)
205 procs = []
207 # ensure 0.4 models get float32 input
208 # which has been the implicit assumption for 0.4
209 member_id = get_member_id(t_descr)
210 procs.append(
211 EnsureDtype(input=member_id, output=member_id, dtype="float32")
212 )
214 return procs
217def _prepare_setup_pre_and_postprocessing(model: AnyModelDescr) -> _SetupProcessing:
218 if isinstance(model, v0_4.ModelDescr):
219 pre = _prepare_procs(model.inputs)
220 post = _prepare_procs(model.outputs)
221 elif isinstance(model, v0_5.ModelDescr):
222 pre = _prepare_procs(model.inputs)
223 post = _prepare_procs(model.outputs)
224 else:
225 assert_never(model)
227 return _SetupProcessing(
228 pre=pre,
229 post=post,
230 pre_measures={m for proc in pre for m in proc.required_measures},
231 post_measures={m for proc in post for m in proc.required_measures},
232 )