Coverage for bioimageio/core/proc_ops.py: 81%
332 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 09:51 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 09:51 +0000
1import collections.abc
2from abc import ABC, abstractmethod
3from dataclasses import InitVar, dataclass, field
4from typing import (
5 Collection,
6 Literal,
7 Mapping,
8 Optional,
9 Sequence,
10 Set,
11 Tuple,
12 Union,
13)
15import numpy as np
16import xarray as xr
17from typing_extensions import Self, assert_never
19from bioimageio.core.digest_spec import get_member_id
20from bioimageio.spec.model import v0_4, v0_5
21from bioimageio.spec.model.v0_5 import (
22 _convert_proc, # pyright: ignore [reportPrivateUsage]
23)
25from ._op_base import BlockedOperator, Operator
26from .axis import AxisId, PerAxis
27from .block import Block
28from .common import DTypeStr, MemberId
29from .sample import Sample, SampleBlock, SampleBlockWithOrigin
30from .stat_calculators import StatsCalculator
31from .stat_measures import (
32 DatasetMean,
33 DatasetMeasure,
34 DatasetPercentile,
35 DatasetStd,
36 MeanMeasure,
37 Measure,
38 MeasureValue,
39 SampleMean,
40 SampleQuantile,
41 SampleStd,
42 Stat,
43 StdMeasure,
44)
45from .tensor import Tensor
48def _convert_axis_ids(
49 axes: v0_4.AxesInCZYX,
50 mode: Literal["per_sample", "per_dataset"],
51) -> Tuple[AxisId, ...]:
52 if not isinstance(axes, str):
53 return tuple(axes)
55 if mode == "per_sample":
56 ret = []
57 elif mode == "per_dataset":
58 ret = [v0_5.BATCH_AXIS_ID]
59 else:
60 assert_never(mode)
62 ret.extend([AxisId(a) for a in axes])
63 return tuple(ret)
66@dataclass
67class _SimpleOperator(BlockedOperator, ABC):
68 input: MemberId
69 output: MemberId
71 @property
72 def required_measures(self) -> Collection[Measure]:
73 return set()
75 @abstractmethod
76 def get_output_shape(self, input_shape: PerAxis[int]) -> PerAxis[int]: ...
78 def __call__(self, sample: Union[Sample, SampleBlock]) -> None:
79 if self.input not in sample.members:
80 return
82 input_tensor = sample.members[self.input]
83 output_tensor = self._apply(input_tensor, sample.stat)
85 if self.output in sample.members:
86 assert (
87 sample.members[self.output].tagged_shape == output_tensor.tagged_shape
88 )
90 if isinstance(sample, Sample):
91 sample.members[self.output] = output_tensor
92 elif isinstance(sample, SampleBlock):
93 b = sample.blocks[self.input]
94 sample.blocks[self.output] = Block(
95 sample_shape=self.get_output_shape(sample.shape[self.input]),
96 data=output_tensor,
97 inner_slice=b.inner_slice,
98 halo=b.halo,
99 block_index=b.block_index,
100 blocks_in_sample=b.blocks_in_sample,
101 )
102 else:
103 assert_never(sample)
105 @abstractmethod
106 def _apply(self, input: Tensor, stat: Stat) -> Tensor: ...
109@dataclass
110class AddKnownDatasetStats(BlockedOperator):
111 dataset_stats: Mapping[DatasetMeasure, MeasureValue]
113 @property
114 def required_measures(self) -> Set[Measure]:
115 return set()
117 def __call__(self, sample: Union[Sample, SampleBlock]) -> None:
118 sample.stat.update(self.dataset_stats.items())
121# @dataclass
122# class UpdateStats(Operator):
123# """Calculates sample and/or dataset measures"""
125# measures: Union[Sequence[Measure], Set[Measure], Mapping[Measure, MeasureValue]]
126# """sample and dataset `measuers` to be calculated by this operator. Initial/fixed
127# dataset measure values may be given, see `keep_updating_dataset_stats` for details.
128# """
129# keep_updating_dataset_stats: Optional[bool] = None
130# """indicates if operator calls should keep updating dataset statistics or not
132# default (None): if `measures` is a `Mapping` (i.e. initial measure values are
133# given) no further updates to dataset statistics is conducted, otherwise (w.o.
134# initial measure values) dataset statistics are updated by each processed sample.
135# """
136# _keep_updating_dataset_stats: bool = field(init=False)
137# _stats_calculator: StatsCalculator = field(init=False)
139# @property
140# def required_measures(self) -> Set[Measure]:
141# return set()
143# def __post_init__(self):
144# self._stats_calculator = StatsCalculator(self.measures)
145# if self.keep_updating_dataset_stats is None:
146# self._keep_updating_dataset_stats = not isinstance(self.measures, collections.abc.Mapping)
147# else:
148# self._keep_updating_dataset_stats = self.keep_updating_dataset_stats
150# def __call__(self, sample_block: SampleBlockWithOrigin> None:
151# if self._keep_updating_dataset_stats:
152# sample.stat.update(self._stats_calculator.update_and_get_all(sample))
153# else:
154# sample.stat.update(self._stats_calculator.skip_update_and_get_all(sample))
157@dataclass
158class UpdateStats(Operator):
159 """Calculates sample and/or dataset measures"""
161 stats_calculator: StatsCalculator
162 """`StatsCalculator` to be used by this operator."""
163 keep_updating_initial_dataset_stats: bool = False
164 """indicates if operator calls should keep updating initial dataset statistics or not;
165 if the `stats_calculator` was not provided with any initial dataset statistics,
166 these are always updated with every new sample.
167 """
168 _keep_updating_dataset_stats: bool = field(init=False)
170 @property
171 def required_measures(self) -> Set[Measure]:
172 return set()
174 def __post_init__(self):
175 self._keep_updating_dataset_stats = (
176 self.keep_updating_initial_dataset_stats
177 or not self.stats_calculator.has_dataset_measures
178 )
180 def __call__(self, sample: Union[Sample, SampleBlockWithOrigin]) -> None:
181 if isinstance(sample, SampleBlockWithOrigin):
182 # update stats with whole sample on first block
183 if sample.block_index != 0:
184 return
186 origin = sample.origin
187 else:
188 origin = sample
190 if self._keep_updating_dataset_stats:
191 sample.stat.update(self.stats_calculator.update_and_get_all(origin))
192 else:
193 sample.stat.update(self.stats_calculator.skip_update_and_get_all(origin))
196@dataclass
197class Binarize(_SimpleOperator):
198 """'output = tensor > threshold'."""
200 threshold: Union[float, Sequence[float]]
201 axis: Optional[AxisId] = None
203 def _apply(self, input: Tensor, stat: Stat) -> Tensor:
204 return input > self.threshold
206 def get_output_shape(
207 self, input_shape: Mapping[AxisId, int]
208 ) -> Mapping[AxisId, int]:
209 return input_shape
211 @classmethod
212 def from_proc_descr(
213 cls, descr: Union[v0_4.BinarizeDescr, v0_5.BinarizeDescr], member_id: MemberId
214 ) -> Self:
215 if isinstance(descr.kwargs, (v0_4.BinarizeKwargs, v0_5.BinarizeKwargs)):
216 return cls(
217 input=member_id, output=member_id, threshold=descr.kwargs.threshold
218 )
219 elif isinstance(descr.kwargs, v0_5.BinarizeAlongAxisKwargs):
220 return cls(
221 input=member_id,
222 output=member_id,
223 threshold=descr.kwargs.threshold,
224 axis=descr.kwargs.axis,
225 )
226 else:
227 assert_never(descr.kwargs)
230@dataclass
231class Clip(_SimpleOperator):
232 min: Optional[float] = None
233 """minimum value for clipping"""
234 max: Optional[float] = None
235 """maximum value for clipping"""
237 def __post_init__(self):
238 assert self.min is not None or self.max is not None, "missing min or max value"
239 assert (
240 self.min is None or self.max is None or self.min < self.max
241 ), f"expected min < max, but {self.min} !< {self.max}"
243 def _apply(self, input: Tensor, stat: Stat) -> Tensor:
244 return input.clip(self.min, self.max)
246 def get_output_shape(
247 self, input_shape: Mapping[AxisId, int]
248 ) -> Mapping[AxisId, int]:
249 return input_shape
251 @classmethod
252 def from_proc_descr(
253 cls, descr: Union[v0_4.ClipDescr, v0_5.ClipDescr], member_id: MemberId
254 ) -> Self:
255 return cls(
256 input=member_id,
257 output=member_id,
258 min=descr.kwargs.min,
259 max=descr.kwargs.max,
260 )
263@dataclass
264class EnsureDtype(_SimpleOperator):
265 dtype: DTypeStr
267 @classmethod
268 def from_proc_descr(cls, descr: v0_5.EnsureDtypeDescr, member_id: MemberId):
269 return cls(input=member_id, output=member_id, dtype=descr.kwargs.dtype)
271 def get_descr(self):
272 return v0_5.EnsureDtypeDescr(kwargs=v0_5.EnsureDtypeKwargs(dtype=self.dtype))
274 def get_output_shape(
275 self, input_shape: Mapping[AxisId, int]
276 ) -> Mapping[AxisId, int]:
277 return input_shape
279 def _apply(self, input: Tensor, stat: Stat) -> Tensor:
280 return input.astype(self.dtype)
283@dataclass
284class ScaleLinear(_SimpleOperator):
285 gain: Union[float, xr.DataArray] = 1.0
286 """multiplicative factor"""
288 offset: Union[float, xr.DataArray] = 0.0
289 """additive term"""
291 def _apply(self, input: Tensor, stat: Stat) -> Tensor:
292 return input * self.gain + self.offset
294 def get_output_shape(
295 self, input_shape: Mapping[AxisId, int]
296 ) -> Mapping[AxisId, int]:
297 return input_shape
299 @classmethod
300 def from_proc_descr(
301 cls,
302 descr: Union[v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr],
303 member_id: MemberId,
304 ) -> Self:
305 kwargs = descr.kwargs
306 if isinstance(kwargs, v0_5.ScaleLinearKwargs):
307 axis = None
308 elif isinstance(kwargs, v0_5.ScaleLinearAlongAxisKwargs):
309 axis = kwargs.axis
310 elif isinstance(kwargs, v0_4.ScaleLinearKwargs):
311 if kwargs.axes is not None:
312 raise NotImplementedError(
313 "model.v0_4.ScaleLinearKwargs with axes not implemented, please consider updating the model to v0_5."
314 )
315 axis = None
316 else:
317 assert_never(kwargs)
319 if axis:
320 gain = xr.DataArray(np.atleast_1d(kwargs.gain), dims=axis)
321 offset = xr.DataArray(np.atleast_1d(kwargs.offset), dims=axis)
322 else:
323 assert (
324 isinstance(kwargs.gain, (float, int)) or len(kwargs.gain) == 1
325 ), kwargs.gain
326 gain = (
327 kwargs.gain if isinstance(kwargs.gain, (float, int)) else kwargs.gain[0]
328 )
329 assert isinstance(kwargs.offset, (float, int)) or len(kwargs.offset) == 1
330 offset = (
331 kwargs.offset
332 if isinstance(kwargs.offset, (float, int))
333 else kwargs.offset[0]
334 )
336 return cls(input=member_id, output=member_id, gain=gain, offset=offset)
339@dataclass
340class ScaleMeanVariance(_SimpleOperator):
341 axes: Optional[Sequence[AxisId]] = None
342 reference_tensor: Optional[MemberId] = None
343 eps: float = 1e-6
344 mean: Union[SampleMean, DatasetMean] = field(init=False)
345 std: Union[SampleStd, DatasetStd] = field(init=False)
346 ref_mean: Union[SampleMean, DatasetMean] = field(init=False)
347 ref_std: Union[SampleStd, DatasetStd] = field(init=False)
349 @property
350 def required_measures(self):
351 return {self.mean, self.std, self.ref_mean, self.ref_std}
353 def __post_init__(self):
354 axes = None if self.axes is None else tuple(self.axes)
355 ref_tensor = self.reference_tensor or self.input
356 if axes is None or AxisId("batch") not in axes:
357 Mean = SampleMean
358 Std = SampleStd
359 else:
360 Mean = DatasetMean
361 Std = DatasetStd
363 self.mean = Mean(member_id=self.input, axes=axes)
364 self.std = Std(member_id=self.input, axes=axes)
365 self.ref_mean = Mean(member_id=ref_tensor, axes=axes)
366 self.ref_std = Std(member_id=ref_tensor, axes=axes)
368 def _apply(self, input: Tensor, stat: Stat) -> Tensor:
369 mean = stat[self.mean]
370 std = stat[self.std] + self.eps
371 ref_mean = stat[self.ref_mean]
372 ref_std = stat[self.ref_std] + self.eps
373 return (input - mean) / std * ref_std + ref_mean
375 def get_output_shape(
376 self, input_shape: Mapping[AxisId, int]
377 ) -> Mapping[AxisId, int]:
378 return input_shape
380 @classmethod
381 def from_proc_descr(
382 cls,
383 descr: Union[v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr],
384 member_id: MemberId,
385 ) -> Self:
386 kwargs = descr.kwargs
387 _, axes = _get_axes(descr.kwargs)
389 return cls(
390 input=member_id,
391 output=member_id,
392 reference_tensor=MemberId(str(kwargs.reference_tensor)),
393 axes=axes,
394 eps=kwargs.eps,
395 )
398def _get_axes(
399 kwargs: Union[
400 v0_4.ZeroMeanUnitVarianceKwargs,
401 v0_5.ZeroMeanUnitVarianceKwargs,
402 v0_4.ScaleRangeKwargs,
403 v0_5.ScaleRangeKwargs,
404 v0_4.ScaleMeanVarianceKwargs,
405 v0_5.ScaleMeanVarianceKwargs,
406 ],
407) -> Tuple[bool, Optional[Tuple[AxisId, ...]]]:
408 if kwargs.axes is None:
409 return True, None
410 elif isinstance(kwargs.axes, str):
411 axes = _convert_axis_ids(kwargs.axes, kwargs["mode"])
412 return AxisId("b") in axes, axes
413 elif isinstance(kwargs.axes, collections.abc.Sequence):
414 axes = tuple(kwargs.axes)
415 return AxisId("batch") in axes, axes
416 else:
417 assert_never(kwargs.axes)
420@dataclass
421class ScaleRange(_SimpleOperator):
422 lower_percentile: InitVar[Optional[Union[SampleQuantile, DatasetPercentile]]] = None
423 upper_percentile: InitVar[Optional[Union[SampleQuantile, DatasetPercentile]]] = None
424 lower: Union[SampleQuantile, DatasetPercentile] = field(init=False)
425 upper: Union[SampleQuantile, DatasetPercentile] = field(init=False)
427 eps: float = 1e-6
429 def __post_init__(
430 self,
431 lower_percentile: Optional[Union[SampleQuantile, DatasetPercentile]],
432 upper_percentile: Optional[Union[SampleQuantile, DatasetPercentile]],
433 ):
434 if lower_percentile is None:
435 tid = self.input if upper_percentile is None else upper_percentile.member_id
436 self.lower = DatasetPercentile(q=0.0, member_id=tid)
437 else:
438 self.lower = lower_percentile
440 if upper_percentile is None:
441 self.upper = DatasetPercentile(q=1.0, member_id=self.lower.member_id)
442 else:
443 self.upper = upper_percentile
445 assert self.lower.member_id == self.upper.member_id
446 assert self.lower.q < self.upper.q
447 assert self.lower.axes == self.upper.axes
449 @property
450 def required_measures(self):
451 return {self.lower, self.upper}
453 def get_output_shape(
454 self, input_shape: Mapping[AxisId, int]
455 ) -> Mapping[AxisId, int]:
456 return input_shape
458 @classmethod
459 def from_proc_descr(
460 cls,
461 descr: Union[v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr],
462 member_id: MemberId,
463 ):
464 kwargs = descr.kwargs
465 ref_tensor = (
466 member_id
467 if kwargs.reference_tensor is None
468 else MemberId(str(kwargs.reference_tensor))
469 )
470 dataset_mode, axes = _get_axes(descr.kwargs)
471 if dataset_mode:
472 Percentile = DatasetPercentile
473 else:
474 Percentile = SampleQuantile
476 return cls(
477 input=member_id,
478 output=member_id,
479 lower_percentile=Percentile(
480 q=kwargs.min_percentile / 100, axes=axes, member_id=ref_tensor
481 ),
482 upper_percentile=Percentile(
483 q=kwargs.max_percentile / 100, axes=axes, member_id=ref_tensor
484 ),
485 )
487 def _apply(self, input: Tensor, stat: Stat) -> Tensor:
488 lower = stat[self.lower]
489 upper = stat[self.upper]
490 return (input - lower) / (upper - lower + self.eps)
492 def get_descr(self):
493 assert self.lower.axes == self.upper.axes
494 assert self.lower.member_id == self.upper.member_id
496 return v0_5.ScaleRangeDescr(
497 kwargs=v0_5.ScaleRangeKwargs(
498 axes=self.lower.axes,
499 min_percentile=self.lower.q * 100,
500 max_percentile=self.upper.q * 100,
501 eps=self.eps,
502 reference_tensor=self.lower.member_id,
503 )
504 )
507@dataclass
508class Sigmoid(_SimpleOperator):
509 """1 / (1 + e^(-input))."""
511 def _apply(self, input: Tensor, stat: Stat) -> Tensor:
512 return Tensor(1.0 / (1.0 + np.exp(-input)), dims=input.dims)
514 @property
515 def required_measures(self) -> Collection[Measure]:
516 return {}
518 def get_output_shape(
519 self, input_shape: Mapping[AxisId, int]
520 ) -> Mapping[AxisId, int]:
521 return input_shape
523 @classmethod
524 def from_proc_descr(
525 cls, descr: Union[v0_4.SigmoidDescr, v0_5.SigmoidDescr], member_id: MemberId
526 ) -> Self:
527 assert isinstance(descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr))
528 return cls(input=member_id, output=member_id)
530 def get_descr(self):
531 return v0_5.SigmoidDescr()
534@dataclass
535class ZeroMeanUnitVariance(_SimpleOperator):
536 """normalize to zero mean, unit variance."""
538 mean: MeanMeasure
539 std: StdMeasure
541 eps: float = 1e-6
543 def __post_init__(self):
544 assert self.mean.axes == self.std.axes
546 @property
547 def required_measures(self) -> Set[Union[MeanMeasure, StdMeasure]]:
548 return {self.mean, self.std}
550 def get_output_shape(
551 self, input_shape: Mapping[AxisId, int]
552 ) -> Mapping[AxisId, int]:
553 return input_shape
555 @classmethod
556 def from_proc_descr(
557 cls,
558 descr: Union[v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr],
559 member_id: MemberId,
560 ):
561 dataset_mode, axes = _get_axes(descr.kwargs)
563 if dataset_mode:
564 Mean = DatasetMean
565 Std = DatasetStd
566 else:
567 Mean = SampleMean
568 Std = SampleStd
570 return cls(
571 input=member_id,
572 output=member_id,
573 mean=Mean(axes=axes, member_id=member_id),
574 std=Std(axes=axes, member_id=member_id),
575 )
577 def _apply(self, input: Tensor, stat: Stat) -> Tensor:
578 mean = stat[self.mean]
579 std = stat[self.std]
580 return (input - mean) / (std + self.eps)
582 def get_descr(self):
583 return v0_5.ZeroMeanUnitVarianceDescr(
584 kwargs=v0_5.ZeroMeanUnitVarianceKwargs(axes=self.mean.axes, eps=self.eps)
585 )
588@dataclass
589class FixedZeroMeanUnitVariance(_SimpleOperator):
590 """normalize to zero mean, unit variance with precomputed values."""
592 mean: Union[float, xr.DataArray]
593 std: Union[float, xr.DataArray]
595 eps: float = 1e-6
597 def __post_init__(self):
598 assert (
599 isinstance(self.mean, (int, float))
600 or isinstance(self.std, (int, float))
601 or self.mean.dims == self.std.dims
602 )
604 def get_output_shape(
605 self, input_shape: Mapping[AxisId, int]
606 ) -> Mapping[AxisId, int]:
607 return input_shape
609 @classmethod
610 def from_proc_descr(
611 cls,
612 descr: v0_5.FixedZeroMeanUnitVarianceDescr,
613 member_id: MemberId,
614 ) -> Self:
615 if isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceKwargs):
616 dims = None
617 elif isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs):
618 dims = (AxisId(descr.kwargs.axis),)
619 else:
620 assert_never(descr.kwargs)
622 return cls(
623 input=member_id,
624 output=member_id,
625 mean=xr.DataArray(descr.kwargs.mean, dims=dims),
626 std=xr.DataArray(descr.kwargs.std, dims=dims),
627 )
629 def get_descr(self):
630 if isinstance(self.mean, (int, float)):
631 assert isinstance(self.std, (int, float))
632 kwargs = v0_5.FixedZeroMeanUnitVarianceKwargs(mean=self.mean, std=self.std)
633 else:
634 assert isinstance(self.std, xr.DataArray)
635 assert len(self.mean.dims) == 1
636 kwargs = v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs(
637 axis=AxisId(str(self.mean.dims[0])),
638 mean=list(self.mean),
639 std=list(self.std),
640 )
642 return v0_5.FixedZeroMeanUnitVarianceDescr(kwargs=kwargs)
644 def _apply(self, input: Tensor, stat: Stat) -> Tensor:
645 return (input - self.mean) / (self.std + self.eps)
648ProcDescr = Union[
649 v0_4.PreprocessingDescr,
650 v0_4.PostprocessingDescr,
651 v0_5.PreprocessingDescr,
652 v0_5.PostprocessingDescr,
653]
655Processing = Union[
656 AddKnownDatasetStats,
657 Binarize,
658 Clip,
659 EnsureDtype,
660 FixedZeroMeanUnitVariance,
661 ScaleLinear,
662 ScaleMeanVariance,
663 ScaleRange,
664 Sigmoid,
665 UpdateStats,
666 ZeroMeanUnitVariance,
667]
670def get_proc(
671 proc_descr: ProcDescr,
672 tensor_descr: Union[
673 v0_4.InputTensorDescr,
674 v0_4.OutputTensorDescr,
675 v0_5.InputTensorDescr,
676 v0_5.OutputTensorDescr,
677 ],
678) -> Processing:
679 member_id = get_member_id(tensor_descr)
681 if isinstance(proc_descr, (v0_4.BinarizeDescr, v0_5.BinarizeDescr)):
682 return Binarize.from_proc_descr(proc_descr, member_id)
683 elif isinstance(proc_descr, (v0_4.ClipDescr, v0_5.ClipDescr)):
684 return Clip.from_proc_descr(proc_descr, member_id)
685 elif isinstance(proc_descr, v0_5.EnsureDtypeDescr):
686 return EnsureDtype.from_proc_descr(proc_descr, member_id)
687 elif isinstance(proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr):
688 return FixedZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id)
689 elif isinstance(proc_descr, (v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr)):
690 return ScaleLinear.from_proc_descr(proc_descr, member_id)
691 elif isinstance(
692 proc_descr, (v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr)
693 ):
694 return ScaleMeanVariance.from_proc_descr(proc_descr, member_id)
695 elif isinstance(proc_descr, (v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr)):
696 return ScaleRange.from_proc_descr(proc_descr, member_id)
697 elif isinstance(proc_descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)):
698 return Sigmoid.from_proc_descr(proc_descr, member_id)
699 elif (
700 isinstance(proc_descr, v0_4.ZeroMeanUnitVarianceDescr)
701 and proc_descr.kwargs.mode == "fixed"
702 ):
703 if not isinstance(
704 tensor_descr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)
705 ):
706 raise TypeError(
707 "Expected v0_4 tensor description for v0_4 processing description"
708 )
710 v5_proc_descr = _convert_proc(proc_descr, tensor_descr.axes)
711 assert isinstance(v5_proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr)
712 return FixedZeroMeanUnitVariance.from_proc_descr(v5_proc_descr, member_id)
713 elif isinstance(
714 proc_descr,
715 (v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr),
716 ):
717 return ZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id)
718 else:
719 assert_never(proc_descr)