bioimageio.core.proc_setup

  1from typing import (
  2    Iterable,
  3    List,
  4    Mapping,
  5    NamedTuple,
  6    Optional,
  7    Sequence,
  8    Set,
  9    Union,
 10)
 11
 12from typing_extensions import assert_never
 13
 14from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
 15from bioimageio.spec.model.v0_5 import TensorId
 16
 17from .digest_spec import get_member_ids
 18from .proc_ops import (
 19    AddKnownDatasetStats,
 20    Processing,
 21    UpdateStats,
 22    get_proc_class,
 23)
 24from .sample import Sample
 25from .stat_calculators import StatsCalculator
 26from .stat_measures import (
 27    DatasetMeasure,
 28    DatasetMeasureBase,
 29    Measure,
 30    MeasureValue,
 31    SampleMeasure,
 32    SampleMeasureBase,
 33)
 34
 35TensorDescr = Union[
 36    v0_4.InputTensorDescr,
 37    v0_4.OutputTensorDescr,
 38    v0_5.InputTensorDescr,
 39    v0_5.OutputTensorDescr,
 40]
 41
 42
 43class PreAndPostprocessing(NamedTuple):
 44    pre: List[Processing]
 45    post: List[Processing]
 46
 47
 48class _SetupProcessing(NamedTuple):
 49    pre: List[Processing]
 50    post: List[Processing]
 51    pre_measures: Set[Measure]
 52    post_measures: Set[Measure]
 53
 54
 55def setup_pre_and_postprocessing(
 56    model: AnyModelDescr,
 57    dataset_for_initial_statistics: Iterable[Sample],
 58    keep_updating_initial_dataset_stats: bool = False,
 59    fixed_dataset_stats: Optional[Mapping[DatasetMeasure, MeasureValue]] = None,
 60) -> PreAndPostprocessing:
 61    """
 62    Get pre- and postprocessing operators for a `model` description.
 63    userd in `bioimageio.core.create_prediction_pipeline"""
 64    prep, post, prep_meas, post_meas = _prepare_setup_pre_and_postprocessing(model)
 65
 66    missing_dataset_stats = {
 67        m
 68        for m in prep_meas | post_meas
 69        if fixed_dataset_stats is None or m not in fixed_dataset_stats
 70    }
 71    if missing_dataset_stats:
 72        initial_stats_calc = StatsCalculator(missing_dataset_stats)
 73        for sample in dataset_for_initial_statistics:
 74            initial_stats_calc.update(sample)
 75
 76        initial_stats = initial_stats_calc.finalize()
 77    else:
 78        initial_stats = {}
 79
 80    prep.insert(
 81        0,
 82        UpdateStats(
 83            StatsCalculator(prep_meas, initial_stats),
 84            keep_updating_initial_dataset_stats=keep_updating_initial_dataset_stats,
 85        ),
 86    )
 87    if post_meas:
 88        post.insert(
 89            0,
 90            UpdateStats(
 91                StatsCalculator(post_meas, initial_stats),
 92                keep_updating_initial_dataset_stats=keep_updating_initial_dataset_stats,
 93            ),
 94        )
 95
 96    if fixed_dataset_stats:
 97        prep.insert(0, AddKnownDatasetStats(fixed_dataset_stats))
 98        post.insert(0, AddKnownDatasetStats(fixed_dataset_stats))
 99
100    return PreAndPostprocessing(prep, post)
101
102
103class RequiredMeasures(NamedTuple):
104    pre: Set[Measure]
105    post: Set[Measure]
106
107
108class RequiredDatasetMeasures(NamedTuple):
109    pre: Set[DatasetMeasure]
110    post: Set[DatasetMeasure]
111
112
113class RequiredSampleMeasures(NamedTuple):
114    pre: Set[SampleMeasure]
115    post: Set[SampleMeasure]
116
117
118def get_requried_measures(model: AnyModelDescr) -> RequiredMeasures:
119    s = _prepare_setup_pre_and_postprocessing(model)
120    return RequiredMeasures(s.pre_measures, s.post_measures)
121
122
123def get_required_dataset_measures(model: AnyModelDescr) -> RequiredDatasetMeasures:
124    s = _prepare_setup_pre_and_postprocessing(model)
125    return RequiredDatasetMeasures(
126        {m for m in s.pre_measures if isinstance(m, DatasetMeasureBase)},
127        {m for m in s.post_measures if isinstance(m, DatasetMeasureBase)},
128    )
129
130
131def get_requried_sample_measures(model: AnyModelDescr) -> RequiredSampleMeasures:
132    s = _prepare_setup_pre_and_postprocessing(model)
133    return RequiredSampleMeasures(
134        {m for m in s.pre_measures if isinstance(m, SampleMeasureBase)},
135        {m for m in s.post_measures if isinstance(m, SampleMeasureBase)},
136    )
137
138
139def _prepare_setup_pre_and_postprocessing(model: AnyModelDescr) -> _SetupProcessing:
140    pre_measures: Set[Measure] = set()
141    post_measures: Set[Measure] = set()
142
143    input_ids = set(get_member_ids(model.inputs))
144    output_ids = set(get_member_ids(model.outputs))
145
146    def prepare_procs(tensor_descrs: Sequence[TensorDescr]):
147        procs: List[Processing] = []
148        for t_descr in tensor_descrs:
149            if isinstance(t_descr, (v0_4.InputTensorDescr, v0_5.InputTensorDescr)):
150                proc_descrs: List[
151                    Union[
152                        v0_4.PreprocessingDescr,
153                        v0_5.PreprocessingDescr,
154                        v0_4.PostprocessingDescr,
155                        v0_5.PostprocessingDescr,
156                    ]
157                ] = list(t_descr.preprocessing)
158            elif isinstance(
159                t_descr,
160                (v0_4.OutputTensorDescr, v0_5.OutputTensorDescr),
161            ):
162                proc_descrs = list(t_descr.postprocessing)
163            else:
164                assert_never(t_descr)
165
166            if isinstance(t_descr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)):
167                ensure_dtype = v0_5.EnsureDtypeDescr(
168                    kwargs=v0_5.EnsureDtypeKwargs(dtype=t_descr.data_type)
169                )
170                if isinstance(t_descr, v0_4.InputTensorDescr) and proc_descrs:
171                    proc_descrs.insert(0, ensure_dtype)
172
173                proc_descrs.append(ensure_dtype)
174
175            for proc_d in proc_descrs:
176                proc_class = get_proc_class(proc_d)
177                member_id = (
178                    TensorId(str(t_descr.name))
179                    if isinstance(t_descr, v0_4.TensorDescrBase)
180                    else t_descr.id
181                )
182                req = proc_class.from_proc_descr(
183                    proc_d, member_id  # pyright: ignore[reportArgumentType]
184                )
185                for m in req.required_measures:
186                    if m.member_id in input_ids:
187                        pre_measures.add(m)
188                    elif m.member_id in output_ids:
189                        post_measures.add(m)
190                    else:
191                        raise ValueError("When to raise ")
192                procs.append(req)
193        return procs
194
195    return _SetupProcessing(
196        pre=prepare_procs(model.inputs),
197        post=prepare_procs(model.outputs),
198        pre_measures=pre_measures,
199        post_measures=post_measures,
200    )
class PreAndPostprocessing(typing.NamedTuple):
44class PreAndPostprocessing(NamedTuple):
45    pre: List[Processing]
46    post: List[Processing]

PreAndPostprocessing(pre, post)

def setup_pre_and_postprocessing( model: Annotated[Union[bioimageio.spec.model.v0_4.ModelDescr, bioimageio.spec.ModelDescr], Discriminator(discriminator='format_version', custom_error_type=None, custom_error_message=None, custom_error_context=None)], dataset_for_initial_statistics: Iterable[bioimageio.core.Sample], keep_updating_initial_dataset_stats: bool = False, fixed_dataset_stats: Optional[Mapping[Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer>, return_type=PydanticUndefined, when_used='always')]]]] = None) -> PreAndPostprocessing:
 56def setup_pre_and_postprocessing(
 57    model: AnyModelDescr,
 58    dataset_for_initial_statistics: Iterable[Sample],
 59    keep_updating_initial_dataset_stats: bool = False,
 60    fixed_dataset_stats: Optional[Mapping[DatasetMeasure, MeasureValue]] = None,
 61) -> PreAndPostprocessing:
 62    """
 63    Get pre- and postprocessing operators for a `model` description.
 64    userd in `bioimageio.core.create_prediction_pipeline"""
 65    prep, post, prep_meas, post_meas = _prepare_setup_pre_and_postprocessing(model)
 66
 67    missing_dataset_stats = {
 68        m
 69        for m in prep_meas | post_meas
 70        if fixed_dataset_stats is None or m not in fixed_dataset_stats
 71    }
 72    if missing_dataset_stats:
 73        initial_stats_calc = StatsCalculator(missing_dataset_stats)
 74        for sample in dataset_for_initial_statistics:
 75            initial_stats_calc.update(sample)
 76
 77        initial_stats = initial_stats_calc.finalize()
 78    else:
 79        initial_stats = {}
 80
 81    prep.insert(
 82        0,
 83        UpdateStats(
 84            StatsCalculator(prep_meas, initial_stats),
 85            keep_updating_initial_dataset_stats=keep_updating_initial_dataset_stats,
 86        ),
 87    )
 88    if post_meas:
 89        post.insert(
 90            0,
 91            UpdateStats(
 92                StatsCalculator(post_meas, initial_stats),
 93                keep_updating_initial_dataset_stats=keep_updating_initial_dataset_stats,
 94            ),
 95        )
 96
 97    if fixed_dataset_stats:
 98        prep.insert(0, AddKnownDatasetStats(fixed_dataset_stats))
 99        post.insert(0, AddKnownDatasetStats(fixed_dataset_stats))
100
101    return PreAndPostprocessing(prep, post)

Get pre- and postprocessing operators for a model description. userd in `bioimageio.core.create_prediction_pipeline

class RequiredMeasures(typing.NamedTuple):
104class RequiredMeasures(NamedTuple):
105    pre: Set[Measure]
106    post: Set[Measure]

RequiredMeasures(pre, post)

RequiredMeasures( pre: Set[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], post: Set[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)]])

Create new instance of RequiredMeasures(pre, post)

pre: Set[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]

Alias for field number 0

post: Set[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]

Alias for field number 1

class RequiredDatasetMeasures(typing.NamedTuple):
109class RequiredDatasetMeasures(NamedTuple):
110    pre: Set[DatasetMeasure]
111    post: Set[DatasetMeasure]

RequiredDatasetMeasures(pre, post)

RequiredDatasetMeasures( pre: Set[Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], post: Set[Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]])

Create new instance of RequiredDatasetMeasures(pre, post)

pre: Set[Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]

Alias for field number 0

post: Set[Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]

Alias for field number 1

class RequiredSampleMeasures(typing.NamedTuple):
114class RequiredSampleMeasures(NamedTuple):
115    pre: Set[SampleMeasure]
116    post: Set[SampleMeasure]

RequiredSampleMeasures(pre, post)

RequiredSampleMeasures( pre: Set[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], post: Set[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]])

Create new instance of RequiredSampleMeasures(pre, post)

pre: Set[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]

Alias for field number 0

post: Set[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]

Alias for field number 1

def get_requried_measures( model: Annotated[Union[bioimageio.spec.model.v0_4.ModelDescr, bioimageio.spec.ModelDescr], Discriminator(discriminator='format_version', custom_error_type=None, custom_error_message=None, custom_error_context=None)]) -> RequiredMeasures:
119def get_requried_measures(model: AnyModelDescr) -> RequiredMeasures:
120    s = _prepare_setup_pre_and_postprocessing(model)
121    return RequiredMeasures(s.pre_measures, s.post_measures)
def get_required_dataset_measures( model: Annotated[Union[bioimageio.spec.model.v0_4.ModelDescr, bioimageio.spec.ModelDescr], Discriminator(discriminator='format_version', custom_error_type=None, custom_error_message=None, custom_error_context=None)]) -> RequiredDatasetMeasures:
124def get_required_dataset_measures(model: AnyModelDescr) -> RequiredDatasetMeasures:
125    s = _prepare_setup_pre_and_postprocessing(model)
126    return RequiredDatasetMeasures(
127        {m for m in s.pre_measures if isinstance(m, DatasetMeasureBase)},
128        {m for m in s.post_measures if isinstance(m, DatasetMeasureBase)},
129    )
def get_requried_sample_measures( model: Annotated[Union[bioimageio.spec.model.v0_4.ModelDescr, bioimageio.spec.ModelDescr], Discriminator(discriminator='format_version', custom_error_type=None, custom_error_message=None, custom_error_context=None)]) -> RequiredSampleMeasures:
132def get_requried_sample_measures(model: AnyModelDescr) -> RequiredSampleMeasures:
133    s = _prepare_setup_pre_and_postprocessing(model)
134    return RequiredSampleMeasures(
135        {m for m in s.pre_measures if isinstance(m, SampleMeasureBase)},
136        {m for m in s.post_measures if isinstance(m, SampleMeasureBase)},
137    )