Coverage for bioimageio/core/proc_ops.py: 77%

321 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-19 09:02 +0000

1import collections.abc 

2from abc import ABC, abstractmethod 

3from dataclasses import InitVar, dataclass, field 

4from typing import ( 

5 Collection, 

6 Literal, 

7 Mapping, 

8 Optional, 

9 Sequence, 

10 Set, 

11 Tuple, 

12 Union, 

13) 

14 

15import numpy as np 

16import xarray as xr 

17from typing_extensions import Self, assert_never 

18 

19from bioimageio.spec.model import v0_4, v0_5 

20 

21from ._op_base import BlockedOperator, Operator 

22from .axis import AxisId, PerAxis 

23from .block import Block 

24from .common import DTypeStr, MemberId 

25from .sample import Sample, SampleBlock, SampleBlockWithOrigin 

26from .stat_calculators import StatsCalculator 

27from .stat_measures import ( 

28 DatasetMean, 

29 DatasetMeasure, 

30 DatasetPercentile, 

31 DatasetStd, 

32 MeanMeasure, 

33 Measure, 

34 MeasureValue, 

35 SampleMean, 

36 SampleQuantile, 

37 SampleStd, 

38 Stat, 

39 StdMeasure, 

40) 

41from .tensor import Tensor 

42 

43 

44def _convert_axis_ids( 

45 axes: v0_4.AxesInCZYX, 

46 mode: Literal["per_sample", "per_dataset"], 

47) -> Tuple[AxisId, ...]: 

48 if not isinstance(axes, str): 

49 return tuple(axes) 

50 

51 if mode == "per_sample": 

52 ret = [] 

53 elif mode == "per_dataset": 

54 ret = [AxisId("b")] 

55 else: 

56 assert_never(mode) 

57 

58 ret.extend([AxisId(a) for a in axes]) 

59 return tuple(ret) 

60 

61 

62@dataclass 

63class _SimpleOperator(BlockedOperator, ABC): 

64 input: MemberId 

65 output: MemberId 

66 

67 @property 

68 def required_measures(self) -> Collection[Measure]: 

69 return set() 

70 

71 @abstractmethod 

72 def get_output_shape(self, input_shape: PerAxis[int]) -> PerAxis[int]: ... 

73 

74 def __call__(self, sample: Union[Sample, SampleBlock]) -> None: 

75 if self.input not in sample.members: 

76 return 

77 

78 input_tensor = sample.members[self.input] 

79 output_tensor = self._apply(input_tensor, sample.stat) 

80 

81 if self.output in sample.members: 

82 assert ( 

83 sample.members[self.output].tagged_shape == output_tensor.tagged_shape 

84 ) 

85 

86 if isinstance(sample, Sample): 

87 sample.members[self.output] = output_tensor 

88 elif isinstance(sample, SampleBlock): 

89 b = sample.blocks[self.input] 

90 sample.blocks[self.output] = Block( 

91 sample_shape=self.get_output_shape(sample.shape[self.input]), 

92 data=output_tensor, 

93 inner_slice=b.inner_slice, 

94 halo=b.halo, 

95 block_index=b.block_index, 

96 blocks_in_sample=b.blocks_in_sample, 

97 ) 

98 else: 

99 assert_never(sample) 

100 

101 @abstractmethod 

102 def _apply(self, input: Tensor, stat: Stat) -> Tensor: ... 

103 

104 

105@dataclass 

106class AddKnownDatasetStats(BlockedOperator): 

107 dataset_stats: Mapping[DatasetMeasure, MeasureValue] 

108 

109 @property 

110 def required_measures(self) -> Set[Measure]: 

111 return set() 

112 

113 def __call__(self, sample: Union[Sample, SampleBlock]) -> None: 

114 sample.stat.update(self.dataset_stats.items()) 

115 

116 

117# @dataclass 

118# class UpdateStats(Operator): 

119# """Calculates sample and/or dataset measures""" 

120 

121# measures: Union[Sequence[Measure], Set[Measure], Mapping[Measure, MeasureValue]] 

122# """sample and dataset `measuers` to be calculated by this operator. Initial/fixed 

123# dataset measure values may be given, see `keep_updating_dataset_stats` for details. 

124# """ 

125# keep_updating_dataset_stats: Optional[bool] = None 

126# """indicates if operator calls should keep updating dataset statistics or not 

127 

128# default (None): if `measures` is a `Mapping` (i.e. initial measure values are 

129# given) no further updates to dataset statistics is conducted, otherwise (w.o. 

130# initial measure values) dataset statistics are updated by each processed sample. 

131# """ 

132# _keep_updating_dataset_stats: bool = field(init=False) 

133# _stats_calculator: StatsCalculator = field(init=False) 

134 

135# @property 

136# def required_measures(self) -> Set[Measure]: 

137# return set() 

138 

139# def __post_init__(self): 

140# self._stats_calculator = StatsCalculator(self.measures) 

141# if self.keep_updating_dataset_stats is None: 

142# self._keep_updating_dataset_stats = not isinstance(self.measures, collections.abc.Mapping) 

143# else: 

144# self._keep_updating_dataset_stats = self.keep_updating_dataset_stats 

145 

146# def __call__(self, sample_block: SampleBlockWithOrigin> None: 

147# if self._keep_updating_dataset_stats: 

148# sample.stat.update(self._stats_calculator.update_and_get_all(sample)) 

149# else: 

150# sample.stat.update(self._stats_calculator.skip_update_and_get_all(sample)) 

151 

152 

153@dataclass 

154class UpdateStats(Operator): 

155 """Calculates sample and/or dataset measures""" 

156 

157 stats_calculator: StatsCalculator 

158 """`StatsCalculator` to be used by this operator.""" 

159 keep_updating_initial_dataset_stats: bool = False 

160 """indicates if operator calls should keep updating initial dataset statistics or not; 

161 if the `stats_calculator` was not provided with any initial dataset statistics, 

162 these are always updated with every new sample. 

163 """ 

164 _keep_updating_dataset_stats: bool = field(init=False) 

165 

166 @property 

167 def required_measures(self) -> Set[Measure]: 

168 return set() 

169 

170 def __post_init__(self): 

171 self._keep_updating_dataset_stats = ( 

172 self.keep_updating_initial_dataset_stats 

173 or not self.stats_calculator.has_dataset_measures 

174 ) 

175 

176 def __call__(self, sample: Union[Sample, SampleBlockWithOrigin]) -> None: 

177 if isinstance(sample, SampleBlockWithOrigin): 

178 # update stats with whole sample on first block 

179 if sample.block_index != 0: 

180 return 

181 

182 origin = sample.origin 

183 else: 

184 origin = sample 

185 

186 if self._keep_updating_dataset_stats: 

187 sample.stat.update(self.stats_calculator.update_and_get_all(origin)) 

188 else: 

189 sample.stat.update(self.stats_calculator.skip_update_and_get_all(origin)) 

190 

191 

192@dataclass 

193class Binarize(_SimpleOperator): 

194 """'output = tensor > threshold'.""" 

195 

196 threshold: Union[float, Sequence[float]] 

197 axis: Optional[AxisId] = None 

198 

199 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 

200 return input > self.threshold 

201 

202 def get_output_shape( 

203 self, input_shape: Mapping[AxisId, int] 

204 ) -> Mapping[AxisId, int]: 

205 return input_shape 

206 

207 @classmethod 

208 def from_proc_descr( 

209 cls, descr: Union[v0_4.BinarizeDescr, v0_5.BinarizeDescr], member_id: MemberId 

210 ) -> Self: 

211 if isinstance(descr.kwargs, (v0_4.BinarizeKwargs, v0_5.BinarizeKwargs)): 

212 return cls( 

213 input=member_id, output=member_id, threshold=descr.kwargs.threshold 

214 ) 

215 elif isinstance(descr.kwargs, v0_5.BinarizeAlongAxisKwargs): 

216 return cls( 

217 input=member_id, 

218 output=member_id, 

219 threshold=descr.kwargs.threshold, 

220 axis=descr.kwargs.axis, 

221 ) 

222 else: 

223 assert_never(descr.kwargs) 

224 

225 

226@dataclass 

227class Clip(_SimpleOperator): 

228 min: Optional[float] = None 

229 """minimum value for clipping""" 

230 max: Optional[float] = None 

231 """maximum value for clipping""" 

232 

233 def __post_init__(self): 

234 assert self.min is not None or self.max is not None, "missing min or max value" 

235 assert ( 

236 self.min is None or self.max is None or self.min < self.max 

237 ), f"expected min < max, but {self.min} !< {self.max}" 

238 

239 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 

240 return input.clip(self.min, self.max) 

241 

242 def get_output_shape( 

243 self, input_shape: Mapping[AxisId, int] 

244 ) -> Mapping[AxisId, int]: 

245 return input_shape 

246 

247 @classmethod 

248 def from_proc_descr( 

249 cls, descr: Union[v0_4.ClipDescr, v0_5.ClipDescr], member_id: MemberId 

250 ) -> Self: 

251 return cls( 

252 input=member_id, 

253 output=member_id, 

254 min=descr.kwargs.min, 

255 max=descr.kwargs.max, 

256 ) 

257 

258 

259@dataclass 

260class EnsureDtype(_SimpleOperator): 

261 dtype: DTypeStr 

262 

263 @classmethod 

264 def from_proc_descr(cls, descr: v0_5.EnsureDtypeDescr, member_id: MemberId): 

265 return cls(input=member_id, output=member_id, dtype=descr.kwargs.dtype) 

266 

267 def get_descr(self): 

268 return v0_5.EnsureDtypeDescr(kwargs=v0_5.EnsureDtypeKwargs(dtype=self.dtype)) 

269 

270 def get_output_shape( 

271 self, input_shape: Mapping[AxisId, int] 

272 ) -> Mapping[AxisId, int]: 

273 return input_shape 

274 

275 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 

276 return input.astype(self.dtype) 

277 

278 

279@dataclass 

280class ScaleLinear(_SimpleOperator): 

281 gain: Union[float, xr.DataArray] = 1.0 

282 """multiplicative factor""" 

283 

284 offset: Union[float, xr.DataArray] = 0.0 

285 """additive term""" 

286 

287 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 

288 return input * self.gain + self.offset 

289 

290 def get_output_shape( 

291 self, input_shape: Mapping[AxisId, int] 

292 ) -> Mapping[AxisId, int]: 

293 return input_shape 

294 

295 @classmethod 

296 def from_proc_descr( 

297 cls, 

298 descr: Union[v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr], 

299 member_id: MemberId, 

300 ) -> Self: 

301 kwargs = descr.kwargs 

302 if isinstance(kwargs, v0_5.ScaleLinearAlongAxisKwargs): 

303 axis = kwargs.axis 

304 elif isinstance(kwargs, (v0_4.ScaleLinearKwargs, v0_5.ScaleLinearKwargs)): 

305 axis = None 

306 else: 

307 assert_never(kwargs) 

308 

309 if axis: 

310 gain = xr.DataArray(np.atleast_1d(kwargs.gain), dims=axis) 

311 offset = xr.DataArray(np.atleast_1d(kwargs.offset), dims=axis) 

312 else: 

313 assert ( 

314 isinstance(kwargs.gain, (float, int)) or len(kwargs.gain) == 1 

315 ), kwargs.gain 

316 gain = ( 

317 kwargs.gain if isinstance(kwargs.gain, (float, int)) else kwargs.gain[0] 

318 ) 

319 assert isinstance(kwargs.offset, (float, int)) or len(kwargs.offset) == 1 

320 offset = ( 

321 kwargs.offset 

322 if isinstance(kwargs.offset, (float, int)) 

323 else kwargs.offset[0] 

324 ) 

325 

326 return cls(input=member_id, output=member_id, gain=gain, offset=offset) 

327 

328 

329@dataclass 

330class ScaleMeanVariance(_SimpleOperator): 

331 axes: Optional[Sequence[AxisId]] = None 

332 reference_tensor: Optional[MemberId] = None 

333 eps: float = 1e-6 

334 mean: Union[SampleMean, DatasetMean] = field(init=False) 

335 std: Union[SampleStd, DatasetStd] = field(init=False) 

336 ref_mean: Union[SampleMean, DatasetMean] = field(init=False) 

337 ref_std: Union[SampleStd, DatasetStd] = field(init=False) 

338 

339 @property 

340 def required_measures(self): 

341 return {self.mean, self.std, self.ref_mean, self.ref_std} 

342 

343 def __post_init__(self): 

344 axes = None if self.axes is None else tuple(self.axes) 

345 ref_tensor = self.reference_tensor or self.input 

346 if axes is None or AxisId("batch") not in axes: 

347 Mean = SampleMean 

348 Std = SampleStd 

349 else: 

350 Mean = DatasetMean 

351 Std = DatasetStd 

352 

353 self.mean = Mean(member_id=self.input, axes=axes) 

354 self.std = Std(member_id=self.input, axes=axes) 

355 self.ref_mean = Mean(member_id=ref_tensor, axes=axes) 

356 self.ref_std = Std(member_id=ref_tensor, axes=axes) 

357 

358 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 

359 mean = stat[self.mean] 

360 std = stat[self.std] + self.eps 

361 ref_mean = stat[self.ref_mean] 

362 ref_std = stat[self.ref_std] + self.eps 

363 return (input - mean) / std * ref_std + ref_mean 

364 

365 def get_output_shape( 

366 self, input_shape: Mapping[AxisId, int] 

367 ) -> Mapping[AxisId, int]: 

368 return input_shape 

369 

370 @classmethod 

371 def from_proc_descr( 

372 cls, 

373 descr: Union[v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr], 

374 member_id: MemberId, 

375 ) -> Self: 

376 kwargs = descr.kwargs 

377 _, axes = _get_axes(descr.kwargs) 

378 

379 return cls( 

380 input=member_id, 

381 output=member_id, 

382 reference_tensor=MemberId(str(kwargs.reference_tensor)), 

383 axes=axes, 

384 eps=kwargs.eps, 

385 ) 

386 

387 

388def _get_axes( 

389 kwargs: Union[ 

390 v0_4.ZeroMeanUnitVarianceKwargs, 

391 v0_5.ZeroMeanUnitVarianceKwargs, 

392 v0_4.ScaleRangeKwargs, 

393 v0_5.ScaleRangeKwargs, 

394 v0_4.ScaleMeanVarianceKwargs, 

395 v0_5.ScaleMeanVarianceKwargs, 

396 ], 

397) -> Tuple[bool, Optional[Tuple[AxisId, ...]]]: 

398 if kwargs.axes is None: 

399 return True, None 

400 elif isinstance(kwargs.axes, str): 

401 axes = _convert_axis_ids(kwargs.axes, kwargs["mode"]) 

402 return AxisId("b") in axes, axes 

403 elif isinstance(kwargs.axes, collections.abc.Sequence): 

404 axes = tuple(kwargs.axes) 

405 return AxisId("batch") in axes, axes 

406 else: 

407 assert_never(kwargs.axes) 

408 

409 

410@dataclass 

411class ScaleRange(_SimpleOperator): 

412 lower_percentile: InitVar[Optional[Union[SampleQuantile, DatasetPercentile]]] = None 

413 upper_percentile: InitVar[Optional[Union[SampleQuantile, DatasetPercentile]]] = None 

414 lower: Union[SampleQuantile, DatasetPercentile] = field(init=False) 

415 upper: Union[SampleQuantile, DatasetPercentile] = field(init=False) 

416 

417 eps: float = 1e-6 

418 

419 def __post_init__( 

420 self, 

421 lower_percentile: Optional[Union[SampleQuantile, DatasetPercentile]], 

422 upper_percentile: Optional[Union[SampleQuantile, DatasetPercentile]], 

423 ): 

424 if lower_percentile is None: 

425 tid = self.input if upper_percentile is None else upper_percentile.member_id 

426 self.lower = DatasetPercentile(q=0.0, member_id=tid) 

427 else: 

428 self.lower = lower_percentile 

429 

430 if upper_percentile is None: 

431 self.upper = DatasetPercentile(q=1.0, member_id=self.lower.member_id) 

432 else: 

433 self.upper = upper_percentile 

434 

435 assert self.lower.member_id == self.upper.member_id 

436 assert self.lower.q < self.upper.q 

437 assert self.lower.axes == self.upper.axes 

438 

439 @property 

440 def required_measures(self): 

441 return {self.lower, self.upper} 

442 

443 def get_output_shape( 

444 self, input_shape: Mapping[AxisId, int] 

445 ) -> Mapping[AxisId, int]: 

446 return input_shape 

447 

448 @classmethod 

449 def from_proc_descr( 

450 cls, 

451 descr: Union[v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr], 

452 member_id: MemberId, 

453 ): 

454 kwargs = descr.kwargs 

455 ref_tensor = ( 

456 member_id 

457 if kwargs.reference_tensor is None 

458 else MemberId(str(kwargs.reference_tensor)) 

459 ) 

460 dataset_mode, axes = _get_axes(descr.kwargs) 

461 if dataset_mode: 

462 Percentile = DatasetPercentile 

463 else: 

464 Percentile = SampleQuantile 

465 

466 return cls( 

467 input=member_id, 

468 output=member_id, 

469 lower_percentile=Percentile( 

470 q=kwargs.min_percentile / 100, axes=axes, member_id=ref_tensor 

471 ), 

472 upper_percentile=Percentile( 

473 q=kwargs.max_percentile / 100, axes=axes, member_id=ref_tensor 

474 ), 

475 ) 

476 

477 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 

478 lower = stat[self.lower] 

479 upper = stat[self.upper] 

480 return (input - lower) / (upper - lower + self.eps) 

481 

482 def get_descr(self): 

483 assert self.lower.axes == self.upper.axes 

484 assert self.lower.member_id == self.upper.member_id 

485 

486 return v0_5.ScaleRangeDescr( 

487 kwargs=v0_5.ScaleRangeKwargs( 

488 axes=self.lower.axes, 

489 min_percentile=self.lower.q * 100, 

490 max_percentile=self.upper.q * 100, 

491 eps=self.eps, 

492 reference_tensor=self.lower.member_id, 

493 ) 

494 ) 

495 

496 

497@dataclass 

498class Sigmoid(_SimpleOperator): 

499 """1 / (1 + e^(-input)).""" 

500 

501 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 

502 return Tensor(1.0 / (1.0 + np.exp(-input)), dims=input.dims) 

503 

504 @property 

505 def required_measures(self) -> Collection[Measure]: 

506 return {} 

507 

508 def get_output_shape( 

509 self, input_shape: Mapping[AxisId, int] 

510 ) -> Mapping[AxisId, int]: 

511 return input_shape 

512 

513 @classmethod 

514 def from_proc_descr( 

515 cls, descr: Union[v0_4.SigmoidDescr, v0_5.SigmoidDescr], member_id: MemberId 

516 ) -> Self: 

517 assert isinstance(descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)) 

518 return cls(input=member_id, output=member_id) 

519 

520 def get_descr(self): 

521 return v0_5.SigmoidDescr() 

522 

523 

524@dataclass 

525class ZeroMeanUnitVariance(_SimpleOperator): 

526 """normalize to zero mean, unit variance.""" 

527 

528 mean: MeanMeasure 

529 std: StdMeasure 

530 

531 eps: float = 1e-6 

532 

533 def __post_init__(self): 

534 assert self.mean.axes == self.std.axes 

535 

536 @property 

537 def required_measures(self) -> Set[Union[MeanMeasure, StdMeasure]]: 

538 return {self.mean, self.std} 

539 

540 def get_output_shape( 

541 self, input_shape: Mapping[AxisId, int] 

542 ) -> Mapping[AxisId, int]: 

543 return input_shape 

544 

545 @classmethod 

546 def from_proc_descr( 

547 cls, 

548 descr: Union[v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr], 

549 member_id: MemberId, 

550 ): 

551 dataset_mode, axes = _get_axes(descr.kwargs) 

552 

553 if dataset_mode: 

554 Mean = DatasetMean 

555 Std = DatasetStd 

556 else: 

557 Mean = SampleMean 

558 Std = SampleStd 

559 

560 return cls( 

561 input=member_id, 

562 output=member_id, 

563 mean=Mean(axes=axes, member_id=member_id), 

564 std=Std(axes=axes, member_id=member_id), 

565 ) 

566 

567 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 

568 mean = stat[self.mean] 

569 std = stat[self.std] 

570 return (input - mean) / (std + self.eps) 

571 

572 def get_descr(self): 

573 return v0_5.ZeroMeanUnitVarianceDescr( 

574 kwargs=v0_5.ZeroMeanUnitVarianceKwargs(axes=self.mean.axes, eps=self.eps) 

575 ) 

576 

577 

578@dataclass 

579class FixedZeroMeanUnitVariance(_SimpleOperator): 

580 """normalize to zero mean, unit variance with precomputed values.""" 

581 

582 mean: Union[float, xr.DataArray] 

583 std: Union[float, xr.DataArray] 

584 

585 eps: float = 1e-6 

586 

587 def __post_init__(self): 

588 assert ( 

589 isinstance(self.mean, (int, float)) 

590 or isinstance(self.std, (int, float)) 

591 or self.mean.dims == self.std.dims 

592 ) 

593 

594 def get_output_shape( 

595 self, input_shape: Mapping[AxisId, int] 

596 ) -> Mapping[AxisId, int]: 

597 return input_shape 

598 

599 @classmethod 

600 def from_proc_descr( 

601 cls, 

602 descr: v0_5.FixedZeroMeanUnitVarianceDescr, 

603 member_id: MemberId, 

604 ) -> Self: 

605 if isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceKwargs): 

606 dims = None 

607 elif isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs): 

608 dims = (descr.kwargs.axis,) 

609 else: 

610 assert_never(descr.kwargs) 

611 

612 return cls( 

613 input=member_id, 

614 output=member_id, 

615 mean=xr.DataArray(descr.kwargs.mean, dims=dims), 

616 std=xr.DataArray(descr.kwargs.std, dims=dims), 

617 ) 

618 

619 def get_descr(self): 

620 if isinstance(self.mean, (int, float)): 

621 assert isinstance(self.std, (int, float)) 

622 kwargs = v0_5.FixedZeroMeanUnitVarianceKwargs(mean=self.mean, std=self.std) 

623 else: 

624 assert isinstance(self.std, xr.DataArray) 

625 assert len(self.mean.dims) == 1 

626 kwargs = v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs( 

627 axis=AxisId(str(self.mean.dims[0])), 

628 mean=list(self.mean), 

629 std=list(self.std), 

630 ) 

631 

632 return v0_5.FixedZeroMeanUnitVarianceDescr(kwargs=kwargs) 

633 

634 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 

635 return (input - self.mean) / (self.std + self.eps) 

636 

637 

638ProcDescr = Union[ 

639 v0_4.PreprocessingDescr, 

640 v0_4.PostprocessingDescr, 

641 v0_5.PreprocessingDescr, 

642 v0_5.PostprocessingDescr, 

643] 

644 

645Processing = Union[ 

646 AddKnownDatasetStats, 

647 Binarize, 

648 Clip, 

649 EnsureDtype, 

650 FixedZeroMeanUnitVariance, 

651 ScaleLinear, 

652 ScaleMeanVariance, 

653 ScaleRange, 

654 Sigmoid, 

655 UpdateStats, 

656 ZeroMeanUnitVariance, 

657] 

658 

659 

660def get_proc_class(proc_spec: ProcDescr): 

661 if isinstance(proc_spec, (v0_4.BinarizeDescr, v0_5.BinarizeDescr)): 

662 return Binarize 

663 elif isinstance(proc_spec, (v0_4.ClipDescr, v0_5.ClipDescr)): 

664 return Clip 

665 elif isinstance(proc_spec, v0_5.EnsureDtypeDescr): 

666 return EnsureDtype 

667 elif isinstance(proc_spec, v0_5.FixedZeroMeanUnitVarianceDescr): 

668 return FixedZeroMeanUnitVariance 

669 elif isinstance(proc_spec, (v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr)): 

670 return ScaleLinear 

671 elif isinstance( 

672 proc_spec, (v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr) 

673 ): 

674 return ScaleMeanVariance 

675 elif isinstance(proc_spec, (v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr)): 

676 return ScaleRange 

677 elif isinstance(proc_spec, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)): 

678 return Sigmoid 

679 elif ( 

680 isinstance(proc_spec, v0_4.ZeroMeanUnitVarianceDescr) 

681 and proc_spec.kwargs.mode == "fixed" 

682 ): 

683 return FixedZeroMeanUnitVariance 

684 elif isinstance( 

685 proc_spec, 

686 (v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr), 

687 ): 

688 return ZeroMeanUnitVariance 

689 else: 

690 assert_never(proc_spec)