bioimageio.core.proc_setup

  1from typing import (
  2    Callable,
  3    Iterable,
  4    List,
  5    Mapping,
  6    NamedTuple,
  7    Optional,
  8    Sequence,
  9    Set,
 10    Union,
 11)
 12
 13from typing_extensions import assert_never
 14
 15from bioimageio.core.digest_spec import get_member_id
 16from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
 17
 18from .proc_ops import (
 19    AddKnownDatasetStats,
 20    EnsureDtype,
 21    Processing,
 22    UpdateStats,
 23    get_proc,
 24)
 25from .sample import Sample
 26from .stat_calculators import StatsCalculator
 27from .stat_measures import (
 28    DatasetMeasure,
 29    DatasetMeasureBase,
 30    Measure,
 31    MeasureValue,
 32    SampleMeasure,
 33    SampleMeasureBase,
 34)
 35
 36TensorDescr = Union[
 37    v0_4.InputTensorDescr,
 38    v0_4.OutputTensorDescr,
 39    v0_5.InputTensorDescr,
 40    v0_5.OutputTensorDescr,
 41]
 42
 43
 44class PreAndPostprocessing(NamedTuple):
 45    pre: List[Processing]
 46    post: List[Processing]
 47
 48
 49class _ProcessingCallables(NamedTuple):
 50    pre: Callable[[Sample], None]
 51    post: Callable[[Sample], None]
 52
 53
 54class _SetupProcessing(NamedTuple):
 55    pre: List[Processing]
 56    post: List[Processing]
 57    pre_measures: Set[Measure]
 58    post_measures: Set[Measure]
 59
 60
 61class _ApplyProcs:
 62    def __init__(self, procs: Sequence[Processing]):
 63        super().__init__()
 64        self._procs = procs
 65
 66    def __call__(self, sample: Sample) -> None:
 67        for op in self._procs:
 68            op(sample)
 69
 70
 71def get_pre_and_postprocessing(
 72    model: AnyModelDescr,
 73    *,
 74    dataset_for_initial_statistics: Iterable[Sample],
 75    keep_updating_initial_dataset_stats: bool = False,
 76    fixed_dataset_stats: Optional[Mapping[DatasetMeasure, MeasureValue]] = None,
 77) -> _ProcessingCallables:
 78    """Creates callables to apply pre- and postprocessing in-place to a sample"""
 79
 80    setup = setup_pre_and_postprocessing(
 81        model=model,
 82        dataset_for_initial_statistics=dataset_for_initial_statistics,
 83        keep_updating_initial_dataset_stats=keep_updating_initial_dataset_stats,
 84        fixed_dataset_stats=fixed_dataset_stats,
 85    )
 86    return _ProcessingCallables(_ApplyProcs(setup.pre), _ApplyProcs(setup.post))
 87
 88
 89def setup_pre_and_postprocessing(
 90    model: AnyModelDescr,
 91    dataset_for_initial_statistics: Iterable[Sample],
 92    keep_updating_initial_dataset_stats: bool = False,
 93    fixed_dataset_stats: Optional[Mapping[DatasetMeasure, MeasureValue]] = None,
 94) -> PreAndPostprocessing:
 95    """
 96    Get pre- and postprocessing operators for a `model` description.
 97    Used in `bioimageio.core.create_prediction_pipeline"""
 98    prep, post, prep_meas, post_meas = _prepare_setup_pre_and_postprocessing(model)
 99
100    missing_dataset_stats = {
101        m
102        for m in prep_meas | post_meas
103        if fixed_dataset_stats is None or m not in fixed_dataset_stats
104    }
105    if missing_dataset_stats:
106        initial_stats_calc = StatsCalculator(missing_dataset_stats)
107        for sample in dataset_for_initial_statistics:
108            initial_stats_calc.update(sample)
109
110        initial_stats = initial_stats_calc.finalize()
111    else:
112        initial_stats = {}
113
114    prep.insert(
115        0,
116        UpdateStats(
117            StatsCalculator(prep_meas, initial_stats),
118            keep_updating_initial_dataset_stats=keep_updating_initial_dataset_stats,
119        ),
120    )
121    if post_meas:
122        post.insert(
123            0,
124            UpdateStats(
125                StatsCalculator(post_meas, initial_stats),
126                keep_updating_initial_dataset_stats=keep_updating_initial_dataset_stats,
127            ),
128        )
129
130    if fixed_dataset_stats:
131        prep.insert(0, AddKnownDatasetStats(fixed_dataset_stats))
132        post.insert(0, AddKnownDatasetStats(fixed_dataset_stats))
133
134    return PreAndPostprocessing(prep, post)
135
136
137class RequiredMeasures(NamedTuple):
138    pre: Set[Measure]
139    post: Set[Measure]
140
141
142class RequiredDatasetMeasures(NamedTuple):
143    pre: Set[DatasetMeasure]
144    post: Set[DatasetMeasure]
145
146
147class RequiredSampleMeasures(NamedTuple):
148    pre: Set[SampleMeasure]
149    post: Set[SampleMeasure]
150
151
152def get_requried_measures(model: AnyModelDescr) -> RequiredMeasures:
153    s = _prepare_setup_pre_and_postprocessing(model)
154    return RequiredMeasures(s.pre_measures, s.post_measures)
155
156
157def get_required_dataset_measures(model: AnyModelDescr) -> RequiredDatasetMeasures:
158    s = _prepare_setup_pre_and_postprocessing(model)
159    return RequiredDatasetMeasures(
160        {m for m in s.pre_measures if isinstance(m, DatasetMeasureBase)},
161        {m for m in s.post_measures if isinstance(m, DatasetMeasureBase)},
162    )
163
164
165def get_requried_sample_measures(model: AnyModelDescr) -> RequiredSampleMeasures:
166    s = _prepare_setup_pre_and_postprocessing(model)
167    return RequiredSampleMeasures(
168        {m for m in s.pre_measures if isinstance(m, SampleMeasureBase)},
169        {m for m in s.post_measures if isinstance(m, SampleMeasureBase)},
170    )
171
172
173def _prepare_procs(
174    tensor_descrs: Union[
175        Sequence[v0_4.InputTensorDescr],
176        Sequence[v0_5.InputTensorDescr],
177        Sequence[v0_4.OutputTensorDescr],
178        Sequence[v0_5.OutputTensorDescr],
179    ],
180) -> List[Processing]:
181    procs: List[Processing] = []
182    for t_descr in tensor_descrs:
183        if isinstance(t_descr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)):
184            member_id = get_member_id(t_descr)
185            procs.append(
186                EnsureDtype(input=member_id, output=member_id, dtype=t_descr.data_type)
187            )
188
189        if isinstance(t_descr, (v0_4.InputTensorDescr, v0_5.InputTensorDescr)):
190            for proc_d in t_descr.preprocessing:
191                procs.append(get_proc(proc_d, t_descr))
192        elif isinstance(t_descr, (v0_4.OutputTensorDescr, v0_5.OutputTensorDescr)):
193            for proc_d in t_descr.postprocessing:
194                procs.append(get_proc(proc_d, t_descr))
195        else:
196            assert_never(t_descr)
197
198        if isinstance(
199            t_descr,
200            (v0_4.InputTensorDescr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)),
201        ):
202            if len(procs) == 1:
203                # remove initial ensure_dtype if there are no other proccessing steps
204                assert isinstance(procs[0], EnsureDtype)
205                procs = []
206
207            # ensure 0.4 models get float32 input
208            # which has been the implicit assumption for 0.4
209            member_id = get_member_id(t_descr)
210            procs.append(
211                EnsureDtype(input=member_id, output=member_id, dtype="float32")
212            )
213
214    return procs
215
216
217def _prepare_setup_pre_and_postprocessing(model: AnyModelDescr) -> _SetupProcessing:
218    if isinstance(model, v0_4.ModelDescr):
219        pre = _prepare_procs(model.inputs)
220        post = _prepare_procs(model.outputs)
221    elif isinstance(model, v0_5.ModelDescr):
222        pre = _prepare_procs(model.inputs)
223        post = _prepare_procs(model.outputs)
224    else:
225        assert_never(model)
226
227    return _SetupProcessing(
228        pre=pre,
229        post=post,
230        pre_measures={m for proc in pre for m in proc.required_measures},
231        post_measures={m for proc in post for m in proc.required_measures},
232    )
class PreAndPostprocessing(typing.NamedTuple):
45class PreAndPostprocessing(NamedTuple):
46    pre: List[Processing]
47    post: List[Processing]

PreAndPostprocessing(pre, post)

def get_pre_and_postprocessing( model: Annotated[Union[Annotated[bioimageio.spec.model.v0_4.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.4')], Annotated[bioimageio.spec.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.5')]], Discriminator(discriminator='format_version', custom_error_type=None, custom_error_message=None, custom_error_context=None), FieldInfo(annotation=NoneType, required=True, title='model')], *, dataset_for_initial_statistics: Iterable[bioimageio.core.Sample], keep_updating_initial_dataset_stats: bool = False, fixed_dataset_stats: Optional[Mapping[Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer>, return_type=PydanticUndefined, when_used='always')]]]] = None) -> bioimageio.core.proc_setup._ProcessingCallables:
72def get_pre_and_postprocessing(
73    model: AnyModelDescr,
74    *,
75    dataset_for_initial_statistics: Iterable[Sample],
76    keep_updating_initial_dataset_stats: bool = False,
77    fixed_dataset_stats: Optional[Mapping[DatasetMeasure, MeasureValue]] = None,
78) -> _ProcessingCallables:
79    """Creates callables to apply pre- and postprocessing in-place to a sample"""
80
81    setup = setup_pre_and_postprocessing(
82        model=model,
83        dataset_for_initial_statistics=dataset_for_initial_statistics,
84        keep_updating_initial_dataset_stats=keep_updating_initial_dataset_stats,
85        fixed_dataset_stats=fixed_dataset_stats,
86    )
87    return _ProcessingCallables(_ApplyProcs(setup.pre), _ApplyProcs(setup.post))

Creates callables to apply pre- and postprocessing in-place to a sample

def setup_pre_and_postprocessing( model: Annotated[Union[Annotated[bioimageio.spec.model.v0_4.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.4')], Annotated[bioimageio.spec.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.5')]], Discriminator(discriminator='format_version', custom_error_type=None, custom_error_message=None, custom_error_context=None), FieldInfo(annotation=NoneType, required=True, title='model')], dataset_for_initial_statistics: Iterable[bioimageio.core.Sample], keep_updating_initial_dataset_stats: bool = False, fixed_dataset_stats: Optional[Mapping[Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer>, return_type=PydanticUndefined, when_used='always')]]]] = None) -> PreAndPostprocessing:
 90def setup_pre_and_postprocessing(
 91    model: AnyModelDescr,
 92    dataset_for_initial_statistics: Iterable[Sample],
 93    keep_updating_initial_dataset_stats: bool = False,
 94    fixed_dataset_stats: Optional[Mapping[DatasetMeasure, MeasureValue]] = None,
 95) -> PreAndPostprocessing:
 96    """
 97    Get pre- and postprocessing operators for a `model` description.
 98    Used in `bioimageio.core.create_prediction_pipeline"""
 99    prep, post, prep_meas, post_meas = _prepare_setup_pre_and_postprocessing(model)
100
101    missing_dataset_stats = {
102        m
103        for m in prep_meas | post_meas
104        if fixed_dataset_stats is None or m not in fixed_dataset_stats
105    }
106    if missing_dataset_stats:
107        initial_stats_calc = StatsCalculator(missing_dataset_stats)
108        for sample in dataset_for_initial_statistics:
109            initial_stats_calc.update(sample)
110
111        initial_stats = initial_stats_calc.finalize()
112    else:
113        initial_stats = {}
114
115    prep.insert(
116        0,
117        UpdateStats(
118            StatsCalculator(prep_meas, initial_stats),
119            keep_updating_initial_dataset_stats=keep_updating_initial_dataset_stats,
120        ),
121    )
122    if post_meas:
123        post.insert(
124            0,
125            UpdateStats(
126                StatsCalculator(post_meas, initial_stats),
127                keep_updating_initial_dataset_stats=keep_updating_initial_dataset_stats,
128            ),
129        )
130
131    if fixed_dataset_stats:
132        prep.insert(0, AddKnownDatasetStats(fixed_dataset_stats))
133        post.insert(0, AddKnownDatasetStats(fixed_dataset_stats))
134
135    return PreAndPostprocessing(prep, post)

Get pre- and postprocessing operators for a model description. Used in `bioimageio.core.create_prediction_pipeline

class RequiredMeasures(typing.NamedTuple):
138class RequiredMeasures(NamedTuple):
139    pre: Set[Measure]
140    post: Set[Measure]

RequiredMeasures(pre, post)

RequiredMeasures( pre: Set[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], post: Set[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)]])

Create new instance of RequiredMeasures(pre, post)

pre: Set[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]

Alias for field number 0

post: Set[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]

Alias for field number 1

class RequiredDatasetMeasures(typing.NamedTuple):
143class RequiredDatasetMeasures(NamedTuple):
144    pre: Set[DatasetMeasure]
145    post: Set[DatasetMeasure]

RequiredDatasetMeasures(pre, post)

RequiredDatasetMeasures( pre: Set[Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], post: Set[Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]])

Create new instance of RequiredDatasetMeasures(pre, post)

pre: Set[Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]

Alias for field number 0

post: Set[Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]

Alias for field number 1

class RequiredSampleMeasures(typing.NamedTuple):
148class RequiredSampleMeasures(NamedTuple):
149    pre: Set[SampleMeasure]
150    post: Set[SampleMeasure]

RequiredSampleMeasures(pre, post)

RequiredSampleMeasures( pre: Set[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], post: Set[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]])

Create new instance of RequiredSampleMeasures(pre, post)

pre: Set[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]

Alias for field number 0

post: Set[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]

Alias for field number 1

def get_requried_measures( model: Annotated[Union[Annotated[bioimageio.spec.model.v0_4.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.4')], Annotated[bioimageio.spec.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.5')]], Discriminator(discriminator='format_version', custom_error_type=None, custom_error_message=None, custom_error_context=None), FieldInfo(annotation=NoneType, required=True, title='model')]) -> RequiredMeasures:
153def get_requried_measures(model: AnyModelDescr) -> RequiredMeasures:
154    s = _prepare_setup_pre_and_postprocessing(model)
155    return RequiredMeasures(s.pre_measures, s.post_measures)
def get_required_dataset_measures( model: Annotated[Union[Annotated[bioimageio.spec.model.v0_4.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.4')], Annotated[bioimageio.spec.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.5')]], Discriminator(discriminator='format_version', custom_error_type=None, custom_error_message=None, custom_error_context=None), FieldInfo(annotation=NoneType, required=True, title='model')]) -> RequiredDatasetMeasures:
158def get_required_dataset_measures(model: AnyModelDescr) -> RequiredDatasetMeasures:
159    s = _prepare_setup_pre_and_postprocessing(model)
160    return RequiredDatasetMeasures(
161        {m for m in s.pre_measures if isinstance(m, DatasetMeasureBase)},
162        {m for m in s.post_measures if isinstance(m, DatasetMeasureBase)},
163    )
def get_requried_sample_measures( model: Annotated[Union[Annotated[bioimageio.spec.model.v0_4.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.4')], Annotated[bioimageio.spec.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.5')]], Discriminator(discriminator='format_version', custom_error_type=None, custom_error_message=None, custom_error_context=None), FieldInfo(annotation=NoneType, required=True, title='model')]) -> RequiredSampleMeasures:
166def get_requried_sample_measures(model: AnyModelDescr) -> RequiredSampleMeasures:
167    s = _prepare_setup_pre_and_postprocessing(model)
168    return RequiredSampleMeasures(
169        {m for m in s.pre_measures if isinstance(m, SampleMeasureBase)},
170        {m for m in s.post_measures if isinstance(m, SampleMeasureBase)},
171    )