Coverage for src/bioimageio/core/proc_ops.py: 77%
350 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-09-22 09:21 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-09-22 09:21 +0000
1import collections.abc
2from abc import ABC, abstractmethod
3from dataclasses import InitVar, dataclass, field
4from typing import (
5 Collection,
6 Literal,
7 Mapping,
8 Optional,
9 Sequence,
10 Set,
11 Tuple,
12 Union,
13)
15import numpy as np
16import scipy # pyright: ignore[reportMissingTypeStubs]
17import xarray as xr
18from typing_extensions import Self, assert_never
20from bioimageio.core.digest_spec import get_member_id
21from bioimageio.spec.model import v0_4, v0_5
22from bioimageio.spec.model.v0_5 import (
23 _convert_proc, # pyright: ignore [reportPrivateUsage]
24)
26from ._op_base import BlockedOperator, Operator
27from .axis import AxisId, PerAxis
28from .block import Block
29from .common import DTypeStr, MemberId
30from .sample import Sample, SampleBlock, SampleBlockWithOrigin
31from .stat_calculators import StatsCalculator
32from .stat_measures import (
33 DatasetMean,
34 DatasetMeasure,
35 DatasetPercentile,
36 DatasetStd,
37 MeanMeasure,
38 Measure,
39 MeasureValue,
40 SampleMean,
41 SampleQuantile,
42 SampleStd,
43 Stat,
44 StdMeasure,
45)
46from .tensor import Tensor
49def _convert_axis_ids(
50 axes: v0_4.AxesInCZYX,
51 mode: Literal["per_sample", "per_dataset"],
52) -> Tuple[AxisId, ...]:
53 if not isinstance(axes, str):
54 return tuple(axes)
56 if mode == "per_sample":
57 ret = []
58 elif mode == "per_dataset":
59 ret = [v0_5.BATCH_AXIS_ID]
60 else:
61 assert_never(mode)
63 ret.extend([AxisId(a) for a in axes])
64 return tuple(ret)
67@dataclass
68class _SimpleOperator(BlockedOperator, ABC):
69 input: MemberId
70 output: MemberId
72 @property
73 def required_measures(self) -> Collection[Measure]:
74 return set()
76 @abstractmethod
77 def get_output_shape(self, input_shape: PerAxis[int]) -> PerAxis[int]: ...
79 def __call__(self, sample: Union[Sample, SampleBlock]) -> None:
80 if self.input not in sample.members:
81 return
83 input_tensor = sample.members[self.input]
84 output_tensor = self._apply(input_tensor, sample.stat)
86 if self.output in sample.members:
87 assert (
88 sample.members[self.output].tagged_shape == output_tensor.tagged_shape
89 )
91 if isinstance(sample, Sample):
92 sample.members[self.output] = output_tensor
93 elif isinstance(sample, SampleBlock):
94 b = sample.blocks[self.input]
95 sample.blocks[self.output] = Block(
96 sample_shape=self.get_output_shape(sample.shape[self.input]),
97 data=output_tensor,
98 inner_slice=b.inner_slice,
99 halo=b.halo,
100 block_index=b.block_index,
101 blocks_in_sample=b.blocks_in_sample,
102 )
103 else:
104 assert_never(sample)
106 @abstractmethod
107 def _apply(self, x: Tensor, stat: Stat) -> Tensor: ...
110@dataclass
111class AddKnownDatasetStats(BlockedOperator):
112 dataset_stats: Mapping[DatasetMeasure, MeasureValue]
114 @property
115 def required_measures(self) -> Set[Measure]:
116 return set()
118 def __call__(self, sample: Union[Sample, SampleBlock]) -> None:
119 sample.stat.update(self.dataset_stats.items())
122# @dataclass
123# class UpdateStats(Operator):
124# """Calculates sample and/or dataset measures"""
126# measures: Union[Sequence[Measure], Set[Measure], Mapping[Measure, MeasureValue]]
127# """sample and dataset `measuers` to be calculated by this operator. Initial/fixed
128# dataset measure values may be given, see `keep_updating_dataset_stats` for details.
129# """
130# keep_updating_dataset_stats: Optional[bool] = None
131# """indicates if operator calls should keep updating dataset statistics or not
133# default (None): if `measures` is a `Mapping` (i.e. initial measure values are
134# given) no further updates to dataset statistics is conducted, otherwise (w.o.
135# initial measure values) dataset statistics are updated by each processed sample.
136# """
137# _keep_updating_dataset_stats: bool = field(init=False)
138# _stats_calculator: StatsCalculator = field(init=False)
140# @property
141# def required_measures(self) -> Set[Measure]:
142# return set()
144# def __post_init__(self):
145# self._stats_calculator = StatsCalculator(self.measures)
146# if self.keep_updating_dataset_stats is None:
147# self._keep_updating_dataset_stats = not isinstance(self.measures, collections.abc.Mapping)
148# else:
149# self._keep_updating_dataset_stats = self.keep_updating_dataset_stats
151# def __call__(self, sample_block: SampleBlockWithOrigin> None:
152# if self._keep_updating_dataset_stats:
153# sample.stat.update(self._stats_calculator.update_and_get_all(sample))
154# else:
155# sample.stat.update(self._stats_calculator.skip_update_and_get_all(sample))
158@dataclass
159class UpdateStats(Operator):
160 """Calculates sample and/or dataset measures"""
162 stats_calculator: StatsCalculator
163 """`StatsCalculator` to be used by this operator."""
164 keep_updating_initial_dataset_stats: bool = False
165 """indicates if operator calls should keep updating initial dataset statistics or not;
166 if the `stats_calculator` was not provided with any initial dataset statistics,
167 these are always updated with every new sample.
168 """
169 _keep_updating_dataset_stats: bool = field(init=False)
171 @property
172 def required_measures(self) -> Set[Measure]:
173 return set()
175 def __post_init__(self):
176 self._keep_updating_dataset_stats = (
177 self.keep_updating_initial_dataset_stats
178 or not self.stats_calculator.has_dataset_measures
179 )
181 def __call__(self, sample: Union[Sample, SampleBlockWithOrigin]) -> None:
182 if isinstance(sample, SampleBlockWithOrigin):
183 # update stats with whole sample on first block
184 if sample.block_index != 0:
185 return
187 origin = sample.origin
188 else:
189 origin = sample
191 if self._keep_updating_dataset_stats:
192 sample.stat.update(self.stats_calculator.update_and_get_all(origin))
193 else:
194 sample.stat.update(self.stats_calculator.skip_update_and_get_all(origin))
197@dataclass
198class Binarize(_SimpleOperator):
199 """'output = tensor > threshold'."""
201 threshold: Union[float, Sequence[float]]
202 axis: Optional[AxisId] = None
204 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
205 return x > self.threshold
207 def get_output_shape(
208 self, input_shape: Mapping[AxisId, int]
209 ) -> Mapping[AxisId, int]:
210 return input_shape
212 @classmethod
213 def from_proc_descr(
214 cls, descr: Union[v0_4.BinarizeDescr, v0_5.BinarizeDescr], member_id: MemberId
215 ) -> Self:
216 if isinstance(descr.kwargs, (v0_4.BinarizeKwargs, v0_5.BinarizeKwargs)):
217 return cls(
218 input=member_id, output=member_id, threshold=descr.kwargs.threshold
219 )
220 elif isinstance(descr.kwargs, v0_5.BinarizeAlongAxisKwargs):
221 return cls(
222 input=member_id,
223 output=member_id,
224 threshold=descr.kwargs.threshold,
225 axis=descr.kwargs.axis,
226 )
227 else:
228 assert_never(descr.kwargs)
231@dataclass
232class Clip(_SimpleOperator):
233 min: Optional[float] = None
234 """minimum value for clipping"""
235 max: Optional[float] = None
236 """maximum value for clipping"""
238 def __post_init__(self):
239 assert self.min is not None or self.max is not None, "missing min or max value"
240 assert (
241 self.min is None or self.max is None or self.min < self.max
242 ), f"expected min < max, but {self.min} !< {self.max}"
244 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
245 return x.clip(self.min, self.max)
247 def get_output_shape(
248 self, input_shape: Mapping[AxisId, int]
249 ) -> Mapping[AxisId, int]:
250 return input_shape
252 @classmethod
253 def from_proc_descr(
254 cls, descr: Union[v0_4.ClipDescr, v0_5.ClipDescr], member_id: MemberId
255 ) -> Self:
256 return cls(
257 input=member_id,
258 output=member_id,
259 min=descr.kwargs.min,
260 max=descr.kwargs.max,
261 )
264@dataclass
265class EnsureDtype(_SimpleOperator):
266 dtype: DTypeStr
268 @classmethod
269 def from_proc_descr(cls, descr: v0_5.EnsureDtypeDescr, member_id: MemberId):
270 return cls(input=member_id, output=member_id, dtype=descr.kwargs.dtype)
272 def get_descr(self):
273 return v0_5.EnsureDtypeDescr(kwargs=v0_5.EnsureDtypeKwargs(dtype=self.dtype))
275 def get_output_shape(
276 self, input_shape: Mapping[AxisId, int]
277 ) -> Mapping[AxisId, int]:
278 return input_shape
280 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
281 return x.astype(self.dtype)
284@dataclass
285class ScaleLinear(_SimpleOperator):
286 gain: Union[float, xr.DataArray] = 1.0
287 """multiplicative factor"""
289 offset: Union[float, xr.DataArray] = 0.0
290 """additive term"""
292 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
293 return x * self.gain + self.offset
295 def get_output_shape(
296 self, input_shape: Mapping[AxisId, int]
297 ) -> Mapping[AxisId, int]:
298 return input_shape
300 @classmethod
301 def from_proc_descr(
302 cls,
303 descr: Union[v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr],
304 member_id: MemberId,
305 ) -> Self:
306 kwargs = descr.kwargs
307 if isinstance(kwargs, v0_5.ScaleLinearKwargs):
308 axis = None
309 elif isinstance(kwargs, v0_5.ScaleLinearAlongAxisKwargs):
310 axis = kwargs.axis
311 elif isinstance(kwargs, v0_4.ScaleLinearKwargs):
312 if kwargs.axes is not None:
313 raise NotImplementedError(
314 "model.v0_4.ScaleLinearKwargs with axes not implemented, please consider updating the model to v0_5."
315 )
316 axis = None
317 else:
318 assert_never(kwargs)
320 if axis:
321 gain = xr.DataArray(np.atleast_1d(kwargs.gain), dims=axis)
322 offset = xr.DataArray(np.atleast_1d(kwargs.offset), dims=axis)
323 else:
324 assert (
325 isinstance(kwargs.gain, (float, int)) or len(kwargs.gain) == 1
326 ), kwargs.gain
327 gain = (
328 kwargs.gain if isinstance(kwargs.gain, (float, int)) else kwargs.gain[0]
329 )
330 assert isinstance(kwargs.offset, (float, int)) or len(kwargs.offset) == 1
331 offset = (
332 kwargs.offset
333 if isinstance(kwargs.offset, (float, int))
334 else kwargs.offset[0]
335 )
337 return cls(input=member_id, output=member_id, gain=gain, offset=offset)
340@dataclass
341class ScaleMeanVariance(_SimpleOperator):
342 axes: Optional[Sequence[AxisId]] = None
343 reference_tensor: Optional[MemberId] = None
344 eps: float = 1e-6
345 mean: Union[SampleMean, DatasetMean] = field(init=False)
346 std: Union[SampleStd, DatasetStd] = field(init=False)
347 ref_mean: Union[SampleMean, DatasetMean] = field(init=False)
348 ref_std: Union[SampleStd, DatasetStd] = field(init=False)
350 @property
351 def required_measures(self):
352 return {self.mean, self.std, self.ref_mean, self.ref_std}
354 def __post_init__(self):
355 axes = None if self.axes is None else tuple(self.axes)
356 ref_tensor = self.reference_tensor or self.input
357 if axes is None or AxisId("batch") not in axes:
358 Mean = SampleMean
359 Std = SampleStd
360 else:
361 Mean = DatasetMean
362 Std = DatasetStd
364 self.mean = Mean(member_id=self.input, axes=axes)
365 self.std = Std(member_id=self.input, axes=axes)
366 self.ref_mean = Mean(member_id=ref_tensor, axes=axes)
367 self.ref_std = Std(member_id=ref_tensor, axes=axes)
369 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
370 mean = stat[self.mean]
371 std = stat[self.std] + self.eps
372 ref_mean = stat[self.ref_mean]
373 ref_std = stat[self.ref_std] + self.eps
374 return (x - mean) / std * ref_std + ref_mean
376 def get_output_shape(
377 self, input_shape: Mapping[AxisId, int]
378 ) -> Mapping[AxisId, int]:
379 return input_shape
381 @classmethod
382 def from_proc_descr(
383 cls,
384 descr: Union[v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr],
385 member_id: MemberId,
386 ) -> Self:
387 kwargs = descr.kwargs
388 _, axes = _get_axes(descr.kwargs)
390 return cls(
391 input=member_id,
392 output=member_id,
393 reference_tensor=MemberId(str(kwargs.reference_tensor)),
394 axes=axes,
395 eps=kwargs.eps,
396 )
399def _get_axes(
400 kwargs: Union[
401 v0_4.ZeroMeanUnitVarianceKwargs,
402 v0_5.ZeroMeanUnitVarianceKwargs,
403 v0_4.ScaleRangeKwargs,
404 v0_5.ScaleRangeKwargs,
405 v0_4.ScaleMeanVarianceKwargs,
406 v0_5.ScaleMeanVarianceKwargs,
407 ],
408) -> Tuple[bool, Optional[Tuple[AxisId, ...]]]:
409 if kwargs.axes is None:
410 return True, None
411 elif isinstance(kwargs.axes, str):
412 axes = _convert_axis_ids(kwargs.axes, kwargs["mode"])
413 return AxisId("b") in axes, axes
414 elif isinstance(kwargs.axes, collections.abc.Sequence):
415 axes = tuple(kwargs.axes)
416 return AxisId("batch") in axes, axes
417 else:
418 assert_never(kwargs.axes)
421@dataclass
422class ScaleRange(_SimpleOperator):
423 lower_percentile: InitVar[Optional[Union[SampleQuantile, DatasetPercentile]]] = None
424 upper_percentile: InitVar[Optional[Union[SampleQuantile, DatasetPercentile]]] = None
425 lower: Union[SampleQuantile, DatasetPercentile] = field(init=False)
426 upper: Union[SampleQuantile, DatasetPercentile] = field(init=False)
428 eps: float = 1e-6
430 def __post_init__(
431 self,
432 lower_percentile: Optional[Union[SampleQuantile, DatasetPercentile]],
433 upper_percentile: Optional[Union[SampleQuantile, DatasetPercentile]],
434 ):
435 if lower_percentile is None:
436 tid = self.input if upper_percentile is None else upper_percentile.member_id
437 self.lower = DatasetPercentile(q=0.0, member_id=tid)
438 else:
439 self.lower = lower_percentile
441 if upper_percentile is None:
442 self.upper = DatasetPercentile(q=1.0, member_id=self.lower.member_id)
443 else:
444 self.upper = upper_percentile
446 assert self.lower.member_id == self.upper.member_id
447 assert self.lower.q < self.upper.q
448 assert self.lower.axes == self.upper.axes
450 @property
451 def required_measures(self):
452 return {self.lower, self.upper}
454 def get_output_shape(
455 self, input_shape: Mapping[AxisId, int]
456 ) -> Mapping[AxisId, int]:
457 return input_shape
459 @classmethod
460 def from_proc_descr(
461 cls,
462 descr: Union[v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr],
463 member_id: MemberId,
464 ):
465 kwargs = descr.kwargs
466 ref_tensor = (
467 member_id
468 if kwargs.reference_tensor is None
469 else MemberId(str(kwargs.reference_tensor))
470 )
471 dataset_mode, axes = _get_axes(descr.kwargs)
472 if dataset_mode:
473 Percentile = DatasetPercentile
474 else:
475 Percentile = SampleQuantile
477 return cls(
478 input=member_id,
479 output=member_id,
480 lower_percentile=Percentile(
481 q=kwargs.min_percentile / 100, axes=axes, member_id=ref_tensor
482 ),
483 upper_percentile=Percentile(
484 q=kwargs.max_percentile / 100, axes=axes, member_id=ref_tensor
485 ),
486 )
488 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
489 lower = stat[self.lower]
490 upper = stat[self.upper]
491 return (x - lower) / (upper - lower + self.eps)
493 def get_descr(self):
494 assert self.lower.axes == self.upper.axes
495 assert self.lower.member_id == self.upper.member_id
497 return v0_5.ScaleRangeDescr(
498 kwargs=v0_5.ScaleRangeKwargs(
499 axes=self.lower.axes,
500 min_percentile=self.lower.q * 100,
501 max_percentile=self.upper.q * 100,
502 eps=self.eps,
503 reference_tensor=self.lower.member_id,
504 )
505 )
508@dataclass
509class Sigmoid(_SimpleOperator):
510 """1 / (1 + e^(-input))."""
512 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
513 return Tensor(1.0 / (1.0 + np.exp(-x)), dims=x.dims)
515 @property
516 def required_measures(self) -> Collection[Measure]:
517 return {}
519 def get_output_shape(
520 self, input_shape: Mapping[AxisId, int]
521 ) -> Mapping[AxisId, int]:
522 return input_shape
524 @classmethod
525 def from_proc_descr(
526 cls, descr: Union[v0_4.SigmoidDescr, v0_5.SigmoidDescr], member_id: MemberId
527 ) -> Self:
528 assert isinstance(descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr))
529 return cls(input=member_id, output=member_id)
531 def get_descr(self):
532 return v0_5.SigmoidDescr()
535@dataclass
536class Softmax(_SimpleOperator):
537 """Softmax activation function."""
539 axis: AxisId = AxisId("channel")
541 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
542 axis_idx = x.dims.index(self.axis)
543 result = scipy.special.softmax(x.data, axis=axis_idx)
544 result_xr = xr.DataArray(result, dims=x.dims)
545 return Tensor.from_xarray(result_xr)
547 @property
548 def required_measures(self) -> Collection[Measure]:
549 return {}
551 def get_output_shape(
552 self, input_shape: Mapping[AxisId, int]
553 ) -> Mapping[AxisId, int]:
554 return input_shape
556 @classmethod
557 def from_proc_descr(cls, descr: v0_5.SoftmaxDescr, member_id: MemberId) -> Self:
558 assert isinstance(descr, v0_5.SoftmaxDescr)
559 return cls(input=member_id, output=member_id, axis=descr.kwargs.axis)
561 def get_descr(self):
562 return v0_5.SoftmaxDescr(kwargs=v0_5.SoftmaxKwargs(axis=self.axis))
565@dataclass
566class ZeroMeanUnitVariance(_SimpleOperator):
567 """normalize to zero mean, unit variance."""
569 mean: MeanMeasure
570 std: StdMeasure
572 eps: float = 1e-6
574 def __post_init__(self):
575 assert self.mean.axes == self.std.axes
577 @property
578 def required_measures(self) -> Set[Union[MeanMeasure, StdMeasure]]:
579 return {self.mean, self.std}
581 def get_output_shape(
582 self, input_shape: Mapping[AxisId, int]
583 ) -> Mapping[AxisId, int]:
584 return input_shape
586 @classmethod
587 def from_proc_descr(
588 cls,
589 descr: Union[v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr],
590 member_id: MemberId,
591 ):
592 dataset_mode, axes = _get_axes(descr.kwargs)
594 if dataset_mode:
595 Mean = DatasetMean
596 Std = DatasetStd
597 else:
598 Mean = SampleMean
599 Std = SampleStd
601 return cls(
602 input=member_id,
603 output=member_id,
604 mean=Mean(axes=axes, member_id=member_id),
605 std=Std(axes=axes, member_id=member_id),
606 )
608 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
609 mean = stat[self.mean]
610 std = stat[self.std]
611 return (x - mean) / (std + self.eps)
613 def get_descr(self):
614 return v0_5.ZeroMeanUnitVarianceDescr(
615 kwargs=v0_5.ZeroMeanUnitVarianceKwargs(axes=self.mean.axes, eps=self.eps)
616 )
619@dataclass
620class FixedZeroMeanUnitVariance(_SimpleOperator):
621 """normalize to zero mean, unit variance with precomputed values."""
623 mean: Union[float, xr.DataArray]
624 std: Union[float, xr.DataArray]
626 eps: float = 1e-6
628 def __post_init__(self):
629 assert (
630 isinstance(self.mean, (int, float))
631 or isinstance(self.std, (int, float))
632 or self.mean.dims == self.std.dims
633 )
635 def get_output_shape(
636 self, input_shape: Mapping[AxisId, int]
637 ) -> Mapping[AxisId, int]:
638 return input_shape
640 @classmethod
641 def from_proc_descr(
642 cls,
643 descr: v0_5.FixedZeroMeanUnitVarianceDescr,
644 member_id: MemberId,
645 ) -> Self:
646 if isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceKwargs):
647 dims = None
648 elif isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs):
649 dims = (AxisId(descr.kwargs.axis),)
650 else:
651 assert_never(descr.kwargs)
653 return cls(
654 input=member_id,
655 output=member_id,
656 mean=xr.DataArray(descr.kwargs.mean, dims=dims),
657 std=xr.DataArray(descr.kwargs.std, dims=dims),
658 )
660 def get_descr(self):
661 if isinstance(self.mean, (int, float)):
662 assert isinstance(self.std, (int, float))
663 kwargs = v0_5.FixedZeroMeanUnitVarianceKwargs(mean=self.mean, std=self.std)
664 else:
665 assert isinstance(self.std, xr.DataArray)
666 assert len(self.mean.dims) == 1
667 kwargs = v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs(
668 axis=AxisId(str(self.mean.dims[0])),
669 mean=list(self.mean),
670 std=list(self.std),
671 )
673 return v0_5.FixedZeroMeanUnitVarianceDescr(kwargs=kwargs)
675 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
676 return (x - self.mean) / (self.std + self.eps)
679ProcDescr = Union[
680 v0_4.PreprocessingDescr,
681 v0_4.PostprocessingDescr,
682 v0_5.PreprocessingDescr,
683 v0_5.PostprocessingDescr,
684]
686Processing = Union[
687 AddKnownDatasetStats,
688 Binarize,
689 Clip,
690 EnsureDtype,
691 FixedZeroMeanUnitVariance,
692 ScaleLinear,
693 ScaleMeanVariance,
694 ScaleRange,
695 Sigmoid,
696 Softmax,
697 UpdateStats,
698 ZeroMeanUnitVariance,
699]
702def get_proc(
703 proc_descr: ProcDescr,
704 tensor_descr: Union[
705 v0_4.InputTensorDescr,
706 v0_4.OutputTensorDescr,
707 v0_5.InputTensorDescr,
708 v0_5.OutputTensorDescr,
709 ],
710) -> Processing:
711 member_id = get_member_id(tensor_descr)
713 if isinstance(proc_descr, (v0_4.BinarizeDescr, v0_5.BinarizeDescr)):
714 return Binarize.from_proc_descr(proc_descr, member_id)
715 elif isinstance(proc_descr, (v0_4.ClipDescr, v0_5.ClipDescr)):
716 return Clip.from_proc_descr(proc_descr, member_id)
717 elif isinstance(proc_descr, v0_5.EnsureDtypeDescr):
718 return EnsureDtype.from_proc_descr(proc_descr, member_id)
719 elif isinstance(proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr):
720 return FixedZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id)
721 elif isinstance(proc_descr, (v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr)):
722 return ScaleLinear.from_proc_descr(proc_descr, member_id)
723 elif isinstance(
724 proc_descr, (v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr)
725 ):
726 return ScaleMeanVariance.from_proc_descr(proc_descr, member_id)
727 elif isinstance(proc_descr, (v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr)):
728 return ScaleRange.from_proc_descr(proc_descr, member_id)
729 elif isinstance(proc_descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)):
730 return Sigmoid.from_proc_descr(proc_descr, member_id)
731 elif (
732 isinstance(proc_descr, v0_4.ZeroMeanUnitVarianceDescr)
733 and proc_descr.kwargs.mode == "fixed"
734 ):
735 if not isinstance(
736 tensor_descr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)
737 ):
738 raise TypeError(
739 "Expected v0_4 tensor description for v0_4 processing description"
740 )
742 v5_proc_descr = _convert_proc(proc_descr, tensor_descr.axes)
743 assert isinstance(v5_proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr)
744 return FixedZeroMeanUnitVariance.from_proc_descr(v5_proc_descr, member_id)
745 elif isinstance(
746 proc_descr,
747 (v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr),
748 ):
749 return ZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id)
750 elif isinstance(proc_descr, v0_5.SoftmaxDescr):
751 return Softmax.from_proc_descr(proc_descr, member_id)
752 else:
753 assert_never(proc_descr)