Coverage for src / bioimageio / core / proc_ops.py: 73%
508 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-22 18:38 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-22 18:38 +0000
1import collections.abc
2from abc import ABC, abstractmethod
3from dataclasses import InitVar, dataclass, field
4from functools import partial
5from typing import (
6 Any,
7 Collection,
8 Generic,
9 List,
10 Literal,
11 Mapping,
12 Optional,
13 Sequence,
14 Set,
15 Tuple,
16 Union,
17)
19import numpy as np
20import scipy
21import xarray as xr
22from numpy.typing import NDArray
23from typing_extensions import Self, TypeVar, assert_never, cast
25from bioimageio.core.digest_spec import get_member_id
26from bioimageio.spec.model import v0_4, v0_5
27from bioimageio.spec.model.v0_5 import (
28 _convert_proc, # pyright: ignore[reportPrivateUsage]
29)
31from ._op_base import BlockwiseOperator, SamplewiseOperator, SimpleOperator
32from .axis import AxisId
33from .common import DTypeStr, MemberId
34from .sample import Sample, SampleBlock
35from .stat_calculators import StatsCalculator
36from .stat_measures import (
37 DatasetMean,
38 DatasetMeasure,
39 DatasetQuantile,
40 DatasetStd,
41 MeanMeasure,
42 Measure,
43 MeasureValue,
44 SampleMean,
45 SampleQuantile,
46 SampleStd,
47 Stat,
48 StdMeasure,
49)
50from .tensor import Tensor
53def _convert_axis_ids(
54 axes: v0_4.AxesInCZYX,
55 mode: Literal["per_sample", "per_dataset"],
56) -> Tuple[AxisId, ...]:
57 if not isinstance(axes, str):
58 return tuple(axes)
60 if mode == "per_sample":
61 ret = []
62 elif mode == "per_dataset":
63 ret = [v0_5.BATCH_AXIS_ID]
64 else:
65 assert_never(mode)
67 ret.extend([AxisId(a) for a in axes])
68 return tuple(ret)
71@dataclass
72class AddKnownDatasetStats(BlockwiseOperator):
73 dataset_stats: Mapping[DatasetMeasure, MeasureValue]
75 @property
76 def required_measures(self) -> Collection[Measure]:
77 return set()
79 def __call__(self, sample: Union[Sample, SampleBlock]) -> None:
80 sample.stat.update(self.dataset_stats.items())
83# @dataclass
84# class UpdateStats(Operator):
85# """Calculates sample and/or dataset measures"""
87# measures: Union[Sequence[Measure], Set[Measure], Mapping[Measure, MeasureValue]]
88# """sample and dataset `measuers` to be calculated by this operator. Initial/fixed
89# dataset measure values may be given, see `keep_updating_dataset_stats` for details.
90# """
91# keep_updating_dataset_stats: Optional[bool] = None
92# """indicates if operator calls should keep updating dataset statistics or not
94# default (None): if `measures` is a `Mapping` (i.e. initial measure values are
95# given) no further updates to dataset statistics is conducted, otherwise (w.o.
96# initial measure values) dataset statistics are updated by each processed sample.
97# """
98# _keep_updating_dataset_stats: bool = field(init=False)
99# _stats_calculator: StatsCalculator = field(init=False)
101# @property
102# def required_measures(self) -> Collection[Measure]:
103# return set()
105# def __post_init__(self):
106# self._stats_calculator = StatsCalculator(self.measures)
107# if self.keep_updating_dataset_stats is None:
108# self._keep_updating_dataset_stats = not isinstance(self.measures, collections.abc.Mapping)
109# else:
110# self._keep_updating_dataset_stats = self.keep_updating_dataset_stats
112# def __call__(self, sample_block: SampleBlockWithOrigin> None:
113# if self._keep_updating_dataset_stats:
114# sample.stat.update(self._stats_calculator.update_and_get_all(sample))
115# else:
116# sample.stat.update(self._stats_calculator.skip_update_and_get_all(sample))
119@dataclass
120class UpdateStats(SamplewiseOperator):
121 """Calculates sample and/or dataset measures"""
123 stats_calculator: StatsCalculator
124 """`StatsCalculator` to be used by this operator."""
125 keep_updating_initial_dataset_stats: bool = False
126 """indicates if operator calls should keep updating initial dataset statistics or not;
127 if the `stats_calculator` was not provided with any initial dataset statistics,
128 these are always updated with every new sample.
129 """
130 _keep_updating_dataset_stats: bool = field(init=False)
132 @property
133 def required_measures(self) -> Collection[Measure]:
134 return set()
136 def __post_init__(self):
137 self._keep_updating_dataset_stats = (
138 self.keep_updating_initial_dataset_stats
139 or not self.stats_calculator.has_dataset_measures
140 )
142 def __call__(self, sample: Sample) -> None:
143 if self._keep_updating_dataset_stats:
144 sample.stat.update(self.stats_calculator.update_and_get_all(sample))
145 else:
146 sample.stat.update(self.stats_calculator.skip_update_and_get_all(sample))
149@dataclass
150class Binarize(SimpleOperator):
151 """'output = tensor > threshold'."""
153 threshold: Union[float, Sequence[float]]
154 axis: Optional[AxisId] = None
156 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
157 return x > self.threshold
159 @property
160 def required_measures(self) -> Collection[Measure]:
161 return set()
163 def get_output_shape(
164 self, input_shape: Mapping[AxisId, int]
165 ) -> Mapping[AxisId, int]:
166 return input_shape
168 @classmethod
169 def from_proc_descr(
170 cls, descr: Union[v0_4.BinarizeDescr, v0_5.BinarizeDescr], member_id: MemberId
171 ) -> Self:
172 if isinstance(descr.kwargs, (v0_4.BinarizeKwargs, v0_5.BinarizeKwargs)):
173 return cls(
174 input=member_id, output=member_id, threshold=descr.kwargs.threshold
175 )
176 elif isinstance(descr.kwargs, v0_5.BinarizeAlongAxisKwargs):
177 return cls(
178 input=member_id,
179 output=member_id,
180 threshold=descr.kwargs.threshold,
181 axis=descr.kwargs.axis,
182 )
183 else:
184 assert_never(descr.kwargs)
187@dataclass
188class Clip(SimpleOperator):
189 min: Optional[Union[float, SampleQuantile, DatasetQuantile]] = None
190 """minimum value for clipping"""
191 max: Optional[Union[float, SampleQuantile, DatasetQuantile]] = None
192 """maximum value for clipping"""
194 def __post_init__(self):
195 if self.min is None and self.max is None:
196 raise ValueError("missing min or max value")
198 if (
199 isinstance(self.min, float)
200 and isinstance(self.max, float)
201 and self.min >= self.max
202 ):
203 raise ValueError(f"expected min < max, but {self.min} >= {self.max}")
205 if isinstance(self.min, (SampleQuantile, DatasetQuantile)) and isinstance(
206 self.max, (SampleQuantile, DatasetQuantile)
207 ):
208 if self.min.axes != self.max.axes:
209 raise NotImplementedError(
210 f"expected min and max quantiles with same axes, but got {self.min.axes} and {self.max.axes}"
211 )
212 if self.min.q >= self.max.q:
213 raise ValueError(
214 f"expected min quantile < max quantile, but {self.min.q} >= {self.max.q}"
215 )
217 @property
218 def required_measures(self):
219 return {
220 arg
221 for arg in (self.min, self.max)
222 if isinstance(arg, (SampleQuantile, DatasetQuantile))
223 }
225 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
226 if isinstance(self.min, (SampleQuantile, DatasetQuantile)):
227 min_value = stat[self.min]
228 if isinstance(min_value, (int, float)):
229 # use clip for scalar value
230 min_clip_arg = min_value
231 else:
232 # clip does not support non-scalar values
233 x = Tensor.from_xarray(
234 x.data.where(x.data >= min_value.data, min_value.data)
235 )
236 min_clip_arg = None
237 else:
238 min_clip_arg = self.min
240 if isinstance(self.max, (SampleQuantile, DatasetQuantile)):
241 max_value = stat[self.max]
242 if isinstance(max_value, (int, float)):
243 # use clip for scalar value
244 max_clip_arg = max_value
245 else:
246 # clip does not support non-scalar values
247 x = Tensor.from_xarray(
248 x.data.where(x.data <= max_value.data, max_value.data)
249 )
250 max_clip_arg = None
251 else:
252 max_clip_arg = self.max
254 if min_clip_arg is not None or max_clip_arg is not None:
255 x = x.clip(min_clip_arg, max_clip_arg)
257 return x
259 def get_output_shape(
260 self, input_shape: Mapping[AxisId, int]
261 ) -> Mapping[AxisId, int]:
262 return input_shape
264 @classmethod
265 def from_proc_descr(
266 cls, descr: Union[v0_4.ClipDescr, v0_5.ClipDescr], member_id: MemberId
267 ) -> Self:
268 if isinstance(descr, v0_5.ClipDescr):
269 dataset_mode, axes = _get_axes(descr.kwargs)
270 if dataset_mode:
271 Quantile = DatasetQuantile
272 else:
273 Quantile = partial(SampleQuantile, method="inverted_cdf")
275 if descr.kwargs.min is not None:
276 min_arg = descr.kwargs.min
277 elif descr.kwargs.min_percentile is not None:
278 min_arg = Quantile(
279 q=descr.kwargs.min_percentile / 100,
280 axes=axes,
281 member_id=member_id,
282 )
283 else:
284 min_arg = None
286 if descr.kwargs.max is not None:
287 max_arg = descr.kwargs.max
288 elif descr.kwargs.max_percentile is not None:
289 max_arg = Quantile(
290 q=descr.kwargs.max_percentile / 100,
291 axes=axes,
292 member_id=member_id,
293 )
294 else:
295 max_arg = None
297 elif isinstance(descr, v0_4.ClipDescr):
298 min_arg = descr.kwargs.min
299 max_arg = descr.kwargs.max
300 else:
301 assert_never(descr)
303 return cls(
304 input=member_id,
305 output=member_id,
306 min=min_arg,
307 max=max_arg,
308 )
311@dataclass
312class EnsureDtype(SimpleOperator):
313 dtype: DTypeStr
315 @classmethod
316 def from_proc_descr(cls, descr: v0_5.EnsureDtypeDescr, member_id: MemberId):
317 return cls(input=member_id, output=member_id, dtype=descr.kwargs.dtype)
319 def get_descr(self):
320 return v0_5.EnsureDtypeDescr(kwargs=v0_5.EnsureDtypeKwargs(dtype=self.dtype))
322 def get_output_shape(
323 self, input_shape: Mapping[AxisId, int]
324 ) -> Mapping[AxisId, int]:
325 return input_shape
327 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
328 return x.astype(self.dtype)
330 @property
331 def required_measures(self) -> Collection[Measure]:
332 return set()
335@dataclass
336class ScaleLinear(SimpleOperator):
337 gain: Union[float, xr.DataArray] = 1.0
338 """multiplicative factor"""
340 offset: Union[float, xr.DataArray] = 0.0
341 """additive term"""
343 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
344 return x * self.gain + self.offset
346 @property
347 def required_measures(self) -> Collection[Measure]:
348 return set()
350 def get_output_shape(
351 self, input_shape: Mapping[AxisId, int]
352 ) -> Mapping[AxisId, int]:
353 return input_shape
355 @classmethod
356 def from_proc_descr(
357 cls,
358 descr: Union[v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr],
359 member_id: MemberId,
360 ) -> Self:
361 kwargs = descr.kwargs
362 if isinstance(kwargs, v0_5.ScaleLinearKwargs):
363 axis = None
364 elif isinstance(kwargs, v0_5.ScaleLinearAlongAxisKwargs):
365 axis = kwargs.axis
366 elif isinstance(kwargs, v0_4.ScaleLinearKwargs):
367 if kwargs.axes is not None:
368 raise NotImplementedError(
369 "model.v0_4.ScaleLinearKwargs with axes not implemented, please consider updating the model to v0_5."
370 )
371 axis = None
372 else:
373 assert_never(kwargs)
375 if axis:
376 gain = xr.DataArray(np.atleast_1d(kwargs.gain), dims=axis)
377 offset = xr.DataArray(np.atleast_1d(kwargs.offset), dims=axis)
378 else:
379 assert isinstance(kwargs.gain, (float, int)) or len(kwargs.gain) == 1, (
380 kwargs.gain
381 )
382 gain = (
383 kwargs.gain if isinstance(kwargs.gain, (float, int)) else kwargs.gain[0]
384 )
385 assert isinstance(kwargs.offset, (float, int)) or len(kwargs.offset) == 1
386 offset = (
387 kwargs.offset
388 if isinstance(kwargs.offset, (float, int))
389 else kwargs.offset[0]
390 )
392 return cls(input=member_id, output=member_id, gain=gain, offset=offset)
395@dataclass
396class ScaleMeanVariance(SimpleOperator):
397 axes: Optional[Sequence[AxisId]] = None
398 reference_tensor: Optional[MemberId] = None
399 eps: float = 1e-6
400 mean: Union[SampleMean, DatasetMean] = field(init=False)
401 std: Union[SampleStd, DatasetStd] = field(init=False)
402 ref_mean: Union[SampleMean, DatasetMean] = field(init=False)
403 ref_std: Union[SampleStd, DatasetStd] = field(init=False)
405 @property
406 def required_measures(self):
407 return {self.mean, self.std, self.ref_mean, self.ref_std}
409 def __post_init__(self):
410 axes = None if self.axes is None else tuple(self.axes)
411 ref_tensor = self.reference_tensor or self.input
412 if axes is None or AxisId("batch") not in axes:
413 Mean = SampleMean
414 Std = SampleStd
415 else:
416 Mean = DatasetMean
417 Std = DatasetStd
419 self.mean = Mean(member_id=self.input, axes=axes)
420 self.std = Std(member_id=self.input, axes=axes)
421 self.ref_mean = Mean(member_id=ref_tensor, axes=axes)
422 self.ref_std = Std(member_id=ref_tensor, axes=axes)
424 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
425 mean = stat[self.mean]
426 std = stat[self.std] + self.eps
427 ref_mean = stat[self.ref_mean]
428 ref_std = stat[self.ref_std] + self.eps
429 return (x - mean) / std * ref_std + ref_mean
431 def get_output_shape(
432 self, input_shape: Mapping[AxisId, int]
433 ) -> Mapping[AxisId, int]:
434 return input_shape
436 @classmethod
437 def from_proc_descr(
438 cls,
439 descr: Union[v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr],
440 member_id: MemberId,
441 ) -> Self:
442 kwargs = descr.kwargs
443 _, axes = _get_axes(descr.kwargs)
445 return cls(
446 input=member_id,
447 output=member_id,
448 reference_tensor=MemberId(str(kwargs.reference_tensor)),
449 axes=axes,
450 eps=kwargs.eps,
451 )
454def _get_axes(
455 kwargs: Union[
456 v0_4.ZeroMeanUnitVarianceKwargs,
457 v0_5.ZeroMeanUnitVarianceKwargs,
458 v0_4.ScaleRangeKwargs,
459 v0_5.ScaleRangeKwargs,
460 v0_4.ScaleMeanVarianceKwargs,
461 v0_5.ScaleMeanVarianceKwargs,
462 v0_5.ClipKwargs,
463 ],
464) -> Tuple[bool, Optional[Tuple[AxisId, ...]]]:
465 if kwargs.axes is None:
466 return True, None
467 elif isinstance(kwargs.axes, str):
468 axes = _convert_axis_ids(kwargs.axes, kwargs["mode"])
469 return AxisId("b") in axes, axes
470 elif isinstance(kwargs.axes, collections.abc.Sequence):
471 axes = tuple(kwargs.axes)
472 return AxisId("batch") in axes, axes
473 else:
474 assert_never(kwargs.axes)
477@dataclass
478class ScaleRange(SimpleOperator):
479 lower_quantile: InitVar[Optional[Union[SampleQuantile, DatasetQuantile]]] = None
480 upper_quantile: InitVar[Optional[Union[SampleQuantile, DatasetQuantile]]] = None
481 lower: Union[SampleQuantile, DatasetQuantile] = field(init=False)
482 upper: Union[SampleQuantile, DatasetQuantile] = field(init=False)
484 eps: float = 1e-6
486 def __post_init__(
487 self,
488 lower_quantile: Optional[Union[SampleQuantile, DatasetQuantile]],
489 upper_quantile: Optional[Union[SampleQuantile, DatasetQuantile]],
490 ):
491 if lower_quantile is None:
492 tid = self.input if upper_quantile is None else upper_quantile.member_id
493 self.lower = DatasetQuantile(q=0.0, member_id=tid)
494 else:
495 self.lower = lower_quantile
497 if upper_quantile is None:
498 self.upper = DatasetQuantile(q=1.0, member_id=self.lower.member_id)
499 else:
500 self.upper = upper_quantile
502 assert self.lower.member_id == self.upper.member_id
503 assert self.lower.q < self.upper.q
504 assert self.lower.axes == self.upper.axes
506 @property
507 def required_measures(self):
508 return {self.lower, self.upper}
510 def get_output_shape(
511 self, input_shape: Mapping[AxisId, int]
512 ) -> Mapping[AxisId, int]:
513 return input_shape
515 @classmethod
516 def from_proc_descr(
517 cls,
518 descr: Union[v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr],
519 member_id: MemberId,
520 ):
521 kwargs = descr.kwargs
522 ref_tensor = (
523 member_id
524 if kwargs.reference_tensor is None
525 else MemberId(str(kwargs.reference_tensor))
526 )
527 dataset_mode, axes = _get_axes(descr.kwargs)
528 if dataset_mode:
529 Quantile = DatasetQuantile
530 else:
531 Quantile = partial(SampleQuantile, method="linear")
533 return cls(
534 input=member_id,
535 output=member_id,
536 lower_quantile=Quantile(
537 q=kwargs.min_percentile / 100,
538 axes=axes,
539 member_id=ref_tensor,
540 ),
541 upper_quantile=Quantile(
542 q=kwargs.max_percentile / 100,
543 axes=axes,
544 member_id=ref_tensor,
545 ),
546 )
548 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
549 lower = stat[self.lower]
550 upper = stat[self.upper]
551 return (x - lower) / (upper - lower + self.eps)
553 def get_descr(self):
554 assert self.lower.axes == self.upper.axes
555 assert self.lower.member_id == self.upper.member_id
557 return v0_5.ScaleRangeDescr(
558 kwargs=v0_5.ScaleRangeKwargs(
559 axes=self.lower.axes,
560 min_percentile=self.lower.q * 100,
561 max_percentile=self.upper.q * 100,
562 eps=self.eps,
563 reference_tensor=self.lower.member_id,
564 )
565 )
568@dataclass
569class Sigmoid(SimpleOperator):
570 """1 / (1 + e^(-input))."""
572 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
573 return Tensor(1.0 / (1.0 + np.exp(-x)), dims=x.dims)
575 @property
576 def required_measures(self) -> Collection[Measure]:
577 return {}
579 def get_output_shape(
580 self, input_shape: Mapping[AxisId, int]
581 ) -> Mapping[AxisId, int]:
582 return input_shape
584 @classmethod
585 def from_proc_descr(
586 cls, descr: Union[v0_4.SigmoidDescr, v0_5.SigmoidDescr], member_id: MemberId
587 ) -> Self:
588 assert isinstance(descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr))
589 return cls(input=member_id, output=member_id)
591 def get_descr(self):
592 return v0_5.SigmoidDescr()
595@dataclass
596class Softmax(SimpleOperator):
597 """Softmax activation function."""
599 axis: AxisId = AxisId("channel")
601 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
602 axis_idx = x.dims.index(self.axis)
603 result = scipy.special.softmax(x.data, axis=axis_idx)
604 result_xr = xr.DataArray(result, dims=x.dims)
605 return Tensor.from_xarray(result_xr)
607 @property
608 def required_measures(self) -> Collection[Measure]:
609 return set()
611 def get_output_shape(
612 self, input_shape: Mapping[AxisId, int]
613 ) -> Mapping[AxisId, int]:
614 return input_shape
616 @classmethod
617 def from_proc_descr(cls, descr: v0_5.SoftmaxDescr, member_id: MemberId) -> Self:
618 assert isinstance(descr, v0_5.SoftmaxDescr)
619 return cls(input=member_id, output=member_id, axis=descr.kwargs.axis)
621 def get_descr(self):
622 return v0_5.SoftmaxDescr(kwargs=v0_5.SoftmaxKwargs(axis=self.axis))
625NdTuple = TypeVar("NdTuple", Tuple[int, int], Tuple[int, int, int])
626NdBorder = TypeVar(
627 "NdBorder",
628 Tuple[Tuple[int, int], Tuple[int, int]],
629 Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int]],
630)
633@dataclass
634class _StardistPostprocessingBase(SamplewiseOperator, Generic[NdTuple, NdBorder], ABC):
635 prob_dist_input_id: MemberId
636 instance_labels_output_id: MemberId
638 grid: NdTuple
639 """Grid size of network predictions."""
641 prob_threshold: float
642 """Object probability threshold for non-maximum suppression."""
644 nms_threshold: float
645 """The IoU threshold for non-maximum suppression."""
647 b: Union[int, NdBorder]
648 """Border region in which object probability is set to zero."""
650 @property
651 def required_measures(self) -> Collection[Measure]:
652 return set()
654 def __call__(self, sample: Sample) -> None:
655 prob_dist = sample.members[self.prob_dist_input_id]
657 assert AxisId("channel") in prob_dist.dims, (
658 "expected 'channel' axis in stardist probability/distance input"
659 )
660 allowed_spatial = tuple(
661 map(AxisId, ("y", "x") if len(self.grid) == 2 else ("z", "y", "x"))
662 )
663 assert all(
664 a in allowed_spatial or a in (AxisId("batch"), AxisId("channel"))
665 for a in prob_dist.dims
666 ), (
667 f"expected prob_dist to have only 'batch', 'channel', and spatial axes {allowed_spatial}, but got {prob_dist.dims}"
668 )
670 spatial_shape = tuple(
671 prob_dist.tagged_shape[a] * g for a, g in zip(allowed_spatial, self.grid)
672 )
673 if len(spatial_shape) != len(self.grid):
674 raise ValueError(
675 f"expected {len(self.grid)} spatial dimensions in prob_dist tensor, but got {len(spatial_shape)}"
676 )
677 else:
678 spatial_shape = cast(NdTuple, spatial_shape)
680 prob_dist = prob_dist.transpose(
681 (AxisId("batch"), *allowed_spatial, AxisId("channel"))
682 )
683 labels: List[NDArray[Any]] = []
684 for batch_idx in range(prob_dist.sizes[AxisId("batch")]):
685 prob = prob_dist[
686 {AxisId("batch"): batch_idx, AxisId("channel"): 0}
687 ].to_numpy()
688 dist = prob_dist[
689 {AxisId("batch"): batch_idx, AxisId("channel"): slice(1, None)}
690 ].to_numpy()
692 labels_i = self._impl(prob, dist, spatial_shape)
693 assert labels_i.shape == spatial_shape, (
694 f"expected label image shape {spatial_shape}, but got {labels_i.shape}"
695 )
696 labels.append(labels_i)
698 instance_labels = Tensor(
699 np.stack(labels)[..., None],
700 dims=(AxisId("batch"), *allowed_spatial, AxisId("channel")),
701 )
702 sample.members[self.instance_labels_output_id] = instance_labels
704 @abstractmethod
705 def _impl(
706 self, prob: NDArray[Any], dist: NDArray[Any], spatial_shape: NdTuple
707 ) -> NDArray[np.int32]:
708 raise NotImplementedError
711@dataclass
712class StardistPostprocessing2D(
713 _StardistPostprocessingBase[
714 Tuple[int, int], Tuple[Tuple[int, int], Tuple[int, int]]
715 ]
716):
717 def _impl(
718 self, prob: NDArray[Any], dist: NDArray[Any], spatial_shape: Tuple[int, int]
719 ) -> NDArray[np.int32]:
720 from stardist import (
721 non_maximum_suppression, # pyright: ignore[reportUnknownVariableType]
722 polygons_to_label, # pyright: ignore[reportUnknownVariableType]
723 )
725 points, probi, disti = non_maximum_suppression( # pyright: ignore[reportUnknownVariableType]
726 dist,
727 prob,
728 grid=self.grid,
729 prob_thresh=self.prob_threshold,
730 nms_thresh=self.nms_threshold,
731 b=self.b, # pyright: ignore[reportArgumentType]
732 )
734 return polygons_to_label(disti, points, prob=probi, shape=spatial_shape)
736 @classmethod
737 def from_proc_descr(
738 cls, descr: v0_5.StardistPostprocessingDescr, member_id: MemberId
739 ) -> Self:
740 if not isinstance(descr.kwargs, v0_5.StardistPostprocessingKwargs2D):
741 raise TypeError(
742 f"expected v0_5.StardistPostprocessingKwargs2D for 2D stardist post-processing, but got {type(descr.kwargs)}"
743 )
745 kwargs = descr.kwargs
746 return cls(
747 prob_dist_input_id=member_id,
748 instance_labels_output_id=member_id,
749 grid=kwargs.grid,
750 prob_threshold=kwargs.prob_threshold,
751 nms_threshold=kwargs.nms_threshold,
752 b=kwargs.b,
753 )
756@dataclass
757class StardistPostprocessing3D(
758 _StardistPostprocessingBase[
759 Tuple[int, int, int], Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int]]
760 ]
761):
762 n_rays: int
763 """Number of rays for 3D star-convex polyhedra."""
765 anisotropy: Tuple[float, float, float]
766 """Anisotropy factors for 3D star-convex polyhedra, i.e. the physical pixel size along each spatial axis."""
768 overlap_label: Optional[int] = None
769 """Optional label to apply to any area of overlapping predicted objects."""
771 def _impl(
772 self,
773 prob: NDArray[Any],
774 dist: NDArray[Any],
775 spatial_shape: Tuple[int, int, int],
776 ) -> NDArray[np.int32]:
777 from stardist import (
778 Rays_GoldenSpiral,
779 non_maximum_suppression_3d, # pyright: ignore[reportUnknownVariableType]
780 polyhedron_to_label, # pyright: ignore[reportUnknownVariableType]
781 )
782 from stardist.matching import (
783 relabel_sequential, # pyright: ignore[reportUnknownVariableType]
784 )
786 rays = Rays_GoldenSpiral(self.n_rays, anisotropy=self.anisotropy)
788 points, probi, disti = non_maximum_suppression_3d( # pyright: ignore[reportUnknownVariableType]
789 dist,
790 prob,
791 rays,
792 grid=self.grid,
793 prob_thresh=self.prob_threshold,
794 nms_thresh=self.nms_threshold,
795 b=self.b, # pyright: ignore[reportArgumentType]
796 )
798 labels = polyhedron_to_label( # pyright: ignore[reportUnknownVariableType]
799 disti,
800 points,
801 rays=rays,
802 prob=probi,
803 shape=spatial_shape,
804 overlap_label=self.overlap_label,
805 )
807 labels, _, _ = relabel_sequential(labels)
808 assert isinstance(labels, np.ndarray) and labels.dtype == np.int32
809 return labels
811 @classmethod
812 def from_proc_descr(
813 cls, descr: v0_5.StardistPostprocessingDescr, member_id: MemberId
814 ) -> Self:
815 if not isinstance(descr.kwargs, v0_5.StardistPostprocessingKwargs3D):
816 raise TypeError(
817 f"expected v0_5.StardistPostprocessingKwargs3D for 3D stardist post-processing, but got {type(descr.kwargs)}"
818 )
820 kwargs = descr.kwargs
821 return cls(
822 prob_dist_input_id=member_id,
823 instance_labels_output_id=member_id,
824 grid=kwargs.grid,
825 prob_threshold=kwargs.prob_threshold,
826 nms_threshold=kwargs.nms_threshold,
827 n_rays=kwargs.n_rays,
828 anisotropy=kwargs.anisotropy,
829 b=kwargs.b,
830 overlap_label=kwargs.overlap_label,
831 )
834@dataclass
835class ZeroMeanUnitVariance(SimpleOperator):
836 """normalize to zero mean, unit variance."""
838 mean: MeanMeasure
839 std: StdMeasure
841 eps: float = 1e-6
843 def __post_init__(self):
844 assert self.mean.axes == self.std.axes
846 @property
847 def required_measures(self) -> Set[Union[MeanMeasure, StdMeasure]]:
848 return {self.mean, self.std}
850 def get_output_shape(
851 self, input_shape: Mapping[AxisId, int]
852 ) -> Mapping[AxisId, int]:
853 return input_shape
855 @classmethod
856 def from_proc_descr(
857 cls,
858 descr: Union[v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr],
859 member_id: MemberId,
860 ):
861 dataset_mode, axes = _get_axes(descr.kwargs)
863 if dataset_mode:
864 Mean = DatasetMean
865 Std = DatasetStd
866 else:
867 Mean = SampleMean
868 Std = SampleStd
870 return cls(
871 input=member_id,
872 output=member_id,
873 mean=Mean(axes=axes, member_id=member_id),
874 std=Std(axes=axes, member_id=member_id),
875 )
877 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
878 mean = stat[self.mean]
879 std = stat[self.std]
880 return (x - mean) / (std + self.eps)
882 def get_descr(self):
883 return v0_5.ZeroMeanUnitVarianceDescr(
884 kwargs=v0_5.ZeroMeanUnitVarianceKwargs(axes=self.mean.axes, eps=self.eps)
885 )
888@dataclass
889class FixedZeroMeanUnitVariance(SimpleOperator):
890 """normalize to zero mean, unit variance with precomputed values."""
892 mean: Union[float, xr.DataArray]
893 std: Union[float, xr.DataArray]
895 eps: float = 1e-6
897 def __post_init__(self):
898 assert (
899 isinstance(self.mean, (int, float))
900 or isinstance(self.std, (int, float))
901 or self.mean.dims == self.std.dims
902 )
904 @property
905 def required_measures(self) -> Collection[Measure]:
906 return set()
908 def get_output_shape(
909 self, input_shape: Mapping[AxisId, int]
910 ) -> Mapping[AxisId, int]:
911 return input_shape
913 @classmethod
914 def from_proc_descr(
915 cls,
916 descr: v0_5.FixedZeroMeanUnitVarianceDescr,
917 member_id: MemberId,
918 ) -> Self:
919 if isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceKwargs):
920 dims = None
921 elif isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs):
922 dims = (AxisId(descr.kwargs.axis),)
923 else:
924 assert_never(descr.kwargs)
926 return cls(
927 input=member_id,
928 output=member_id,
929 mean=xr.DataArray(descr.kwargs.mean, dims=dims),
930 std=xr.DataArray(descr.kwargs.std, dims=dims),
931 )
933 def get_descr(self):
934 if isinstance(self.mean, (int, float)):
935 assert isinstance(self.std, (int, float))
936 kwargs = v0_5.FixedZeroMeanUnitVarianceKwargs(mean=self.mean, std=self.std)
937 else:
938 assert isinstance(self.std, xr.DataArray)
939 assert len(self.mean.dims) == 1
940 kwargs = v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs(
941 axis=AxisId(str(self.mean.dims[0])),
942 mean=list(self.mean),
943 std=list(self.std),
944 )
946 return v0_5.FixedZeroMeanUnitVarianceDescr(kwargs=kwargs)
948 def _apply(self, x: Tensor, stat: Stat) -> Tensor:
949 return (x - self.mean) / (self.std + self.eps)
952@dataclass
953class CustomPostprocessing(SamplewiseOperator):
954 """Execute a user-supplied custom postprocessing callable.
956 The callable is loaded from a Python source file packaged with the model.
957 The source file's SHA-256 hash is verified before loading.
959 Two styles are supported — callable class and factory function::
961 # Callable class style
962 class my_postprocess:
963 def __init__(self, threshold=0.5):
964 self.threshold = threshold
965 def __call__(self, *arrays):
966 return (arrays[0] > self.threshold).astype(np.uint8)
968 # Factory function style
969 def my_postprocess(threshold=0.5):
970 def run(*arrays):
971 return (arrays[0] > threshold).astype(np.uint8)
972 return run
974 Runtime protocol: ``op = callable(**kwargs)`` once at construction;
975 ``result = op(*tensors)`` once per sample.
976 """
978 output_id: MemberId
979 """The model output tensor that will be replaced with the op result."""
981 input_ids: Sequence[MemberId]
982 """All model output tensor ids, passed to the op in rdf.yaml declaration order."""
984 callable_name: str
985 """Name of the class or factory function defined in ``source_code``."""
987 source_code: bytes
988 """Python source code of the op file."""
990 kwargs: Mapping[str, Any]
991 """Keyword arguments forwarded to the callable."""
993 # Initialised in __post_init__
994 _op: Any = field(init=False, repr=False)
996 def __post_init__(self) -> None:
997 import importlib.util
998 import sys
999 import tempfile
1001 # Write source to a temp file so importlib can load it properly
1002 with tempfile.NamedTemporaryFile(
1003 suffix=".py",
1004 prefix=f"_bioimageio_custom_{self.callable_name}_",
1005 delete=False,
1006 ) as tmp:
1007 _ = tmp.write(self.source_code)
1008 tmp_path = tmp.name
1010 spec = importlib.util.spec_from_file_location(
1011 f"_bioimageio_custom_op_{self.callable_name}", tmp_path
1012 )
1013 if spec is None or spec.loader is None:
1014 raise ImportError(
1015 f"Cannot create module spec from {tmp_path!r}"
1016 )
1017 module = importlib.util.module_from_spec(spec)
1018 sys.modules[spec.name] = module
1019 spec.loader.exec_module(module)
1021 callable_obj = getattr(module, self.callable_name, None)
1022 if callable_obj is None:
1023 raise AttributeError(
1024 f"No attribute {self.callable_name!r} found in custom op source"
1025 )
1026 self._op = callable_obj(**self.kwargs)
1028 @property
1029 def required_measures(self) -> Collection[Measure]:
1030 return set()
1032 def __call__(self, sample: Sample) -> None:
1033 arrays: List[NDArray[Any]] = [
1034 sample.members[mid].data.values
1035 for mid in self.input_ids
1036 if mid in sample.members
1037 ]
1038 result_array: NDArray[Any] = self._op(*arrays)
1039 result_xr = xr.DataArray(
1040 result_array, dims=sample.members[self.output_id].dims
1041 )
1042 sample.members[self.output_id] = Tensor.from_xarray(result_xr)
1044 @classmethod
1045 def from_proc_descr(
1046 cls,
1047 descr: Any, # v0_5.CustomPostprocessingDescr (guarded for older spec versions)
1048 tensor_descr: v0_5.OutputTensorDescr,
1049 all_output_ids: Sequence[MemberId],
1050 ) -> "CustomPostprocessing":
1051 from bioimageio.spec._internal.io import get_reader
1053 output_id = get_member_id(tensor_descr)
1054 reader = get_reader(descr.source, sha256=descr.sha256)
1055 source_code: bytes = reader.read()
1057 return cls(
1058 output_id=output_id,
1059 input_ids=list(all_output_ids),
1060 callable_name=descr.callable,
1061 source_code=source_code,
1062 kwargs=dict(descr.kwargs),
1063 )
1066ProcDescr = Union[
1067 v0_4.PreprocessingDescr,
1068 v0_4.PostprocessingDescr,
1069 v0_5.PreprocessingDescr,
1070 v0_5.PostprocessingDescr,
1071]
1074Processing = Union[
1075 AddKnownDatasetStats,
1076 Binarize,
1077 Clip,
1078 CustomPostprocessing,
1079 EnsureDtype,
1080 FixedZeroMeanUnitVariance,
1081 ScaleLinear,
1082 ScaleMeanVariance,
1083 ScaleRange,
1084 Sigmoid,
1085 StardistPostprocessing2D,
1086 StardistPostprocessing3D,
1087 Softmax,
1088 UpdateStats,
1089 ZeroMeanUnitVariance,
1090]
1093def get_proc(
1094 proc_descr: ProcDescr,
1095 tensor_descr: Union[
1096 v0_4.InputTensorDescr,
1097 v0_4.OutputTensorDescr,
1098 v0_5.InputTensorDescr,
1099 v0_5.OutputTensorDescr,
1100 ],
1101) -> Processing:
1102 member_id = get_member_id(tensor_descr)
1104 if isinstance(proc_descr, (v0_4.BinarizeDescr, v0_5.BinarizeDescr)):
1105 return Binarize.from_proc_descr(proc_descr, member_id)
1106 elif isinstance(proc_descr, (v0_4.ClipDescr, v0_5.ClipDescr)):
1107 return Clip.from_proc_descr(proc_descr, member_id)
1108 elif isinstance(proc_descr, v0_5.EnsureDtypeDescr):
1109 return EnsureDtype.from_proc_descr(proc_descr, member_id)
1110 elif isinstance(proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr):
1111 return FixedZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id)
1112 elif isinstance(proc_descr, (v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr)):
1113 return ScaleLinear.from_proc_descr(proc_descr, member_id)
1114 elif isinstance(
1115 proc_descr, (v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr)
1116 ):
1117 return ScaleMeanVariance.from_proc_descr(proc_descr, member_id)
1118 elif isinstance(proc_descr, (v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr)):
1119 return ScaleRange.from_proc_descr(proc_descr, member_id)
1120 elif isinstance(proc_descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)):
1121 return Sigmoid.from_proc_descr(proc_descr, member_id)
1122 elif (
1123 isinstance(proc_descr, v0_4.ZeroMeanUnitVarianceDescr)
1124 and proc_descr.kwargs.mode == "fixed"
1125 ):
1126 if not isinstance(
1127 tensor_descr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)
1128 ):
1129 raise TypeError(
1130 "Expected v0_4 tensor description for v0_4 processing description"
1131 )
1133 v5_proc_descr = _convert_proc(proc_descr, tensor_descr.axes)
1134 assert isinstance(v5_proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr)
1135 return FixedZeroMeanUnitVariance.from_proc_descr(v5_proc_descr, member_id)
1136 elif isinstance(
1137 proc_descr,
1138 (v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr),
1139 ):
1140 return ZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id)
1141 elif isinstance(proc_descr, v0_5.SoftmaxDescr):
1142 return Softmax.from_proc_descr(proc_descr, member_id)
1143 elif isinstance(proc_descr, v0_5.StardistPostprocessingDescr):
1144 if isinstance(proc_descr.kwargs, v0_5.StardistPostprocessingKwargs2D):
1145 return StardistPostprocessing2D.from_proc_descr(proc_descr, member_id)
1146 elif isinstance(proc_descr.kwargs, v0_5.StardistPostprocessingKwargs3D):
1147 return StardistPostprocessing3D.from_proc_descr(proc_descr, member_id)
1148 else:
1149 raise ValueError(
1150 f"expected ndim 2 or 3 for stardist postprocessing, but got {proc_descr.kwargs.ndim}"
1151 )
1152 else:
1153 assert_never(proc_descr)