Coverage for src/bioimageio/core/proc_ops.py: 76%
416 statements
« prev ^ index » next coverage.py v7.14.2, created at 2026-06-22 16:54 +0000
« prev ^ index » next coverage.py v7.14.2, created at 2026-06-22 16:54 +0000
1import collections.abc
2from dataclasses import InitVar, dataclass, field
3from functools import partial
4from typing import (
5 Any,
6 Callable,
7 Collection,
8 Literal,
9 Mapping,
10 Optional,
11 Sequence,
12 Set,
13 Tuple,
14 Union,
15)
17import numpy as np
18import scipy
19import xarray as xr
20from numpy.typing import NDArray
21from typing_extensions import Self, assert_never
23from bioimageio.spec.model import v0_4, v0_5
24from bioimageio.spec.model.v0_5 import (
25 _convert_proc, # pyright: ignore[reportPrivateUsage]
26)
28from ._op_base import BlockwiseOperator, SamplewiseOperator, SimpleOperator
29from ._ops_cellpose import CellposeFlowDynamics
30from ._ops_stardist import StardistPostprocessing2D as StardistPostprocessing2D
31from ._ops_stardist import StardistPostprocessing3D as StardistPostprocessing3D
32from .axis import AxisId
33from .common import DTypeStr, MemberId
34from .digest_spec import get_member_id, import_callable
35from .sample import Sample, SampleBlock
36from .stat_calculators import StatsCalculator
37from .stat_measures import (
38 DatasetMean,
39 DatasetQuantile,
40 DatasetStd,
41 MeanMeasure,
42 Measure,
43 MeasureValue,
44 SampleMean,
45 SampleQuantile,
46 SampleStd,
47 Stat,
48 StdMeasure,
49)
50from .tensor import Tensor
53def _convert_axis_ids(
54 axes: v0_4.AxesInCZYX,
55 mode: Literal["per_sample", "per_dataset"],
56) -> Tuple[AxisId, ...]:
57 if not isinstance(axes, str):
58 return tuple(axes)
60 if mode == "per_sample":
61 ret = []
62 elif mode == "per_dataset":
63 ret = [AxisId(v0_5.BATCH_AXIS_ID)]
64 else:
65 assert_never(mode)
67 ret.extend([AxisId(a) for a in axes])
68 return tuple(ret)
71@dataclass
72class AddKnownDatasetStats(BlockwiseOperator):
73 dataset_stats: Mapping[Measure, MeasureValue]
75 def __post_init__(self):
76 # keep only dataset measures
77 self.dataset_stats = {
78 k: v for k, v in self.dataset_stats.items() if k.scope == "dataset"
79 }
81 @property
82 def required_measures(self) -> Collection[Measure]:
83 return set()
85 def __call__(self, sample: Union[Sample, SampleBlock]) -> None:
86 sample.stat.update(self.dataset_stats)
89# @dataclass
90# class UpdateStats(Operator):
91# """Calculates sample and/or dataset measures"""
93# measures: Union[Sequence[Measure], Set[Measure], Mapping[Measure, MeasureValue]]
94# """sample and dataset `measuers` to be calculated by this operator. Initial/fixed
95# dataset measure values may be given, see `keep_updating_dataset_stats` for details.
96# """
97# keep_updating_dataset_stats: Optional[bool] = None
98# """indicates if operator calls should keep updating dataset statistics or not
100# default (None): if `measures` is a `Mapping` (i.e. initial measure values are
101# given) no further updates to dataset statistics is conducted, otherwise (w.o.
102# initial measure values) dataset statistics are updated by each processed sample.
103# """
104# _keep_updating_dataset_stats: bool = field(init=False)
105# _stats_calculator: StatsCalculator = field(init=False)
107# @property
108# def required_measures(self) -> Collection[Measure]:
109# return set()
111# def __post_init__(self):
112# self._stats_calculator = StatsCalculator(self.measures)
113# if self.keep_updating_dataset_stats is None:
114# self._keep_updating_dataset_stats = not isinstance(self.measures, collections.abc.Mapping)
115# else:
116# self._keep_updating_dataset_stats = self.keep_updating_dataset_stats
118# def __call__(self, sample_block: SampleBlockWithOrigin> None:
119# if self._keep_updating_dataset_stats:
120# sample.stat.update(self._stats_calculator.update_and_get_all(sample))
121# else:
122# sample.stat.update(self._stats_calculator.skip_update_and_get_all(sample))
125@dataclass
126class UpdateStats(SamplewiseOperator):
127 """Calculates sample and/or dataset measures"""
129 stats_calculator: StatsCalculator
130 """`StatsCalculator` to be used by this operator."""
131 keep_updating_initial_dataset_stats: bool = False
132 """indicates if operator calls should keep updating initial dataset statistics or not;
133 if the `stats_calculator` was not provided with any initial dataset statistics,
134 these are always updated with every new sample.
135 """
136 _keep_updating_dataset_stats: bool = field(init=False)
138 @property
139 def required_measures(self) -> Collection[Measure]:
140 return set()
142 def __post_init__(self):
143 self._keep_updating_dataset_stats = (
144 self.keep_updating_initial_dataset_stats
145 or not self.stats_calculator.has_dataset_measures
146 )
148 def __call__(self, sample: Sample) -> None:
149 if self._keep_updating_dataset_stats:
150 sample.stat.update(self.stats_calculator.update_and_get_all(sample))
151 else:
152 sample.stat.update(self.stats_calculator.skip_update_and_get_all(sample))
155@dataclass
156class Binarize(SimpleOperator):
157 """'output = tensor > threshold'."""
159 threshold: Union[float, Sequence[float]]
160 axis: Optional[AxisId] = None
162 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
163 return x > self.threshold
165 @property
166 def required_measures(self) -> Collection[Measure]:
167 return set()
169 def get_output_shape(
170 self, input_shape: Mapping[AxisId, int]
171 ) -> Mapping[AxisId, int]:
172 return input_shape
174 @classmethod
175 def from_proc_descr(
176 cls, descr: Union[v0_4.BinarizeDescr, v0_5.BinarizeDescr], member_id: MemberId
177 ) -> Self:
178 if isinstance(descr.kwargs, (v0_4.BinarizeKwargs, v0_5.BinarizeKwargs)):
179 return cls(
180 input=member_id, output=member_id, threshold=descr.kwargs.threshold
181 )
182 elif isinstance(descr.kwargs, v0_5.BinarizeAlongAxisKwargs):
183 return cls(
184 input=member_id,
185 output=member_id,
186 threshold=descr.kwargs.threshold,
187 axis=descr.kwargs.axis,
188 )
189 else:
190 assert_never(descr.kwargs)
193@dataclass
194class Clip(SimpleOperator):
195 min: Optional[Union[float, SampleQuantile, DatasetQuantile]] = None
196 """minimum value for clipping"""
197 max: Optional[Union[float, SampleQuantile, DatasetQuantile]] = None
198 """maximum value for clipping"""
200 def __post_init__(self):
201 if self.min is None and self.max is None:
202 raise ValueError("missing min or max value")
204 if (
205 isinstance(self.min, float)
206 and isinstance(self.max, float)
207 and self.min >= self.max
208 ):
209 raise ValueError(f"expected min < max, but {self.min} >= {self.max}")
211 if isinstance(self.min, (SampleQuantile, DatasetQuantile)) and isinstance(
212 self.max, (SampleQuantile, DatasetQuantile)
213 ):
214 if self.min.axes != self.max.axes:
215 raise NotImplementedError(
216 f"expected min and max quantiles with same axes, but got {self.min.axes} and {self.max.axes}"
217 )
218 if self.min.q >= self.max.q:
219 raise ValueError(
220 f"expected min quantile < max quantile, but {self.min.q} >= {self.max.q}"
221 )
223 @property
224 def required_measures(self):
225 return {
226 arg
227 for arg in (self.min, self.max)
228 if isinstance(arg, (SampleQuantile, DatasetQuantile))
229 }
231 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
232 if isinstance(self.min, (SampleQuantile, DatasetQuantile)):
233 min_value = stat[self.min]
234 if isinstance(min_value, (int, float)):
235 # use clip for scalar value
236 min_clip_arg = min_value
237 else:
238 # clip does not support non-scalar values
239 x = Tensor.from_xarray(
240 x.data.where(x.data >= min_value.data, min_value.data)
241 )
242 min_clip_arg = None
243 else:
244 min_clip_arg = self.min
246 if isinstance(self.max, (SampleQuantile, DatasetQuantile)):
247 max_value = stat[self.max]
248 if isinstance(max_value, (int, float)):
249 # use clip for scalar value
250 max_clip_arg = max_value
251 else:
252 # clip does not support non-scalar values
253 x = Tensor.from_xarray(
254 x.data.where(x.data <= max_value.data, max_value.data)
255 )
256 max_clip_arg = None
257 else:
258 max_clip_arg = self.max
260 if min_clip_arg is not None or max_clip_arg is not None:
261 x = x.clip(min_clip_arg, max_clip_arg)
263 return x
265 def get_output_shape(
266 self, input_shape: Mapping[AxisId, int]
267 ) -> Mapping[AxisId, int]:
268 return input_shape
270 @classmethod
271 def from_proc_descr(
272 cls, descr: Union[v0_4.ClipDescr, v0_5.ClipDescr], member_id: MemberId
273 ) -> Self:
274 if isinstance(descr, v0_5.ClipDescr):
275 dataset_mode, axes = _get_axes(descr.kwargs)
276 if dataset_mode:
277 Quantile = DatasetQuantile
278 else:
279 Quantile = partial(SampleQuantile, method="inverted_cdf")
281 if descr.kwargs.min is not None:
282 min_arg = descr.kwargs.min
283 elif descr.kwargs.min_percentile is not None:
284 min_arg = Quantile(
285 q=descr.kwargs.min_percentile / 100,
286 axes=axes,
287 member_id=member_id,
288 )
289 else:
290 min_arg = None
292 if descr.kwargs.max is not None:
293 max_arg = descr.kwargs.max
294 elif descr.kwargs.max_percentile is not None:
295 max_arg = Quantile(
296 q=descr.kwargs.max_percentile / 100,
297 axes=axes,
298 member_id=member_id,
299 )
300 else:
301 max_arg = None
303 elif isinstance(descr, v0_4.ClipDescr):
304 min_arg = descr.kwargs.min
305 max_arg = descr.kwargs.max
306 else:
307 assert_never(descr)
309 return cls(
310 input=member_id,
311 output=member_id,
312 min=min_arg,
313 max=max_arg,
314 )
317@dataclass
318class EnsureDtype(SimpleOperator):
319 dtype: DTypeStr
321 @classmethod
322 def from_proc_descr(cls, descr: v0_5.EnsureDtypeDescr, member_id: MemberId):
323 return cls(input=member_id, output=member_id, dtype=descr.kwargs.dtype)
325 def get_descr(self):
326 return v0_5.EnsureDtypeDescr(kwargs=v0_5.EnsureDtypeKwargs(dtype=self.dtype))
328 def get_output_shape(
329 self, input_shape: Mapping[AxisId, int]
330 ) -> Mapping[AxisId, int]:
331 return input_shape
333 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
334 return x.astype(self.dtype)
336 @property
337 def required_measures(self) -> Collection[Measure]:
338 return set()
341@dataclass
342class ScaleLinear(SimpleOperator):
343 gain: Union[float, xr.DataArray] = 1.0
344 """multiplicative factor"""
346 offset: Union[float, xr.DataArray] = 0.0
347 """additive term"""
349 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
350 return x * self.gain + self.offset
352 @property
353 def required_measures(self) -> Collection[Measure]:
354 return set()
356 def get_output_shape(
357 self, input_shape: Mapping[AxisId, int]
358 ) -> Mapping[AxisId, int]:
359 return input_shape
361 @classmethod
362 def from_proc_descr(
363 cls,
364 descr: Union[v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr],
365 member_id: MemberId,
366 ) -> Self:
367 kwargs = descr.kwargs
368 if isinstance(kwargs, v0_5.ScaleLinearKwargs):
369 axis = None
370 elif isinstance(kwargs, v0_5.ScaleLinearAlongAxisKwargs):
371 axis = kwargs.axis
372 elif isinstance(kwargs, v0_4.ScaleLinearKwargs):
373 if kwargs.axes is not None:
374 raise NotImplementedError(
375 "model.v0_4.ScaleLinearKwargs with axes not implemented, please consider updating the model to v0_5."
376 )
377 axis = None
378 else:
379 assert_never(kwargs)
381 if axis:
382 gain = xr.DataArray(np.atleast_1d(kwargs.gain), dims=axis)
383 offset = xr.DataArray(np.atleast_1d(kwargs.offset), dims=axis)
384 else:
385 assert isinstance(kwargs.gain, (float, int)) or len(kwargs.gain) == 1, (
386 kwargs.gain
387 )
388 gain = (
389 kwargs.gain if isinstance(kwargs.gain, (float, int)) else kwargs.gain[0]
390 )
391 assert isinstance(kwargs.offset, (float, int)) or len(kwargs.offset) == 1
392 offset = (
393 kwargs.offset
394 if isinstance(kwargs.offset, (float, int))
395 else kwargs.offset[0]
396 )
398 return cls(input=member_id, output=member_id, gain=gain, offset=offset)
401@dataclass
402class ScaleMeanVariance(SimpleOperator):
403 axes: Optional[Sequence[AxisId]] = None
404 reference_tensor: Optional[MemberId] = None
405 eps: float = 1e-6
406 mean: Union[SampleMean, DatasetMean] = field(init=False)
407 std: Union[SampleStd, DatasetStd] = field(init=False)
408 ref_mean: Union[SampleMean, DatasetMean] = field(init=False)
409 ref_std: Union[SampleStd, DatasetStd] = field(init=False)
411 @property
412 def required_measures(self):
413 return {self.mean, self.std, self.ref_mean, self.ref_std}
415 def __post_init__(self):
416 axes = None if self.axes is None else tuple(self.axes)
417 ref_tensor = self.reference_tensor or self.input
418 if axes is None or AxisId("batch") not in axes:
419 Mean = SampleMean
420 Std = SampleStd
421 else:
422 Mean = DatasetMean
423 Std = DatasetStd
425 self.mean = Mean(member_id=self.input, axes=axes)
426 self.std = Std(member_id=self.input, axes=axes)
427 self.ref_mean = Mean(member_id=ref_tensor, axes=axes)
428 self.ref_std = Std(member_id=ref_tensor, axes=axes)
430 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
431 mean = stat[self.mean]
432 std = stat[self.std] + self.eps
433 ref_mean = stat[self.ref_mean]
434 ref_std = stat[self.ref_std] + self.eps
435 return (x - mean) / std * ref_std + ref_mean
437 def get_output_shape(
438 self, input_shape: Mapping[AxisId, int]
439 ) -> Mapping[AxisId, int]:
440 return input_shape
442 @classmethod
443 def from_proc_descr(
444 cls,
445 descr: Union[v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr],
446 member_id: MemberId,
447 ) -> Self:
448 kwargs = descr.kwargs
449 _, axes = _get_axes(descr.kwargs)
451 return cls(
452 input=member_id,
453 output=member_id,
454 reference_tensor=MemberId(str(kwargs.reference_tensor)),
455 axes=axes,
456 eps=kwargs.eps,
457 )
460def _get_axes(
461 kwargs: Union[
462 v0_4.ZeroMeanUnitVarianceKwargs,
463 v0_5.ZeroMeanUnitVarianceKwargs,
464 v0_4.ScaleRangeKwargs,
465 v0_5.ScaleRangeKwargs,
466 v0_4.ScaleMeanVarianceKwargs,
467 v0_5.ScaleMeanVarianceKwargs,
468 v0_5.ClipKwargs,
469 ],
470) -> Tuple[bool, Optional[Tuple[AxisId, ...]]]:
471 if kwargs.axes is None:
472 return True, None
473 elif isinstance(kwargs.axes, str):
474 axes = _convert_axis_ids(kwargs.axes, kwargs["mode"])
475 return AxisId("b") in axes, axes
476 elif isinstance(kwargs.axes, collections.abc.Sequence):
477 axes = tuple(kwargs.axes)
478 return AxisId("batch") in axes, axes
479 else:
480 assert_never(kwargs.axes)
483@dataclass
484class ScaleRange(SimpleOperator):
485 lower_quantile: InitVar[Optional[Union[SampleQuantile, DatasetQuantile]]] = None
486 upper_quantile: InitVar[Optional[Union[SampleQuantile, DatasetQuantile]]] = None
487 lower: Union[SampleQuantile, DatasetQuantile] = field(init=False)
488 upper: Union[SampleQuantile, DatasetQuantile] = field(init=False)
490 eps: float = 1e-6
492 def __post_init__(
493 self,
494 lower_quantile: Optional[Union[SampleQuantile, DatasetQuantile]],
495 upper_quantile: Optional[Union[SampleQuantile, DatasetQuantile]],
496 ):
497 if lower_quantile is None:
498 tid = self.input if upper_quantile is None else upper_quantile.member_id
499 self.lower = DatasetQuantile(q=0.0, member_id=tid)
500 else:
501 self.lower = lower_quantile
503 if upper_quantile is None:
504 self.upper = DatasetQuantile(q=1.0, member_id=self.lower.member_id)
505 else:
506 self.upper = upper_quantile
508 assert self.lower.member_id == self.upper.member_id
509 assert self.lower.q < self.upper.q
510 assert self.lower.axes == self.upper.axes
512 @property
513 def required_measures(self):
514 return {self.lower, self.upper}
516 def get_output_shape(
517 self, input_shape: Mapping[AxisId, int]
518 ) -> Mapping[AxisId, int]:
519 return input_shape
521 @classmethod
522 def from_proc_descr(
523 cls,
524 descr: Union[v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr],
525 member_id: MemberId,
526 ):
527 kwargs = descr.kwargs
528 ref_tensor = (
529 member_id
530 if kwargs.reference_tensor is None
531 else MemberId(str(kwargs.reference_tensor))
532 )
533 dataset_mode, axes = _get_axes(descr.kwargs)
534 if dataset_mode:
535 Quantile = DatasetQuantile
536 else:
537 Quantile = partial(SampleQuantile, method="linear")
539 return cls(
540 input=member_id,
541 output=member_id,
542 lower_quantile=Quantile(
543 q=kwargs.min_percentile / 100,
544 axes=axes,
545 member_id=ref_tensor,
546 ),
547 upper_quantile=Quantile(
548 q=kwargs.max_percentile / 100,
549 axes=axes,
550 member_id=ref_tensor,
551 ),
552 )
554 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
555 lower = stat[self.lower]
556 upper = stat[self.upper]
557 return (x - lower) / (upper - lower + self.eps)
559 def get_descr(self):
560 assert self.lower.axes == self.upper.axes
561 assert self.lower.member_id == self.upper.member_id
563 return v0_5.ScaleRangeDescr(
564 kwargs=v0_5.ScaleRangeKwargs(
565 axes=None
566 if self.lower.axes is None
567 else [v0_5.AxisId(a) for a in self.lower.axes],
568 min_percentile=self.lower.q * 100,
569 max_percentile=self.upper.q * 100,
570 eps=self.eps,
571 reference_tensor=self.lower.member_id,
572 )
573 )
576@dataclass
577class Sigmoid(SimpleOperator):
578 """1 / (1 + e^(-input))."""
580 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
581 return Tensor(1.0 / (1.0 + np.exp(-x)), dims=x.dims)
583 @property
584 def required_measures(self) -> Collection[Measure]:
585 return {}
587 def get_output_shape(
588 self, input_shape: Mapping[AxisId, int]
589 ) -> Mapping[AxisId, int]:
590 return input_shape
592 @classmethod
593 def from_proc_descr(
594 cls, descr: Union[v0_4.SigmoidDescr, v0_5.SigmoidDescr], member_id: MemberId
595 ) -> Self:
596 assert isinstance(descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr))
597 return cls(input=member_id, output=member_id)
599 def get_descr(self):
600 return v0_5.SigmoidDescr()
603@dataclass
604class Softmax(SimpleOperator):
605 """Softmax activation function."""
607 axis: AxisId = AxisId("channel")
609 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
610 axis_idx = x.dims.index(self.axis)
611 result = scipy.special.softmax(x.data, axis=axis_idx)
612 result_xr = xr.DataArray(result, dims=x.dims)
613 return Tensor.from_xarray(result_xr)
615 @property
616 def required_measures(self) -> Collection[Measure]:
617 return set()
619 def get_output_shape(
620 self, input_shape: Mapping[AxisId, int]
621 ) -> Mapping[AxisId, int]:
622 return input_shape
624 @classmethod
625 def from_proc_descr(cls, descr: v0_5.SoftmaxDescr, member_id: MemberId) -> Self:
626 assert isinstance(descr, v0_5.SoftmaxDescr)
627 return cls(input=member_id, output=member_id, axis=descr.kwargs.axis)
629 def get_descr(self):
630 return v0_5.SoftmaxDescr(kwargs=v0_5.SoftmaxKwargs(axis=v0_5.AxisId(self.axis)))
633@dataclass
634class ZeroMeanUnitVariance(SimpleOperator):
635 """normalize to zero mean, unit variance."""
637 mean: MeanMeasure
638 std: StdMeasure
640 eps: float = 1e-6
642 def __post_init__(self):
643 assert self.mean.axes == self.std.axes
645 @property
646 def required_measures(self) -> Set[Union[MeanMeasure, StdMeasure]]:
647 return {self.mean, self.std}
649 def get_output_shape(
650 self, input_shape: Mapping[AxisId, int]
651 ) -> Mapping[AxisId, int]:
652 return input_shape
654 @classmethod
655 def from_proc_descr(
656 cls,
657 descr: Union[v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr],
658 member_id: MemberId,
659 ):
660 dataset_mode, axes = _get_axes(descr.kwargs)
662 if dataset_mode:
663 Mean = DatasetMean
664 Std = DatasetStd
665 else:
666 Mean = SampleMean
667 Std = SampleStd
669 return cls(
670 input=member_id,
671 output=member_id,
672 mean=Mean(axes=axes, member_id=member_id),
673 std=Std(axes=axes, member_id=member_id),
674 )
676 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
677 mean = stat[self.mean]
678 std = stat[self.std]
679 return (x - mean) / (std + self.eps)
681 def get_descr(self):
682 return v0_5.ZeroMeanUnitVarianceDescr(
683 kwargs=v0_5.ZeroMeanUnitVarianceKwargs(axes=self.mean.axes, eps=self.eps)
684 )
687@dataclass
688class FixedZeroMeanUnitVariance(SimpleOperator):
689 """normalize to zero mean, unit variance with precomputed values."""
691 mean: Union[float, xr.DataArray]
692 std: Union[float, xr.DataArray]
694 eps: float = 1e-6
696 def __post_init__(self):
697 assert (
698 isinstance(self.mean, (int, float))
699 or isinstance(self.std, (int, float))
700 or self.mean.dims == self.std.dims
701 )
703 @property
704 def required_measures(self) -> Collection[Measure]:
705 return set()
707 def get_output_shape(
708 self, input_shape: Mapping[AxisId, int]
709 ) -> Mapping[AxisId, int]:
710 return input_shape
712 @classmethod
713 def from_proc_descr(
714 cls,
715 descr: v0_5.FixedZeroMeanUnitVarianceDescr,
716 member_id: MemberId,
717 ) -> Self:
718 if isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceKwargs):
719 dims = None
720 elif isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs):
721 dims = (AxisId(descr.kwargs.axis),)
722 else:
723 assert_never(descr.kwargs)
725 return cls(
726 input=member_id,
727 output=member_id,
728 mean=xr.DataArray(descr.kwargs.mean, dims=dims),
729 std=xr.DataArray(descr.kwargs.std, dims=dims),
730 )
732 def get_descr(self):
733 if isinstance(self.mean, (int, float)):
734 assert isinstance(self.std, (int, float))
735 kwargs = v0_5.FixedZeroMeanUnitVarianceKwargs(mean=self.mean, std=self.std)
736 else:
737 assert isinstance(self.std, xr.DataArray)
738 assert len(self.mean.dims) == 1
739 kwargs = v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs(
740 axis=AxisId(str(self.mean.dims[0])),
741 mean=list(self.mean),
742 std=list(self.std),
743 )
745 return v0_5.FixedZeroMeanUnitVarianceDescr(kwargs=kwargs)
747 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
748 return (x - self.mean) / (self.std + self.eps)
751@dataclass
752class CustomProcessing(SimpleOperator):
753 """Execute a user-supplied custom processing callable.
755 Two styles are supported — callable class and factory function::
757 # Callable class style
758 class my_factory:
759 def __init__(self, threshold=0.5):
760 self.threshold = threshold
761 def __call__(self, *arrays):
762 return (arrays[0] > self.threshold).astype(np.uint8)
764 # Factory function style
765 def my_factory(threshold=0.5):
766 def run(*arrays):
767 return (arrays[0] > threshold).astype(np.uint8)
768 return run
770 Runtime protocol: ``custom_callable = my_factory(**kwargs)`` once at construction;
771 ``result = custom_callable(tensor)`` once per sample.
773 Note: The custom callable may not change the shape of the input tensor.
774 """
776 custom_factory: Callable[..., Callable[[NDArray[Any]], NDArray[Any]]]
778 kwargs: Mapping[str, Any]
779 """Keyword arguments forwarded to the custom factory."""
781 # Initialised in __post_init__
782 _custom_callable: Any = field(init=False, repr=False)
784 def __post_init__(self) -> None:
785 self._custom_callable = self.custom_factory(**self.kwargs)
787 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
788 return Tensor.from_numpy(self._custom_callable(x.to_numpy()), dims=x.dims)
790 def get_output_shape(
791 self, input_shape: Mapping[AxisId, int]
792 ) -> Mapping[AxisId, int]:
793 return input_shape
795 @property
796 def required_measures(self) -> Collection[Measure]:
797 return set()
799 @classmethod
800 def from_proc_descr(
801 cls,
802 descr: v0_5.CustomProcessingDescr,
803 member_id: MemberId,
804 ) -> Self:
805 factory = import_callable(descr)
807 return cls(
808 input=member_id,
809 output=member_id,
810 custom_factory=factory,
811 kwargs=dict(descr.kwargs),
812 )
815ProcDescr = Union[
816 v0_4.PreprocessingDescr,
817 v0_4.PostprocessingDescr,
818 v0_5.PreprocessingDescr,
819 v0_5.PostprocessingDescr,
820]
823Processing = Union[
824 AddKnownDatasetStats,
825 Binarize,
826 Clip,
827 CellposeFlowDynamics,
828 CustomProcessing,
829 EnsureDtype,
830 FixedZeroMeanUnitVariance,
831 ScaleLinear,
832 ScaleMeanVariance,
833 ScaleRange,
834 Sigmoid,
835 StardistPostprocessing2D,
836 StardistPostprocessing3D,
837 Softmax,
838 UpdateStats,
839 ZeroMeanUnitVariance,
840]
843def get_proc(
844 proc_descr: ProcDescr,
845 tensor_descr: Union[
846 v0_4.InputTensorDescr,
847 v0_4.OutputTensorDescr,
848 v0_5.InputTensorDescr,
849 v0_5.OutputTensorDescr,
850 ],
851) -> Processing:
852 member_id = get_member_id(tensor_descr)
854 if isinstance(proc_descr, (v0_4.BinarizeDescr, v0_5.BinarizeDescr)):
855 return Binarize.from_proc_descr(proc_descr, member_id)
856 elif isinstance(proc_descr, v0_5.CellposeFlowDynamicsDescr):
857 return CellposeFlowDynamics.from_proc_descr(proc_descr, member_id)
858 elif isinstance(proc_descr, (v0_4.ClipDescr, v0_5.ClipDescr)):
859 return Clip.from_proc_descr(proc_descr, member_id)
860 elif isinstance(proc_descr, v0_5.CustomProcessingDescr):
861 return CustomProcessing.from_proc_descr(proc_descr, member_id)
862 elif isinstance(proc_descr, v0_5.EnsureDtypeDescr):
863 return EnsureDtype.from_proc_descr(proc_descr, member_id)
864 elif isinstance(proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr):
865 return FixedZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id)
866 elif isinstance(proc_descr, (v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr)):
867 return ScaleLinear.from_proc_descr(proc_descr, member_id)
868 elif isinstance(
869 proc_descr, (v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr)
870 ):
871 return ScaleMeanVariance.from_proc_descr(proc_descr, member_id)
872 elif isinstance(proc_descr, (v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr)):
873 return ScaleRange.from_proc_descr(proc_descr, member_id)
874 elif isinstance(proc_descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)):
875 return Sigmoid.from_proc_descr(proc_descr, member_id)
876 elif (
877 isinstance(proc_descr, v0_4.ZeroMeanUnitVarianceDescr)
878 and proc_descr.kwargs.mode == "fixed"
879 ):
880 if not isinstance(
881 tensor_descr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)
882 ):
883 raise TypeError(
884 "Expected v0_4 tensor description for v0_4 processing description"
885 )
887 v5_proc_descr = _convert_proc(proc_descr, tensor_descr.axes)
888 assert isinstance(v5_proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr)
889 return FixedZeroMeanUnitVariance.from_proc_descr(v5_proc_descr, member_id)
890 elif isinstance(proc_descr, v0_5.SoftmaxDescr):
891 return Softmax.from_proc_descr(proc_descr, member_id)
892 elif isinstance(proc_descr, v0_5.StardistPostprocessingDescr):
893 if isinstance(proc_descr.kwargs, v0_5.StardistPostprocessingKwargs2D):
894 return StardistPostprocessing2D.from_proc_descr(proc_descr, member_id)
895 elif isinstance(proc_descr.kwargs, v0_5.StardistPostprocessingKwargs3D):
896 return StardistPostprocessing3D.from_proc_descr(proc_descr, member_id)
897 else:
898 raise ValueError(
899 f"expected ndim 2 or 3 for stardist postprocessing, but got {proc_descr.kwargs.ndim}"
900 )
901 elif isinstance(
902 proc_descr,
903 (v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr),
904 ):
905 return ZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id)
906 else:
907 assert_never(proc_descr)