Coverage for src / bioimageio / core / proc_ops.py: 71%

463 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-30 13:23 +0000

1import collections.abc 

2from abc import ABC, abstractmethod 

3from dataclasses import InitVar, dataclass, field 

4from functools import partial 

5from typing import ( 

6 Any, 

7 Collection, 

8 Generic, 

9 List, 

10 Literal, 

11 Mapping, 

12 Optional, 

13 Sequence, 

14 Set, 

15 Tuple, 

16 Union, 

17) 

18 

19import numpy as np 

20import scipy 

21import xarray as xr 

22from numpy.typing import NDArray 

23from typing_extensions import Self, TypeVar, assert_never, cast 

24 

25from bioimageio.core.digest_spec import get_member_id 

26from bioimageio.spec.model import v0_4, v0_5 

27from bioimageio.spec.model.v0_5 import ( 

28 _convert_proc, # pyright: ignore[reportPrivateUsage] 

29) 

30 

31from ._op_base import BlockwiseOperator, SamplewiseOperator, SimpleOperator 

32from .axis import AxisId 

33from .common import DTypeStr, MemberId 

34from .sample import Sample, SampleBlock 

35from .stat_calculators import StatsCalculator 

36from .stat_measures import ( 

37 DatasetMean, 

38 DatasetMeasure, 

39 DatasetQuantile, 

40 DatasetStd, 

41 MeanMeasure, 

42 Measure, 

43 MeasureValue, 

44 SampleMean, 

45 SampleQuantile, 

46 SampleStd, 

47 Stat, 

48 StdMeasure, 

49) 

50from .tensor import Tensor 

51 

52 

53def _convert_axis_ids( 

54 axes: v0_4.AxesInCZYX, 

55 mode: Literal["per_sample", "per_dataset"], 

56) -> Tuple[AxisId, ...]: 

57 if not isinstance(axes, str): 

58 return tuple(axes) 

59 

60 if mode == "per_sample": 

61 ret = [] 

62 elif mode == "per_dataset": 

63 ret = [v0_5.BATCH_AXIS_ID] 

64 else: 

65 assert_never(mode) 

66 

67 ret.extend([AxisId(a) for a in axes]) 

68 return tuple(ret) 

69 

70 

71@dataclass 

72class AddKnownDatasetStats(BlockwiseOperator): 

73 dataset_stats: Mapping[DatasetMeasure, MeasureValue] 

74 

75 @property 

76 def required_measures(self) -> Collection[Measure]: 

77 return set() 

78 

79 def __call__(self, sample: Union[Sample, SampleBlock]) -> None: 

80 sample.stat.update(self.dataset_stats.items()) 

81 

82 

83# @dataclass 

84# class UpdateStats(Operator): 

85# """Calculates sample and/or dataset measures""" 

86 

87# measures: Union[Sequence[Measure], Set[Measure], Mapping[Measure, MeasureValue]] 

88# """sample and dataset `measuers` to be calculated by this operator. Initial/fixed 

89# dataset measure values may be given, see `keep_updating_dataset_stats` for details. 

90# """ 

91# keep_updating_dataset_stats: Optional[bool] = None 

92# """indicates if operator calls should keep updating dataset statistics or not 

93 

94# default (None): if `measures` is a `Mapping` (i.e. initial measure values are 

95# given) no further updates to dataset statistics is conducted, otherwise (w.o. 

96# initial measure values) dataset statistics are updated by each processed sample. 

97# """ 

98# _keep_updating_dataset_stats: bool = field(init=False) 

99# _stats_calculator: StatsCalculator = field(init=False) 

100 

101# @property 

102# def required_measures(self) -> Collection[Measure]: 

103# return set() 

104 

105# def __post_init__(self): 

106# self._stats_calculator = StatsCalculator(self.measures) 

107# if self.keep_updating_dataset_stats is None: 

108# self._keep_updating_dataset_stats = not isinstance(self.measures, collections.abc.Mapping) 

109# else: 

110# self._keep_updating_dataset_stats = self.keep_updating_dataset_stats 

111 

112# def __call__(self, sample_block: SampleBlockWithOrigin> None: 

113# if self._keep_updating_dataset_stats: 

114# sample.stat.update(self._stats_calculator.update_and_get_all(sample)) 

115# else: 

116# sample.stat.update(self._stats_calculator.skip_update_and_get_all(sample)) 

117 

118 

119@dataclass 

120class UpdateStats(SamplewiseOperator): 

121 """Calculates sample and/or dataset measures""" 

122 

123 stats_calculator: StatsCalculator 

124 """`StatsCalculator` to be used by this operator.""" 

125 keep_updating_initial_dataset_stats: bool = False 

126 """indicates if operator calls should keep updating initial dataset statistics or not; 

127 if the `stats_calculator` was not provided with any initial dataset statistics, 

128 these are always updated with every new sample. 

129 """ 

130 _keep_updating_dataset_stats: bool = field(init=False) 

131 

132 @property 

133 def required_measures(self) -> Collection[Measure]: 

134 return set() 

135 

136 def __post_init__(self): 

137 self._keep_updating_dataset_stats = ( 

138 self.keep_updating_initial_dataset_stats 

139 or not self.stats_calculator.has_dataset_measures 

140 ) 

141 

142 def __call__(self, sample: Sample) -> None: 

143 if self._keep_updating_dataset_stats: 

144 sample.stat.update(self.stats_calculator.update_and_get_all(sample)) 

145 else: 

146 sample.stat.update(self.stats_calculator.skip_update_and_get_all(sample)) 

147 

148 

149@dataclass 

150class Binarize(SimpleOperator): 

151 """'output = tensor > threshold'.""" 

152 

153 threshold: Union[float, Sequence[float]] 

154 axis: Optional[AxisId] = None 

155 

156 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

157 return x > self.threshold 

158 

159 @property 

160 def required_measures(self) -> Collection[Measure]: 

161 return set() 

162 

163 def get_output_shape( 

164 self, input_shape: Mapping[AxisId, int] 

165 ) -> Mapping[AxisId, int]: 

166 return input_shape 

167 

168 @classmethod 

169 def from_proc_descr( 

170 cls, descr: Union[v0_4.BinarizeDescr, v0_5.BinarizeDescr], member_id: MemberId 

171 ) -> Self: 

172 if isinstance(descr.kwargs, (v0_4.BinarizeKwargs, v0_5.BinarizeKwargs)): 

173 return cls( 

174 input=member_id, output=member_id, threshold=descr.kwargs.threshold 

175 ) 

176 elif isinstance(descr.kwargs, v0_5.BinarizeAlongAxisKwargs): 

177 return cls( 

178 input=member_id, 

179 output=member_id, 

180 threshold=descr.kwargs.threshold, 

181 axis=descr.kwargs.axis, 

182 ) 

183 else: 

184 assert_never(descr.kwargs) 

185 

186 

187@dataclass 

188class Clip(SimpleOperator): 

189 min: Optional[Union[float, SampleQuantile, DatasetQuantile]] = None 

190 """minimum value for clipping""" 

191 max: Optional[Union[float, SampleQuantile, DatasetQuantile]] = None 

192 """maximum value for clipping""" 

193 

194 def __post_init__(self): 

195 if self.min is None and self.max is None: 

196 raise ValueError("missing min or max value") 

197 

198 if ( 

199 isinstance(self.min, float) 

200 and isinstance(self.max, float) 

201 and self.min >= self.max 

202 ): 

203 raise ValueError(f"expected min < max, but {self.min} >= {self.max}") 

204 

205 if isinstance(self.min, (SampleQuantile, DatasetQuantile)) and isinstance( 

206 self.max, (SampleQuantile, DatasetQuantile) 

207 ): 

208 if self.min.axes != self.max.axes: 

209 raise NotImplementedError( 

210 f"expected min and max quantiles with same axes, but got {self.min.axes} and {self.max.axes}" 

211 ) 

212 if self.min.q >= self.max.q: 

213 raise ValueError( 

214 f"expected min quantile < max quantile, but {self.min.q} >= {self.max.q}" 

215 ) 

216 

217 @property 

218 def required_measures(self): 

219 return { 

220 arg 

221 for arg in (self.min, self.max) 

222 if isinstance(arg, (SampleQuantile, DatasetQuantile)) 

223 } 

224 

225 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

226 if isinstance(self.min, (SampleQuantile, DatasetQuantile)): 

227 min_value = stat[self.min] 

228 if isinstance(min_value, (int, float)): 

229 # use clip for scalar value 

230 min_clip_arg = min_value 

231 else: 

232 # clip does not support non-scalar values 

233 x = Tensor.from_xarray( 

234 x.data.where(x.data >= min_value.data, min_value.data) 

235 ) 

236 min_clip_arg = None 

237 else: 

238 min_clip_arg = self.min 

239 

240 if isinstance(self.max, (SampleQuantile, DatasetQuantile)): 

241 max_value = stat[self.max] 

242 if isinstance(max_value, (int, float)): 

243 # use clip for scalar value 

244 max_clip_arg = max_value 

245 else: 

246 # clip does not support non-scalar values 

247 x = Tensor.from_xarray( 

248 x.data.where(x.data <= max_value.data, max_value.data) 

249 ) 

250 max_clip_arg = None 

251 else: 

252 max_clip_arg = self.max 

253 

254 if min_clip_arg is not None or max_clip_arg is not None: 

255 x = x.clip(min_clip_arg, max_clip_arg) 

256 

257 return x 

258 

259 def get_output_shape( 

260 self, input_shape: Mapping[AxisId, int] 

261 ) -> Mapping[AxisId, int]: 

262 return input_shape 

263 

264 @classmethod 

265 def from_proc_descr( 

266 cls, descr: Union[v0_4.ClipDescr, v0_5.ClipDescr], member_id: MemberId 

267 ) -> Self: 

268 if isinstance(descr, v0_5.ClipDescr): 

269 dataset_mode, axes = _get_axes(descr.kwargs) 

270 if dataset_mode: 

271 Quantile = DatasetQuantile 

272 else: 

273 Quantile = partial(SampleQuantile, method="inverted_cdf") 

274 

275 if descr.kwargs.min is not None: 

276 min_arg = descr.kwargs.min 

277 elif descr.kwargs.min_percentile is not None: 

278 min_arg = Quantile( 

279 q=descr.kwargs.min_percentile / 100, 

280 axes=axes, 

281 member_id=member_id, 

282 ) 

283 else: 

284 min_arg = None 

285 

286 if descr.kwargs.max is not None: 

287 max_arg = descr.kwargs.max 

288 elif descr.kwargs.max_percentile is not None: 

289 max_arg = Quantile( 

290 q=descr.kwargs.max_percentile / 100, 

291 axes=axes, 

292 member_id=member_id, 

293 ) 

294 else: 

295 max_arg = None 

296 

297 elif isinstance(descr, v0_4.ClipDescr): 

298 min_arg = descr.kwargs.min 

299 max_arg = descr.kwargs.max 

300 else: 

301 assert_never(descr) 

302 

303 return cls( 

304 input=member_id, 

305 output=member_id, 

306 min=min_arg, 

307 max=max_arg, 

308 ) 

309 

310 

311@dataclass 

312class EnsureDtype(SimpleOperator): 

313 dtype: DTypeStr 

314 

315 @classmethod 

316 def from_proc_descr(cls, descr: v0_5.EnsureDtypeDescr, member_id: MemberId): 

317 return cls(input=member_id, output=member_id, dtype=descr.kwargs.dtype) 

318 

319 def get_descr(self): 

320 return v0_5.EnsureDtypeDescr(kwargs=v0_5.EnsureDtypeKwargs(dtype=self.dtype)) 

321 

322 def get_output_shape( 

323 self, input_shape: Mapping[AxisId, int] 

324 ) -> Mapping[AxisId, int]: 

325 return input_shape 

326 

327 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

328 return x.astype(self.dtype) 

329 

330 @property 

331 def required_measures(self) -> Collection[Measure]: 

332 return set() 

333 

334 

335@dataclass 

336class ScaleLinear(SimpleOperator): 

337 gain: Union[float, xr.DataArray] = 1.0 

338 """multiplicative factor""" 

339 

340 offset: Union[float, xr.DataArray] = 0.0 

341 """additive term""" 

342 

343 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

344 return x * self.gain + self.offset 

345 

346 @property 

347 def required_measures(self) -> Collection[Measure]: 

348 return set() 

349 

350 def get_output_shape( 

351 self, input_shape: Mapping[AxisId, int] 

352 ) -> Mapping[AxisId, int]: 

353 return input_shape 

354 

355 @classmethod 

356 def from_proc_descr( 

357 cls, 

358 descr: Union[v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr], 

359 member_id: MemberId, 

360 ) -> Self: 

361 kwargs = descr.kwargs 

362 if isinstance(kwargs, v0_5.ScaleLinearKwargs): 

363 axis = None 

364 elif isinstance(kwargs, v0_5.ScaleLinearAlongAxisKwargs): 

365 axis = kwargs.axis 

366 elif isinstance(kwargs, v0_4.ScaleLinearKwargs): 

367 if kwargs.axes is not None: 

368 raise NotImplementedError( 

369 "model.v0_4.ScaleLinearKwargs with axes not implemented, please consider updating the model to v0_5." 

370 ) 

371 axis = None 

372 else: 

373 assert_never(kwargs) 

374 

375 if axis: 

376 gain = xr.DataArray(np.atleast_1d(kwargs.gain), dims=axis) 

377 offset = xr.DataArray(np.atleast_1d(kwargs.offset), dims=axis) 

378 else: 

379 assert isinstance(kwargs.gain, (float, int)) or len(kwargs.gain) == 1, ( 

380 kwargs.gain 

381 ) 

382 gain = ( 

383 kwargs.gain if isinstance(kwargs.gain, (float, int)) else kwargs.gain[0] 

384 ) 

385 assert isinstance(kwargs.offset, (float, int)) or len(kwargs.offset) == 1 

386 offset = ( 

387 kwargs.offset 

388 if isinstance(kwargs.offset, (float, int)) 

389 else kwargs.offset[0] 

390 ) 

391 

392 return cls(input=member_id, output=member_id, gain=gain, offset=offset) 

393 

394 

395@dataclass 

396class ScaleMeanVariance(SimpleOperator): 

397 axes: Optional[Sequence[AxisId]] = None 

398 reference_tensor: Optional[MemberId] = None 

399 eps: float = 1e-6 

400 mean: Union[SampleMean, DatasetMean] = field(init=False) 

401 std: Union[SampleStd, DatasetStd] = field(init=False) 

402 ref_mean: Union[SampleMean, DatasetMean] = field(init=False) 

403 ref_std: Union[SampleStd, DatasetStd] = field(init=False) 

404 

405 @property 

406 def required_measures(self): 

407 return {self.mean, self.std, self.ref_mean, self.ref_std} 

408 

409 def __post_init__(self): 

410 axes = None if self.axes is None else tuple(self.axes) 

411 ref_tensor = self.reference_tensor or self.input 

412 if axes is None or AxisId("batch") not in axes: 

413 Mean = SampleMean 

414 Std = SampleStd 

415 else: 

416 Mean = DatasetMean 

417 Std = DatasetStd 

418 

419 self.mean = Mean(member_id=self.input, axes=axes) 

420 self.std = Std(member_id=self.input, axes=axes) 

421 self.ref_mean = Mean(member_id=ref_tensor, axes=axes) 

422 self.ref_std = Std(member_id=ref_tensor, axes=axes) 

423 

424 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

425 mean = stat[self.mean] 

426 std = stat[self.std] + self.eps 

427 ref_mean = stat[self.ref_mean] 

428 ref_std = stat[self.ref_std] + self.eps 

429 return (x - mean) / std * ref_std + ref_mean 

430 

431 def get_output_shape( 

432 self, input_shape: Mapping[AxisId, int] 

433 ) -> Mapping[AxisId, int]: 

434 return input_shape 

435 

436 @classmethod 

437 def from_proc_descr( 

438 cls, 

439 descr: Union[v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr], 

440 member_id: MemberId, 

441 ) -> Self: 

442 kwargs = descr.kwargs 

443 _, axes = _get_axes(descr.kwargs) 

444 

445 return cls( 

446 input=member_id, 

447 output=member_id, 

448 reference_tensor=MemberId(str(kwargs.reference_tensor)), 

449 axes=axes, 

450 eps=kwargs.eps, 

451 ) 

452 

453 

454def _get_axes( 

455 kwargs: Union[ 

456 v0_4.ZeroMeanUnitVarianceKwargs, 

457 v0_5.ZeroMeanUnitVarianceKwargs, 

458 v0_4.ScaleRangeKwargs, 

459 v0_5.ScaleRangeKwargs, 

460 v0_4.ScaleMeanVarianceKwargs, 

461 v0_5.ScaleMeanVarianceKwargs, 

462 v0_5.ClipKwargs, 

463 ], 

464) -> Tuple[bool, Optional[Tuple[AxisId, ...]]]: 

465 if kwargs.axes is None: 

466 return True, None 

467 elif isinstance(kwargs.axes, str): 

468 axes = _convert_axis_ids(kwargs.axes, kwargs["mode"]) 

469 return AxisId("b") in axes, axes 

470 elif isinstance(kwargs.axes, collections.abc.Sequence): 

471 axes = tuple(kwargs.axes) 

472 return AxisId("batch") in axes, axes 

473 else: 

474 assert_never(kwargs.axes) 

475 

476 

477@dataclass 

478class ScaleRange(SimpleOperator): 

479 lower_quantile: InitVar[Optional[Union[SampleQuantile, DatasetQuantile]]] = None 

480 upper_quantile: InitVar[Optional[Union[SampleQuantile, DatasetQuantile]]] = None 

481 lower: Union[SampleQuantile, DatasetQuantile] = field(init=False) 

482 upper: Union[SampleQuantile, DatasetQuantile] = field(init=False) 

483 

484 eps: float = 1e-6 

485 

486 def __post_init__( 

487 self, 

488 lower_quantile: Optional[Union[SampleQuantile, DatasetQuantile]], 

489 upper_quantile: Optional[Union[SampleQuantile, DatasetQuantile]], 

490 ): 

491 if lower_quantile is None: 

492 tid = self.input if upper_quantile is None else upper_quantile.member_id 

493 self.lower = DatasetQuantile(q=0.0, member_id=tid) 

494 else: 

495 self.lower = lower_quantile 

496 

497 if upper_quantile is None: 

498 self.upper = DatasetQuantile(q=1.0, member_id=self.lower.member_id) 

499 else: 

500 self.upper = upper_quantile 

501 

502 assert self.lower.member_id == self.upper.member_id 

503 assert self.lower.q < self.upper.q 

504 assert self.lower.axes == self.upper.axes 

505 

506 @property 

507 def required_measures(self): 

508 return {self.lower, self.upper} 

509 

510 def get_output_shape( 

511 self, input_shape: Mapping[AxisId, int] 

512 ) -> Mapping[AxisId, int]: 

513 return input_shape 

514 

515 @classmethod 

516 def from_proc_descr( 

517 cls, 

518 descr: Union[v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr], 

519 member_id: MemberId, 

520 ): 

521 kwargs = descr.kwargs 

522 ref_tensor = ( 

523 member_id 

524 if kwargs.reference_tensor is None 

525 else MemberId(str(kwargs.reference_tensor)) 

526 ) 

527 dataset_mode, axes = _get_axes(descr.kwargs) 

528 if dataset_mode: 

529 Quantile = DatasetQuantile 

530 else: 

531 Quantile = partial(SampleQuantile, method="linear") 

532 

533 return cls( 

534 input=member_id, 

535 output=member_id, 

536 lower_quantile=Quantile( 

537 q=kwargs.min_percentile / 100, 

538 axes=axes, 

539 member_id=ref_tensor, 

540 ), 

541 upper_quantile=Quantile( 

542 q=kwargs.max_percentile / 100, 

543 axes=axes, 

544 member_id=ref_tensor, 

545 ), 

546 ) 

547 

548 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

549 lower = stat[self.lower] 

550 upper = stat[self.upper] 

551 return (x - lower) / (upper - lower + self.eps) 

552 

553 def get_descr(self): 

554 assert self.lower.axes == self.upper.axes 

555 assert self.lower.member_id == self.upper.member_id 

556 

557 return v0_5.ScaleRangeDescr( 

558 kwargs=v0_5.ScaleRangeKwargs( 

559 axes=self.lower.axes, 

560 min_percentile=self.lower.q * 100, 

561 max_percentile=self.upper.q * 100, 

562 eps=self.eps, 

563 reference_tensor=self.lower.member_id, 

564 ) 

565 ) 

566 

567 

568@dataclass 

569class Sigmoid(SimpleOperator): 

570 """1 / (1 + e^(-input)).""" 

571 

572 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

573 return Tensor(1.0 / (1.0 + np.exp(-x)), dims=x.dims) 

574 

575 @property 

576 def required_measures(self) -> Collection[Measure]: 

577 return {} 

578 

579 def get_output_shape( 

580 self, input_shape: Mapping[AxisId, int] 

581 ) -> Mapping[AxisId, int]: 

582 return input_shape 

583 

584 @classmethod 

585 def from_proc_descr( 

586 cls, descr: Union[v0_4.SigmoidDescr, v0_5.SigmoidDescr], member_id: MemberId 

587 ) -> Self: 

588 assert isinstance(descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)) 

589 return cls(input=member_id, output=member_id) 

590 

591 def get_descr(self): 

592 return v0_5.SigmoidDescr() 

593 

594 

595@dataclass 

596class Softmax(SimpleOperator): 

597 """Softmax activation function.""" 

598 

599 axis: AxisId = AxisId("channel") 

600 

601 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

602 axis_idx = x.dims.index(self.axis) 

603 result = scipy.special.softmax(x.data, axis=axis_idx) 

604 result_xr = xr.DataArray(result, dims=x.dims) 

605 return Tensor.from_xarray(result_xr) 

606 

607 @property 

608 def required_measures(self) -> Collection[Measure]: 

609 return set() 

610 

611 def get_output_shape( 

612 self, input_shape: Mapping[AxisId, int] 

613 ) -> Mapping[AxisId, int]: 

614 return input_shape 

615 

616 @classmethod 

617 def from_proc_descr(cls, descr: v0_5.SoftmaxDescr, member_id: MemberId) -> Self: 

618 assert isinstance(descr, v0_5.SoftmaxDescr) 

619 return cls(input=member_id, output=member_id, axis=descr.kwargs.axis) 

620 

621 def get_descr(self): 

622 return v0_5.SoftmaxDescr(kwargs=v0_5.SoftmaxKwargs(axis=self.axis)) 

623 

624 

625NdTuple = TypeVar("NdTuple", Tuple[int, int], Tuple[int, int, int]) 

626NdBorder = TypeVar( 

627 "NdBorder", 

628 Tuple[Tuple[int, int], Tuple[int, int]], 

629 Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int]], 

630) 

631 

632 

633@dataclass 

634class _StardistPostprocessingBase(SamplewiseOperator, Generic[NdTuple, NdBorder], ABC): 

635 prob_dist_input_id: MemberId 

636 instance_labels_output_id: MemberId 

637 

638 grid: NdTuple 

639 """Grid size of network predictions.""" 

640 

641 prob_threshold: float 

642 """Object probability threshold for non-maximum suppression.""" 

643 

644 nms_threshold: float 

645 """The IoU threshold for non-maximum suppression.""" 

646 

647 b: Union[int, NdBorder] 

648 """Border region in which object probability is set to zero.""" 

649 

650 @property 

651 def required_measures(self) -> Collection[Measure]: 

652 return set() 

653 

654 def __call__(self, sample: Sample) -> None: 

655 prob_dist = sample.members[self.prob_dist_input_id] 

656 

657 assert AxisId("channel") in prob_dist.dims, ( 

658 "expected 'channel' axis in stardist probability/distance input" 

659 ) 

660 allowed_spatial = tuple( 

661 map(AxisId, ("y", "x") if len(self.grid) == 2 else ("z", "y", "x")) 

662 ) 

663 assert all( 

664 a in allowed_spatial or a in (AxisId("batch"), AxisId("channel")) 

665 for a in prob_dist.dims 

666 ), ( 

667 f"expected prob_dist to have only 'batch', 'channel', and spatial axes {allowed_spatial}, but got {prob_dist.dims}" 

668 ) 

669 

670 spatial_shape = tuple( 

671 prob_dist.tagged_shape[a] * g for a, g in zip(allowed_spatial, self.grid) 

672 ) 

673 if len(spatial_shape) != len(self.grid): 

674 raise ValueError( 

675 f"expected {len(self.grid)} spatial dimensions in prob_dist tensor, but got {len(spatial_shape)}" 

676 ) 

677 else: 

678 spatial_shape = cast(NdTuple, spatial_shape) 

679 

680 prob_dist = prob_dist.transpose( 

681 (AxisId("batch"), *allowed_spatial, AxisId("channel")) 

682 ) 

683 labels: List[NDArray[Any]] = [] 

684 for batch_idx in range(prob_dist.sizes[AxisId("batch")]): 

685 prob = prob_dist[ 

686 {AxisId("batch"): batch_idx, AxisId("channel"): 0} 

687 ].to_numpy() 

688 dist = prob_dist[ 

689 {AxisId("batch"): batch_idx, AxisId("channel"): slice(1, None)} 

690 ].to_numpy() 

691 

692 labels_i = self._impl(prob, dist, spatial_shape) 

693 assert labels_i.shape == spatial_shape, ( 

694 f"expected label image shape {spatial_shape}, but got {labels_i.shape}" 

695 ) 

696 labels.append(labels_i) 

697 

698 instance_labels = Tensor( 

699 np.stack(labels)[..., None], 

700 dims=(AxisId("batch"), *allowed_spatial, AxisId("channel")), 

701 ) 

702 sample.members[self.instance_labels_output_id] = instance_labels 

703 

704 @abstractmethod 

705 def _impl( 

706 self, prob: NDArray[Any], dist: NDArray[Any], spatial_shape: NdTuple 

707 ) -> NDArray[np.int32]: 

708 raise NotImplementedError 

709 

710 

711@dataclass 

712class StardistPostprocessing2D( 

713 _StardistPostprocessingBase[ 

714 Tuple[int, int], Tuple[Tuple[int, int], Tuple[int, int]] 

715 ] 

716): 

717 def _impl( 

718 self, prob: NDArray[Any], dist: NDArray[Any], spatial_shape: Tuple[int, int] 

719 ) -> NDArray[np.int32]: 

720 from stardist import ( 

721 non_maximum_suppression, # pyright: ignore[reportUnknownVariableType] 

722 polygons_to_label, # pyright: ignore[reportUnknownVariableType] 

723 ) 

724 

725 points, probi, disti = non_maximum_suppression( # pyright: ignore[reportUnknownVariableType] 

726 dist, 

727 prob, 

728 grid=self.grid, 

729 prob_thresh=self.prob_threshold, 

730 nms_thresh=self.nms_threshold, 

731 b=self.b, # pyright: ignore[reportArgumentType] 

732 ) 

733 

734 return polygons_to_label(disti, points, prob=probi, shape=spatial_shape) 

735 

736 @classmethod 

737 def from_proc_descr( 

738 cls, descr: v0_5.StardistPostprocessingDescr, member_id: MemberId 

739 ) -> Self: 

740 if not isinstance(descr.kwargs, v0_5.StardistPostprocessingKwargs2D): 

741 raise TypeError( 

742 f"expected v0_5.StardistPostprocessingKwargs2D for 2D stardist post-processing, but got {type(descr.kwargs)}" 

743 ) 

744 

745 kwargs = descr.kwargs 

746 return cls( 

747 prob_dist_input_id=member_id, 

748 instance_labels_output_id=member_id, 

749 grid=kwargs.grid, 

750 prob_threshold=kwargs.prob_threshold, 

751 nms_threshold=kwargs.nms_threshold, 

752 b=kwargs.b, 

753 ) 

754 

755 

756@dataclass 

757class StardistPostprocessing3D( 

758 _StardistPostprocessingBase[ 

759 Tuple[int, int, int], Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int]] 

760 ] 

761): 

762 n_rays: int 

763 """Number of rays for 3D star-convex polyhedra.""" 

764 

765 anisotropy: Tuple[float, float, float] 

766 """Anisotropy factors for 3D star-convex polyhedra, i.e. the physical pixel size along each spatial axis.""" 

767 

768 overlap_label: Optional[int] = None 

769 """Optional label to apply to any area of overlapping predicted objects.""" 

770 

771 def _impl( 

772 self, 

773 prob: NDArray[Any], 

774 dist: NDArray[Any], 

775 spatial_shape: Tuple[int, int, int], 

776 ) -> NDArray[np.int32]: 

777 from stardist import ( 

778 Rays_GoldenSpiral, 

779 non_maximum_suppression_3d, # pyright: ignore[reportUnknownVariableType] 

780 polyhedron_to_label, # pyright: ignore[reportUnknownVariableType] 

781 ) 

782 from stardist.matching import ( 

783 relabel_sequential, # pyright: ignore[reportUnknownVariableType] 

784 ) 

785 

786 rays = Rays_GoldenSpiral(self.n_rays, anisotropy=self.anisotropy) 

787 

788 points, probi, disti = non_maximum_suppression_3d( # pyright: ignore[reportUnknownVariableType] 

789 dist, 

790 prob, 

791 rays, 

792 grid=self.grid, 

793 prob_thresh=self.prob_threshold, 

794 nms_thresh=self.nms_threshold, 

795 b=self.b, # pyright: ignore[reportArgumentType] 

796 ) 

797 

798 labels = polyhedron_to_label( # pyright: ignore[reportUnknownVariableType] 

799 disti, 

800 points, 

801 rays=rays, 

802 prob=probi, 

803 shape=spatial_shape, 

804 overlap_label=self.overlap_label, 

805 ) 

806 

807 labels, _, _ = relabel_sequential(labels) 

808 assert isinstance(labels, np.ndarray) and labels.dtype == np.int32 

809 return labels 

810 

811 @classmethod 

812 def from_proc_descr( 

813 cls, descr: v0_5.StardistPostprocessingDescr, member_id: MemberId 

814 ) -> Self: 

815 if not isinstance(descr.kwargs, v0_5.StardistPostprocessingKwargs3D): 

816 raise TypeError( 

817 f"expected v0_5.StardistPostprocessingKwargs3D for 3D stardist post-processing, but got {type(descr.kwargs)}" 

818 ) 

819 

820 kwargs = descr.kwargs 

821 return cls( 

822 prob_dist_input_id=member_id, 

823 instance_labels_output_id=member_id, 

824 grid=kwargs.grid, 

825 prob_threshold=kwargs.prob_threshold, 

826 nms_threshold=kwargs.nms_threshold, 

827 n_rays=kwargs.n_rays, 

828 anisotropy=kwargs.anisotropy, 

829 b=kwargs.b, 

830 overlap_label=kwargs.overlap_label, 

831 ) 

832 

833 

834@dataclass 

835class ZeroMeanUnitVariance(SimpleOperator): 

836 """normalize to zero mean, unit variance.""" 

837 

838 mean: MeanMeasure 

839 std: StdMeasure 

840 

841 eps: float = 1e-6 

842 

843 def __post_init__(self): 

844 assert self.mean.axes == self.std.axes 

845 

846 @property 

847 def required_measures(self) -> Set[Union[MeanMeasure, StdMeasure]]: 

848 return {self.mean, self.std} 

849 

850 def get_output_shape( 

851 self, input_shape: Mapping[AxisId, int] 

852 ) -> Mapping[AxisId, int]: 

853 return input_shape 

854 

855 @classmethod 

856 def from_proc_descr( 

857 cls, 

858 descr: Union[v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr], 

859 member_id: MemberId, 

860 ): 

861 dataset_mode, axes = _get_axes(descr.kwargs) 

862 

863 if dataset_mode: 

864 Mean = DatasetMean 

865 Std = DatasetStd 

866 else: 

867 Mean = SampleMean 

868 Std = SampleStd 

869 

870 return cls( 

871 input=member_id, 

872 output=member_id, 

873 mean=Mean(axes=axes, member_id=member_id), 

874 std=Std(axes=axes, member_id=member_id), 

875 ) 

876 

877 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

878 mean = stat[self.mean] 

879 std = stat[self.std] 

880 return (x - mean) / (std + self.eps) 

881 

882 def get_descr(self): 

883 return v0_5.ZeroMeanUnitVarianceDescr( 

884 kwargs=v0_5.ZeroMeanUnitVarianceKwargs(axes=self.mean.axes, eps=self.eps) 

885 ) 

886 

887 

888@dataclass 

889class FixedZeroMeanUnitVariance(SimpleOperator): 

890 """normalize to zero mean, unit variance with precomputed values.""" 

891 

892 mean: Union[float, xr.DataArray] 

893 std: Union[float, xr.DataArray] 

894 

895 eps: float = 1e-6 

896 

897 def __post_init__(self): 

898 assert ( 

899 isinstance(self.mean, (int, float)) 

900 or isinstance(self.std, (int, float)) 

901 or self.mean.dims == self.std.dims 

902 ) 

903 

904 @property 

905 def required_measures(self) -> Collection[Measure]: 

906 return set() 

907 

908 def get_output_shape( 

909 self, input_shape: Mapping[AxisId, int] 

910 ) -> Mapping[AxisId, int]: 

911 return input_shape 

912 

913 @classmethod 

914 def from_proc_descr( 

915 cls, 

916 descr: v0_5.FixedZeroMeanUnitVarianceDescr, 

917 member_id: MemberId, 

918 ) -> Self: 

919 if isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceKwargs): 

920 dims = None 

921 elif isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs): 

922 dims = (AxisId(descr.kwargs.axis),) 

923 else: 

924 assert_never(descr.kwargs) 

925 

926 return cls( 

927 input=member_id, 

928 output=member_id, 

929 mean=xr.DataArray(descr.kwargs.mean, dims=dims), 

930 std=xr.DataArray(descr.kwargs.std, dims=dims), 

931 ) 

932 

933 def get_descr(self): 

934 if isinstance(self.mean, (int, float)): 

935 assert isinstance(self.std, (int, float)) 

936 kwargs = v0_5.FixedZeroMeanUnitVarianceKwargs(mean=self.mean, std=self.std) 

937 else: 

938 assert isinstance(self.std, xr.DataArray) 

939 assert len(self.mean.dims) == 1 

940 kwargs = v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs( 

941 axis=AxisId(str(self.mean.dims[0])), 

942 mean=list(self.mean), 

943 std=list(self.std), 

944 ) 

945 

946 return v0_5.FixedZeroMeanUnitVarianceDescr(kwargs=kwargs) 

947 

948 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

949 return (x - self.mean) / (self.std + self.eps) 

950 

951 

952ProcDescr = Union[ 

953 v0_4.PreprocessingDescr, 

954 v0_4.PostprocessingDescr, 

955 v0_5.PreprocessingDescr, 

956 v0_5.PostprocessingDescr, 

957] 

958 

959 

960Processing = Union[ 

961 AddKnownDatasetStats, 

962 Binarize, 

963 Clip, 

964 EnsureDtype, 

965 FixedZeroMeanUnitVariance, 

966 ScaleLinear, 

967 ScaleMeanVariance, 

968 ScaleRange, 

969 Sigmoid, 

970 StardistPostprocessing2D, 

971 StardistPostprocessing3D, 

972 Softmax, 

973 UpdateStats, 

974 ZeroMeanUnitVariance, 

975] 

976 

977 

978def get_proc( 

979 proc_descr: ProcDescr, 

980 tensor_descr: Union[ 

981 v0_4.InputTensorDescr, 

982 v0_4.OutputTensorDescr, 

983 v0_5.InputTensorDescr, 

984 v0_5.OutputTensorDescr, 

985 ], 

986) -> Processing: 

987 member_id = get_member_id(tensor_descr) 

988 

989 if isinstance(proc_descr, (v0_4.BinarizeDescr, v0_5.BinarizeDescr)): 

990 return Binarize.from_proc_descr(proc_descr, member_id) 

991 elif isinstance(proc_descr, (v0_4.ClipDescr, v0_5.ClipDescr)): 

992 return Clip.from_proc_descr(proc_descr, member_id) 

993 elif isinstance(proc_descr, v0_5.EnsureDtypeDescr): 

994 return EnsureDtype.from_proc_descr(proc_descr, member_id) 

995 elif isinstance(proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr): 

996 return FixedZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id) 

997 elif isinstance(proc_descr, (v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr)): 

998 return ScaleLinear.from_proc_descr(proc_descr, member_id) 

999 elif isinstance( 

1000 proc_descr, (v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr) 

1001 ): 

1002 return ScaleMeanVariance.from_proc_descr(proc_descr, member_id) 

1003 elif isinstance(proc_descr, (v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr)): 

1004 return ScaleRange.from_proc_descr(proc_descr, member_id) 

1005 elif isinstance(proc_descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)): 

1006 return Sigmoid.from_proc_descr(proc_descr, member_id) 

1007 elif ( 

1008 isinstance(proc_descr, v0_4.ZeroMeanUnitVarianceDescr) 

1009 and proc_descr.kwargs.mode == "fixed" 

1010 ): 

1011 if not isinstance( 

1012 tensor_descr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr) 

1013 ): 

1014 raise TypeError( 

1015 "Expected v0_4 tensor description for v0_4 processing description" 

1016 ) 

1017 

1018 v5_proc_descr = _convert_proc(proc_descr, tensor_descr.axes) 

1019 assert isinstance(v5_proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr) 

1020 return FixedZeroMeanUnitVariance.from_proc_descr(v5_proc_descr, member_id) 

1021 elif isinstance( 

1022 proc_descr, 

1023 (v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr), 

1024 ): 

1025 return ZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id) 

1026 elif isinstance(proc_descr, v0_5.SoftmaxDescr): 

1027 return Softmax.from_proc_descr(proc_descr, member_id) 

1028 elif isinstance(proc_descr, v0_5.StardistPostprocessingDescr): 

1029 if isinstance(proc_descr.kwargs, v0_5.StardistPostprocessingKwargs2D): 

1030 return StardistPostprocessing2D.from_proc_descr(proc_descr, member_id) 

1031 elif isinstance(proc_descr.kwargs, v0_5.StardistPostprocessingKwargs3D): 

1032 return StardistPostprocessing3D.from_proc_descr(proc_descr, member_id) 

1033 else: 

1034 raise ValueError( 

1035 f"expected ndim 2 or 3 for stardist postprocessing, but got {proc_descr.kwargs.ndim}" 

1036 ) 

1037 else: 

1038 assert_never(proc_descr)