Coverage for bioimageio/core/proc_ops.py: 81%

332 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-01 09:51 +0000

1import collections.abc 

2from abc import ABC, abstractmethod 

3from dataclasses import InitVar, dataclass, field 

4from typing import ( 

5 Collection, 

6 Literal, 

7 Mapping, 

8 Optional, 

9 Sequence, 

10 Set, 

11 Tuple, 

12 Union, 

13) 

14 

15import numpy as np 

16import xarray as xr 

17from typing_extensions import Self, assert_never 

18 

19from bioimageio.core.digest_spec import get_member_id 

20from bioimageio.spec.model import v0_4, v0_5 

21from bioimageio.spec.model.v0_5 import ( 

22 _convert_proc, # pyright: ignore [reportPrivateUsage] 

23) 

24 

25from ._op_base import BlockedOperator, Operator 

26from .axis import AxisId, PerAxis 

27from .block import Block 

28from .common import DTypeStr, MemberId 

29from .sample import Sample, SampleBlock, SampleBlockWithOrigin 

30from .stat_calculators import StatsCalculator 

31from .stat_measures import ( 

32 DatasetMean, 

33 DatasetMeasure, 

34 DatasetPercentile, 

35 DatasetStd, 

36 MeanMeasure, 

37 Measure, 

38 MeasureValue, 

39 SampleMean, 

40 SampleQuantile, 

41 SampleStd, 

42 Stat, 

43 StdMeasure, 

44) 

45from .tensor import Tensor 

46 

47 

48def _convert_axis_ids( 

49 axes: v0_4.AxesInCZYX, 

50 mode: Literal["per_sample", "per_dataset"], 

51) -> Tuple[AxisId, ...]: 

52 if not isinstance(axes, str): 

53 return tuple(axes) 

54 

55 if mode == "per_sample": 

56 ret = [] 

57 elif mode == "per_dataset": 

58 ret = [v0_5.BATCH_AXIS_ID] 

59 else: 

60 assert_never(mode) 

61 

62 ret.extend([AxisId(a) for a in axes]) 

63 return tuple(ret) 

64 

65 

66@dataclass 

67class _SimpleOperator(BlockedOperator, ABC): 

68 input: MemberId 

69 output: MemberId 

70 

71 @property 

72 def required_measures(self) -> Collection[Measure]: 

73 return set() 

74 

75 @abstractmethod 

76 def get_output_shape(self, input_shape: PerAxis[int]) -> PerAxis[int]: ... 

77 

78 def __call__(self, sample: Union[Sample, SampleBlock]) -> None: 

79 if self.input not in sample.members: 

80 return 

81 

82 input_tensor = sample.members[self.input] 

83 output_tensor = self._apply(input_tensor, sample.stat) 

84 

85 if self.output in sample.members: 

86 assert ( 

87 sample.members[self.output].tagged_shape == output_tensor.tagged_shape 

88 ) 

89 

90 if isinstance(sample, Sample): 

91 sample.members[self.output] = output_tensor 

92 elif isinstance(sample, SampleBlock): 

93 b = sample.blocks[self.input] 

94 sample.blocks[self.output] = Block( 

95 sample_shape=self.get_output_shape(sample.shape[self.input]), 

96 data=output_tensor, 

97 inner_slice=b.inner_slice, 

98 halo=b.halo, 

99 block_index=b.block_index, 

100 blocks_in_sample=b.blocks_in_sample, 

101 ) 

102 else: 

103 assert_never(sample) 

104 

105 @abstractmethod 

106 def _apply(self, input: Tensor, stat: Stat) -> Tensor: ... 

107 

108 

109@dataclass 

110class AddKnownDatasetStats(BlockedOperator): 

111 dataset_stats: Mapping[DatasetMeasure, MeasureValue] 

112 

113 @property 

114 def required_measures(self) -> Set[Measure]: 

115 return set() 

116 

117 def __call__(self, sample: Union[Sample, SampleBlock]) -> None: 

118 sample.stat.update(self.dataset_stats.items()) 

119 

120 

121# @dataclass 

122# class UpdateStats(Operator): 

123# """Calculates sample and/or dataset measures""" 

124 

125# measures: Union[Sequence[Measure], Set[Measure], Mapping[Measure, MeasureValue]] 

126# """sample and dataset `measuers` to be calculated by this operator. Initial/fixed 

127# dataset measure values may be given, see `keep_updating_dataset_stats` for details. 

128# """ 

129# keep_updating_dataset_stats: Optional[bool] = None 

130# """indicates if operator calls should keep updating dataset statistics or not 

131 

132# default (None): if `measures` is a `Mapping` (i.e. initial measure values are 

133# given) no further updates to dataset statistics is conducted, otherwise (w.o. 

134# initial measure values) dataset statistics are updated by each processed sample. 

135# """ 

136# _keep_updating_dataset_stats: bool = field(init=False) 

137# _stats_calculator: StatsCalculator = field(init=False) 

138 

139# @property 

140# def required_measures(self) -> Set[Measure]: 

141# return set() 

142 

143# def __post_init__(self): 

144# self._stats_calculator = StatsCalculator(self.measures) 

145# if self.keep_updating_dataset_stats is None: 

146# self._keep_updating_dataset_stats = not isinstance(self.measures, collections.abc.Mapping) 

147# else: 

148# self._keep_updating_dataset_stats = self.keep_updating_dataset_stats 

149 

150# def __call__(self, sample_block: SampleBlockWithOrigin> None: 

151# if self._keep_updating_dataset_stats: 

152# sample.stat.update(self._stats_calculator.update_and_get_all(sample)) 

153# else: 

154# sample.stat.update(self._stats_calculator.skip_update_and_get_all(sample)) 

155 

156 

157@dataclass 

158class UpdateStats(Operator): 

159 """Calculates sample and/or dataset measures""" 

160 

161 stats_calculator: StatsCalculator 

162 """`StatsCalculator` to be used by this operator.""" 

163 keep_updating_initial_dataset_stats: bool = False 

164 """indicates if operator calls should keep updating initial dataset statistics or not; 

165 if the `stats_calculator` was not provided with any initial dataset statistics, 

166 these are always updated with every new sample. 

167 """ 

168 _keep_updating_dataset_stats: bool = field(init=False) 

169 

170 @property 

171 def required_measures(self) -> Set[Measure]: 

172 return set() 

173 

174 def __post_init__(self): 

175 self._keep_updating_dataset_stats = ( 

176 self.keep_updating_initial_dataset_stats 

177 or not self.stats_calculator.has_dataset_measures 

178 ) 

179 

180 def __call__(self, sample: Union[Sample, SampleBlockWithOrigin]) -> None: 

181 if isinstance(sample, SampleBlockWithOrigin): 

182 # update stats with whole sample on first block 

183 if sample.block_index != 0: 

184 return 

185 

186 origin = sample.origin 

187 else: 

188 origin = sample 

189 

190 if self._keep_updating_dataset_stats: 

191 sample.stat.update(self.stats_calculator.update_and_get_all(origin)) 

192 else: 

193 sample.stat.update(self.stats_calculator.skip_update_and_get_all(origin)) 

194 

195 

196@dataclass 

197class Binarize(_SimpleOperator): 

198 """'output = tensor > threshold'.""" 

199 

200 threshold: Union[float, Sequence[float]] 

201 axis: Optional[AxisId] = None 

202 

203 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 

204 return input > self.threshold 

205 

206 def get_output_shape( 

207 self, input_shape: Mapping[AxisId, int] 

208 ) -> Mapping[AxisId, int]: 

209 return input_shape 

210 

211 @classmethod 

212 def from_proc_descr( 

213 cls, descr: Union[v0_4.BinarizeDescr, v0_5.BinarizeDescr], member_id: MemberId 

214 ) -> Self: 

215 if isinstance(descr.kwargs, (v0_4.BinarizeKwargs, v0_5.BinarizeKwargs)): 

216 return cls( 

217 input=member_id, output=member_id, threshold=descr.kwargs.threshold 

218 ) 

219 elif isinstance(descr.kwargs, v0_5.BinarizeAlongAxisKwargs): 

220 return cls( 

221 input=member_id, 

222 output=member_id, 

223 threshold=descr.kwargs.threshold, 

224 axis=descr.kwargs.axis, 

225 ) 

226 else: 

227 assert_never(descr.kwargs) 

228 

229 

230@dataclass 

231class Clip(_SimpleOperator): 

232 min: Optional[float] = None 

233 """minimum value for clipping""" 

234 max: Optional[float] = None 

235 """maximum value for clipping""" 

236 

237 def __post_init__(self): 

238 assert self.min is not None or self.max is not None, "missing min or max value" 

239 assert ( 

240 self.min is None or self.max is None or self.min < self.max 

241 ), f"expected min < max, but {self.min} !< {self.max}" 

242 

243 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 

244 return input.clip(self.min, self.max) 

245 

246 def get_output_shape( 

247 self, input_shape: Mapping[AxisId, int] 

248 ) -> Mapping[AxisId, int]: 

249 return input_shape 

250 

251 @classmethod 

252 def from_proc_descr( 

253 cls, descr: Union[v0_4.ClipDescr, v0_5.ClipDescr], member_id: MemberId 

254 ) -> Self: 

255 return cls( 

256 input=member_id, 

257 output=member_id, 

258 min=descr.kwargs.min, 

259 max=descr.kwargs.max, 

260 ) 

261 

262 

263@dataclass 

264class EnsureDtype(_SimpleOperator): 

265 dtype: DTypeStr 

266 

267 @classmethod 

268 def from_proc_descr(cls, descr: v0_5.EnsureDtypeDescr, member_id: MemberId): 

269 return cls(input=member_id, output=member_id, dtype=descr.kwargs.dtype) 

270 

271 def get_descr(self): 

272 return v0_5.EnsureDtypeDescr(kwargs=v0_5.EnsureDtypeKwargs(dtype=self.dtype)) 

273 

274 def get_output_shape( 

275 self, input_shape: Mapping[AxisId, int] 

276 ) -> Mapping[AxisId, int]: 

277 return input_shape 

278 

279 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 

280 return input.astype(self.dtype) 

281 

282 

283@dataclass 

284class ScaleLinear(_SimpleOperator): 

285 gain: Union[float, xr.DataArray] = 1.0 

286 """multiplicative factor""" 

287 

288 offset: Union[float, xr.DataArray] = 0.0 

289 """additive term""" 

290 

291 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 

292 return input * self.gain + self.offset 

293 

294 def get_output_shape( 

295 self, input_shape: Mapping[AxisId, int] 

296 ) -> Mapping[AxisId, int]: 

297 return input_shape 

298 

299 @classmethod 

300 def from_proc_descr( 

301 cls, 

302 descr: Union[v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr], 

303 member_id: MemberId, 

304 ) -> Self: 

305 kwargs = descr.kwargs 

306 if isinstance(kwargs, v0_5.ScaleLinearKwargs): 

307 axis = None 

308 elif isinstance(kwargs, v0_5.ScaleLinearAlongAxisKwargs): 

309 axis = kwargs.axis 

310 elif isinstance(kwargs, v0_4.ScaleLinearKwargs): 

311 if kwargs.axes is not None: 

312 raise NotImplementedError( 

313 "model.v0_4.ScaleLinearKwargs with axes not implemented, please consider updating the model to v0_5." 

314 ) 

315 axis = None 

316 else: 

317 assert_never(kwargs) 

318 

319 if axis: 

320 gain = xr.DataArray(np.atleast_1d(kwargs.gain), dims=axis) 

321 offset = xr.DataArray(np.atleast_1d(kwargs.offset), dims=axis) 

322 else: 

323 assert ( 

324 isinstance(kwargs.gain, (float, int)) or len(kwargs.gain) == 1 

325 ), kwargs.gain 

326 gain = ( 

327 kwargs.gain if isinstance(kwargs.gain, (float, int)) else kwargs.gain[0] 

328 ) 

329 assert isinstance(kwargs.offset, (float, int)) or len(kwargs.offset) == 1 

330 offset = ( 

331 kwargs.offset 

332 if isinstance(kwargs.offset, (float, int)) 

333 else kwargs.offset[0] 

334 ) 

335 

336 return cls(input=member_id, output=member_id, gain=gain, offset=offset) 

337 

338 

339@dataclass 

340class ScaleMeanVariance(_SimpleOperator): 

341 axes: Optional[Sequence[AxisId]] = None 

342 reference_tensor: Optional[MemberId] = None 

343 eps: float = 1e-6 

344 mean: Union[SampleMean, DatasetMean] = field(init=False) 

345 std: Union[SampleStd, DatasetStd] = field(init=False) 

346 ref_mean: Union[SampleMean, DatasetMean] = field(init=False) 

347 ref_std: Union[SampleStd, DatasetStd] = field(init=False) 

348 

349 @property 

350 def required_measures(self): 

351 return {self.mean, self.std, self.ref_mean, self.ref_std} 

352 

353 def __post_init__(self): 

354 axes = None if self.axes is None else tuple(self.axes) 

355 ref_tensor = self.reference_tensor or self.input 

356 if axes is None or AxisId("batch") not in axes: 

357 Mean = SampleMean 

358 Std = SampleStd 

359 else: 

360 Mean = DatasetMean 

361 Std = DatasetStd 

362 

363 self.mean = Mean(member_id=self.input, axes=axes) 

364 self.std = Std(member_id=self.input, axes=axes) 

365 self.ref_mean = Mean(member_id=ref_tensor, axes=axes) 

366 self.ref_std = Std(member_id=ref_tensor, axes=axes) 

367 

368 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 

369 mean = stat[self.mean] 

370 std = stat[self.std] + self.eps 

371 ref_mean = stat[self.ref_mean] 

372 ref_std = stat[self.ref_std] + self.eps 

373 return (input - mean) / std * ref_std + ref_mean 

374 

375 def get_output_shape( 

376 self, input_shape: Mapping[AxisId, int] 

377 ) -> Mapping[AxisId, int]: 

378 return input_shape 

379 

380 @classmethod 

381 def from_proc_descr( 

382 cls, 

383 descr: Union[v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr], 

384 member_id: MemberId, 

385 ) -> Self: 

386 kwargs = descr.kwargs 

387 _, axes = _get_axes(descr.kwargs) 

388 

389 return cls( 

390 input=member_id, 

391 output=member_id, 

392 reference_tensor=MemberId(str(kwargs.reference_tensor)), 

393 axes=axes, 

394 eps=kwargs.eps, 

395 ) 

396 

397 

398def _get_axes( 

399 kwargs: Union[ 

400 v0_4.ZeroMeanUnitVarianceKwargs, 

401 v0_5.ZeroMeanUnitVarianceKwargs, 

402 v0_4.ScaleRangeKwargs, 

403 v0_5.ScaleRangeKwargs, 

404 v0_4.ScaleMeanVarianceKwargs, 

405 v0_5.ScaleMeanVarianceKwargs, 

406 ], 

407) -> Tuple[bool, Optional[Tuple[AxisId, ...]]]: 

408 if kwargs.axes is None: 

409 return True, None 

410 elif isinstance(kwargs.axes, str): 

411 axes = _convert_axis_ids(kwargs.axes, kwargs["mode"]) 

412 return AxisId("b") in axes, axes 

413 elif isinstance(kwargs.axes, collections.abc.Sequence): 

414 axes = tuple(kwargs.axes) 

415 return AxisId("batch") in axes, axes 

416 else: 

417 assert_never(kwargs.axes) 

418 

419 

420@dataclass 

421class ScaleRange(_SimpleOperator): 

422 lower_percentile: InitVar[Optional[Union[SampleQuantile, DatasetPercentile]]] = None 

423 upper_percentile: InitVar[Optional[Union[SampleQuantile, DatasetPercentile]]] = None 

424 lower: Union[SampleQuantile, DatasetPercentile] = field(init=False) 

425 upper: Union[SampleQuantile, DatasetPercentile] = field(init=False) 

426 

427 eps: float = 1e-6 

428 

429 def __post_init__( 

430 self, 

431 lower_percentile: Optional[Union[SampleQuantile, DatasetPercentile]], 

432 upper_percentile: Optional[Union[SampleQuantile, DatasetPercentile]], 

433 ): 

434 if lower_percentile is None: 

435 tid = self.input if upper_percentile is None else upper_percentile.member_id 

436 self.lower = DatasetPercentile(q=0.0, member_id=tid) 

437 else: 

438 self.lower = lower_percentile 

439 

440 if upper_percentile is None: 

441 self.upper = DatasetPercentile(q=1.0, member_id=self.lower.member_id) 

442 else: 

443 self.upper = upper_percentile 

444 

445 assert self.lower.member_id == self.upper.member_id 

446 assert self.lower.q < self.upper.q 

447 assert self.lower.axes == self.upper.axes 

448 

449 @property 

450 def required_measures(self): 

451 return {self.lower, self.upper} 

452 

453 def get_output_shape( 

454 self, input_shape: Mapping[AxisId, int] 

455 ) -> Mapping[AxisId, int]: 

456 return input_shape 

457 

458 @classmethod 

459 def from_proc_descr( 

460 cls, 

461 descr: Union[v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr], 

462 member_id: MemberId, 

463 ): 

464 kwargs = descr.kwargs 

465 ref_tensor = ( 

466 member_id 

467 if kwargs.reference_tensor is None 

468 else MemberId(str(kwargs.reference_tensor)) 

469 ) 

470 dataset_mode, axes = _get_axes(descr.kwargs) 

471 if dataset_mode: 

472 Percentile = DatasetPercentile 

473 else: 

474 Percentile = SampleQuantile 

475 

476 return cls( 

477 input=member_id, 

478 output=member_id, 

479 lower_percentile=Percentile( 

480 q=kwargs.min_percentile / 100, axes=axes, member_id=ref_tensor 

481 ), 

482 upper_percentile=Percentile( 

483 q=kwargs.max_percentile / 100, axes=axes, member_id=ref_tensor 

484 ), 

485 ) 

486 

487 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 

488 lower = stat[self.lower] 

489 upper = stat[self.upper] 

490 return (input - lower) / (upper - lower + self.eps) 

491 

492 def get_descr(self): 

493 assert self.lower.axes == self.upper.axes 

494 assert self.lower.member_id == self.upper.member_id 

495 

496 return v0_5.ScaleRangeDescr( 

497 kwargs=v0_5.ScaleRangeKwargs( 

498 axes=self.lower.axes, 

499 min_percentile=self.lower.q * 100, 

500 max_percentile=self.upper.q * 100, 

501 eps=self.eps, 

502 reference_tensor=self.lower.member_id, 

503 ) 

504 ) 

505 

506 

507@dataclass 

508class Sigmoid(_SimpleOperator): 

509 """1 / (1 + e^(-input)).""" 

510 

511 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 

512 return Tensor(1.0 / (1.0 + np.exp(-input)), dims=input.dims) 

513 

514 @property 

515 def required_measures(self) -> Collection[Measure]: 

516 return {} 

517 

518 def get_output_shape( 

519 self, input_shape: Mapping[AxisId, int] 

520 ) -> Mapping[AxisId, int]: 

521 return input_shape 

522 

523 @classmethod 

524 def from_proc_descr( 

525 cls, descr: Union[v0_4.SigmoidDescr, v0_5.SigmoidDescr], member_id: MemberId 

526 ) -> Self: 

527 assert isinstance(descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)) 

528 return cls(input=member_id, output=member_id) 

529 

530 def get_descr(self): 

531 return v0_5.SigmoidDescr() 

532 

533 

534@dataclass 

535class ZeroMeanUnitVariance(_SimpleOperator): 

536 """normalize to zero mean, unit variance.""" 

537 

538 mean: MeanMeasure 

539 std: StdMeasure 

540 

541 eps: float = 1e-6 

542 

543 def __post_init__(self): 

544 assert self.mean.axes == self.std.axes 

545 

546 @property 

547 def required_measures(self) -> Set[Union[MeanMeasure, StdMeasure]]: 

548 return {self.mean, self.std} 

549 

550 def get_output_shape( 

551 self, input_shape: Mapping[AxisId, int] 

552 ) -> Mapping[AxisId, int]: 

553 return input_shape 

554 

555 @classmethod 

556 def from_proc_descr( 

557 cls, 

558 descr: Union[v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr], 

559 member_id: MemberId, 

560 ): 

561 dataset_mode, axes = _get_axes(descr.kwargs) 

562 

563 if dataset_mode: 

564 Mean = DatasetMean 

565 Std = DatasetStd 

566 else: 

567 Mean = SampleMean 

568 Std = SampleStd 

569 

570 return cls( 

571 input=member_id, 

572 output=member_id, 

573 mean=Mean(axes=axes, member_id=member_id), 

574 std=Std(axes=axes, member_id=member_id), 

575 ) 

576 

577 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 

578 mean = stat[self.mean] 

579 std = stat[self.std] 

580 return (input - mean) / (std + self.eps) 

581 

582 def get_descr(self): 

583 return v0_5.ZeroMeanUnitVarianceDescr( 

584 kwargs=v0_5.ZeroMeanUnitVarianceKwargs(axes=self.mean.axes, eps=self.eps) 

585 ) 

586 

587 

588@dataclass 

589class FixedZeroMeanUnitVariance(_SimpleOperator): 

590 """normalize to zero mean, unit variance with precomputed values.""" 

591 

592 mean: Union[float, xr.DataArray] 

593 std: Union[float, xr.DataArray] 

594 

595 eps: float = 1e-6 

596 

597 def __post_init__(self): 

598 assert ( 

599 isinstance(self.mean, (int, float)) 

600 or isinstance(self.std, (int, float)) 

601 or self.mean.dims == self.std.dims 

602 ) 

603 

604 def get_output_shape( 

605 self, input_shape: Mapping[AxisId, int] 

606 ) -> Mapping[AxisId, int]: 

607 return input_shape 

608 

609 @classmethod 

610 def from_proc_descr( 

611 cls, 

612 descr: v0_5.FixedZeroMeanUnitVarianceDescr, 

613 member_id: MemberId, 

614 ) -> Self: 

615 if isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceKwargs): 

616 dims = None 

617 elif isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs): 

618 dims = (AxisId(descr.kwargs.axis),) 

619 else: 

620 assert_never(descr.kwargs) 

621 

622 return cls( 

623 input=member_id, 

624 output=member_id, 

625 mean=xr.DataArray(descr.kwargs.mean, dims=dims), 

626 std=xr.DataArray(descr.kwargs.std, dims=dims), 

627 ) 

628 

629 def get_descr(self): 

630 if isinstance(self.mean, (int, float)): 

631 assert isinstance(self.std, (int, float)) 

632 kwargs = v0_5.FixedZeroMeanUnitVarianceKwargs(mean=self.mean, std=self.std) 

633 else: 

634 assert isinstance(self.std, xr.DataArray) 

635 assert len(self.mean.dims) == 1 

636 kwargs = v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs( 

637 axis=AxisId(str(self.mean.dims[0])), 

638 mean=list(self.mean), 

639 std=list(self.std), 

640 ) 

641 

642 return v0_5.FixedZeroMeanUnitVarianceDescr(kwargs=kwargs) 

643 

644 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 

645 return (input - self.mean) / (self.std + self.eps) 

646 

647 

648ProcDescr = Union[ 

649 v0_4.PreprocessingDescr, 

650 v0_4.PostprocessingDescr, 

651 v0_5.PreprocessingDescr, 

652 v0_5.PostprocessingDescr, 

653] 

654 

655Processing = Union[ 

656 AddKnownDatasetStats, 

657 Binarize, 

658 Clip, 

659 EnsureDtype, 

660 FixedZeroMeanUnitVariance, 

661 ScaleLinear, 

662 ScaleMeanVariance, 

663 ScaleRange, 

664 Sigmoid, 

665 UpdateStats, 

666 ZeroMeanUnitVariance, 

667] 

668 

669 

670def get_proc( 

671 proc_descr: ProcDescr, 

672 tensor_descr: Union[ 

673 v0_4.InputTensorDescr, 

674 v0_4.OutputTensorDescr, 

675 v0_5.InputTensorDescr, 

676 v0_5.OutputTensorDescr, 

677 ], 

678) -> Processing: 

679 member_id = get_member_id(tensor_descr) 

680 

681 if isinstance(proc_descr, (v0_4.BinarizeDescr, v0_5.BinarizeDescr)): 

682 return Binarize.from_proc_descr(proc_descr, member_id) 

683 elif isinstance(proc_descr, (v0_4.ClipDescr, v0_5.ClipDescr)): 

684 return Clip.from_proc_descr(proc_descr, member_id) 

685 elif isinstance(proc_descr, v0_5.EnsureDtypeDescr): 

686 return EnsureDtype.from_proc_descr(proc_descr, member_id) 

687 elif isinstance(proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr): 

688 return FixedZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id) 

689 elif isinstance(proc_descr, (v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr)): 

690 return ScaleLinear.from_proc_descr(proc_descr, member_id) 

691 elif isinstance( 

692 proc_descr, (v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr) 

693 ): 

694 return ScaleMeanVariance.from_proc_descr(proc_descr, member_id) 

695 elif isinstance(proc_descr, (v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr)): 

696 return ScaleRange.from_proc_descr(proc_descr, member_id) 

697 elif isinstance(proc_descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)): 

698 return Sigmoid.from_proc_descr(proc_descr, member_id) 

699 elif ( 

700 isinstance(proc_descr, v0_4.ZeroMeanUnitVarianceDescr) 

701 and proc_descr.kwargs.mode == "fixed" 

702 ): 

703 if not isinstance( 

704 tensor_descr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr) 

705 ): 

706 raise TypeError( 

707 "Expected v0_4 tensor description for v0_4 processing description" 

708 ) 

709 

710 v5_proc_descr = _convert_proc(proc_descr, tensor_descr.axes) 

711 assert isinstance(v5_proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr) 

712 return FixedZeroMeanUnitVariance.from_proc_descr(v5_proc_descr, member_id) 

713 elif isinstance( 

714 proc_descr, 

715 (v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr), 

716 ): 

717 return ZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id) 

718 else: 

719 assert_never(proc_descr)