bioimageio.core.proc_ops

  1import collections.abc
  2from abc import ABC, abstractmethod
  3from dataclasses import InitVar, dataclass, field
  4from typing import (
  5    Collection,
  6    Literal,
  7    Mapping,
  8    Optional,
  9    Sequence,
 10    Set,
 11    Tuple,
 12    Union,
 13)
 14
 15import numpy as np
 16import xarray as xr
 17from typing_extensions import Self, assert_never
 18
 19from bioimageio.core.digest_spec import get_member_id
 20from bioimageio.spec.model import v0_4, v0_5
 21from bioimageio.spec.model.v0_5 import (
 22    _convert_proc,  # pyright: ignore [reportPrivateUsage]
 23)
 24
 25from ._op_base import BlockedOperator, Operator
 26from .axis import AxisId, PerAxis
 27from .block import Block
 28from .common import DTypeStr, MemberId
 29from .sample import Sample, SampleBlock, SampleBlockWithOrigin
 30from .stat_calculators import StatsCalculator
 31from .stat_measures import (
 32    DatasetMean,
 33    DatasetMeasure,
 34    DatasetPercentile,
 35    DatasetStd,
 36    MeanMeasure,
 37    Measure,
 38    MeasureValue,
 39    SampleMean,
 40    SampleQuantile,
 41    SampleStd,
 42    Stat,
 43    StdMeasure,
 44)
 45from .tensor import Tensor
 46
 47
 48def _convert_axis_ids(
 49    axes: v0_4.AxesInCZYX,
 50    mode: Literal["per_sample", "per_dataset"],
 51) -> Tuple[AxisId, ...]:
 52    if not isinstance(axes, str):
 53        return tuple(axes)
 54
 55    if mode == "per_sample":
 56        ret = []
 57    elif mode == "per_dataset":
 58        ret = [v0_5.BATCH_AXIS_ID]
 59    else:
 60        assert_never(mode)
 61
 62    ret.extend([AxisId(a) for a in axes])
 63    return tuple(ret)
 64
 65
 66@dataclass
 67class _SimpleOperator(BlockedOperator, ABC):
 68    input: MemberId
 69    output: MemberId
 70
 71    @property
 72    def required_measures(self) -> Collection[Measure]:
 73        return set()
 74
 75    @abstractmethod
 76    def get_output_shape(self, input_shape: PerAxis[int]) -> PerAxis[int]: ...
 77
 78    def __call__(self, sample: Union[Sample, SampleBlock]) -> None:
 79        if self.input not in sample.members:
 80            return
 81
 82        input_tensor = sample.members[self.input]
 83        output_tensor = self._apply(input_tensor, sample.stat)
 84
 85        if self.output in sample.members:
 86            assert (
 87                sample.members[self.output].tagged_shape == output_tensor.tagged_shape
 88            )
 89
 90        if isinstance(sample, Sample):
 91            sample.members[self.output] = output_tensor
 92        elif isinstance(sample, SampleBlock):
 93            b = sample.blocks[self.input]
 94            sample.blocks[self.output] = Block(
 95                sample_shape=self.get_output_shape(sample.shape[self.input]),
 96                data=output_tensor,
 97                inner_slice=b.inner_slice,
 98                halo=b.halo,
 99                block_index=b.block_index,
100                blocks_in_sample=b.blocks_in_sample,
101            )
102        else:
103            assert_never(sample)
104
105    @abstractmethod
106    def _apply(self, input: Tensor, stat: Stat) -> Tensor: ...
107
108
109@dataclass
110class AddKnownDatasetStats(BlockedOperator):
111    dataset_stats: Mapping[DatasetMeasure, MeasureValue]
112
113    @property
114    def required_measures(self) -> Set[Measure]:
115        return set()
116
117    def __call__(self, sample: Union[Sample, SampleBlock]) -> None:
118        sample.stat.update(self.dataset_stats.items())
119
120
121# @dataclass
122# class UpdateStats(Operator):
123#     """Calculates sample and/or dataset measures"""
124
125#     measures: Union[Sequence[Measure], Set[Measure], Mapping[Measure, MeasureValue]]
126#     """sample and dataset `measuers` to be calculated by this operator. Initial/fixed
127#     dataset measure values may be given, see `keep_updating_dataset_stats` for details.
128#     """
129#     keep_updating_dataset_stats: Optional[bool] = None
130#     """indicates if operator calls should keep updating dataset statistics or not
131
132#     default (None): if `measures` is a `Mapping` (i.e. initial measure values are
133#     given) no further updates to dataset statistics is conducted, otherwise (w.o.
134#     initial measure values) dataset statistics are updated by each processed sample.
135#     """
136#     _keep_updating_dataset_stats: bool = field(init=False)
137#     _stats_calculator: StatsCalculator = field(init=False)
138
139#     @property
140#     def required_measures(self) -> Set[Measure]:
141#         return set()
142
143#     def __post_init__(self):
144#         self._stats_calculator = StatsCalculator(self.measures)
145#         if self.keep_updating_dataset_stats is None:
146#             self._keep_updating_dataset_stats = not isinstance(self.measures, collections.abc.Mapping)
147#         else:
148#             self._keep_updating_dataset_stats = self.keep_updating_dataset_stats
149
150#     def __call__(self, sample_block: SampleBlockWithOrigin> None:
151#         if self._keep_updating_dataset_stats:
152#             sample.stat.update(self._stats_calculator.update_and_get_all(sample))
153#         else:
154#             sample.stat.update(self._stats_calculator.skip_update_and_get_all(sample))
155
156
157@dataclass
158class UpdateStats(Operator):
159    """Calculates sample and/or dataset measures"""
160
161    stats_calculator: StatsCalculator
162    """`StatsCalculator` to be used by this operator."""
163    keep_updating_initial_dataset_stats: bool = False
164    """indicates if operator calls should keep updating initial dataset statistics or not;
165    if the `stats_calculator` was not provided with any initial dataset statistics,
166    these are always updated with every new sample.
167    """
168    _keep_updating_dataset_stats: bool = field(init=False)
169
170    @property
171    def required_measures(self) -> Set[Measure]:
172        return set()
173
174    def __post_init__(self):
175        self._keep_updating_dataset_stats = (
176            self.keep_updating_initial_dataset_stats
177            or not self.stats_calculator.has_dataset_measures
178        )
179
180    def __call__(self, sample: Union[Sample, SampleBlockWithOrigin]) -> None:
181        if isinstance(sample, SampleBlockWithOrigin):
182            # update stats with whole sample on first block
183            if sample.block_index != 0:
184                return
185
186            origin = sample.origin
187        else:
188            origin = sample
189
190        if self._keep_updating_dataset_stats:
191            sample.stat.update(self.stats_calculator.update_and_get_all(origin))
192        else:
193            sample.stat.update(self.stats_calculator.skip_update_and_get_all(origin))
194
195
196@dataclass
197class Binarize(_SimpleOperator):
198    """'output = tensor > threshold'."""
199
200    threshold: Union[float, Sequence[float]]
201    axis: Optional[AxisId] = None
202
203    def _apply(self, input: Tensor, stat: Stat) -> Tensor:
204        return input > self.threshold
205
206    def get_output_shape(
207        self, input_shape: Mapping[AxisId, int]
208    ) -> Mapping[AxisId, int]:
209        return input_shape
210
211    @classmethod
212    def from_proc_descr(
213        cls, descr: Union[v0_4.BinarizeDescr, v0_5.BinarizeDescr], member_id: MemberId
214    ) -> Self:
215        if isinstance(descr.kwargs, (v0_4.BinarizeKwargs, v0_5.BinarizeKwargs)):
216            return cls(
217                input=member_id, output=member_id, threshold=descr.kwargs.threshold
218            )
219        elif isinstance(descr.kwargs, v0_5.BinarizeAlongAxisKwargs):
220            return cls(
221                input=member_id,
222                output=member_id,
223                threshold=descr.kwargs.threshold,
224                axis=descr.kwargs.axis,
225            )
226        else:
227            assert_never(descr.kwargs)
228
229
230@dataclass
231class Clip(_SimpleOperator):
232    min: Optional[float] = None
233    """minimum value for clipping"""
234    max: Optional[float] = None
235    """maximum value for clipping"""
236
237    def __post_init__(self):
238        assert self.min is not None or self.max is not None, "missing min or max value"
239        assert (
240            self.min is None or self.max is None or self.min < self.max
241        ), f"expected min < max, but {self.min} !< {self.max}"
242
243    def _apply(self, input: Tensor, stat: Stat) -> Tensor:
244        return input.clip(self.min, self.max)
245
246    def get_output_shape(
247        self, input_shape: Mapping[AxisId, int]
248    ) -> Mapping[AxisId, int]:
249        return input_shape
250
251    @classmethod
252    def from_proc_descr(
253        cls, descr: Union[v0_4.ClipDescr, v0_5.ClipDescr], member_id: MemberId
254    ) -> Self:
255        return cls(
256            input=member_id,
257            output=member_id,
258            min=descr.kwargs.min,
259            max=descr.kwargs.max,
260        )
261
262
263@dataclass
264class EnsureDtype(_SimpleOperator):
265    dtype: DTypeStr
266
267    @classmethod
268    def from_proc_descr(cls, descr: v0_5.EnsureDtypeDescr, member_id: MemberId):
269        return cls(input=member_id, output=member_id, dtype=descr.kwargs.dtype)
270
271    def get_descr(self):
272        return v0_5.EnsureDtypeDescr(kwargs=v0_5.EnsureDtypeKwargs(dtype=self.dtype))
273
274    def get_output_shape(
275        self, input_shape: Mapping[AxisId, int]
276    ) -> Mapping[AxisId, int]:
277        return input_shape
278
279    def _apply(self, input: Tensor, stat: Stat) -> Tensor:
280        return input.astype(self.dtype)
281
282
283@dataclass
284class ScaleLinear(_SimpleOperator):
285    gain: Union[float, xr.DataArray] = 1.0
286    """multiplicative factor"""
287
288    offset: Union[float, xr.DataArray] = 0.0
289    """additive term"""
290
291    def _apply(self, input: Tensor, stat: Stat) -> Tensor:
292        return input * self.gain + self.offset
293
294    def get_output_shape(
295        self, input_shape: Mapping[AxisId, int]
296    ) -> Mapping[AxisId, int]:
297        return input_shape
298
299    @classmethod
300    def from_proc_descr(
301        cls,
302        descr: Union[v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr],
303        member_id: MemberId,
304    ) -> Self:
305        kwargs = descr.kwargs
306        if isinstance(kwargs, v0_5.ScaleLinearKwargs):
307            axis = None
308        elif isinstance(kwargs, v0_5.ScaleLinearAlongAxisKwargs):
309            axis = kwargs.axis
310        elif isinstance(kwargs, v0_4.ScaleLinearKwargs):
311            if kwargs.axes is not None:
312                raise NotImplementedError(
313                    "model.v0_4.ScaleLinearKwargs with axes not implemented, please consider updating the model to v0_5."
314                )
315            axis = None
316        else:
317            assert_never(kwargs)
318
319        if axis:
320            gain = xr.DataArray(np.atleast_1d(kwargs.gain), dims=axis)
321            offset = xr.DataArray(np.atleast_1d(kwargs.offset), dims=axis)
322        else:
323            assert (
324                isinstance(kwargs.gain, (float, int)) or len(kwargs.gain) == 1
325            ), kwargs.gain
326            gain = (
327                kwargs.gain if isinstance(kwargs.gain, (float, int)) else kwargs.gain[0]
328            )
329            assert isinstance(kwargs.offset, (float, int)) or len(kwargs.offset) == 1
330            offset = (
331                kwargs.offset
332                if isinstance(kwargs.offset, (float, int))
333                else kwargs.offset[0]
334            )
335
336        return cls(input=member_id, output=member_id, gain=gain, offset=offset)
337
338
339@dataclass
340class ScaleMeanVariance(_SimpleOperator):
341    axes: Optional[Sequence[AxisId]] = None
342    reference_tensor: Optional[MemberId] = None
343    eps: float = 1e-6
344    mean: Union[SampleMean, DatasetMean] = field(init=False)
345    std: Union[SampleStd, DatasetStd] = field(init=False)
346    ref_mean: Union[SampleMean, DatasetMean] = field(init=False)
347    ref_std: Union[SampleStd, DatasetStd] = field(init=False)
348
349    @property
350    def required_measures(self):
351        return {self.mean, self.std, self.ref_mean, self.ref_std}
352
353    def __post_init__(self):
354        axes = None if self.axes is None else tuple(self.axes)
355        ref_tensor = self.reference_tensor or self.input
356        if axes is None or AxisId("batch") not in axes:
357            Mean = SampleMean
358            Std = SampleStd
359        else:
360            Mean = DatasetMean
361            Std = DatasetStd
362
363        self.mean = Mean(member_id=self.input, axes=axes)
364        self.std = Std(member_id=self.input, axes=axes)
365        self.ref_mean = Mean(member_id=ref_tensor, axes=axes)
366        self.ref_std = Std(member_id=ref_tensor, axes=axes)
367
368    def _apply(self, input: Tensor, stat: Stat) -> Tensor:
369        mean = stat[self.mean]
370        std = stat[self.std] + self.eps
371        ref_mean = stat[self.ref_mean]
372        ref_std = stat[self.ref_std] + self.eps
373        return (input - mean) / std * ref_std + ref_mean
374
375    def get_output_shape(
376        self, input_shape: Mapping[AxisId, int]
377    ) -> Mapping[AxisId, int]:
378        return input_shape
379
380    @classmethod
381    def from_proc_descr(
382        cls,
383        descr: Union[v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr],
384        member_id: MemberId,
385    ) -> Self:
386        kwargs = descr.kwargs
387        _, axes = _get_axes(descr.kwargs)
388
389        return cls(
390            input=member_id,
391            output=member_id,
392            reference_tensor=MemberId(str(kwargs.reference_tensor)),
393            axes=axes,
394            eps=kwargs.eps,
395        )
396
397
398def _get_axes(
399    kwargs: Union[
400        v0_4.ZeroMeanUnitVarianceKwargs,
401        v0_5.ZeroMeanUnitVarianceKwargs,
402        v0_4.ScaleRangeKwargs,
403        v0_5.ScaleRangeKwargs,
404        v0_4.ScaleMeanVarianceKwargs,
405        v0_5.ScaleMeanVarianceKwargs,
406    ],
407) -> Tuple[bool, Optional[Tuple[AxisId, ...]]]:
408    if kwargs.axes is None:
409        return True, None
410    elif isinstance(kwargs.axes, str):
411        axes = _convert_axis_ids(kwargs.axes, kwargs["mode"])
412        return AxisId("b") in axes, axes
413    elif isinstance(kwargs.axes, collections.abc.Sequence):
414        axes = tuple(kwargs.axes)
415        return AxisId("batch") in axes, axes
416    else:
417        assert_never(kwargs.axes)
418
419
420@dataclass
421class ScaleRange(_SimpleOperator):
422    lower_percentile: InitVar[Optional[Union[SampleQuantile, DatasetPercentile]]] = None
423    upper_percentile: InitVar[Optional[Union[SampleQuantile, DatasetPercentile]]] = None
424    lower: Union[SampleQuantile, DatasetPercentile] = field(init=False)
425    upper: Union[SampleQuantile, DatasetPercentile] = field(init=False)
426
427    eps: float = 1e-6
428
429    def __post_init__(
430        self,
431        lower_percentile: Optional[Union[SampleQuantile, DatasetPercentile]],
432        upper_percentile: Optional[Union[SampleQuantile, DatasetPercentile]],
433    ):
434        if lower_percentile is None:
435            tid = self.input if upper_percentile is None else upper_percentile.member_id
436            self.lower = DatasetPercentile(q=0.0, member_id=tid)
437        else:
438            self.lower = lower_percentile
439
440        if upper_percentile is None:
441            self.upper = DatasetPercentile(q=1.0, member_id=self.lower.member_id)
442        else:
443            self.upper = upper_percentile
444
445        assert self.lower.member_id == self.upper.member_id
446        assert self.lower.q < self.upper.q
447        assert self.lower.axes == self.upper.axes
448
449    @property
450    def required_measures(self):
451        return {self.lower, self.upper}
452
453    def get_output_shape(
454        self, input_shape: Mapping[AxisId, int]
455    ) -> Mapping[AxisId, int]:
456        return input_shape
457
458    @classmethod
459    def from_proc_descr(
460        cls,
461        descr: Union[v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr],
462        member_id: MemberId,
463    ):
464        kwargs = descr.kwargs
465        ref_tensor = (
466            member_id
467            if kwargs.reference_tensor is None
468            else MemberId(str(kwargs.reference_tensor))
469        )
470        dataset_mode, axes = _get_axes(descr.kwargs)
471        if dataset_mode:
472            Percentile = DatasetPercentile
473        else:
474            Percentile = SampleQuantile
475
476        return cls(
477            input=member_id,
478            output=member_id,
479            lower_percentile=Percentile(
480                q=kwargs.min_percentile / 100, axes=axes, member_id=ref_tensor
481            ),
482            upper_percentile=Percentile(
483                q=kwargs.max_percentile / 100, axes=axes, member_id=ref_tensor
484            ),
485        )
486
487    def _apply(self, input: Tensor, stat: Stat) -> Tensor:
488        lower = stat[self.lower]
489        upper = stat[self.upper]
490        return (input - lower) / (upper - lower + self.eps)
491
492    def get_descr(self):
493        assert self.lower.axes == self.upper.axes
494        assert self.lower.member_id == self.upper.member_id
495
496        return v0_5.ScaleRangeDescr(
497            kwargs=v0_5.ScaleRangeKwargs(
498                axes=self.lower.axes,
499                min_percentile=self.lower.q * 100,
500                max_percentile=self.upper.q * 100,
501                eps=self.eps,
502                reference_tensor=self.lower.member_id,
503            )
504        )
505
506
507@dataclass
508class Sigmoid(_SimpleOperator):
509    """1 / (1 + e^(-input))."""
510
511    def _apply(self, input: Tensor, stat: Stat) -> Tensor:
512        return Tensor(1.0 / (1.0 + np.exp(-input)), dims=input.dims)
513
514    @property
515    def required_measures(self) -> Collection[Measure]:
516        return {}
517
518    def get_output_shape(
519        self, input_shape: Mapping[AxisId, int]
520    ) -> Mapping[AxisId, int]:
521        return input_shape
522
523    @classmethod
524    def from_proc_descr(
525        cls, descr: Union[v0_4.SigmoidDescr, v0_5.SigmoidDescr], member_id: MemberId
526    ) -> Self:
527        assert isinstance(descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr))
528        return cls(input=member_id, output=member_id)
529
530    def get_descr(self):
531        return v0_5.SigmoidDescr()
532
533
534@dataclass
535class ZeroMeanUnitVariance(_SimpleOperator):
536    """normalize to zero mean, unit variance."""
537
538    mean: MeanMeasure
539    std: StdMeasure
540
541    eps: float = 1e-6
542
543    def __post_init__(self):
544        assert self.mean.axes == self.std.axes
545
546    @property
547    def required_measures(self) -> Set[Union[MeanMeasure, StdMeasure]]:
548        return {self.mean, self.std}
549
550    def get_output_shape(
551        self, input_shape: Mapping[AxisId, int]
552    ) -> Mapping[AxisId, int]:
553        return input_shape
554
555    @classmethod
556    def from_proc_descr(
557        cls,
558        descr: Union[v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr],
559        member_id: MemberId,
560    ):
561        dataset_mode, axes = _get_axes(descr.kwargs)
562
563        if dataset_mode:
564            Mean = DatasetMean
565            Std = DatasetStd
566        else:
567            Mean = SampleMean
568            Std = SampleStd
569
570        return cls(
571            input=member_id,
572            output=member_id,
573            mean=Mean(axes=axes, member_id=member_id),
574            std=Std(axes=axes, member_id=member_id),
575        )
576
577    def _apply(self, input: Tensor, stat: Stat) -> Tensor:
578        mean = stat[self.mean]
579        std = stat[self.std]
580        return (input - mean) / (std + self.eps)
581
582    def get_descr(self):
583        return v0_5.ZeroMeanUnitVarianceDescr(
584            kwargs=v0_5.ZeroMeanUnitVarianceKwargs(axes=self.mean.axes, eps=self.eps)
585        )
586
587
588@dataclass
589class FixedZeroMeanUnitVariance(_SimpleOperator):
590    """normalize to zero mean, unit variance with precomputed values."""
591
592    mean: Union[float, xr.DataArray]
593    std: Union[float, xr.DataArray]
594
595    eps: float = 1e-6
596
597    def __post_init__(self):
598        assert (
599            isinstance(self.mean, (int, float))
600            or isinstance(self.std, (int, float))
601            or self.mean.dims == self.std.dims
602        )
603
604    def get_output_shape(
605        self, input_shape: Mapping[AxisId, int]
606    ) -> Mapping[AxisId, int]:
607        return input_shape
608
609    @classmethod
610    def from_proc_descr(
611        cls,
612        descr: v0_5.FixedZeroMeanUnitVarianceDescr,
613        member_id: MemberId,
614    ) -> Self:
615        if isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceKwargs):
616            dims = None
617        elif isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs):
618            dims = (AxisId(descr.kwargs.axis),)
619        else:
620            assert_never(descr.kwargs)
621
622        return cls(
623            input=member_id,
624            output=member_id,
625            mean=xr.DataArray(descr.kwargs.mean, dims=dims),
626            std=xr.DataArray(descr.kwargs.std, dims=dims),
627        )
628
629    def get_descr(self):
630        if isinstance(self.mean, (int, float)):
631            assert isinstance(self.std, (int, float))
632            kwargs = v0_5.FixedZeroMeanUnitVarianceKwargs(mean=self.mean, std=self.std)
633        else:
634            assert isinstance(self.std, xr.DataArray)
635            assert len(self.mean.dims) == 1
636            kwargs = v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs(
637                axis=AxisId(str(self.mean.dims[0])),
638                mean=list(self.mean),
639                std=list(self.std),
640            )
641
642        return v0_5.FixedZeroMeanUnitVarianceDescr(kwargs=kwargs)
643
644    def _apply(self, input: Tensor, stat: Stat) -> Tensor:
645        return (input - self.mean) / (self.std + self.eps)
646
647
648ProcDescr = Union[
649    v0_4.PreprocessingDescr,
650    v0_4.PostprocessingDescr,
651    v0_5.PreprocessingDescr,
652    v0_5.PostprocessingDescr,
653]
654
655Processing = Union[
656    AddKnownDatasetStats,
657    Binarize,
658    Clip,
659    EnsureDtype,
660    FixedZeroMeanUnitVariance,
661    ScaleLinear,
662    ScaleMeanVariance,
663    ScaleRange,
664    Sigmoid,
665    UpdateStats,
666    ZeroMeanUnitVariance,
667]
668
669
670def get_proc(
671    proc_descr: ProcDescr,
672    tensor_descr: Union[
673        v0_4.InputTensorDescr,
674        v0_4.OutputTensorDescr,
675        v0_5.InputTensorDescr,
676        v0_5.OutputTensorDescr,
677    ],
678) -> Processing:
679    member_id = get_member_id(tensor_descr)
680
681    if isinstance(proc_descr, (v0_4.BinarizeDescr, v0_5.BinarizeDescr)):
682        return Binarize.from_proc_descr(proc_descr, member_id)
683    elif isinstance(proc_descr, (v0_4.ClipDescr, v0_5.ClipDescr)):
684        return Clip.from_proc_descr(proc_descr, member_id)
685    elif isinstance(proc_descr, v0_5.EnsureDtypeDescr):
686        return EnsureDtype.from_proc_descr(proc_descr, member_id)
687    elif isinstance(proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr):
688        return FixedZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id)
689    elif isinstance(proc_descr, (v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr)):
690        return ScaleLinear.from_proc_descr(proc_descr, member_id)
691    elif isinstance(
692        proc_descr, (v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr)
693    ):
694        return ScaleMeanVariance.from_proc_descr(proc_descr, member_id)
695    elif isinstance(proc_descr, (v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr)):
696        return ScaleRange.from_proc_descr(proc_descr, member_id)
697    elif isinstance(proc_descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)):
698        return Sigmoid.from_proc_descr(proc_descr, member_id)
699    elif (
700        isinstance(proc_descr, v0_4.ZeroMeanUnitVarianceDescr)
701        and proc_descr.kwargs.mode == "fixed"
702    ):
703        if not isinstance(
704            tensor_descr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)
705        ):
706            raise TypeError(
707                "Expected v0_4 tensor description for v0_4 processing description"
708            )
709
710        v5_proc_descr = _convert_proc(proc_descr, tensor_descr.axes)
711        assert isinstance(v5_proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr)
712        return FixedZeroMeanUnitVariance.from_proc_descr(v5_proc_descr, member_id)
713    elif isinstance(
714        proc_descr,
715        (v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr),
716    ):
717        return ZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id)
718    else:
719        assert_never(proc_descr)
@dataclass
class AddKnownDatasetStats(bioimageio.core._op_base.BlockedOperator):
110@dataclass
111class AddKnownDatasetStats(BlockedOperator):
112    dataset_stats: Mapping[DatasetMeasure, MeasureValue]
113
114    @property
115    def required_measures(self) -> Set[Measure]:
116        return set()
117
118    def __call__(self, sample: Union[Sample, SampleBlock]) -> None:
119        sample.stat.update(self.dataset_stats.items())
AddKnownDatasetStats( dataset_stats: Mapping[Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer>, return_type=PydanticUndefined, when_used='always')]]])
dataset_stats: Mapping[Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7f25f54c6f20>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7f25f54c7100>, return_type=PydanticUndefined, when_used='always')]]]
required_measures: Set[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]
114    @property
115    def required_measures(self) -> Set[Measure]:
116        return set()
@dataclass
class UpdateStats(bioimageio.core._op_base.Operator):
158@dataclass
159class UpdateStats(Operator):
160    """Calculates sample and/or dataset measures"""
161
162    stats_calculator: StatsCalculator
163    """`StatsCalculator` to be used by this operator."""
164    keep_updating_initial_dataset_stats: bool = False
165    """indicates if operator calls should keep updating initial dataset statistics or not;
166    if the `stats_calculator` was not provided with any initial dataset statistics,
167    these are always updated with every new sample.
168    """
169    _keep_updating_dataset_stats: bool = field(init=False)
170
171    @property
172    def required_measures(self) -> Set[Measure]:
173        return set()
174
175    def __post_init__(self):
176        self._keep_updating_dataset_stats = (
177            self.keep_updating_initial_dataset_stats
178            or not self.stats_calculator.has_dataset_measures
179        )
180
181    def __call__(self, sample: Union[Sample, SampleBlockWithOrigin]) -> None:
182        if isinstance(sample, SampleBlockWithOrigin):
183            # update stats with whole sample on first block
184            if sample.block_index != 0:
185                return
186
187            origin = sample.origin
188        else:
189            origin = sample
190
191        if self._keep_updating_dataset_stats:
192            sample.stat.update(self.stats_calculator.update_and_get_all(origin))
193        else:
194            sample.stat.update(self.stats_calculator.skip_update_and_get_all(origin))

Calculates sample and/or dataset measures

UpdateStats( stats_calculator: bioimageio.core.stat_calculators.StatsCalculator, keep_updating_initial_dataset_stats: bool = False)

StatsCalculator to be used by this operator.

keep_updating_initial_dataset_stats: bool = False

indicates if operator calls should keep updating initial dataset statistics or not; if the stats_calculator was not provided with any initial dataset statistics, these are always updated with every new sample.

required_measures: Set[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]
171    @property
172    def required_measures(self) -> Set[Measure]:
173        return set()
@dataclass
class Binarize(_SimpleOperator):
197@dataclass
198class Binarize(_SimpleOperator):
199    """'output = tensor > threshold'."""
200
201    threshold: Union[float, Sequence[float]]
202    axis: Optional[AxisId] = None
203
204    def _apply(self, input: Tensor, stat: Stat) -> Tensor:
205        return input > self.threshold
206
207    def get_output_shape(
208        self, input_shape: Mapping[AxisId, int]
209    ) -> Mapping[AxisId, int]:
210        return input_shape
211
212    @classmethod
213    def from_proc_descr(
214        cls, descr: Union[v0_4.BinarizeDescr, v0_5.BinarizeDescr], member_id: MemberId
215    ) -> Self:
216        if isinstance(descr.kwargs, (v0_4.BinarizeKwargs, v0_5.BinarizeKwargs)):
217            return cls(
218                input=member_id, output=member_id, threshold=descr.kwargs.threshold
219            )
220        elif isinstance(descr.kwargs, v0_5.BinarizeAlongAxisKwargs):
221            return cls(
222                input=member_id,
223                output=member_id,
224                threshold=descr.kwargs.threshold,
225                axis=descr.kwargs.axis,
226            )
227        else:
228            assert_never(descr.kwargs)

'output = tensor > threshold'.

Binarize( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId, threshold: Union[float, Sequence[float]], axis: Optional[bioimageio.spec.model.v0_5.AxisId] = None)
threshold: Union[float, Sequence[float]]
axis: Optional[bioimageio.spec.model.v0_5.AxisId] = None
def get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
207    def get_output_shape(
208        self, input_shape: Mapping[AxisId, int]
209    ) -> Mapping[AxisId, int]:
210        return input_shape
@classmethod
def from_proc_descr( cls, descr: Union[bioimageio.spec.model.v0_4.BinarizeDescr, bioimageio.spec.model.v0_5.BinarizeDescr], member_id: bioimageio.spec.model.v0_5.TensorId) -> Self:
212    @classmethod
213    def from_proc_descr(
214        cls, descr: Union[v0_4.BinarizeDescr, v0_5.BinarizeDescr], member_id: MemberId
215    ) -> Self:
216        if isinstance(descr.kwargs, (v0_4.BinarizeKwargs, v0_5.BinarizeKwargs)):
217            return cls(
218                input=member_id, output=member_id, threshold=descr.kwargs.threshold
219            )
220        elif isinstance(descr.kwargs, v0_5.BinarizeAlongAxisKwargs):
221            return cls(
222                input=member_id,
223                output=member_id,
224                threshold=descr.kwargs.threshold,
225                axis=descr.kwargs.axis,
226            )
227        else:
228            assert_never(descr.kwargs)
@dataclass
class Clip(_SimpleOperator):
231@dataclass
232class Clip(_SimpleOperator):
233    min: Optional[float] = None
234    """minimum value for clipping"""
235    max: Optional[float] = None
236    """maximum value for clipping"""
237
238    def __post_init__(self):
239        assert self.min is not None or self.max is not None, "missing min or max value"
240        assert (
241            self.min is None or self.max is None or self.min < self.max
242        ), f"expected min < max, but {self.min} !< {self.max}"
243
244    def _apply(self, input: Tensor, stat: Stat) -> Tensor:
245        return input.clip(self.min, self.max)
246
247    def get_output_shape(
248        self, input_shape: Mapping[AxisId, int]
249    ) -> Mapping[AxisId, int]:
250        return input_shape
251
252    @classmethod
253    def from_proc_descr(
254        cls, descr: Union[v0_4.ClipDescr, v0_5.ClipDescr], member_id: MemberId
255    ) -> Self:
256        return cls(
257            input=member_id,
258            output=member_id,
259            min=descr.kwargs.min,
260            max=descr.kwargs.max,
261        )
Clip( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId, min: Optional[float] = None, max: Optional[float] = None)
min: Optional[float] = None

minimum value for clipping

max: Optional[float] = None

maximum value for clipping

def get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
247    def get_output_shape(
248        self, input_shape: Mapping[AxisId, int]
249    ) -> Mapping[AxisId, int]:
250        return input_shape
@classmethod
def from_proc_descr( cls, descr: Union[bioimageio.spec.model.v0_4.ClipDescr, bioimageio.spec.model.v0_5.ClipDescr], member_id: bioimageio.spec.model.v0_5.TensorId) -> Self:
252    @classmethod
253    def from_proc_descr(
254        cls, descr: Union[v0_4.ClipDescr, v0_5.ClipDescr], member_id: MemberId
255    ) -> Self:
256        return cls(
257            input=member_id,
258            output=member_id,
259            min=descr.kwargs.min,
260            max=descr.kwargs.max,
261        )
@dataclass
class EnsureDtype(_SimpleOperator):
264@dataclass
265class EnsureDtype(_SimpleOperator):
266    dtype: DTypeStr
267
268    @classmethod
269    def from_proc_descr(cls, descr: v0_5.EnsureDtypeDescr, member_id: MemberId):
270        return cls(input=member_id, output=member_id, dtype=descr.kwargs.dtype)
271
272    def get_descr(self):
273        return v0_5.EnsureDtypeDescr(kwargs=v0_5.EnsureDtypeKwargs(dtype=self.dtype))
274
275    def get_output_shape(
276        self, input_shape: Mapping[AxisId, int]
277    ) -> Mapping[AxisId, int]:
278        return input_shape
279
280    def _apply(self, input: Tensor, stat: Stat) -> Tensor:
281        return input.astype(self.dtype)
EnsureDtype( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId, dtype: Literal['bool', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64'])
dtype: Literal['bool', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64']
@classmethod
def from_proc_descr( cls, descr: bioimageio.spec.model.v0_5.EnsureDtypeDescr, member_id: bioimageio.spec.model.v0_5.TensorId):
268    @classmethod
269    def from_proc_descr(cls, descr: v0_5.EnsureDtypeDescr, member_id: MemberId):
270        return cls(input=member_id, output=member_id, dtype=descr.kwargs.dtype)
def get_descr(self):
272    def get_descr(self):
273        return v0_5.EnsureDtypeDescr(kwargs=v0_5.EnsureDtypeKwargs(dtype=self.dtype))
def get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
275    def get_output_shape(
276        self, input_shape: Mapping[AxisId, int]
277    ) -> Mapping[AxisId, int]:
278        return input_shape
@dataclass
class ScaleLinear(_SimpleOperator):
284@dataclass
285class ScaleLinear(_SimpleOperator):
286    gain: Union[float, xr.DataArray] = 1.0
287    """multiplicative factor"""
288
289    offset: Union[float, xr.DataArray] = 0.0
290    """additive term"""
291
292    def _apply(self, input: Tensor, stat: Stat) -> Tensor:
293        return input * self.gain + self.offset
294
295    def get_output_shape(
296        self, input_shape: Mapping[AxisId, int]
297    ) -> Mapping[AxisId, int]:
298        return input_shape
299
300    @classmethod
301    def from_proc_descr(
302        cls,
303        descr: Union[v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr],
304        member_id: MemberId,
305    ) -> Self:
306        kwargs = descr.kwargs
307        if isinstance(kwargs, v0_5.ScaleLinearKwargs):
308            axis = None
309        elif isinstance(kwargs, v0_5.ScaleLinearAlongAxisKwargs):
310            axis = kwargs.axis
311        elif isinstance(kwargs, v0_4.ScaleLinearKwargs):
312            if kwargs.axes is not None:
313                raise NotImplementedError(
314                    "model.v0_4.ScaleLinearKwargs with axes not implemented, please consider updating the model to v0_5."
315                )
316            axis = None
317        else:
318            assert_never(kwargs)
319
320        if axis:
321            gain = xr.DataArray(np.atleast_1d(kwargs.gain), dims=axis)
322            offset = xr.DataArray(np.atleast_1d(kwargs.offset), dims=axis)
323        else:
324            assert (
325                isinstance(kwargs.gain, (float, int)) or len(kwargs.gain) == 1
326            ), kwargs.gain
327            gain = (
328                kwargs.gain if isinstance(kwargs.gain, (float, int)) else kwargs.gain[0]
329            )
330            assert isinstance(kwargs.offset, (float, int)) or len(kwargs.offset) == 1
331            offset = (
332                kwargs.offset
333                if isinstance(kwargs.offset, (float, int))
334                else kwargs.offset[0]
335            )
336
337        return cls(input=member_id, output=member_id, gain=gain, offset=offset)
ScaleLinear( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId, gain: Union[float, xarray.core.dataarray.DataArray] = 1.0, offset: Union[float, xarray.core.dataarray.DataArray] = 0.0)
gain: Union[float, xarray.core.dataarray.DataArray] = 1.0

multiplicative factor

offset: Union[float, xarray.core.dataarray.DataArray] = 0.0

additive term

def get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
295    def get_output_shape(
296        self, input_shape: Mapping[AxisId, int]
297    ) -> Mapping[AxisId, int]:
298        return input_shape
@classmethod
def from_proc_descr( cls, descr: Union[bioimageio.spec.model.v0_4.ScaleLinearDescr, bioimageio.spec.model.v0_5.ScaleLinearDescr], member_id: bioimageio.spec.model.v0_5.TensorId) -> Self:
300    @classmethod
301    def from_proc_descr(
302        cls,
303        descr: Union[v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr],
304        member_id: MemberId,
305    ) -> Self:
306        kwargs = descr.kwargs
307        if isinstance(kwargs, v0_5.ScaleLinearKwargs):
308            axis = None
309        elif isinstance(kwargs, v0_5.ScaleLinearAlongAxisKwargs):
310            axis = kwargs.axis
311        elif isinstance(kwargs, v0_4.ScaleLinearKwargs):
312            if kwargs.axes is not None:
313                raise NotImplementedError(
314                    "model.v0_4.ScaleLinearKwargs with axes not implemented, please consider updating the model to v0_5."
315                )
316            axis = None
317        else:
318            assert_never(kwargs)
319
320        if axis:
321            gain = xr.DataArray(np.atleast_1d(kwargs.gain), dims=axis)
322            offset = xr.DataArray(np.atleast_1d(kwargs.offset), dims=axis)
323        else:
324            assert (
325                isinstance(kwargs.gain, (float, int)) or len(kwargs.gain) == 1
326            ), kwargs.gain
327            gain = (
328                kwargs.gain if isinstance(kwargs.gain, (float, int)) else kwargs.gain[0]
329            )
330            assert isinstance(kwargs.offset, (float, int)) or len(kwargs.offset) == 1
331            offset = (
332                kwargs.offset
333                if isinstance(kwargs.offset, (float, int))
334                else kwargs.offset[0]
335            )
336
337        return cls(input=member_id, output=member_id, gain=gain, offset=offset)
@dataclass
class ScaleMeanVariance(_SimpleOperator):
340@dataclass
341class ScaleMeanVariance(_SimpleOperator):
342    axes: Optional[Sequence[AxisId]] = None
343    reference_tensor: Optional[MemberId] = None
344    eps: float = 1e-6
345    mean: Union[SampleMean, DatasetMean] = field(init=False)
346    std: Union[SampleStd, DatasetStd] = field(init=False)
347    ref_mean: Union[SampleMean, DatasetMean] = field(init=False)
348    ref_std: Union[SampleStd, DatasetStd] = field(init=False)
349
350    @property
351    def required_measures(self):
352        return {self.mean, self.std, self.ref_mean, self.ref_std}
353
354    def __post_init__(self):
355        axes = None if self.axes is None else tuple(self.axes)
356        ref_tensor = self.reference_tensor or self.input
357        if axes is None or AxisId("batch") not in axes:
358            Mean = SampleMean
359            Std = SampleStd
360        else:
361            Mean = DatasetMean
362            Std = DatasetStd
363
364        self.mean = Mean(member_id=self.input, axes=axes)
365        self.std = Std(member_id=self.input, axes=axes)
366        self.ref_mean = Mean(member_id=ref_tensor, axes=axes)
367        self.ref_std = Std(member_id=ref_tensor, axes=axes)
368
369    def _apply(self, input: Tensor, stat: Stat) -> Tensor:
370        mean = stat[self.mean]
371        std = stat[self.std] + self.eps
372        ref_mean = stat[self.ref_mean]
373        ref_std = stat[self.ref_std] + self.eps
374        return (input - mean) / std * ref_std + ref_mean
375
376    def get_output_shape(
377        self, input_shape: Mapping[AxisId, int]
378    ) -> Mapping[AxisId, int]:
379        return input_shape
380
381    @classmethod
382    def from_proc_descr(
383        cls,
384        descr: Union[v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr],
385        member_id: MemberId,
386    ) -> Self:
387        kwargs = descr.kwargs
388        _, axes = _get_axes(descr.kwargs)
389
390        return cls(
391            input=member_id,
392            output=member_id,
393            reference_tensor=MemberId(str(kwargs.reference_tensor)),
394            axes=axes,
395            eps=kwargs.eps,
396        )
ScaleMeanVariance( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId, axes: Optional[Sequence[bioimageio.spec.model.v0_5.AxisId]] = None, reference_tensor: Optional[bioimageio.spec.model.v0_5.TensorId] = None, eps: float = 1e-06)
axes: Optional[Sequence[bioimageio.spec.model.v0_5.AxisId]] = None
reference_tensor: Optional[bioimageio.spec.model.v0_5.TensorId] = None
eps: float = 1e-06
required_measures
350    @property
351    def required_measures(self):
352        return {self.mean, self.std, self.ref_mean, self.ref_std}
def get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
376    def get_output_shape(
377        self, input_shape: Mapping[AxisId, int]
378    ) -> Mapping[AxisId, int]:
379        return input_shape
@classmethod
def from_proc_descr( cls, descr: Union[bioimageio.spec.model.v0_4.ScaleMeanVarianceDescr, bioimageio.spec.model.v0_5.ScaleMeanVarianceDescr], member_id: bioimageio.spec.model.v0_5.TensorId) -> Self:
381    @classmethod
382    def from_proc_descr(
383        cls,
384        descr: Union[v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr],
385        member_id: MemberId,
386    ) -> Self:
387        kwargs = descr.kwargs
388        _, axes = _get_axes(descr.kwargs)
389
390        return cls(
391            input=member_id,
392            output=member_id,
393            reference_tensor=MemberId(str(kwargs.reference_tensor)),
394            axes=axes,
395            eps=kwargs.eps,
396        )
Inherited Members
_SimpleOperator
input
output
@dataclass
class ScaleRange(_SimpleOperator):
421@dataclass
422class ScaleRange(_SimpleOperator):
423    lower_percentile: InitVar[Optional[Union[SampleQuantile, DatasetPercentile]]] = None
424    upper_percentile: InitVar[Optional[Union[SampleQuantile, DatasetPercentile]]] = None
425    lower: Union[SampleQuantile, DatasetPercentile] = field(init=False)
426    upper: Union[SampleQuantile, DatasetPercentile] = field(init=False)
427
428    eps: float = 1e-6
429
430    def __post_init__(
431        self,
432        lower_percentile: Optional[Union[SampleQuantile, DatasetPercentile]],
433        upper_percentile: Optional[Union[SampleQuantile, DatasetPercentile]],
434    ):
435        if lower_percentile is None:
436            tid = self.input if upper_percentile is None else upper_percentile.member_id
437            self.lower = DatasetPercentile(q=0.0, member_id=tid)
438        else:
439            self.lower = lower_percentile
440
441        if upper_percentile is None:
442            self.upper = DatasetPercentile(q=1.0, member_id=self.lower.member_id)
443        else:
444            self.upper = upper_percentile
445
446        assert self.lower.member_id == self.upper.member_id
447        assert self.lower.q < self.upper.q
448        assert self.lower.axes == self.upper.axes
449
450    @property
451    def required_measures(self):
452        return {self.lower, self.upper}
453
454    def get_output_shape(
455        self, input_shape: Mapping[AxisId, int]
456    ) -> Mapping[AxisId, int]:
457        return input_shape
458
459    @classmethod
460    def from_proc_descr(
461        cls,
462        descr: Union[v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr],
463        member_id: MemberId,
464    ):
465        kwargs = descr.kwargs
466        ref_tensor = (
467            member_id
468            if kwargs.reference_tensor is None
469            else MemberId(str(kwargs.reference_tensor))
470        )
471        dataset_mode, axes = _get_axes(descr.kwargs)
472        if dataset_mode:
473            Percentile = DatasetPercentile
474        else:
475            Percentile = SampleQuantile
476
477        return cls(
478            input=member_id,
479            output=member_id,
480            lower_percentile=Percentile(
481                q=kwargs.min_percentile / 100, axes=axes, member_id=ref_tensor
482            ),
483            upper_percentile=Percentile(
484                q=kwargs.max_percentile / 100, axes=axes, member_id=ref_tensor
485            ),
486        )
487
488    def _apply(self, input: Tensor, stat: Stat) -> Tensor:
489        lower = stat[self.lower]
490        upper = stat[self.upper]
491        return (input - lower) / (upper - lower + self.eps)
492
493    def get_descr(self):
494        assert self.lower.axes == self.upper.axes
495        assert self.lower.member_id == self.upper.member_id
496
497        return v0_5.ScaleRangeDescr(
498            kwargs=v0_5.ScaleRangeKwargs(
499                axes=self.lower.axes,
500                min_percentile=self.lower.q * 100,
501                max_percentile=self.upper.q * 100,
502                eps=self.eps,
503                reference_tensor=self.lower.member_id,
504            )
505        )
ScaleRange( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId, lower_percentile: dataclasses.InitVar[typing.Union[bioimageio.core.stat_measures.SampleQuantile, bioimageio.core.stat_measures.DatasetPercentile, NoneType]] = None, upper_percentile: dataclasses.InitVar[typing.Union[bioimageio.core.stat_measures.SampleQuantile, bioimageio.core.stat_measures.DatasetPercentile, NoneType]] = None, eps: float = 1e-06)
lower_percentile: dataclasses.InitVar[typing.Union[bioimageio.core.stat_measures.SampleQuantile, bioimageio.core.stat_measures.DatasetPercentile, NoneType]] = None
upper_percentile: dataclasses.InitVar[typing.Union[bioimageio.core.stat_measures.SampleQuantile, bioimageio.core.stat_measures.DatasetPercentile, NoneType]] = None
eps: float = 1e-06
required_measures
450    @property
451    def required_measures(self):
452        return {self.lower, self.upper}
def get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
454    def get_output_shape(
455        self, input_shape: Mapping[AxisId, int]
456    ) -> Mapping[AxisId, int]:
457        return input_shape
@classmethod
def from_proc_descr( cls, descr: Union[bioimageio.spec.model.v0_4.ScaleRangeDescr, bioimageio.spec.model.v0_5.ScaleRangeDescr], member_id: bioimageio.spec.model.v0_5.TensorId):
459    @classmethod
460    def from_proc_descr(
461        cls,
462        descr: Union[v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr],
463        member_id: MemberId,
464    ):
465        kwargs = descr.kwargs
466        ref_tensor = (
467            member_id
468            if kwargs.reference_tensor is None
469            else MemberId(str(kwargs.reference_tensor))
470        )
471        dataset_mode, axes = _get_axes(descr.kwargs)
472        if dataset_mode:
473            Percentile = DatasetPercentile
474        else:
475            Percentile = SampleQuantile
476
477        return cls(
478            input=member_id,
479            output=member_id,
480            lower_percentile=Percentile(
481                q=kwargs.min_percentile / 100, axes=axes, member_id=ref_tensor
482            ),
483            upper_percentile=Percentile(
484                q=kwargs.max_percentile / 100, axes=axes, member_id=ref_tensor
485            ),
486        )
def get_descr(self):
493    def get_descr(self):
494        assert self.lower.axes == self.upper.axes
495        assert self.lower.member_id == self.upper.member_id
496
497        return v0_5.ScaleRangeDescr(
498            kwargs=v0_5.ScaleRangeKwargs(
499                axes=self.lower.axes,
500                min_percentile=self.lower.q * 100,
501                max_percentile=self.upper.q * 100,
502                eps=self.eps,
503                reference_tensor=self.lower.member_id,
504            )
505        )
Inherited Members
_SimpleOperator
input
output
@dataclass
class Sigmoid(_SimpleOperator):
508@dataclass
509class Sigmoid(_SimpleOperator):
510    """1 / (1 + e^(-input))."""
511
512    def _apply(self, input: Tensor, stat: Stat) -> Tensor:
513        return Tensor(1.0 / (1.0 + np.exp(-input)), dims=input.dims)
514
515    @property
516    def required_measures(self) -> Collection[Measure]:
517        return {}
518
519    def get_output_shape(
520        self, input_shape: Mapping[AxisId, int]
521    ) -> Mapping[AxisId, int]:
522        return input_shape
523
524    @classmethod
525    def from_proc_descr(
526        cls, descr: Union[v0_4.SigmoidDescr, v0_5.SigmoidDescr], member_id: MemberId
527    ) -> Self:
528        assert isinstance(descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr))
529        return cls(input=member_id, output=member_id)
530
531    def get_descr(self):
532        return v0_5.SigmoidDescr()

1 / (1 + e^(-input)).

required_measures: Collection[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]
515    @property
516    def required_measures(self) -> Collection[Measure]:
517        return {}
def get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
519    def get_output_shape(
520        self, input_shape: Mapping[AxisId, int]
521    ) -> Mapping[AxisId, int]:
522        return input_shape
@classmethod
def from_proc_descr( cls, descr: Union[bioimageio.spec.model.v0_4.SigmoidDescr, bioimageio.spec.model.v0_5.SigmoidDescr], member_id: bioimageio.spec.model.v0_5.TensorId) -> Self:
524    @classmethod
525    def from_proc_descr(
526        cls, descr: Union[v0_4.SigmoidDescr, v0_5.SigmoidDescr], member_id: MemberId
527    ) -> Self:
528        assert isinstance(descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr))
529        return cls(input=member_id, output=member_id)
def get_descr(self):
531    def get_descr(self):
532        return v0_5.SigmoidDescr()
Inherited Members
_SimpleOperator
input
output
@dataclass
class ZeroMeanUnitVariance(_SimpleOperator):
535@dataclass
536class ZeroMeanUnitVariance(_SimpleOperator):
537    """normalize to zero mean, unit variance."""
538
539    mean: MeanMeasure
540    std: StdMeasure
541
542    eps: float = 1e-6
543
544    def __post_init__(self):
545        assert self.mean.axes == self.std.axes
546
547    @property
548    def required_measures(self) -> Set[Union[MeanMeasure, StdMeasure]]:
549        return {self.mean, self.std}
550
551    def get_output_shape(
552        self, input_shape: Mapping[AxisId, int]
553    ) -> Mapping[AxisId, int]:
554        return input_shape
555
556    @classmethod
557    def from_proc_descr(
558        cls,
559        descr: Union[v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr],
560        member_id: MemberId,
561    ):
562        dataset_mode, axes = _get_axes(descr.kwargs)
563
564        if dataset_mode:
565            Mean = DatasetMean
566            Std = DatasetStd
567        else:
568            Mean = SampleMean
569            Std = SampleStd
570
571        return cls(
572            input=member_id,
573            output=member_id,
574            mean=Mean(axes=axes, member_id=member_id),
575            std=Std(axes=axes, member_id=member_id),
576        )
577
578    def _apply(self, input: Tensor, stat: Stat) -> Tensor:
579        mean = stat[self.mean]
580        std = stat[self.std]
581        return (input - mean) / (std + self.eps)
582
583    def get_descr(self):
584        return v0_5.ZeroMeanUnitVarianceDescr(
585            kwargs=v0_5.ZeroMeanUnitVarianceKwargs(axes=self.mean.axes, eps=self.eps)
586        )

normalize to zero mean, unit variance.

eps: float = 1e-06
547    @property
548    def required_measures(self) -> Set[Union[MeanMeasure, StdMeasure]]:
549        return {self.mean, self.std}
def get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
551    def get_output_shape(
552        self, input_shape: Mapping[AxisId, int]
553    ) -> Mapping[AxisId, int]:
554        return input_shape
556    @classmethod
557    def from_proc_descr(
558        cls,
559        descr: Union[v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr],
560        member_id: MemberId,
561    ):
562        dataset_mode, axes = _get_axes(descr.kwargs)
563
564        if dataset_mode:
565            Mean = DatasetMean
566            Std = DatasetStd
567        else:
568            Mean = SampleMean
569            Std = SampleStd
570
571        return cls(
572            input=member_id,
573            output=member_id,
574            mean=Mean(axes=axes, member_id=member_id),
575            std=Std(axes=axes, member_id=member_id),
576        )
def get_descr(self):
583    def get_descr(self):
584        return v0_5.ZeroMeanUnitVarianceDescr(
585            kwargs=v0_5.ZeroMeanUnitVarianceKwargs(axes=self.mean.axes, eps=self.eps)
586        )
Inherited Members
_SimpleOperator
input
output
@dataclass
class FixedZeroMeanUnitVariance(_SimpleOperator):
589@dataclass
590class FixedZeroMeanUnitVariance(_SimpleOperator):
591    """normalize to zero mean, unit variance with precomputed values."""
592
593    mean: Union[float, xr.DataArray]
594    std: Union[float, xr.DataArray]
595
596    eps: float = 1e-6
597
598    def __post_init__(self):
599        assert (
600            isinstance(self.mean, (int, float))
601            or isinstance(self.std, (int, float))
602            or self.mean.dims == self.std.dims
603        )
604
605    def get_output_shape(
606        self, input_shape: Mapping[AxisId, int]
607    ) -> Mapping[AxisId, int]:
608        return input_shape
609
610    @classmethod
611    def from_proc_descr(
612        cls,
613        descr: v0_5.FixedZeroMeanUnitVarianceDescr,
614        member_id: MemberId,
615    ) -> Self:
616        if isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceKwargs):
617            dims = None
618        elif isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs):
619            dims = (AxisId(descr.kwargs.axis),)
620        else:
621            assert_never(descr.kwargs)
622
623        return cls(
624            input=member_id,
625            output=member_id,
626            mean=xr.DataArray(descr.kwargs.mean, dims=dims),
627            std=xr.DataArray(descr.kwargs.std, dims=dims),
628        )
629
630    def get_descr(self):
631        if isinstance(self.mean, (int, float)):
632            assert isinstance(self.std, (int, float))
633            kwargs = v0_5.FixedZeroMeanUnitVarianceKwargs(mean=self.mean, std=self.std)
634        else:
635            assert isinstance(self.std, xr.DataArray)
636            assert len(self.mean.dims) == 1
637            kwargs = v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs(
638                axis=AxisId(str(self.mean.dims[0])),
639                mean=list(self.mean),
640                std=list(self.std),
641            )
642
643        return v0_5.FixedZeroMeanUnitVarianceDescr(kwargs=kwargs)
644
645    def _apply(self, input: Tensor, stat: Stat) -> Tensor:
646        return (input - self.mean) / (self.std + self.eps)

normalize to zero mean, unit variance with precomputed values.

FixedZeroMeanUnitVariance( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId, mean: Union[float, xarray.core.dataarray.DataArray], std: Union[float, xarray.core.dataarray.DataArray], eps: float = 1e-06)
mean: Union[float, xarray.core.dataarray.DataArray]
std: Union[float, xarray.core.dataarray.DataArray]
eps: float = 1e-06
def get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
605    def get_output_shape(
606        self, input_shape: Mapping[AxisId, int]
607    ) -> Mapping[AxisId, int]:
608        return input_shape
@classmethod
def from_proc_descr( cls, descr: bioimageio.spec.model.v0_5.FixedZeroMeanUnitVarianceDescr, member_id: bioimageio.spec.model.v0_5.TensorId) -> Self:
610    @classmethod
611    def from_proc_descr(
612        cls,
613        descr: v0_5.FixedZeroMeanUnitVarianceDescr,
614        member_id: MemberId,
615    ) -> Self:
616        if isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceKwargs):
617            dims = None
618        elif isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs):
619            dims = (AxisId(descr.kwargs.axis),)
620        else:
621            assert_never(descr.kwargs)
622
623        return cls(
624            input=member_id,
625            output=member_id,
626            mean=xr.DataArray(descr.kwargs.mean, dims=dims),
627            std=xr.DataArray(descr.kwargs.std, dims=dims),
628        )
def get_descr(self):
630    def get_descr(self):
631        if isinstance(self.mean, (int, float)):
632            assert isinstance(self.std, (int, float))
633            kwargs = v0_5.FixedZeroMeanUnitVarianceKwargs(mean=self.mean, std=self.std)
634        else:
635            assert isinstance(self.std, xr.DataArray)
636            assert len(self.mean.dims) == 1
637            kwargs = v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs(
638                axis=AxisId(str(self.mean.dims[0])),
639                mean=list(self.mean),
640                std=list(self.std),
641            )
642
643        return v0_5.FixedZeroMeanUnitVarianceDescr(kwargs=kwargs)
ProcDescr = typing.Union[typing.Annotated[typing.Union[bioimageio.spec.model.v0_4.BinarizeDescr, bioimageio.spec.model.v0_4.ClipDescr, bioimageio.spec.model.v0_4.ScaleLinearDescr, bioimageio.spec.model.v0_4.SigmoidDescr, bioimageio.spec.model.v0_4.ZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_4.ScaleRangeDescr], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], typing.Annotated[typing.Union[bioimageio.spec.model.v0_4.BinarizeDescr, bioimageio.spec.model.v0_4.ClipDescr, bioimageio.spec.model.v0_4.ScaleLinearDescr, bioimageio.spec.model.v0_4.SigmoidDescr, bioimageio.spec.model.v0_4.ZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_4.ScaleRangeDescr, bioimageio.spec.model.v0_4.ScaleMeanVarianceDescr], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], typing.Annotated[typing.Union[bioimageio.spec.model.v0_5.BinarizeDescr, bioimageio.spec.model.v0_5.ClipDescr, bioimageio.spec.model.v0_5.EnsureDtypeDescr, bioimageio.spec.model.v0_5.ScaleLinearDescr, bioimageio.spec.model.v0_5.SigmoidDescr, bioimageio.spec.model.v0_5.FixedZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_5.ZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_5.ScaleRangeDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)], typing.Annotated[typing.Union[bioimageio.spec.model.v0_5.BinarizeDescr, bioimageio.spec.model.v0_5.ClipDescr, bioimageio.spec.model.v0_5.EnsureDtypeDescr, bioimageio.spec.model.v0_5.ScaleLinearDescr, bioimageio.spec.model.v0_5.SigmoidDescr, bioimageio.spec.model.v0_5.FixedZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_5.ZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_5.ScaleRangeDescr, bioimageio.spec.model.v0_5.ScaleMeanVarianceDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]
def get_proc( proc_descr: Union[Annotated[Union[bioimageio.spec.model.v0_4.BinarizeDescr, bioimageio.spec.model.v0_4.ClipDescr, bioimageio.spec.model.v0_4.ScaleLinearDescr, bioimageio.spec.model.v0_4.SigmoidDescr, bioimageio.spec.model.v0_4.ZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_4.ScaleRangeDescr], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.spec.model.v0_4.BinarizeDescr, bioimageio.spec.model.v0_4.ClipDescr, bioimageio.spec.model.v0_4.ScaleLinearDescr, bioimageio.spec.model.v0_4.SigmoidDescr, bioimageio.spec.model.v0_4.ZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_4.ScaleRangeDescr, bioimageio.spec.model.v0_4.ScaleMeanVarianceDescr], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.spec.model.v0_5.BinarizeDescr, bioimageio.spec.model.v0_5.ClipDescr, bioimageio.spec.model.v0_5.EnsureDtypeDescr, bioimageio.spec.model.v0_5.ScaleLinearDescr, bioimageio.spec.model.v0_5.SigmoidDescr, bioimageio.spec.model.v0_5.FixedZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_5.ZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_5.ScaleRangeDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.spec.model.v0_5.BinarizeDescr, bioimageio.spec.model.v0_5.ClipDescr, bioimageio.spec.model.v0_5.EnsureDtypeDescr, bioimageio.spec.model.v0_5.ScaleLinearDescr, bioimageio.spec.model.v0_5.SigmoidDescr, bioimageio.spec.model.v0_5.FixedZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_5.ZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_5.ScaleRangeDescr, bioimageio.spec.model.v0_5.ScaleMeanVarianceDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], tensor_descr: Union[bioimageio.spec.model.v0_4.InputTensorDescr, bioimageio.spec.model.v0_4.OutputTensorDescr, bioimageio.spec.model.v0_5.InputTensorDescr, bioimageio.spec.model.v0_5.OutputTensorDescr]) -> Union[AddKnownDatasetStats, Binarize, Clip, EnsureDtype, FixedZeroMeanUnitVariance, ScaleLinear, ScaleMeanVariance, ScaleRange, Sigmoid, UpdateStats, ZeroMeanUnitVariance]:
671def get_proc(
672    proc_descr: ProcDescr,
673    tensor_descr: Union[
674        v0_4.InputTensorDescr,
675        v0_4.OutputTensorDescr,
676        v0_5.InputTensorDescr,
677        v0_5.OutputTensorDescr,
678    ],
679) -> Processing:
680    member_id = get_member_id(tensor_descr)
681
682    if isinstance(proc_descr, (v0_4.BinarizeDescr, v0_5.BinarizeDescr)):
683        return Binarize.from_proc_descr(proc_descr, member_id)
684    elif isinstance(proc_descr, (v0_4.ClipDescr, v0_5.ClipDescr)):
685        return Clip.from_proc_descr(proc_descr, member_id)
686    elif isinstance(proc_descr, v0_5.EnsureDtypeDescr):
687        return EnsureDtype.from_proc_descr(proc_descr, member_id)
688    elif isinstance(proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr):
689        return FixedZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id)
690    elif isinstance(proc_descr, (v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr)):
691        return ScaleLinear.from_proc_descr(proc_descr, member_id)
692    elif isinstance(
693        proc_descr, (v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr)
694    ):
695        return ScaleMeanVariance.from_proc_descr(proc_descr, member_id)
696    elif isinstance(proc_descr, (v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr)):
697        return ScaleRange.from_proc_descr(proc_descr, member_id)
698    elif isinstance(proc_descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)):
699        return Sigmoid.from_proc_descr(proc_descr, member_id)
700    elif (
701        isinstance(proc_descr, v0_4.ZeroMeanUnitVarianceDescr)
702        and proc_descr.kwargs.mode == "fixed"
703    ):
704        if not isinstance(
705            tensor_descr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)
706        ):
707            raise TypeError(
708                "Expected v0_4 tensor description for v0_4 processing description"
709            )
710
711        v5_proc_descr = _convert_proc(proc_descr, tensor_descr.axes)
712        assert isinstance(v5_proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr)
713        return FixedZeroMeanUnitVariance.from_proc_descr(v5_proc_descr, member_id)
714    elif isinstance(
715        proc_descr,
716        (v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr),
717    ):
718        return ZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id)
719    else:
720        assert_never(proc_descr)