bioimageio.core.proc_ops

  1import collections.abc
  2from abc import ABC, abstractmethod
  3from dataclasses import InitVar, dataclass, field
  4from typing import (
  5    Collection,
  6    Literal,
  7    Mapping,
  8    Optional,
  9    Sequence,
 10    Set,
 11    Tuple,
 12    Union,
 13)
 14
 15import numpy as np
 16import xarray as xr
 17from typing_extensions import Self, assert_never
 18
 19from bioimageio.spec.model import v0_4, v0_5
 20
 21from ._op_base import BlockedOperator, Operator
 22from .axis import AxisId, PerAxis
 23from .block import Block
 24from .common import DTypeStr, MemberId
 25from .sample import Sample, SampleBlock, SampleBlockWithOrigin
 26from .stat_calculators import StatsCalculator
 27from .stat_measures import (
 28    DatasetMean,
 29    DatasetMeasure,
 30    DatasetPercentile,
 31    DatasetStd,
 32    MeanMeasure,
 33    Measure,
 34    MeasureValue,
 35    SampleMean,
 36    SampleQuantile,
 37    SampleStd,
 38    Stat,
 39    StdMeasure,
 40)
 41from .tensor import Tensor
 42
 43
 44def _convert_axis_ids(
 45    axes: v0_4.AxesInCZYX,
 46    mode: Literal["per_sample", "per_dataset"],
 47) -> Tuple[AxisId, ...]:
 48    if not isinstance(axes, str):
 49        return tuple(axes)
 50
 51    if mode == "per_sample":
 52        ret = []
 53    elif mode == "per_dataset":
 54        ret = [AxisId("b")]
 55    else:
 56        assert_never(mode)
 57
 58    ret.extend([AxisId(a) for a in axes])
 59    return tuple(ret)
 60
 61
 62@dataclass
 63class _SimpleOperator(BlockedOperator, ABC):
 64    input: MemberId
 65    output: MemberId
 66
 67    @property
 68    def required_measures(self) -> Collection[Measure]:
 69        return set()
 70
 71    @abstractmethod
 72    def get_output_shape(self, input_shape: PerAxis[int]) -> PerAxis[int]: ...
 73
 74    def __call__(self, sample: Union[Sample, SampleBlock]) -> None:
 75        if self.input not in sample.members:
 76            return
 77
 78        input_tensor = sample.members[self.input]
 79        output_tensor = self._apply(input_tensor, sample.stat)
 80
 81        if self.output in sample.members:
 82            assert (
 83                sample.members[self.output].tagged_shape == output_tensor.tagged_shape
 84            )
 85
 86        if isinstance(sample, Sample):
 87            sample.members[self.output] = output_tensor
 88        elif isinstance(sample, SampleBlock):
 89            b = sample.blocks[self.input]
 90            sample.blocks[self.output] = Block(
 91                sample_shape=self.get_output_shape(sample.shape[self.input]),
 92                data=output_tensor,
 93                inner_slice=b.inner_slice,
 94                halo=b.halo,
 95                block_index=b.block_index,
 96                blocks_in_sample=b.blocks_in_sample,
 97            )
 98        else:
 99            assert_never(sample)
100
101    @abstractmethod
102    def _apply(self, input: Tensor, stat: Stat) -> Tensor: ...
103
104
105@dataclass
106class AddKnownDatasetStats(BlockedOperator):
107    dataset_stats: Mapping[DatasetMeasure, MeasureValue]
108
109    @property
110    def required_measures(self) -> Set[Measure]:
111        return set()
112
113    def __call__(self, sample: Union[Sample, SampleBlock]) -> None:
114        sample.stat.update(self.dataset_stats.items())
115
116
117# @dataclass
118# class UpdateStats(Operator):
119#     """Calculates sample and/or dataset measures"""
120
121#     measures: Union[Sequence[Measure], Set[Measure], Mapping[Measure, MeasureValue]]
122#     """sample and dataset `measuers` to be calculated by this operator. Initial/fixed
123#     dataset measure values may be given, see `keep_updating_dataset_stats` for details.
124#     """
125#     keep_updating_dataset_stats: Optional[bool] = None
126#     """indicates if operator calls should keep updating dataset statistics or not
127
128#     default (None): if `measures` is a `Mapping` (i.e. initial measure values are
129#     given) no further updates to dataset statistics is conducted, otherwise (w.o.
130#     initial measure values) dataset statistics are updated by each processed sample.
131#     """
132#     _keep_updating_dataset_stats: bool = field(init=False)
133#     _stats_calculator: StatsCalculator = field(init=False)
134
135#     @property
136#     def required_measures(self) -> Set[Measure]:
137#         return set()
138
139#     def __post_init__(self):
140#         self._stats_calculator = StatsCalculator(self.measures)
141#         if self.keep_updating_dataset_stats is None:
142#             self._keep_updating_dataset_stats = not isinstance(self.measures, collections.abc.Mapping)
143#         else:
144#             self._keep_updating_dataset_stats = self.keep_updating_dataset_stats
145
146#     def __call__(self, sample_block: SampleBlockWithOrigin> None:
147#         if self._keep_updating_dataset_stats:
148#             sample.stat.update(self._stats_calculator.update_and_get_all(sample))
149#         else:
150#             sample.stat.update(self._stats_calculator.skip_update_and_get_all(sample))
151
152
153@dataclass
154class UpdateStats(Operator):
155    """Calculates sample and/or dataset measures"""
156
157    stats_calculator: StatsCalculator
158    """`StatsCalculator` to be used by this operator."""
159    keep_updating_initial_dataset_stats: bool = False
160    """indicates if operator calls should keep updating initial dataset statistics or not;
161    if the `stats_calculator` was not provided with any initial dataset statistics,
162    these are always updated with every new sample.
163    """
164    _keep_updating_dataset_stats: bool = field(init=False)
165
166    @property
167    def required_measures(self) -> Set[Measure]:
168        return set()
169
170    def __post_init__(self):
171        self._keep_updating_dataset_stats = (
172            self.keep_updating_initial_dataset_stats
173            or not self.stats_calculator.has_dataset_measures
174        )
175
176    def __call__(self, sample: Union[Sample, SampleBlockWithOrigin]) -> None:
177        if isinstance(sample, SampleBlockWithOrigin):
178            # update stats with whole sample on first block
179            if sample.block_index != 0:
180                return
181
182            origin = sample.origin
183        else:
184            origin = sample
185
186        if self._keep_updating_dataset_stats:
187            sample.stat.update(self.stats_calculator.update_and_get_all(origin))
188        else:
189            sample.stat.update(self.stats_calculator.skip_update_and_get_all(origin))
190
191
192@dataclass
193class Binarize(_SimpleOperator):
194    """'output = tensor > threshold'."""
195
196    threshold: Union[float, Sequence[float]]
197    axis: Optional[AxisId] = None
198
199    def _apply(self, input: Tensor, stat: Stat) -> Tensor:
200        return input > self.threshold
201
202    def get_output_shape(
203        self, input_shape: Mapping[AxisId, int]
204    ) -> Mapping[AxisId, int]:
205        return input_shape
206
207    @classmethod
208    def from_proc_descr(
209        cls, descr: Union[v0_4.BinarizeDescr, v0_5.BinarizeDescr], member_id: MemberId
210    ) -> Self:
211        if isinstance(descr.kwargs, (v0_4.BinarizeKwargs, v0_5.BinarizeKwargs)):
212            return cls(
213                input=member_id, output=member_id, threshold=descr.kwargs.threshold
214            )
215        elif isinstance(descr.kwargs, v0_5.BinarizeAlongAxisKwargs):
216            return cls(
217                input=member_id,
218                output=member_id,
219                threshold=descr.kwargs.threshold,
220                axis=descr.kwargs.axis,
221            )
222        else:
223            assert_never(descr.kwargs)
224
225
226@dataclass
227class Clip(_SimpleOperator):
228    min: Optional[float] = None
229    """minimum value for clipping"""
230    max: Optional[float] = None
231    """maximum value for clipping"""
232
233    def __post_init__(self):
234        assert self.min is not None or self.max is not None, "missing min or max value"
235        assert (
236            self.min is None or self.max is None or self.min < self.max
237        ), f"expected min < max, but {self.min} !< {self.max}"
238
239    def _apply(self, input: Tensor, stat: Stat) -> Tensor:
240        return input.clip(self.min, self.max)
241
242    def get_output_shape(
243        self, input_shape: Mapping[AxisId, int]
244    ) -> Mapping[AxisId, int]:
245        return input_shape
246
247    @classmethod
248    def from_proc_descr(
249        cls, descr: Union[v0_4.ClipDescr, v0_5.ClipDescr], member_id: MemberId
250    ) -> Self:
251        return cls(
252            input=member_id,
253            output=member_id,
254            min=descr.kwargs.min,
255            max=descr.kwargs.max,
256        )
257
258
259@dataclass
260class EnsureDtype(_SimpleOperator):
261    dtype: DTypeStr
262
263    @classmethod
264    def from_proc_descr(cls, descr: v0_5.EnsureDtypeDescr, member_id: MemberId):
265        return cls(input=member_id, output=member_id, dtype=descr.kwargs.dtype)
266
267    def get_descr(self):
268        return v0_5.EnsureDtypeDescr(kwargs=v0_5.EnsureDtypeKwargs(dtype=self.dtype))
269
270    def get_output_shape(
271        self, input_shape: Mapping[AxisId, int]
272    ) -> Mapping[AxisId, int]:
273        return input_shape
274
275    def _apply(self, input: Tensor, stat: Stat) -> Tensor:
276        return input.astype(self.dtype)
277
278
279@dataclass
280class ScaleLinear(_SimpleOperator):
281    gain: Union[float, xr.DataArray] = 1.0
282    """multiplicative factor"""
283
284    offset: Union[float, xr.DataArray] = 0.0
285    """additive term"""
286
287    def _apply(self, input: Tensor, stat: Stat) -> Tensor:
288        return input * self.gain + self.offset
289
290    def get_output_shape(
291        self, input_shape: Mapping[AxisId, int]
292    ) -> Mapping[AxisId, int]:
293        return input_shape
294
295    @classmethod
296    def from_proc_descr(
297        cls,
298        descr: Union[v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr],
299        member_id: MemberId,
300    ) -> Self:
301        kwargs = descr.kwargs
302        if isinstance(kwargs, v0_5.ScaleLinearAlongAxisKwargs):
303            axis = kwargs.axis
304        elif isinstance(kwargs, (v0_4.ScaleLinearKwargs, v0_5.ScaleLinearKwargs)):
305            axis = None
306        else:
307            assert_never(kwargs)
308
309        if axis:
310            gain = xr.DataArray(np.atleast_1d(kwargs.gain), dims=axis)
311            offset = xr.DataArray(np.atleast_1d(kwargs.offset), dims=axis)
312        else:
313            assert (
314                isinstance(kwargs.gain, (float, int)) or len(kwargs.gain) == 1
315            ), kwargs.gain
316            gain = (
317                kwargs.gain if isinstance(kwargs.gain, (float, int)) else kwargs.gain[0]
318            )
319            assert isinstance(kwargs.offset, (float, int)) or len(kwargs.offset) == 1
320            offset = (
321                kwargs.offset
322                if isinstance(kwargs.offset, (float, int))
323                else kwargs.offset[0]
324            )
325
326        return cls(input=member_id, output=member_id, gain=gain, offset=offset)
327
328
329@dataclass
330class ScaleMeanVariance(_SimpleOperator):
331    axes: Optional[Sequence[AxisId]] = None
332    reference_tensor: Optional[MemberId] = None
333    eps: float = 1e-6
334    mean: Union[SampleMean, DatasetMean] = field(init=False)
335    std: Union[SampleStd, DatasetStd] = field(init=False)
336    ref_mean: Union[SampleMean, DatasetMean] = field(init=False)
337    ref_std: Union[SampleStd, DatasetStd] = field(init=False)
338
339    @property
340    def required_measures(self):
341        return {self.mean, self.std, self.ref_mean, self.ref_std}
342
343    def __post_init__(self):
344        axes = None if self.axes is None else tuple(self.axes)
345        ref_tensor = self.reference_tensor or self.input
346        if axes is None or AxisId("batch") not in axes:
347            Mean = SampleMean
348            Std = SampleStd
349        else:
350            Mean = DatasetMean
351            Std = DatasetStd
352
353        self.mean = Mean(member_id=self.input, axes=axes)
354        self.std = Std(member_id=self.input, axes=axes)
355        self.ref_mean = Mean(member_id=ref_tensor, axes=axes)
356        self.ref_std = Std(member_id=ref_tensor, axes=axes)
357
358    def _apply(self, input: Tensor, stat: Stat) -> Tensor:
359        mean = stat[self.mean]
360        std = stat[self.std] + self.eps
361        ref_mean = stat[self.ref_mean]
362        ref_std = stat[self.ref_std] + self.eps
363        return (input - mean) / std * ref_std + ref_mean
364
365    def get_output_shape(
366        self, input_shape: Mapping[AxisId, int]
367    ) -> Mapping[AxisId, int]:
368        return input_shape
369
370    @classmethod
371    def from_proc_descr(
372        cls,
373        descr: Union[v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr],
374        member_id: MemberId,
375    ) -> Self:
376        kwargs = descr.kwargs
377        _, axes = _get_axes(descr.kwargs)
378
379        return cls(
380            input=member_id,
381            output=member_id,
382            reference_tensor=MemberId(str(kwargs.reference_tensor)),
383            axes=axes,
384            eps=kwargs.eps,
385        )
386
387
388def _get_axes(
389    kwargs: Union[
390        v0_4.ZeroMeanUnitVarianceKwargs,
391        v0_5.ZeroMeanUnitVarianceKwargs,
392        v0_4.ScaleRangeKwargs,
393        v0_5.ScaleRangeKwargs,
394        v0_4.ScaleMeanVarianceKwargs,
395        v0_5.ScaleMeanVarianceKwargs,
396    ],
397) -> Tuple[bool, Optional[Tuple[AxisId, ...]]]:
398    if kwargs.axes is None:
399        return True, None
400    elif isinstance(kwargs.axes, str):
401        axes = _convert_axis_ids(kwargs.axes, kwargs["mode"])
402        return AxisId("b") in axes, axes
403    elif isinstance(kwargs.axes, collections.abc.Sequence):
404        axes = tuple(kwargs.axes)
405        return AxisId("batch") in axes, axes
406    else:
407        assert_never(kwargs.axes)
408
409
410@dataclass
411class ScaleRange(_SimpleOperator):
412    lower_percentile: InitVar[Optional[Union[SampleQuantile, DatasetPercentile]]] = None
413    upper_percentile: InitVar[Optional[Union[SampleQuantile, DatasetPercentile]]] = None
414    lower: Union[SampleQuantile, DatasetPercentile] = field(init=False)
415    upper: Union[SampleQuantile, DatasetPercentile] = field(init=False)
416
417    eps: float = 1e-6
418
419    def __post_init__(
420        self,
421        lower_percentile: Optional[Union[SampleQuantile, DatasetPercentile]],
422        upper_percentile: Optional[Union[SampleQuantile, DatasetPercentile]],
423    ):
424        if lower_percentile is None:
425            tid = self.input if upper_percentile is None else upper_percentile.member_id
426            self.lower = DatasetPercentile(q=0.0, member_id=tid)
427        else:
428            self.lower = lower_percentile
429
430        if upper_percentile is None:
431            self.upper = DatasetPercentile(q=1.0, member_id=self.lower.member_id)
432        else:
433            self.upper = upper_percentile
434
435        assert self.lower.member_id == self.upper.member_id
436        assert self.lower.q < self.upper.q
437        assert self.lower.axes == self.upper.axes
438
439    @property
440    def required_measures(self):
441        return {self.lower, self.upper}
442
443    def get_output_shape(
444        self, input_shape: Mapping[AxisId, int]
445    ) -> Mapping[AxisId, int]:
446        return input_shape
447
448    @classmethod
449    def from_proc_descr(
450        cls,
451        descr: Union[v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr],
452        member_id: MemberId,
453    ):
454        kwargs = descr.kwargs
455        ref_tensor = (
456            member_id
457            if kwargs.reference_tensor is None
458            else MemberId(str(kwargs.reference_tensor))
459        )
460        dataset_mode, axes = _get_axes(descr.kwargs)
461        if dataset_mode:
462            Percentile = DatasetPercentile
463        else:
464            Percentile = SampleQuantile
465
466        return cls(
467            input=member_id,
468            output=member_id,
469            lower_percentile=Percentile(
470                q=kwargs.min_percentile / 100, axes=axes, member_id=ref_tensor
471            ),
472            upper_percentile=Percentile(
473                q=kwargs.max_percentile / 100, axes=axes, member_id=ref_tensor
474            ),
475        )
476
477    def _apply(self, input: Tensor, stat: Stat) -> Tensor:
478        lower = stat[self.lower]
479        upper = stat[self.upper]
480        return (input - lower) / (upper - lower + self.eps)
481
482    def get_descr(self):
483        assert self.lower.axes == self.upper.axes
484        assert self.lower.member_id == self.upper.member_id
485
486        return v0_5.ScaleRangeDescr(
487            kwargs=v0_5.ScaleRangeKwargs(
488                axes=self.lower.axes,
489                min_percentile=self.lower.q * 100,
490                max_percentile=self.upper.q * 100,
491                eps=self.eps,
492                reference_tensor=self.lower.member_id,
493            )
494        )
495
496
497@dataclass
498class Sigmoid(_SimpleOperator):
499    """1 / (1 + e^(-input))."""
500
501    def _apply(self, input: Tensor, stat: Stat) -> Tensor:
502        return Tensor(1.0 / (1.0 + np.exp(-input)), dims=input.dims)
503
504    @property
505    def required_measures(self) -> Collection[Measure]:
506        return {}
507
508    def get_output_shape(
509        self, input_shape: Mapping[AxisId, int]
510    ) -> Mapping[AxisId, int]:
511        return input_shape
512
513    @classmethod
514    def from_proc_descr(
515        cls, descr: Union[v0_4.SigmoidDescr, v0_5.SigmoidDescr], member_id: MemberId
516    ) -> Self:
517        assert isinstance(descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr))
518        return cls(input=member_id, output=member_id)
519
520    def get_descr(self):
521        return v0_5.SigmoidDescr()
522
523
524@dataclass
525class ZeroMeanUnitVariance(_SimpleOperator):
526    """normalize to zero mean, unit variance."""
527
528    mean: MeanMeasure
529    std: StdMeasure
530
531    eps: float = 1e-6
532
533    def __post_init__(self):
534        assert self.mean.axes == self.std.axes
535
536    @property
537    def required_measures(self) -> Set[Union[MeanMeasure, StdMeasure]]:
538        return {self.mean, self.std}
539
540    def get_output_shape(
541        self, input_shape: Mapping[AxisId, int]
542    ) -> Mapping[AxisId, int]:
543        return input_shape
544
545    @classmethod
546    def from_proc_descr(
547        cls,
548        descr: Union[v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr],
549        member_id: MemberId,
550    ):
551        dataset_mode, axes = _get_axes(descr.kwargs)
552
553        if dataset_mode:
554            Mean = DatasetMean
555            Std = DatasetStd
556        else:
557            Mean = SampleMean
558            Std = SampleStd
559
560        return cls(
561            input=member_id,
562            output=member_id,
563            mean=Mean(axes=axes, member_id=member_id),
564            std=Std(axes=axes, member_id=member_id),
565        )
566
567    def _apply(self, input: Tensor, stat: Stat) -> Tensor:
568        mean = stat[self.mean]
569        std = stat[self.std]
570        return (input - mean) / (std + self.eps)
571
572    def get_descr(self):
573        return v0_5.ZeroMeanUnitVarianceDescr(
574            kwargs=v0_5.ZeroMeanUnitVarianceKwargs(axes=self.mean.axes, eps=self.eps)
575        )
576
577
578@dataclass
579class FixedZeroMeanUnitVariance(_SimpleOperator):
580    """normalize to zero mean, unit variance with precomputed values."""
581
582    mean: Union[float, xr.DataArray]
583    std: Union[float, xr.DataArray]
584
585    eps: float = 1e-6
586
587    def __post_init__(self):
588        assert (
589            isinstance(self.mean, (int, float))
590            or isinstance(self.std, (int, float))
591            or self.mean.dims == self.std.dims
592        )
593
594    def get_output_shape(
595        self, input_shape: Mapping[AxisId, int]
596    ) -> Mapping[AxisId, int]:
597        return input_shape
598
599    @classmethod
600    def from_proc_descr(
601        cls,
602        descr: v0_5.FixedZeroMeanUnitVarianceDescr,
603        member_id: MemberId,
604    ) -> Self:
605        if isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceKwargs):
606            dims = None
607        elif isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs):
608            dims = (descr.kwargs.axis,)
609        else:
610            assert_never(descr.kwargs)
611
612        return cls(
613            input=member_id,
614            output=member_id,
615            mean=xr.DataArray(descr.kwargs.mean, dims=dims),
616            std=xr.DataArray(descr.kwargs.std, dims=dims),
617        )
618
619    def get_descr(self):
620        if isinstance(self.mean, (int, float)):
621            assert isinstance(self.std, (int, float))
622            kwargs = v0_5.FixedZeroMeanUnitVarianceKwargs(mean=self.mean, std=self.std)
623        else:
624            assert isinstance(self.std, xr.DataArray)
625            assert len(self.mean.dims) == 1
626            kwargs = v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs(
627                axis=AxisId(str(self.mean.dims[0])),
628                mean=list(self.mean),
629                std=list(self.std),
630            )
631
632        return v0_5.FixedZeroMeanUnitVarianceDescr(kwargs=kwargs)
633
634    def _apply(self, input: Tensor, stat: Stat) -> Tensor:
635        return (input - self.mean) / (self.std + self.eps)
636
637
638ProcDescr = Union[
639    v0_4.PreprocessingDescr,
640    v0_4.PostprocessingDescr,
641    v0_5.PreprocessingDescr,
642    v0_5.PostprocessingDescr,
643]
644
645Processing = Union[
646    AddKnownDatasetStats,
647    Binarize,
648    Clip,
649    EnsureDtype,
650    FixedZeroMeanUnitVariance,
651    ScaleLinear,
652    ScaleMeanVariance,
653    ScaleRange,
654    Sigmoid,
655    UpdateStats,
656    ZeroMeanUnitVariance,
657]
658
659
660def get_proc_class(proc_spec: ProcDescr):
661    if isinstance(proc_spec, (v0_4.BinarizeDescr, v0_5.BinarizeDescr)):
662        return Binarize
663    elif isinstance(proc_spec, (v0_4.ClipDescr, v0_5.ClipDescr)):
664        return Clip
665    elif isinstance(proc_spec, v0_5.EnsureDtypeDescr):
666        return EnsureDtype
667    elif isinstance(proc_spec, v0_5.FixedZeroMeanUnitVarianceDescr):
668        return FixedZeroMeanUnitVariance
669    elif isinstance(proc_spec, (v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr)):
670        return ScaleLinear
671    elif isinstance(
672        proc_spec, (v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr)
673    ):
674        return ScaleMeanVariance
675    elif isinstance(proc_spec, (v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr)):
676        return ScaleRange
677    elif isinstance(proc_spec, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)):
678        return Sigmoid
679    elif (
680        isinstance(proc_spec, v0_4.ZeroMeanUnitVarianceDescr)
681        and proc_spec.kwargs.mode == "fixed"
682    ):
683        return FixedZeroMeanUnitVariance
684    elif isinstance(
685        proc_spec,
686        (v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr),
687    ):
688        return ZeroMeanUnitVariance
689    else:
690        assert_never(proc_spec)
@dataclass
class AddKnownDatasetStats(bioimageio.core._op_base.BlockedOperator):
106@dataclass
107class AddKnownDatasetStats(BlockedOperator):
108    dataset_stats: Mapping[DatasetMeasure, MeasureValue]
109
110    @property
111    def required_measures(self) -> Set[Measure]:
112        return set()
113
114    def __call__(self, sample: Union[Sample, SampleBlock]) -> None:
115        sample.stat.update(self.dataset_stats.items())
AddKnownDatasetStats( dataset_stats: Mapping[Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer>, return_type=PydanticUndefined, when_used='always')]]])
dataset_stats: Mapping[Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7f9a7099e840>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7f9a7099ea20>, return_type=PydanticUndefined, when_used='always')]]]
required_measures: Set[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]
110    @property
111    def required_measures(self) -> Set[Measure]:
112        return set()
@dataclass
class UpdateStats(bioimageio.core._op_base.Operator):
154@dataclass
155class UpdateStats(Operator):
156    """Calculates sample and/or dataset measures"""
157
158    stats_calculator: StatsCalculator
159    """`StatsCalculator` to be used by this operator."""
160    keep_updating_initial_dataset_stats: bool = False
161    """indicates if operator calls should keep updating initial dataset statistics or not;
162    if the `stats_calculator` was not provided with any initial dataset statistics,
163    these are always updated with every new sample.
164    """
165    _keep_updating_dataset_stats: bool = field(init=False)
166
167    @property
168    def required_measures(self) -> Set[Measure]:
169        return set()
170
171    def __post_init__(self):
172        self._keep_updating_dataset_stats = (
173            self.keep_updating_initial_dataset_stats
174            or not self.stats_calculator.has_dataset_measures
175        )
176
177    def __call__(self, sample: Union[Sample, SampleBlockWithOrigin]) -> None:
178        if isinstance(sample, SampleBlockWithOrigin):
179            # update stats with whole sample on first block
180            if sample.block_index != 0:
181                return
182
183            origin = sample.origin
184        else:
185            origin = sample
186
187        if self._keep_updating_dataset_stats:
188            sample.stat.update(self.stats_calculator.update_and_get_all(origin))
189        else:
190            sample.stat.update(self.stats_calculator.skip_update_and_get_all(origin))

Calculates sample and/or dataset measures

UpdateStats( stats_calculator: bioimageio.core.stat_calculators.StatsCalculator, keep_updating_initial_dataset_stats: bool = False)

StatsCalculator to be used by this operator.

keep_updating_initial_dataset_stats: bool = False

indicates if operator calls should keep updating initial dataset statistics or not; if the stats_calculator was not provided with any initial dataset statistics, these are always updated with every new sample.

required_measures: Set[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]
167    @property
168    def required_measures(self) -> Set[Measure]:
169        return set()
@dataclass
class Binarize(_SimpleOperator):
193@dataclass
194class Binarize(_SimpleOperator):
195    """'output = tensor > threshold'."""
196
197    threshold: Union[float, Sequence[float]]
198    axis: Optional[AxisId] = None
199
200    def _apply(self, input: Tensor, stat: Stat) -> Tensor:
201        return input > self.threshold
202
203    def get_output_shape(
204        self, input_shape: Mapping[AxisId, int]
205    ) -> Mapping[AxisId, int]:
206        return input_shape
207
208    @classmethod
209    def from_proc_descr(
210        cls, descr: Union[v0_4.BinarizeDescr, v0_5.BinarizeDescr], member_id: MemberId
211    ) -> Self:
212        if isinstance(descr.kwargs, (v0_4.BinarizeKwargs, v0_5.BinarizeKwargs)):
213            return cls(
214                input=member_id, output=member_id, threshold=descr.kwargs.threshold
215            )
216        elif isinstance(descr.kwargs, v0_5.BinarizeAlongAxisKwargs):
217            return cls(
218                input=member_id,
219                output=member_id,
220                threshold=descr.kwargs.threshold,
221                axis=descr.kwargs.axis,
222            )
223        else:
224            assert_never(descr.kwargs)

'output = tensor > threshold'.

Binarize( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId, threshold: Union[float, Sequence[float]], axis: Optional[bioimageio.spec.model.v0_5.AxisId] = None)
threshold: Union[float, Sequence[float]]
axis: Optional[bioimageio.spec.model.v0_5.AxisId] = None
def get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
203    def get_output_shape(
204        self, input_shape: Mapping[AxisId, int]
205    ) -> Mapping[AxisId, int]:
206        return input_shape
@classmethod
def from_proc_descr( cls, descr: Union[bioimageio.spec.model.v0_4.BinarizeDescr, bioimageio.spec.model.v0_5.BinarizeDescr], member_id: bioimageio.spec.model.v0_5.TensorId) -> Self:
208    @classmethod
209    def from_proc_descr(
210        cls, descr: Union[v0_4.BinarizeDescr, v0_5.BinarizeDescr], member_id: MemberId
211    ) -> Self:
212        if isinstance(descr.kwargs, (v0_4.BinarizeKwargs, v0_5.BinarizeKwargs)):
213            return cls(
214                input=member_id, output=member_id, threshold=descr.kwargs.threshold
215            )
216        elif isinstance(descr.kwargs, v0_5.BinarizeAlongAxisKwargs):
217            return cls(
218                input=member_id,
219                output=member_id,
220                threshold=descr.kwargs.threshold,
221                axis=descr.kwargs.axis,
222            )
223        else:
224            assert_never(descr.kwargs)
@dataclass
class Clip(_SimpleOperator):
227@dataclass
228class Clip(_SimpleOperator):
229    min: Optional[float] = None
230    """minimum value for clipping"""
231    max: Optional[float] = None
232    """maximum value for clipping"""
233
234    def __post_init__(self):
235        assert self.min is not None or self.max is not None, "missing min or max value"
236        assert (
237            self.min is None or self.max is None or self.min < self.max
238        ), f"expected min < max, but {self.min} !< {self.max}"
239
240    def _apply(self, input: Tensor, stat: Stat) -> Tensor:
241        return input.clip(self.min, self.max)
242
243    def get_output_shape(
244        self, input_shape: Mapping[AxisId, int]
245    ) -> Mapping[AxisId, int]:
246        return input_shape
247
248    @classmethod
249    def from_proc_descr(
250        cls, descr: Union[v0_4.ClipDescr, v0_5.ClipDescr], member_id: MemberId
251    ) -> Self:
252        return cls(
253            input=member_id,
254            output=member_id,
255            min=descr.kwargs.min,
256            max=descr.kwargs.max,
257        )
Clip( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId, min: Optional[float] = None, max: Optional[float] = None)
min: Optional[float] = None

minimum value for clipping

max: Optional[float] = None

maximum value for clipping

def get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
243    def get_output_shape(
244        self, input_shape: Mapping[AxisId, int]
245    ) -> Mapping[AxisId, int]:
246        return input_shape
@classmethod
def from_proc_descr( cls, descr: Union[bioimageio.spec.model.v0_4.ClipDescr, bioimageio.spec.model.v0_5.ClipDescr], member_id: bioimageio.spec.model.v0_5.TensorId) -> Self:
248    @classmethod
249    def from_proc_descr(
250        cls, descr: Union[v0_4.ClipDescr, v0_5.ClipDescr], member_id: MemberId
251    ) -> Self:
252        return cls(
253            input=member_id,
254            output=member_id,
255            min=descr.kwargs.min,
256            max=descr.kwargs.max,
257        )
@dataclass
class EnsureDtype(_SimpleOperator):
260@dataclass
261class EnsureDtype(_SimpleOperator):
262    dtype: DTypeStr
263
264    @classmethod
265    def from_proc_descr(cls, descr: v0_5.EnsureDtypeDescr, member_id: MemberId):
266        return cls(input=member_id, output=member_id, dtype=descr.kwargs.dtype)
267
268    def get_descr(self):
269        return v0_5.EnsureDtypeDescr(kwargs=v0_5.EnsureDtypeKwargs(dtype=self.dtype))
270
271    def get_output_shape(
272        self, input_shape: Mapping[AxisId, int]
273    ) -> Mapping[AxisId, int]:
274        return input_shape
275
276    def _apply(self, input: Tensor, stat: Stat) -> Tensor:
277        return input.astype(self.dtype)
EnsureDtype( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId, dtype: Literal['bool', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64'])
dtype: Literal['bool', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64']
@classmethod
def from_proc_descr( cls, descr: bioimageio.spec.model.v0_5.EnsureDtypeDescr, member_id: bioimageio.spec.model.v0_5.TensorId):
264    @classmethod
265    def from_proc_descr(cls, descr: v0_5.EnsureDtypeDescr, member_id: MemberId):
266        return cls(input=member_id, output=member_id, dtype=descr.kwargs.dtype)
def get_descr(self):
268    def get_descr(self):
269        return v0_5.EnsureDtypeDescr(kwargs=v0_5.EnsureDtypeKwargs(dtype=self.dtype))
def get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
271    def get_output_shape(
272        self, input_shape: Mapping[AxisId, int]
273    ) -> Mapping[AxisId, int]:
274        return input_shape
@dataclass
class ScaleLinear(_SimpleOperator):
280@dataclass
281class ScaleLinear(_SimpleOperator):
282    gain: Union[float, xr.DataArray] = 1.0
283    """multiplicative factor"""
284
285    offset: Union[float, xr.DataArray] = 0.0
286    """additive term"""
287
288    def _apply(self, input: Tensor, stat: Stat) -> Tensor:
289        return input * self.gain + self.offset
290
291    def get_output_shape(
292        self, input_shape: Mapping[AxisId, int]
293    ) -> Mapping[AxisId, int]:
294        return input_shape
295
296    @classmethod
297    def from_proc_descr(
298        cls,
299        descr: Union[v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr],
300        member_id: MemberId,
301    ) -> Self:
302        kwargs = descr.kwargs
303        if isinstance(kwargs, v0_5.ScaleLinearAlongAxisKwargs):
304            axis = kwargs.axis
305        elif isinstance(kwargs, (v0_4.ScaleLinearKwargs, v0_5.ScaleLinearKwargs)):
306            axis = None
307        else:
308            assert_never(kwargs)
309
310        if axis:
311            gain = xr.DataArray(np.atleast_1d(kwargs.gain), dims=axis)
312            offset = xr.DataArray(np.atleast_1d(kwargs.offset), dims=axis)
313        else:
314            assert (
315                isinstance(kwargs.gain, (float, int)) or len(kwargs.gain) == 1
316            ), kwargs.gain
317            gain = (
318                kwargs.gain if isinstance(kwargs.gain, (float, int)) else kwargs.gain[0]
319            )
320            assert isinstance(kwargs.offset, (float, int)) or len(kwargs.offset) == 1
321            offset = (
322                kwargs.offset
323                if isinstance(kwargs.offset, (float, int))
324                else kwargs.offset[0]
325            )
326
327        return cls(input=member_id, output=member_id, gain=gain, offset=offset)
ScaleLinear( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId, gain: Union[float, xarray.core.dataarray.DataArray] = 1.0, offset: Union[float, xarray.core.dataarray.DataArray] = 0.0)
gain: Union[float, xarray.core.dataarray.DataArray] = 1.0

multiplicative factor

offset: Union[float, xarray.core.dataarray.DataArray] = 0.0

additive term

def get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
291    def get_output_shape(
292        self, input_shape: Mapping[AxisId, int]
293    ) -> Mapping[AxisId, int]:
294        return input_shape
@classmethod
def from_proc_descr( cls, descr: Union[bioimageio.spec.model.v0_4.ScaleLinearDescr, bioimageio.spec.model.v0_5.ScaleLinearDescr], member_id: bioimageio.spec.model.v0_5.TensorId) -> Self:
296    @classmethod
297    def from_proc_descr(
298        cls,
299        descr: Union[v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr],
300        member_id: MemberId,
301    ) -> Self:
302        kwargs = descr.kwargs
303        if isinstance(kwargs, v0_5.ScaleLinearAlongAxisKwargs):
304            axis = kwargs.axis
305        elif isinstance(kwargs, (v0_4.ScaleLinearKwargs, v0_5.ScaleLinearKwargs)):
306            axis = None
307        else:
308            assert_never(kwargs)
309
310        if axis:
311            gain = xr.DataArray(np.atleast_1d(kwargs.gain), dims=axis)
312            offset = xr.DataArray(np.atleast_1d(kwargs.offset), dims=axis)
313        else:
314            assert (
315                isinstance(kwargs.gain, (float, int)) or len(kwargs.gain) == 1
316            ), kwargs.gain
317            gain = (
318                kwargs.gain if isinstance(kwargs.gain, (float, int)) else kwargs.gain[0]
319            )
320            assert isinstance(kwargs.offset, (float, int)) or len(kwargs.offset) == 1
321            offset = (
322                kwargs.offset
323                if isinstance(kwargs.offset, (float, int))
324                else kwargs.offset[0]
325            )
326
327        return cls(input=member_id, output=member_id, gain=gain, offset=offset)
@dataclass
class ScaleMeanVariance(_SimpleOperator):
330@dataclass
331class ScaleMeanVariance(_SimpleOperator):
332    axes: Optional[Sequence[AxisId]] = None
333    reference_tensor: Optional[MemberId] = None
334    eps: float = 1e-6
335    mean: Union[SampleMean, DatasetMean] = field(init=False)
336    std: Union[SampleStd, DatasetStd] = field(init=False)
337    ref_mean: Union[SampleMean, DatasetMean] = field(init=False)
338    ref_std: Union[SampleStd, DatasetStd] = field(init=False)
339
340    @property
341    def required_measures(self):
342        return {self.mean, self.std, self.ref_mean, self.ref_std}
343
344    def __post_init__(self):
345        axes = None if self.axes is None else tuple(self.axes)
346        ref_tensor = self.reference_tensor or self.input
347        if axes is None or AxisId("batch") not in axes:
348            Mean = SampleMean
349            Std = SampleStd
350        else:
351            Mean = DatasetMean
352            Std = DatasetStd
353
354        self.mean = Mean(member_id=self.input, axes=axes)
355        self.std = Std(member_id=self.input, axes=axes)
356        self.ref_mean = Mean(member_id=ref_tensor, axes=axes)
357        self.ref_std = Std(member_id=ref_tensor, axes=axes)
358
359    def _apply(self, input: Tensor, stat: Stat) -> Tensor:
360        mean = stat[self.mean]
361        std = stat[self.std] + self.eps
362        ref_mean = stat[self.ref_mean]
363        ref_std = stat[self.ref_std] + self.eps
364        return (input - mean) / std * ref_std + ref_mean
365
366    def get_output_shape(
367        self, input_shape: Mapping[AxisId, int]
368    ) -> Mapping[AxisId, int]:
369        return input_shape
370
371    @classmethod
372    def from_proc_descr(
373        cls,
374        descr: Union[v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr],
375        member_id: MemberId,
376    ) -> Self:
377        kwargs = descr.kwargs
378        _, axes = _get_axes(descr.kwargs)
379
380        return cls(
381            input=member_id,
382            output=member_id,
383            reference_tensor=MemberId(str(kwargs.reference_tensor)),
384            axes=axes,
385            eps=kwargs.eps,
386        )
ScaleMeanVariance( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId, axes: Optional[Sequence[bioimageio.spec.model.v0_5.AxisId]] = None, reference_tensor: Optional[bioimageio.spec.model.v0_5.TensorId] = None, eps: float = 1e-06)
axes: Optional[Sequence[bioimageio.spec.model.v0_5.AxisId]] = None
reference_tensor: Optional[bioimageio.spec.model.v0_5.TensorId] = None
eps: float = 1e-06
required_measures
340    @property
341    def required_measures(self):
342        return {self.mean, self.std, self.ref_mean, self.ref_std}
def get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
366    def get_output_shape(
367        self, input_shape: Mapping[AxisId, int]
368    ) -> Mapping[AxisId, int]:
369        return input_shape
@classmethod
def from_proc_descr( cls, descr: Union[bioimageio.spec.model.v0_4.ScaleMeanVarianceDescr, bioimageio.spec.model.v0_5.ScaleMeanVarianceDescr], member_id: bioimageio.spec.model.v0_5.TensorId) -> Self:
371    @classmethod
372    def from_proc_descr(
373        cls,
374        descr: Union[v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr],
375        member_id: MemberId,
376    ) -> Self:
377        kwargs = descr.kwargs
378        _, axes = _get_axes(descr.kwargs)
379
380        return cls(
381            input=member_id,
382            output=member_id,
383            reference_tensor=MemberId(str(kwargs.reference_tensor)),
384            axes=axes,
385            eps=kwargs.eps,
386        )
Inherited Members
_SimpleOperator
input
output
@dataclass
class ScaleRange(_SimpleOperator):
411@dataclass
412class ScaleRange(_SimpleOperator):
413    lower_percentile: InitVar[Optional[Union[SampleQuantile, DatasetPercentile]]] = None
414    upper_percentile: InitVar[Optional[Union[SampleQuantile, DatasetPercentile]]] = None
415    lower: Union[SampleQuantile, DatasetPercentile] = field(init=False)
416    upper: Union[SampleQuantile, DatasetPercentile] = field(init=False)
417
418    eps: float = 1e-6
419
420    def __post_init__(
421        self,
422        lower_percentile: Optional[Union[SampleQuantile, DatasetPercentile]],
423        upper_percentile: Optional[Union[SampleQuantile, DatasetPercentile]],
424    ):
425        if lower_percentile is None:
426            tid = self.input if upper_percentile is None else upper_percentile.member_id
427            self.lower = DatasetPercentile(q=0.0, member_id=tid)
428        else:
429            self.lower = lower_percentile
430
431        if upper_percentile is None:
432            self.upper = DatasetPercentile(q=1.0, member_id=self.lower.member_id)
433        else:
434            self.upper = upper_percentile
435
436        assert self.lower.member_id == self.upper.member_id
437        assert self.lower.q < self.upper.q
438        assert self.lower.axes == self.upper.axes
439
440    @property
441    def required_measures(self):
442        return {self.lower, self.upper}
443
444    def get_output_shape(
445        self, input_shape: Mapping[AxisId, int]
446    ) -> Mapping[AxisId, int]:
447        return input_shape
448
449    @classmethod
450    def from_proc_descr(
451        cls,
452        descr: Union[v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr],
453        member_id: MemberId,
454    ):
455        kwargs = descr.kwargs
456        ref_tensor = (
457            member_id
458            if kwargs.reference_tensor is None
459            else MemberId(str(kwargs.reference_tensor))
460        )
461        dataset_mode, axes = _get_axes(descr.kwargs)
462        if dataset_mode:
463            Percentile = DatasetPercentile
464        else:
465            Percentile = SampleQuantile
466
467        return cls(
468            input=member_id,
469            output=member_id,
470            lower_percentile=Percentile(
471                q=kwargs.min_percentile / 100, axes=axes, member_id=ref_tensor
472            ),
473            upper_percentile=Percentile(
474                q=kwargs.max_percentile / 100, axes=axes, member_id=ref_tensor
475            ),
476        )
477
478    def _apply(self, input: Tensor, stat: Stat) -> Tensor:
479        lower = stat[self.lower]
480        upper = stat[self.upper]
481        return (input - lower) / (upper - lower + self.eps)
482
483    def get_descr(self):
484        assert self.lower.axes == self.upper.axes
485        assert self.lower.member_id == self.upper.member_id
486
487        return v0_5.ScaleRangeDescr(
488            kwargs=v0_5.ScaleRangeKwargs(
489                axes=self.lower.axes,
490                min_percentile=self.lower.q * 100,
491                max_percentile=self.upper.q * 100,
492                eps=self.eps,
493                reference_tensor=self.lower.member_id,
494            )
495        )
ScaleRange( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId, lower_percentile: dataclasses.InitVar[typing.Union[bioimageio.core.stat_measures.SampleQuantile, bioimageio.core.stat_measures.DatasetPercentile, NoneType]] = None, upper_percentile: dataclasses.InitVar[typing.Union[bioimageio.core.stat_measures.SampleQuantile, bioimageio.core.stat_measures.DatasetPercentile, NoneType]] = None, eps: float = 1e-06)
lower_percentile: dataclasses.InitVar[typing.Union[bioimageio.core.stat_measures.SampleQuantile, bioimageio.core.stat_measures.DatasetPercentile, NoneType]] = None
upper_percentile: dataclasses.InitVar[typing.Union[bioimageio.core.stat_measures.SampleQuantile, bioimageio.core.stat_measures.DatasetPercentile, NoneType]] = None
eps: float = 1e-06
required_measures
440    @property
441    def required_measures(self):
442        return {self.lower, self.upper}
def get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
444    def get_output_shape(
445        self, input_shape: Mapping[AxisId, int]
446    ) -> Mapping[AxisId, int]:
447        return input_shape
@classmethod
def from_proc_descr( cls, descr: Union[bioimageio.spec.model.v0_4.ScaleRangeDescr, bioimageio.spec.model.v0_5.ScaleRangeDescr], member_id: bioimageio.spec.model.v0_5.TensorId):
449    @classmethod
450    def from_proc_descr(
451        cls,
452        descr: Union[v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr],
453        member_id: MemberId,
454    ):
455        kwargs = descr.kwargs
456        ref_tensor = (
457            member_id
458            if kwargs.reference_tensor is None
459            else MemberId(str(kwargs.reference_tensor))
460        )
461        dataset_mode, axes = _get_axes(descr.kwargs)
462        if dataset_mode:
463            Percentile = DatasetPercentile
464        else:
465            Percentile = SampleQuantile
466
467        return cls(
468            input=member_id,
469            output=member_id,
470            lower_percentile=Percentile(
471                q=kwargs.min_percentile / 100, axes=axes, member_id=ref_tensor
472            ),
473            upper_percentile=Percentile(
474                q=kwargs.max_percentile / 100, axes=axes, member_id=ref_tensor
475            ),
476        )
def get_descr(self):
483    def get_descr(self):
484        assert self.lower.axes == self.upper.axes
485        assert self.lower.member_id == self.upper.member_id
486
487        return v0_5.ScaleRangeDescr(
488            kwargs=v0_5.ScaleRangeKwargs(
489                axes=self.lower.axes,
490                min_percentile=self.lower.q * 100,
491                max_percentile=self.upper.q * 100,
492                eps=self.eps,
493                reference_tensor=self.lower.member_id,
494            )
495        )
Inherited Members
_SimpleOperator
input
output
@dataclass
class Sigmoid(_SimpleOperator):
498@dataclass
499class Sigmoid(_SimpleOperator):
500    """1 / (1 + e^(-input))."""
501
502    def _apply(self, input: Tensor, stat: Stat) -> Tensor:
503        return Tensor(1.0 / (1.0 + np.exp(-input)), dims=input.dims)
504
505    @property
506    def required_measures(self) -> Collection[Measure]:
507        return {}
508
509    def get_output_shape(
510        self, input_shape: Mapping[AxisId, int]
511    ) -> Mapping[AxisId, int]:
512        return input_shape
513
514    @classmethod
515    def from_proc_descr(
516        cls, descr: Union[v0_4.SigmoidDescr, v0_5.SigmoidDescr], member_id: MemberId
517    ) -> Self:
518        assert isinstance(descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr))
519        return cls(input=member_id, output=member_id)
520
521    def get_descr(self):
522        return v0_5.SigmoidDescr()

1 / (1 + e^(-input)).

required_measures: Collection[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]
505    @property
506    def required_measures(self) -> Collection[Measure]:
507        return {}
def get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
509    def get_output_shape(
510        self, input_shape: Mapping[AxisId, int]
511    ) -> Mapping[AxisId, int]:
512        return input_shape
@classmethod
def from_proc_descr( cls, descr: Union[bioimageio.spec.model.v0_4.SigmoidDescr, bioimageio.spec.model.v0_5.SigmoidDescr], member_id: bioimageio.spec.model.v0_5.TensorId) -> Self:
514    @classmethod
515    def from_proc_descr(
516        cls, descr: Union[v0_4.SigmoidDescr, v0_5.SigmoidDescr], member_id: MemberId
517    ) -> Self:
518        assert isinstance(descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr))
519        return cls(input=member_id, output=member_id)
def get_descr(self):
521    def get_descr(self):
522        return v0_5.SigmoidDescr()
Inherited Members
_SimpleOperator
input
output
@dataclass
class ZeroMeanUnitVariance(_SimpleOperator):
525@dataclass
526class ZeroMeanUnitVariance(_SimpleOperator):
527    """normalize to zero mean, unit variance."""
528
529    mean: MeanMeasure
530    std: StdMeasure
531
532    eps: float = 1e-6
533
534    def __post_init__(self):
535        assert self.mean.axes == self.std.axes
536
537    @property
538    def required_measures(self) -> Set[Union[MeanMeasure, StdMeasure]]:
539        return {self.mean, self.std}
540
541    def get_output_shape(
542        self, input_shape: Mapping[AxisId, int]
543    ) -> Mapping[AxisId, int]:
544        return input_shape
545
546    @classmethod
547    def from_proc_descr(
548        cls,
549        descr: Union[v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr],
550        member_id: MemberId,
551    ):
552        dataset_mode, axes = _get_axes(descr.kwargs)
553
554        if dataset_mode:
555            Mean = DatasetMean
556            Std = DatasetStd
557        else:
558            Mean = SampleMean
559            Std = SampleStd
560
561        return cls(
562            input=member_id,
563            output=member_id,
564            mean=Mean(axes=axes, member_id=member_id),
565            std=Std(axes=axes, member_id=member_id),
566        )
567
568    def _apply(self, input: Tensor, stat: Stat) -> Tensor:
569        mean = stat[self.mean]
570        std = stat[self.std]
571        return (input - mean) / (std + self.eps)
572
573    def get_descr(self):
574        return v0_5.ZeroMeanUnitVarianceDescr(
575            kwargs=v0_5.ZeroMeanUnitVarianceKwargs(axes=self.mean.axes, eps=self.eps)
576        )

normalize to zero mean, unit variance.

eps: float = 1e-06
537    @property
538    def required_measures(self) -> Set[Union[MeanMeasure, StdMeasure]]:
539        return {self.mean, self.std}
def get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
541    def get_output_shape(
542        self, input_shape: Mapping[AxisId, int]
543    ) -> Mapping[AxisId, int]:
544        return input_shape
546    @classmethod
547    def from_proc_descr(
548        cls,
549        descr: Union[v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr],
550        member_id: MemberId,
551    ):
552        dataset_mode, axes = _get_axes(descr.kwargs)
553
554        if dataset_mode:
555            Mean = DatasetMean
556            Std = DatasetStd
557        else:
558            Mean = SampleMean
559            Std = SampleStd
560
561        return cls(
562            input=member_id,
563            output=member_id,
564            mean=Mean(axes=axes, member_id=member_id),
565            std=Std(axes=axes, member_id=member_id),
566        )
def get_descr(self):
573    def get_descr(self):
574        return v0_5.ZeroMeanUnitVarianceDescr(
575            kwargs=v0_5.ZeroMeanUnitVarianceKwargs(axes=self.mean.axes, eps=self.eps)
576        )
Inherited Members
_SimpleOperator
input
output
@dataclass
class FixedZeroMeanUnitVariance(_SimpleOperator):
579@dataclass
580class FixedZeroMeanUnitVariance(_SimpleOperator):
581    """normalize to zero mean, unit variance with precomputed values."""
582
583    mean: Union[float, xr.DataArray]
584    std: Union[float, xr.DataArray]
585
586    eps: float = 1e-6
587
588    def __post_init__(self):
589        assert (
590            isinstance(self.mean, (int, float))
591            or isinstance(self.std, (int, float))
592            or self.mean.dims == self.std.dims
593        )
594
595    def get_output_shape(
596        self, input_shape: Mapping[AxisId, int]
597    ) -> Mapping[AxisId, int]:
598        return input_shape
599
600    @classmethod
601    def from_proc_descr(
602        cls,
603        descr: v0_5.FixedZeroMeanUnitVarianceDescr,
604        member_id: MemberId,
605    ) -> Self:
606        if isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceKwargs):
607            dims = None
608        elif isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs):
609            dims = (descr.kwargs.axis,)
610        else:
611            assert_never(descr.kwargs)
612
613        return cls(
614            input=member_id,
615            output=member_id,
616            mean=xr.DataArray(descr.kwargs.mean, dims=dims),
617            std=xr.DataArray(descr.kwargs.std, dims=dims),
618        )
619
620    def get_descr(self):
621        if isinstance(self.mean, (int, float)):
622            assert isinstance(self.std, (int, float))
623            kwargs = v0_5.FixedZeroMeanUnitVarianceKwargs(mean=self.mean, std=self.std)
624        else:
625            assert isinstance(self.std, xr.DataArray)
626            assert len(self.mean.dims) == 1
627            kwargs = v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs(
628                axis=AxisId(str(self.mean.dims[0])),
629                mean=list(self.mean),
630                std=list(self.std),
631            )
632
633        return v0_5.FixedZeroMeanUnitVarianceDescr(kwargs=kwargs)
634
635    def _apply(self, input: Tensor, stat: Stat) -> Tensor:
636        return (input - self.mean) / (self.std + self.eps)

normalize to zero mean, unit variance with precomputed values.

FixedZeroMeanUnitVariance( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId, mean: Union[float, xarray.core.dataarray.DataArray], std: Union[float, xarray.core.dataarray.DataArray], eps: float = 1e-06)
mean: Union[float, xarray.core.dataarray.DataArray]
std: Union[float, xarray.core.dataarray.DataArray]
eps: float = 1e-06
def get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
595    def get_output_shape(
596        self, input_shape: Mapping[AxisId, int]
597    ) -> Mapping[AxisId, int]:
598        return input_shape
@classmethod
def from_proc_descr( cls, descr: bioimageio.spec.model.v0_5.FixedZeroMeanUnitVarianceDescr, member_id: bioimageio.spec.model.v0_5.TensorId) -> Self:
600    @classmethod
601    def from_proc_descr(
602        cls,
603        descr: v0_5.FixedZeroMeanUnitVarianceDescr,
604        member_id: MemberId,
605    ) -> Self:
606        if isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceKwargs):
607            dims = None
608        elif isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs):
609            dims = (descr.kwargs.axis,)
610        else:
611            assert_never(descr.kwargs)
612
613        return cls(
614            input=member_id,
615            output=member_id,
616            mean=xr.DataArray(descr.kwargs.mean, dims=dims),
617            std=xr.DataArray(descr.kwargs.std, dims=dims),
618        )
def get_descr(self):
620    def get_descr(self):
621        if isinstance(self.mean, (int, float)):
622            assert isinstance(self.std, (int, float))
623            kwargs = v0_5.FixedZeroMeanUnitVarianceKwargs(mean=self.mean, std=self.std)
624        else:
625            assert isinstance(self.std, xr.DataArray)
626            assert len(self.mean.dims) == 1
627            kwargs = v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs(
628                axis=AxisId(str(self.mean.dims[0])),
629                mean=list(self.mean),
630                std=list(self.std),
631            )
632
633        return v0_5.FixedZeroMeanUnitVarianceDescr(kwargs=kwargs)
ProcDescr = typing.Union[typing.Annotated[typing.Union[bioimageio.spec.model.v0_4.BinarizeDescr, bioimageio.spec.model.v0_4.ClipDescr, bioimageio.spec.model.v0_4.ScaleLinearDescr, bioimageio.spec.model.v0_4.SigmoidDescr, bioimageio.spec.model.v0_4.ZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_4.ScaleRangeDescr], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], typing.Annotated[typing.Union[bioimageio.spec.model.v0_4.BinarizeDescr, bioimageio.spec.model.v0_4.ClipDescr, bioimageio.spec.model.v0_4.ScaleLinearDescr, bioimageio.spec.model.v0_4.SigmoidDescr, bioimageio.spec.model.v0_4.ZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_4.ScaleRangeDescr, bioimageio.spec.model.v0_4.ScaleMeanVarianceDescr], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], typing.Annotated[typing.Union[bioimageio.spec.model.v0_5.BinarizeDescr, bioimageio.spec.model.v0_5.ClipDescr, bioimageio.spec.model.v0_5.EnsureDtypeDescr, bioimageio.spec.model.v0_5.ScaleLinearDescr, bioimageio.spec.model.v0_5.SigmoidDescr, bioimageio.spec.model.v0_5.FixedZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_5.ZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_5.ScaleRangeDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)], typing.Annotated[typing.Union[bioimageio.spec.model.v0_5.BinarizeDescr, bioimageio.spec.model.v0_5.ClipDescr, bioimageio.spec.model.v0_5.EnsureDtypeDescr, bioimageio.spec.model.v0_5.ScaleLinearDescr, bioimageio.spec.model.v0_5.SigmoidDescr, bioimageio.spec.model.v0_5.FixedZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_5.ZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_5.ScaleRangeDescr, bioimageio.spec.model.v0_5.ScaleMeanVarianceDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]
def get_proc_class( proc_spec: Union[Annotated[Union[bioimageio.spec.model.v0_4.BinarizeDescr, bioimageio.spec.model.v0_4.ClipDescr, bioimageio.spec.model.v0_4.ScaleLinearDescr, bioimageio.spec.model.v0_4.SigmoidDescr, bioimageio.spec.model.v0_4.ZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_4.ScaleRangeDescr], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.spec.model.v0_4.BinarizeDescr, bioimageio.spec.model.v0_4.ClipDescr, bioimageio.spec.model.v0_4.ScaleLinearDescr, bioimageio.spec.model.v0_4.SigmoidDescr, bioimageio.spec.model.v0_4.ZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_4.ScaleRangeDescr, bioimageio.spec.model.v0_4.ScaleMeanVarianceDescr], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.spec.model.v0_5.BinarizeDescr, bioimageio.spec.model.v0_5.ClipDescr, bioimageio.spec.model.v0_5.EnsureDtypeDescr, bioimageio.spec.model.v0_5.ScaleLinearDescr, bioimageio.spec.model.v0_5.SigmoidDescr, bioimageio.spec.model.v0_5.FixedZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_5.ZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_5.ScaleRangeDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.spec.model.v0_5.BinarizeDescr, bioimageio.spec.model.v0_5.ClipDescr, bioimageio.spec.model.v0_5.EnsureDtypeDescr, bioimageio.spec.model.v0_5.ScaleLinearDescr, bioimageio.spec.model.v0_5.SigmoidDescr, bioimageio.spec.model.v0_5.FixedZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_5.ZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_5.ScaleRangeDescr, bioimageio.spec.model.v0_5.ScaleMeanVarianceDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]):
661def get_proc_class(proc_spec: ProcDescr):
662    if isinstance(proc_spec, (v0_4.BinarizeDescr, v0_5.BinarizeDescr)):
663        return Binarize
664    elif isinstance(proc_spec, (v0_4.ClipDescr, v0_5.ClipDescr)):
665        return Clip
666    elif isinstance(proc_spec, v0_5.EnsureDtypeDescr):
667        return EnsureDtype
668    elif isinstance(proc_spec, v0_5.FixedZeroMeanUnitVarianceDescr):
669        return FixedZeroMeanUnitVariance
670    elif isinstance(proc_spec, (v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr)):
671        return ScaleLinear
672    elif isinstance(
673        proc_spec, (v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr)
674    ):
675        return ScaleMeanVariance
676    elif isinstance(proc_spec, (v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr)):
677        return ScaleRange
678    elif isinstance(proc_spec, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)):
679        return Sigmoid
680    elif (
681        isinstance(proc_spec, v0_4.ZeroMeanUnitVarianceDescr)
682        and proc_spec.kwargs.mode == "fixed"
683    ):
684        return FixedZeroMeanUnitVariance
685    elif isinstance(
686        proc_spec,
687        (v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr),
688    ):
689        return ZeroMeanUnitVariance
690    else:
691        assert_never(proc_spec)