Coverage for src / bioimageio / core / proc_ops.py: 76%

395 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-13 09:46 +0000

1import collections.abc 

2from abc import ABC, abstractmethod 

3from dataclasses import InitVar, dataclass, field 

4from functools import partial 

5from typing import ( 

6 Collection, 

7 Literal, 

8 Mapping, 

9 Optional, 

10 Sequence, 

11 Set, 

12 Tuple, 

13 Union, 

14) 

15 

16import numpy as np 

17import scipy # pyright: ignore[reportMissingTypeStubs] 

18import xarray as xr 

19from typing_extensions import Self, assert_never 

20 

21from bioimageio.core.digest_spec import get_member_id 

22from bioimageio.spec.model import v0_4, v0_5 

23from bioimageio.spec.model.v0_5 import ( 

24 _convert_proc, # pyright: ignore [reportPrivateUsage] 

25) 

26 

27from ._op_base import BlockedOperator, Operator 

28from .axis import AxisId, PerAxis 

29from .block import Block 

30from .common import DTypeStr, MemberId 

31from .sample import Sample, SampleBlock, SampleBlockWithOrigin 

32from .stat_calculators import StatsCalculator 

33from .stat_measures import ( 

34 DatasetMean, 

35 DatasetMeasure, 

36 DatasetQuantile, 

37 DatasetStd, 

38 MeanMeasure, 

39 Measure, 

40 MeasureValue, 

41 SampleMean, 

42 SampleQuantile, 

43 SampleStd, 

44 Stat, 

45 StdMeasure, 

46) 

47from .tensor import Tensor 

48 

49 

50def _convert_axis_ids( 

51 axes: v0_4.AxesInCZYX, 

52 mode: Literal["per_sample", "per_dataset"], 

53) -> Tuple[AxisId, ...]: 

54 if not isinstance(axes, str): 

55 return tuple(axes) 

56 

57 if mode == "per_sample": 

58 ret = [] 

59 elif mode == "per_dataset": 

60 ret = [v0_5.BATCH_AXIS_ID] 

61 else: 

62 assert_never(mode) 

63 

64 ret.extend([AxisId(a) for a in axes]) 

65 return tuple(ret) 

66 

67 

68@dataclass 

69class _SimpleOperator(BlockedOperator, ABC): 

70 input: MemberId 

71 output: MemberId 

72 

73 @property 

74 def required_measures(self) -> Collection[Measure]: 

75 return set() 

76 

77 @abstractmethod 

78 def get_output_shape(self, input_shape: PerAxis[int]) -> PerAxis[int]: ... 

79 

80 def __call__(self, sample: Union[Sample, SampleBlock]) -> None: 

81 if self.input not in sample.members: 

82 return 

83 

84 input_tensor = sample.members[self.input] 

85 output_tensor = self._apply(input_tensor, sample.stat) 

86 

87 if self.output in sample.members: 

88 assert ( 

89 sample.members[self.output].tagged_shape == output_tensor.tagged_shape 

90 ) 

91 

92 if isinstance(sample, Sample): 

93 sample.members[self.output] = output_tensor 

94 elif isinstance(sample, SampleBlock): 

95 b = sample.blocks[self.input] 

96 sample.blocks[self.output] = Block( 

97 sample_shape=self.get_output_shape(sample.shape[self.input]), 

98 data=output_tensor, 

99 inner_slice=b.inner_slice, 

100 halo=b.halo, 

101 block_index=b.block_index, 

102 blocks_in_sample=b.blocks_in_sample, 

103 ) 

104 else: 

105 assert_never(sample) 

106 

107 @abstractmethod 

108 def _apply(self, x: Tensor, stat: Stat) -> Tensor: ... 

109 

110 

111@dataclass 

112class AddKnownDatasetStats(BlockedOperator): 

113 dataset_stats: Mapping[DatasetMeasure, MeasureValue] 

114 

115 @property 

116 def required_measures(self) -> Set[Measure]: 

117 return set() 

118 

119 def __call__(self, sample: Union[Sample, SampleBlock]) -> None: 

120 sample.stat.update(self.dataset_stats.items()) 

121 

122 

123# @dataclass 

124# class UpdateStats(Operator): 

125# """Calculates sample and/or dataset measures""" 

126 

127# measures: Union[Sequence[Measure], Set[Measure], Mapping[Measure, MeasureValue]] 

128# """sample and dataset `measuers` to be calculated by this operator. Initial/fixed 

129# dataset measure values may be given, see `keep_updating_dataset_stats` for details. 

130# """ 

131# keep_updating_dataset_stats: Optional[bool] = None 

132# """indicates if operator calls should keep updating dataset statistics or not 

133 

134# default (None): if `measures` is a `Mapping` (i.e. initial measure values are 

135# given) no further updates to dataset statistics is conducted, otherwise (w.o. 

136# initial measure values) dataset statistics are updated by each processed sample. 

137# """ 

138# _keep_updating_dataset_stats: bool = field(init=False) 

139# _stats_calculator: StatsCalculator = field(init=False) 

140 

141# @property 

142# def required_measures(self) -> Set[Measure]: 

143# return set() 

144 

145# def __post_init__(self): 

146# self._stats_calculator = StatsCalculator(self.measures) 

147# if self.keep_updating_dataset_stats is None: 

148# self._keep_updating_dataset_stats = not isinstance(self.measures, collections.abc.Mapping) 

149# else: 

150# self._keep_updating_dataset_stats = self.keep_updating_dataset_stats 

151 

152# def __call__(self, sample_block: SampleBlockWithOrigin> None: 

153# if self._keep_updating_dataset_stats: 

154# sample.stat.update(self._stats_calculator.update_and_get_all(sample)) 

155# else: 

156# sample.stat.update(self._stats_calculator.skip_update_and_get_all(sample)) 

157 

158 

159@dataclass 

160class UpdateStats(Operator): 

161 """Calculates sample and/or dataset measures""" 

162 

163 stats_calculator: StatsCalculator 

164 """`StatsCalculator` to be used by this operator.""" 

165 keep_updating_initial_dataset_stats: bool = False 

166 """indicates if operator calls should keep updating initial dataset statistics or not; 

167 if the `stats_calculator` was not provided with any initial dataset statistics, 

168 these are always updated with every new sample. 

169 """ 

170 _keep_updating_dataset_stats: bool = field(init=False) 

171 

172 @property 

173 def required_measures(self) -> Set[Measure]: 

174 return set() 

175 

176 def __post_init__(self): 

177 self._keep_updating_dataset_stats = ( 

178 self.keep_updating_initial_dataset_stats 

179 or not self.stats_calculator.has_dataset_measures 

180 ) 

181 

182 def __call__(self, sample: Union[Sample, SampleBlockWithOrigin]) -> None: 

183 if isinstance(sample, SampleBlockWithOrigin): 

184 # update stats with whole sample on first block 

185 if sample.block_index != 0: 

186 return 

187 

188 origin = sample.origin 

189 else: 

190 origin = sample 

191 

192 if self._keep_updating_dataset_stats: 

193 sample.stat.update(self.stats_calculator.update_and_get_all(origin)) 

194 else: 

195 sample.stat.update(self.stats_calculator.skip_update_and_get_all(origin)) 

196 

197 

198@dataclass 

199class Binarize(_SimpleOperator): 

200 """'output = tensor > threshold'.""" 

201 

202 threshold: Union[float, Sequence[float]] 

203 axis: Optional[AxisId] = None 

204 

205 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

206 return x > self.threshold 

207 

208 def get_output_shape( 

209 self, input_shape: Mapping[AxisId, int] 

210 ) -> Mapping[AxisId, int]: 

211 return input_shape 

212 

213 @classmethod 

214 def from_proc_descr( 

215 cls, descr: Union[v0_4.BinarizeDescr, v0_5.BinarizeDescr], member_id: MemberId 

216 ) -> Self: 

217 if isinstance(descr.kwargs, (v0_4.BinarizeKwargs, v0_5.BinarizeKwargs)): 

218 return cls( 

219 input=member_id, output=member_id, threshold=descr.kwargs.threshold 

220 ) 

221 elif isinstance(descr.kwargs, v0_5.BinarizeAlongAxisKwargs): 

222 return cls( 

223 input=member_id, 

224 output=member_id, 

225 threshold=descr.kwargs.threshold, 

226 axis=descr.kwargs.axis, 

227 ) 

228 else: 

229 assert_never(descr.kwargs) 

230 

231 

232@dataclass 

233class Clip(_SimpleOperator): 

234 min: Optional[Union[float, SampleQuantile, DatasetQuantile]] = None 

235 """minimum value for clipping""" 

236 max: Optional[Union[float, SampleQuantile, DatasetQuantile]] = None 

237 """maximum value for clipping""" 

238 

239 def __post_init__(self): 

240 if self.min is None and self.max is None: 

241 raise ValueError("missing min or max value") 

242 

243 if ( 

244 isinstance(self.min, float) 

245 and isinstance(self.max, float) 

246 and self.min >= self.max 

247 ): 

248 raise ValueError(f"expected min < max, but {self.min} >= {self.max}") 

249 

250 if isinstance(self.min, (SampleQuantile, DatasetQuantile)) and isinstance( 

251 self.max, (SampleQuantile, DatasetQuantile) 

252 ): 

253 if self.min.axes != self.max.axes: 

254 raise NotImplementedError( 

255 f"expected min and max quantiles with same axes, but got {self.min.axes} and {self.max.axes}" 

256 ) 

257 if self.min.q >= self.max.q: 

258 raise ValueError( 

259 f"expected min quantile < max quantile, but {self.min.q} >= {self.max.q}" 

260 ) 

261 

262 @property 

263 def required_measures(self): 

264 return { 

265 arg 

266 for arg in (self.min, self.max) 

267 if isinstance(arg, (SampleQuantile, DatasetQuantile)) 

268 } 

269 

270 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

271 if isinstance(self.min, (SampleQuantile, DatasetQuantile)): 

272 min_value = stat[self.min] 

273 if isinstance(min_value, (int, float)): 

274 # use clip for scalar value 

275 min_clip_arg = min_value 

276 else: 

277 # clip does not support non-scalar values 

278 x = Tensor.from_xarray( 

279 x.data.where(x.data >= min_value.data, min_value.data) 

280 ) 

281 min_clip_arg = None 

282 else: 

283 min_clip_arg = self.min 

284 

285 if isinstance(self.max, (SampleQuantile, DatasetQuantile)): 

286 max_value = stat[self.max] 

287 if isinstance(max_value, (int, float)): 

288 # use clip for scalar value 

289 max_clip_arg = max_value 

290 else: 

291 # clip does not support non-scalar values 

292 x = Tensor.from_xarray( 

293 x.data.where(x.data <= max_value.data, max_value.data) 

294 ) 

295 max_clip_arg = None 

296 else: 

297 max_clip_arg = self.max 

298 

299 if min_clip_arg is not None or max_clip_arg is not None: 

300 x = x.clip(min_clip_arg, max_clip_arg) 

301 

302 return x 

303 

304 def get_output_shape( 

305 self, input_shape: Mapping[AxisId, int] 

306 ) -> Mapping[AxisId, int]: 

307 return input_shape 

308 

309 @classmethod 

310 def from_proc_descr( 

311 cls, descr: Union[v0_4.ClipDescr, v0_5.ClipDescr], member_id: MemberId 

312 ) -> Self: 

313 if isinstance(descr, v0_5.ClipDescr): 

314 dataset_mode, axes = _get_axes(descr.kwargs) 

315 if dataset_mode: 

316 Quantile = DatasetQuantile 

317 else: 

318 Quantile = partial(SampleQuantile, method="inverted_cdf") 

319 

320 if descr.kwargs.min is not None: 

321 min_arg = descr.kwargs.min 

322 elif descr.kwargs.min_percentile is not None: 

323 min_arg = Quantile( 

324 q=descr.kwargs.min_percentile / 100, 

325 axes=axes, 

326 member_id=member_id, 

327 ) 

328 else: 

329 min_arg = None 

330 

331 if descr.kwargs.max is not None: 

332 max_arg = descr.kwargs.max 

333 elif descr.kwargs.max_percentile is not None: 

334 max_arg = Quantile( 

335 q=descr.kwargs.max_percentile / 100, 

336 axes=axes, 

337 member_id=member_id, 

338 ) 

339 else: 

340 max_arg = None 

341 

342 elif isinstance(descr, v0_4.ClipDescr): 

343 min_arg = descr.kwargs.min 

344 max_arg = descr.kwargs.max 

345 else: 

346 assert_never(descr) 

347 

348 return cls( 

349 input=member_id, 

350 output=member_id, 

351 min=min_arg, 

352 max=max_arg, 

353 ) 

354 

355 

356@dataclass 

357class EnsureDtype(_SimpleOperator): 

358 dtype: DTypeStr 

359 

360 @classmethod 

361 def from_proc_descr(cls, descr: v0_5.EnsureDtypeDescr, member_id: MemberId): 

362 return cls(input=member_id, output=member_id, dtype=descr.kwargs.dtype) 

363 

364 def get_descr(self): 

365 return v0_5.EnsureDtypeDescr(kwargs=v0_5.EnsureDtypeKwargs(dtype=self.dtype)) 

366 

367 def get_output_shape( 

368 self, input_shape: Mapping[AxisId, int] 

369 ) -> Mapping[AxisId, int]: 

370 return input_shape 

371 

372 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

373 return x.astype(self.dtype) 

374 

375 

376@dataclass 

377class ScaleLinear(_SimpleOperator): 

378 gain: Union[float, xr.DataArray] = 1.0 

379 """multiplicative factor""" 

380 

381 offset: Union[float, xr.DataArray] = 0.0 

382 """additive term""" 

383 

384 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

385 return x * self.gain + self.offset 

386 

387 def get_output_shape( 

388 self, input_shape: Mapping[AxisId, int] 

389 ) -> Mapping[AxisId, int]: 

390 return input_shape 

391 

392 @classmethod 

393 def from_proc_descr( 

394 cls, 

395 descr: Union[v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr], 

396 member_id: MemberId, 

397 ) -> Self: 

398 kwargs = descr.kwargs 

399 if isinstance(kwargs, v0_5.ScaleLinearKwargs): 

400 axis = None 

401 elif isinstance(kwargs, v0_5.ScaleLinearAlongAxisKwargs): 

402 axis = kwargs.axis 

403 elif isinstance(kwargs, v0_4.ScaleLinearKwargs): 

404 if kwargs.axes is not None: 

405 raise NotImplementedError( 

406 "model.v0_4.ScaleLinearKwargs with axes not implemented, please consider updating the model to v0_5." 

407 ) 

408 axis = None 

409 else: 

410 assert_never(kwargs) 

411 

412 if axis: 

413 gain = xr.DataArray(np.atleast_1d(kwargs.gain), dims=axis) 

414 offset = xr.DataArray(np.atleast_1d(kwargs.offset), dims=axis) 

415 else: 

416 assert isinstance(kwargs.gain, (float, int)) or len(kwargs.gain) == 1, ( 

417 kwargs.gain 

418 ) 

419 gain = ( 

420 kwargs.gain if isinstance(kwargs.gain, (float, int)) else kwargs.gain[0] 

421 ) 

422 assert isinstance(kwargs.offset, (float, int)) or len(kwargs.offset) == 1 

423 offset = ( 

424 kwargs.offset 

425 if isinstance(kwargs.offset, (float, int)) 

426 else kwargs.offset[0] 

427 ) 

428 

429 return cls(input=member_id, output=member_id, gain=gain, offset=offset) 

430 

431 

432@dataclass 

433class ScaleMeanVariance(_SimpleOperator): 

434 axes: Optional[Sequence[AxisId]] = None 

435 reference_tensor: Optional[MemberId] = None 

436 eps: float = 1e-6 

437 mean: Union[SampleMean, DatasetMean] = field(init=False) 

438 std: Union[SampleStd, DatasetStd] = field(init=False) 

439 ref_mean: Union[SampleMean, DatasetMean] = field(init=False) 

440 ref_std: Union[SampleStd, DatasetStd] = field(init=False) 

441 

442 @property 

443 def required_measures(self): 

444 return {self.mean, self.std, self.ref_mean, self.ref_std} 

445 

446 def __post_init__(self): 

447 axes = None if self.axes is None else tuple(self.axes) 

448 ref_tensor = self.reference_tensor or self.input 

449 if axes is None or AxisId("batch") not in axes: 

450 Mean = SampleMean 

451 Std = SampleStd 

452 else: 

453 Mean = DatasetMean 

454 Std = DatasetStd 

455 

456 self.mean = Mean(member_id=self.input, axes=axes) 

457 self.std = Std(member_id=self.input, axes=axes) 

458 self.ref_mean = Mean(member_id=ref_tensor, axes=axes) 

459 self.ref_std = Std(member_id=ref_tensor, axes=axes) 

460 

461 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

462 mean = stat[self.mean] 

463 std = stat[self.std] + self.eps 

464 ref_mean = stat[self.ref_mean] 

465 ref_std = stat[self.ref_std] + self.eps 

466 return (x - mean) / std * ref_std + ref_mean 

467 

468 def get_output_shape( 

469 self, input_shape: Mapping[AxisId, int] 

470 ) -> Mapping[AxisId, int]: 

471 return input_shape 

472 

473 @classmethod 

474 def from_proc_descr( 

475 cls, 

476 descr: Union[v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr], 

477 member_id: MemberId, 

478 ) -> Self: 

479 kwargs = descr.kwargs 

480 _, axes = _get_axes(descr.kwargs) 

481 

482 return cls( 

483 input=member_id, 

484 output=member_id, 

485 reference_tensor=MemberId(str(kwargs.reference_tensor)), 

486 axes=axes, 

487 eps=kwargs.eps, 

488 ) 

489 

490 

491def _get_axes( 

492 kwargs: Union[ 

493 v0_4.ZeroMeanUnitVarianceKwargs, 

494 v0_5.ZeroMeanUnitVarianceKwargs, 

495 v0_4.ScaleRangeKwargs, 

496 v0_5.ScaleRangeKwargs, 

497 v0_4.ScaleMeanVarianceKwargs, 

498 v0_5.ScaleMeanVarianceKwargs, 

499 v0_5.ClipKwargs, 

500 ], 

501) -> Tuple[bool, Optional[Tuple[AxisId, ...]]]: 

502 if kwargs.axes is None: 

503 return True, None 

504 elif isinstance(kwargs.axes, str): 

505 axes = _convert_axis_ids(kwargs.axes, kwargs["mode"]) 

506 return AxisId("b") in axes, axes 

507 elif isinstance(kwargs.axes, collections.abc.Sequence): 

508 axes = tuple(kwargs.axes) 

509 return AxisId("batch") in axes, axes 

510 else: 

511 assert_never(kwargs.axes) 

512 

513 

514@dataclass 

515class ScaleRange(_SimpleOperator): 

516 lower_quantile: InitVar[Optional[Union[SampleQuantile, DatasetQuantile]]] = None 

517 upper_quantile: InitVar[Optional[Union[SampleQuantile, DatasetQuantile]]] = None 

518 lower: Union[SampleQuantile, DatasetQuantile] = field(init=False) 

519 upper: Union[SampleQuantile, DatasetQuantile] = field(init=False) 

520 

521 eps: float = 1e-6 

522 

523 def __post_init__( 

524 self, 

525 lower_quantile: Optional[Union[SampleQuantile, DatasetQuantile]], 

526 upper_quantile: Optional[Union[SampleQuantile, DatasetQuantile]], 

527 ): 

528 if lower_quantile is None: 

529 tid = self.input if upper_quantile is None else upper_quantile.member_id 

530 self.lower = DatasetQuantile(q=0.0, member_id=tid) 

531 else: 

532 self.lower = lower_quantile 

533 

534 if upper_quantile is None: 

535 self.upper = DatasetQuantile(q=1.0, member_id=self.lower.member_id) 

536 else: 

537 self.upper = upper_quantile 

538 

539 assert self.lower.member_id == self.upper.member_id 

540 assert self.lower.q < self.upper.q 

541 assert self.lower.axes == self.upper.axes 

542 

543 @property 

544 def required_measures(self): 

545 return {self.lower, self.upper} 

546 

547 def get_output_shape( 

548 self, input_shape: Mapping[AxisId, int] 

549 ) -> Mapping[AxisId, int]: 

550 return input_shape 

551 

552 @classmethod 

553 def from_proc_descr( 

554 cls, 

555 descr: Union[v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr], 

556 member_id: MemberId, 

557 ): 

558 kwargs = descr.kwargs 

559 ref_tensor = ( 

560 member_id 

561 if kwargs.reference_tensor is None 

562 else MemberId(str(kwargs.reference_tensor)) 

563 ) 

564 dataset_mode, axes = _get_axes(descr.kwargs) 

565 if dataset_mode: 

566 Quantile = DatasetQuantile 

567 else: 

568 Quantile = partial(SampleQuantile, method="linear") 

569 

570 return cls( 

571 input=member_id, 

572 output=member_id, 

573 lower_quantile=Quantile( 

574 q=kwargs.min_percentile / 100, 

575 axes=axes, 

576 member_id=ref_tensor, 

577 ), 

578 upper_quantile=Quantile( 

579 q=kwargs.max_percentile / 100, 

580 axes=axes, 

581 member_id=ref_tensor, 

582 ), 

583 ) 

584 

585 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

586 lower = stat[self.lower] 

587 upper = stat[self.upper] 

588 return (x - lower) / (upper - lower + self.eps) 

589 

590 def get_descr(self): 

591 assert self.lower.axes == self.upper.axes 

592 assert self.lower.member_id == self.upper.member_id 

593 

594 return v0_5.ScaleRangeDescr( 

595 kwargs=v0_5.ScaleRangeKwargs( 

596 axes=self.lower.axes, 

597 min_percentile=self.lower.q * 100, 

598 max_percentile=self.upper.q * 100, 

599 eps=self.eps, 

600 reference_tensor=self.lower.member_id, 

601 ) 

602 ) 

603 

604 

605@dataclass 

606class Sigmoid(_SimpleOperator): 

607 """1 / (1 + e^(-input)).""" 

608 

609 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

610 return Tensor(1.0 / (1.0 + np.exp(-x)), dims=x.dims) 

611 

612 @property 

613 def required_measures(self) -> Collection[Measure]: 

614 return {} 

615 

616 def get_output_shape( 

617 self, input_shape: Mapping[AxisId, int] 

618 ) -> Mapping[AxisId, int]: 

619 return input_shape 

620 

621 @classmethod 

622 def from_proc_descr( 

623 cls, descr: Union[v0_4.SigmoidDescr, v0_5.SigmoidDescr], member_id: MemberId 

624 ) -> Self: 

625 assert isinstance(descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)) 

626 return cls(input=member_id, output=member_id) 

627 

628 def get_descr(self): 

629 return v0_5.SigmoidDescr() 

630 

631 

632@dataclass 

633class Softmax(_SimpleOperator): 

634 """Softmax activation function.""" 

635 

636 axis: AxisId = AxisId("channel") 

637 

638 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

639 axis_idx = x.dims.index(self.axis) 

640 result = scipy.special.softmax(x.data, axis=axis_idx) 

641 result_xr = xr.DataArray(result, dims=x.dims) 

642 return Tensor.from_xarray(result_xr) 

643 

644 @property 

645 def required_measures(self) -> Collection[Measure]: 

646 return {} 

647 

648 def get_output_shape( 

649 self, input_shape: Mapping[AxisId, int] 

650 ) -> Mapping[AxisId, int]: 

651 return input_shape 

652 

653 @classmethod 

654 def from_proc_descr(cls, descr: v0_5.SoftmaxDescr, member_id: MemberId) -> Self: 

655 assert isinstance(descr, v0_5.SoftmaxDescr) 

656 return cls(input=member_id, output=member_id, axis=descr.kwargs.axis) 

657 

658 def get_descr(self): 

659 return v0_5.SoftmaxDescr(kwargs=v0_5.SoftmaxKwargs(axis=self.axis)) 

660 

661 

662@dataclass 

663class ZeroMeanUnitVariance(_SimpleOperator): 

664 """normalize to zero mean, unit variance.""" 

665 

666 mean: MeanMeasure 

667 std: StdMeasure 

668 

669 eps: float = 1e-6 

670 

671 def __post_init__(self): 

672 assert self.mean.axes == self.std.axes 

673 

674 @property 

675 def required_measures(self) -> Set[Union[MeanMeasure, StdMeasure]]: 

676 return {self.mean, self.std} 

677 

678 def get_output_shape( 

679 self, input_shape: Mapping[AxisId, int] 

680 ) -> Mapping[AxisId, int]: 

681 return input_shape 

682 

683 @classmethod 

684 def from_proc_descr( 

685 cls, 

686 descr: Union[v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr], 

687 member_id: MemberId, 

688 ): 

689 dataset_mode, axes = _get_axes(descr.kwargs) 

690 

691 if dataset_mode: 

692 Mean = DatasetMean 

693 Std = DatasetStd 

694 else: 

695 Mean = SampleMean 

696 Std = SampleStd 

697 

698 return cls( 

699 input=member_id, 

700 output=member_id, 

701 mean=Mean(axes=axes, member_id=member_id), 

702 std=Std(axes=axes, member_id=member_id), 

703 ) 

704 

705 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

706 mean = stat[self.mean] 

707 std = stat[self.std] 

708 return (x - mean) / (std + self.eps) 

709 

710 def get_descr(self): 

711 return v0_5.ZeroMeanUnitVarianceDescr( 

712 kwargs=v0_5.ZeroMeanUnitVarianceKwargs(axes=self.mean.axes, eps=self.eps) 

713 ) 

714 

715 

716@dataclass 

717class FixedZeroMeanUnitVariance(_SimpleOperator): 

718 """normalize to zero mean, unit variance with precomputed values.""" 

719 

720 mean: Union[float, xr.DataArray] 

721 std: Union[float, xr.DataArray] 

722 

723 eps: float = 1e-6 

724 

725 def __post_init__(self): 

726 assert ( 

727 isinstance(self.mean, (int, float)) 

728 or isinstance(self.std, (int, float)) 

729 or self.mean.dims == self.std.dims 

730 ) 

731 

732 def get_output_shape( 

733 self, input_shape: Mapping[AxisId, int] 

734 ) -> Mapping[AxisId, int]: 

735 return input_shape 

736 

737 @classmethod 

738 def from_proc_descr( 

739 cls, 

740 descr: v0_5.FixedZeroMeanUnitVarianceDescr, 

741 member_id: MemberId, 

742 ) -> Self: 

743 if isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceKwargs): 

744 dims = None 

745 elif isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs): 

746 dims = (AxisId(descr.kwargs.axis),) 

747 else: 

748 assert_never(descr.kwargs) 

749 

750 return cls( 

751 input=member_id, 

752 output=member_id, 

753 mean=xr.DataArray(descr.kwargs.mean, dims=dims), 

754 std=xr.DataArray(descr.kwargs.std, dims=dims), 

755 ) 

756 

757 def get_descr(self): 

758 if isinstance(self.mean, (int, float)): 

759 assert isinstance(self.std, (int, float)) 

760 kwargs = v0_5.FixedZeroMeanUnitVarianceKwargs(mean=self.mean, std=self.std) 

761 else: 

762 assert isinstance(self.std, xr.DataArray) 

763 assert len(self.mean.dims) == 1 

764 kwargs = v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs( 

765 axis=AxisId(str(self.mean.dims[0])), 

766 mean=list(self.mean), 

767 std=list(self.std), 

768 ) 

769 

770 return v0_5.FixedZeroMeanUnitVarianceDescr(kwargs=kwargs) 

771 

772 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

773 return (x - self.mean) / (self.std + self.eps) 

774 

775 

776ProcDescr = Union[ 

777 v0_4.PreprocessingDescr, 

778 v0_4.PostprocessingDescr, 

779 v0_5.PreprocessingDescr, 

780 v0_5.PostprocessingDescr, 

781] 

782 

783Processing = Union[ 

784 AddKnownDatasetStats, 

785 Binarize, 

786 Clip, 

787 EnsureDtype, 

788 FixedZeroMeanUnitVariance, 

789 ScaleLinear, 

790 ScaleMeanVariance, 

791 ScaleRange, 

792 Sigmoid, 

793 Softmax, 

794 UpdateStats, 

795 ZeroMeanUnitVariance, 

796] 

797 

798 

799def get_proc( 

800 proc_descr: ProcDescr, 

801 tensor_descr: Union[ 

802 v0_4.InputTensorDescr, 

803 v0_4.OutputTensorDescr, 

804 v0_5.InputTensorDescr, 

805 v0_5.OutputTensorDescr, 

806 ], 

807) -> Processing: 

808 member_id = get_member_id(tensor_descr) 

809 

810 if isinstance(proc_descr, (v0_4.BinarizeDescr, v0_5.BinarizeDescr)): 

811 return Binarize.from_proc_descr(proc_descr, member_id) 

812 elif isinstance(proc_descr, (v0_4.ClipDescr, v0_5.ClipDescr)): 

813 return Clip.from_proc_descr(proc_descr, member_id) 

814 elif isinstance(proc_descr, v0_5.EnsureDtypeDescr): 

815 return EnsureDtype.from_proc_descr(proc_descr, member_id) 

816 elif isinstance(proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr): 

817 return FixedZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id) 

818 elif isinstance(proc_descr, (v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr)): 

819 return ScaleLinear.from_proc_descr(proc_descr, member_id) 

820 elif isinstance( 

821 proc_descr, (v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr) 

822 ): 

823 return ScaleMeanVariance.from_proc_descr(proc_descr, member_id) 

824 elif isinstance(proc_descr, (v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr)): 

825 return ScaleRange.from_proc_descr(proc_descr, member_id) 

826 elif isinstance(proc_descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)): 

827 return Sigmoid.from_proc_descr(proc_descr, member_id) 

828 elif ( 

829 isinstance(proc_descr, v0_4.ZeroMeanUnitVarianceDescr) 

830 and proc_descr.kwargs.mode == "fixed" 

831 ): 

832 if not isinstance( 

833 tensor_descr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr) 

834 ): 

835 raise TypeError( 

836 "Expected v0_4 tensor description for v0_4 processing description" 

837 ) 

838 

839 v5_proc_descr = _convert_proc(proc_descr, tensor_descr.axes) 

840 assert isinstance(v5_proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr) 

841 return FixedZeroMeanUnitVariance.from_proc_descr(v5_proc_descr, member_id) 

842 elif isinstance( 

843 proc_descr, 

844 (v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr), 

845 ): 

846 return ZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id) 

847 elif isinstance(proc_descr, v0_5.SoftmaxDescr): 

848 return Softmax.from_proc_descr(proc_descr, member_id) 

849 else: 

850 assert_never(proc_descr)