Coverage for src / bioimageio / core / proc_ops.py: 76%
395 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-13 09:46 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-13 09:46 +0000
1import collections.abc
2from abc import ABC, abstractmethod
3from dataclasses import InitVar, dataclass, field
4from functools import partial
5from typing import (
6 Collection,
7 Literal,
8 Mapping,
9 Optional,
10 Sequence,
11 Set,
12 Tuple,
13 Union,
14)
16import numpy as np
17import scipy # pyright: ignore[reportMissingTypeStubs]
18import xarray as xr
19from typing_extensions import Self, assert_never
21from bioimageio.core.digest_spec import get_member_id
22from bioimageio.spec.model import v0_4, v0_5
23from bioimageio.spec.model.v0_5 import (
24 _convert_proc, # pyright: ignore [reportPrivateUsage]
25)
27from ._op_base import BlockedOperator, Operator
28from .axis import AxisId, PerAxis
29from .block import Block
30from .common import DTypeStr, MemberId
31from .sample import Sample, SampleBlock, SampleBlockWithOrigin
32from .stat_calculators import StatsCalculator
33from .stat_measures import (
34 DatasetMean,
35 DatasetMeasure,
36 DatasetQuantile,
37 DatasetStd,
38 MeanMeasure,
39 Measure,
40 MeasureValue,
41 SampleMean,
42 SampleQuantile,
43 SampleStd,
44 Stat,
45 StdMeasure,
46)
47from .tensor import Tensor
50def _convert_axis_ids(
51 axes: v0_4.AxesInCZYX,
52 mode: Literal["per_sample", "per_dataset"],
53) -> Tuple[AxisId, ...]:
54 if not isinstance(axes, str):
55 return tuple(axes)
57 if mode == "per_sample":
58 ret = []
59 elif mode == "per_dataset":
60 ret = [v0_5.BATCH_AXIS_ID]
61 else:
62 assert_never(mode)
64 ret.extend([AxisId(a) for a in axes])
65 return tuple(ret)
68@dataclass
69class _SimpleOperator(BlockedOperator, ABC):
70 input: MemberId
71 output: MemberId
73 @property
74 def required_measures(self) -> Collection[Measure]:
75 return set()
77 @abstractmethod
78 def get_output_shape(self, input_shape: PerAxis[int]) -> PerAxis[int]: ...
80 def __call__(self, sample: Union[Sample, SampleBlock]) -> None:
81 if self.input not in sample.members:
82 return
84 input_tensor = sample.members[self.input]
85 output_tensor = self._apply(input_tensor, sample.stat)
87 if self.output in sample.members:
88 assert (
89 sample.members[self.output].tagged_shape == output_tensor.tagged_shape
90 )
92 if isinstance(sample, Sample):
93 sample.members[self.output] = output_tensor
94 elif isinstance(sample, SampleBlock):
95 b = sample.blocks[self.input]
96 sample.blocks[self.output] = Block(
97 sample_shape=self.get_output_shape(sample.shape[self.input]),
98 data=output_tensor,
99 inner_slice=b.inner_slice,
100 halo=b.halo,
101 block_index=b.block_index,
102 blocks_in_sample=b.blocks_in_sample,
103 )
104 else:
105 assert_never(sample)
107 @abstractmethod
108 def _apply(self, x: Tensor, stat: Stat) -> Tensor: ...
111@dataclass
112class AddKnownDatasetStats(BlockedOperator):
113 dataset_stats: Mapping[DatasetMeasure, MeasureValue]
115 @property
116 def required_measures(self) -> Set[Measure]:
117 return set()
119 def __call__(self, sample: Union[Sample, SampleBlock]) -> None:
120 sample.stat.update(self.dataset_stats.items())
123# @dataclass
124# class UpdateStats(Operator):
125# """Calculates sample and/or dataset measures"""
127# measures: Union[Sequence[Measure], Set[Measure], Mapping[Measure, MeasureValue]]
128# """sample and dataset `measuers` to be calculated by this operator. Initial/fixed
129# dataset measure values may be given, see `keep_updating_dataset_stats` for details.
130# """
131# keep_updating_dataset_stats: Optional[bool] = None
132# """indicates if operator calls should keep updating dataset statistics or not
134# default (None): if `measures` is a `Mapping` (i.e. initial measure values are
135# given) no further updates to dataset statistics is conducted, otherwise (w.o.
136# initial measure values) dataset statistics are updated by each processed sample.
137# """
138# _keep_updating_dataset_stats: bool = field(init=False)
139# _stats_calculator: StatsCalculator = field(init=False)
141# @property
142# def required_measures(self) -> Set[Measure]:
143# return set()
145# def __post_init__(self):
146# self._stats_calculator = StatsCalculator(self.measures)
147# if self.keep_updating_dataset_stats is None:
148# self._keep_updating_dataset_stats = not isinstance(self.measures, collections.abc.Mapping)
149# else:
150# self._keep_updating_dataset_stats = self.keep_updating_dataset_stats
152# def __call__(self, sample_block: SampleBlockWithOrigin> None:
153# if self._keep_updating_dataset_stats:
154# sample.stat.update(self._stats_calculator.update_and_get_all(sample))
155# else:
156# sample.stat.update(self._stats_calculator.skip_update_and_get_all(sample))
159@dataclass
160class UpdateStats(Operator):
161 """Calculates sample and/or dataset measures"""
163 stats_calculator: StatsCalculator
164 """`StatsCalculator` to be used by this operator."""
165 keep_updating_initial_dataset_stats: bool = False
166 """indicates if operator calls should keep updating initial dataset statistics or not;
167 if the `stats_calculator` was not provided with any initial dataset statistics,
168 these are always updated with every new sample.
169 """
170 _keep_updating_dataset_stats: bool = field(init=False)
172 @property
173 def required_measures(self) -> Set[Measure]:
174 return set()
176 def __post_init__(self):
177 self._keep_updating_dataset_stats = (
178 self.keep_updating_initial_dataset_stats
179 or not self.stats_calculator.has_dataset_measures
180 )
182 def __call__(self, sample: Union[Sample, SampleBlockWithOrigin]) -> None:
183 if isinstance(sample, SampleBlockWithOrigin):
184 # update stats with whole sample on first block
185 if sample.block_index != 0:
186 return
188 origin = sample.origin
189 else:
190 origin = sample
192 if self._keep_updating_dataset_stats:
193 sample.stat.update(self.stats_calculator.update_and_get_all(origin))
194 else:
195 sample.stat.update(self.stats_calculator.skip_update_and_get_all(origin))
198@dataclass
199class Binarize(_SimpleOperator):
200 """'output = tensor > threshold'."""
202 threshold: Union[float, Sequence[float]]
203 axis: Optional[AxisId] = None
205 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
206 return x > self.threshold
208 def get_output_shape(
209 self, input_shape: Mapping[AxisId, int]
210 ) -> Mapping[AxisId, int]:
211 return input_shape
213 @classmethod
214 def from_proc_descr(
215 cls, descr: Union[v0_4.BinarizeDescr, v0_5.BinarizeDescr], member_id: MemberId
216 ) -> Self:
217 if isinstance(descr.kwargs, (v0_4.BinarizeKwargs, v0_5.BinarizeKwargs)):
218 return cls(
219 input=member_id, output=member_id, threshold=descr.kwargs.threshold
220 )
221 elif isinstance(descr.kwargs, v0_5.BinarizeAlongAxisKwargs):
222 return cls(
223 input=member_id,
224 output=member_id,
225 threshold=descr.kwargs.threshold,
226 axis=descr.kwargs.axis,
227 )
228 else:
229 assert_never(descr.kwargs)
232@dataclass
233class Clip(_SimpleOperator):
234 min: Optional[Union[float, SampleQuantile, DatasetQuantile]] = None
235 """minimum value for clipping"""
236 max: Optional[Union[float, SampleQuantile, DatasetQuantile]] = None
237 """maximum value for clipping"""
239 def __post_init__(self):
240 if self.min is None and self.max is None:
241 raise ValueError("missing min or max value")
243 if (
244 isinstance(self.min, float)
245 and isinstance(self.max, float)
246 and self.min >= self.max
247 ):
248 raise ValueError(f"expected min < max, but {self.min} >= {self.max}")
250 if isinstance(self.min, (SampleQuantile, DatasetQuantile)) and isinstance(
251 self.max, (SampleQuantile, DatasetQuantile)
252 ):
253 if self.min.axes != self.max.axes:
254 raise NotImplementedError(
255 f"expected min and max quantiles with same axes, but got {self.min.axes} and {self.max.axes}"
256 )
257 if self.min.q >= self.max.q:
258 raise ValueError(
259 f"expected min quantile < max quantile, but {self.min.q} >= {self.max.q}"
260 )
262 @property
263 def required_measures(self):
264 return {
265 arg
266 for arg in (self.min, self.max)
267 if isinstance(arg, (SampleQuantile, DatasetQuantile))
268 }
270 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
271 if isinstance(self.min, (SampleQuantile, DatasetQuantile)):
272 min_value = stat[self.min]
273 if isinstance(min_value, (int, float)):
274 # use clip for scalar value
275 min_clip_arg = min_value
276 else:
277 # clip does not support non-scalar values
278 x = Tensor.from_xarray(
279 x.data.where(x.data >= min_value.data, min_value.data)
280 )
281 min_clip_arg = None
282 else:
283 min_clip_arg = self.min
285 if isinstance(self.max, (SampleQuantile, DatasetQuantile)):
286 max_value = stat[self.max]
287 if isinstance(max_value, (int, float)):
288 # use clip for scalar value
289 max_clip_arg = max_value
290 else:
291 # clip does not support non-scalar values
292 x = Tensor.from_xarray(
293 x.data.where(x.data <= max_value.data, max_value.data)
294 )
295 max_clip_arg = None
296 else:
297 max_clip_arg = self.max
299 if min_clip_arg is not None or max_clip_arg is not None:
300 x = x.clip(min_clip_arg, max_clip_arg)
302 return x
304 def get_output_shape(
305 self, input_shape: Mapping[AxisId, int]
306 ) -> Mapping[AxisId, int]:
307 return input_shape
309 @classmethod
310 def from_proc_descr(
311 cls, descr: Union[v0_4.ClipDescr, v0_5.ClipDescr], member_id: MemberId
312 ) -> Self:
313 if isinstance(descr, v0_5.ClipDescr):
314 dataset_mode, axes = _get_axes(descr.kwargs)
315 if dataset_mode:
316 Quantile = DatasetQuantile
317 else:
318 Quantile = partial(SampleQuantile, method="inverted_cdf")
320 if descr.kwargs.min is not None:
321 min_arg = descr.kwargs.min
322 elif descr.kwargs.min_percentile is not None:
323 min_arg = Quantile(
324 q=descr.kwargs.min_percentile / 100,
325 axes=axes,
326 member_id=member_id,
327 )
328 else:
329 min_arg = None
331 if descr.kwargs.max is not None:
332 max_arg = descr.kwargs.max
333 elif descr.kwargs.max_percentile is not None:
334 max_arg = Quantile(
335 q=descr.kwargs.max_percentile / 100,
336 axes=axes,
337 member_id=member_id,
338 )
339 else:
340 max_arg = None
342 elif isinstance(descr, v0_4.ClipDescr):
343 min_arg = descr.kwargs.min
344 max_arg = descr.kwargs.max
345 else:
346 assert_never(descr)
348 return cls(
349 input=member_id,
350 output=member_id,
351 min=min_arg,
352 max=max_arg,
353 )
356@dataclass
357class EnsureDtype(_SimpleOperator):
358 dtype: DTypeStr
360 @classmethod
361 def from_proc_descr(cls, descr: v0_5.EnsureDtypeDescr, member_id: MemberId):
362 return cls(input=member_id, output=member_id, dtype=descr.kwargs.dtype)
364 def get_descr(self):
365 return v0_5.EnsureDtypeDescr(kwargs=v0_5.EnsureDtypeKwargs(dtype=self.dtype))
367 def get_output_shape(
368 self, input_shape: Mapping[AxisId, int]
369 ) -> Mapping[AxisId, int]:
370 return input_shape
372 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
373 return x.astype(self.dtype)
376@dataclass
377class ScaleLinear(_SimpleOperator):
378 gain: Union[float, xr.DataArray] = 1.0
379 """multiplicative factor"""
381 offset: Union[float, xr.DataArray] = 0.0
382 """additive term"""
384 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
385 return x * self.gain + self.offset
387 def get_output_shape(
388 self, input_shape: Mapping[AxisId, int]
389 ) -> Mapping[AxisId, int]:
390 return input_shape
392 @classmethod
393 def from_proc_descr(
394 cls,
395 descr: Union[v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr],
396 member_id: MemberId,
397 ) -> Self:
398 kwargs = descr.kwargs
399 if isinstance(kwargs, v0_5.ScaleLinearKwargs):
400 axis = None
401 elif isinstance(kwargs, v0_5.ScaleLinearAlongAxisKwargs):
402 axis = kwargs.axis
403 elif isinstance(kwargs, v0_4.ScaleLinearKwargs):
404 if kwargs.axes is not None:
405 raise NotImplementedError(
406 "model.v0_4.ScaleLinearKwargs with axes not implemented, please consider updating the model to v0_5."
407 )
408 axis = None
409 else:
410 assert_never(kwargs)
412 if axis:
413 gain = xr.DataArray(np.atleast_1d(kwargs.gain), dims=axis)
414 offset = xr.DataArray(np.atleast_1d(kwargs.offset), dims=axis)
415 else:
416 assert isinstance(kwargs.gain, (float, int)) or len(kwargs.gain) == 1, (
417 kwargs.gain
418 )
419 gain = (
420 kwargs.gain if isinstance(kwargs.gain, (float, int)) else kwargs.gain[0]
421 )
422 assert isinstance(kwargs.offset, (float, int)) or len(kwargs.offset) == 1
423 offset = (
424 kwargs.offset
425 if isinstance(kwargs.offset, (float, int))
426 else kwargs.offset[0]
427 )
429 return cls(input=member_id, output=member_id, gain=gain, offset=offset)
432@dataclass
433class ScaleMeanVariance(_SimpleOperator):
434 axes: Optional[Sequence[AxisId]] = None
435 reference_tensor: Optional[MemberId] = None
436 eps: float = 1e-6
437 mean: Union[SampleMean, DatasetMean] = field(init=False)
438 std: Union[SampleStd, DatasetStd] = field(init=False)
439 ref_mean: Union[SampleMean, DatasetMean] = field(init=False)
440 ref_std: Union[SampleStd, DatasetStd] = field(init=False)
442 @property
443 def required_measures(self):
444 return {self.mean, self.std, self.ref_mean, self.ref_std}
446 def __post_init__(self):
447 axes = None if self.axes is None else tuple(self.axes)
448 ref_tensor = self.reference_tensor or self.input
449 if axes is None or AxisId("batch") not in axes:
450 Mean = SampleMean
451 Std = SampleStd
452 else:
453 Mean = DatasetMean
454 Std = DatasetStd
456 self.mean = Mean(member_id=self.input, axes=axes)
457 self.std = Std(member_id=self.input, axes=axes)
458 self.ref_mean = Mean(member_id=ref_tensor, axes=axes)
459 self.ref_std = Std(member_id=ref_tensor, axes=axes)
461 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
462 mean = stat[self.mean]
463 std = stat[self.std] + self.eps
464 ref_mean = stat[self.ref_mean]
465 ref_std = stat[self.ref_std] + self.eps
466 return (x - mean) / std * ref_std + ref_mean
468 def get_output_shape(
469 self, input_shape: Mapping[AxisId, int]
470 ) -> Mapping[AxisId, int]:
471 return input_shape
473 @classmethod
474 def from_proc_descr(
475 cls,
476 descr: Union[v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr],
477 member_id: MemberId,
478 ) -> Self:
479 kwargs = descr.kwargs
480 _, axes = _get_axes(descr.kwargs)
482 return cls(
483 input=member_id,
484 output=member_id,
485 reference_tensor=MemberId(str(kwargs.reference_tensor)),
486 axes=axes,
487 eps=kwargs.eps,
488 )
491def _get_axes(
492 kwargs: Union[
493 v0_4.ZeroMeanUnitVarianceKwargs,
494 v0_5.ZeroMeanUnitVarianceKwargs,
495 v0_4.ScaleRangeKwargs,
496 v0_5.ScaleRangeKwargs,
497 v0_4.ScaleMeanVarianceKwargs,
498 v0_5.ScaleMeanVarianceKwargs,
499 v0_5.ClipKwargs,
500 ],
501) -> Tuple[bool, Optional[Tuple[AxisId, ...]]]:
502 if kwargs.axes is None:
503 return True, None
504 elif isinstance(kwargs.axes, str):
505 axes = _convert_axis_ids(kwargs.axes, kwargs["mode"])
506 return AxisId("b") in axes, axes
507 elif isinstance(kwargs.axes, collections.abc.Sequence):
508 axes = tuple(kwargs.axes)
509 return AxisId("batch") in axes, axes
510 else:
511 assert_never(kwargs.axes)
514@dataclass
515class ScaleRange(_SimpleOperator):
516 lower_quantile: InitVar[Optional[Union[SampleQuantile, DatasetQuantile]]] = None
517 upper_quantile: InitVar[Optional[Union[SampleQuantile, DatasetQuantile]]] = None
518 lower: Union[SampleQuantile, DatasetQuantile] = field(init=False)
519 upper: Union[SampleQuantile, DatasetQuantile] = field(init=False)
521 eps: float = 1e-6
523 def __post_init__(
524 self,
525 lower_quantile: Optional[Union[SampleQuantile, DatasetQuantile]],
526 upper_quantile: Optional[Union[SampleQuantile, DatasetQuantile]],
527 ):
528 if lower_quantile is None:
529 tid = self.input if upper_quantile is None else upper_quantile.member_id
530 self.lower = DatasetQuantile(q=0.0, member_id=tid)
531 else:
532 self.lower = lower_quantile
534 if upper_quantile is None:
535 self.upper = DatasetQuantile(q=1.0, member_id=self.lower.member_id)
536 else:
537 self.upper = upper_quantile
539 assert self.lower.member_id == self.upper.member_id
540 assert self.lower.q < self.upper.q
541 assert self.lower.axes == self.upper.axes
543 @property
544 def required_measures(self):
545 return {self.lower, self.upper}
547 def get_output_shape(
548 self, input_shape: Mapping[AxisId, int]
549 ) -> Mapping[AxisId, int]:
550 return input_shape
552 @classmethod
553 def from_proc_descr(
554 cls,
555 descr: Union[v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr],
556 member_id: MemberId,
557 ):
558 kwargs = descr.kwargs
559 ref_tensor = (
560 member_id
561 if kwargs.reference_tensor is None
562 else MemberId(str(kwargs.reference_tensor))
563 )
564 dataset_mode, axes = _get_axes(descr.kwargs)
565 if dataset_mode:
566 Quantile = DatasetQuantile
567 else:
568 Quantile = partial(SampleQuantile, method="linear")
570 return cls(
571 input=member_id,
572 output=member_id,
573 lower_quantile=Quantile(
574 q=kwargs.min_percentile / 100,
575 axes=axes,
576 member_id=ref_tensor,
577 ),
578 upper_quantile=Quantile(
579 q=kwargs.max_percentile / 100,
580 axes=axes,
581 member_id=ref_tensor,
582 ),
583 )
585 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
586 lower = stat[self.lower]
587 upper = stat[self.upper]
588 return (x - lower) / (upper - lower + self.eps)
590 def get_descr(self):
591 assert self.lower.axes == self.upper.axes
592 assert self.lower.member_id == self.upper.member_id
594 return v0_5.ScaleRangeDescr(
595 kwargs=v0_5.ScaleRangeKwargs(
596 axes=self.lower.axes,
597 min_percentile=self.lower.q * 100,
598 max_percentile=self.upper.q * 100,
599 eps=self.eps,
600 reference_tensor=self.lower.member_id,
601 )
602 )
605@dataclass
606class Sigmoid(_SimpleOperator):
607 """1 / (1 + e^(-input))."""
609 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
610 return Tensor(1.0 / (1.0 + np.exp(-x)), dims=x.dims)
612 @property
613 def required_measures(self) -> Collection[Measure]:
614 return {}
616 def get_output_shape(
617 self, input_shape: Mapping[AxisId, int]
618 ) -> Mapping[AxisId, int]:
619 return input_shape
621 @classmethod
622 def from_proc_descr(
623 cls, descr: Union[v0_4.SigmoidDescr, v0_5.SigmoidDescr], member_id: MemberId
624 ) -> Self:
625 assert isinstance(descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr))
626 return cls(input=member_id, output=member_id)
628 def get_descr(self):
629 return v0_5.SigmoidDescr()
632@dataclass
633class Softmax(_SimpleOperator):
634 """Softmax activation function."""
636 axis: AxisId = AxisId("channel")
638 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
639 axis_idx = x.dims.index(self.axis)
640 result = scipy.special.softmax(x.data, axis=axis_idx)
641 result_xr = xr.DataArray(result, dims=x.dims)
642 return Tensor.from_xarray(result_xr)
644 @property
645 def required_measures(self) -> Collection[Measure]:
646 return {}
648 def get_output_shape(
649 self, input_shape: Mapping[AxisId, int]
650 ) -> Mapping[AxisId, int]:
651 return input_shape
653 @classmethod
654 def from_proc_descr(cls, descr: v0_5.SoftmaxDescr, member_id: MemberId) -> Self:
655 assert isinstance(descr, v0_5.SoftmaxDescr)
656 return cls(input=member_id, output=member_id, axis=descr.kwargs.axis)
658 def get_descr(self):
659 return v0_5.SoftmaxDescr(kwargs=v0_5.SoftmaxKwargs(axis=self.axis))
662@dataclass
663class ZeroMeanUnitVariance(_SimpleOperator):
664 """normalize to zero mean, unit variance."""
666 mean: MeanMeasure
667 std: StdMeasure
669 eps: float = 1e-6
671 def __post_init__(self):
672 assert self.mean.axes == self.std.axes
674 @property
675 def required_measures(self) -> Set[Union[MeanMeasure, StdMeasure]]:
676 return {self.mean, self.std}
678 def get_output_shape(
679 self, input_shape: Mapping[AxisId, int]
680 ) -> Mapping[AxisId, int]:
681 return input_shape
683 @classmethod
684 def from_proc_descr(
685 cls,
686 descr: Union[v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr],
687 member_id: MemberId,
688 ):
689 dataset_mode, axes = _get_axes(descr.kwargs)
691 if dataset_mode:
692 Mean = DatasetMean
693 Std = DatasetStd
694 else:
695 Mean = SampleMean
696 Std = SampleStd
698 return cls(
699 input=member_id,
700 output=member_id,
701 mean=Mean(axes=axes, member_id=member_id),
702 std=Std(axes=axes, member_id=member_id),
703 )
705 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
706 mean = stat[self.mean]
707 std = stat[self.std]
708 return (x - mean) / (std + self.eps)
710 def get_descr(self):
711 return v0_5.ZeroMeanUnitVarianceDescr(
712 kwargs=v0_5.ZeroMeanUnitVarianceKwargs(axes=self.mean.axes, eps=self.eps)
713 )
716@dataclass
717class FixedZeroMeanUnitVariance(_SimpleOperator):
718 """normalize to zero mean, unit variance with precomputed values."""
720 mean: Union[float, xr.DataArray]
721 std: Union[float, xr.DataArray]
723 eps: float = 1e-6
725 def __post_init__(self):
726 assert (
727 isinstance(self.mean, (int, float))
728 or isinstance(self.std, (int, float))
729 or self.mean.dims == self.std.dims
730 )
732 def get_output_shape(
733 self, input_shape: Mapping[AxisId, int]
734 ) -> Mapping[AxisId, int]:
735 return input_shape
737 @classmethod
738 def from_proc_descr(
739 cls,
740 descr: v0_5.FixedZeroMeanUnitVarianceDescr,
741 member_id: MemberId,
742 ) -> Self:
743 if isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceKwargs):
744 dims = None
745 elif isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs):
746 dims = (AxisId(descr.kwargs.axis),)
747 else:
748 assert_never(descr.kwargs)
750 return cls(
751 input=member_id,
752 output=member_id,
753 mean=xr.DataArray(descr.kwargs.mean, dims=dims),
754 std=xr.DataArray(descr.kwargs.std, dims=dims),
755 )
757 def get_descr(self):
758 if isinstance(self.mean, (int, float)):
759 assert isinstance(self.std, (int, float))
760 kwargs = v0_5.FixedZeroMeanUnitVarianceKwargs(mean=self.mean, std=self.std)
761 else:
762 assert isinstance(self.std, xr.DataArray)
763 assert len(self.mean.dims) == 1
764 kwargs = v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs(
765 axis=AxisId(str(self.mean.dims[0])),
766 mean=list(self.mean),
767 std=list(self.std),
768 )
770 return v0_5.FixedZeroMeanUnitVarianceDescr(kwargs=kwargs)
772 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
773 return (x - self.mean) / (self.std + self.eps)
776ProcDescr = Union[
777 v0_4.PreprocessingDescr,
778 v0_4.PostprocessingDescr,
779 v0_5.PreprocessingDescr,
780 v0_5.PostprocessingDescr,
781]
783Processing = Union[
784 AddKnownDatasetStats,
785 Binarize,
786 Clip,
787 EnsureDtype,
788 FixedZeroMeanUnitVariance,
789 ScaleLinear,
790 ScaleMeanVariance,
791 ScaleRange,
792 Sigmoid,
793 Softmax,
794 UpdateStats,
795 ZeroMeanUnitVariance,
796]
799def get_proc(
800 proc_descr: ProcDescr,
801 tensor_descr: Union[
802 v0_4.InputTensorDescr,
803 v0_4.OutputTensorDescr,
804 v0_5.InputTensorDescr,
805 v0_5.OutputTensorDescr,
806 ],
807) -> Processing:
808 member_id = get_member_id(tensor_descr)
810 if isinstance(proc_descr, (v0_4.BinarizeDescr, v0_5.BinarizeDescr)):
811 return Binarize.from_proc_descr(proc_descr, member_id)
812 elif isinstance(proc_descr, (v0_4.ClipDescr, v0_5.ClipDescr)):
813 return Clip.from_proc_descr(proc_descr, member_id)
814 elif isinstance(proc_descr, v0_5.EnsureDtypeDescr):
815 return EnsureDtype.from_proc_descr(proc_descr, member_id)
816 elif isinstance(proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr):
817 return FixedZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id)
818 elif isinstance(proc_descr, (v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr)):
819 return ScaleLinear.from_proc_descr(proc_descr, member_id)
820 elif isinstance(
821 proc_descr, (v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr)
822 ):
823 return ScaleMeanVariance.from_proc_descr(proc_descr, member_id)
824 elif isinstance(proc_descr, (v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr)):
825 return ScaleRange.from_proc_descr(proc_descr, member_id)
826 elif isinstance(proc_descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)):
827 return Sigmoid.from_proc_descr(proc_descr, member_id)
828 elif (
829 isinstance(proc_descr, v0_4.ZeroMeanUnitVarianceDescr)
830 and proc_descr.kwargs.mode == "fixed"
831 ):
832 if not isinstance(
833 tensor_descr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)
834 ):
835 raise TypeError(
836 "Expected v0_4 tensor description for v0_4 processing description"
837 )
839 v5_proc_descr = _convert_proc(proc_descr, tensor_descr.axes)
840 assert isinstance(v5_proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr)
841 return FixedZeroMeanUnitVariance.from_proc_descr(v5_proc_descr, member_id)
842 elif isinstance(
843 proc_descr,
844 (v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr),
845 ):
846 return ZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id)
847 elif isinstance(proc_descr, v0_5.SoftmaxDescr):
848 return Softmax.from_proc_descr(proc_descr, member_id)
849 else:
850 assert_never(proc_descr)