Coverage for src / bioimageio / core / proc_ops.py: 79%
416 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-18 12:35 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-18 12:35 +0000
1import collections.abc
2from dataclasses import InitVar, dataclass, field
3from functools import partial
4from typing import (
5 Any,
6 Callable,
7 Collection,
8 Literal,
9 Mapping,
10 Optional,
11 Sequence,
12 Set,
13 Tuple,
14 Union,
15)
17import numpy as np
18import scipy
19import xarray as xr
20from numpy.typing import NDArray
21from typing_extensions import Self, assert_never
23from bioimageio.spec.model import v0_4, v0_5
24from bioimageio.spec.model.v0_5 import (
25 _convert_proc, # pyright: ignore[reportPrivateUsage]
26)
28from ._op_base import BlockwiseOperator, SamplewiseOperator, SimpleOperator
29from ._ops_cellpose import CellposeFlowDynamics
30from ._ops_stardist import StardistPostprocessing2D as StardistPostprocessing2D
31from ._ops_stardist import StardistPostprocessing3D as StardistPostprocessing3D
32from .axis import AxisId
33from .common import DTypeStr, MemberId
34from .digest_spec import get_member_id, import_callable
35from .sample import Sample, SampleBlock
36from .stat_calculators import StatsCalculator
37from .stat_measures import (
38 DatasetMean,
39 DatasetQuantile,
40 DatasetStd,
41 MeanMeasure,
42 Measure,
43 MeasureValue,
44 SampleMean,
45 SampleQuantile,
46 SampleStd,
47 Stat,
48 StdMeasure,
49)
50from .tensor import Tensor
53def _convert_axis_ids(
54 axes: v0_4.AxesInCZYX,
55 mode: Literal["per_sample", "per_dataset"],
56) -> Tuple[AxisId, ...]:
57 if not isinstance(axes, str):
58 return tuple(axes)
60 if mode == "per_sample":
61 ret = []
62 elif mode == "per_dataset":
63 ret = [v0_5.BATCH_AXIS_ID]
64 else:
65 assert_never(mode)
67 ret.extend([AxisId(a) for a in axes])
68 return tuple(ret)
71@dataclass
72class AddKnownDatasetStats(BlockwiseOperator):
73 dataset_stats: Mapping[Measure, MeasureValue]
75 def __post_init__(self):
76 # keep only dataset measures
77 self.dataset_stats = {
78 k: v for k, v in self.dataset_stats.items() if k.scope == "dataset"
79 }
81 @property
82 def required_measures(self) -> Collection[Measure]:
83 return set()
85 def __call__(self, sample: Union[Sample, SampleBlock]) -> None:
86 sample.stat.update(self.dataset_stats)
89# @dataclass
90# class UpdateStats(Operator):
91# """Calculates sample and/or dataset measures"""
93# measures: Union[Sequence[Measure], Set[Measure], Mapping[Measure, MeasureValue]]
94# """sample and dataset `measuers` to be calculated by this operator. Initial/fixed
95# dataset measure values may be given, see `keep_updating_dataset_stats` for details.
96# """
97# keep_updating_dataset_stats: Optional[bool] = None
98# """indicates if operator calls should keep updating dataset statistics or not
100# default (None): if `measures` is a `Mapping` (i.e. initial measure values are
101# given) no further updates to dataset statistics is conducted, otherwise (w.o.
102# initial measure values) dataset statistics are updated by each processed sample.
103# """
104# _keep_updating_dataset_stats: bool = field(init=False)
105# _stats_calculator: StatsCalculator = field(init=False)
107# @property
108# def required_measures(self) -> Collection[Measure]:
109# return set()
111# def __post_init__(self):
112# self._stats_calculator = StatsCalculator(self.measures)
113# if self.keep_updating_dataset_stats is None:
114# self._keep_updating_dataset_stats = not isinstance(self.measures, collections.abc.Mapping)
115# else:
116# self._keep_updating_dataset_stats = self.keep_updating_dataset_stats
118# def __call__(self, sample_block: SampleBlockWithOrigin> None:
119# if self._keep_updating_dataset_stats:
120# sample.stat.update(self._stats_calculator.update_and_get_all(sample))
121# else:
122# sample.stat.update(self._stats_calculator.skip_update_and_get_all(sample))
125@dataclass
126class UpdateStats(SamplewiseOperator):
127 """Calculates sample and/or dataset measures"""
129 stats_calculator: StatsCalculator
130 """`StatsCalculator` to be used by this operator."""
131 keep_updating_initial_dataset_stats: bool = False
132 """indicates if operator calls should keep updating initial dataset statistics or not;
133 if the `stats_calculator` was not provided with any initial dataset statistics,
134 these are always updated with every new sample.
135 """
136 _keep_updating_dataset_stats: bool = field(init=False)
138 @property
139 def required_measures(self) -> Collection[Measure]:
140 return set()
142 def __post_init__(self):
143 self._keep_updating_dataset_stats = (
144 self.keep_updating_initial_dataset_stats
145 or not self.stats_calculator.has_dataset_measures
146 )
148 def __call__(self, sample: Sample) -> None:
149 if self._keep_updating_dataset_stats:
150 sample.stat.update(self.stats_calculator.update_and_get_all(sample))
151 else:
152 sample.stat.update(self.stats_calculator.skip_update_and_get_all(sample))
155@dataclass
156class Binarize(SimpleOperator):
157 """'output = tensor > threshold'."""
159 threshold: Union[float, Sequence[float]]
160 axis: Optional[AxisId] = None
162 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
163 return x > self.threshold
165 @property
166 def required_measures(self) -> Collection[Measure]:
167 return set()
169 def get_output_shape(
170 self, input_shape: Mapping[AxisId, int]
171 ) -> Mapping[AxisId, int]:
172 return input_shape
174 @classmethod
175 def from_proc_descr(
176 cls, descr: Union[v0_4.BinarizeDescr, v0_5.BinarizeDescr], member_id: MemberId
177 ) -> Self:
178 if isinstance(descr.kwargs, (v0_4.BinarizeKwargs, v0_5.BinarizeKwargs)):
179 return cls(
180 input=member_id, output=member_id, threshold=descr.kwargs.threshold
181 )
182 elif isinstance(descr.kwargs, v0_5.BinarizeAlongAxisKwargs):
183 return cls(
184 input=member_id,
185 output=member_id,
186 threshold=descr.kwargs.threshold,
187 axis=descr.kwargs.axis,
188 )
189 else:
190 assert_never(descr.kwargs)
193@dataclass
194class Clip(SimpleOperator):
195 min: Optional[Union[float, SampleQuantile, DatasetQuantile]] = None
196 """minimum value for clipping"""
197 max: Optional[Union[float, SampleQuantile, DatasetQuantile]] = None
198 """maximum value for clipping"""
200 def __post_init__(self):
201 if self.min is None and self.max is None:
202 raise ValueError("missing min or max value")
204 if (
205 isinstance(self.min, float)
206 and isinstance(self.max, float)
207 and self.min >= self.max
208 ):
209 raise ValueError(f"expected min < max, but {self.min} >= {self.max}")
211 if isinstance(self.min, (SampleQuantile, DatasetQuantile)) and isinstance(
212 self.max, (SampleQuantile, DatasetQuantile)
213 ):
214 if self.min.axes != self.max.axes:
215 raise NotImplementedError(
216 f"expected min and max quantiles with same axes, but got {self.min.axes} and {self.max.axes}"
217 )
218 if self.min.q >= self.max.q:
219 raise ValueError(
220 f"expected min quantile < max quantile, but {self.min.q} >= {self.max.q}"
221 )
223 @property
224 def required_measures(self):
225 return {
226 arg
227 for arg in (self.min, self.max)
228 if isinstance(arg, (SampleQuantile, DatasetQuantile))
229 }
231 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
232 if isinstance(self.min, (SampleQuantile, DatasetQuantile)):
233 min_value = stat[self.min]
234 if isinstance(min_value, (int, float)):
235 # use clip for scalar value
236 min_clip_arg = min_value
237 else:
238 # clip does not support non-scalar values
239 x = Tensor.from_xarray(
240 x.data.where(x.data >= min_value.data, min_value.data)
241 )
242 min_clip_arg = None
243 else:
244 min_clip_arg = self.min
246 if isinstance(self.max, (SampleQuantile, DatasetQuantile)):
247 max_value = stat[self.max]
248 if isinstance(max_value, (int, float)):
249 # use clip for scalar value
250 max_clip_arg = max_value
251 else:
252 # clip does not support non-scalar values
253 x = Tensor.from_xarray(
254 x.data.where(x.data <= max_value.data, max_value.data)
255 )
256 max_clip_arg = None
257 else:
258 max_clip_arg = self.max
260 if min_clip_arg is not None or max_clip_arg is not None:
261 x = x.clip(min_clip_arg, max_clip_arg)
263 return x
265 def get_output_shape(
266 self, input_shape: Mapping[AxisId, int]
267 ) -> Mapping[AxisId, int]:
268 return input_shape
270 @classmethod
271 def from_proc_descr(
272 cls, descr: Union[v0_4.ClipDescr, v0_5.ClipDescr], member_id: MemberId
273 ) -> Self:
274 if isinstance(descr, v0_5.ClipDescr):
275 dataset_mode, axes = _get_axes(descr.kwargs)
276 if dataset_mode:
277 Quantile = DatasetQuantile
278 else:
279 Quantile = partial(SampleQuantile, method="inverted_cdf")
281 if descr.kwargs.min is not None:
282 min_arg = descr.kwargs.min
283 elif descr.kwargs.min_percentile is not None:
284 min_arg = Quantile(
285 q=descr.kwargs.min_percentile / 100,
286 axes=axes,
287 member_id=member_id,
288 )
289 else:
290 min_arg = None
292 if descr.kwargs.max is not None:
293 max_arg = descr.kwargs.max
294 elif descr.kwargs.max_percentile is not None:
295 max_arg = Quantile(
296 q=descr.kwargs.max_percentile / 100,
297 axes=axes,
298 member_id=member_id,
299 )
300 else:
301 max_arg = None
303 elif isinstance(descr, v0_4.ClipDescr):
304 min_arg = descr.kwargs.min
305 max_arg = descr.kwargs.max
306 else:
307 assert_never(descr)
309 return cls(
310 input=member_id,
311 output=member_id,
312 min=min_arg,
313 max=max_arg,
314 )
317@dataclass
318class EnsureDtype(SimpleOperator):
319 dtype: DTypeStr
321 @classmethod
322 def from_proc_descr(cls, descr: v0_5.EnsureDtypeDescr, member_id: MemberId):
323 return cls(input=member_id, output=member_id, dtype=descr.kwargs.dtype)
325 def get_descr(self):
326 return v0_5.EnsureDtypeDescr(kwargs=v0_5.EnsureDtypeKwargs(dtype=self.dtype))
328 def get_output_shape(
329 self, input_shape: Mapping[AxisId, int]
330 ) -> Mapping[AxisId, int]:
331 return input_shape
333 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
334 return x.astype(self.dtype)
336 @property
337 def required_measures(self) -> Collection[Measure]:
338 return set()
341@dataclass
342class ScaleLinear(SimpleOperator):
343 gain: Union[float, xr.DataArray] = 1.0
344 """multiplicative factor"""
346 offset: Union[float, xr.DataArray] = 0.0
347 """additive term"""
349 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
350 return x * self.gain + self.offset
352 @property
353 def required_measures(self) -> Collection[Measure]:
354 return set()
356 def get_output_shape(
357 self, input_shape: Mapping[AxisId, int]
358 ) -> Mapping[AxisId, int]:
359 return input_shape
361 @classmethod
362 def from_proc_descr(
363 cls,
364 descr: Union[v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr],
365 member_id: MemberId,
366 ) -> Self:
367 kwargs = descr.kwargs
368 if isinstance(kwargs, v0_5.ScaleLinearKwargs):
369 axis = None
370 elif isinstance(kwargs, v0_5.ScaleLinearAlongAxisKwargs):
371 axis = kwargs.axis
372 elif isinstance(kwargs, v0_4.ScaleLinearKwargs):
373 if kwargs.axes is not None:
374 raise NotImplementedError(
375 "model.v0_4.ScaleLinearKwargs with axes not implemented, please consider updating the model to v0_5."
376 )
377 axis = None
378 else:
379 assert_never(kwargs)
381 if axis:
382 gain = xr.DataArray(np.atleast_1d(kwargs.gain), dims=axis)
383 offset = xr.DataArray(np.atleast_1d(kwargs.offset), dims=axis)
384 else:
385 assert isinstance(kwargs.gain, (float, int)) or len(kwargs.gain) == 1, (
386 kwargs.gain
387 )
388 gain = (
389 kwargs.gain if isinstance(kwargs.gain, (float, int)) else kwargs.gain[0]
390 )
391 assert isinstance(kwargs.offset, (float, int)) or len(kwargs.offset) == 1
392 offset = (
393 kwargs.offset
394 if isinstance(kwargs.offset, (float, int))
395 else kwargs.offset[0]
396 )
398 return cls(input=member_id, output=member_id, gain=gain, offset=offset)
401@dataclass
402class ScaleMeanVariance(SimpleOperator):
403 axes: Optional[Sequence[AxisId]] = None
404 reference_tensor: Optional[MemberId] = None
405 eps: float = 1e-6
406 mean: Union[SampleMean, DatasetMean] = field(init=False)
407 std: Union[SampleStd, DatasetStd] = field(init=False)
408 ref_mean: Union[SampleMean, DatasetMean] = field(init=False)
409 ref_std: Union[SampleStd, DatasetStd] = field(init=False)
411 @property
412 def required_measures(self):
413 return {self.mean, self.std, self.ref_mean, self.ref_std}
415 def __post_init__(self):
416 axes = None if self.axes is None else tuple(self.axes)
417 ref_tensor = self.reference_tensor or self.input
418 if axes is None or AxisId("batch") not in axes:
419 Mean = SampleMean
420 Std = SampleStd
421 else:
422 Mean = DatasetMean
423 Std = DatasetStd
425 self.mean = Mean(member_id=self.input, axes=axes)
426 self.std = Std(member_id=self.input, axes=axes)
427 self.ref_mean = Mean(member_id=ref_tensor, axes=axes)
428 self.ref_std = Std(member_id=ref_tensor, axes=axes)
430 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
431 mean = stat[self.mean]
432 std = stat[self.std] + self.eps
433 ref_mean = stat[self.ref_mean]
434 ref_std = stat[self.ref_std] + self.eps
435 return (x - mean) / std * ref_std + ref_mean
437 def get_output_shape(
438 self, input_shape: Mapping[AxisId, int]
439 ) -> Mapping[AxisId, int]:
440 return input_shape
442 @classmethod
443 def from_proc_descr(
444 cls,
445 descr: Union[v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr],
446 member_id: MemberId,
447 ) -> Self:
448 kwargs = descr.kwargs
449 _, axes = _get_axes(descr.kwargs)
451 return cls(
452 input=member_id,
453 output=member_id,
454 reference_tensor=MemberId(str(kwargs.reference_tensor)),
455 axes=axes,
456 eps=kwargs.eps,
457 )
460def _get_axes(
461 kwargs: Union[
462 v0_4.ZeroMeanUnitVarianceKwargs,
463 v0_5.ZeroMeanUnitVarianceKwargs,
464 v0_4.ScaleRangeKwargs,
465 v0_5.ScaleRangeKwargs,
466 v0_4.ScaleMeanVarianceKwargs,
467 v0_5.ScaleMeanVarianceKwargs,
468 v0_5.ClipKwargs,
469 ],
470) -> Tuple[bool, Optional[Tuple[AxisId, ...]]]:
471 if kwargs.axes is None:
472 return True, None
473 elif isinstance(kwargs.axes, str):
474 axes = _convert_axis_ids(kwargs.axes, kwargs["mode"])
475 return AxisId("b") in axes, axes
476 elif isinstance(kwargs.axes, collections.abc.Sequence):
477 axes = tuple(kwargs.axes)
478 return AxisId("batch") in axes, axes
479 else:
480 assert_never(kwargs.axes)
483@dataclass
484class ScaleRange(SimpleOperator):
485 lower_quantile: InitVar[Optional[Union[SampleQuantile, DatasetQuantile]]] = None
486 upper_quantile: InitVar[Optional[Union[SampleQuantile, DatasetQuantile]]] = None
487 lower: Union[SampleQuantile, DatasetQuantile] = field(init=False)
488 upper: Union[SampleQuantile, DatasetQuantile] = field(init=False)
490 eps: float = 1e-6
492 def __post_init__(
493 self,
494 lower_quantile: Optional[Union[SampleQuantile, DatasetQuantile]],
495 upper_quantile: Optional[Union[SampleQuantile, DatasetQuantile]],
496 ):
497 if lower_quantile is None:
498 tid = self.input if upper_quantile is None else upper_quantile.member_id
499 self.lower = DatasetQuantile(q=0.0, member_id=tid)
500 else:
501 self.lower = lower_quantile
503 if upper_quantile is None:
504 self.upper = DatasetQuantile(q=1.0, member_id=self.lower.member_id)
505 else:
506 self.upper = upper_quantile
508 assert self.lower.member_id == self.upper.member_id
509 assert self.lower.q < self.upper.q
510 assert self.lower.axes == self.upper.axes
512 @property
513 def required_measures(self):
514 return {self.lower, self.upper}
516 def get_output_shape(
517 self, input_shape: Mapping[AxisId, int]
518 ) -> Mapping[AxisId, int]:
519 return input_shape
521 @classmethod
522 def from_proc_descr(
523 cls,
524 descr: Union[v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr],
525 member_id: MemberId,
526 ):
527 kwargs = descr.kwargs
528 ref_tensor = (
529 member_id
530 if kwargs.reference_tensor is None
531 else MemberId(str(kwargs.reference_tensor))
532 )
533 dataset_mode, axes = _get_axes(descr.kwargs)
534 if dataset_mode:
535 Quantile = DatasetQuantile
536 else:
537 Quantile = partial(SampleQuantile, method="linear")
539 return cls(
540 input=member_id,
541 output=member_id,
542 lower_quantile=Quantile(
543 q=kwargs.min_percentile / 100,
544 axes=axes,
545 member_id=ref_tensor,
546 ),
547 upper_quantile=Quantile(
548 q=kwargs.max_percentile / 100,
549 axes=axes,
550 member_id=ref_tensor,
551 ),
552 )
554 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
555 lower = stat[self.lower]
556 upper = stat[self.upper]
557 return (x - lower) / (upper - lower + self.eps)
559 def get_descr(self):
560 assert self.lower.axes == self.upper.axes
561 assert self.lower.member_id == self.upper.member_id
563 return v0_5.ScaleRangeDescr(
564 kwargs=v0_5.ScaleRangeKwargs(
565 axes=self.lower.axes,
566 min_percentile=self.lower.q * 100,
567 max_percentile=self.upper.q * 100,
568 eps=self.eps,
569 reference_tensor=self.lower.member_id,
570 )
571 )
574@dataclass
575class Sigmoid(SimpleOperator):
576 """1 / (1 + e^(-input))."""
578 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
579 return Tensor(1.0 / (1.0 + np.exp(-x)), dims=x.dims)
581 @property
582 def required_measures(self) -> Collection[Measure]:
583 return {}
585 def get_output_shape(
586 self, input_shape: Mapping[AxisId, int]
587 ) -> Mapping[AxisId, int]:
588 return input_shape
590 @classmethod
591 def from_proc_descr(
592 cls, descr: Union[v0_4.SigmoidDescr, v0_5.SigmoidDescr], member_id: MemberId
593 ) -> Self:
594 assert isinstance(descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr))
595 return cls(input=member_id, output=member_id)
597 def get_descr(self):
598 return v0_5.SigmoidDescr()
601@dataclass
602class Softmax(SimpleOperator):
603 """Softmax activation function."""
605 axis: AxisId = AxisId("channel")
607 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
608 axis_idx = x.dims.index(self.axis)
609 result = scipy.special.softmax(x.data, axis=axis_idx)
610 result_xr = xr.DataArray(result, dims=x.dims)
611 return Tensor.from_xarray(result_xr)
613 @property
614 def required_measures(self) -> Collection[Measure]:
615 return set()
617 def get_output_shape(
618 self, input_shape: Mapping[AxisId, int]
619 ) -> Mapping[AxisId, int]:
620 return input_shape
622 @classmethod
623 def from_proc_descr(cls, descr: v0_5.SoftmaxDescr, member_id: MemberId) -> Self:
624 assert isinstance(descr, v0_5.SoftmaxDescr)
625 return cls(input=member_id, output=member_id, axis=descr.kwargs.axis)
627 def get_descr(self):
628 return v0_5.SoftmaxDescr(kwargs=v0_5.SoftmaxKwargs(axis=self.axis))
631@dataclass
632class ZeroMeanUnitVariance(SimpleOperator):
633 """normalize to zero mean, unit variance."""
635 mean: MeanMeasure
636 std: StdMeasure
638 eps: float = 1e-6
640 def __post_init__(self):
641 assert self.mean.axes == self.std.axes
643 @property
644 def required_measures(self) -> Set[Union[MeanMeasure, StdMeasure]]:
645 return {self.mean, self.std}
647 def get_output_shape(
648 self, input_shape: Mapping[AxisId, int]
649 ) -> Mapping[AxisId, int]:
650 return input_shape
652 @classmethod
653 def from_proc_descr(
654 cls,
655 descr: Union[v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr],
656 member_id: MemberId,
657 ):
658 dataset_mode, axes = _get_axes(descr.kwargs)
660 if dataset_mode:
661 Mean = DatasetMean
662 Std = DatasetStd
663 else:
664 Mean = SampleMean
665 Std = SampleStd
667 return cls(
668 input=member_id,
669 output=member_id,
670 mean=Mean(axes=axes, member_id=member_id),
671 std=Std(axes=axes, member_id=member_id),
672 )
674 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
675 mean = stat[self.mean]
676 std = stat[self.std]
677 return (x - mean) / (std + self.eps)
679 def get_descr(self):
680 return v0_5.ZeroMeanUnitVarianceDescr(
681 kwargs=v0_5.ZeroMeanUnitVarianceKwargs(axes=self.mean.axes, eps=self.eps)
682 )
685@dataclass
686class FixedZeroMeanUnitVariance(SimpleOperator):
687 """normalize to zero mean, unit variance with precomputed values."""
689 mean: Union[float, xr.DataArray]
690 std: Union[float, xr.DataArray]
692 eps: float = 1e-6
694 def __post_init__(self):
695 assert (
696 isinstance(self.mean, (int, float))
697 or isinstance(self.std, (int, float))
698 or self.mean.dims == self.std.dims
699 )
701 @property
702 def required_measures(self) -> Collection[Measure]:
703 return set()
705 def get_output_shape(
706 self, input_shape: Mapping[AxisId, int]
707 ) -> Mapping[AxisId, int]:
708 return input_shape
710 @classmethod
711 def from_proc_descr(
712 cls,
713 descr: v0_5.FixedZeroMeanUnitVarianceDescr,
714 member_id: MemberId,
715 ) -> Self:
716 if isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceKwargs):
717 dims = None
718 elif isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs):
719 dims = (AxisId(descr.kwargs.axis),)
720 else:
721 assert_never(descr.kwargs)
723 return cls(
724 input=member_id,
725 output=member_id,
726 mean=xr.DataArray(descr.kwargs.mean, dims=dims),
727 std=xr.DataArray(descr.kwargs.std, dims=dims),
728 )
730 def get_descr(self):
731 if isinstance(self.mean, (int, float)):
732 assert isinstance(self.std, (int, float))
733 kwargs = v0_5.FixedZeroMeanUnitVarianceKwargs(mean=self.mean, std=self.std)
734 else:
735 assert isinstance(self.std, xr.DataArray)
736 assert len(self.mean.dims) == 1
737 kwargs = v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs(
738 axis=AxisId(str(self.mean.dims[0])),
739 mean=list(self.mean),
740 std=list(self.std),
741 )
743 return v0_5.FixedZeroMeanUnitVarianceDescr(kwargs=kwargs)
745 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
746 return (x - self.mean) / (self.std + self.eps)
749@dataclass
750class CustomProcessing(SimpleOperator):
751 """Execute a user-supplied custom processing callable.
753 Two styles are supported — callable class and factory function::
755 # Callable class style
756 class my_factory:
757 def __init__(self, threshold=0.5):
758 self.threshold = threshold
759 def __call__(self, *arrays):
760 return (arrays[0] > self.threshold).astype(np.uint8)
762 # Factory function style
763 def my_factory(threshold=0.5):
764 def run(*arrays):
765 return (arrays[0] > threshold).astype(np.uint8)
766 return run
768 Runtime protocol: ``custom_callable = my_factory(**kwargs)`` once at construction;
769 ``result = custom_callable(tensor)`` once per sample.
771 Note: The custom callable may not change the shape of the input tensor.
772 """
774 custom_factory: Callable[..., Callable[[NDArray[Any]], NDArray[Any]]]
776 kwargs: Mapping[str, Any]
777 """Keyword arguments forwarded to the custom factory."""
779 # Initialised in __post_init__
780 _custom_callable: Any = field(init=False, repr=False)
782 def __post_init__(self) -> None:
783 self._custom_callable = self.custom_factory(**self.kwargs)
785 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
786 return Tensor.from_numpy(self._custom_callable(x.to_numpy()), dims=x.dims)
788 def get_output_shape(
789 self, input_shape: Mapping[AxisId, int]
790 ) -> Mapping[AxisId, int]:
791 return input_shape
793 @property
794 def required_measures(self) -> Collection[Measure]:
795 return set()
797 @classmethod
798 def from_proc_descr(
799 cls,
800 descr: v0_5.CustomProcessingDescr,
801 member_id: MemberId,
802 ) -> Self:
803 factory = import_callable(descr)
805 return cls(
806 input=member_id,
807 output=member_id,
808 custom_factory=factory,
809 kwargs=dict(descr.kwargs),
810 )
813ProcDescr = Union[
814 v0_4.PreprocessingDescr,
815 v0_4.PostprocessingDescr,
816 v0_5.PreprocessingDescr,
817 v0_5.PostprocessingDescr,
818]
821Processing = Union[
822 AddKnownDatasetStats,
823 Binarize,
824 Clip,
825 CellposeFlowDynamics,
826 CustomProcessing,
827 EnsureDtype,
828 FixedZeroMeanUnitVariance,
829 ScaleLinear,
830 ScaleMeanVariance,
831 ScaleRange,
832 Sigmoid,
833 StardistPostprocessing2D,
834 StardistPostprocessing3D,
835 Softmax,
836 UpdateStats,
837 ZeroMeanUnitVariance,
838]
841def get_proc(
842 proc_descr: ProcDescr,
843 tensor_descr: Union[
844 v0_4.InputTensorDescr,
845 v0_4.OutputTensorDescr,
846 v0_5.InputTensorDescr,
847 v0_5.OutputTensorDescr,
848 ],
849) -> Processing:
850 member_id = get_member_id(tensor_descr)
852 if isinstance(proc_descr, (v0_4.BinarizeDescr, v0_5.BinarizeDescr)):
853 return Binarize.from_proc_descr(proc_descr, member_id)
854 elif isinstance(proc_descr, v0_5.CellposeFlowDynamicsDescr):
855 return CellposeFlowDynamics.from_proc_descr(proc_descr, member_id)
856 elif isinstance(proc_descr, (v0_4.ClipDescr, v0_5.ClipDescr)):
857 return Clip.from_proc_descr(proc_descr, member_id)
858 elif isinstance(proc_descr, v0_5.CustomProcessingDescr):
859 return CustomProcessing.from_proc_descr(proc_descr, member_id)
860 elif isinstance(proc_descr, v0_5.EnsureDtypeDescr):
861 return EnsureDtype.from_proc_descr(proc_descr, member_id)
862 elif isinstance(proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr):
863 return FixedZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id)
864 elif isinstance(proc_descr, (v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr)):
865 return ScaleLinear.from_proc_descr(proc_descr, member_id)
866 elif isinstance(
867 proc_descr, (v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr)
868 ):
869 return ScaleMeanVariance.from_proc_descr(proc_descr, member_id)
870 elif isinstance(proc_descr, (v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr)):
871 return ScaleRange.from_proc_descr(proc_descr, member_id)
872 elif isinstance(proc_descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)):
873 return Sigmoid.from_proc_descr(proc_descr, member_id)
874 elif (
875 isinstance(proc_descr, v0_4.ZeroMeanUnitVarianceDescr)
876 and proc_descr.kwargs.mode == "fixed"
877 ):
878 if not isinstance(
879 tensor_descr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)
880 ):
881 raise TypeError(
882 "Expected v0_4 tensor description for v0_4 processing description"
883 )
885 v5_proc_descr = _convert_proc(proc_descr, tensor_descr.axes)
886 assert isinstance(v5_proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr)
887 return FixedZeroMeanUnitVariance.from_proc_descr(v5_proc_descr, member_id)
888 elif isinstance(proc_descr, v0_5.SoftmaxDescr):
889 return Softmax.from_proc_descr(proc_descr, member_id)
890 elif isinstance(proc_descr, v0_5.StardistPostprocessingDescr):
891 if isinstance(proc_descr.kwargs, v0_5.StardistPostprocessingKwargs2D):
892 return StardistPostprocessing2D.from_proc_descr(proc_descr, member_id)
893 elif isinstance(proc_descr.kwargs, v0_5.StardistPostprocessingKwargs3D):
894 return StardistPostprocessing3D.from_proc_descr(proc_descr, member_id)
895 else:
896 raise ValueError(
897 f"expected ndim 2 or 3 for stardist postprocessing, but got {proc_descr.kwargs.ndim}"
898 )
899 elif isinstance(
900 proc_descr,
901 (v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr),
902 ):
903 return ZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id)
904 else:
905 assert_never(proc_descr)