Coverage for src/bioimageio/core/proc_ops.py: 77%

350 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-09-22 09:21 +0000

1import collections.abc 

2from abc import ABC, abstractmethod 

3from dataclasses import InitVar, dataclass, field 

4from typing import ( 

5 Collection, 

6 Literal, 

7 Mapping, 

8 Optional, 

9 Sequence, 

10 Set, 

11 Tuple, 

12 Union, 

13) 

14 

15import numpy as np 

16import scipy # pyright: ignore[reportMissingTypeStubs] 

17import xarray as xr 

18from typing_extensions import Self, assert_never 

19 

20from bioimageio.core.digest_spec import get_member_id 

21from bioimageio.spec.model import v0_4, v0_5 

22from bioimageio.spec.model.v0_5 import ( 

23 _convert_proc, # pyright: ignore [reportPrivateUsage] 

24) 

25 

26from ._op_base import BlockedOperator, Operator 

27from .axis import AxisId, PerAxis 

28from .block import Block 

29from .common import DTypeStr, MemberId 

30from .sample import Sample, SampleBlock, SampleBlockWithOrigin 

31from .stat_calculators import StatsCalculator 

32from .stat_measures import ( 

33 DatasetMean, 

34 DatasetMeasure, 

35 DatasetPercentile, 

36 DatasetStd, 

37 MeanMeasure, 

38 Measure, 

39 MeasureValue, 

40 SampleMean, 

41 SampleQuantile, 

42 SampleStd, 

43 Stat, 

44 StdMeasure, 

45) 

46from .tensor import Tensor 

47 

48 

49def _convert_axis_ids( 

50 axes: v0_4.AxesInCZYX, 

51 mode: Literal["per_sample", "per_dataset"], 

52) -> Tuple[AxisId, ...]: 

53 if not isinstance(axes, str): 

54 return tuple(axes) 

55 

56 if mode == "per_sample": 

57 ret = [] 

58 elif mode == "per_dataset": 

59 ret = [v0_5.BATCH_AXIS_ID] 

60 else: 

61 assert_never(mode) 

62 

63 ret.extend([AxisId(a) for a in axes]) 

64 return tuple(ret) 

65 

66 

67@dataclass 

68class _SimpleOperator(BlockedOperator, ABC): 

69 input: MemberId 

70 output: MemberId 

71 

72 @property 

73 def required_measures(self) -> Collection[Measure]: 

74 return set() 

75 

76 @abstractmethod 

77 def get_output_shape(self, input_shape: PerAxis[int]) -> PerAxis[int]: ... 

78 

79 def __call__(self, sample: Union[Sample, SampleBlock]) -> None: 

80 if self.input not in sample.members: 

81 return 

82 

83 input_tensor = sample.members[self.input] 

84 output_tensor = self._apply(input_tensor, sample.stat) 

85 

86 if self.output in sample.members: 

87 assert ( 

88 sample.members[self.output].tagged_shape == output_tensor.tagged_shape 

89 ) 

90 

91 if isinstance(sample, Sample): 

92 sample.members[self.output] = output_tensor 

93 elif isinstance(sample, SampleBlock): 

94 b = sample.blocks[self.input] 

95 sample.blocks[self.output] = Block( 

96 sample_shape=self.get_output_shape(sample.shape[self.input]), 

97 data=output_tensor, 

98 inner_slice=b.inner_slice, 

99 halo=b.halo, 

100 block_index=b.block_index, 

101 blocks_in_sample=b.blocks_in_sample, 

102 ) 

103 else: 

104 assert_never(sample) 

105 

106 @abstractmethod 

107 def _apply(self, x: Tensor, stat: Stat) -> Tensor: ... 

108 

109 

110@dataclass 

111class AddKnownDatasetStats(BlockedOperator): 

112 dataset_stats: Mapping[DatasetMeasure, MeasureValue] 

113 

114 @property 

115 def required_measures(self) -> Set[Measure]: 

116 return set() 

117 

118 def __call__(self, sample: Union[Sample, SampleBlock]) -> None: 

119 sample.stat.update(self.dataset_stats.items()) 

120 

121 

122# @dataclass 

123# class UpdateStats(Operator): 

124# """Calculates sample and/or dataset measures""" 

125 

126# measures: Union[Sequence[Measure], Set[Measure], Mapping[Measure, MeasureValue]] 

127# """sample and dataset `measuers` to be calculated by this operator. Initial/fixed 

128# dataset measure values may be given, see `keep_updating_dataset_stats` for details. 

129# """ 

130# keep_updating_dataset_stats: Optional[bool] = None 

131# """indicates if operator calls should keep updating dataset statistics or not 

132 

133# default (None): if `measures` is a `Mapping` (i.e. initial measure values are 

134# given) no further updates to dataset statistics is conducted, otherwise (w.o. 

135# initial measure values) dataset statistics are updated by each processed sample. 

136# """ 

137# _keep_updating_dataset_stats: bool = field(init=False) 

138# _stats_calculator: StatsCalculator = field(init=False) 

139 

140# @property 

141# def required_measures(self) -> Set[Measure]: 

142# return set() 

143 

144# def __post_init__(self): 

145# self._stats_calculator = StatsCalculator(self.measures) 

146# if self.keep_updating_dataset_stats is None: 

147# self._keep_updating_dataset_stats = not isinstance(self.measures, collections.abc.Mapping) 

148# else: 

149# self._keep_updating_dataset_stats = self.keep_updating_dataset_stats 

150 

151# def __call__(self, sample_block: SampleBlockWithOrigin> None: 

152# if self._keep_updating_dataset_stats: 

153# sample.stat.update(self._stats_calculator.update_and_get_all(sample)) 

154# else: 

155# sample.stat.update(self._stats_calculator.skip_update_and_get_all(sample)) 

156 

157 

158@dataclass 

159class UpdateStats(Operator): 

160 """Calculates sample and/or dataset measures""" 

161 

162 stats_calculator: StatsCalculator 

163 """`StatsCalculator` to be used by this operator.""" 

164 keep_updating_initial_dataset_stats: bool = False 

165 """indicates if operator calls should keep updating initial dataset statistics or not; 

166 if the `stats_calculator` was not provided with any initial dataset statistics, 

167 these are always updated with every new sample. 

168 """ 

169 _keep_updating_dataset_stats: bool = field(init=False) 

170 

171 @property 

172 def required_measures(self) -> Set[Measure]: 

173 return set() 

174 

175 def __post_init__(self): 

176 self._keep_updating_dataset_stats = ( 

177 self.keep_updating_initial_dataset_stats 

178 or not self.stats_calculator.has_dataset_measures 

179 ) 

180 

181 def __call__(self, sample: Union[Sample, SampleBlockWithOrigin]) -> None: 

182 if isinstance(sample, SampleBlockWithOrigin): 

183 # update stats with whole sample on first block 

184 if sample.block_index != 0: 

185 return 

186 

187 origin = sample.origin 

188 else: 

189 origin = sample 

190 

191 if self._keep_updating_dataset_stats: 

192 sample.stat.update(self.stats_calculator.update_and_get_all(origin)) 

193 else: 

194 sample.stat.update(self.stats_calculator.skip_update_and_get_all(origin)) 

195 

196 

197@dataclass 

198class Binarize(_SimpleOperator): 

199 """'output = tensor > threshold'.""" 

200 

201 threshold: Union[float, Sequence[float]] 

202 axis: Optional[AxisId] = None 

203 

204 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

205 return x > self.threshold 

206 

207 def get_output_shape( 

208 self, input_shape: Mapping[AxisId, int] 

209 ) -> Mapping[AxisId, int]: 

210 return input_shape 

211 

212 @classmethod 

213 def from_proc_descr( 

214 cls, descr: Union[v0_4.BinarizeDescr, v0_5.BinarizeDescr], member_id: MemberId 

215 ) -> Self: 

216 if isinstance(descr.kwargs, (v0_4.BinarizeKwargs, v0_5.BinarizeKwargs)): 

217 return cls( 

218 input=member_id, output=member_id, threshold=descr.kwargs.threshold 

219 ) 

220 elif isinstance(descr.kwargs, v0_5.BinarizeAlongAxisKwargs): 

221 return cls( 

222 input=member_id, 

223 output=member_id, 

224 threshold=descr.kwargs.threshold, 

225 axis=descr.kwargs.axis, 

226 ) 

227 else: 

228 assert_never(descr.kwargs) 

229 

230 

231@dataclass 

232class Clip(_SimpleOperator): 

233 min: Optional[float] = None 

234 """minimum value for clipping""" 

235 max: Optional[float] = None 

236 """maximum value for clipping""" 

237 

238 def __post_init__(self): 

239 assert self.min is not None or self.max is not None, "missing min or max value" 

240 assert ( 

241 self.min is None or self.max is None or self.min < self.max 

242 ), f"expected min < max, but {self.min} !< {self.max}" 

243 

244 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

245 return x.clip(self.min, self.max) 

246 

247 def get_output_shape( 

248 self, input_shape: Mapping[AxisId, int] 

249 ) -> Mapping[AxisId, int]: 

250 return input_shape 

251 

252 @classmethod 

253 def from_proc_descr( 

254 cls, descr: Union[v0_4.ClipDescr, v0_5.ClipDescr], member_id: MemberId 

255 ) -> Self: 

256 return cls( 

257 input=member_id, 

258 output=member_id, 

259 min=descr.kwargs.min, 

260 max=descr.kwargs.max, 

261 ) 

262 

263 

264@dataclass 

265class EnsureDtype(_SimpleOperator): 

266 dtype: DTypeStr 

267 

268 @classmethod 

269 def from_proc_descr(cls, descr: v0_5.EnsureDtypeDescr, member_id: MemberId): 

270 return cls(input=member_id, output=member_id, dtype=descr.kwargs.dtype) 

271 

272 def get_descr(self): 

273 return v0_5.EnsureDtypeDescr(kwargs=v0_5.EnsureDtypeKwargs(dtype=self.dtype)) 

274 

275 def get_output_shape( 

276 self, input_shape: Mapping[AxisId, int] 

277 ) -> Mapping[AxisId, int]: 

278 return input_shape 

279 

280 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

281 return x.astype(self.dtype) 

282 

283 

284@dataclass 

285class ScaleLinear(_SimpleOperator): 

286 gain: Union[float, xr.DataArray] = 1.0 

287 """multiplicative factor""" 

288 

289 offset: Union[float, xr.DataArray] = 0.0 

290 """additive term""" 

291 

292 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

293 return x * self.gain + self.offset 

294 

295 def get_output_shape( 

296 self, input_shape: Mapping[AxisId, int] 

297 ) -> Mapping[AxisId, int]: 

298 return input_shape 

299 

300 @classmethod 

301 def from_proc_descr( 

302 cls, 

303 descr: Union[v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr], 

304 member_id: MemberId, 

305 ) -> Self: 

306 kwargs = descr.kwargs 

307 if isinstance(kwargs, v0_5.ScaleLinearKwargs): 

308 axis = None 

309 elif isinstance(kwargs, v0_5.ScaleLinearAlongAxisKwargs): 

310 axis = kwargs.axis 

311 elif isinstance(kwargs, v0_4.ScaleLinearKwargs): 

312 if kwargs.axes is not None: 

313 raise NotImplementedError( 

314 "model.v0_4.ScaleLinearKwargs with axes not implemented, please consider updating the model to v0_5." 

315 ) 

316 axis = None 

317 else: 

318 assert_never(kwargs) 

319 

320 if axis: 

321 gain = xr.DataArray(np.atleast_1d(kwargs.gain), dims=axis) 

322 offset = xr.DataArray(np.atleast_1d(kwargs.offset), dims=axis) 

323 else: 

324 assert ( 

325 isinstance(kwargs.gain, (float, int)) or len(kwargs.gain) == 1 

326 ), kwargs.gain 

327 gain = ( 

328 kwargs.gain if isinstance(kwargs.gain, (float, int)) else kwargs.gain[0] 

329 ) 

330 assert isinstance(kwargs.offset, (float, int)) or len(kwargs.offset) == 1 

331 offset = ( 

332 kwargs.offset 

333 if isinstance(kwargs.offset, (float, int)) 

334 else kwargs.offset[0] 

335 ) 

336 

337 return cls(input=member_id, output=member_id, gain=gain, offset=offset) 

338 

339 

340@dataclass 

341class ScaleMeanVariance(_SimpleOperator): 

342 axes: Optional[Sequence[AxisId]] = None 

343 reference_tensor: Optional[MemberId] = None 

344 eps: float = 1e-6 

345 mean: Union[SampleMean, DatasetMean] = field(init=False) 

346 std: Union[SampleStd, DatasetStd] = field(init=False) 

347 ref_mean: Union[SampleMean, DatasetMean] = field(init=False) 

348 ref_std: Union[SampleStd, DatasetStd] = field(init=False) 

349 

350 @property 

351 def required_measures(self): 

352 return {self.mean, self.std, self.ref_mean, self.ref_std} 

353 

354 def __post_init__(self): 

355 axes = None if self.axes is None else tuple(self.axes) 

356 ref_tensor = self.reference_tensor or self.input 

357 if axes is None or AxisId("batch") not in axes: 

358 Mean = SampleMean 

359 Std = SampleStd 

360 else: 

361 Mean = DatasetMean 

362 Std = DatasetStd 

363 

364 self.mean = Mean(member_id=self.input, axes=axes) 

365 self.std = Std(member_id=self.input, axes=axes) 

366 self.ref_mean = Mean(member_id=ref_tensor, axes=axes) 

367 self.ref_std = Std(member_id=ref_tensor, axes=axes) 

368 

369 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

370 mean = stat[self.mean] 

371 std = stat[self.std] + self.eps 

372 ref_mean = stat[self.ref_mean] 

373 ref_std = stat[self.ref_std] + self.eps 

374 return (x - mean) / std * ref_std + ref_mean 

375 

376 def get_output_shape( 

377 self, input_shape: Mapping[AxisId, int] 

378 ) -> Mapping[AxisId, int]: 

379 return input_shape 

380 

381 @classmethod 

382 def from_proc_descr( 

383 cls, 

384 descr: Union[v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr], 

385 member_id: MemberId, 

386 ) -> Self: 

387 kwargs = descr.kwargs 

388 _, axes = _get_axes(descr.kwargs) 

389 

390 return cls( 

391 input=member_id, 

392 output=member_id, 

393 reference_tensor=MemberId(str(kwargs.reference_tensor)), 

394 axes=axes, 

395 eps=kwargs.eps, 

396 ) 

397 

398 

399def _get_axes( 

400 kwargs: Union[ 

401 v0_4.ZeroMeanUnitVarianceKwargs, 

402 v0_5.ZeroMeanUnitVarianceKwargs, 

403 v0_4.ScaleRangeKwargs, 

404 v0_5.ScaleRangeKwargs, 

405 v0_4.ScaleMeanVarianceKwargs, 

406 v0_5.ScaleMeanVarianceKwargs, 

407 ], 

408) -> Tuple[bool, Optional[Tuple[AxisId, ...]]]: 

409 if kwargs.axes is None: 

410 return True, None 

411 elif isinstance(kwargs.axes, str): 

412 axes = _convert_axis_ids(kwargs.axes, kwargs["mode"]) 

413 return AxisId("b") in axes, axes 

414 elif isinstance(kwargs.axes, collections.abc.Sequence): 

415 axes = tuple(kwargs.axes) 

416 return AxisId("batch") in axes, axes 

417 else: 

418 assert_never(kwargs.axes) 

419 

420 

421@dataclass 

422class ScaleRange(_SimpleOperator): 

423 lower_percentile: InitVar[Optional[Union[SampleQuantile, DatasetPercentile]]] = None 

424 upper_percentile: InitVar[Optional[Union[SampleQuantile, DatasetPercentile]]] = None 

425 lower: Union[SampleQuantile, DatasetPercentile] = field(init=False) 

426 upper: Union[SampleQuantile, DatasetPercentile] = field(init=False) 

427 

428 eps: float = 1e-6 

429 

430 def __post_init__( 

431 self, 

432 lower_percentile: Optional[Union[SampleQuantile, DatasetPercentile]], 

433 upper_percentile: Optional[Union[SampleQuantile, DatasetPercentile]], 

434 ): 

435 if lower_percentile is None: 

436 tid = self.input if upper_percentile is None else upper_percentile.member_id 

437 self.lower = DatasetPercentile(q=0.0, member_id=tid) 

438 else: 

439 self.lower = lower_percentile 

440 

441 if upper_percentile is None: 

442 self.upper = DatasetPercentile(q=1.0, member_id=self.lower.member_id) 

443 else: 

444 self.upper = upper_percentile 

445 

446 assert self.lower.member_id == self.upper.member_id 

447 assert self.lower.q < self.upper.q 

448 assert self.lower.axes == self.upper.axes 

449 

450 @property 

451 def required_measures(self): 

452 return {self.lower, self.upper} 

453 

454 def get_output_shape( 

455 self, input_shape: Mapping[AxisId, int] 

456 ) -> Mapping[AxisId, int]: 

457 return input_shape 

458 

459 @classmethod 

460 def from_proc_descr( 

461 cls, 

462 descr: Union[v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr], 

463 member_id: MemberId, 

464 ): 

465 kwargs = descr.kwargs 

466 ref_tensor = ( 

467 member_id 

468 if kwargs.reference_tensor is None 

469 else MemberId(str(kwargs.reference_tensor)) 

470 ) 

471 dataset_mode, axes = _get_axes(descr.kwargs) 

472 if dataset_mode: 

473 Percentile = DatasetPercentile 

474 else: 

475 Percentile = SampleQuantile 

476 

477 return cls( 

478 input=member_id, 

479 output=member_id, 

480 lower_percentile=Percentile( 

481 q=kwargs.min_percentile / 100, axes=axes, member_id=ref_tensor 

482 ), 

483 upper_percentile=Percentile( 

484 q=kwargs.max_percentile / 100, axes=axes, member_id=ref_tensor 

485 ), 

486 ) 

487 

488 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

489 lower = stat[self.lower] 

490 upper = stat[self.upper] 

491 return (x - lower) / (upper - lower + self.eps) 

492 

493 def get_descr(self): 

494 assert self.lower.axes == self.upper.axes 

495 assert self.lower.member_id == self.upper.member_id 

496 

497 return v0_5.ScaleRangeDescr( 

498 kwargs=v0_5.ScaleRangeKwargs( 

499 axes=self.lower.axes, 

500 min_percentile=self.lower.q * 100, 

501 max_percentile=self.upper.q * 100, 

502 eps=self.eps, 

503 reference_tensor=self.lower.member_id, 

504 ) 

505 ) 

506 

507 

508@dataclass 

509class Sigmoid(_SimpleOperator): 

510 """1 / (1 + e^(-input)).""" 

511 

512 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

513 return Tensor(1.0 / (1.0 + np.exp(-x)), dims=x.dims) 

514 

515 @property 

516 def required_measures(self) -> Collection[Measure]: 

517 return {} 

518 

519 def get_output_shape( 

520 self, input_shape: Mapping[AxisId, int] 

521 ) -> Mapping[AxisId, int]: 

522 return input_shape 

523 

524 @classmethod 

525 def from_proc_descr( 

526 cls, descr: Union[v0_4.SigmoidDescr, v0_5.SigmoidDescr], member_id: MemberId 

527 ) -> Self: 

528 assert isinstance(descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)) 

529 return cls(input=member_id, output=member_id) 

530 

531 def get_descr(self): 

532 return v0_5.SigmoidDescr() 

533 

534 

535@dataclass 

536class Softmax(_SimpleOperator): 

537 """Softmax activation function.""" 

538 

539 axis: AxisId = AxisId("channel") 

540 

541 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

542 axis_idx = x.dims.index(self.axis) 

543 result = scipy.special.softmax(x.data, axis=axis_idx) 

544 result_xr = xr.DataArray(result, dims=x.dims) 

545 return Tensor.from_xarray(result_xr) 

546 

547 @property 

548 def required_measures(self) -> Collection[Measure]: 

549 return {} 

550 

551 def get_output_shape( 

552 self, input_shape: Mapping[AxisId, int] 

553 ) -> Mapping[AxisId, int]: 

554 return input_shape 

555 

556 @classmethod 

557 def from_proc_descr(cls, descr: v0_5.SoftmaxDescr, member_id: MemberId) -> Self: 

558 assert isinstance(descr, v0_5.SoftmaxDescr) 

559 return cls(input=member_id, output=member_id, axis=descr.kwargs.axis) 

560 

561 def get_descr(self): 

562 return v0_5.SoftmaxDescr(kwargs=v0_5.SoftmaxKwargs(axis=self.axis)) 

563 

564 

565@dataclass 

566class ZeroMeanUnitVariance(_SimpleOperator): 

567 """normalize to zero mean, unit variance.""" 

568 

569 mean: MeanMeasure 

570 std: StdMeasure 

571 

572 eps: float = 1e-6 

573 

574 def __post_init__(self): 

575 assert self.mean.axes == self.std.axes 

576 

577 @property 

578 def required_measures(self) -> Set[Union[MeanMeasure, StdMeasure]]: 

579 return {self.mean, self.std} 

580 

581 def get_output_shape( 

582 self, input_shape: Mapping[AxisId, int] 

583 ) -> Mapping[AxisId, int]: 

584 return input_shape 

585 

586 @classmethod 

587 def from_proc_descr( 

588 cls, 

589 descr: Union[v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr], 

590 member_id: MemberId, 

591 ): 

592 dataset_mode, axes = _get_axes(descr.kwargs) 

593 

594 if dataset_mode: 

595 Mean = DatasetMean 

596 Std = DatasetStd 

597 else: 

598 Mean = SampleMean 

599 Std = SampleStd 

600 

601 return cls( 

602 input=member_id, 

603 output=member_id, 

604 mean=Mean(axes=axes, member_id=member_id), 

605 std=Std(axes=axes, member_id=member_id), 

606 ) 

607 

608 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

609 mean = stat[self.mean] 

610 std = stat[self.std] 

611 return (x - mean) / (std + self.eps) 

612 

613 def get_descr(self): 

614 return v0_5.ZeroMeanUnitVarianceDescr( 

615 kwargs=v0_5.ZeroMeanUnitVarianceKwargs(axes=self.mean.axes, eps=self.eps) 

616 ) 

617 

618 

619@dataclass 

620class FixedZeroMeanUnitVariance(_SimpleOperator): 

621 """normalize to zero mean, unit variance with precomputed values.""" 

622 

623 mean: Union[float, xr.DataArray] 

624 std: Union[float, xr.DataArray] 

625 

626 eps: float = 1e-6 

627 

628 def __post_init__(self): 

629 assert ( 

630 isinstance(self.mean, (int, float)) 

631 or isinstance(self.std, (int, float)) 

632 or self.mean.dims == self.std.dims 

633 ) 

634 

635 def get_output_shape( 

636 self, input_shape: Mapping[AxisId, int] 

637 ) -> Mapping[AxisId, int]: 

638 return input_shape 

639 

640 @classmethod 

641 def from_proc_descr( 

642 cls, 

643 descr: v0_5.FixedZeroMeanUnitVarianceDescr, 

644 member_id: MemberId, 

645 ) -> Self: 

646 if isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceKwargs): 

647 dims = None 

648 elif isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs): 

649 dims = (AxisId(descr.kwargs.axis),) 

650 else: 

651 assert_never(descr.kwargs) 

652 

653 return cls( 

654 input=member_id, 

655 output=member_id, 

656 mean=xr.DataArray(descr.kwargs.mean, dims=dims), 

657 std=xr.DataArray(descr.kwargs.std, dims=dims), 

658 ) 

659 

660 def get_descr(self): 

661 if isinstance(self.mean, (int, float)): 

662 assert isinstance(self.std, (int, float)) 

663 kwargs = v0_5.FixedZeroMeanUnitVarianceKwargs(mean=self.mean, std=self.std) 

664 else: 

665 assert isinstance(self.std, xr.DataArray) 

666 assert len(self.mean.dims) == 1 

667 kwargs = v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs( 

668 axis=AxisId(str(self.mean.dims[0])), 

669 mean=list(self.mean), 

670 std=list(self.std), 

671 ) 

672 

673 return v0_5.FixedZeroMeanUnitVarianceDescr(kwargs=kwargs) 

674 

675 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 

676 return (x - self.mean) / (self.std + self.eps) 

677 

678 

679ProcDescr = Union[ 

680 v0_4.PreprocessingDescr, 

681 v0_4.PostprocessingDescr, 

682 v0_5.PreprocessingDescr, 

683 v0_5.PostprocessingDescr, 

684] 

685 

686Processing = Union[ 

687 AddKnownDatasetStats, 

688 Binarize, 

689 Clip, 

690 EnsureDtype, 

691 FixedZeroMeanUnitVariance, 

692 ScaleLinear, 

693 ScaleMeanVariance, 

694 ScaleRange, 

695 Sigmoid, 

696 Softmax, 

697 UpdateStats, 

698 ZeroMeanUnitVariance, 

699] 

700 

701 

702def get_proc( 

703 proc_descr: ProcDescr, 

704 tensor_descr: Union[ 

705 v0_4.InputTensorDescr, 

706 v0_4.OutputTensorDescr, 

707 v0_5.InputTensorDescr, 

708 v0_5.OutputTensorDescr, 

709 ], 

710) -> Processing: 

711 member_id = get_member_id(tensor_descr) 

712 

713 if isinstance(proc_descr, (v0_4.BinarizeDescr, v0_5.BinarizeDescr)): 

714 return Binarize.from_proc_descr(proc_descr, member_id) 

715 elif isinstance(proc_descr, (v0_4.ClipDescr, v0_5.ClipDescr)): 

716 return Clip.from_proc_descr(proc_descr, member_id) 

717 elif isinstance(proc_descr, v0_5.EnsureDtypeDescr): 

718 return EnsureDtype.from_proc_descr(proc_descr, member_id) 

719 elif isinstance(proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr): 

720 return FixedZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id) 

721 elif isinstance(proc_descr, (v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr)): 

722 return ScaleLinear.from_proc_descr(proc_descr, member_id) 

723 elif isinstance( 

724 proc_descr, (v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr) 

725 ): 

726 return ScaleMeanVariance.from_proc_descr(proc_descr, member_id) 

727 elif isinstance(proc_descr, (v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr)): 

728 return ScaleRange.from_proc_descr(proc_descr, member_id) 

729 elif isinstance(proc_descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)): 

730 return Sigmoid.from_proc_descr(proc_descr, member_id) 

731 elif ( 

732 isinstance(proc_descr, v0_4.ZeroMeanUnitVarianceDescr) 

733 and proc_descr.kwargs.mode == "fixed" 

734 ): 

735 if not isinstance( 

736 tensor_descr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr) 

737 ): 

738 raise TypeError( 

739 "Expected v0_4 tensor description for v0_4 processing description" 

740 ) 

741 

742 v5_proc_descr = _convert_proc(proc_descr, tensor_descr.axes) 

743 assert isinstance(v5_proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr) 

744 return FixedZeroMeanUnitVariance.from_proc_descr(v5_proc_descr, member_id) 

745 elif isinstance( 

746 proc_descr, 

747 (v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr), 

748 ): 

749 return ZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id) 

750 elif isinstance(proc_descr, v0_5.SoftmaxDescr): 

751 return Softmax.from_proc_descr(proc_descr, member_id) 

752 else: 

753 assert_never(proc_descr)