Coverage for bioimageio/core/proc_setup.py: 87%
82 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-19 09:02 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-19 09:02 +0000
1from typing import (
2 Iterable,
3 List,
4 Mapping,
5 NamedTuple,
6 Optional,
7 Sequence,
8 Set,
9 Union,
10)
12from typing_extensions import assert_never
14from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
15from bioimageio.spec.model.v0_5 import TensorId
17from .digest_spec import get_member_ids
18from .proc_ops import (
19 AddKnownDatasetStats,
20 Processing,
21 UpdateStats,
22 get_proc_class,
23)
24from .sample import Sample
25from .stat_calculators import StatsCalculator
26from .stat_measures import (
27 DatasetMeasure,
28 DatasetMeasureBase,
29 Measure,
30 MeasureValue,
31 SampleMeasure,
32 SampleMeasureBase,
33)
35TensorDescr = Union[
36 v0_4.InputTensorDescr,
37 v0_4.OutputTensorDescr,
38 v0_5.InputTensorDescr,
39 v0_5.OutputTensorDescr,
40]
43class PreAndPostprocessing(NamedTuple):
44 pre: List[Processing]
45 post: List[Processing]
48class _SetupProcessing(NamedTuple):
49 pre: List[Processing]
50 post: List[Processing]
51 pre_measures: Set[Measure]
52 post_measures: Set[Measure]
55def setup_pre_and_postprocessing(
56 model: AnyModelDescr,
57 dataset_for_initial_statistics: Iterable[Sample],
58 keep_updating_initial_dataset_stats: bool = False,
59 fixed_dataset_stats: Optional[Mapping[DatasetMeasure, MeasureValue]] = None,
60) -> PreAndPostprocessing:
61 """
62 Get pre- and postprocessing operators for a `model` description.
63 userd in `bioimageio.core.create_prediction_pipeline"""
64 prep, post, prep_meas, post_meas = _prepare_setup_pre_and_postprocessing(model)
66 missing_dataset_stats = {
67 m
68 for m in prep_meas | post_meas
69 if fixed_dataset_stats is None or m not in fixed_dataset_stats
70 }
71 if missing_dataset_stats:
72 initial_stats_calc = StatsCalculator(missing_dataset_stats)
73 for sample in dataset_for_initial_statistics:
74 initial_stats_calc.update(sample)
76 initial_stats = initial_stats_calc.finalize()
77 else:
78 initial_stats = {}
80 prep.insert(
81 0,
82 UpdateStats(
83 StatsCalculator(prep_meas, initial_stats),
84 keep_updating_initial_dataset_stats=keep_updating_initial_dataset_stats,
85 ),
86 )
87 if post_meas:
88 post.insert(
89 0,
90 UpdateStats(
91 StatsCalculator(post_meas, initial_stats),
92 keep_updating_initial_dataset_stats=keep_updating_initial_dataset_stats,
93 ),
94 )
96 if fixed_dataset_stats:
97 prep.insert(0, AddKnownDatasetStats(fixed_dataset_stats))
98 post.insert(0, AddKnownDatasetStats(fixed_dataset_stats))
100 return PreAndPostprocessing(prep, post)
103class RequiredMeasures(NamedTuple):
104 pre: Set[Measure]
105 post: Set[Measure]
108class RequiredDatasetMeasures(NamedTuple):
109 pre: Set[DatasetMeasure]
110 post: Set[DatasetMeasure]
113class RequiredSampleMeasures(NamedTuple):
114 pre: Set[SampleMeasure]
115 post: Set[SampleMeasure]
118def get_requried_measures(model: AnyModelDescr) -> RequiredMeasures:
119 s = _prepare_setup_pre_and_postprocessing(model)
120 return RequiredMeasures(s.pre_measures, s.post_measures)
123def get_required_dataset_measures(model: AnyModelDescr) -> RequiredDatasetMeasures:
124 s = _prepare_setup_pre_and_postprocessing(model)
125 return RequiredDatasetMeasures(
126 {m for m in s.pre_measures if isinstance(m, DatasetMeasureBase)},
127 {m for m in s.post_measures if isinstance(m, DatasetMeasureBase)},
128 )
131def get_requried_sample_measures(model: AnyModelDescr) -> RequiredSampleMeasures:
132 s = _prepare_setup_pre_and_postprocessing(model)
133 return RequiredSampleMeasures(
134 {m for m in s.pre_measures if isinstance(m, SampleMeasureBase)},
135 {m for m in s.post_measures if isinstance(m, SampleMeasureBase)},
136 )
139def _prepare_setup_pre_and_postprocessing(model: AnyModelDescr) -> _SetupProcessing:
140 pre_measures: Set[Measure] = set()
141 post_measures: Set[Measure] = set()
143 input_ids = set(get_member_ids(model.inputs))
144 output_ids = set(get_member_ids(model.outputs))
146 def prepare_procs(tensor_descrs: Sequence[TensorDescr]):
147 procs: List[Processing] = []
148 for t_descr in tensor_descrs:
149 if isinstance(t_descr, (v0_4.InputTensorDescr, v0_5.InputTensorDescr)):
150 proc_descrs: List[
151 Union[
152 v0_4.PreprocessingDescr,
153 v0_5.PreprocessingDescr,
154 v0_4.PostprocessingDescr,
155 v0_5.PostprocessingDescr,
156 ]
157 ] = list(t_descr.preprocessing)
158 elif isinstance(
159 t_descr,
160 (v0_4.OutputTensorDescr, v0_5.OutputTensorDescr),
161 ):
162 proc_descrs = list(t_descr.postprocessing)
163 else:
164 assert_never(t_descr)
166 if isinstance(t_descr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)):
167 ensure_dtype = v0_5.EnsureDtypeDescr(
168 kwargs=v0_5.EnsureDtypeKwargs(dtype=t_descr.data_type)
169 )
170 if isinstance(t_descr, v0_4.InputTensorDescr) and proc_descrs:
171 proc_descrs.insert(0, ensure_dtype)
173 proc_descrs.append(ensure_dtype)
175 for proc_d in proc_descrs:
176 proc_class = get_proc_class(proc_d)
177 member_id = (
178 TensorId(str(t_descr.name))
179 if isinstance(t_descr, v0_4.TensorDescrBase)
180 else t_descr.id
181 )
182 req = proc_class.from_proc_descr(
183 proc_d, member_id # pyright: ignore[reportArgumentType]
184 )
185 for m in req.required_measures:
186 if m.member_id in input_ids:
187 pre_measures.add(m)
188 elif m.member_id in output_ids:
189 post_measures.add(m)
190 else:
191 raise ValueError("When to raise ")
192 procs.append(req)
193 return procs
195 return _SetupProcessing(
196 pre=prepare_procs(model.inputs),
197 post=prepare_procs(model.outputs),
198 pre_measures=pre_measures,
199 post_measures=post_measures,
200 )