Coverage for src / bioimageio / core / _ops_stardist.py: 55%

87 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-18 12:35 +0000

1from abc import ABC, abstractmethod 

2from dataclasses import dataclass 

3from typing import ( 

4 Any, 

5 Collection, 

6 Generic, 

7 List, 

8 Optional, 

9 Tuple, 

10 Union, 

11) 

12 

13import numpy as np 

14from numpy.typing import NDArray 

15from typing_extensions import Self, TypeVar, cast 

16 

17from bioimageio.spec.model import v0_5 

18 

19from ._op_base import SamplewiseOperator 

20from .axis import AxisId 

21from .common import MemberId 

22from .sample import Sample 

23from .stat_measures import ( 

24 Measure, 

25) 

26from .tensor import Tensor 

27 

28NdTuple = TypeVar("NdTuple", Tuple[int, int], Tuple[int, int, int]) 

29NdBorder = TypeVar( 

30 "NdBorder", 

31 Tuple[Tuple[int, int], Tuple[int, int]], 

32 Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int]], 

33) 

34 

35 

36@dataclass 

37class _StardistPostprocessingBase(SamplewiseOperator, Generic[NdTuple, NdBorder], ABC): 

38 prob_dist_input_id: MemberId 

39 instance_labels_output_id: MemberId 

40 

41 grid: NdTuple 

42 """Grid size of network predictions.""" 

43 

44 prob_threshold: float 

45 """Object probability threshold for non-maximum suppression.""" 

46 

47 nms_threshold: float 

48 """The IoU threshold for non-maximum suppression.""" 

49 

50 b: Union[int, NdBorder] 

51 """Border region in which object probability is set to zero.""" 

52 

53 n_rays: int 

54 """Number of radial lines (rays) cast from the center of an object to its boundary.""" 

55 

56 @property 

57 def required_measures(self) -> Collection[Measure]: 

58 return set() 

59 

60 def __call__(self, sample: Sample) -> None: 

61 prob_dist = sample.members[self.prob_dist_input_id] 

62 

63 assert AxisId("channel") in prob_dist.dims, ( 

64 "expected 'channel' axis in stardist probability/distance input" 

65 ) 

66 allowed_spatial = tuple( 

67 map(AxisId, ("y", "x") if len(self.grid) == 2 else ("z", "y", "x")) 

68 ) 

69 assert all( 

70 a in allowed_spatial or a in (AxisId("batch"), AxisId("channel")) 

71 for a in prob_dist.dims 

72 ), ( 

73 f"expected prob_dist to have only 'batch', 'channel', and spatial axes {allowed_spatial}, but got {prob_dist.dims}" 

74 ) 

75 

76 spatial_shape = tuple( 

77 prob_dist.tagged_shape[a] * g for a, g in zip(allowed_spatial, self.grid) 

78 ) 

79 if len(spatial_shape) != len(self.grid): 

80 raise ValueError( 

81 f"expected {len(self.grid)} spatial dimensions in prob_dist tensor, but got {len(spatial_shape)}" 

82 ) 

83 else: 

84 spatial_shape = cast(NdTuple, spatial_shape) 

85 

86 prob_dist = prob_dist.transpose( 

87 (AxisId("batch"), *allowed_spatial, AxisId("channel")) 

88 ) 

89 labels: List[NDArray[Any]] = [] 

90 for batch_idx in range(prob_dist.sizes[AxisId("batch")]): 

91 prob = prob_dist[ 

92 {AxisId("batch"): batch_idx, AxisId("channel"): 0} 

93 ].to_numpy() 

94 dist = prob_dist[ 

95 {AxisId("batch"): batch_idx, AxisId("channel"): slice(1, None)} 

96 ].to_numpy() 

97 

98 labels_i = self._impl(prob, dist, spatial_shape) 

99 assert labels_i.shape == spatial_shape, ( 

100 f"expected label image shape {spatial_shape}, but got {labels_i.shape}" 

101 ) 

102 labels.append(labels_i) 

103 

104 instance_labels = Tensor( 

105 np.stack(labels)[..., None], 

106 dims=(AxisId("batch"), *allowed_spatial, AxisId("channel")), 

107 ) 

108 sample.members[self.instance_labels_output_id] = instance_labels 

109 

110 @abstractmethod 

111 def _impl( 

112 self, prob: NDArray[Any], dist: NDArray[Any], spatial_shape: NdTuple 

113 ) -> NDArray[np.int32]: 

114 raise NotImplementedError 

115 

116 

117@dataclass 

118class StardistPostprocessing2D( 

119 _StardistPostprocessingBase[ 

120 Tuple[int, int], Tuple[Tuple[int, int], Tuple[int, int]] 

121 ] 

122): 

123 def _impl( 

124 self, prob: NDArray[Any], dist: NDArray[Any], spatial_shape: Tuple[int, int] 

125 ) -> NDArray[np.int32]: 

126 from stardist import ( 

127 non_maximum_suppression, # pyright: ignore[reportUnknownVariableType] 

128 polygons_to_label, # pyright: ignore[reportUnknownVariableType] 

129 ) 

130 

131 points, probi, disti = non_maximum_suppression( # pyright: ignore[reportUnknownVariableType] 

132 dist, 

133 prob, 

134 grid=self.grid, 

135 prob_thresh=self.prob_threshold, 

136 nms_thresh=self.nms_threshold, 

137 b=self.b, # pyright: ignore[reportArgumentType] 

138 ) 

139 

140 return polygons_to_label(disti, points, prob=probi, shape=spatial_shape) 

141 

142 @classmethod 

143 def from_proc_descr( 

144 cls, descr: v0_5.StardistPostprocessingDescr, member_id: MemberId 

145 ) -> Self: 

146 if not isinstance(descr.kwargs, v0_5.StardistPostprocessingKwargs2D): 

147 raise TypeError( 

148 f"expected v0_5.StardistPostprocessingKwargs2D for 2D stardist post-processing, but got {type(descr.kwargs)}" 

149 ) 

150 

151 kwargs = descr.kwargs 

152 return cls( 

153 prob_dist_input_id=member_id, 

154 instance_labels_output_id=member_id, 

155 grid=kwargs.grid, 

156 prob_threshold=kwargs.prob_threshold, 

157 nms_threshold=kwargs.nms_threshold, 

158 b=kwargs.b, 

159 n_rays=kwargs.n_rays, 

160 ) 

161 

162 

163@dataclass 

164class StardistPostprocessing3D( 

165 _StardistPostprocessingBase[ 

166 Tuple[int, int, int], Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int]] 

167 ] 

168): 

169 anisotropy: Tuple[float, float, float] 

170 """Anisotropy factors for 3D star-convex polyhedra, i.e. the physical pixel size along each spatial axis.""" 

171 

172 overlap_label: Optional[int] = None 

173 """Optional label to apply to any area of overlapping predicted objects.""" 

174 

175 def _impl( 

176 self, 

177 prob: NDArray[Any], 

178 dist: NDArray[Any], 

179 spatial_shape: Tuple[int, int, int], 

180 ) -> NDArray[np.int32]: 

181 from stardist import ( 

182 Rays_GoldenSpiral, 

183 non_maximum_suppression_3d, # pyright: ignore[reportUnknownVariableType] 

184 polyhedron_to_label, # pyright: ignore[reportUnknownVariableType] 

185 ) 

186 from stardist.matching import ( 

187 relabel_sequential, # pyright: ignore[reportUnknownVariableType] 

188 ) 

189 

190 rays = Rays_GoldenSpiral(self.n_rays, anisotropy=self.anisotropy) 

191 

192 points, probi, disti = non_maximum_suppression_3d( # pyright: ignore[reportUnknownVariableType] 

193 dist, 

194 prob, 

195 rays, 

196 grid=self.grid, 

197 prob_thresh=self.prob_threshold, 

198 nms_thresh=self.nms_threshold, 

199 b=self.b, # pyright: ignore[reportArgumentType] 

200 ) 

201 

202 labels = polyhedron_to_label( # pyright: ignore[reportUnknownVariableType] 

203 disti, 

204 points, 

205 rays=rays, 

206 prob=probi, 

207 shape=spatial_shape, 

208 overlap_label=self.overlap_label, 

209 ) 

210 

211 labels, _, _ = relabel_sequential(labels) 

212 assert isinstance(labels, np.ndarray) and labels.dtype == np.int32 

213 return labels 

214 

215 @classmethod 

216 def from_proc_descr( 

217 cls, descr: v0_5.StardistPostprocessingDescr, member_id: MemberId 

218 ) -> Self: 

219 if not isinstance(descr.kwargs, v0_5.StardistPostprocessingKwargs3D): 

220 raise TypeError( 

221 f"expected v0_5.StardistPostprocessingKwargs3D for 3D stardist post-processing, but got {type(descr.kwargs)}" 

222 ) 

223 

224 kwargs = descr.kwargs 

225 return cls( 

226 prob_dist_input_id=member_id, 

227 instance_labels_output_id=member_id, 

228 grid=kwargs.grid, 

229 prob_threshold=kwargs.prob_threshold, 

230 nms_threshold=kwargs.nms_threshold, 

231 n_rays=kwargs.n_rays, 

232 anisotropy=kwargs.anisotropy, 

233 b=kwargs.b, 

234 overlap_label=kwargs.overlap_label, 

235 )