Coverage for src / bioimageio / core / proc_ops.py: 73%

508 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-22 18:38 +0000

1import collections.abc 

2from abc import ABC, abstractmethod 

3from dataclasses import InitVar, dataclass, field 

4from functools import partial 

5from typing import ( 

6 Any, 

7 Collection, 

8 Generic, 

9 List, 

10 Literal, 

11 Mapping, 

12 Optional, 

13 Sequence, 

14 Set, 

15 Tuple, 

16 Union, 

17) 

18 

19import numpy as np 

20import scipy 

21import xarray as xr 

22from numpy.typing import NDArray 

23from typing_extensions import Self, TypeVar, assert_never, cast 

24 

25from bioimageio.core.digest_spec import get_member_id 

26from bioimageio.spec.model import v0_4, v0_5 

27from bioimageio.spec.model.v0_5 import ( 

28 _convert_proc, # pyright: ignore[reportPrivateUsage] 

29) 

30 

31from ._op_base import BlockwiseOperator, SamplewiseOperator, SimpleOperator 

32from .axis import AxisId 

33from .common import DTypeStr, MemberId 

34from .sample import Sample, SampleBlock 

35from .stat_calculators import StatsCalculator 

36from .stat_measures import ( 

37 DatasetMean, 

38 DatasetMeasure, 

39 DatasetQuantile, 

40 DatasetStd, 

41 MeanMeasure, 

42 Measure, 

43 MeasureValue, 

44 SampleMean, 

45 SampleQuantile, 

46 SampleStd, 

47 Stat, 

48 StdMeasure, 

49) 

50from .tensor import Tensor 

51 

52 

53def _convert_axis_ids( 

54 axes: v0_4.AxesInCZYX, 

55 mode: Literal["per_sample", "per_dataset"], 

56) -> Tuple[AxisId, ...]: 

57 if not isinstance(axes, str): 

58 return tuple(axes) 

59 

60 if mode == "per_sample": 

61 ret = [] 

62 elif mode == "per_dataset": 

63 ret = [v0_5.BATCH_AXIS_ID] 

64 else: 

65 assert_never(mode) 

66 

67 ret.extend([AxisId(a) for a in axes]) 

68 return tuple(ret) 

69 

70 

71@dataclass 

72class AddKnownDatasetStats(BlockwiseOperator): 

73 dataset_stats: Mapping[DatasetMeasure, MeasureValue] 

74 

75 @property 

76 def required_measures(self) -> Collection[Measure]: 

77 return set() 

78 

79 def __call__(self, sample: Union[Sample, SampleBlock]) -> None: 

80 sample.stat.update(self.dataset_stats.items()) 

81 

82 

83# @dataclass 

84# class UpdateStats(Operator): 

85# """Calculates sample and/or dataset measures""" 

86 

87# measures: Union[Sequence[Measure], Set[Measure], Mapping[Measure, MeasureValue]] 

88# """sample and dataset `measuers` to be calculated by this operator. Initial/fixed 

89# dataset measure values may be given, see `keep_updating_dataset_stats` for details. 

90# """ 

91# keep_updating_dataset_stats: Optional[bool] = None 

92# """indicates if operator calls should keep updating dataset statistics or not 

93 

94# default (None): if `measures` is a `Mapping` (i.e. initial measure values are 

95# given) no further updates to dataset statistics is conducted, otherwise (w.o. 

96# initial measure values) dataset statistics are updated by each processed sample. 

97# """ 

98# _keep_updating_dataset_stats: bool = field(init=False) 

99# _stats_calculator: StatsCalculator = field(init=False) 

100 

101# @property 

102# def required_measures(self) -> Collection[Measure]: 

103# return set() 

104 

105# def __post_init__(self): 

106# self._stats_calculator = StatsCalculator(self.measures) 

107# if self.keep_updating_dataset_stats is None: 

108# self._keep_updating_dataset_stats = not isinstance(self.measures, collections.abc.Mapping) 

109# else: 

110# self._keep_updating_dataset_stats = self.keep_updating_dataset_stats 

111 

112# def __call__(self, sample_block: SampleBlockWithOrigin> None: 

113# if self._keep_updating_dataset_stats: 

114# sample.stat.update(self._stats_calculator.update_and_get_all(sample)) 

115# else: 

116# sample.stat.update(self._stats_calculator.skip_update_and_get_all(sample)) 

117 

118 

119@dataclass 

120class UpdateStats(SamplewiseOperator): 

121 """Calculates sample and/or dataset measures""" 

122 

123 stats_calculator: StatsCalculator 

124 """`StatsCalculator` to be used by this operator.""" 

125 keep_updating_initial_dataset_stats: bool = False 

126 """indicates if operator calls should keep updating initial dataset statistics or not; 

127 if the `stats_calculator` was not provided with any initial dataset statistics, 

128 these are always updated with every new sample. 

129 """ 

130 _keep_updating_dataset_stats: bool = field(init=False) 

131 

132 @property 

133 def required_measures(self) -> Collection[Measure]: 

134 return set() 

135 

136 def __post_init__(self): 

137 self._keep_updating_dataset_stats = ( 

138 self.keep_updating_initial_dataset_stats 

139 or not self.stats_calculator.has_dataset_measures 

140 ) 

141 

142 def __call__(self, sample: Sample) -> None: 

143 if self._keep_updating_dataset_stats: 

144 sample.stat.update(self.stats_calculator.update_and_get_all(sample)) 

145 else: 

146 sample.stat.update(self.stats_calculator.skip_update_and_get_all(sample)) 

147 

148 

149@dataclass 

150class Binarize(SimpleOperator): 

151 """'output = tensor > threshold'.""" 

152 

153 threshold: Union[float, Sequence[float]] 

154 axis: Optional[AxisId] = None 

155 

156 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

157 return x > self.threshold 

158 

159 @property 

160 def required_measures(self) -> Collection[Measure]: 

161 return set() 

162 

163 def get_output_shape( 

164 self, input_shape: Mapping[AxisId, int] 

165 ) -> Mapping[AxisId, int]: 

166 return input_shape 

167 

168 @classmethod 

169 def from_proc_descr( 

170 cls, descr: Union[v0_4.BinarizeDescr, v0_5.BinarizeDescr], member_id: MemberId 

171 ) -> Self: 

172 if isinstance(descr.kwargs, (v0_4.BinarizeKwargs, v0_5.BinarizeKwargs)): 

173 return cls( 

174 input=member_id, output=member_id, threshold=descr.kwargs.threshold 

175 ) 

176 elif isinstance(descr.kwargs, v0_5.BinarizeAlongAxisKwargs): 

177 return cls( 

178 input=member_id, 

179 output=member_id, 

180 threshold=descr.kwargs.threshold, 

181 axis=descr.kwargs.axis, 

182 ) 

183 else: 

184 assert_never(descr.kwargs) 

185 

186 

187@dataclass 

188class Clip(SimpleOperator): 

189 min: Optional[Union[float, SampleQuantile, DatasetQuantile]] = None 

190 """minimum value for clipping""" 

191 max: Optional[Union[float, SampleQuantile, DatasetQuantile]] = None 

192 """maximum value for clipping""" 

193 

194 def __post_init__(self): 

195 if self.min is None and self.max is None: 

196 raise ValueError("missing min or max value") 

197 

198 if ( 

199 isinstance(self.min, float) 

200 and isinstance(self.max, float) 

201 and self.min >= self.max 

202 ): 

203 raise ValueError(f"expected min < max, but {self.min} >= {self.max}") 

204 

205 if isinstance(self.min, (SampleQuantile, DatasetQuantile)) and isinstance( 

206 self.max, (SampleQuantile, DatasetQuantile) 

207 ): 

208 if self.min.axes != self.max.axes: 

209 raise NotImplementedError( 

210 f"expected min and max quantiles with same axes, but got {self.min.axes} and {self.max.axes}" 

211 ) 

212 if self.min.q >= self.max.q: 

213 raise ValueError( 

214 f"expected min quantile < max quantile, but {self.min.q} >= {self.max.q}" 

215 ) 

216 

217 @property 

218 def required_measures(self): 

219 return { 

220 arg 

221 for arg in (self.min, self.max) 

222 if isinstance(arg, (SampleQuantile, DatasetQuantile)) 

223 } 

224 

225 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

226 if isinstance(self.min, (SampleQuantile, DatasetQuantile)): 

227 min_value = stat[self.min] 

228 if isinstance(min_value, (int, float)): 

229 # use clip for scalar value 

230 min_clip_arg = min_value 

231 else: 

232 # clip does not support non-scalar values 

233 x = Tensor.from_xarray( 

234 x.data.where(x.data >= min_value.data, min_value.data) 

235 ) 

236 min_clip_arg = None 

237 else: 

238 min_clip_arg = self.min 

239 

240 if isinstance(self.max, (SampleQuantile, DatasetQuantile)): 

241 max_value = stat[self.max] 

242 if isinstance(max_value, (int, float)): 

243 # use clip for scalar value 

244 max_clip_arg = max_value 

245 else: 

246 # clip does not support non-scalar values 

247 x = Tensor.from_xarray( 

248 x.data.where(x.data <= max_value.data, max_value.data) 

249 ) 

250 max_clip_arg = None 

251 else: 

252 max_clip_arg = self.max 

253 

254 if min_clip_arg is not None or max_clip_arg is not None: 

255 x = x.clip(min_clip_arg, max_clip_arg) 

256 

257 return x 

258 

259 def get_output_shape( 

260 self, input_shape: Mapping[AxisId, int] 

261 ) -> Mapping[AxisId, int]: 

262 return input_shape 

263 

264 @classmethod 

265 def from_proc_descr( 

266 cls, descr: Union[v0_4.ClipDescr, v0_5.ClipDescr], member_id: MemberId 

267 ) -> Self: 

268 if isinstance(descr, v0_5.ClipDescr): 

269 dataset_mode, axes = _get_axes(descr.kwargs) 

270 if dataset_mode: 

271 Quantile = DatasetQuantile 

272 else: 

273 Quantile = partial(SampleQuantile, method="inverted_cdf") 

274 

275 if descr.kwargs.min is not None: 

276 min_arg = descr.kwargs.min 

277 elif descr.kwargs.min_percentile is not None: 

278 min_arg = Quantile( 

279 q=descr.kwargs.min_percentile / 100, 

280 axes=axes, 

281 member_id=member_id, 

282 ) 

283 else: 

284 min_arg = None 

285 

286 if descr.kwargs.max is not None: 

287 max_arg = descr.kwargs.max 

288 elif descr.kwargs.max_percentile is not None: 

289 max_arg = Quantile( 

290 q=descr.kwargs.max_percentile / 100, 

291 axes=axes, 

292 member_id=member_id, 

293 ) 

294 else: 

295 max_arg = None 

296 

297 elif isinstance(descr, v0_4.ClipDescr): 

298 min_arg = descr.kwargs.min 

299 max_arg = descr.kwargs.max 

300 else: 

301 assert_never(descr) 

302 

303 return cls( 

304 input=member_id, 

305 output=member_id, 

306 min=min_arg, 

307 max=max_arg, 

308 ) 

309 

310 

311@dataclass 

312class EnsureDtype(SimpleOperator): 

313 dtype: DTypeStr 

314 

315 @classmethod 

316 def from_proc_descr(cls, descr: v0_5.EnsureDtypeDescr, member_id: MemberId): 

317 return cls(input=member_id, output=member_id, dtype=descr.kwargs.dtype) 

318 

319 def get_descr(self): 

320 return v0_5.EnsureDtypeDescr(kwargs=v0_5.EnsureDtypeKwargs(dtype=self.dtype)) 

321 

322 def get_output_shape( 

323 self, input_shape: Mapping[AxisId, int] 

324 ) -> Mapping[AxisId, int]: 

325 return input_shape 

326 

327 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

328 return x.astype(self.dtype) 

329 

330 @property 

331 def required_measures(self) -> Collection[Measure]: 

332 return set() 

333 

334 

335@dataclass 

336class ScaleLinear(SimpleOperator): 

337 gain: Union[float, xr.DataArray] = 1.0 

338 """multiplicative factor""" 

339 

340 offset: Union[float, xr.DataArray] = 0.0 

341 """additive term""" 

342 

343 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

344 return x * self.gain + self.offset 

345 

346 @property 

347 def required_measures(self) -> Collection[Measure]: 

348 return set() 

349 

350 def get_output_shape( 

351 self, input_shape: Mapping[AxisId, int] 

352 ) -> Mapping[AxisId, int]: 

353 return input_shape 

354 

355 @classmethod 

356 def from_proc_descr( 

357 cls, 

358 descr: Union[v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr], 

359 member_id: MemberId, 

360 ) -> Self: 

361 kwargs = descr.kwargs 

362 if isinstance(kwargs, v0_5.ScaleLinearKwargs): 

363 axis = None 

364 elif isinstance(kwargs, v0_5.ScaleLinearAlongAxisKwargs): 

365 axis = kwargs.axis 

366 elif isinstance(kwargs, v0_4.ScaleLinearKwargs): 

367 if kwargs.axes is not None: 

368 raise NotImplementedError( 

369 "model.v0_4.ScaleLinearKwargs with axes not implemented, please consider updating the model to v0_5." 

370 ) 

371 axis = None 

372 else: 

373 assert_never(kwargs) 

374 

375 if axis: 

376 gain = xr.DataArray(np.atleast_1d(kwargs.gain), dims=axis) 

377 offset = xr.DataArray(np.atleast_1d(kwargs.offset), dims=axis) 

378 else: 

379 assert isinstance(kwargs.gain, (float, int)) or len(kwargs.gain) == 1, ( 

380 kwargs.gain 

381 ) 

382 gain = ( 

383 kwargs.gain if isinstance(kwargs.gain, (float, int)) else kwargs.gain[0] 

384 ) 

385 assert isinstance(kwargs.offset, (float, int)) or len(kwargs.offset) == 1 

386 offset = ( 

387 kwargs.offset 

388 if isinstance(kwargs.offset, (float, int)) 

389 else kwargs.offset[0] 

390 ) 

391 

392 return cls(input=member_id, output=member_id, gain=gain, offset=offset) 

393 

394 

395@dataclass 

396class ScaleMeanVariance(SimpleOperator): 

397 axes: Optional[Sequence[AxisId]] = None 

398 reference_tensor: Optional[MemberId] = None 

399 eps: float = 1e-6 

400 mean: Union[SampleMean, DatasetMean] = field(init=False) 

401 std: Union[SampleStd, DatasetStd] = field(init=False) 

402 ref_mean: Union[SampleMean, DatasetMean] = field(init=False) 

403 ref_std: Union[SampleStd, DatasetStd] = field(init=False) 

404 

405 @property 

406 def required_measures(self): 

407 return {self.mean, self.std, self.ref_mean, self.ref_std} 

408 

409 def __post_init__(self): 

410 axes = None if self.axes is None else tuple(self.axes) 

411 ref_tensor = self.reference_tensor or self.input 

412 if axes is None or AxisId("batch") not in axes: 

413 Mean = SampleMean 

414 Std = SampleStd 

415 else: 

416 Mean = DatasetMean 

417 Std = DatasetStd 

418 

419 self.mean = Mean(member_id=self.input, axes=axes) 

420 self.std = Std(member_id=self.input, axes=axes) 

421 self.ref_mean = Mean(member_id=ref_tensor, axes=axes) 

422 self.ref_std = Std(member_id=ref_tensor, axes=axes) 

423 

424 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

425 mean = stat[self.mean] 

426 std = stat[self.std] + self.eps 

427 ref_mean = stat[self.ref_mean] 

428 ref_std = stat[self.ref_std] + self.eps 

429 return (x - mean) / std * ref_std + ref_mean 

430 

431 def get_output_shape( 

432 self, input_shape: Mapping[AxisId, int] 

433 ) -> Mapping[AxisId, int]: 

434 return input_shape 

435 

436 @classmethod 

437 def from_proc_descr( 

438 cls, 

439 descr: Union[v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr], 

440 member_id: MemberId, 

441 ) -> Self: 

442 kwargs = descr.kwargs 

443 _, axes = _get_axes(descr.kwargs) 

444 

445 return cls( 

446 input=member_id, 

447 output=member_id, 

448 reference_tensor=MemberId(str(kwargs.reference_tensor)), 

449 axes=axes, 

450 eps=kwargs.eps, 

451 ) 

452 

453 

454def _get_axes( 

455 kwargs: Union[ 

456 v0_4.ZeroMeanUnitVarianceKwargs, 

457 v0_5.ZeroMeanUnitVarianceKwargs, 

458 v0_4.ScaleRangeKwargs, 

459 v0_5.ScaleRangeKwargs, 

460 v0_4.ScaleMeanVarianceKwargs, 

461 v0_5.ScaleMeanVarianceKwargs, 

462 v0_5.ClipKwargs, 

463 ], 

464) -> Tuple[bool, Optional[Tuple[AxisId, ...]]]: 

465 if kwargs.axes is None: 

466 return True, None 

467 elif isinstance(kwargs.axes, str): 

468 axes = _convert_axis_ids(kwargs.axes, kwargs["mode"]) 

469 return AxisId("b") in axes, axes 

470 elif isinstance(kwargs.axes, collections.abc.Sequence): 

471 axes = tuple(kwargs.axes) 

472 return AxisId("batch") in axes, axes 

473 else: 

474 assert_never(kwargs.axes) 

475 

476 

477@dataclass 

478class ScaleRange(SimpleOperator): 

479 lower_quantile: InitVar[Optional[Union[SampleQuantile, DatasetQuantile]]] = None 

480 upper_quantile: InitVar[Optional[Union[SampleQuantile, DatasetQuantile]]] = None 

481 lower: Union[SampleQuantile, DatasetQuantile] = field(init=False) 

482 upper: Union[SampleQuantile, DatasetQuantile] = field(init=False) 

483 

484 eps: float = 1e-6 

485 

486 def __post_init__( 

487 self, 

488 lower_quantile: Optional[Union[SampleQuantile, DatasetQuantile]], 

489 upper_quantile: Optional[Union[SampleQuantile, DatasetQuantile]], 

490 ): 

491 if lower_quantile is None: 

492 tid = self.input if upper_quantile is None else upper_quantile.member_id 

493 self.lower = DatasetQuantile(q=0.0, member_id=tid) 

494 else: 

495 self.lower = lower_quantile 

496 

497 if upper_quantile is None: 

498 self.upper = DatasetQuantile(q=1.0, member_id=self.lower.member_id) 

499 else: 

500 self.upper = upper_quantile 

501 

502 assert self.lower.member_id == self.upper.member_id 

503 assert self.lower.q < self.upper.q 

504 assert self.lower.axes == self.upper.axes 

505 

506 @property 

507 def required_measures(self): 

508 return {self.lower, self.upper} 

509 

510 def get_output_shape( 

511 self, input_shape: Mapping[AxisId, int] 

512 ) -> Mapping[AxisId, int]: 

513 return input_shape 

514 

515 @classmethod 

516 def from_proc_descr( 

517 cls, 

518 descr: Union[v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr], 

519 member_id: MemberId, 

520 ): 

521 kwargs = descr.kwargs 

522 ref_tensor = ( 

523 member_id 

524 if kwargs.reference_tensor is None 

525 else MemberId(str(kwargs.reference_tensor)) 

526 ) 

527 dataset_mode, axes = _get_axes(descr.kwargs) 

528 if dataset_mode: 

529 Quantile = DatasetQuantile 

530 else: 

531 Quantile = partial(SampleQuantile, method="linear") 

532 

533 return cls( 

534 input=member_id, 

535 output=member_id, 

536 lower_quantile=Quantile( 

537 q=kwargs.min_percentile / 100, 

538 axes=axes, 

539 member_id=ref_tensor, 

540 ), 

541 upper_quantile=Quantile( 

542 q=kwargs.max_percentile / 100, 

543 axes=axes, 

544 member_id=ref_tensor, 

545 ), 

546 ) 

547 

548 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

549 lower = stat[self.lower] 

550 upper = stat[self.upper] 

551 return (x - lower) / (upper - lower + self.eps) 

552 

553 def get_descr(self): 

554 assert self.lower.axes == self.upper.axes 

555 assert self.lower.member_id == self.upper.member_id 

556 

557 return v0_5.ScaleRangeDescr( 

558 kwargs=v0_5.ScaleRangeKwargs( 

559 axes=self.lower.axes, 

560 min_percentile=self.lower.q * 100, 

561 max_percentile=self.upper.q * 100, 

562 eps=self.eps, 

563 reference_tensor=self.lower.member_id, 

564 ) 

565 ) 

566 

567 

568@dataclass 

569class Sigmoid(SimpleOperator): 

570 """1 / (1 + e^(-input)).""" 

571 

572 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

573 return Tensor(1.0 / (1.0 + np.exp(-x)), dims=x.dims) 

574 

575 @property 

576 def required_measures(self) -> Collection[Measure]: 

577 return {} 

578 

579 def get_output_shape( 

580 self, input_shape: Mapping[AxisId, int] 

581 ) -> Mapping[AxisId, int]: 

582 return input_shape 

583 

584 @classmethod 

585 def from_proc_descr( 

586 cls, descr: Union[v0_4.SigmoidDescr, v0_5.SigmoidDescr], member_id: MemberId 

587 ) -> Self: 

588 assert isinstance(descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)) 

589 return cls(input=member_id, output=member_id) 

590 

591 def get_descr(self): 

592 return v0_5.SigmoidDescr() 

593 

594 

595@dataclass 

596class Softmax(SimpleOperator): 

597 """Softmax activation function.""" 

598 

599 axis: AxisId = AxisId("channel") 

600 

601 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

602 axis_idx = x.dims.index(self.axis) 

603 result = scipy.special.softmax(x.data, axis=axis_idx) 

604 result_xr = xr.DataArray(result, dims=x.dims) 

605 return Tensor.from_xarray(result_xr) 

606 

607 @property 

608 def required_measures(self) -> Collection[Measure]: 

609 return set() 

610 

611 def get_output_shape( 

612 self, input_shape: Mapping[AxisId, int] 

613 ) -> Mapping[AxisId, int]: 

614 return input_shape 

615 

616 @classmethod 

617 def from_proc_descr(cls, descr: v0_5.SoftmaxDescr, member_id: MemberId) -> Self: 

618 assert isinstance(descr, v0_5.SoftmaxDescr) 

619 return cls(input=member_id, output=member_id, axis=descr.kwargs.axis) 

620 

621 def get_descr(self): 

622 return v0_5.SoftmaxDescr(kwargs=v0_5.SoftmaxKwargs(axis=self.axis)) 

623 

624 

625NdTuple = TypeVar("NdTuple", Tuple[int, int], Tuple[int, int, int]) 

626NdBorder = TypeVar( 

627 "NdBorder", 

628 Tuple[Tuple[int, int], Tuple[int, int]], 

629 Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int]], 

630) 

631 

632 

633@dataclass 

634class _StardistPostprocessingBase(SamplewiseOperator, Generic[NdTuple, NdBorder], ABC): 

635 prob_dist_input_id: MemberId 

636 instance_labels_output_id: MemberId 

637 

638 grid: NdTuple 

639 """Grid size of network predictions.""" 

640 

641 prob_threshold: float 

642 """Object probability threshold for non-maximum suppression.""" 

643 

644 nms_threshold: float 

645 """The IoU threshold for non-maximum suppression.""" 

646 

647 b: Union[int, NdBorder] 

648 """Border region in which object probability is set to zero.""" 

649 

650 @property 

651 def required_measures(self) -> Collection[Measure]: 

652 return set() 

653 

654 def __call__(self, sample: Sample) -> None: 

655 prob_dist = sample.members[self.prob_dist_input_id] 

656 

657 assert AxisId("channel") in prob_dist.dims, ( 

658 "expected 'channel' axis in stardist probability/distance input" 

659 ) 

660 allowed_spatial = tuple( 

661 map(AxisId, ("y", "x") if len(self.grid) == 2 else ("z", "y", "x")) 

662 ) 

663 assert all( 

664 a in allowed_spatial or a in (AxisId("batch"), AxisId("channel")) 

665 for a in prob_dist.dims 

666 ), ( 

667 f"expected prob_dist to have only 'batch', 'channel', and spatial axes {allowed_spatial}, but got {prob_dist.dims}" 

668 ) 

669 

670 spatial_shape = tuple( 

671 prob_dist.tagged_shape[a] * g for a, g in zip(allowed_spatial, self.grid) 

672 ) 

673 if len(spatial_shape) != len(self.grid): 

674 raise ValueError( 

675 f"expected {len(self.grid)} spatial dimensions in prob_dist tensor, but got {len(spatial_shape)}" 

676 ) 

677 else: 

678 spatial_shape = cast(NdTuple, spatial_shape) 

679 

680 prob_dist = prob_dist.transpose( 

681 (AxisId("batch"), *allowed_spatial, AxisId("channel")) 

682 ) 

683 labels: List[NDArray[Any]] = [] 

684 for batch_idx in range(prob_dist.sizes[AxisId("batch")]): 

685 prob = prob_dist[ 

686 {AxisId("batch"): batch_idx, AxisId("channel"): 0} 

687 ].to_numpy() 

688 dist = prob_dist[ 

689 {AxisId("batch"): batch_idx, AxisId("channel"): slice(1, None)} 

690 ].to_numpy() 

691 

692 labels_i = self._impl(prob, dist, spatial_shape) 

693 assert labels_i.shape == spatial_shape, ( 

694 f"expected label image shape {spatial_shape}, but got {labels_i.shape}" 

695 ) 

696 labels.append(labels_i) 

697 

698 instance_labels = Tensor( 

699 np.stack(labels)[..., None], 

700 dims=(AxisId("batch"), *allowed_spatial, AxisId("channel")), 

701 ) 

702 sample.members[self.instance_labels_output_id] = instance_labels 

703 

704 @abstractmethod 

705 def _impl( 

706 self, prob: NDArray[Any], dist: NDArray[Any], spatial_shape: NdTuple 

707 ) -> NDArray[np.int32]: 

708 raise NotImplementedError 

709 

710 

711@dataclass 

712class StardistPostprocessing2D( 

713 _StardistPostprocessingBase[ 

714 Tuple[int, int], Tuple[Tuple[int, int], Tuple[int, int]] 

715 ] 

716): 

717 def _impl( 

718 self, prob: NDArray[Any], dist: NDArray[Any], spatial_shape: Tuple[int, int] 

719 ) -> NDArray[np.int32]: 

720 from stardist import ( 

721 non_maximum_suppression, # pyright: ignore[reportUnknownVariableType] 

722 polygons_to_label, # pyright: ignore[reportUnknownVariableType] 

723 ) 

724 

725 points, probi, disti = non_maximum_suppression( # pyright: ignore[reportUnknownVariableType] 

726 dist, 

727 prob, 

728 grid=self.grid, 

729 prob_thresh=self.prob_threshold, 

730 nms_thresh=self.nms_threshold, 

731 b=self.b, # pyright: ignore[reportArgumentType] 

732 ) 

733 

734 return polygons_to_label(disti, points, prob=probi, shape=spatial_shape) 

735 

736 @classmethod 

737 def from_proc_descr( 

738 cls, descr: v0_5.StardistPostprocessingDescr, member_id: MemberId 

739 ) -> Self: 

740 if not isinstance(descr.kwargs, v0_5.StardistPostprocessingKwargs2D): 

741 raise TypeError( 

742 f"expected v0_5.StardistPostprocessingKwargs2D for 2D stardist post-processing, but got {type(descr.kwargs)}" 

743 ) 

744 

745 kwargs = descr.kwargs 

746 return cls( 

747 prob_dist_input_id=member_id, 

748 instance_labels_output_id=member_id, 

749 grid=kwargs.grid, 

750 prob_threshold=kwargs.prob_threshold, 

751 nms_threshold=kwargs.nms_threshold, 

752 b=kwargs.b, 

753 ) 

754 

755 

756@dataclass 

757class StardistPostprocessing3D( 

758 _StardistPostprocessingBase[ 

759 Tuple[int, int, int], Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int]] 

760 ] 

761): 

762 n_rays: int 

763 """Number of rays for 3D star-convex polyhedra.""" 

764 

765 anisotropy: Tuple[float, float, float] 

766 """Anisotropy factors for 3D star-convex polyhedra, i.e. the physical pixel size along each spatial axis.""" 

767 

768 overlap_label: Optional[int] = None 

769 """Optional label to apply to any area of overlapping predicted objects.""" 

770 

771 def _impl( 

772 self, 

773 prob: NDArray[Any], 

774 dist: NDArray[Any], 

775 spatial_shape: Tuple[int, int, int], 

776 ) -> NDArray[np.int32]: 

777 from stardist import ( 

778 Rays_GoldenSpiral, 

779 non_maximum_suppression_3d, # pyright: ignore[reportUnknownVariableType] 

780 polyhedron_to_label, # pyright: ignore[reportUnknownVariableType] 

781 ) 

782 from stardist.matching import ( 

783 relabel_sequential, # pyright: ignore[reportUnknownVariableType] 

784 ) 

785 

786 rays = Rays_GoldenSpiral(self.n_rays, anisotropy=self.anisotropy) 

787 

788 points, probi, disti = non_maximum_suppression_3d( # pyright: ignore[reportUnknownVariableType] 

789 dist, 

790 prob, 

791 rays, 

792 grid=self.grid, 

793 prob_thresh=self.prob_threshold, 

794 nms_thresh=self.nms_threshold, 

795 b=self.b, # pyright: ignore[reportArgumentType] 

796 ) 

797 

798 labels = polyhedron_to_label( # pyright: ignore[reportUnknownVariableType] 

799 disti, 

800 points, 

801 rays=rays, 

802 prob=probi, 

803 shape=spatial_shape, 

804 overlap_label=self.overlap_label, 

805 ) 

806 

807 labels, _, _ = relabel_sequential(labels) 

808 assert isinstance(labels, np.ndarray) and labels.dtype == np.int32 

809 return labels 

810 

811 @classmethod 

812 def from_proc_descr( 

813 cls, descr: v0_5.StardistPostprocessingDescr, member_id: MemberId 

814 ) -> Self: 

815 if not isinstance(descr.kwargs, v0_5.StardistPostprocessingKwargs3D): 

816 raise TypeError( 

817 f"expected v0_5.StardistPostprocessingKwargs3D for 3D stardist post-processing, but got {type(descr.kwargs)}" 

818 ) 

819 

820 kwargs = descr.kwargs 

821 return cls( 

822 prob_dist_input_id=member_id, 

823 instance_labels_output_id=member_id, 

824 grid=kwargs.grid, 

825 prob_threshold=kwargs.prob_threshold, 

826 nms_threshold=kwargs.nms_threshold, 

827 n_rays=kwargs.n_rays, 

828 anisotropy=kwargs.anisotropy, 

829 b=kwargs.b, 

830 overlap_label=kwargs.overlap_label, 

831 ) 

832 

833 

834@dataclass 

835class ZeroMeanUnitVariance(SimpleOperator): 

836 """normalize to zero mean, unit variance.""" 

837 

838 mean: MeanMeasure 

839 std: StdMeasure 

840 

841 eps: float = 1e-6 

842 

843 def __post_init__(self): 

844 assert self.mean.axes == self.std.axes 

845 

846 @property 

847 def required_measures(self) -> Set[Union[MeanMeasure, StdMeasure]]: 

848 return {self.mean, self.std} 

849 

850 def get_output_shape( 

851 self, input_shape: Mapping[AxisId, int] 

852 ) -> Mapping[AxisId, int]: 

853 return input_shape 

854 

855 @classmethod 

856 def from_proc_descr( 

857 cls, 

858 descr: Union[v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr], 

859 member_id: MemberId, 

860 ): 

861 dataset_mode, axes = _get_axes(descr.kwargs) 

862 

863 if dataset_mode: 

864 Mean = DatasetMean 

865 Std = DatasetStd 

866 else: 

867 Mean = SampleMean 

868 Std = SampleStd 

869 

870 return cls( 

871 input=member_id, 

872 output=member_id, 

873 mean=Mean(axes=axes, member_id=member_id), 

874 std=Std(axes=axes, member_id=member_id), 

875 ) 

876 

877 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

878 mean = stat[self.mean] 

879 std = stat[self.std] 

880 return (x - mean) / (std + self.eps) 

881 

882 def get_descr(self): 

883 return v0_5.ZeroMeanUnitVarianceDescr( 

884 kwargs=v0_5.ZeroMeanUnitVarianceKwargs(axes=self.mean.axes, eps=self.eps) 

885 ) 

886 

887 

888@dataclass 

889class FixedZeroMeanUnitVariance(SimpleOperator): 

890 """normalize to zero mean, unit variance with precomputed values.""" 

891 

892 mean: Union[float, xr.DataArray] 

893 std: Union[float, xr.DataArray] 

894 

895 eps: float = 1e-6 

896 

897 def __post_init__(self): 

898 assert ( 

899 isinstance(self.mean, (int, float)) 

900 or isinstance(self.std, (int, float)) 

901 or self.mean.dims == self.std.dims 

902 ) 

903 

904 @property 

905 def required_measures(self) -> Collection[Measure]: 

906 return set() 

907 

908 def get_output_shape( 

909 self, input_shape: Mapping[AxisId, int] 

910 ) -> Mapping[AxisId, int]: 

911 return input_shape 

912 

913 @classmethod 

914 def from_proc_descr( 

915 cls, 

916 descr: v0_5.FixedZeroMeanUnitVarianceDescr, 

917 member_id: MemberId, 

918 ) -> Self: 

919 if isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceKwargs): 

920 dims = None 

921 elif isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs): 

922 dims = (AxisId(descr.kwargs.axis),) 

923 else: 

924 assert_never(descr.kwargs) 

925 

926 return cls( 

927 input=member_id, 

928 output=member_id, 

929 mean=xr.DataArray(descr.kwargs.mean, dims=dims), 

930 std=xr.DataArray(descr.kwargs.std, dims=dims), 

931 ) 

932 

933 def get_descr(self): 

934 if isinstance(self.mean, (int, float)): 

935 assert isinstance(self.std, (int, float)) 

936 kwargs = v0_5.FixedZeroMeanUnitVarianceKwargs(mean=self.mean, std=self.std) 

937 else: 

938 assert isinstance(self.std, xr.DataArray) 

939 assert len(self.mean.dims) == 1 

940 kwargs = v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs( 

941 axis=AxisId(str(self.mean.dims[0])), 

942 mean=list(self.mean), 

943 std=list(self.std), 

944 ) 

945 

946 return v0_5.FixedZeroMeanUnitVarianceDescr(kwargs=kwargs) 

947 

948 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

949 return (x - self.mean) / (self.std + self.eps) 

950 

951 

952@dataclass 

953class CustomPostprocessing(SamplewiseOperator): 

954 """Execute a user-supplied custom postprocessing callable. 

955 

956 The callable is loaded from a Python source file packaged with the model. 

957 The source file's SHA-256 hash is verified before loading. 

958 

959 Two styles are supported — callable class and factory function:: 

960 

961 # Callable class style 

962 class my_postprocess: 

963 def __init__(self, threshold=0.5): 

964 self.threshold = threshold 

965 def __call__(self, *arrays): 

966 return (arrays[0] > self.threshold).astype(np.uint8) 

967 

968 # Factory function style 

969 def my_postprocess(threshold=0.5): 

970 def run(*arrays): 

971 return (arrays[0] > threshold).astype(np.uint8) 

972 return run 

973 

974 Runtime protocol: ``op = callable(**kwargs)`` once at construction; 

975 ``result = op(*tensors)`` once per sample. 

976 """ 

977 

978 output_id: MemberId 

979 """The model output tensor that will be replaced with the op result.""" 

980 

981 input_ids: Sequence[MemberId] 

982 """All model output tensor ids, passed to the op in rdf.yaml declaration order.""" 

983 

984 callable_name: str 

985 """Name of the class or factory function defined in ``source_code``.""" 

986 

987 source_code: bytes 

988 """Python source code of the op file.""" 

989 

990 kwargs: Mapping[str, Any] 

991 """Keyword arguments forwarded to the callable.""" 

992 

993 # Initialised in __post_init__ 

994 _op: Any = field(init=False, repr=False) 

995 

996 def __post_init__(self) -> None: 

997 import importlib.util 

998 import sys 

999 import tempfile 

1000 

1001 # Write source to a temp file so importlib can load it properly 

1002 with tempfile.NamedTemporaryFile( 

1003 suffix=".py", 

1004 prefix=f"_bioimageio_custom_{self.callable_name}_", 

1005 delete=False, 

1006 ) as tmp: 

1007 _ = tmp.write(self.source_code) 

1008 tmp_path = tmp.name 

1009 

1010 spec = importlib.util.spec_from_file_location( 

1011 f"_bioimageio_custom_op_{self.callable_name}", tmp_path 

1012 ) 

1013 if spec is None or spec.loader is None: 

1014 raise ImportError( 

1015 f"Cannot create module spec from {tmp_path!r}" 

1016 ) 

1017 module = importlib.util.module_from_spec(spec) 

1018 sys.modules[spec.name] = module 

1019 spec.loader.exec_module(module) 

1020 

1021 callable_obj = getattr(module, self.callable_name, None) 

1022 if callable_obj is None: 

1023 raise AttributeError( 

1024 f"No attribute {self.callable_name!r} found in custom op source" 

1025 ) 

1026 self._op = callable_obj(**self.kwargs) 

1027 

1028 @property 

1029 def required_measures(self) -> Collection[Measure]: 

1030 return set() 

1031 

1032 def __call__(self, sample: Sample) -> None: 

1033 arrays: List[NDArray[Any]] = [ 

1034 sample.members[mid].data.values 

1035 for mid in self.input_ids 

1036 if mid in sample.members 

1037 ] 

1038 result_array: NDArray[Any] = self._op(*arrays) 

1039 result_xr = xr.DataArray( 

1040 result_array, dims=sample.members[self.output_id].dims 

1041 ) 

1042 sample.members[self.output_id] = Tensor.from_xarray(result_xr) 

1043 

1044 @classmethod 

1045 def from_proc_descr( 

1046 cls, 

1047 descr: Any, # v0_5.CustomPostprocessingDescr (guarded for older spec versions) 

1048 tensor_descr: v0_5.OutputTensorDescr, 

1049 all_output_ids: Sequence[MemberId], 

1050 ) -> "CustomPostprocessing": 

1051 from bioimageio.spec._internal.io import get_reader 

1052 

1053 output_id = get_member_id(tensor_descr) 

1054 reader = get_reader(descr.source, sha256=descr.sha256) 

1055 source_code: bytes = reader.read() 

1056 

1057 return cls( 

1058 output_id=output_id, 

1059 input_ids=list(all_output_ids), 

1060 callable_name=descr.callable, 

1061 source_code=source_code, 

1062 kwargs=dict(descr.kwargs), 

1063 ) 

1064 

1065 

1066ProcDescr = Union[ 

1067 v0_4.PreprocessingDescr, 

1068 v0_4.PostprocessingDescr, 

1069 v0_5.PreprocessingDescr, 

1070 v0_5.PostprocessingDescr, 

1071] 

1072 

1073 

1074Processing = Union[ 

1075 AddKnownDatasetStats, 

1076 Binarize, 

1077 Clip, 

1078 CustomPostprocessing, 

1079 EnsureDtype, 

1080 FixedZeroMeanUnitVariance, 

1081 ScaleLinear, 

1082 ScaleMeanVariance, 

1083 ScaleRange, 

1084 Sigmoid, 

1085 StardistPostprocessing2D, 

1086 StardistPostprocessing3D, 

1087 Softmax, 

1088 UpdateStats, 

1089 ZeroMeanUnitVariance, 

1090] 

1091 

1092 

1093def get_proc( 

1094 proc_descr: ProcDescr, 

1095 tensor_descr: Union[ 

1096 v0_4.InputTensorDescr, 

1097 v0_4.OutputTensorDescr, 

1098 v0_5.InputTensorDescr, 

1099 v0_5.OutputTensorDescr, 

1100 ], 

1101) -> Processing: 

1102 member_id = get_member_id(tensor_descr) 

1103 

1104 if isinstance(proc_descr, (v0_4.BinarizeDescr, v0_5.BinarizeDescr)): 

1105 return Binarize.from_proc_descr(proc_descr, member_id) 

1106 elif isinstance(proc_descr, (v0_4.ClipDescr, v0_5.ClipDescr)): 

1107 return Clip.from_proc_descr(proc_descr, member_id) 

1108 elif isinstance(proc_descr, v0_5.EnsureDtypeDescr): 

1109 return EnsureDtype.from_proc_descr(proc_descr, member_id) 

1110 elif isinstance(proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr): 

1111 return FixedZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id) 

1112 elif isinstance(proc_descr, (v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr)): 

1113 return ScaleLinear.from_proc_descr(proc_descr, member_id) 

1114 elif isinstance( 

1115 proc_descr, (v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr) 

1116 ): 

1117 return ScaleMeanVariance.from_proc_descr(proc_descr, member_id) 

1118 elif isinstance(proc_descr, (v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr)): 

1119 return ScaleRange.from_proc_descr(proc_descr, member_id) 

1120 elif isinstance(proc_descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)): 

1121 return Sigmoid.from_proc_descr(proc_descr, member_id) 

1122 elif ( 

1123 isinstance(proc_descr, v0_4.ZeroMeanUnitVarianceDescr) 

1124 and proc_descr.kwargs.mode == "fixed" 

1125 ): 

1126 if not isinstance( 

1127 tensor_descr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr) 

1128 ): 

1129 raise TypeError( 

1130 "Expected v0_4 tensor description for v0_4 processing description" 

1131 ) 

1132 

1133 v5_proc_descr = _convert_proc(proc_descr, tensor_descr.axes) 

1134 assert isinstance(v5_proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr) 

1135 return FixedZeroMeanUnitVariance.from_proc_descr(v5_proc_descr, member_id) 

1136 elif isinstance( 

1137 proc_descr, 

1138 (v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr), 

1139 ): 

1140 return ZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id) 

1141 elif isinstance(proc_descr, v0_5.SoftmaxDescr): 

1142 return Softmax.from_proc_descr(proc_descr, member_id) 

1143 elif isinstance(proc_descr, v0_5.StardistPostprocessingDescr): 

1144 if isinstance(proc_descr.kwargs, v0_5.StardistPostprocessingKwargs2D): 

1145 return StardistPostprocessing2D.from_proc_descr(proc_descr, member_id) 

1146 elif isinstance(proc_descr.kwargs, v0_5.StardistPostprocessingKwargs3D): 

1147 return StardistPostprocessing3D.from_proc_descr(proc_descr, member_id) 

1148 else: 

1149 raise ValueError( 

1150 f"expected ndim 2 or 3 for stardist postprocessing, but got {proc_descr.kwargs.ndim}" 

1151 ) 

1152 else: 

1153 assert_never(proc_descr)