Coverage for src / bioimageio / core / proc_ops.py: 71%
463 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-27 22:06 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-27 22:06 +0000
1import collections.abc
2from abc import ABC, abstractmethod
3from dataclasses import InitVar, dataclass, field
4from functools import partial
5from typing import (
6 Any,
7 Collection,
8 Generic,
9 List,
10 Literal,
11 Mapping,
12 Optional,
13 Sequence,
14 Set,
15 Tuple,
16 Union,
17)
19import numpy as np
20import scipy
21import xarray as xr
22from numpy.typing import NDArray
23from typing_extensions import Self, TypeVar, assert_never, cast
25from bioimageio.core.digest_spec import get_member_id
26from bioimageio.spec.model import v0_4, v0_5
27from bioimageio.spec.model.v0_5 import (
28 _convert_proc, # pyright: ignore[reportPrivateUsage]
29)
31from ._op_base import BlockwiseOperator, SamplewiseOperator, SimpleOperator
32from .axis import AxisId
33from .common import DTypeStr, MemberId
34from .sample import Sample, SampleBlock
35from .stat_calculators import StatsCalculator
36from .stat_measures import (
37 DatasetMean,
38 DatasetMeasure,
39 DatasetQuantile,
40 DatasetStd,
41 MeanMeasure,
42 Measure,
43 MeasureValue,
44 SampleMean,
45 SampleQuantile,
46 SampleStd,
47 Stat,
48 StdMeasure,
49)
50from .tensor import Tensor
53def _convert_axis_ids(
54 axes: v0_4.AxesInCZYX,
55 mode: Literal["per_sample", "per_dataset"],
56) -> Tuple[AxisId, ...]:
57 if not isinstance(axes, str):
58 return tuple(axes)
60 if mode == "per_sample":
61 ret = []
62 elif mode == "per_dataset":
63 ret = [v0_5.BATCH_AXIS_ID]
64 else:
65 assert_never(mode)
67 ret.extend([AxisId(a) for a in axes])
68 return tuple(ret)
71@dataclass
72class AddKnownDatasetStats(BlockwiseOperator):
73 dataset_stats: Mapping[DatasetMeasure, MeasureValue]
75 @property
76 def required_measures(self) -> Collection[Measure]:
77 return set()
79 def __call__(self, sample: Union[Sample, SampleBlock]) -> None:
80 sample.stat.update(self.dataset_stats.items())
83# @dataclass
84# class UpdateStats(Operator):
85# """Calculates sample and/or dataset measures"""
87# measures: Union[Sequence[Measure], Set[Measure], Mapping[Measure, MeasureValue]]
88# """sample and dataset `measuers` to be calculated by this operator. Initial/fixed
89# dataset measure values may be given, see `keep_updating_dataset_stats` for details.
90# """
91# keep_updating_dataset_stats: Optional[bool] = None
92# """indicates if operator calls should keep updating dataset statistics or not
94# default (None): if `measures` is a `Mapping` (i.e. initial measure values are
95# given) no further updates to dataset statistics is conducted, otherwise (w.o.
96# initial measure values) dataset statistics are updated by each processed sample.
97# """
98# _keep_updating_dataset_stats: bool = field(init=False)
99# _stats_calculator: StatsCalculator = field(init=False)
101# @property
102# def required_measures(self) -> Collection[Measure]:
103# return set()
105# def __post_init__(self):
106# self._stats_calculator = StatsCalculator(self.measures)
107# if self.keep_updating_dataset_stats is None:
108# self._keep_updating_dataset_stats = not isinstance(self.measures, collections.abc.Mapping)
109# else:
110# self._keep_updating_dataset_stats = self.keep_updating_dataset_stats
112# def __call__(self, sample_block: SampleBlockWithOrigin> None:
113# if self._keep_updating_dataset_stats:
114# sample.stat.update(self._stats_calculator.update_and_get_all(sample))
115# else:
116# sample.stat.update(self._stats_calculator.skip_update_and_get_all(sample))
119@dataclass
120class UpdateStats(SamplewiseOperator):
121 """Calculates sample and/or dataset measures"""
123 stats_calculator: StatsCalculator
124 """`StatsCalculator` to be used by this operator."""
125 keep_updating_initial_dataset_stats: bool = False
126 """indicates if operator calls should keep updating initial dataset statistics or not;
127 if the `stats_calculator` was not provided with any initial dataset statistics,
128 these are always updated with every new sample.
129 """
130 _keep_updating_dataset_stats: bool = field(init=False)
132 @property
133 def required_measures(self) -> Collection[Measure]:
134 return set()
136 def __post_init__(self):
137 self._keep_updating_dataset_stats = (
138 self.keep_updating_initial_dataset_stats
139 or not self.stats_calculator.has_dataset_measures
140 )
142 def __call__(self, sample: Sample) -> None:
143 if self._keep_updating_dataset_stats:
144 sample.stat.update(self.stats_calculator.update_and_get_all(sample))
145 else:
146 sample.stat.update(self.stats_calculator.skip_update_and_get_all(sample))
149@dataclass
150class Binarize(SimpleOperator):
151 """'output = tensor > threshold'."""
153 threshold: Union[float, Sequence[float]]
154 axis: Optional[AxisId] = None
156 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
157 return x > self.threshold
159 @property
160 def required_measures(self) -> Collection[Measure]:
161 return set()
163 def get_output_shape(
164 self, input_shape: Mapping[AxisId, int]
165 ) -> Mapping[AxisId, int]:
166 return input_shape
168 @classmethod
169 def from_proc_descr(
170 cls, descr: Union[v0_4.BinarizeDescr, v0_5.BinarizeDescr], member_id: MemberId
171 ) -> Self:
172 if isinstance(descr.kwargs, (v0_4.BinarizeKwargs, v0_5.BinarizeKwargs)):
173 return cls(
174 input=member_id, output=member_id, threshold=descr.kwargs.threshold
175 )
176 elif isinstance(descr.kwargs, v0_5.BinarizeAlongAxisKwargs):
177 return cls(
178 input=member_id,
179 output=member_id,
180 threshold=descr.kwargs.threshold,
181 axis=descr.kwargs.axis,
182 )
183 else:
184 assert_never(descr.kwargs)
187@dataclass
188class Clip(SimpleOperator):
189 min: Optional[Union[float, SampleQuantile, DatasetQuantile]] = None
190 """minimum value for clipping"""
191 max: Optional[Union[float, SampleQuantile, DatasetQuantile]] = None
192 """maximum value for clipping"""
194 def __post_init__(self):
195 if self.min is None and self.max is None:
196 raise ValueError("missing min or max value")
198 if (
199 isinstance(self.min, float)
200 and isinstance(self.max, float)
201 and self.min >= self.max
202 ):
203 raise ValueError(f"expected min < max, but {self.min} >= {self.max}")
205 if isinstance(self.min, (SampleQuantile, DatasetQuantile)) and isinstance(
206 self.max, (SampleQuantile, DatasetQuantile)
207 ):
208 if self.min.axes != self.max.axes:
209 raise NotImplementedError(
210 f"expected min and max quantiles with same axes, but got {self.min.axes} and {self.max.axes}"
211 )
212 if self.min.q >= self.max.q:
213 raise ValueError(
214 f"expected min quantile < max quantile, but {self.min.q} >= {self.max.q}"
215 )
217 @property
218 def required_measures(self):
219 return {
220 arg
221 for arg in (self.min, self.max)
222 if isinstance(arg, (SampleQuantile, DatasetQuantile))
223 }
225 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
226 if isinstance(self.min, (SampleQuantile, DatasetQuantile)):
227 min_value = stat[self.min]
228 if isinstance(min_value, (int, float)):
229 # use clip for scalar value
230 min_clip_arg = min_value
231 else:
232 # clip does not support non-scalar values
233 x = Tensor.from_xarray(
234 x.data.where(x.data >= min_value.data, min_value.data)
235 )
236 min_clip_arg = None
237 else:
238 min_clip_arg = self.min
240 if isinstance(self.max, (SampleQuantile, DatasetQuantile)):
241 max_value = stat[self.max]
242 if isinstance(max_value, (int, float)):
243 # use clip for scalar value
244 max_clip_arg = max_value
245 else:
246 # clip does not support non-scalar values
247 x = Tensor.from_xarray(
248 x.data.where(x.data <= max_value.data, max_value.data)
249 )
250 max_clip_arg = None
251 else:
252 max_clip_arg = self.max
254 if min_clip_arg is not None or max_clip_arg is not None:
255 x = x.clip(min_clip_arg, max_clip_arg)
257 return x
259 def get_output_shape(
260 self, input_shape: Mapping[AxisId, int]
261 ) -> Mapping[AxisId, int]:
262 return input_shape
264 @classmethod
265 def from_proc_descr(
266 cls, descr: Union[v0_4.ClipDescr, v0_5.ClipDescr], member_id: MemberId
267 ) -> Self:
268 if isinstance(descr, v0_5.ClipDescr):
269 dataset_mode, axes = _get_axes(descr.kwargs)
270 if dataset_mode:
271 Quantile = DatasetQuantile
272 else:
273 Quantile = partial(SampleQuantile, method="inverted_cdf")
275 if descr.kwargs.min is not None:
276 min_arg = descr.kwargs.min
277 elif descr.kwargs.min_percentile is not None:
278 min_arg = Quantile(
279 q=descr.kwargs.min_percentile / 100,
280 axes=axes,
281 member_id=member_id,
282 )
283 else:
284 min_arg = None
286 if descr.kwargs.max is not None:
287 max_arg = descr.kwargs.max
288 elif descr.kwargs.max_percentile is not None:
289 max_arg = Quantile(
290 q=descr.kwargs.max_percentile / 100,
291 axes=axes,
292 member_id=member_id,
293 )
294 else:
295 max_arg = None
297 elif isinstance(descr, v0_4.ClipDescr):
298 min_arg = descr.kwargs.min
299 max_arg = descr.kwargs.max
300 else:
301 assert_never(descr)
303 return cls(
304 input=member_id,
305 output=member_id,
306 min=min_arg,
307 max=max_arg,
308 )
311@dataclass
312class EnsureDtype(SimpleOperator):
313 dtype: DTypeStr
315 @classmethod
316 def from_proc_descr(cls, descr: v0_5.EnsureDtypeDescr, member_id: MemberId):
317 return cls(input=member_id, output=member_id, dtype=descr.kwargs.dtype)
319 def get_descr(self):
320 return v0_5.EnsureDtypeDescr(kwargs=v0_5.EnsureDtypeKwargs(dtype=self.dtype))
322 def get_output_shape(
323 self, input_shape: Mapping[AxisId, int]
324 ) -> Mapping[AxisId, int]:
325 return input_shape
327 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
328 return x.astype(self.dtype)
330 @property
331 def required_measures(self) -> Collection[Measure]:
332 return set()
335@dataclass
336class ScaleLinear(SimpleOperator):
337 gain: Union[float, xr.DataArray] = 1.0
338 """multiplicative factor"""
340 offset: Union[float, xr.DataArray] = 0.0
341 """additive term"""
343 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
344 return x * self.gain + self.offset
346 @property
347 def required_measures(self) -> Collection[Measure]:
348 return set()
350 def get_output_shape(
351 self, input_shape: Mapping[AxisId, int]
352 ) -> Mapping[AxisId, int]:
353 return input_shape
355 @classmethod
356 def from_proc_descr(
357 cls,
358 descr: Union[v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr],
359 member_id: MemberId,
360 ) -> Self:
361 kwargs = descr.kwargs
362 if isinstance(kwargs, v0_5.ScaleLinearKwargs):
363 axis = None
364 elif isinstance(kwargs, v0_5.ScaleLinearAlongAxisKwargs):
365 axis = kwargs.axis
366 elif isinstance(kwargs, v0_4.ScaleLinearKwargs):
367 if kwargs.axes is not None:
368 raise NotImplementedError(
369 "model.v0_4.ScaleLinearKwargs with axes not implemented, please consider updating the model to v0_5."
370 )
371 axis = None
372 else:
373 assert_never(kwargs)
375 if axis:
376 gain = xr.DataArray(np.atleast_1d(kwargs.gain), dims=axis)
377 offset = xr.DataArray(np.atleast_1d(kwargs.offset), dims=axis)
378 else:
379 assert isinstance(kwargs.gain, (float, int)) or len(kwargs.gain) == 1, (
380 kwargs.gain
381 )
382 gain = (
383 kwargs.gain if isinstance(kwargs.gain, (float, int)) else kwargs.gain[0]
384 )
385 assert isinstance(kwargs.offset, (float, int)) or len(kwargs.offset) == 1
386 offset = (
387 kwargs.offset
388 if isinstance(kwargs.offset, (float, int))
389 else kwargs.offset[0]
390 )
392 return cls(input=member_id, output=member_id, gain=gain, offset=offset)
395@dataclass
396class ScaleMeanVariance(SimpleOperator):
397 axes: Optional[Sequence[AxisId]] = None
398 reference_tensor: Optional[MemberId] = None
399 eps: float = 1e-6
400 mean: Union[SampleMean, DatasetMean] = field(init=False)
401 std: Union[SampleStd, DatasetStd] = field(init=False)
402 ref_mean: Union[SampleMean, DatasetMean] = field(init=False)
403 ref_std: Union[SampleStd, DatasetStd] = field(init=False)
405 @property
406 def required_measures(self):
407 return {self.mean, self.std, self.ref_mean, self.ref_std}
409 def __post_init__(self):
410 axes = None if self.axes is None else tuple(self.axes)
411 ref_tensor = self.reference_tensor or self.input
412 if axes is None or AxisId("batch") not in axes:
413 Mean = SampleMean
414 Std = SampleStd
415 else:
416 Mean = DatasetMean
417 Std = DatasetStd
419 self.mean = Mean(member_id=self.input, axes=axes)
420 self.std = Std(member_id=self.input, axes=axes)
421 self.ref_mean = Mean(member_id=ref_tensor, axes=axes)
422 self.ref_std = Std(member_id=ref_tensor, axes=axes)
424 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
425 mean = stat[self.mean]
426 std = stat[self.std] + self.eps
427 ref_mean = stat[self.ref_mean]
428 ref_std = stat[self.ref_std] + self.eps
429 return (x - mean) / std * ref_std + ref_mean
431 def get_output_shape(
432 self, input_shape: Mapping[AxisId, int]
433 ) -> Mapping[AxisId, int]:
434 return input_shape
436 @classmethod
437 def from_proc_descr(
438 cls,
439 descr: Union[v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr],
440 member_id: MemberId,
441 ) -> Self:
442 kwargs = descr.kwargs
443 _, axes = _get_axes(descr.kwargs)
445 return cls(
446 input=member_id,
447 output=member_id,
448 reference_tensor=MemberId(str(kwargs.reference_tensor)),
449 axes=axes,
450 eps=kwargs.eps,
451 )
454def _get_axes(
455 kwargs: Union[
456 v0_4.ZeroMeanUnitVarianceKwargs,
457 v0_5.ZeroMeanUnitVarianceKwargs,
458 v0_4.ScaleRangeKwargs,
459 v0_5.ScaleRangeKwargs,
460 v0_4.ScaleMeanVarianceKwargs,
461 v0_5.ScaleMeanVarianceKwargs,
462 v0_5.ClipKwargs,
463 ],
464) -> Tuple[bool, Optional[Tuple[AxisId, ...]]]:
465 if kwargs.axes is None:
466 return True, None
467 elif isinstance(kwargs.axes, str):
468 axes = _convert_axis_ids(kwargs.axes, kwargs["mode"])
469 return AxisId("b") in axes, axes
470 elif isinstance(kwargs.axes, collections.abc.Sequence):
471 axes = tuple(kwargs.axes)
472 return AxisId("batch") in axes, axes
473 else:
474 assert_never(kwargs.axes)
477@dataclass
478class ScaleRange(SimpleOperator):
479 lower_quantile: InitVar[Optional[Union[SampleQuantile, DatasetQuantile]]] = None
480 upper_quantile: InitVar[Optional[Union[SampleQuantile, DatasetQuantile]]] = None
481 lower: Union[SampleQuantile, DatasetQuantile] = field(init=False)
482 upper: Union[SampleQuantile, DatasetQuantile] = field(init=False)
484 eps: float = 1e-6
486 def __post_init__(
487 self,
488 lower_quantile: Optional[Union[SampleQuantile, DatasetQuantile]],
489 upper_quantile: Optional[Union[SampleQuantile, DatasetQuantile]],
490 ):
491 if lower_quantile is None:
492 tid = self.input if upper_quantile is None else upper_quantile.member_id
493 self.lower = DatasetQuantile(q=0.0, member_id=tid)
494 else:
495 self.lower = lower_quantile
497 if upper_quantile is None:
498 self.upper = DatasetQuantile(q=1.0, member_id=self.lower.member_id)
499 else:
500 self.upper = upper_quantile
502 assert self.lower.member_id == self.upper.member_id
503 assert self.lower.q < self.upper.q
504 assert self.lower.axes == self.upper.axes
506 @property
507 def required_measures(self):
508 return {self.lower, self.upper}
510 def get_output_shape(
511 self, input_shape: Mapping[AxisId, int]
512 ) -> Mapping[AxisId, int]:
513 return input_shape
515 @classmethod
516 def from_proc_descr(
517 cls,
518 descr: Union[v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr],
519 member_id: MemberId,
520 ):
521 kwargs = descr.kwargs
522 ref_tensor = (
523 member_id
524 if kwargs.reference_tensor is None
525 else MemberId(str(kwargs.reference_tensor))
526 )
527 dataset_mode, axes = _get_axes(descr.kwargs)
528 if dataset_mode:
529 Quantile = DatasetQuantile
530 else:
531 Quantile = partial(SampleQuantile, method="linear")
533 return cls(
534 input=member_id,
535 output=member_id,
536 lower_quantile=Quantile(
537 q=kwargs.min_percentile / 100,
538 axes=axes,
539 member_id=ref_tensor,
540 ),
541 upper_quantile=Quantile(
542 q=kwargs.max_percentile / 100,
543 axes=axes,
544 member_id=ref_tensor,
545 ),
546 )
548 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
549 lower = stat[self.lower]
550 upper = stat[self.upper]
551 return (x - lower) / (upper - lower + self.eps)
553 def get_descr(self):
554 assert self.lower.axes == self.upper.axes
555 assert self.lower.member_id == self.upper.member_id
557 return v0_5.ScaleRangeDescr(
558 kwargs=v0_5.ScaleRangeKwargs(
559 axes=self.lower.axes,
560 min_percentile=self.lower.q * 100,
561 max_percentile=self.upper.q * 100,
562 eps=self.eps,
563 reference_tensor=self.lower.member_id,
564 )
565 )
568@dataclass
569class Sigmoid(SimpleOperator):
570 """1 / (1 + e^(-input))."""
572 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
573 return Tensor(1.0 / (1.0 + np.exp(-x)), dims=x.dims)
575 @property
576 def required_measures(self) -> Collection[Measure]:
577 return {}
579 def get_output_shape(
580 self, input_shape: Mapping[AxisId, int]
581 ) -> Mapping[AxisId, int]:
582 return input_shape
584 @classmethod
585 def from_proc_descr(
586 cls, descr: Union[v0_4.SigmoidDescr, v0_5.SigmoidDescr], member_id: MemberId
587 ) -> Self:
588 assert isinstance(descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr))
589 return cls(input=member_id, output=member_id)
591 def get_descr(self):
592 return v0_5.SigmoidDescr()
595@dataclass
596class Softmax(SimpleOperator):
597 """Softmax activation function."""
599 axis: AxisId = AxisId("channel")
601 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
602 axis_idx = x.dims.index(self.axis)
603 result = scipy.special.softmax(x.data, axis=axis_idx)
604 result_xr = xr.DataArray(result, dims=x.dims)
605 return Tensor.from_xarray(result_xr)
607 @property
608 def required_measures(self) -> Collection[Measure]:
609 return set()
611 def get_output_shape(
612 self, input_shape: Mapping[AxisId, int]
613 ) -> Mapping[AxisId, int]:
614 return input_shape
616 @classmethod
617 def from_proc_descr(cls, descr: v0_5.SoftmaxDescr, member_id: MemberId) -> Self:
618 assert isinstance(descr, v0_5.SoftmaxDescr)
619 return cls(input=member_id, output=member_id, axis=descr.kwargs.axis)
621 def get_descr(self):
622 return v0_5.SoftmaxDescr(kwargs=v0_5.SoftmaxKwargs(axis=self.axis))
625NdTuple = TypeVar("NdTuple", Tuple[int, int], Tuple[int, int, int])
626NdBorder = TypeVar(
627 "NdBorder",
628 Tuple[Tuple[int, int], Tuple[int, int]],
629 Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int]],
630)
633@dataclass
634class _StardistPostprocessingBase(SamplewiseOperator, Generic[NdTuple, NdBorder], ABC):
635 prob_dist_input_id: MemberId
636 instance_labels_output_id: MemberId
638 grid: NdTuple
639 """Grid size of network predictions."""
641 prob_threshold: float
642 """Object probability threshold for non-maximum suppression."""
644 nms_threshold: float
645 """The IoU threshold for non-maximum suppression."""
647 b: Union[int, NdBorder]
648 """Border region in which object probability is set to zero."""
650 @property
651 def required_measures(self) -> Collection[Measure]:
652 return set()
654 def __call__(self, sample: Sample) -> None:
655 prob_dist = sample.members[self.prob_dist_input_id]
657 assert AxisId("channel") in prob_dist.dims, (
658 "expected 'channel' axis in stardist probability/distance input"
659 )
660 allowed_spatial = tuple(
661 map(AxisId, ("y", "x") if len(self.grid) == 2 else ("z", "y", "x"))
662 )
663 assert all(
664 a in allowed_spatial or a in (AxisId("batch"), AxisId("channel"))
665 for a in prob_dist.dims
666 ), (
667 f"expected prob_dist to have only 'batch', 'channel', and spatial axes {allowed_spatial}, but got {prob_dist.dims}"
668 )
670 spatial_shape = tuple(
671 prob_dist.tagged_shape[a] * g for a, g in zip(allowed_spatial, self.grid)
672 )
673 if len(spatial_shape) != len(self.grid):
674 raise ValueError(
675 f"expected {len(self.grid)} spatial dimensions in prob_dist tensor, but got {len(spatial_shape)}"
676 )
677 else:
678 spatial_shape = cast(NdTuple, spatial_shape)
680 prob_dist = prob_dist.transpose(
681 (AxisId("batch"), *allowed_spatial, AxisId("channel"))
682 )
683 labels: List[NDArray[Any]] = []
684 for batch_idx in range(prob_dist.sizes[AxisId("batch")]):
685 prob = prob_dist[
686 {AxisId("batch"): batch_idx, AxisId("channel"): 0}
687 ].to_numpy()
688 dist = prob_dist[
689 {AxisId("batch"): batch_idx, AxisId("channel"): slice(1, None)}
690 ].to_numpy()
692 labels_i = self._impl(prob, dist, spatial_shape)
693 assert labels_i.shape == spatial_shape, (
694 f"expected label image shape {spatial_shape}, but got {labels_i.shape}"
695 )
696 labels.append(labels_i)
698 instance_labels = Tensor(
699 np.stack(labels)[..., None],
700 dims=(AxisId("batch"), *allowed_spatial, AxisId("channel")),
701 )
702 sample.members[self.instance_labels_output_id] = instance_labels
704 @abstractmethod
705 def _impl(
706 self, prob: NDArray[Any], dist: NDArray[Any], spatial_shape: NdTuple
707 ) -> NDArray[np.int32]:
708 raise NotImplementedError
711@dataclass
712class StardistPostprocessing2D(
713 _StardistPostprocessingBase[
714 Tuple[int, int], Tuple[Tuple[int, int], Tuple[int, int]]
715 ]
716):
717 def _impl(
718 self, prob: NDArray[Any], dist: NDArray[Any], spatial_shape: Tuple[int, int]
719 ) -> NDArray[np.int32]:
720 from stardist import (
721 non_maximum_suppression, # pyright: ignore[reportUnknownVariableType]
722 polygons_to_label, # pyright: ignore[reportUnknownVariableType]
723 )
725 points, probi, disti = non_maximum_suppression( # pyright: ignore[reportUnknownVariableType]
726 dist,
727 prob,
728 grid=self.grid,
729 prob_thresh=self.prob_threshold,
730 nms_thresh=self.nms_threshold,
731 b=self.b, # pyright: ignore[reportArgumentType]
732 )
734 return polygons_to_label(disti, points, prob=probi, shape=spatial_shape)
736 @classmethod
737 def from_proc_descr(
738 cls, descr: v0_5.StardistPostprocessingDescr, member_id: MemberId
739 ) -> Self:
740 if not isinstance(descr.kwargs, v0_5.StardistPostprocessingKwargs2D):
741 raise TypeError(
742 f"expected v0_5.StardistPostprocessingKwargs2D for 2D stardist post-processing, but got {type(descr.kwargs)}"
743 )
745 kwargs = descr.kwargs
746 return cls(
747 prob_dist_input_id=member_id,
748 instance_labels_output_id=member_id,
749 grid=kwargs.grid,
750 prob_threshold=kwargs.prob_threshold,
751 nms_threshold=kwargs.nms_threshold,
752 b=kwargs.b,
753 )
756@dataclass
757class StardistPostprocessing3D(
758 _StardistPostprocessingBase[
759 Tuple[int, int, int], Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int]]
760 ]
761):
762 n_rays: int
763 """Number of rays for 3D star-convex polyhedra."""
765 anisotropy: Tuple[float, float, float]
766 """Anisotropy factors for 3D star-convex polyhedra, i.e. the physical pixel size along each spatial axis."""
768 overlap_label: Optional[int] = None
769 """Optional label to apply to any area of overlapping predicted objects."""
771 def _impl(
772 self,
773 prob: NDArray[Any],
774 dist: NDArray[Any],
775 spatial_shape: Tuple[int, int, int],
776 ) -> NDArray[np.int32]:
777 from stardist import (
778 Rays_GoldenSpiral,
779 non_maximum_suppression_3d, # pyright: ignore[reportUnknownVariableType]
780 polyhedron_to_label, # pyright: ignore[reportUnknownVariableType]
781 )
782 from stardist.matching import (
783 relabel_sequential, # pyright: ignore[reportUnknownVariableType]
784 )
786 rays = Rays_GoldenSpiral(self.n_rays, anisotropy=self.anisotropy)
788 points, probi, disti = non_maximum_suppression_3d( # pyright: ignore[reportUnknownVariableType]
789 dist,
790 prob,
791 rays,
792 grid=self.grid,
793 prob_thresh=self.prob_threshold,
794 nms_thresh=self.nms_threshold,
795 b=self.b, # pyright: ignore[reportArgumentType]
796 )
798 labels = polyhedron_to_label( # pyright: ignore[reportUnknownVariableType]
799 disti,
800 points,
801 rays=rays,
802 prob=probi,
803 shape=spatial_shape,
804 overlap_label=self.overlap_label,
805 )
807 labels, _, _ = relabel_sequential(labels)
808 assert isinstance(labels, np.ndarray) and labels.dtype == np.int32
809 return labels
811 @classmethod
812 def from_proc_descr(
813 cls, descr: v0_5.StardistPostprocessingDescr, member_id: MemberId
814 ) -> Self:
815 if not isinstance(descr.kwargs, v0_5.StardistPostprocessingKwargs3D):
816 raise TypeError(
817 f"expected v0_5.StardistPostprocessingKwargs3D for 3D stardist post-processing, but got {type(descr.kwargs)}"
818 )
820 kwargs = descr.kwargs
821 return cls(
822 prob_dist_input_id=member_id,
823 instance_labels_output_id=member_id,
824 grid=kwargs.grid,
825 prob_threshold=kwargs.prob_threshold,
826 nms_threshold=kwargs.nms_threshold,
827 n_rays=kwargs.n_rays,
828 anisotropy=kwargs.anisotropy,
829 b=kwargs.b,
830 overlap_label=kwargs.overlap_label,
831 )
834@dataclass
835class ZeroMeanUnitVariance(SimpleOperator):
836 """normalize to zero mean, unit variance."""
838 mean: MeanMeasure
839 std: StdMeasure
841 eps: float = 1e-6
843 def __post_init__(self):
844 assert self.mean.axes == self.std.axes
846 @property
847 def required_measures(self) -> Set[Union[MeanMeasure, StdMeasure]]:
848 return {self.mean, self.std}
850 def get_output_shape(
851 self, input_shape: Mapping[AxisId, int]
852 ) -> Mapping[AxisId, int]:
853 return input_shape
855 @classmethod
856 def from_proc_descr(
857 cls,
858 descr: Union[v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr],
859 member_id: MemberId,
860 ):
861 dataset_mode, axes = _get_axes(descr.kwargs)
863 if dataset_mode:
864 Mean = DatasetMean
865 Std = DatasetStd
866 else:
867 Mean = SampleMean
868 Std = SampleStd
870 return cls(
871 input=member_id,
872 output=member_id,
873 mean=Mean(axes=axes, member_id=member_id),
874 std=Std(axes=axes, member_id=member_id),
875 )
877 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
878 mean = stat[self.mean]
879 std = stat[self.std]
880 return (x - mean) / (std + self.eps)
882 def get_descr(self):
883 return v0_5.ZeroMeanUnitVarianceDescr(
884 kwargs=v0_5.ZeroMeanUnitVarianceKwargs(axes=self.mean.axes, eps=self.eps)
885 )
888@dataclass
889class FixedZeroMeanUnitVariance(SimpleOperator):
890 """normalize to zero mean, unit variance with precomputed values."""
892 mean: Union[float, xr.DataArray]
893 std: Union[float, xr.DataArray]
895 eps: float = 1e-6
897 def __post_init__(self):
898 assert (
899 isinstance(self.mean, (int, float))
900 or isinstance(self.std, (int, float))
901 or self.mean.dims == self.std.dims
902 )
904 @property
905 def required_measures(self) -> Collection[Measure]:
906 return set()
908 def get_output_shape(
909 self, input_shape: Mapping[AxisId, int]
910 ) -> Mapping[AxisId, int]:
911 return input_shape
913 @classmethod
914 def from_proc_descr(
915 cls,
916 descr: v0_5.FixedZeroMeanUnitVarianceDescr,
917 member_id: MemberId,
918 ) -> Self:
919 if isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceKwargs):
920 dims = None
921 elif isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs):
922 dims = (AxisId(descr.kwargs.axis),)
923 else:
924 assert_never(descr.kwargs)
926 return cls(
927 input=member_id,
928 output=member_id,
929 mean=xr.DataArray(descr.kwargs.mean, dims=dims),
930 std=xr.DataArray(descr.kwargs.std, dims=dims),
931 )
933 def get_descr(self):
934 if isinstance(self.mean, (int, float)):
935 assert isinstance(self.std, (int, float))
936 kwargs = v0_5.FixedZeroMeanUnitVarianceKwargs(mean=self.mean, std=self.std)
937 else:
938 assert isinstance(self.std, xr.DataArray)
939 assert len(self.mean.dims) == 1
940 kwargs = v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs(
941 axis=AxisId(str(self.mean.dims[0])),
942 mean=list(self.mean),
943 std=list(self.std),
944 )
946 return v0_5.FixedZeroMeanUnitVarianceDescr(kwargs=kwargs)
948 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
949 return (x - self.mean) / (self.std + self.eps)
952ProcDescr = Union[
953 v0_4.PreprocessingDescr,
954 v0_4.PostprocessingDescr,
955 v0_5.PreprocessingDescr,
956 v0_5.PostprocessingDescr,
957]
960Processing = Union[
961 AddKnownDatasetStats,
962 Binarize,
963 Clip,
964 EnsureDtype,
965 FixedZeroMeanUnitVariance,
966 ScaleLinear,
967 ScaleMeanVariance,
968 ScaleRange,
969 Sigmoid,
970 StardistPostprocessing2D,
971 StardistPostprocessing3D,
972 Softmax,
973 UpdateStats,
974 ZeroMeanUnitVariance,
975]
978def get_proc(
979 proc_descr: ProcDescr,
980 tensor_descr: Union[
981 v0_4.InputTensorDescr,
982 v0_4.OutputTensorDescr,
983 v0_5.InputTensorDescr,
984 v0_5.OutputTensorDescr,
985 ],
986) -> Processing:
987 member_id = get_member_id(tensor_descr)
989 if isinstance(proc_descr, (v0_4.BinarizeDescr, v0_5.BinarizeDescr)):
990 return Binarize.from_proc_descr(proc_descr, member_id)
991 elif isinstance(proc_descr, (v0_4.ClipDescr, v0_5.ClipDescr)):
992 return Clip.from_proc_descr(proc_descr, member_id)
993 elif isinstance(proc_descr, v0_5.EnsureDtypeDescr):
994 return EnsureDtype.from_proc_descr(proc_descr, member_id)
995 elif isinstance(proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr):
996 return FixedZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id)
997 elif isinstance(proc_descr, (v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr)):
998 return ScaleLinear.from_proc_descr(proc_descr, member_id)
999 elif isinstance(
1000 proc_descr, (v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr)
1001 ):
1002 return ScaleMeanVariance.from_proc_descr(proc_descr, member_id)
1003 elif isinstance(proc_descr, (v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr)):
1004 return ScaleRange.from_proc_descr(proc_descr, member_id)
1005 elif isinstance(proc_descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)):
1006 return Sigmoid.from_proc_descr(proc_descr, member_id)
1007 elif (
1008 isinstance(proc_descr, v0_4.ZeroMeanUnitVarianceDescr)
1009 and proc_descr.kwargs.mode == "fixed"
1010 ):
1011 if not isinstance(
1012 tensor_descr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)
1013 ):
1014 raise TypeError(
1015 "Expected v0_4 tensor description for v0_4 processing description"
1016 )
1018 v5_proc_descr = _convert_proc(proc_descr, tensor_descr.axes)
1019 assert isinstance(v5_proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr)
1020 return FixedZeroMeanUnitVariance.from_proc_descr(v5_proc_descr, member_id)
1021 elif isinstance(
1022 proc_descr,
1023 (v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr),
1024 ):
1025 return ZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id)
1026 elif isinstance(proc_descr, v0_5.SoftmaxDescr):
1027 return Softmax.from_proc_descr(proc_descr, member_id)
1028 elif isinstance(proc_descr, v0_5.StardistPostprocessingDescr):
1029 if isinstance(proc_descr.kwargs, v0_5.StardistPostprocessingKwargs2D):
1030 return StardistPostprocessing2D.from_proc_descr(proc_descr, member_id)
1031 elif isinstance(proc_descr.kwargs, v0_5.StardistPostprocessingKwargs3D):
1032 return StardistPostprocessing3D.from_proc_descr(proc_descr, member_id)
1033 else:
1034 raise ValueError(
1035 f"expected ndim 2 or 3 for stardist postprocessing, but got {proc_descr.kwargs.ndim}"
1036 )
1037 else:
1038 assert_never(proc_descr)