Coverage for src / bioimageio / core / proc_ops.py: 79%

416 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-18 12:35 +0000

1import collections.abc 

2from dataclasses import InitVar, dataclass, field 

3from functools import partial 

4from typing import ( 

5 Any, 

6 Callable, 

7 Collection, 

8 Literal, 

9 Mapping, 

10 Optional, 

11 Sequence, 

12 Set, 

13 Tuple, 

14 Union, 

15) 

16 

17import numpy as np 

18import scipy 

19import xarray as xr 

20from numpy.typing import NDArray 

21from typing_extensions import Self, assert_never 

22 

23from bioimageio.spec.model import v0_4, v0_5 

24from bioimageio.spec.model.v0_5 import ( 

25 _convert_proc, # pyright: ignore[reportPrivateUsage] 

26) 

27 

28from ._op_base import BlockwiseOperator, SamplewiseOperator, SimpleOperator 

29from ._ops_cellpose import CellposeFlowDynamics 

30from ._ops_stardist import StardistPostprocessing2D as StardistPostprocessing2D 

31from ._ops_stardist import StardistPostprocessing3D as StardistPostprocessing3D 

32from .axis import AxisId 

33from .common import DTypeStr, MemberId 

34from .digest_spec import get_member_id, import_callable 

35from .sample import Sample, SampleBlock 

36from .stat_calculators import StatsCalculator 

37from .stat_measures import ( 

38 DatasetMean, 

39 DatasetQuantile, 

40 DatasetStd, 

41 MeanMeasure, 

42 Measure, 

43 MeasureValue, 

44 SampleMean, 

45 SampleQuantile, 

46 SampleStd, 

47 Stat, 

48 StdMeasure, 

49) 

50from .tensor import Tensor 

51 

52 

53def _convert_axis_ids( 

54 axes: v0_4.AxesInCZYX, 

55 mode: Literal["per_sample", "per_dataset"], 

56) -> Tuple[AxisId, ...]: 

57 if not isinstance(axes, str): 

58 return tuple(axes) 

59 

60 if mode == "per_sample": 

61 ret = [] 

62 elif mode == "per_dataset": 

63 ret = [v0_5.BATCH_AXIS_ID] 

64 else: 

65 assert_never(mode) 

66 

67 ret.extend([AxisId(a) for a in axes]) 

68 return tuple(ret) 

69 

70 

71@dataclass 

72class AddKnownDatasetStats(BlockwiseOperator): 

73 dataset_stats: Mapping[Measure, MeasureValue] 

74 

75 def __post_init__(self): 

76 # keep only dataset measures 

77 self.dataset_stats = { 

78 k: v for k, v in self.dataset_stats.items() if k.scope == "dataset" 

79 } 

80 

81 @property 

82 def required_measures(self) -> Collection[Measure]: 

83 return set() 

84 

85 def __call__(self, sample: Union[Sample, SampleBlock]) -> None: 

86 sample.stat.update(self.dataset_stats) 

87 

88 

89# @dataclass 

90# class UpdateStats(Operator): 

91# """Calculates sample and/or dataset measures""" 

92 

93# measures: Union[Sequence[Measure], Set[Measure], Mapping[Measure, MeasureValue]] 

94# """sample and dataset `measuers` to be calculated by this operator. Initial/fixed 

95# dataset measure values may be given, see `keep_updating_dataset_stats` for details. 

96# """ 

97# keep_updating_dataset_stats: Optional[bool] = None 

98# """indicates if operator calls should keep updating dataset statistics or not 

99 

100# default (None): if `measures` is a `Mapping` (i.e. initial measure values are 

101# given) no further updates to dataset statistics is conducted, otherwise (w.o. 

102# initial measure values) dataset statistics are updated by each processed sample. 

103# """ 

104# _keep_updating_dataset_stats: bool = field(init=False) 

105# _stats_calculator: StatsCalculator = field(init=False) 

106 

107# @property 

108# def required_measures(self) -> Collection[Measure]: 

109# return set() 

110 

111# def __post_init__(self): 

112# self._stats_calculator = StatsCalculator(self.measures) 

113# if self.keep_updating_dataset_stats is None: 

114# self._keep_updating_dataset_stats = not isinstance(self.measures, collections.abc.Mapping) 

115# else: 

116# self._keep_updating_dataset_stats = self.keep_updating_dataset_stats 

117 

118# def __call__(self, sample_block: SampleBlockWithOrigin> None: 

119# if self._keep_updating_dataset_stats: 

120# sample.stat.update(self._stats_calculator.update_and_get_all(sample)) 

121# else: 

122# sample.stat.update(self._stats_calculator.skip_update_and_get_all(sample)) 

123 

124 

125@dataclass 

126class UpdateStats(SamplewiseOperator): 

127 """Calculates sample and/or dataset measures""" 

128 

129 stats_calculator: StatsCalculator 

130 """`StatsCalculator` to be used by this operator.""" 

131 keep_updating_initial_dataset_stats: bool = False 

132 """indicates if operator calls should keep updating initial dataset statistics or not; 

133 if the `stats_calculator` was not provided with any initial dataset statistics, 

134 these are always updated with every new sample. 

135 """ 

136 _keep_updating_dataset_stats: bool = field(init=False) 

137 

138 @property 

139 def required_measures(self) -> Collection[Measure]: 

140 return set() 

141 

142 def __post_init__(self): 

143 self._keep_updating_dataset_stats = ( 

144 self.keep_updating_initial_dataset_stats 

145 or not self.stats_calculator.has_dataset_measures 

146 ) 

147 

148 def __call__(self, sample: Sample) -> None: 

149 if self._keep_updating_dataset_stats: 

150 sample.stat.update(self.stats_calculator.update_and_get_all(sample)) 

151 else: 

152 sample.stat.update(self.stats_calculator.skip_update_and_get_all(sample)) 

153 

154 

155@dataclass 

156class Binarize(SimpleOperator): 

157 """'output = tensor > threshold'.""" 

158 

159 threshold: Union[float, Sequence[float]] 

160 axis: Optional[AxisId] = None 

161 

162 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

163 return x > self.threshold 

164 

165 @property 

166 def required_measures(self) -> Collection[Measure]: 

167 return set() 

168 

169 def get_output_shape( 

170 self, input_shape: Mapping[AxisId, int] 

171 ) -> Mapping[AxisId, int]: 

172 return input_shape 

173 

174 @classmethod 

175 def from_proc_descr( 

176 cls, descr: Union[v0_4.BinarizeDescr, v0_5.BinarizeDescr], member_id: MemberId 

177 ) -> Self: 

178 if isinstance(descr.kwargs, (v0_4.BinarizeKwargs, v0_5.BinarizeKwargs)): 

179 return cls( 

180 input=member_id, output=member_id, threshold=descr.kwargs.threshold 

181 ) 

182 elif isinstance(descr.kwargs, v0_5.BinarizeAlongAxisKwargs): 

183 return cls( 

184 input=member_id, 

185 output=member_id, 

186 threshold=descr.kwargs.threshold, 

187 axis=descr.kwargs.axis, 

188 ) 

189 else: 

190 assert_never(descr.kwargs) 

191 

192 

193@dataclass 

194class Clip(SimpleOperator): 

195 min: Optional[Union[float, SampleQuantile, DatasetQuantile]] = None 

196 """minimum value for clipping""" 

197 max: Optional[Union[float, SampleQuantile, DatasetQuantile]] = None 

198 """maximum value for clipping""" 

199 

200 def __post_init__(self): 

201 if self.min is None and self.max is None: 

202 raise ValueError("missing min or max value") 

203 

204 if ( 

205 isinstance(self.min, float) 

206 and isinstance(self.max, float) 

207 and self.min >= self.max 

208 ): 

209 raise ValueError(f"expected min < max, but {self.min} >= {self.max}") 

210 

211 if isinstance(self.min, (SampleQuantile, DatasetQuantile)) and isinstance( 

212 self.max, (SampleQuantile, DatasetQuantile) 

213 ): 

214 if self.min.axes != self.max.axes: 

215 raise NotImplementedError( 

216 f"expected min and max quantiles with same axes, but got {self.min.axes} and {self.max.axes}" 

217 ) 

218 if self.min.q >= self.max.q: 

219 raise ValueError( 

220 f"expected min quantile < max quantile, but {self.min.q} >= {self.max.q}" 

221 ) 

222 

223 @property 

224 def required_measures(self): 

225 return { 

226 arg 

227 for arg in (self.min, self.max) 

228 if isinstance(arg, (SampleQuantile, DatasetQuantile)) 

229 } 

230 

231 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

232 if isinstance(self.min, (SampleQuantile, DatasetQuantile)): 

233 min_value = stat[self.min] 

234 if isinstance(min_value, (int, float)): 

235 # use clip for scalar value 

236 min_clip_arg = min_value 

237 else: 

238 # clip does not support non-scalar values 

239 x = Tensor.from_xarray( 

240 x.data.where(x.data >= min_value.data, min_value.data) 

241 ) 

242 min_clip_arg = None 

243 else: 

244 min_clip_arg = self.min 

245 

246 if isinstance(self.max, (SampleQuantile, DatasetQuantile)): 

247 max_value = stat[self.max] 

248 if isinstance(max_value, (int, float)): 

249 # use clip for scalar value 

250 max_clip_arg = max_value 

251 else: 

252 # clip does not support non-scalar values 

253 x = Tensor.from_xarray( 

254 x.data.where(x.data <= max_value.data, max_value.data) 

255 ) 

256 max_clip_arg = None 

257 else: 

258 max_clip_arg = self.max 

259 

260 if min_clip_arg is not None or max_clip_arg is not None: 

261 x = x.clip(min_clip_arg, max_clip_arg) 

262 

263 return x 

264 

265 def get_output_shape( 

266 self, input_shape: Mapping[AxisId, int] 

267 ) -> Mapping[AxisId, int]: 

268 return input_shape 

269 

270 @classmethod 

271 def from_proc_descr( 

272 cls, descr: Union[v0_4.ClipDescr, v0_5.ClipDescr], member_id: MemberId 

273 ) -> Self: 

274 if isinstance(descr, v0_5.ClipDescr): 

275 dataset_mode, axes = _get_axes(descr.kwargs) 

276 if dataset_mode: 

277 Quantile = DatasetQuantile 

278 else: 

279 Quantile = partial(SampleQuantile, method="inverted_cdf") 

280 

281 if descr.kwargs.min is not None: 

282 min_arg = descr.kwargs.min 

283 elif descr.kwargs.min_percentile is not None: 

284 min_arg = Quantile( 

285 q=descr.kwargs.min_percentile / 100, 

286 axes=axes, 

287 member_id=member_id, 

288 ) 

289 else: 

290 min_arg = None 

291 

292 if descr.kwargs.max is not None: 

293 max_arg = descr.kwargs.max 

294 elif descr.kwargs.max_percentile is not None: 

295 max_arg = Quantile( 

296 q=descr.kwargs.max_percentile / 100, 

297 axes=axes, 

298 member_id=member_id, 

299 ) 

300 else: 

301 max_arg = None 

302 

303 elif isinstance(descr, v0_4.ClipDescr): 

304 min_arg = descr.kwargs.min 

305 max_arg = descr.kwargs.max 

306 else: 

307 assert_never(descr) 

308 

309 return cls( 

310 input=member_id, 

311 output=member_id, 

312 min=min_arg, 

313 max=max_arg, 

314 ) 

315 

316 

317@dataclass 

318class EnsureDtype(SimpleOperator): 

319 dtype: DTypeStr 

320 

321 @classmethod 

322 def from_proc_descr(cls, descr: v0_5.EnsureDtypeDescr, member_id: MemberId): 

323 return cls(input=member_id, output=member_id, dtype=descr.kwargs.dtype) 

324 

325 def get_descr(self): 

326 return v0_5.EnsureDtypeDescr(kwargs=v0_5.EnsureDtypeKwargs(dtype=self.dtype)) 

327 

328 def get_output_shape( 

329 self, input_shape: Mapping[AxisId, int] 

330 ) -> Mapping[AxisId, int]: 

331 return input_shape 

332 

333 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

334 return x.astype(self.dtype) 

335 

336 @property 

337 def required_measures(self) -> Collection[Measure]: 

338 return set() 

339 

340 

341@dataclass 

342class ScaleLinear(SimpleOperator): 

343 gain: Union[float, xr.DataArray] = 1.0 

344 """multiplicative factor""" 

345 

346 offset: Union[float, xr.DataArray] = 0.0 

347 """additive term""" 

348 

349 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

350 return x * self.gain + self.offset 

351 

352 @property 

353 def required_measures(self) -> Collection[Measure]: 

354 return set() 

355 

356 def get_output_shape( 

357 self, input_shape: Mapping[AxisId, int] 

358 ) -> Mapping[AxisId, int]: 

359 return input_shape 

360 

361 @classmethod 

362 def from_proc_descr( 

363 cls, 

364 descr: Union[v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr], 

365 member_id: MemberId, 

366 ) -> Self: 

367 kwargs = descr.kwargs 

368 if isinstance(kwargs, v0_5.ScaleLinearKwargs): 

369 axis = None 

370 elif isinstance(kwargs, v0_5.ScaleLinearAlongAxisKwargs): 

371 axis = kwargs.axis 

372 elif isinstance(kwargs, v0_4.ScaleLinearKwargs): 

373 if kwargs.axes is not None: 

374 raise NotImplementedError( 

375 "model.v0_4.ScaleLinearKwargs with axes not implemented, please consider updating the model to v0_5." 

376 ) 

377 axis = None 

378 else: 

379 assert_never(kwargs) 

380 

381 if axis: 

382 gain = xr.DataArray(np.atleast_1d(kwargs.gain), dims=axis) 

383 offset = xr.DataArray(np.atleast_1d(kwargs.offset), dims=axis) 

384 else: 

385 assert isinstance(kwargs.gain, (float, int)) or len(kwargs.gain) == 1, ( 

386 kwargs.gain 

387 ) 

388 gain = ( 

389 kwargs.gain if isinstance(kwargs.gain, (float, int)) else kwargs.gain[0] 

390 ) 

391 assert isinstance(kwargs.offset, (float, int)) or len(kwargs.offset) == 1 

392 offset = ( 

393 kwargs.offset 

394 if isinstance(kwargs.offset, (float, int)) 

395 else kwargs.offset[0] 

396 ) 

397 

398 return cls(input=member_id, output=member_id, gain=gain, offset=offset) 

399 

400 

401@dataclass 

402class ScaleMeanVariance(SimpleOperator): 

403 axes: Optional[Sequence[AxisId]] = None 

404 reference_tensor: Optional[MemberId] = None 

405 eps: float = 1e-6 

406 mean: Union[SampleMean, DatasetMean] = field(init=False) 

407 std: Union[SampleStd, DatasetStd] = field(init=False) 

408 ref_mean: Union[SampleMean, DatasetMean] = field(init=False) 

409 ref_std: Union[SampleStd, DatasetStd] = field(init=False) 

410 

411 @property 

412 def required_measures(self): 

413 return {self.mean, self.std, self.ref_mean, self.ref_std} 

414 

415 def __post_init__(self): 

416 axes = None if self.axes is None else tuple(self.axes) 

417 ref_tensor = self.reference_tensor or self.input 

418 if axes is None or AxisId("batch") not in axes: 

419 Mean = SampleMean 

420 Std = SampleStd 

421 else: 

422 Mean = DatasetMean 

423 Std = DatasetStd 

424 

425 self.mean = Mean(member_id=self.input, axes=axes) 

426 self.std = Std(member_id=self.input, axes=axes) 

427 self.ref_mean = Mean(member_id=ref_tensor, axes=axes) 

428 self.ref_std = Std(member_id=ref_tensor, axes=axes) 

429 

430 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

431 mean = stat[self.mean] 

432 std = stat[self.std] + self.eps 

433 ref_mean = stat[self.ref_mean] 

434 ref_std = stat[self.ref_std] + self.eps 

435 return (x - mean) / std * ref_std + ref_mean 

436 

437 def get_output_shape( 

438 self, input_shape: Mapping[AxisId, int] 

439 ) -> Mapping[AxisId, int]: 

440 return input_shape 

441 

442 @classmethod 

443 def from_proc_descr( 

444 cls, 

445 descr: Union[v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr], 

446 member_id: MemberId, 

447 ) -> Self: 

448 kwargs = descr.kwargs 

449 _, axes = _get_axes(descr.kwargs) 

450 

451 return cls( 

452 input=member_id, 

453 output=member_id, 

454 reference_tensor=MemberId(str(kwargs.reference_tensor)), 

455 axes=axes, 

456 eps=kwargs.eps, 

457 ) 

458 

459 

460def _get_axes( 

461 kwargs: Union[ 

462 v0_4.ZeroMeanUnitVarianceKwargs, 

463 v0_5.ZeroMeanUnitVarianceKwargs, 

464 v0_4.ScaleRangeKwargs, 

465 v0_5.ScaleRangeKwargs, 

466 v0_4.ScaleMeanVarianceKwargs, 

467 v0_5.ScaleMeanVarianceKwargs, 

468 v0_5.ClipKwargs, 

469 ], 

470) -> Tuple[bool, Optional[Tuple[AxisId, ...]]]: 

471 if kwargs.axes is None: 

472 return True, None 

473 elif isinstance(kwargs.axes, str): 

474 axes = _convert_axis_ids(kwargs.axes, kwargs["mode"]) 

475 return AxisId("b") in axes, axes 

476 elif isinstance(kwargs.axes, collections.abc.Sequence): 

477 axes = tuple(kwargs.axes) 

478 return AxisId("batch") in axes, axes 

479 else: 

480 assert_never(kwargs.axes) 

481 

482 

483@dataclass 

484class ScaleRange(SimpleOperator): 

485 lower_quantile: InitVar[Optional[Union[SampleQuantile, DatasetQuantile]]] = None 

486 upper_quantile: InitVar[Optional[Union[SampleQuantile, DatasetQuantile]]] = None 

487 lower: Union[SampleQuantile, DatasetQuantile] = field(init=False) 

488 upper: Union[SampleQuantile, DatasetQuantile] = field(init=False) 

489 

490 eps: float = 1e-6 

491 

492 def __post_init__( 

493 self, 

494 lower_quantile: Optional[Union[SampleQuantile, DatasetQuantile]], 

495 upper_quantile: Optional[Union[SampleQuantile, DatasetQuantile]], 

496 ): 

497 if lower_quantile is None: 

498 tid = self.input if upper_quantile is None else upper_quantile.member_id 

499 self.lower = DatasetQuantile(q=0.0, member_id=tid) 

500 else: 

501 self.lower = lower_quantile 

502 

503 if upper_quantile is None: 

504 self.upper = DatasetQuantile(q=1.0, member_id=self.lower.member_id) 

505 else: 

506 self.upper = upper_quantile 

507 

508 assert self.lower.member_id == self.upper.member_id 

509 assert self.lower.q < self.upper.q 

510 assert self.lower.axes == self.upper.axes 

511 

512 @property 

513 def required_measures(self): 

514 return {self.lower, self.upper} 

515 

516 def get_output_shape( 

517 self, input_shape: Mapping[AxisId, int] 

518 ) -> Mapping[AxisId, int]: 

519 return input_shape 

520 

521 @classmethod 

522 def from_proc_descr( 

523 cls, 

524 descr: Union[v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr], 

525 member_id: MemberId, 

526 ): 

527 kwargs = descr.kwargs 

528 ref_tensor = ( 

529 member_id 

530 if kwargs.reference_tensor is None 

531 else MemberId(str(kwargs.reference_tensor)) 

532 ) 

533 dataset_mode, axes = _get_axes(descr.kwargs) 

534 if dataset_mode: 

535 Quantile = DatasetQuantile 

536 else: 

537 Quantile = partial(SampleQuantile, method="linear") 

538 

539 return cls( 

540 input=member_id, 

541 output=member_id, 

542 lower_quantile=Quantile( 

543 q=kwargs.min_percentile / 100, 

544 axes=axes, 

545 member_id=ref_tensor, 

546 ), 

547 upper_quantile=Quantile( 

548 q=kwargs.max_percentile / 100, 

549 axes=axes, 

550 member_id=ref_tensor, 

551 ), 

552 ) 

553 

554 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

555 lower = stat[self.lower] 

556 upper = stat[self.upper] 

557 return (x - lower) / (upper - lower + self.eps) 

558 

559 def get_descr(self): 

560 assert self.lower.axes == self.upper.axes 

561 assert self.lower.member_id == self.upper.member_id 

562 

563 return v0_5.ScaleRangeDescr( 

564 kwargs=v0_5.ScaleRangeKwargs( 

565 axes=self.lower.axes, 

566 min_percentile=self.lower.q * 100, 

567 max_percentile=self.upper.q * 100, 

568 eps=self.eps, 

569 reference_tensor=self.lower.member_id, 

570 ) 

571 ) 

572 

573 

574@dataclass 

575class Sigmoid(SimpleOperator): 

576 """1 / (1 + e^(-input)).""" 

577 

578 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

579 return Tensor(1.0 / (1.0 + np.exp(-x)), dims=x.dims) 

580 

581 @property 

582 def required_measures(self) -> Collection[Measure]: 

583 return {} 

584 

585 def get_output_shape( 

586 self, input_shape: Mapping[AxisId, int] 

587 ) -> Mapping[AxisId, int]: 

588 return input_shape 

589 

590 @classmethod 

591 def from_proc_descr( 

592 cls, descr: Union[v0_4.SigmoidDescr, v0_5.SigmoidDescr], member_id: MemberId 

593 ) -> Self: 

594 assert isinstance(descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)) 

595 return cls(input=member_id, output=member_id) 

596 

597 def get_descr(self): 

598 return v0_5.SigmoidDescr() 

599 

600 

601@dataclass 

602class Softmax(SimpleOperator): 

603 """Softmax activation function.""" 

604 

605 axis: AxisId = AxisId("channel") 

606 

607 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

608 axis_idx = x.dims.index(self.axis) 

609 result = scipy.special.softmax(x.data, axis=axis_idx) 

610 result_xr = xr.DataArray(result, dims=x.dims) 

611 return Tensor.from_xarray(result_xr) 

612 

613 @property 

614 def required_measures(self) -> Collection[Measure]: 

615 return set() 

616 

617 def get_output_shape( 

618 self, input_shape: Mapping[AxisId, int] 

619 ) -> Mapping[AxisId, int]: 

620 return input_shape 

621 

622 @classmethod 

623 def from_proc_descr(cls, descr: v0_5.SoftmaxDescr, member_id: MemberId) -> Self: 

624 assert isinstance(descr, v0_5.SoftmaxDescr) 

625 return cls(input=member_id, output=member_id, axis=descr.kwargs.axis) 

626 

627 def get_descr(self): 

628 return v0_5.SoftmaxDescr(kwargs=v0_5.SoftmaxKwargs(axis=self.axis)) 

629 

630 

631@dataclass 

632class ZeroMeanUnitVariance(SimpleOperator): 

633 """normalize to zero mean, unit variance.""" 

634 

635 mean: MeanMeasure 

636 std: StdMeasure 

637 

638 eps: float = 1e-6 

639 

640 def __post_init__(self): 

641 assert self.mean.axes == self.std.axes 

642 

643 @property 

644 def required_measures(self) -> Set[Union[MeanMeasure, StdMeasure]]: 

645 return {self.mean, self.std} 

646 

647 def get_output_shape( 

648 self, input_shape: Mapping[AxisId, int] 

649 ) -> Mapping[AxisId, int]: 

650 return input_shape 

651 

652 @classmethod 

653 def from_proc_descr( 

654 cls, 

655 descr: Union[v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr], 

656 member_id: MemberId, 

657 ): 

658 dataset_mode, axes = _get_axes(descr.kwargs) 

659 

660 if dataset_mode: 

661 Mean = DatasetMean 

662 Std = DatasetStd 

663 else: 

664 Mean = SampleMean 

665 Std = SampleStd 

666 

667 return cls( 

668 input=member_id, 

669 output=member_id, 

670 mean=Mean(axes=axes, member_id=member_id), 

671 std=Std(axes=axes, member_id=member_id), 

672 ) 

673 

674 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

675 mean = stat[self.mean] 

676 std = stat[self.std] 

677 return (x - mean) / (std + self.eps) 

678 

679 def get_descr(self): 

680 return v0_5.ZeroMeanUnitVarianceDescr( 

681 kwargs=v0_5.ZeroMeanUnitVarianceKwargs(axes=self.mean.axes, eps=self.eps) 

682 ) 

683 

684 

685@dataclass 

686class FixedZeroMeanUnitVariance(SimpleOperator): 

687 """normalize to zero mean, unit variance with precomputed values.""" 

688 

689 mean: Union[float, xr.DataArray] 

690 std: Union[float, xr.DataArray] 

691 

692 eps: float = 1e-6 

693 

694 def __post_init__(self): 

695 assert ( 

696 isinstance(self.mean, (int, float)) 

697 or isinstance(self.std, (int, float)) 

698 or self.mean.dims == self.std.dims 

699 ) 

700 

701 @property 

702 def required_measures(self) -> Collection[Measure]: 

703 return set() 

704 

705 def get_output_shape( 

706 self, input_shape: Mapping[AxisId, int] 

707 ) -> Mapping[AxisId, int]: 

708 return input_shape 

709 

710 @classmethod 

711 def from_proc_descr( 

712 cls, 

713 descr: v0_5.FixedZeroMeanUnitVarianceDescr, 

714 member_id: MemberId, 

715 ) -> Self: 

716 if isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceKwargs): 

717 dims = None 

718 elif isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs): 

719 dims = (AxisId(descr.kwargs.axis),) 

720 else: 

721 assert_never(descr.kwargs) 

722 

723 return cls( 

724 input=member_id, 

725 output=member_id, 

726 mean=xr.DataArray(descr.kwargs.mean, dims=dims), 

727 std=xr.DataArray(descr.kwargs.std, dims=dims), 

728 ) 

729 

730 def get_descr(self): 

731 if isinstance(self.mean, (int, float)): 

732 assert isinstance(self.std, (int, float)) 

733 kwargs = v0_5.FixedZeroMeanUnitVarianceKwargs(mean=self.mean, std=self.std) 

734 else: 

735 assert isinstance(self.std, xr.DataArray) 

736 assert len(self.mean.dims) == 1 

737 kwargs = v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs( 

738 axis=AxisId(str(self.mean.dims[0])), 

739 mean=list(self.mean), 

740 std=list(self.std), 

741 ) 

742 

743 return v0_5.FixedZeroMeanUnitVarianceDescr(kwargs=kwargs) 

744 

745 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

746 return (x - self.mean) / (self.std + self.eps) 

747 

748 

749@dataclass 

750class CustomProcessing(SimpleOperator): 

751 """Execute a user-supplied custom processing callable. 

752 

753 Two styles are supported — callable class and factory function:: 

754 

755 # Callable class style 

756 class my_factory: 

757 def __init__(self, threshold=0.5): 

758 self.threshold = threshold 

759 def __call__(self, *arrays): 

760 return (arrays[0] > self.threshold).astype(np.uint8) 

761 

762 # Factory function style 

763 def my_factory(threshold=0.5): 

764 def run(*arrays): 

765 return (arrays[0] > threshold).astype(np.uint8) 

766 return run 

767 

768 Runtime protocol: ``custom_callable = my_factory(**kwargs)`` once at construction; 

769 ``result = custom_callable(tensor)`` once per sample. 

770 

771 Note: The custom callable may not change the shape of the input tensor. 

772 """ 

773 

774 custom_factory: Callable[..., Callable[[NDArray[Any]], NDArray[Any]]] 

775 

776 kwargs: Mapping[str, Any] 

777 """Keyword arguments forwarded to the custom factory.""" 

778 

779 # Initialised in __post_init__ 

780 _custom_callable: Any = field(init=False, repr=False) 

781 

782 def __post_init__(self) -> None: 

783 self._custom_callable = self.custom_factory(**self.kwargs) 

784 

785 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

786 return Tensor.from_numpy(self._custom_callable(x.to_numpy()), dims=x.dims) 

787 

788 def get_output_shape( 

789 self, input_shape: Mapping[AxisId, int] 

790 ) -> Mapping[AxisId, int]: 

791 return input_shape 

792 

793 @property 

794 def required_measures(self) -> Collection[Measure]: 

795 return set() 

796 

797 @classmethod 

798 def from_proc_descr( 

799 cls, 

800 descr: v0_5.CustomProcessingDescr, 

801 member_id: MemberId, 

802 ) -> Self: 

803 factory = import_callable(descr) 

804 

805 return cls( 

806 input=member_id, 

807 output=member_id, 

808 custom_factory=factory, 

809 kwargs=dict(descr.kwargs), 

810 ) 

811 

812 

813ProcDescr = Union[ 

814 v0_4.PreprocessingDescr, 

815 v0_4.PostprocessingDescr, 

816 v0_5.PreprocessingDescr, 

817 v0_5.PostprocessingDescr, 

818] 

819 

820 

821Processing = Union[ 

822 AddKnownDatasetStats, 

823 Binarize, 

824 Clip, 

825 CellposeFlowDynamics, 

826 CustomProcessing, 

827 EnsureDtype, 

828 FixedZeroMeanUnitVariance, 

829 ScaleLinear, 

830 ScaleMeanVariance, 

831 ScaleRange, 

832 Sigmoid, 

833 StardistPostprocessing2D, 

834 StardistPostprocessing3D, 

835 Softmax, 

836 UpdateStats, 

837 ZeroMeanUnitVariance, 

838] 

839 

840 

841def get_proc( 

842 proc_descr: ProcDescr, 

843 tensor_descr: Union[ 

844 v0_4.InputTensorDescr, 

845 v0_4.OutputTensorDescr, 

846 v0_5.InputTensorDescr, 

847 v0_5.OutputTensorDescr, 

848 ], 

849) -> Processing: 

850 member_id = get_member_id(tensor_descr) 

851 

852 if isinstance(proc_descr, (v0_4.BinarizeDescr, v0_5.BinarizeDescr)): 

853 return Binarize.from_proc_descr(proc_descr, member_id) 

854 elif isinstance(proc_descr, v0_5.CellposeFlowDynamicsDescr): 

855 return CellposeFlowDynamics.from_proc_descr(proc_descr, member_id) 

856 elif isinstance(proc_descr, (v0_4.ClipDescr, v0_5.ClipDescr)): 

857 return Clip.from_proc_descr(proc_descr, member_id) 

858 elif isinstance(proc_descr, v0_5.CustomProcessingDescr): 

859 return CustomProcessing.from_proc_descr(proc_descr, member_id) 

860 elif isinstance(proc_descr, v0_5.EnsureDtypeDescr): 

861 return EnsureDtype.from_proc_descr(proc_descr, member_id) 

862 elif isinstance(proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr): 

863 return FixedZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id) 

864 elif isinstance(proc_descr, (v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr)): 

865 return ScaleLinear.from_proc_descr(proc_descr, member_id) 

866 elif isinstance( 

867 proc_descr, (v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr) 

868 ): 

869 return ScaleMeanVariance.from_proc_descr(proc_descr, member_id) 

870 elif isinstance(proc_descr, (v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr)): 

871 return ScaleRange.from_proc_descr(proc_descr, member_id) 

872 elif isinstance(proc_descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)): 

873 return Sigmoid.from_proc_descr(proc_descr, member_id) 

874 elif ( 

875 isinstance(proc_descr, v0_4.ZeroMeanUnitVarianceDescr) 

876 and proc_descr.kwargs.mode == "fixed" 

877 ): 

878 if not isinstance( 

879 tensor_descr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr) 

880 ): 

881 raise TypeError( 

882 "Expected v0_4 tensor description for v0_4 processing description" 

883 ) 

884 

885 v5_proc_descr = _convert_proc(proc_descr, tensor_descr.axes) 

886 assert isinstance(v5_proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr) 

887 return FixedZeroMeanUnitVariance.from_proc_descr(v5_proc_descr, member_id) 

888 elif isinstance(proc_descr, v0_5.SoftmaxDescr): 

889 return Softmax.from_proc_descr(proc_descr, member_id) 

890 elif isinstance(proc_descr, v0_5.StardistPostprocessingDescr): 

891 if isinstance(proc_descr.kwargs, v0_5.StardistPostprocessingKwargs2D): 

892 return StardistPostprocessing2D.from_proc_descr(proc_descr, member_id) 

893 elif isinstance(proc_descr.kwargs, v0_5.StardistPostprocessingKwargs3D): 

894 return StardistPostprocessing3D.from_proc_descr(proc_descr, member_id) 

895 else: 

896 raise ValueError( 

897 f"expected ndim 2 or 3 for stardist postprocessing, but got {proc_descr.kwargs.ndim}" 

898 ) 

899 elif isinstance( 

900 proc_descr, 

901 (v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr), 

902 ): 

903 return ZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id) 

904 else: 

905 assert_never(proc_descr)