Coverage for src / bioimageio / core / _ops_stardist.py: 55%
87 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-18 12:35 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-18 12:35 +0000
1from abc import ABC, abstractmethod
2from dataclasses import dataclass
3from typing import (
4 Any,
5 Collection,
6 Generic,
7 List,
8 Optional,
9 Tuple,
10 Union,
11)
13import numpy as np
14from numpy.typing import NDArray
15from typing_extensions import Self, TypeVar, cast
17from bioimageio.spec.model import v0_5
19from ._op_base import SamplewiseOperator
20from .axis import AxisId
21from .common import MemberId
22from .sample import Sample
23from .stat_measures import (
24 Measure,
25)
26from .tensor import Tensor
28NdTuple = TypeVar("NdTuple", Tuple[int, int], Tuple[int, int, int])
29NdBorder = TypeVar(
30 "NdBorder",
31 Tuple[Tuple[int, int], Tuple[int, int]],
32 Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int]],
33)
36@dataclass
37class _StardistPostprocessingBase(SamplewiseOperator, Generic[NdTuple, NdBorder], ABC):
38 prob_dist_input_id: MemberId
39 instance_labels_output_id: MemberId
41 grid: NdTuple
42 """Grid size of network predictions."""
44 prob_threshold: float
45 """Object probability threshold for non-maximum suppression."""
47 nms_threshold: float
48 """The IoU threshold for non-maximum suppression."""
50 b: Union[int, NdBorder]
51 """Border region in which object probability is set to zero."""
53 n_rays: int
54 """Number of radial lines (rays) cast from the center of an object to its boundary."""
56 @property
57 def required_measures(self) -> Collection[Measure]:
58 return set()
60 def __call__(self, sample: Sample) -> None:
61 prob_dist = sample.members[self.prob_dist_input_id]
63 assert AxisId("channel") in prob_dist.dims, (
64 "expected 'channel' axis in stardist probability/distance input"
65 )
66 allowed_spatial = tuple(
67 map(AxisId, ("y", "x") if len(self.grid) == 2 else ("z", "y", "x"))
68 )
69 assert all(
70 a in allowed_spatial or a in (AxisId("batch"), AxisId("channel"))
71 for a in prob_dist.dims
72 ), (
73 f"expected prob_dist to have only 'batch', 'channel', and spatial axes {allowed_spatial}, but got {prob_dist.dims}"
74 )
76 spatial_shape = tuple(
77 prob_dist.tagged_shape[a] * g for a, g in zip(allowed_spatial, self.grid)
78 )
79 if len(spatial_shape) != len(self.grid):
80 raise ValueError(
81 f"expected {len(self.grid)} spatial dimensions in prob_dist tensor, but got {len(spatial_shape)}"
82 )
83 else:
84 spatial_shape = cast(NdTuple, spatial_shape)
86 prob_dist = prob_dist.transpose(
87 (AxisId("batch"), *allowed_spatial, AxisId("channel"))
88 )
89 labels: List[NDArray[Any]] = []
90 for batch_idx in range(prob_dist.sizes[AxisId("batch")]):
91 prob = prob_dist[
92 {AxisId("batch"): batch_idx, AxisId("channel"): 0}
93 ].to_numpy()
94 dist = prob_dist[
95 {AxisId("batch"): batch_idx, AxisId("channel"): slice(1, None)}
96 ].to_numpy()
98 labels_i = self._impl(prob, dist, spatial_shape)
99 assert labels_i.shape == spatial_shape, (
100 f"expected label image shape {spatial_shape}, but got {labels_i.shape}"
101 )
102 labels.append(labels_i)
104 instance_labels = Tensor(
105 np.stack(labels)[..., None],
106 dims=(AxisId("batch"), *allowed_spatial, AxisId("channel")),
107 )
108 sample.members[self.instance_labels_output_id] = instance_labels
110 @abstractmethod
111 def _impl(
112 self, prob: NDArray[Any], dist: NDArray[Any], spatial_shape: NdTuple
113 ) -> NDArray[np.int32]:
114 raise NotImplementedError
117@dataclass
118class StardistPostprocessing2D(
119 _StardistPostprocessingBase[
120 Tuple[int, int], Tuple[Tuple[int, int], Tuple[int, int]]
121 ]
122):
123 def _impl(
124 self, prob: NDArray[Any], dist: NDArray[Any], spatial_shape: Tuple[int, int]
125 ) -> NDArray[np.int32]:
126 from stardist import (
127 non_maximum_suppression, # pyright: ignore[reportUnknownVariableType]
128 polygons_to_label, # pyright: ignore[reportUnknownVariableType]
129 )
131 points, probi, disti = non_maximum_suppression( # pyright: ignore[reportUnknownVariableType]
132 dist,
133 prob,
134 grid=self.grid,
135 prob_thresh=self.prob_threshold,
136 nms_thresh=self.nms_threshold,
137 b=self.b, # pyright: ignore[reportArgumentType]
138 )
140 return polygons_to_label(disti, points, prob=probi, shape=spatial_shape)
142 @classmethod
143 def from_proc_descr(
144 cls, descr: v0_5.StardistPostprocessingDescr, member_id: MemberId
145 ) -> Self:
146 if not isinstance(descr.kwargs, v0_5.StardistPostprocessingKwargs2D):
147 raise TypeError(
148 f"expected v0_5.StardistPostprocessingKwargs2D for 2D stardist post-processing, but got {type(descr.kwargs)}"
149 )
151 kwargs = descr.kwargs
152 return cls(
153 prob_dist_input_id=member_id,
154 instance_labels_output_id=member_id,
155 grid=kwargs.grid,
156 prob_threshold=kwargs.prob_threshold,
157 nms_threshold=kwargs.nms_threshold,
158 b=kwargs.b,
159 n_rays=kwargs.n_rays,
160 )
163@dataclass
164class StardistPostprocessing3D(
165 _StardistPostprocessingBase[
166 Tuple[int, int, int], Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int]]
167 ]
168):
169 anisotropy: Tuple[float, float, float]
170 """Anisotropy factors for 3D star-convex polyhedra, i.e. the physical pixel size along each spatial axis."""
172 overlap_label: Optional[int] = None
173 """Optional label to apply to any area of overlapping predicted objects."""
175 def _impl(
176 self,
177 prob: NDArray[Any],
178 dist: NDArray[Any],
179 spatial_shape: Tuple[int, int, int],
180 ) -> NDArray[np.int32]:
181 from stardist import (
182 Rays_GoldenSpiral,
183 non_maximum_suppression_3d, # pyright: ignore[reportUnknownVariableType]
184 polyhedron_to_label, # pyright: ignore[reportUnknownVariableType]
185 )
186 from stardist.matching import (
187 relabel_sequential, # pyright: ignore[reportUnknownVariableType]
188 )
190 rays = Rays_GoldenSpiral(self.n_rays, anisotropy=self.anisotropy)
192 points, probi, disti = non_maximum_suppression_3d( # pyright: ignore[reportUnknownVariableType]
193 dist,
194 prob,
195 rays,
196 grid=self.grid,
197 prob_thresh=self.prob_threshold,
198 nms_thresh=self.nms_threshold,
199 b=self.b, # pyright: ignore[reportArgumentType]
200 )
202 labels = polyhedron_to_label( # pyright: ignore[reportUnknownVariableType]
203 disti,
204 points,
205 rays=rays,
206 prob=probi,
207 shape=spatial_shape,
208 overlap_label=self.overlap_label,
209 )
211 labels, _, _ = relabel_sequential(labels)
212 assert isinstance(labels, np.ndarray) and labels.dtype == np.int32
213 return labels
215 @classmethod
216 def from_proc_descr(
217 cls, descr: v0_5.StardistPostprocessingDescr, member_id: MemberId
218 ) -> Self:
219 if not isinstance(descr.kwargs, v0_5.StardistPostprocessingKwargs3D):
220 raise TypeError(
221 f"expected v0_5.StardistPostprocessingKwargs3D for 3D stardist post-processing, but got {type(descr.kwargs)}"
222 )
224 kwargs = descr.kwargs
225 return cls(
226 prob_dist_input_id=member_id,
227 instance_labels_output_id=member_id,
228 grid=kwargs.grid,
229 prob_threshold=kwargs.prob_threshold,
230 nms_threshold=kwargs.nms_threshold,
231 n_rays=kwargs.n_rays,
232 anisotropy=kwargs.anisotropy,
233 b=kwargs.b,
234 overlap_label=kwargs.overlap_label,
235 )