bioimageio.core.proc_ops

  1import collections.abc
  2from abc import ABC, abstractmethod
  3from dataclasses import InitVar, dataclass, field
  4from typing import (
  5    Collection,
  6    Literal,
  7    Mapping,
  8    Optional,
  9    Sequence,
 10    Set,
 11    Tuple,
 12    Union,
 13)
 14
 15import numpy as np
 16import scipy  # pyright: ignore[reportMissingTypeStubs]
 17import xarray as xr
 18from typing_extensions import Self, assert_never
 19
 20from bioimageio.core.digest_spec import get_member_id
 21from bioimageio.spec.model import v0_4, v0_5
 22from bioimageio.spec.model.v0_5 import (
 23    _convert_proc,  # pyright: ignore [reportPrivateUsage]
 24)
 25
 26from ._op_base import BlockedOperator, Operator
 27from .axis import AxisId, PerAxis
 28from .block import Block
 29from .common import DTypeStr, MemberId
 30from .sample import Sample, SampleBlock, SampleBlockWithOrigin
 31from .stat_calculators import StatsCalculator
 32from .stat_measures import (
 33    DatasetMean,
 34    DatasetMeasure,
 35    DatasetPercentile,
 36    DatasetStd,
 37    MeanMeasure,
 38    Measure,
 39    MeasureValue,
 40    SampleMean,
 41    SampleQuantile,
 42    SampleStd,
 43    Stat,
 44    StdMeasure,
 45)
 46from .tensor import Tensor
 47
 48
 49def _convert_axis_ids(
 50    axes: v0_4.AxesInCZYX,
 51    mode: Literal["per_sample", "per_dataset"],
 52) -> Tuple[AxisId, ...]:
 53    if not isinstance(axes, str):
 54        return tuple(axes)
 55
 56    if mode == "per_sample":
 57        ret = []
 58    elif mode == "per_dataset":
 59        ret = [v0_5.BATCH_AXIS_ID]
 60    else:
 61        assert_never(mode)
 62
 63    ret.extend([AxisId(a) for a in axes])
 64    return tuple(ret)
 65
 66
 67@dataclass
 68class _SimpleOperator(BlockedOperator, ABC):
 69    input: MemberId
 70    output: MemberId
 71
 72    @property
 73    def required_measures(self) -> Collection[Measure]:
 74        return set()
 75
 76    @abstractmethod
 77    def get_output_shape(self, input_shape: PerAxis[int]) -> PerAxis[int]: ...
 78
 79    def __call__(self, sample: Union[Sample, SampleBlock]) -> None:
 80        if self.input not in sample.members:
 81            return
 82
 83        input_tensor = sample.members[self.input]
 84        output_tensor = self._apply(input_tensor, sample.stat)
 85
 86        if self.output in sample.members:
 87            assert (
 88                sample.members[self.output].tagged_shape == output_tensor.tagged_shape
 89            )
 90
 91        if isinstance(sample, Sample):
 92            sample.members[self.output] = output_tensor
 93        elif isinstance(sample, SampleBlock):
 94            b = sample.blocks[self.input]
 95            sample.blocks[self.output] = Block(
 96                sample_shape=self.get_output_shape(sample.shape[self.input]),
 97                data=output_tensor,
 98                inner_slice=b.inner_slice,
 99                halo=b.halo,
100                block_index=b.block_index,
101                blocks_in_sample=b.blocks_in_sample,
102            )
103        else:
104            assert_never(sample)
105
106    @abstractmethod
107    def _apply(self, x: Tensor, stat: Stat) -> Tensor: ...
108
109
110@dataclass
111class AddKnownDatasetStats(BlockedOperator):
112    dataset_stats: Mapping[DatasetMeasure, MeasureValue]
113
114    @property
115    def required_measures(self) -> Set[Measure]:
116        return set()
117
118    def __call__(self, sample: Union[Sample, SampleBlock]) -> None:
119        sample.stat.update(self.dataset_stats.items())
120
121
122# @dataclass
123# class UpdateStats(Operator):
124#     """Calculates sample and/or dataset measures"""
125
126#     measures: Union[Sequence[Measure], Set[Measure], Mapping[Measure, MeasureValue]]
127#     """sample and dataset `measuers` to be calculated by this operator. Initial/fixed
128#     dataset measure values may be given, see `keep_updating_dataset_stats` for details.
129#     """
130#     keep_updating_dataset_stats: Optional[bool] = None
131#     """indicates if operator calls should keep updating dataset statistics or not
132
133#     default (None): if `measures` is a `Mapping` (i.e. initial measure values are
134#     given) no further updates to dataset statistics is conducted, otherwise (w.o.
135#     initial measure values) dataset statistics are updated by each processed sample.
136#     """
137#     _keep_updating_dataset_stats: bool = field(init=False)
138#     _stats_calculator: StatsCalculator = field(init=False)
139
140#     @property
141#     def required_measures(self) -> Set[Measure]:
142#         return set()
143
144#     def __post_init__(self):
145#         self._stats_calculator = StatsCalculator(self.measures)
146#         if self.keep_updating_dataset_stats is None:
147#             self._keep_updating_dataset_stats = not isinstance(self.measures, collections.abc.Mapping)
148#         else:
149#             self._keep_updating_dataset_stats = self.keep_updating_dataset_stats
150
151#     def __call__(self, sample_block: SampleBlockWithOrigin> None:
152#         if self._keep_updating_dataset_stats:
153#             sample.stat.update(self._stats_calculator.update_and_get_all(sample))
154#         else:
155#             sample.stat.update(self._stats_calculator.skip_update_and_get_all(sample))
156
157
158@dataclass
159class UpdateStats(Operator):
160    """Calculates sample and/or dataset measures"""
161
162    stats_calculator: StatsCalculator
163    """`StatsCalculator` to be used by this operator."""
164    keep_updating_initial_dataset_stats: bool = False
165    """indicates if operator calls should keep updating initial dataset statistics or not;
166    if the `stats_calculator` was not provided with any initial dataset statistics,
167    these are always updated with every new sample.
168    """
169    _keep_updating_dataset_stats: bool = field(init=False)
170
171    @property
172    def required_measures(self) -> Set[Measure]:
173        return set()
174
175    def __post_init__(self):
176        self._keep_updating_dataset_stats = (
177            self.keep_updating_initial_dataset_stats
178            or not self.stats_calculator.has_dataset_measures
179        )
180
181    def __call__(self, sample: Union[Sample, SampleBlockWithOrigin]) -> None:
182        if isinstance(sample, SampleBlockWithOrigin):
183            # update stats with whole sample on first block
184            if sample.block_index != 0:
185                return
186
187            origin = sample.origin
188        else:
189            origin = sample
190
191        if self._keep_updating_dataset_stats:
192            sample.stat.update(self.stats_calculator.update_and_get_all(origin))
193        else:
194            sample.stat.update(self.stats_calculator.skip_update_and_get_all(origin))
195
196
197@dataclass
198class Binarize(_SimpleOperator):
199    """'output = tensor > threshold'."""
200
201    threshold: Union[float, Sequence[float]]
202    axis: Optional[AxisId] = None
203
204    def _apply(self, x: Tensor, stat: Stat) -> Tensor:
205        return x > self.threshold
206
207    def get_output_shape(
208        self, input_shape: Mapping[AxisId, int]
209    ) -> Mapping[AxisId, int]:
210        return input_shape
211
212    @classmethod
213    def from_proc_descr(
214        cls, descr: Union[v0_4.BinarizeDescr, v0_5.BinarizeDescr], member_id: MemberId
215    ) -> Self:
216        if isinstance(descr.kwargs, (v0_4.BinarizeKwargs, v0_5.BinarizeKwargs)):
217            return cls(
218                input=member_id, output=member_id, threshold=descr.kwargs.threshold
219            )
220        elif isinstance(descr.kwargs, v0_5.BinarizeAlongAxisKwargs):
221            return cls(
222                input=member_id,
223                output=member_id,
224                threshold=descr.kwargs.threshold,
225                axis=descr.kwargs.axis,
226            )
227        else:
228            assert_never(descr.kwargs)
229
230
231@dataclass
232class Clip(_SimpleOperator):
233    min: Optional[float] = None
234    """minimum value for clipping"""
235    max: Optional[float] = None
236    """maximum value for clipping"""
237
238    def __post_init__(self):
239        assert self.min is not None or self.max is not None, "missing min or max value"
240        assert self.min is None or self.max is None or self.min < self.max, (
241            f"expected min < max, but {self.min} !< {self.max}"
242        )
243
244    def _apply(self, x: Tensor, stat: Stat) -> Tensor:
245        return x.clip(self.min, self.max)
246
247    def get_output_shape(
248        self, input_shape: Mapping[AxisId, int]
249    ) -> Mapping[AxisId, int]:
250        return input_shape
251
252    @classmethod
253    def from_proc_descr(
254        cls, descr: Union[v0_4.ClipDescr, v0_5.ClipDescr], member_id: MemberId
255    ) -> Self:
256        return cls(
257            input=member_id,
258            output=member_id,
259            min=descr.kwargs.min,
260            max=descr.kwargs.max,
261        )
262
263
264@dataclass
265class EnsureDtype(_SimpleOperator):
266    dtype: DTypeStr
267
268    @classmethod
269    def from_proc_descr(cls, descr: v0_5.EnsureDtypeDescr, member_id: MemberId):
270        return cls(input=member_id, output=member_id, dtype=descr.kwargs.dtype)
271
272    def get_descr(self):
273        return v0_5.EnsureDtypeDescr(kwargs=v0_5.EnsureDtypeKwargs(dtype=self.dtype))
274
275    def get_output_shape(
276        self, input_shape: Mapping[AxisId, int]
277    ) -> Mapping[AxisId, int]:
278        return input_shape
279
280    def _apply(self, x: Tensor, stat: Stat) -> Tensor:
281        return x.astype(self.dtype)
282
283
284@dataclass
285class ScaleLinear(_SimpleOperator):
286    gain: Union[float, xr.DataArray] = 1.0
287    """multiplicative factor"""
288
289    offset: Union[float, xr.DataArray] = 0.0
290    """additive term"""
291
292    def _apply(self, x: Tensor, stat: Stat) -> Tensor:
293        return x * self.gain + self.offset
294
295    def get_output_shape(
296        self, input_shape: Mapping[AxisId, int]
297    ) -> Mapping[AxisId, int]:
298        return input_shape
299
300    @classmethod
301    def from_proc_descr(
302        cls,
303        descr: Union[v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr],
304        member_id: MemberId,
305    ) -> Self:
306        kwargs = descr.kwargs
307        if isinstance(kwargs, v0_5.ScaleLinearKwargs):
308            axis = None
309        elif isinstance(kwargs, v0_5.ScaleLinearAlongAxisKwargs):
310            axis = kwargs.axis
311        elif isinstance(kwargs, v0_4.ScaleLinearKwargs):
312            if kwargs.axes is not None:
313                raise NotImplementedError(
314                    "model.v0_4.ScaleLinearKwargs with axes not implemented, please consider updating the model to v0_5."
315                )
316            axis = None
317        else:
318            assert_never(kwargs)
319
320        if axis:
321            gain = xr.DataArray(np.atleast_1d(kwargs.gain), dims=axis)
322            offset = xr.DataArray(np.atleast_1d(kwargs.offset), dims=axis)
323        else:
324            assert isinstance(kwargs.gain, (float, int)) or len(kwargs.gain) == 1, (
325                kwargs.gain
326            )
327            gain = (
328                kwargs.gain if isinstance(kwargs.gain, (float, int)) else kwargs.gain[0]
329            )
330            assert isinstance(kwargs.offset, (float, int)) or len(kwargs.offset) == 1
331            offset = (
332                kwargs.offset
333                if isinstance(kwargs.offset, (float, int))
334                else kwargs.offset[0]
335            )
336
337        return cls(input=member_id, output=member_id, gain=gain, offset=offset)
338
339
340@dataclass
341class ScaleMeanVariance(_SimpleOperator):
342    axes: Optional[Sequence[AxisId]] = None
343    reference_tensor: Optional[MemberId] = None
344    eps: float = 1e-6
345    mean: Union[SampleMean, DatasetMean] = field(init=False)
346    std: Union[SampleStd, DatasetStd] = field(init=False)
347    ref_mean: Union[SampleMean, DatasetMean] = field(init=False)
348    ref_std: Union[SampleStd, DatasetStd] = field(init=False)
349
350    @property
351    def required_measures(self):
352        return {self.mean, self.std, self.ref_mean, self.ref_std}
353
354    def __post_init__(self):
355        axes = None if self.axes is None else tuple(self.axes)
356        ref_tensor = self.reference_tensor or self.input
357        if axes is None or AxisId("batch") not in axes:
358            Mean = SampleMean
359            Std = SampleStd
360        else:
361            Mean = DatasetMean
362            Std = DatasetStd
363
364        self.mean = Mean(member_id=self.input, axes=axes)
365        self.std = Std(member_id=self.input, axes=axes)
366        self.ref_mean = Mean(member_id=ref_tensor, axes=axes)
367        self.ref_std = Std(member_id=ref_tensor, axes=axes)
368
369    def _apply(self, x: Tensor, stat: Stat) -> Tensor:
370        mean = stat[self.mean]
371        std = stat[self.std] + self.eps
372        ref_mean = stat[self.ref_mean]
373        ref_std = stat[self.ref_std] + self.eps
374        return (x - mean) / std * ref_std + ref_mean
375
376    def get_output_shape(
377        self, input_shape: Mapping[AxisId, int]
378    ) -> Mapping[AxisId, int]:
379        return input_shape
380
381    @classmethod
382    def from_proc_descr(
383        cls,
384        descr: Union[v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr],
385        member_id: MemberId,
386    ) -> Self:
387        kwargs = descr.kwargs
388        _, axes = _get_axes(descr.kwargs)
389
390        return cls(
391            input=member_id,
392            output=member_id,
393            reference_tensor=MemberId(str(kwargs.reference_tensor)),
394            axes=axes,
395            eps=kwargs.eps,
396        )
397
398
399def _get_axes(
400    kwargs: Union[
401        v0_4.ZeroMeanUnitVarianceKwargs,
402        v0_5.ZeroMeanUnitVarianceKwargs,
403        v0_4.ScaleRangeKwargs,
404        v0_5.ScaleRangeKwargs,
405        v0_4.ScaleMeanVarianceKwargs,
406        v0_5.ScaleMeanVarianceKwargs,
407    ],
408) -> Tuple[bool, Optional[Tuple[AxisId, ...]]]:
409    if kwargs.axes is None:
410        return True, None
411    elif isinstance(kwargs.axes, str):
412        axes = _convert_axis_ids(kwargs.axes, kwargs["mode"])
413        return AxisId("b") in axes, axes
414    elif isinstance(kwargs.axes, collections.abc.Sequence):
415        axes = tuple(kwargs.axes)
416        return AxisId("batch") in axes, axes
417    else:
418        assert_never(kwargs.axes)
419
420
421@dataclass
422class ScaleRange(_SimpleOperator):
423    lower_percentile: InitVar[Optional[Union[SampleQuantile, DatasetPercentile]]] = None
424    upper_percentile: InitVar[Optional[Union[SampleQuantile, DatasetPercentile]]] = None
425    lower: Union[SampleQuantile, DatasetPercentile] = field(init=False)
426    upper: Union[SampleQuantile, DatasetPercentile] = field(init=False)
427
428    eps: float = 1e-6
429
430    def __post_init__(
431        self,
432        lower_percentile: Optional[Union[SampleQuantile, DatasetPercentile]],
433        upper_percentile: Optional[Union[SampleQuantile, DatasetPercentile]],
434    ):
435        if lower_percentile is None:
436            tid = self.input if upper_percentile is None else upper_percentile.member_id
437            self.lower = DatasetPercentile(q=0.0, member_id=tid)
438        else:
439            self.lower = lower_percentile
440
441        if upper_percentile is None:
442            self.upper = DatasetPercentile(q=1.0, member_id=self.lower.member_id)
443        else:
444            self.upper = upper_percentile
445
446        assert self.lower.member_id == self.upper.member_id
447        assert self.lower.q < self.upper.q
448        assert self.lower.axes == self.upper.axes
449
450    @property
451    def required_measures(self):
452        return {self.lower, self.upper}
453
454    def get_output_shape(
455        self, input_shape: Mapping[AxisId, int]
456    ) -> Mapping[AxisId, int]:
457        return input_shape
458
459    @classmethod
460    def from_proc_descr(
461        cls,
462        descr: Union[v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr],
463        member_id: MemberId,
464    ):
465        kwargs = descr.kwargs
466        ref_tensor = (
467            member_id
468            if kwargs.reference_tensor is None
469            else MemberId(str(kwargs.reference_tensor))
470        )
471        dataset_mode, axes = _get_axes(descr.kwargs)
472        if dataset_mode:
473            Percentile = DatasetPercentile
474        else:
475            Percentile = SampleQuantile
476
477        return cls(
478            input=member_id,
479            output=member_id,
480            lower_percentile=Percentile(
481                q=kwargs.min_percentile / 100, axes=axes, member_id=ref_tensor
482            ),
483            upper_percentile=Percentile(
484                q=kwargs.max_percentile / 100, axes=axes, member_id=ref_tensor
485            ),
486        )
487
488    def _apply(self, x: Tensor, stat: Stat) -> Tensor:
489        lower = stat[self.lower]
490        upper = stat[self.upper]
491        return (x - lower) / (upper - lower + self.eps)
492
493    def get_descr(self):
494        assert self.lower.axes == self.upper.axes
495        assert self.lower.member_id == self.upper.member_id
496
497        return v0_5.ScaleRangeDescr(
498            kwargs=v0_5.ScaleRangeKwargs(
499                axes=self.lower.axes,
500                min_percentile=self.lower.q * 100,
501                max_percentile=self.upper.q * 100,
502                eps=self.eps,
503                reference_tensor=self.lower.member_id,
504            )
505        )
506
507
508@dataclass
509class Sigmoid(_SimpleOperator):
510    """1 / (1 + e^(-input))."""
511
512    def _apply(self, x: Tensor, stat: Stat) -> Tensor:
513        return Tensor(1.0 / (1.0 + np.exp(-x)), dims=x.dims)
514
515    @property
516    def required_measures(self) -> Collection[Measure]:
517        return {}
518
519    def get_output_shape(
520        self, input_shape: Mapping[AxisId, int]
521    ) -> Mapping[AxisId, int]:
522        return input_shape
523
524    @classmethod
525    def from_proc_descr(
526        cls, descr: Union[v0_4.SigmoidDescr, v0_5.SigmoidDescr], member_id: MemberId
527    ) -> Self:
528        assert isinstance(descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr))
529        return cls(input=member_id, output=member_id)
530
531    def get_descr(self):
532        return v0_5.SigmoidDescr()
533
534
535@dataclass
536class Softmax(_SimpleOperator):
537    """Softmax activation function."""
538
539    axis: AxisId = AxisId("channel")
540
541    def _apply(self, x: Tensor, stat: Stat) -> Tensor:
542        axis_idx = x.dims.index(self.axis)
543        result = scipy.special.softmax(x.data, axis=axis_idx)
544        result_xr = xr.DataArray(result, dims=x.dims)
545        return Tensor.from_xarray(result_xr)
546
547    @property
548    def required_measures(self) -> Collection[Measure]:
549        return {}
550
551    def get_output_shape(
552        self, input_shape: Mapping[AxisId, int]
553    ) -> Mapping[AxisId, int]:
554        return input_shape
555
556    @classmethod
557    def from_proc_descr(cls, descr: v0_5.SoftmaxDescr, member_id: MemberId) -> Self:
558        assert isinstance(descr, v0_5.SoftmaxDescr)
559        return cls(input=member_id, output=member_id, axis=descr.kwargs.axis)
560
561    def get_descr(self):
562        return v0_5.SoftmaxDescr(kwargs=v0_5.SoftmaxKwargs(axis=self.axis))
563
564
565@dataclass
566class ZeroMeanUnitVariance(_SimpleOperator):
567    """normalize to zero mean, unit variance."""
568
569    mean: MeanMeasure
570    std: StdMeasure
571
572    eps: float = 1e-6
573
574    def __post_init__(self):
575        assert self.mean.axes == self.std.axes
576
577    @property
578    def required_measures(self) -> Set[Union[MeanMeasure, StdMeasure]]:
579        return {self.mean, self.std}
580
581    def get_output_shape(
582        self, input_shape: Mapping[AxisId, int]
583    ) -> Mapping[AxisId, int]:
584        return input_shape
585
586    @classmethod
587    def from_proc_descr(
588        cls,
589        descr: Union[v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr],
590        member_id: MemberId,
591    ):
592        dataset_mode, axes = _get_axes(descr.kwargs)
593
594        if dataset_mode:
595            Mean = DatasetMean
596            Std = DatasetStd
597        else:
598            Mean = SampleMean
599            Std = SampleStd
600
601        return cls(
602            input=member_id,
603            output=member_id,
604            mean=Mean(axes=axes, member_id=member_id),
605            std=Std(axes=axes, member_id=member_id),
606        )
607
608    def _apply(self, x: Tensor, stat: Stat) -> Tensor:
609        mean = stat[self.mean]
610        std = stat[self.std]
611        return (x - mean) / (std + self.eps)
612
613    def get_descr(self):
614        return v0_5.ZeroMeanUnitVarianceDescr(
615            kwargs=v0_5.ZeroMeanUnitVarianceKwargs(axes=self.mean.axes, eps=self.eps)
616        )
617
618
619@dataclass
620class FixedZeroMeanUnitVariance(_SimpleOperator):
621    """normalize to zero mean, unit variance with precomputed values."""
622
623    mean: Union[float, xr.DataArray]
624    std: Union[float, xr.DataArray]
625
626    eps: float = 1e-6
627
628    def __post_init__(self):
629        assert (
630            isinstance(self.mean, (int, float))
631            or isinstance(self.std, (int, float))
632            or self.mean.dims == self.std.dims
633        )
634
635    def get_output_shape(
636        self, input_shape: Mapping[AxisId, int]
637    ) -> Mapping[AxisId, int]:
638        return input_shape
639
640    @classmethod
641    def from_proc_descr(
642        cls,
643        descr: v0_5.FixedZeroMeanUnitVarianceDescr,
644        member_id: MemberId,
645    ) -> Self:
646        if isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceKwargs):
647            dims = None
648        elif isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs):
649            dims = (AxisId(descr.kwargs.axis),)
650        else:
651            assert_never(descr.kwargs)
652
653        return cls(
654            input=member_id,
655            output=member_id,
656            mean=xr.DataArray(descr.kwargs.mean, dims=dims),
657            std=xr.DataArray(descr.kwargs.std, dims=dims),
658        )
659
660    def get_descr(self):
661        if isinstance(self.mean, (int, float)):
662            assert isinstance(self.std, (int, float))
663            kwargs = v0_5.FixedZeroMeanUnitVarianceKwargs(mean=self.mean, std=self.std)
664        else:
665            assert isinstance(self.std, xr.DataArray)
666            assert len(self.mean.dims) == 1
667            kwargs = v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs(
668                axis=AxisId(str(self.mean.dims[0])),
669                mean=list(self.mean),
670                std=list(self.std),
671            )
672
673        return v0_5.FixedZeroMeanUnitVarianceDescr(kwargs=kwargs)
674
675    def _apply(self, x: Tensor, stat: Stat) -> Tensor:
676        return (x - self.mean) / (self.std + self.eps)
677
678
679ProcDescr = Union[
680    v0_4.PreprocessingDescr,
681    v0_4.PostprocessingDescr,
682    v0_5.PreprocessingDescr,
683    v0_5.PostprocessingDescr,
684]
685
686Processing = Union[
687    AddKnownDatasetStats,
688    Binarize,
689    Clip,
690    EnsureDtype,
691    FixedZeroMeanUnitVariance,
692    ScaleLinear,
693    ScaleMeanVariance,
694    ScaleRange,
695    Sigmoid,
696    Softmax,
697    UpdateStats,
698    ZeroMeanUnitVariance,
699]
700
701
702def get_proc(
703    proc_descr: ProcDescr,
704    tensor_descr: Union[
705        v0_4.InputTensorDescr,
706        v0_4.OutputTensorDescr,
707        v0_5.InputTensorDescr,
708        v0_5.OutputTensorDescr,
709    ],
710) -> Processing:
711    member_id = get_member_id(tensor_descr)
712
713    if isinstance(proc_descr, (v0_4.BinarizeDescr, v0_5.BinarizeDescr)):
714        return Binarize.from_proc_descr(proc_descr, member_id)
715    elif isinstance(proc_descr, (v0_4.ClipDescr, v0_5.ClipDescr)):
716        return Clip.from_proc_descr(proc_descr, member_id)
717    elif isinstance(proc_descr, v0_5.EnsureDtypeDescr):
718        return EnsureDtype.from_proc_descr(proc_descr, member_id)
719    elif isinstance(proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr):
720        return FixedZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id)
721    elif isinstance(proc_descr, (v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr)):
722        return ScaleLinear.from_proc_descr(proc_descr, member_id)
723    elif isinstance(
724        proc_descr, (v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr)
725    ):
726        return ScaleMeanVariance.from_proc_descr(proc_descr, member_id)
727    elif isinstance(proc_descr, (v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr)):
728        return ScaleRange.from_proc_descr(proc_descr, member_id)
729    elif isinstance(proc_descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)):
730        return Sigmoid.from_proc_descr(proc_descr, member_id)
731    elif (
732        isinstance(proc_descr, v0_4.ZeroMeanUnitVarianceDescr)
733        and proc_descr.kwargs.mode == "fixed"
734    ):
735        if not isinstance(
736            tensor_descr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)
737        ):
738            raise TypeError(
739                "Expected v0_4 tensor description for v0_4 processing description"
740            )
741
742        v5_proc_descr = _convert_proc(proc_descr, tensor_descr.axes)
743        assert isinstance(v5_proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr)
744        return FixedZeroMeanUnitVariance.from_proc_descr(v5_proc_descr, member_id)
745    elif isinstance(
746        proc_descr,
747        (v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr),
748    ):
749        return ZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id)
750    elif isinstance(proc_descr, v0_5.SoftmaxDescr):
751        return Softmax.from_proc_descr(proc_descr, member_id)
752    else:
753        assert_never(proc_descr)
@dataclass
class AddKnownDatasetStats(bioimageio.core._op_base.BlockedOperator):
111@dataclass
112class AddKnownDatasetStats(BlockedOperator):
113    dataset_stats: Mapping[DatasetMeasure, MeasureValue]
114
115    @property
116    def required_measures(self) -> Set[Measure]:
117        return set()
118
119    def __call__(self, sample: Union[Sample, SampleBlock]) -> None:
120        sample.stat.update(self.dataset_stats.items())
AddKnownDatasetStats( dataset_stats: Mapping[Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer>, return_type=PydanticUndefined, when_used='always')]]])
dataset_stats: Mapping[Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7f0219527100>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7f02195272e0>, return_type=PydanticUndefined, when_used='always')]]]
required_measures: Set[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]
115    @property
116    def required_measures(self) -> Set[Measure]:
117        return set()
@dataclass
class UpdateStats(bioimageio.core._op_base.Operator):
159@dataclass
160class UpdateStats(Operator):
161    """Calculates sample and/or dataset measures"""
162
163    stats_calculator: StatsCalculator
164    """`StatsCalculator` to be used by this operator."""
165    keep_updating_initial_dataset_stats: bool = False
166    """indicates if operator calls should keep updating initial dataset statistics or not;
167    if the `stats_calculator` was not provided with any initial dataset statistics,
168    these are always updated with every new sample.
169    """
170    _keep_updating_dataset_stats: bool = field(init=False)
171
172    @property
173    def required_measures(self) -> Set[Measure]:
174        return set()
175
176    def __post_init__(self):
177        self._keep_updating_dataset_stats = (
178            self.keep_updating_initial_dataset_stats
179            or not self.stats_calculator.has_dataset_measures
180        )
181
182    def __call__(self, sample: Union[Sample, SampleBlockWithOrigin]) -> None:
183        if isinstance(sample, SampleBlockWithOrigin):
184            # update stats with whole sample on first block
185            if sample.block_index != 0:
186                return
187
188            origin = sample.origin
189        else:
190            origin = sample
191
192        if self._keep_updating_dataset_stats:
193            sample.stat.update(self.stats_calculator.update_and_get_all(origin))
194        else:
195            sample.stat.update(self.stats_calculator.skip_update_and_get_all(origin))

Calculates sample and/or dataset measures

UpdateStats( stats_calculator: bioimageio.core.stat_calculators.StatsCalculator, keep_updating_initial_dataset_stats: bool = False)

StatsCalculator to be used by this operator.

keep_updating_initial_dataset_stats: bool = False

indicates if operator calls should keep updating initial dataset statistics or not; if the stats_calculator was not provided with any initial dataset statistics, these are always updated with every new sample.

required_measures: Set[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]
172    @property
173    def required_measures(self) -> Set[Measure]:
174        return set()
@dataclass
class Binarize(_SimpleOperator):
198@dataclass
199class Binarize(_SimpleOperator):
200    """'output = tensor > threshold'."""
201
202    threshold: Union[float, Sequence[float]]
203    axis: Optional[AxisId] = None
204
205    def _apply(self, x: Tensor, stat: Stat) -> Tensor:
206        return x > self.threshold
207
208    def get_output_shape(
209        self, input_shape: Mapping[AxisId, int]
210    ) -> Mapping[AxisId, int]:
211        return input_shape
212
213    @classmethod
214    def from_proc_descr(
215        cls, descr: Union[v0_4.BinarizeDescr, v0_5.BinarizeDescr], member_id: MemberId
216    ) -> Self:
217        if isinstance(descr.kwargs, (v0_4.BinarizeKwargs, v0_5.BinarizeKwargs)):
218            return cls(
219                input=member_id, output=member_id, threshold=descr.kwargs.threshold
220            )
221        elif isinstance(descr.kwargs, v0_5.BinarizeAlongAxisKwargs):
222            return cls(
223                input=member_id,
224                output=member_id,
225                threshold=descr.kwargs.threshold,
226                axis=descr.kwargs.axis,
227            )
228        else:
229            assert_never(descr.kwargs)

'output = tensor > threshold'.

Binarize( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId, threshold: Union[float, Sequence[float]], axis: Optional[bioimageio.spec.model.v0_5.AxisId] = None)
threshold: Union[float, Sequence[float]]
axis: Optional[bioimageio.spec.model.v0_5.AxisId] = None
def get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
208    def get_output_shape(
209        self, input_shape: Mapping[AxisId, int]
210    ) -> Mapping[AxisId, int]:
211        return input_shape
@classmethod
def from_proc_descr( cls, descr: Union[bioimageio.spec.model.v0_4.BinarizeDescr, bioimageio.spec.model.v0_5.BinarizeDescr], member_id: bioimageio.spec.model.v0_5.TensorId) -> Self:
213    @classmethod
214    def from_proc_descr(
215        cls, descr: Union[v0_4.BinarizeDescr, v0_5.BinarizeDescr], member_id: MemberId
216    ) -> Self:
217        if isinstance(descr.kwargs, (v0_4.BinarizeKwargs, v0_5.BinarizeKwargs)):
218            return cls(
219                input=member_id, output=member_id, threshold=descr.kwargs.threshold
220            )
221        elif isinstance(descr.kwargs, v0_5.BinarizeAlongAxisKwargs):
222            return cls(
223                input=member_id,
224                output=member_id,
225                threshold=descr.kwargs.threshold,
226                axis=descr.kwargs.axis,
227            )
228        else:
229            assert_never(descr.kwargs)
@dataclass
class Clip(_SimpleOperator):
232@dataclass
233class Clip(_SimpleOperator):
234    min: Optional[float] = None
235    """minimum value for clipping"""
236    max: Optional[float] = None
237    """maximum value for clipping"""
238
239    def __post_init__(self):
240        assert self.min is not None or self.max is not None, "missing min or max value"
241        assert self.min is None or self.max is None or self.min < self.max, (
242            f"expected min < max, but {self.min} !< {self.max}"
243        )
244
245    def _apply(self, x: Tensor, stat: Stat) -> Tensor:
246        return x.clip(self.min, self.max)
247
248    def get_output_shape(
249        self, input_shape: Mapping[AxisId, int]
250    ) -> Mapping[AxisId, int]:
251        return input_shape
252
253    @classmethod
254    def from_proc_descr(
255        cls, descr: Union[v0_4.ClipDescr, v0_5.ClipDescr], member_id: MemberId
256    ) -> Self:
257        return cls(
258            input=member_id,
259            output=member_id,
260            min=descr.kwargs.min,
261            max=descr.kwargs.max,
262        )
Clip( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId, min: Optional[float] = None, max: Optional[float] = None)
min: Optional[float] = None

minimum value for clipping

max: Optional[float] = None

maximum value for clipping

def get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
248    def get_output_shape(
249        self, input_shape: Mapping[AxisId, int]
250    ) -> Mapping[AxisId, int]:
251        return input_shape
@classmethod
def from_proc_descr( cls, descr: Union[bioimageio.spec.model.v0_4.ClipDescr, bioimageio.spec.model.v0_5.ClipDescr], member_id: bioimageio.spec.model.v0_5.TensorId) -> Self:
253    @classmethod
254    def from_proc_descr(
255        cls, descr: Union[v0_4.ClipDescr, v0_5.ClipDescr], member_id: MemberId
256    ) -> Self:
257        return cls(
258            input=member_id,
259            output=member_id,
260            min=descr.kwargs.min,
261            max=descr.kwargs.max,
262        )
@dataclass
class EnsureDtype(_SimpleOperator):
265@dataclass
266class EnsureDtype(_SimpleOperator):
267    dtype: DTypeStr
268
269    @classmethod
270    def from_proc_descr(cls, descr: v0_5.EnsureDtypeDescr, member_id: MemberId):
271        return cls(input=member_id, output=member_id, dtype=descr.kwargs.dtype)
272
273    def get_descr(self):
274        return v0_5.EnsureDtypeDescr(kwargs=v0_5.EnsureDtypeKwargs(dtype=self.dtype))
275
276    def get_output_shape(
277        self, input_shape: Mapping[AxisId, int]
278    ) -> Mapping[AxisId, int]:
279        return input_shape
280
281    def _apply(self, x: Tensor, stat: Stat) -> Tensor:
282        return x.astype(self.dtype)
EnsureDtype( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId, dtype: Literal['bool', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64'])
dtype: Literal['bool', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64']
@classmethod
def from_proc_descr( cls, descr: bioimageio.spec.model.v0_5.EnsureDtypeDescr, member_id: bioimageio.spec.model.v0_5.TensorId):
269    @classmethod
270    def from_proc_descr(cls, descr: v0_5.EnsureDtypeDescr, member_id: MemberId):
271        return cls(input=member_id, output=member_id, dtype=descr.kwargs.dtype)
def get_descr(self):
273    def get_descr(self):
274        return v0_5.EnsureDtypeDescr(kwargs=v0_5.EnsureDtypeKwargs(dtype=self.dtype))
def get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
276    def get_output_shape(
277        self, input_shape: Mapping[AxisId, int]
278    ) -> Mapping[AxisId, int]:
279        return input_shape
@dataclass
class ScaleLinear(_SimpleOperator):
285@dataclass
286class ScaleLinear(_SimpleOperator):
287    gain: Union[float, xr.DataArray] = 1.0
288    """multiplicative factor"""
289
290    offset: Union[float, xr.DataArray] = 0.0
291    """additive term"""
292
293    def _apply(self, x: Tensor, stat: Stat) -> Tensor:
294        return x * self.gain + self.offset
295
296    def get_output_shape(
297        self, input_shape: Mapping[AxisId, int]
298    ) -> Mapping[AxisId, int]:
299        return input_shape
300
301    @classmethod
302    def from_proc_descr(
303        cls,
304        descr: Union[v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr],
305        member_id: MemberId,
306    ) -> Self:
307        kwargs = descr.kwargs
308        if isinstance(kwargs, v0_5.ScaleLinearKwargs):
309            axis = None
310        elif isinstance(kwargs, v0_5.ScaleLinearAlongAxisKwargs):
311            axis = kwargs.axis
312        elif isinstance(kwargs, v0_4.ScaleLinearKwargs):
313            if kwargs.axes is not None:
314                raise NotImplementedError(
315                    "model.v0_4.ScaleLinearKwargs with axes not implemented, please consider updating the model to v0_5."
316                )
317            axis = None
318        else:
319            assert_never(kwargs)
320
321        if axis:
322            gain = xr.DataArray(np.atleast_1d(kwargs.gain), dims=axis)
323            offset = xr.DataArray(np.atleast_1d(kwargs.offset), dims=axis)
324        else:
325            assert isinstance(kwargs.gain, (float, int)) or len(kwargs.gain) == 1, (
326                kwargs.gain
327            )
328            gain = (
329                kwargs.gain if isinstance(kwargs.gain, (float, int)) else kwargs.gain[0]
330            )
331            assert isinstance(kwargs.offset, (float, int)) or len(kwargs.offset) == 1
332            offset = (
333                kwargs.offset
334                if isinstance(kwargs.offset, (float, int))
335                else kwargs.offset[0]
336            )
337
338        return cls(input=member_id, output=member_id, gain=gain, offset=offset)
ScaleLinear( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId, gain: Union[float, xarray.core.dataarray.DataArray] = 1.0, offset: Union[float, xarray.core.dataarray.DataArray] = 0.0)
gain: Union[float, xarray.core.dataarray.DataArray] = 1.0

multiplicative factor

offset: Union[float, xarray.core.dataarray.DataArray] = 0.0

additive term

def get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
296    def get_output_shape(
297        self, input_shape: Mapping[AxisId, int]
298    ) -> Mapping[AxisId, int]:
299        return input_shape
@classmethod
def from_proc_descr( cls, descr: Union[bioimageio.spec.model.v0_4.ScaleLinearDescr, bioimageio.spec.model.v0_5.ScaleLinearDescr], member_id: bioimageio.spec.model.v0_5.TensorId) -> Self:
301    @classmethod
302    def from_proc_descr(
303        cls,
304        descr: Union[v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr],
305        member_id: MemberId,
306    ) -> Self:
307        kwargs = descr.kwargs
308        if isinstance(kwargs, v0_5.ScaleLinearKwargs):
309            axis = None
310        elif isinstance(kwargs, v0_5.ScaleLinearAlongAxisKwargs):
311            axis = kwargs.axis
312        elif isinstance(kwargs, v0_4.ScaleLinearKwargs):
313            if kwargs.axes is not None:
314                raise NotImplementedError(
315                    "model.v0_4.ScaleLinearKwargs with axes not implemented, please consider updating the model to v0_5."
316                )
317            axis = None
318        else:
319            assert_never(kwargs)
320
321        if axis:
322            gain = xr.DataArray(np.atleast_1d(kwargs.gain), dims=axis)
323            offset = xr.DataArray(np.atleast_1d(kwargs.offset), dims=axis)
324        else:
325            assert isinstance(kwargs.gain, (float, int)) or len(kwargs.gain) == 1, (
326                kwargs.gain
327            )
328            gain = (
329                kwargs.gain if isinstance(kwargs.gain, (float, int)) else kwargs.gain[0]
330            )
331            assert isinstance(kwargs.offset, (float, int)) or len(kwargs.offset) == 1
332            offset = (
333                kwargs.offset
334                if isinstance(kwargs.offset, (float, int))
335                else kwargs.offset[0]
336            )
337
338        return cls(input=member_id, output=member_id, gain=gain, offset=offset)
@dataclass
class ScaleMeanVariance(_SimpleOperator):
341@dataclass
342class ScaleMeanVariance(_SimpleOperator):
343    axes: Optional[Sequence[AxisId]] = None
344    reference_tensor: Optional[MemberId] = None
345    eps: float = 1e-6
346    mean: Union[SampleMean, DatasetMean] = field(init=False)
347    std: Union[SampleStd, DatasetStd] = field(init=False)
348    ref_mean: Union[SampleMean, DatasetMean] = field(init=False)
349    ref_std: Union[SampleStd, DatasetStd] = field(init=False)
350
351    @property
352    def required_measures(self):
353        return {self.mean, self.std, self.ref_mean, self.ref_std}
354
355    def __post_init__(self):
356        axes = None if self.axes is None else tuple(self.axes)
357        ref_tensor = self.reference_tensor or self.input
358        if axes is None or AxisId("batch") not in axes:
359            Mean = SampleMean
360            Std = SampleStd
361        else:
362            Mean = DatasetMean
363            Std = DatasetStd
364
365        self.mean = Mean(member_id=self.input, axes=axes)
366        self.std = Std(member_id=self.input, axes=axes)
367        self.ref_mean = Mean(member_id=ref_tensor, axes=axes)
368        self.ref_std = Std(member_id=ref_tensor, axes=axes)
369
370    def _apply(self, x: Tensor, stat: Stat) -> Tensor:
371        mean = stat[self.mean]
372        std = stat[self.std] + self.eps
373        ref_mean = stat[self.ref_mean]
374        ref_std = stat[self.ref_std] + self.eps
375        return (x - mean) / std * ref_std + ref_mean
376
377    def get_output_shape(
378        self, input_shape: Mapping[AxisId, int]
379    ) -> Mapping[AxisId, int]:
380        return input_shape
381
382    @classmethod
383    def from_proc_descr(
384        cls,
385        descr: Union[v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr],
386        member_id: MemberId,
387    ) -> Self:
388        kwargs = descr.kwargs
389        _, axes = _get_axes(descr.kwargs)
390
391        return cls(
392            input=member_id,
393            output=member_id,
394            reference_tensor=MemberId(str(kwargs.reference_tensor)),
395            axes=axes,
396            eps=kwargs.eps,
397        )
ScaleMeanVariance( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId, axes: Optional[Sequence[bioimageio.spec.model.v0_5.AxisId]] = None, reference_tensor: Optional[bioimageio.spec.model.v0_5.TensorId] = None, eps: float = 1e-06)
axes: Optional[Sequence[bioimageio.spec.model.v0_5.AxisId]] = None
reference_tensor: Optional[bioimageio.spec.model.v0_5.TensorId] = None
eps: float = 1e-06
required_measures
351    @property
352    def required_measures(self):
353        return {self.mean, self.std, self.ref_mean, self.ref_std}
def get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
377    def get_output_shape(
378        self, input_shape: Mapping[AxisId, int]
379    ) -> Mapping[AxisId, int]:
380        return input_shape
@classmethod
def from_proc_descr( cls, descr: Union[bioimageio.spec.model.v0_4.ScaleMeanVarianceDescr, bioimageio.spec.model.v0_5.ScaleMeanVarianceDescr], member_id: bioimageio.spec.model.v0_5.TensorId) -> Self:
382    @classmethod
383    def from_proc_descr(
384        cls,
385        descr: Union[v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr],
386        member_id: MemberId,
387    ) -> Self:
388        kwargs = descr.kwargs
389        _, axes = _get_axes(descr.kwargs)
390
391        return cls(
392            input=member_id,
393            output=member_id,
394            reference_tensor=MemberId(str(kwargs.reference_tensor)),
395            axes=axes,
396            eps=kwargs.eps,
397        )
Inherited Members
_SimpleOperator
input
output
@dataclass
class ScaleRange(_SimpleOperator):
422@dataclass
423class ScaleRange(_SimpleOperator):
424    lower_percentile: InitVar[Optional[Union[SampleQuantile, DatasetPercentile]]] = None
425    upper_percentile: InitVar[Optional[Union[SampleQuantile, DatasetPercentile]]] = None
426    lower: Union[SampleQuantile, DatasetPercentile] = field(init=False)
427    upper: Union[SampleQuantile, DatasetPercentile] = field(init=False)
428
429    eps: float = 1e-6
430
431    def __post_init__(
432        self,
433        lower_percentile: Optional[Union[SampleQuantile, DatasetPercentile]],
434        upper_percentile: Optional[Union[SampleQuantile, DatasetPercentile]],
435    ):
436        if lower_percentile is None:
437            tid = self.input if upper_percentile is None else upper_percentile.member_id
438            self.lower = DatasetPercentile(q=0.0, member_id=tid)
439        else:
440            self.lower = lower_percentile
441
442        if upper_percentile is None:
443            self.upper = DatasetPercentile(q=1.0, member_id=self.lower.member_id)
444        else:
445            self.upper = upper_percentile
446
447        assert self.lower.member_id == self.upper.member_id
448        assert self.lower.q < self.upper.q
449        assert self.lower.axes == self.upper.axes
450
451    @property
452    def required_measures(self):
453        return {self.lower, self.upper}
454
455    def get_output_shape(
456        self, input_shape: Mapping[AxisId, int]
457    ) -> Mapping[AxisId, int]:
458        return input_shape
459
460    @classmethod
461    def from_proc_descr(
462        cls,
463        descr: Union[v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr],
464        member_id: MemberId,
465    ):
466        kwargs = descr.kwargs
467        ref_tensor = (
468            member_id
469            if kwargs.reference_tensor is None
470            else MemberId(str(kwargs.reference_tensor))
471        )
472        dataset_mode, axes = _get_axes(descr.kwargs)
473        if dataset_mode:
474            Percentile = DatasetPercentile
475        else:
476            Percentile = SampleQuantile
477
478        return cls(
479            input=member_id,
480            output=member_id,
481            lower_percentile=Percentile(
482                q=kwargs.min_percentile / 100, axes=axes, member_id=ref_tensor
483            ),
484            upper_percentile=Percentile(
485                q=kwargs.max_percentile / 100, axes=axes, member_id=ref_tensor
486            ),
487        )
488
489    def _apply(self, x: Tensor, stat: Stat) -> Tensor:
490        lower = stat[self.lower]
491        upper = stat[self.upper]
492        return (x - lower) / (upper - lower + self.eps)
493
494    def get_descr(self):
495        assert self.lower.axes == self.upper.axes
496        assert self.lower.member_id == self.upper.member_id
497
498        return v0_5.ScaleRangeDescr(
499            kwargs=v0_5.ScaleRangeKwargs(
500                axes=self.lower.axes,
501                min_percentile=self.lower.q * 100,
502                max_percentile=self.upper.q * 100,
503                eps=self.eps,
504                reference_tensor=self.lower.member_id,
505            )
506        )
ScaleRange( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId, lower_percentile: dataclasses.InitVar[typing.Union[bioimageio.core.stat_measures.SampleQuantile, bioimageio.core.stat_measures.DatasetPercentile, NoneType]] = None, upper_percentile: dataclasses.InitVar[typing.Union[bioimageio.core.stat_measures.SampleQuantile, bioimageio.core.stat_measures.DatasetPercentile, NoneType]] = None, eps: float = 1e-06)
lower_percentile: dataclasses.InitVar[typing.Union[bioimageio.core.stat_measures.SampleQuantile, bioimageio.core.stat_measures.DatasetPercentile, NoneType]] = None
upper_percentile: dataclasses.InitVar[typing.Union[bioimageio.core.stat_measures.SampleQuantile, bioimageio.core.stat_measures.DatasetPercentile, NoneType]] = None
eps: float = 1e-06
required_measures
451    @property
452    def required_measures(self):
453        return {self.lower, self.upper}
def get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
455    def get_output_shape(
456        self, input_shape: Mapping[AxisId, int]
457    ) -> Mapping[AxisId, int]:
458        return input_shape
@classmethod
def from_proc_descr( cls, descr: Union[bioimageio.spec.model.v0_4.ScaleRangeDescr, bioimageio.spec.model.v0_5.ScaleRangeDescr], member_id: bioimageio.spec.model.v0_5.TensorId):
460    @classmethod
461    def from_proc_descr(
462        cls,
463        descr: Union[v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr],
464        member_id: MemberId,
465    ):
466        kwargs = descr.kwargs
467        ref_tensor = (
468            member_id
469            if kwargs.reference_tensor is None
470            else MemberId(str(kwargs.reference_tensor))
471        )
472        dataset_mode, axes = _get_axes(descr.kwargs)
473        if dataset_mode:
474            Percentile = DatasetPercentile
475        else:
476            Percentile = SampleQuantile
477
478        return cls(
479            input=member_id,
480            output=member_id,
481            lower_percentile=Percentile(
482                q=kwargs.min_percentile / 100, axes=axes, member_id=ref_tensor
483            ),
484            upper_percentile=Percentile(
485                q=kwargs.max_percentile / 100, axes=axes, member_id=ref_tensor
486            ),
487        )
def get_descr(self):
494    def get_descr(self):
495        assert self.lower.axes == self.upper.axes
496        assert self.lower.member_id == self.upper.member_id
497
498        return v0_5.ScaleRangeDescr(
499            kwargs=v0_5.ScaleRangeKwargs(
500                axes=self.lower.axes,
501                min_percentile=self.lower.q * 100,
502                max_percentile=self.upper.q * 100,
503                eps=self.eps,
504                reference_tensor=self.lower.member_id,
505            )
506        )
Inherited Members
_SimpleOperator
input
output
@dataclass
class Sigmoid(_SimpleOperator):
509@dataclass
510class Sigmoid(_SimpleOperator):
511    """1 / (1 + e^(-input))."""
512
513    def _apply(self, x: Tensor, stat: Stat) -> Tensor:
514        return Tensor(1.0 / (1.0 + np.exp(-x)), dims=x.dims)
515
516    @property
517    def required_measures(self) -> Collection[Measure]:
518        return {}
519
520    def get_output_shape(
521        self, input_shape: Mapping[AxisId, int]
522    ) -> Mapping[AxisId, int]:
523        return input_shape
524
525    @classmethod
526    def from_proc_descr(
527        cls, descr: Union[v0_4.SigmoidDescr, v0_5.SigmoidDescr], member_id: MemberId
528    ) -> Self:
529        assert isinstance(descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr))
530        return cls(input=member_id, output=member_id)
531
532    def get_descr(self):
533        return v0_5.SigmoidDescr()

1 / (1 + e^(-input)).

required_measures: Collection[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]
516    @property
517    def required_measures(self) -> Collection[Measure]:
518        return {}
def get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
520    def get_output_shape(
521        self, input_shape: Mapping[AxisId, int]
522    ) -> Mapping[AxisId, int]:
523        return input_shape
@classmethod
def from_proc_descr( cls, descr: Union[bioimageio.spec.model.v0_4.SigmoidDescr, bioimageio.spec.model.v0_5.SigmoidDescr], member_id: bioimageio.spec.model.v0_5.TensorId) -> Self:
525    @classmethod
526    def from_proc_descr(
527        cls, descr: Union[v0_4.SigmoidDescr, v0_5.SigmoidDescr], member_id: MemberId
528    ) -> Self:
529        assert isinstance(descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr))
530        return cls(input=member_id, output=member_id)
def get_descr(self):
532    def get_descr(self):
533        return v0_5.SigmoidDescr()
Inherited Members
_SimpleOperator
input
output
@dataclass
class Softmax(_SimpleOperator):
536@dataclass
537class Softmax(_SimpleOperator):
538    """Softmax activation function."""
539
540    axis: AxisId = AxisId("channel")
541
542    def _apply(self, x: Tensor, stat: Stat) -> Tensor:
543        axis_idx = x.dims.index(self.axis)
544        result = scipy.special.softmax(x.data, axis=axis_idx)
545        result_xr = xr.DataArray(result, dims=x.dims)
546        return Tensor.from_xarray(result_xr)
547
548    @property
549    def required_measures(self) -> Collection[Measure]:
550        return {}
551
552    def get_output_shape(
553        self, input_shape: Mapping[AxisId, int]
554    ) -> Mapping[AxisId, int]:
555        return input_shape
556
557    @classmethod
558    def from_proc_descr(cls, descr: v0_5.SoftmaxDescr, member_id: MemberId) -> Self:
559        assert isinstance(descr, v0_5.SoftmaxDescr)
560        return cls(input=member_id, output=member_id, axis=descr.kwargs.axis)
561
562    def get_descr(self):
563        return v0_5.SoftmaxDescr(kwargs=v0_5.SoftmaxKwargs(axis=self.axis))

Softmax activation function.

required_measures: Collection[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]
548    @property
549    def required_measures(self) -> Collection[Measure]:
550        return {}
def get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
552    def get_output_shape(
553        self, input_shape: Mapping[AxisId, int]
554    ) -> Mapping[AxisId, int]:
555        return input_shape
@classmethod
def from_proc_descr( cls, descr: bioimageio.spec.model.v0_5.SoftmaxDescr, member_id: bioimageio.spec.model.v0_5.TensorId) -> Self:
557    @classmethod
558    def from_proc_descr(cls, descr: v0_5.SoftmaxDescr, member_id: MemberId) -> Self:
559        assert isinstance(descr, v0_5.SoftmaxDescr)
560        return cls(input=member_id, output=member_id, axis=descr.kwargs.axis)
def get_descr(self):
562    def get_descr(self):
563        return v0_5.SoftmaxDescr(kwargs=v0_5.SoftmaxKwargs(axis=self.axis))
Inherited Members
_SimpleOperator
input
output
@dataclass
class ZeroMeanUnitVariance(_SimpleOperator):
566@dataclass
567class ZeroMeanUnitVariance(_SimpleOperator):
568    """normalize to zero mean, unit variance."""
569
570    mean: MeanMeasure
571    std: StdMeasure
572
573    eps: float = 1e-6
574
575    def __post_init__(self):
576        assert self.mean.axes == self.std.axes
577
578    @property
579    def required_measures(self) -> Set[Union[MeanMeasure, StdMeasure]]:
580        return {self.mean, self.std}
581
582    def get_output_shape(
583        self, input_shape: Mapping[AxisId, int]
584    ) -> Mapping[AxisId, int]:
585        return input_shape
586
587    @classmethod
588    def from_proc_descr(
589        cls,
590        descr: Union[v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr],
591        member_id: MemberId,
592    ):
593        dataset_mode, axes = _get_axes(descr.kwargs)
594
595        if dataset_mode:
596            Mean = DatasetMean
597            Std = DatasetStd
598        else:
599            Mean = SampleMean
600            Std = SampleStd
601
602        return cls(
603            input=member_id,
604            output=member_id,
605            mean=Mean(axes=axes, member_id=member_id),
606            std=Std(axes=axes, member_id=member_id),
607        )
608
609    def _apply(self, x: Tensor, stat: Stat) -> Tensor:
610        mean = stat[self.mean]
611        std = stat[self.std]
612        return (x - mean) / (std + self.eps)
613
614    def get_descr(self):
615        return v0_5.ZeroMeanUnitVarianceDescr(
616            kwargs=v0_5.ZeroMeanUnitVarianceKwargs(axes=self.mean.axes, eps=self.eps)
617        )

normalize to zero mean, unit variance.

eps: float = 1e-06
578    @property
579    def required_measures(self) -> Set[Union[MeanMeasure, StdMeasure]]:
580        return {self.mean, self.std}
def get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
582    def get_output_shape(
583        self, input_shape: Mapping[AxisId, int]
584    ) -> Mapping[AxisId, int]:
585        return input_shape
587    @classmethod
588    def from_proc_descr(
589        cls,
590        descr: Union[v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr],
591        member_id: MemberId,
592    ):
593        dataset_mode, axes = _get_axes(descr.kwargs)
594
595        if dataset_mode:
596            Mean = DatasetMean
597            Std = DatasetStd
598        else:
599            Mean = SampleMean
600            Std = SampleStd
601
602        return cls(
603            input=member_id,
604            output=member_id,
605            mean=Mean(axes=axes, member_id=member_id),
606            std=Std(axes=axes, member_id=member_id),
607        )
def get_descr(self):
614    def get_descr(self):
615        return v0_5.ZeroMeanUnitVarianceDescr(
616            kwargs=v0_5.ZeroMeanUnitVarianceKwargs(axes=self.mean.axes, eps=self.eps)
617        )
Inherited Members
_SimpleOperator
input
output
@dataclass
class FixedZeroMeanUnitVariance(_SimpleOperator):
620@dataclass
621class FixedZeroMeanUnitVariance(_SimpleOperator):
622    """normalize to zero mean, unit variance with precomputed values."""
623
624    mean: Union[float, xr.DataArray]
625    std: Union[float, xr.DataArray]
626
627    eps: float = 1e-6
628
629    def __post_init__(self):
630        assert (
631            isinstance(self.mean, (int, float))
632            or isinstance(self.std, (int, float))
633            or self.mean.dims == self.std.dims
634        )
635
636    def get_output_shape(
637        self, input_shape: Mapping[AxisId, int]
638    ) -> Mapping[AxisId, int]:
639        return input_shape
640
641    @classmethod
642    def from_proc_descr(
643        cls,
644        descr: v0_5.FixedZeroMeanUnitVarianceDescr,
645        member_id: MemberId,
646    ) -> Self:
647        if isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceKwargs):
648            dims = None
649        elif isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs):
650            dims = (AxisId(descr.kwargs.axis),)
651        else:
652            assert_never(descr.kwargs)
653
654        return cls(
655            input=member_id,
656            output=member_id,
657            mean=xr.DataArray(descr.kwargs.mean, dims=dims),
658            std=xr.DataArray(descr.kwargs.std, dims=dims),
659        )
660
661    def get_descr(self):
662        if isinstance(self.mean, (int, float)):
663            assert isinstance(self.std, (int, float))
664            kwargs = v0_5.FixedZeroMeanUnitVarianceKwargs(mean=self.mean, std=self.std)
665        else:
666            assert isinstance(self.std, xr.DataArray)
667            assert len(self.mean.dims) == 1
668            kwargs = v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs(
669                axis=AxisId(str(self.mean.dims[0])),
670                mean=list(self.mean),
671                std=list(self.std),
672            )
673
674        return v0_5.FixedZeroMeanUnitVarianceDescr(kwargs=kwargs)
675
676    def _apply(self, x: Tensor, stat: Stat) -> Tensor:
677        return (x - self.mean) / (self.std + self.eps)

normalize to zero mean, unit variance with precomputed values.

FixedZeroMeanUnitVariance( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId, mean: Union[float, xarray.core.dataarray.DataArray], std: Union[float, xarray.core.dataarray.DataArray], eps: float = 1e-06)
mean: Union[float, xarray.core.dataarray.DataArray]
std: Union[float, xarray.core.dataarray.DataArray]
eps: float = 1e-06
def get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
636    def get_output_shape(
637        self, input_shape: Mapping[AxisId, int]
638    ) -> Mapping[AxisId, int]:
639        return input_shape
@classmethod
def from_proc_descr( cls, descr: bioimageio.spec.model.v0_5.FixedZeroMeanUnitVarianceDescr, member_id: bioimageio.spec.model.v0_5.TensorId) -> Self:
641    @classmethod
642    def from_proc_descr(
643        cls,
644        descr: v0_5.FixedZeroMeanUnitVarianceDescr,
645        member_id: MemberId,
646    ) -> Self:
647        if isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceKwargs):
648            dims = None
649        elif isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs):
650            dims = (AxisId(descr.kwargs.axis),)
651        else:
652            assert_never(descr.kwargs)
653
654        return cls(
655            input=member_id,
656            output=member_id,
657            mean=xr.DataArray(descr.kwargs.mean, dims=dims),
658            std=xr.DataArray(descr.kwargs.std, dims=dims),
659        )
def get_descr(self):
661    def get_descr(self):
662        if isinstance(self.mean, (int, float)):
663            assert isinstance(self.std, (int, float))
664            kwargs = v0_5.FixedZeroMeanUnitVarianceKwargs(mean=self.mean, std=self.std)
665        else:
666            assert isinstance(self.std, xr.DataArray)
667            assert len(self.mean.dims) == 1
668            kwargs = v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs(
669                axis=AxisId(str(self.mean.dims[0])),
670                mean=list(self.mean),
671                std=list(self.std),
672            )
673
674        return v0_5.FixedZeroMeanUnitVarianceDescr(kwargs=kwargs)
ProcDescr = typing.Union[typing.Annotated[typing.Union[bioimageio.spec.model.v0_4.BinarizeDescr, bioimageio.spec.model.v0_4.ClipDescr, bioimageio.spec.model.v0_4.ScaleLinearDescr, bioimageio.spec.model.v0_4.SigmoidDescr, bioimageio.spec.model.v0_4.ZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_4.ScaleRangeDescr], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], typing.Annotated[typing.Union[bioimageio.spec.model.v0_4.BinarizeDescr, bioimageio.spec.model.v0_4.ClipDescr, bioimageio.spec.model.v0_4.ScaleLinearDescr, bioimageio.spec.model.v0_4.SigmoidDescr, bioimageio.spec.model.v0_4.ZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_4.ScaleRangeDescr, bioimageio.spec.model.v0_4.ScaleMeanVarianceDescr], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], typing.Annotated[typing.Union[bioimageio.spec.model.v0_5.BinarizeDescr, bioimageio.spec.model.v0_5.ClipDescr, bioimageio.spec.model.v0_5.EnsureDtypeDescr, bioimageio.spec.model.v0_5.FixedZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_5.ScaleLinearDescr, bioimageio.spec.model.v0_5.ScaleRangeDescr, bioimageio.spec.model.v0_5.SigmoidDescr, bioimageio.spec.model.v0_5.SoftmaxDescr, bioimageio.spec.model.v0_5.ZeroMeanUnitVarianceDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)], typing.Annotated[typing.Union[bioimageio.spec.model.v0_5.BinarizeDescr, bioimageio.spec.model.v0_5.ClipDescr, bioimageio.spec.model.v0_5.EnsureDtypeDescr, bioimageio.spec.model.v0_5.FixedZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_5.ScaleLinearDescr, bioimageio.spec.model.v0_5.ScaleMeanVarianceDescr, bioimageio.spec.model.v0_5.ScaleRangeDescr, bioimageio.spec.model.v0_5.SigmoidDescr, bioimageio.spec.model.v0_5.SoftmaxDescr, bioimageio.spec.model.v0_5.ZeroMeanUnitVarianceDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]
def get_proc( proc_descr: Union[Annotated[Union[bioimageio.spec.model.v0_4.BinarizeDescr, bioimageio.spec.model.v0_4.ClipDescr, bioimageio.spec.model.v0_4.ScaleLinearDescr, bioimageio.spec.model.v0_4.SigmoidDescr, bioimageio.spec.model.v0_4.ZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_4.ScaleRangeDescr], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.spec.model.v0_4.BinarizeDescr, bioimageio.spec.model.v0_4.ClipDescr, bioimageio.spec.model.v0_4.ScaleLinearDescr, bioimageio.spec.model.v0_4.SigmoidDescr, bioimageio.spec.model.v0_4.ZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_4.ScaleRangeDescr, bioimageio.spec.model.v0_4.ScaleMeanVarianceDescr], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.spec.model.v0_5.BinarizeDescr, bioimageio.spec.model.v0_5.ClipDescr, bioimageio.spec.model.v0_5.EnsureDtypeDescr, bioimageio.spec.model.v0_5.FixedZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_5.ScaleLinearDescr, bioimageio.spec.model.v0_5.ScaleRangeDescr, bioimageio.spec.model.v0_5.SigmoidDescr, bioimageio.spec.model.v0_5.SoftmaxDescr, bioimageio.spec.model.v0_5.ZeroMeanUnitVarianceDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.spec.model.v0_5.BinarizeDescr, bioimageio.spec.model.v0_5.ClipDescr, bioimageio.spec.model.v0_5.EnsureDtypeDescr, bioimageio.spec.model.v0_5.FixedZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_5.ScaleLinearDescr, bioimageio.spec.model.v0_5.ScaleMeanVarianceDescr, bioimageio.spec.model.v0_5.ScaleRangeDescr, bioimageio.spec.model.v0_5.SigmoidDescr, bioimageio.spec.model.v0_5.SoftmaxDescr, bioimageio.spec.model.v0_5.ZeroMeanUnitVarianceDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], tensor_descr: Union[bioimageio.spec.model.v0_4.InputTensorDescr, bioimageio.spec.model.v0_4.OutputTensorDescr, bioimageio.spec.model.v0_5.InputTensorDescr, bioimageio.spec.model.v0_5.OutputTensorDescr]) -> Union[AddKnownDatasetStats, Binarize, Clip, EnsureDtype, FixedZeroMeanUnitVariance, ScaleLinear, ScaleMeanVariance, ScaleRange, Sigmoid, Softmax, UpdateStats, ZeroMeanUnitVariance]:
703def get_proc(
704    proc_descr: ProcDescr,
705    tensor_descr: Union[
706        v0_4.InputTensorDescr,
707        v0_4.OutputTensorDescr,
708        v0_5.InputTensorDescr,
709        v0_5.OutputTensorDescr,
710    ],
711) -> Processing:
712    member_id = get_member_id(tensor_descr)
713
714    if isinstance(proc_descr, (v0_4.BinarizeDescr, v0_5.BinarizeDescr)):
715        return Binarize.from_proc_descr(proc_descr, member_id)
716    elif isinstance(proc_descr, (v0_4.ClipDescr, v0_5.ClipDescr)):
717        return Clip.from_proc_descr(proc_descr, member_id)
718    elif isinstance(proc_descr, v0_5.EnsureDtypeDescr):
719        return EnsureDtype.from_proc_descr(proc_descr, member_id)
720    elif isinstance(proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr):
721        return FixedZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id)
722    elif isinstance(proc_descr, (v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr)):
723        return ScaleLinear.from_proc_descr(proc_descr, member_id)
724    elif isinstance(
725        proc_descr, (v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr)
726    ):
727        return ScaleMeanVariance.from_proc_descr(proc_descr, member_id)
728    elif isinstance(proc_descr, (v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr)):
729        return ScaleRange.from_proc_descr(proc_descr, member_id)
730    elif isinstance(proc_descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)):
731        return Sigmoid.from_proc_descr(proc_descr, member_id)
732    elif (
733        isinstance(proc_descr, v0_4.ZeroMeanUnitVarianceDescr)
734        and proc_descr.kwargs.mode == "fixed"
735    ):
736        if not isinstance(
737            tensor_descr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)
738        ):
739            raise TypeError(
740                "Expected v0_4 tensor description for v0_4 processing description"
741            )
742
743        v5_proc_descr = _convert_proc(proc_descr, tensor_descr.axes)
744        assert isinstance(v5_proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr)
745        return FixedZeroMeanUnitVariance.from_proc_descr(v5_proc_descr, member_id)
746    elif isinstance(
747        proc_descr,
748        (v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr),
749    ):
750        return ZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id)
751    elif isinstance(proc_descr, v0_5.SoftmaxDescr):
752        return Softmax.from_proc_descr(proc_descr, member_id)
753    else:
754        assert_never(proc_descr)