Coverage for src/bioimageio/core/proc_ops.py: 76%

416 statements  

« prev     ^ index     » next       coverage.py v7.14.2, created at 2026-06-22 16:54 +0000

1import collections.abc 

2from dataclasses import InitVar, dataclass, field 

3from functools import partial 

4from typing import ( 

5 Any, 

6 Callable, 

7 Collection, 

8 Literal, 

9 Mapping, 

10 Optional, 

11 Sequence, 

12 Set, 

13 Tuple, 

14 Union, 

15) 

16 

17import numpy as np 

18import scipy 

19import xarray as xr 

20from numpy.typing import NDArray 

21from typing_extensions import Self, assert_never 

22 

23from bioimageio.spec.model import v0_4, v0_5 

24from bioimageio.spec.model.v0_5 import ( 

25 _convert_proc, # pyright: ignore[reportPrivateUsage] 

26) 

27 

28from ._op_base import BlockwiseOperator, SamplewiseOperator, SimpleOperator 

29from ._ops_cellpose import CellposeFlowDynamics 

30from ._ops_stardist import StardistPostprocessing2D as StardistPostprocessing2D 

31from ._ops_stardist import StardistPostprocessing3D as StardistPostprocessing3D 

32from .axis import AxisId 

33from .common import DTypeStr, MemberId 

34from .digest_spec import get_member_id, import_callable 

35from .sample import Sample, SampleBlock 

36from .stat_calculators import StatsCalculator 

37from .stat_measures import ( 

38 DatasetMean, 

39 DatasetQuantile, 

40 DatasetStd, 

41 MeanMeasure, 

42 Measure, 

43 MeasureValue, 

44 SampleMean, 

45 SampleQuantile, 

46 SampleStd, 

47 Stat, 

48 StdMeasure, 

49) 

50from .tensor import Tensor 

51 

52 

53def _convert_axis_ids( 

54 axes: v0_4.AxesInCZYX, 

55 mode: Literal["per_sample", "per_dataset"], 

56) -> Tuple[AxisId, ...]: 

57 if not isinstance(axes, str): 

58 return tuple(axes) 

59 

60 if mode == "per_sample": 

61 ret = [] 

62 elif mode == "per_dataset": 

63 ret = [AxisId(v0_5.BATCH_AXIS_ID)] 

64 else: 

65 assert_never(mode) 

66 

67 ret.extend([AxisId(a) for a in axes]) 

68 return tuple(ret) 

69 

70 

71@dataclass 

72class AddKnownDatasetStats(BlockwiseOperator): 

73 dataset_stats: Mapping[Measure, MeasureValue] 

74 

75 def __post_init__(self): 

76 # keep only dataset measures 

77 self.dataset_stats = { 

78 k: v for k, v in self.dataset_stats.items() if k.scope == "dataset" 

79 } 

80 

81 @property 

82 def required_measures(self) -> Collection[Measure]: 

83 return set() 

84 

85 def __call__(self, sample: Union[Sample, SampleBlock]) -> None: 

86 sample.stat.update(self.dataset_stats) 

87 

88 

89# @dataclass 

90# class UpdateStats(Operator): 

91# """Calculates sample and/or dataset measures""" 

92 

93# measures: Union[Sequence[Measure], Set[Measure], Mapping[Measure, MeasureValue]] 

94# """sample and dataset `measuers` to be calculated by this operator. Initial/fixed 

95# dataset measure values may be given, see `keep_updating_dataset_stats` for details. 

96# """ 

97# keep_updating_dataset_stats: Optional[bool] = None 

98# """indicates if operator calls should keep updating dataset statistics or not 

99 

100# default (None): if `measures` is a `Mapping` (i.e. initial measure values are 

101# given) no further updates to dataset statistics is conducted, otherwise (w.o. 

102# initial measure values) dataset statistics are updated by each processed sample. 

103# """ 

104# _keep_updating_dataset_stats: bool = field(init=False) 

105# _stats_calculator: StatsCalculator = field(init=False) 

106 

107# @property 

108# def required_measures(self) -> Collection[Measure]: 

109# return set() 

110 

111# def __post_init__(self): 

112# self._stats_calculator = StatsCalculator(self.measures) 

113# if self.keep_updating_dataset_stats is None: 

114# self._keep_updating_dataset_stats = not isinstance(self.measures, collections.abc.Mapping) 

115# else: 

116# self._keep_updating_dataset_stats = self.keep_updating_dataset_stats 

117 

118# def __call__(self, sample_block: SampleBlockWithOrigin> None: 

119# if self._keep_updating_dataset_stats: 

120# sample.stat.update(self._stats_calculator.update_and_get_all(sample)) 

121# else: 

122# sample.stat.update(self._stats_calculator.skip_update_and_get_all(sample)) 

123 

124 

125@dataclass 

126class UpdateStats(SamplewiseOperator): 

127 """Calculates sample and/or dataset measures""" 

128 

129 stats_calculator: StatsCalculator 

130 """`StatsCalculator` to be used by this operator.""" 

131 keep_updating_initial_dataset_stats: bool = False 

132 """indicates if operator calls should keep updating initial dataset statistics or not; 

133 if the `stats_calculator` was not provided with any initial dataset statistics, 

134 these are always updated with every new sample. 

135 """ 

136 _keep_updating_dataset_stats: bool = field(init=False) 

137 

138 @property 

139 def required_measures(self) -> Collection[Measure]: 

140 return set() 

141 

142 def __post_init__(self): 

143 self._keep_updating_dataset_stats = ( 

144 self.keep_updating_initial_dataset_stats 

145 or not self.stats_calculator.has_dataset_measures 

146 ) 

147 

148 def __call__(self, sample: Sample) -> None: 

149 if self._keep_updating_dataset_stats: 

150 sample.stat.update(self.stats_calculator.update_and_get_all(sample)) 

151 else: 

152 sample.stat.update(self.stats_calculator.skip_update_and_get_all(sample)) 

153 

154 

155@dataclass 

156class Binarize(SimpleOperator): 

157 """'output = tensor > threshold'.""" 

158 

159 threshold: Union[float, Sequence[float]] 

160 axis: Optional[AxisId] = None 

161 

162 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

163 return x > self.threshold 

164 

165 @property 

166 def required_measures(self) -> Collection[Measure]: 

167 return set() 

168 

169 def get_output_shape( 

170 self, input_shape: Mapping[AxisId, int] 

171 ) -> Mapping[AxisId, int]: 

172 return input_shape 

173 

174 @classmethod 

175 def from_proc_descr( 

176 cls, descr: Union[v0_4.BinarizeDescr, v0_5.BinarizeDescr], member_id: MemberId 

177 ) -> Self: 

178 if isinstance(descr.kwargs, (v0_4.BinarizeKwargs, v0_5.BinarizeKwargs)): 

179 return cls( 

180 input=member_id, output=member_id, threshold=descr.kwargs.threshold 

181 ) 

182 elif isinstance(descr.kwargs, v0_5.BinarizeAlongAxisKwargs): 

183 return cls( 

184 input=member_id, 

185 output=member_id, 

186 threshold=descr.kwargs.threshold, 

187 axis=descr.kwargs.axis, 

188 ) 

189 else: 

190 assert_never(descr.kwargs) 

191 

192 

193@dataclass 

194class Clip(SimpleOperator): 

195 min: Optional[Union[float, SampleQuantile, DatasetQuantile]] = None 

196 """minimum value for clipping""" 

197 max: Optional[Union[float, SampleQuantile, DatasetQuantile]] = None 

198 """maximum value for clipping""" 

199 

200 def __post_init__(self): 

201 if self.min is None and self.max is None: 

202 raise ValueError("missing min or max value") 

203 

204 if ( 

205 isinstance(self.min, float) 

206 and isinstance(self.max, float) 

207 and self.min >= self.max 

208 ): 

209 raise ValueError(f"expected min < max, but {self.min} >= {self.max}") 

210 

211 if isinstance(self.min, (SampleQuantile, DatasetQuantile)) and isinstance( 

212 self.max, (SampleQuantile, DatasetQuantile) 

213 ): 

214 if self.min.axes != self.max.axes: 

215 raise NotImplementedError( 

216 f"expected min and max quantiles with same axes, but got {self.min.axes} and {self.max.axes}" 

217 ) 

218 if self.min.q >= self.max.q: 

219 raise ValueError( 

220 f"expected min quantile < max quantile, but {self.min.q} >= {self.max.q}" 

221 ) 

222 

223 @property 

224 def required_measures(self): 

225 return { 

226 arg 

227 for arg in (self.min, self.max) 

228 if isinstance(arg, (SampleQuantile, DatasetQuantile)) 

229 } 

230 

231 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

232 if isinstance(self.min, (SampleQuantile, DatasetQuantile)): 

233 min_value = stat[self.min] 

234 if isinstance(min_value, (int, float)): 

235 # use clip for scalar value 

236 min_clip_arg = min_value 

237 else: 

238 # clip does not support non-scalar values 

239 x = Tensor.from_xarray( 

240 x.data.where(x.data >= min_value.data, min_value.data) 

241 ) 

242 min_clip_arg = None 

243 else: 

244 min_clip_arg = self.min 

245 

246 if isinstance(self.max, (SampleQuantile, DatasetQuantile)): 

247 max_value = stat[self.max] 

248 if isinstance(max_value, (int, float)): 

249 # use clip for scalar value 

250 max_clip_arg = max_value 

251 else: 

252 # clip does not support non-scalar values 

253 x = Tensor.from_xarray( 

254 x.data.where(x.data <= max_value.data, max_value.data) 

255 ) 

256 max_clip_arg = None 

257 else: 

258 max_clip_arg = self.max 

259 

260 if min_clip_arg is not None or max_clip_arg is not None: 

261 x = x.clip(min_clip_arg, max_clip_arg) 

262 

263 return x 

264 

265 def get_output_shape( 

266 self, input_shape: Mapping[AxisId, int] 

267 ) -> Mapping[AxisId, int]: 

268 return input_shape 

269 

270 @classmethod 

271 def from_proc_descr( 

272 cls, descr: Union[v0_4.ClipDescr, v0_5.ClipDescr], member_id: MemberId 

273 ) -> Self: 

274 if isinstance(descr, v0_5.ClipDescr): 

275 dataset_mode, axes = _get_axes(descr.kwargs) 

276 if dataset_mode: 

277 Quantile = DatasetQuantile 

278 else: 

279 Quantile = partial(SampleQuantile, method="inverted_cdf") 

280 

281 if descr.kwargs.min is not None: 

282 min_arg = descr.kwargs.min 

283 elif descr.kwargs.min_percentile is not None: 

284 min_arg = Quantile( 

285 q=descr.kwargs.min_percentile / 100, 

286 axes=axes, 

287 member_id=member_id, 

288 ) 

289 else: 

290 min_arg = None 

291 

292 if descr.kwargs.max is not None: 

293 max_arg = descr.kwargs.max 

294 elif descr.kwargs.max_percentile is not None: 

295 max_arg = Quantile( 

296 q=descr.kwargs.max_percentile / 100, 

297 axes=axes, 

298 member_id=member_id, 

299 ) 

300 else: 

301 max_arg = None 

302 

303 elif isinstance(descr, v0_4.ClipDescr): 

304 min_arg = descr.kwargs.min 

305 max_arg = descr.kwargs.max 

306 else: 

307 assert_never(descr) 

308 

309 return cls( 

310 input=member_id, 

311 output=member_id, 

312 min=min_arg, 

313 max=max_arg, 

314 ) 

315 

316 

317@dataclass 

318class EnsureDtype(SimpleOperator): 

319 dtype: DTypeStr 

320 

321 @classmethod 

322 def from_proc_descr(cls, descr: v0_5.EnsureDtypeDescr, member_id: MemberId): 

323 return cls(input=member_id, output=member_id, dtype=descr.kwargs.dtype) 

324 

325 def get_descr(self): 

326 return v0_5.EnsureDtypeDescr(kwargs=v0_5.EnsureDtypeKwargs(dtype=self.dtype)) 

327 

328 def get_output_shape( 

329 self, input_shape: Mapping[AxisId, int] 

330 ) -> Mapping[AxisId, int]: 

331 return input_shape 

332 

333 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

334 return x.astype(self.dtype) 

335 

336 @property 

337 def required_measures(self) -> Collection[Measure]: 

338 return set() 

339 

340 

341@dataclass 

342class ScaleLinear(SimpleOperator): 

343 gain: Union[float, xr.DataArray] = 1.0 

344 """multiplicative factor""" 

345 

346 offset: Union[float, xr.DataArray] = 0.0 

347 """additive term""" 

348 

349 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

350 return x * self.gain + self.offset 

351 

352 @property 

353 def required_measures(self) -> Collection[Measure]: 

354 return set() 

355 

356 def get_output_shape( 

357 self, input_shape: Mapping[AxisId, int] 

358 ) -> Mapping[AxisId, int]: 

359 return input_shape 

360 

361 @classmethod 

362 def from_proc_descr( 

363 cls, 

364 descr: Union[v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr], 

365 member_id: MemberId, 

366 ) -> Self: 

367 kwargs = descr.kwargs 

368 if isinstance(kwargs, v0_5.ScaleLinearKwargs): 

369 axis = None 

370 elif isinstance(kwargs, v0_5.ScaleLinearAlongAxisKwargs): 

371 axis = kwargs.axis 

372 elif isinstance(kwargs, v0_4.ScaleLinearKwargs): 

373 if kwargs.axes is not None: 

374 raise NotImplementedError( 

375 "model.v0_4.ScaleLinearKwargs with axes not implemented, please consider updating the model to v0_5." 

376 ) 

377 axis = None 

378 else: 

379 assert_never(kwargs) 

380 

381 if axis: 

382 gain = xr.DataArray(np.atleast_1d(kwargs.gain), dims=axis) 

383 offset = xr.DataArray(np.atleast_1d(kwargs.offset), dims=axis) 

384 else: 

385 assert isinstance(kwargs.gain, (float, int)) or len(kwargs.gain) == 1, ( 

386 kwargs.gain 

387 ) 

388 gain = ( 

389 kwargs.gain if isinstance(kwargs.gain, (float, int)) else kwargs.gain[0] 

390 ) 

391 assert isinstance(kwargs.offset, (float, int)) or len(kwargs.offset) == 1 

392 offset = ( 

393 kwargs.offset 

394 if isinstance(kwargs.offset, (float, int)) 

395 else kwargs.offset[0] 

396 ) 

397 

398 return cls(input=member_id, output=member_id, gain=gain, offset=offset) 

399 

400 

401@dataclass 

402class ScaleMeanVariance(SimpleOperator): 

403 axes: Optional[Sequence[AxisId]] = None 

404 reference_tensor: Optional[MemberId] = None 

405 eps: float = 1e-6 

406 mean: Union[SampleMean, DatasetMean] = field(init=False) 

407 std: Union[SampleStd, DatasetStd] = field(init=False) 

408 ref_mean: Union[SampleMean, DatasetMean] = field(init=False) 

409 ref_std: Union[SampleStd, DatasetStd] = field(init=False) 

410 

411 @property 

412 def required_measures(self): 

413 return {self.mean, self.std, self.ref_mean, self.ref_std} 

414 

415 def __post_init__(self): 

416 axes = None if self.axes is None else tuple(self.axes) 

417 ref_tensor = self.reference_tensor or self.input 

418 if axes is None or AxisId("batch") not in axes: 

419 Mean = SampleMean 

420 Std = SampleStd 

421 else: 

422 Mean = DatasetMean 

423 Std = DatasetStd 

424 

425 self.mean = Mean(member_id=self.input, axes=axes) 

426 self.std = Std(member_id=self.input, axes=axes) 

427 self.ref_mean = Mean(member_id=ref_tensor, axes=axes) 

428 self.ref_std = Std(member_id=ref_tensor, axes=axes) 

429 

430 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

431 mean = stat[self.mean] 

432 std = stat[self.std] + self.eps 

433 ref_mean = stat[self.ref_mean] 

434 ref_std = stat[self.ref_std] + self.eps 

435 return (x - mean) / std * ref_std + ref_mean 

436 

437 def get_output_shape( 

438 self, input_shape: Mapping[AxisId, int] 

439 ) -> Mapping[AxisId, int]: 

440 return input_shape 

441 

442 @classmethod 

443 def from_proc_descr( 

444 cls, 

445 descr: Union[v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr], 

446 member_id: MemberId, 

447 ) -> Self: 

448 kwargs = descr.kwargs 

449 _, axes = _get_axes(descr.kwargs) 

450 

451 return cls( 

452 input=member_id, 

453 output=member_id, 

454 reference_tensor=MemberId(str(kwargs.reference_tensor)), 

455 axes=axes, 

456 eps=kwargs.eps, 

457 ) 

458 

459 

460def _get_axes( 

461 kwargs: Union[ 

462 v0_4.ZeroMeanUnitVarianceKwargs, 

463 v0_5.ZeroMeanUnitVarianceKwargs, 

464 v0_4.ScaleRangeKwargs, 

465 v0_5.ScaleRangeKwargs, 

466 v0_4.ScaleMeanVarianceKwargs, 

467 v0_5.ScaleMeanVarianceKwargs, 

468 v0_5.ClipKwargs, 

469 ], 

470) -> Tuple[bool, Optional[Tuple[AxisId, ...]]]: 

471 if kwargs.axes is None: 

472 return True, None 

473 elif isinstance(kwargs.axes, str): 

474 axes = _convert_axis_ids(kwargs.axes, kwargs["mode"]) 

475 return AxisId("b") in axes, axes 

476 elif isinstance(kwargs.axes, collections.abc.Sequence): 

477 axes = tuple(kwargs.axes) 

478 return AxisId("batch") in axes, axes 

479 else: 

480 assert_never(kwargs.axes) 

481 

482 

483@dataclass 

484class ScaleRange(SimpleOperator): 

485 lower_quantile: InitVar[Optional[Union[SampleQuantile, DatasetQuantile]]] = None 

486 upper_quantile: InitVar[Optional[Union[SampleQuantile, DatasetQuantile]]] = None 

487 lower: Union[SampleQuantile, DatasetQuantile] = field(init=False) 

488 upper: Union[SampleQuantile, DatasetQuantile] = field(init=False) 

489 

490 eps: float = 1e-6 

491 

492 def __post_init__( 

493 self, 

494 lower_quantile: Optional[Union[SampleQuantile, DatasetQuantile]], 

495 upper_quantile: Optional[Union[SampleQuantile, DatasetQuantile]], 

496 ): 

497 if lower_quantile is None: 

498 tid = self.input if upper_quantile is None else upper_quantile.member_id 

499 self.lower = DatasetQuantile(q=0.0, member_id=tid) 

500 else: 

501 self.lower = lower_quantile 

502 

503 if upper_quantile is None: 

504 self.upper = DatasetQuantile(q=1.0, member_id=self.lower.member_id) 

505 else: 

506 self.upper = upper_quantile 

507 

508 assert self.lower.member_id == self.upper.member_id 

509 assert self.lower.q < self.upper.q 

510 assert self.lower.axes == self.upper.axes 

511 

512 @property 

513 def required_measures(self): 

514 return {self.lower, self.upper} 

515 

516 def get_output_shape( 

517 self, input_shape: Mapping[AxisId, int] 

518 ) -> Mapping[AxisId, int]: 

519 return input_shape 

520 

521 @classmethod 

522 def from_proc_descr( 

523 cls, 

524 descr: Union[v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr], 

525 member_id: MemberId, 

526 ): 

527 kwargs = descr.kwargs 

528 ref_tensor = ( 

529 member_id 

530 if kwargs.reference_tensor is None 

531 else MemberId(str(kwargs.reference_tensor)) 

532 ) 

533 dataset_mode, axes = _get_axes(descr.kwargs) 

534 if dataset_mode: 

535 Quantile = DatasetQuantile 

536 else: 

537 Quantile = partial(SampleQuantile, method="linear") 

538 

539 return cls( 

540 input=member_id, 

541 output=member_id, 

542 lower_quantile=Quantile( 

543 q=kwargs.min_percentile / 100, 

544 axes=axes, 

545 member_id=ref_tensor, 

546 ), 

547 upper_quantile=Quantile( 

548 q=kwargs.max_percentile / 100, 

549 axes=axes, 

550 member_id=ref_tensor, 

551 ), 

552 ) 

553 

554 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

555 lower = stat[self.lower] 

556 upper = stat[self.upper] 

557 return (x - lower) / (upper - lower + self.eps) 

558 

559 def get_descr(self): 

560 assert self.lower.axes == self.upper.axes 

561 assert self.lower.member_id == self.upper.member_id 

562 

563 return v0_5.ScaleRangeDescr( 

564 kwargs=v0_5.ScaleRangeKwargs( 

565 axes=None 

566 if self.lower.axes is None 

567 else [v0_5.AxisId(a) for a in self.lower.axes], 

568 min_percentile=self.lower.q * 100, 

569 max_percentile=self.upper.q * 100, 

570 eps=self.eps, 

571 reference_tensor=self.lower.member_id, 

572 ) 

573 ) 

574 

575 

576@dataclass 

577class Sigmoid(SimpleOperator): 

578 """1 / (1 + e^(-input)).""" 

579 

580 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

581 return Tensor(1.0 / (1.0 + np.exp(-x)), dims=x.dims) 

582 

583 @property 

584 def required_measures(self) -> Collection[Measure]: 

585 return {} 

586 

587 def get_output_shape( 

588 self, input_shape: Mapping[AxisId, int] 

589 ) -> Mapping[AxisId, int]: 

590 return input_shape 

591 

592 @classmethod 

593 def from_proc_descr( 

594 cls, descr: Union[v0_4.SigmoidDescr, v0_5.SigmoidDescr], member_id: MemberId 

595 ) -> Self: 

596 assert isinstance(descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)) 

597 return cls(input=member_id, output=member_id) 

598 

599 def get_descr(self): 

600 return v0_5.SigmoidDescr() 

601 

602 

603@dataclass 

604class Softmax(SimpleOperator): 

605 """Softmax activation function.""" 

606 

607 axis: AxisId = AxisId("channel") 

608 

609 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

610 axis_idx = x.dims.index(self.axis) 

611 result = scipy.special.softmax(x.data, axis=axis_idx) 

612 result_xr = xr.DataArray(result, dims=x.dims) 

613 return Tensor.from_xarray(result_xr) 

614 

615 @property 

616 def required_measures(self) -> Collection[Measure]: 

617 return set() 

618 

619 def get_output_shape( 

620 self, input_shape: Mapping[AxisId, int] 

621 ) -> Mapping[AxisId, int]: 

622 return input_shape 

623 

624 @classmethod 

625 def from_proc_descr(cls, descr: v0_5.SoftmaxDescr, member_id: MemberId) -> Self: 

626 assert isinstance(descr, v0_5.SoftmaxDescr) 

627 return cls(input=member_id, output=member_id, axis=descr.kwargs.axis) 

628 

629 def get_descr(self): 

630 return v0_5.SoftmaxDescr(kwargs=v0_5.SoftmaxKwargs(axis=v0_5.AxisId(self.axis))) 

631 

632 

633@dataclass 

634class ZeroMeanUnitVariance(SimpleOperator): 

635 """normalize to zero mean, unit variance.""" 

636 

637 mean: MeanMeasure 

638 std: StdMeasure 

639 

640 eps: float = 1e-6 

641 

642 def __post_init__(self): 

643 assert self.mean.axes == self.std.axes 

644 

645 @property 

646 def required_measures(self) -> Set[Union[MeanMeasure, StdMeasure]]: 

647 return {self.mean, self.std} 

648 

649 def get_output_shape( 

650 self, input_shape: Mapping[AxisId, int] 

651 ) -> Mapping[AxisId, int]: 

652 return input_shape 

653 

654 @classmethod 

655 def from_proc_descr( 

656 cls, 

657 descr: Union[v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr], 

658 member_id: MemberId, 

659 ): 

660 dataset_mode, axes = _get_axes(descr.kwargs) 

661 

662 if dataset_mode: 

663 Mean = DatasetMean 

664 Std = DatasetStd 

665 else: 

666 Mean = SampleMean 

667 Std = SampleStd 

668 

669 return cls( 

670 input=member_id, 

671 output=member_id, 

672 mean=Mean(axes=axes, member_id=member_id), 

673 std=Std(axes=axes, member_id=member_id), 

674 ) 

675 

676 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

677 mean = stat[self.mean] 

678 std = stat[self.std] 

679 return (x - mean) / (std + self.eps) 

680 

681 def get_descr(self): 

682 return v0_5.ZeroMeanUnitVarianceDescr( 

683 kwargs=v0_5.ZeroMeanUnitVarianceKwargs(axes=self.mean.axes, eps=self.eps) 

684 ) 

685 

686 

687@dataclass 

688class FixedZeroMeanUnitVariance(SimpleOperator): 

689 """normalize to zero mean, unit variance with precomputed values.""" 

690 

691 mean: Union[float, xr.DataArray] 

692 std: Union[float, xr.DataArray] 

693 

694 eps: float = 1e-6 

695 

696 def __post_init__(self): 

697 assert ( 

698 isinstance(self.mean, (int, float)) 

699 or isinstance(self.std, (int, float)) 

700 or self.mean.dims == self.std.dims 

701 ) 

702 

703 @property 

704 def required_measures(self) -> Collection[Measure]: 

705 return set() 

706 

707 def get_output_shape( 

708 self, input_shape: Mapping[AxisId, int] 

709 ) -> Mapping[AxisId, int]: 

710 return input_shape 

711 

712 @classmethod 

713 def from_proc_descr( 

714 cls, 

715 descr: v0_5.FixedZeroMeanUnitVarianceDescr, 

716 member_id: MemberId, 

717 ) -> Self: 

718 if isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceKwargs): 

719 dims = None 

720 elif isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs): 

721 dims = (AxisId(descr.kwargs.axis),) 

722 else: 

723 assert_never(descr.kwargs) 

724 

725 return cls( 

726 input=member_id, 

727 output=member_id, 

728 mean=xr.DataArray(descr.kwargs.mean, dims=dims), 

729 std=xr.DataArray(descr.kwargs.std, dims=dims), 

730 ) 

731 

732 def get_descr(self): 

733 if isinstance(self.mean, (int, float)): 

734 assert isinstance(self.std, (int, float)) 

735 kwargs = v0_5.FixedZeroMeanUnitVarianceKwargs(mean=self.mean, std=self.std) 

736 else: 

737 assert isinstance(self.std, xr.DataArray) 

738 assert len(self.mean.dims) == 1 

739 kwargs = v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs( 

740 axis=AxisId(str(self.mean.dims[0])), 

741 mean=list(self.mean), 

742 std=list(self.std), 

743 ) 

744 

745 return v0_5.FixedZeroMeanUnitVarianceDescr(kwargs=kwargs) 

746 

747 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

748 return (x - self.mean) / (self.std + self.eps) 

749 

750 

751@dataclass 

752class CustomProcessing(SimpleOperator): 

753 """Execute a user-supplied custom processing callable. 

754 

755 Two styles are supported — callable class and factory function:: 

756 

757 # Callable class style 

758 class my_factory: 

759 def __init__(self, threshold=0.5): 

760 self.threshold = threshold 

761 def __call__(self, *arrays): 

762 return (arrays[0] > self.threshold).astype(np.uint8) 

763 

764 # Factory function style 

765 def my_factory(threshold=0.5): 

766 def run(*arrays): 

767 return (arrays[0] > threshold).astype(np.uint8) 

768 return run 

769 

770 Runtime protocol: ``custom_callable = my_factory(**kwargs)`` once at construction; 

771 ``result = custom_callable(tensor)`` once per sample. 

772 

773 Note: The custom callable may not change the shape of the input tensor. 

774 """ 

775 

776 custom_factory: Callable[..., Callable[[NDArray[Any]], NDArray[Any]]] 

777 

778 kwargs: Mapping[str, Any] 

779 """Keyword arguments forwarded to the custom factory.""" 

780 

781 # Initialised in __post_init__ 

782 _custom_callable: Any = field(init=False, repr=False) 

783 

784 def __post_init__(self) -> None: 

785 self._custom_callable = self.custom_factory(**self.kwargs) 

786 

787 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

788 return Tensor.from_numpy(self._custom_callable(x.to_numpy()), dims=x.dims) 

789 

790 def get_output_shape( 

791 self, input_shape: Mapping[AxisId, int] 

792 ) -> Mapping[AxisId, int]: 

793 return input_shape 

794 

795 @property 

796 def required_measures(self) -> Collection[Measure]: 

797 return set() 

798 

799 @classmethod 

800 def from_proc_descr( 

801 cls, 

802 descr: v0_5.CustomProcessingDescr, 

803 member_id: MemberId, 

804 ) -> Self: 

805 factory = import_callable(descr) 

806 

807 return cls( 

808 input=member_id, 

809 output=member_id, 

810 custom_factory=factory, 

811 kwargs=dict(descr.kwargs), 

812 ) 

813 

814 

815ProcDescr = Union[ 

816 v0_4.PreprocessingDescr, 

817 v0_4.PostprocessingDescr, 

818 v0_5.PreprocessingDescr, 

819 v0_5.PostprocessingDescr, 

820] 

821 

822 

823Processing = Union[ 

824 AddKnownDatasetStats, 

825 Binarize, 

826 Clip, 

827 CellposeFlowDynamics, 

828 CustomProcessing, 

829 EnsureDtype, 

830 FixedZeroMeanUnitVariance, 

831 ScaleLinear, 

832 ScaleMeanVariance, 

833 ScaleRange, 

834 Sigmoid, 

835 StardistPostprocessing2D, 

836 StardistPostprocessing3D, 

837 Softmax, 

838 UpdateStats, 

839 ZeroMeanUnitVariance, 

840] 

841 

842 

843def get_proc( 

844 proc_descr: ProcDescr, 

845 tensor_descr: Union[ 

846 v0_4.InputTensorDescr, 

847 v0_4.OutputTensorDescr, 

848 v0_5.InputTensorDescr, 

849 v0_5.OutputTensorDescr, 

850 ], 

851) -> Processing: 

852 member_id = get_member_id(tensor_descr) 

853 

854 if isinstance(proc_descr, (v0_4.BinarizeDescr, v0_5.BinarizeDescr)): 

855 return Binarize.from_proc_descr(proc_descr, member_id) 

856 elif isinstance(proc_descr, v0_5.CellposeFlowDynamicsDescr): 

857 return CellposeFlowDynamics.from_proc_descr(proc_descr, member_id) 

858 elif isinstance(proc_descr, (v0_4.ClipDescr, v0_5.ClipDescr)): 

859 return Clip.from_proc_descr(proc_descr, member_id) 

860 elif isinstance(proc_descr, v0_5.CustomProcessingDescr): 

861 return CustomProcessing.from_proc_descr(proc_descr, member_id) 

862 elif isinstance(proc_descr, v0_5.EnsureDtypeDescr): 

863 return EnsureDtype.from_proc_descr(proc_descr, member_id) 

864 elif isinstance(proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr): 

865 return FixedZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id) 

866 elif isinstance(proc_descr, (v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr)): 

867 return ScaleLinear.from_proc_descr(proc_descr, member_id) 

868 elif isinstance( 

869 proc_descr, (v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr) 

870 ): 

871 return ScaleMeanVariance.from_proc_descr(proc_descr, member_id) 

872 elif isinstance(proc_descr, (v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr)): 

873 return ScaleRange.from_proc_descr(proc_descr, member_id) 

874 elif isinstance(proc_descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)): 

875 return Sigmoid.from_proc_descr(proc_descr, member_id) 

876 elif ( 

877 isinstance(proc_descr, v0_4.ZeroMeanUnitVarianceDescr) 

878 and proc_descr.kwargs.mode == "fixed" 

879 ): 

880 if not isinstance( 

881 tensor_descr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr) 

882 ): 

883 raise TypeError( 

884 "Expected v0_4 tensor description for v0_4 processing description" 

885 ) 

886 

887 v5_proc_descr = _convert_proc(proc_descr, tensor_descr.axes) 

888 assert isinstance(v5_proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr) 

889 return FixedZeroMeanUnitVariance.from_proc_descr(v5_proc_descr, member_id) 

890 elif isinstance(proc_descr, v0_5.SoftmaxDescr): 

891 return Softmax.from_proc_descr(proc_descr, member_id) 

892 elif isinstance(proc_descr, v0_5.StardistPostprocessingDescr): 

893 if isinstance(proc_descr.kwargs, v0_5.StardistPostprocessingKwargs2D): 

894 return StardistPostprocessing2D.from_proc_descr(proc_descr, member_id) 

895 elif isinstance(proc_descr.kwargs, v0_5.StardistPostprocessingKwargs3D): 

896 return StardistPostprocessing3D.from_proc_descr(proc_descr, member_id) 

897 else: 

898 raise ValueError( 

899 f"expected ndim 2 or 3 for stardist postprocessing, but got {proc_descr.kwargs.ndim}" 

900 ) 

901 elif isinstance( 

902 proc_descr, 

903 (v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr), 

904 ): 

905 return ZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id) 

906 else: 

907 assert_never(proc_descr)