bioimageio.core.proc_ops
1import collections.abc 2from abc import ABC, abstractmethod 3from dataclasses import InitVar, dataclass, field 4from typing import ( 5 Collection, 6 Literal, 7 Mapping, 8 Optional, 9 Sequence, 10 Set, 11 Tuple, 12 Union, 13) 14 15import numpy as np 16import scipy # pyright: ignore[reportMissingTypeStubs] 17import xarray as xr 18from typing_extensions import Self, assert_never 19 20from bioimageio.core.digest_spec import get_member_id 21from bioimageio.spec.model import v0_4, v0_5 22from bioimageio.spec.model.v0_5 import ( 23 _convert_proc, # pyright: ignore [reportPrivateUsage] 24) 25 26from ._op_base import BlockedOperator, Operator 27from .axis import AxisId, PerAxis 28from .block import Block 29from .common import DTypeStr, MemberId 30from .sample import Sample, SampleBlock, SampleBlockWithOrigin 31from .stat_calculators import StatsCalculator 32from .stat_measures import ( 33 DatasetMean, 34 DatasetMeasure, 35 DatasetPercentile, 36 DatasetStd, 37 MeanMeasure, 38 Measure, 39 MeasureValue, 40 SampleMean, 41 SampleQuantile, 42 SampleStd, 43 Stat, 44 StdMeasure, 45) 46from .tensor import Tensor 47 48 49def _convert_axis_ids( 50 axes: v0_4.AxesInCZYX, 51 mode: Literal["per_sample", "per_dataset"], 52) -> Tuple[AxisId, ...]: 53 if not isinstance(axes, str): 54 return tuple(axes) 55 56 if mode == "per_sample": 57 ret = [] 58 elif mode == "per_dataset": 59 ret = [v0_5.BATCH_AXIS_ID] 60 else: 61 assert_never(mode) 62 63 ret.extend([AxisId(a) for a in axes]) 64 return tuple(ret) 65 66 67@dataclass 68class _SimpleOperator(BlockedOperator, ABC): 69 input: MemberId 70 output: MemberId 71 72 @property 73 def required_measures(self) -> Collection[Measure]: 74 return set() 75 76 @abstractmethod 77 def get_output_shape(self, input_shape: PerAxis[int]) -> PerAxis[int]: ... 78 79 def __call__(self, sample: Union[Sample, SampleBlock]) -> None: 80 if self.input not in sample.members: 81 return 82 83 input_tensor = sample.members[self.input] 84 output_tensor = self._apply(input_tensor, sample.stat) 85 86 if self.output in sample.members: 87 assert ( 88 sample.members[self.output].tagged_shape == output_tensor.tagged_shape 89 ) 90 91 if isinstance(sample, Sample): 92 sample.members[self.output] = output_tensor 93 elif isinstance(sample, SampleBlock): 94 b = sample.blocks[self.input] 95 sample.blocks[self.output] = Block( 96 sample_shape=self.get_output_shape(sample.shape[self.input]), 97 data=output_tensor, 98 inner_slice=b.inner_slice, 99 halo=b.halo, 100 block_index=b.block_index, 101 blocks_in_sample=b.blocks_in_sample, 102 ) 103 else: 104 assert_never(sample) 105 106 @abstractmethod 107 def _apply(self, x: Tensor, stat: Stat) -> Tensor: ... 108 109 110@dataclass 111class AddKnownDatasetStats(BlockedOperator): 112 dataset_stats: Mapping[DatasetMeasure, MeasureValue] 113 114 @property 115 def required_measures(self) -> Set[Measure]: 116 return set() 117 118 def __call__(self, sample: Union[Sample, SampleBlock]) -> None: 119 sample.stat.update(self.dataset_stats.items()) 120 121 122# @dataclass 123# class UpdateStats(Operator): 124# """Calculates sample and/or dataset measures""" 125 126# measures: Union[Sequence[Measure], Set[Measure], Mapping[Measure, MeasureValue]] 127# """sample and dataset `measuers` to be calculated by this operator. Initial/fixed 128# dataset measure values may be given, see `keep_updating_dataset_stats` for details. 129# """ 130# keep_updating_dataset_stats: Optional[bool] = None 131# """indicates if operator calls should keep updating dataset statistics or not 132 133# default (None): if `measures` is a `Mapping` (i.e. initial measure values are 134# given) no further updates to dataset statistics is conducted, otherwise (w.o. 135# initial measure values) dataset statistics are updated by each processed sample. 136# """ 137# _keep_updating_dataset_stats: bool = field(init=False) 138# _stats_calculator: StatsCalculator = field(init=False) 139 140# @property 141# def required_measures(self) -> Set[Measure]: 142# return set() 143 144# def __post_init__(self): 145# self._stats_calculator = StatsCalculator(self.measures) 146# if self.keep_updating_dataset_stats is None: 147# self._keep_updating_dataset_stats = not isinstance(self.measures, collections.abc.Mapping) 148# else: 149# self._keep_updating_dataset_stats = self.keep_updating_dataset_stats 150 151# def __call__(self, sample_block: SampleBlockWithOrigin> None: 152# if self._keep_updating_dataset_stats: 153# sample.stat.update(self._stats_calculator.update_and_get_all(sample)) 154# else: 155# sample.stat.update(self._stats_calculator.skip_update_and_get_all(sample)) 156 157 158@dataclass 159class UpdateStats(Operator): 160 """Calculates sample and/or dataset measures""" 161 162 stats_calculator: StatsCalculator 163 """`StatsCalculator` to be used by this operator.""" 164 keep_updating_initial_dataset_stats: bool = False 165 """indicates if operator calls should keep updating initial dataset statistics or not; 166 if the `stats_calculator` was not provided with any initial dataset statistics, 167 these are always updated with every new sample. 168 """ 169 _keep_updating_dataset_stats: bool = field(init=False) 170 171 @property 172 def required_measures(self) -> Set[Measure]: 173 return set() 174 175 def __post_init__(self): 176 self._keep_updating_dataset_stats = ( 177 self.keep_updating_initial_dataset_stats 178 or not self.stats_calculator.has_dataset_measures 179 ) 180 181 def __call__(self, sample: Union[Sample, SampleBlockWithOrigin]) -> None: 182 if isinstance(sample, SampleBlockWithOrigin): 183 # update stats with whole sample on first block 184 if sample.block_index != 0: 185 return 186 187 origin = sample.origin 188 else: 189 origin = sample 190 191 if self._keep_updating_dataset_stats: 192 sample.stat.update(self.stats_calculator.update_and_get_all(origin)) 193 else: 194 sample.stat.update(self.stats_calculator.skip_update_and_get_all(origin)) 195 196 197@dataclass 198class Binarize(_SimpleOperator): 199 """'output = tensor > threshold'.""" 200 201 threshold: Union[float, Sequence[float]] 202 axis: Optional[AxisId] = None 203 204 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 205 return x > self.threshold 206 207 def get_output_shape( 208 self, input_shape: Mapping[AxisId, int] 209 ) -> Mapping[AxisId, int]: 210 return input_shape 211 212 @classmethod 213 def from_proc_descr( 214 cls, descr: Union[v0_4.BinarizeDescr, v0_5.BinarizeDescr], member_id: MemberId 215 ) -> Self: 216 if isinstance(descr.kwargs, (v0_4.BinarizeKwargs, v0_5.BinarizeKwargs)): 217 return cls( 218 input=member_id, output=member_id, threshold=descr.kwargs.threshold 219 ) 220 elif isinstance(descr.kwargs, v0_5.BinarizeAlongAxisKwargs): 221 return cls( 222 input=member_id, 223 output=member_id, 224 threshold=descr.kwargs.threshold, 225 axis=descr.kwargs.axis, 226 ) 227 else: 228 assert_never(descr.kwargs) 229 230 231@dataclass 232class Clip(_SimpleOperator): 233 min: Optional[float] = None 234 """minimum value for clipping""" 235 max: Optional[float] = None 236 """maximum value for clipping""" 237 238 def __post_init__(self): 239 assert self.min is not None or self.max is not None, "missing min or max value" 240 assert self.min is None or self.max is None or self.min < self.max, ( 241 f"expected min < max, but {self.min} !< {self.max}" 242 ) 243 244 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 245 return x.clip(self.min, self.max) 246 247 def get_output_shape( 248 self, input_shape: Mapping[AxisId, int] 249 ) -> Mapping[AxisId, int]: 250 return input_shape 251 252 @classmethod 253 def from_proc_descr( 254 cls, descr: Union[v0_4.ClipDescr, v0_5.ClipDescr], member_id: MemberId 255 ) -> Self: 256 return cls( 257 input=member_id, 258 output=member_id, 259 min=descr.kwargs.min, 260 max=descr.kwargs.max, 261 ) 262 263 264@dataclass 265class EnsureDtype(_SimpleOperator): 266 dtype: DTypeStr 267 268 @classmethod 269 def from_proc_descr(cls, descr: v0_5.EnsureDtypeDescr, member_id: MemberId): 270 return cls(input=member_id, output=member_id, dtype=descr.kwargs.dtype) 271 272 def get_descr(self): 273 return v0_5.EnsureDtypeDescr(kwargs=v0_5.EnsureDtypeKwargs(dtype=self.dtype)) 274 275 def get_output_shape( 276 self, input_shape: Mapping[AxisId, int] 277 ) -> Mapping[AxisId, int]: 278 return input_shape 279 280 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 281 return x.astype(self.dtype) 282 283 284@dataclass 285class ScaleLinear(_SimpleOperator): 286 gain: Union[float, xr.DataArray] = 1.0 287 """multiplicative factor""" 288 289 offset: Union[float, xr.DataArray] = 0.0 290 """additive term""" 291 292 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 293 return x * self.gain + self.offset 294 295 def get_output_shape( 296 self, input_shape: Mapping[AxisId, int] 297 ) -> Mapping[AxisId, int]: 298 return input_shape 299 300 @classmethod 301 def from_proc_descr( 302 cls, 303 descr: Union[v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr], 304 member_id: MemberId, 305 ) -> Self: 306 kwargs = descr.kwargs 307 if isinstance(kwargs, v0_5.ScaleLinearKwargs): 308 axis = None 309 elif isinstance(kwargs, v0_5.ScaleLinearAlongAxisKwargs): 310 axis = kwargs.axis 311 elif isinstance(kwargs, v0_4.ScaleLinearKwargs): 312 if kwargs.axes is not None: 313 raise NotImplementedError( 314 "model.v0_4.ScaleLinearKwargs with axes not implemented, please consider updating the model to v0_5." 315 ) 316 axis = None 317 else: 318 assert_never(kwargs) 319 320 if axis: 321 gain = xr.DataArray(np.atleast_1d(kwargs.gain), dims=axis) 322 offset = xr.DataArray(np.atleast_1d(kwargs.offset), dims=axis) 323 else: 324 assert isinstance(kwargs.gain, (float, int)) or len(kwargs.gain) == 1, ( 325 kwargs.gain 326 ) 327 gain = ( 328 kwargs.gain if isinstance(kwargs.gain, (float, int)) else kwargs.gain[0] 329 ) 330 assert isinstance(kwargs.offset, (float, int)) or len(kwargs.offset) == 1 331 offset = ( 332 kwargs.offset 333 if isinstance(kwargs.offset, (float, int)) 334 else kwargs.offset[0] 335 ) 336 337 return cls(input=member_id, output=member_id, gain=gain, offset=offset) 338 339 340@dataclass 341class ScaleMeanVariance(_SimpleOperator): 342 axes: Optional[Sequence[AxisId]] = None 343 reference_tensor: Optional[MemberId] = None 344 eps: float = 1e-6 345 mean: Union[SampleMean, DatasetMean] = field(init=False) 346 std: Union[SampleStd, DatasetStd] = field(init=False) 347 ref_mean: Union[SampleMean, DatasetMean] = field(init=False) 348 ref_std: Union[SampleStd, DatasetStd] = field(init=False) 349 350 @property 351 def required_measures(self): 352 return {self.mean, self.std, self.ref_mean, self.ref_std} 353 354 def __post_init__(self): 355 axes = None if self.axes is None else tuple(self.axes) 356 ref_tensor = self.reference_tensor or self.input 357 if axes is None or AxisId("batch") not in axes: 358 Mean = SampleMean 359 Std = SampleStd 360 else: 361 Mean = DatasetMean 362 Std = DatasetStd 363 364 self.mean = Mean(member_id=self.input, axes=axes) 365 self.std = Std(member_id=self.input, axes=axes) 366 self.ref_mean = Mean(member_id=ref_tensor, axes=axes) 367 self.ref_std = Std(member_id=ref_tensor, axes=axes) 368 369 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 370 mean = stat[self.mean] 371 std = stat[self.std] + self.eps 372 ref_mean = stat[self.ref_mean] 373 ref_std = stat[self.ref_std] + self.eps 374 return (x - mean) / std * ref_std + ref_mean 375 376 def get_output_shape( 377 self, input_shape: Mapping[AxisId, int] 378 ) -> Mapping[AxisId, int]: 379 return input_shape 380 381 @classmethod 382 def from_proc_descr( 383 cls, 384 descr: Union[v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr], 385 member_id: MemberId, 386 ) -> Self: 387 kwargs = descr.kwargs 388 _, axes = _get_axes(descr.kwargs) 389 390 return cls( 391 input=member_id, 392 output=member_id, 393 reference_tensor=MemberId(str(kwargs.reference_tensor)), 394 axes=axes, 395 eps=kwargs.eps, 396 ) 397 398 399def _get_axes( 400 kwargs: Union[ 401 v0_4.ZeroMeanUnitVarianceKwargs, 402 v0_5.ZeroMeanUnitVarianceKwargs, 403 v0_4.ScaleRangeKwargs, 404 v0_5.ScaleRangeKwargs, 405 v0_4.ScaleMeanVarianceKwargs, 406 v0_5.ScaleMeanVarianceKwargs, 407 ], 408) -> Tuple[bool, Optional[Tuple[AxisId, ...]]]: 409 if kwargs.axes is None: 410 return True, None 411 elif isinstance(kwargs.axes, str): 412 axes = _convert_axis_ids(kwargs.axes, kwargs["mode"]) 413 return AxisId("b") in axes, axes 414 elif isinstance(kwargs.axes, collections.abc.Sequence): 415 axes = tuple(kwargs.axes) 416 return AxisId("batch") in axes, axes 417 else: 418 assert_never(kwargs.axes) 419 420 421@dataclass 422class ScaleRange(_SimpleOperator): 423 lower_percentile: InitVar[Optional[Union[SampleQuantile, DatasetPercentile]]] = None 424 upper_percentile: InitVar[Optional[Union[SampleQuantile, DatasetPercentile]]] = None 425 lower: Union[SampleQuantile, DatasetPercentile] = field(init=False) 426 upper: Union[SampleQuantile, DatasetPercentile] = field(init=False) 427 428 eps: float = 1e-6 429 430 def __post_init__( 431 self, 432 lower_percentile: Optional[Union[SampleQuantile, DatasetPercentile]], 433 upper_percentile: Optional[Union[SampleQuantile, DatasetPercentile]], 434 ): 435 if lower_percentile is None: 436 tid = self.input if upper_percentile is None else upper_percentile.member_id 437 self.lower = DatasetPercentile(q=0.0, member_id=tid) 438 else: 439 self.lower = lower_percentile 440 441 if upper_percentile is None: 442 self.upper = DatasetPercentile(q=1.0, member_id=self.lower.member_id) 443 else: 444 self.upper = upper_percentile 445 446 assert self.lower.member_id == self.upper.member_id 447 assert self.lower.q < self.upper.q 448 assert self.lower.axes == self.upper.axes 449 450 @property 451 def required_measures(self): 452 return {self.lower, self.upper} 453 454 def get_output_shape( 455 self, input_shape: Mapping[AxisId, int] 456 ) -> Mapping[AxisId, int]: 457 return input_shape 458 459 @classmethod 460 def from_proc_descr( 461 cls, 462 descr: Union[v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr], 463 member_id: MemberId, 464 ): 465 kwargs = descr.kwargs 466 ref_tensor = ( 467 member_id 468 if kwargs.reference_tensor is None 469 else MemberId(str(kwargs.reference_tensor)) 470 ) 471 dataset_mode, axes = _get_axes(descr.kwargs) 472 if dataset_mode: 473 Percentile = DatasetPercentile 474 else: 475 Percentile = SampleQuantile 476 477 return cls( 478 input=member_id, 479 output=member_id, 480 lower_percentile=Percentile( 481 q=kwargs.min_percentile / 100, axes=axes, member_id=ref_tensor 482 ), 483 upper_percentile=Percentile( 484 q=kwargs.max_percentile / 100, axes=axes, member_id=ref_tensor 485 ), 486 ) 487 488 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 489 lower = stat[self.lower] 490 upper = stat[self.upper] 491 return (x - lower) / (upper - lower + self.eps) 492 493 def get_descr(self): 494 assert self.lower.axes == self.upper.axes 495 assert self.lower.member_id == self.upper.member_id 496 497 return v0_5.ScaleRangeDescr( 498 kwargs=v0_5.ScaleRangeKwargs( 499 axes=self.lower.axes, 500 min_percentile=self.lower.q * 100, 501 max_percentile=self.upper.q * 100, 502 eps=self.eps, 503 reference_tensor=self.lower.member_id, 504 ) 505 ) 506 507 508@dataclass 509class Sigmoid(_SimpleOperator): 510 """1 / (1 + e^(-input)).""" 511 512 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 513 return Tensor(1.0 / (1.0 + np.exp(-x)), dims=x.dims) 514 515 @property 516 def required_measures(self) -> Collection[Measure]: 517 return {} 518 519 def get_output_shape( 520 self, input_shape: Mapping[AxisId, int] 521 ) -> Mapping[AxisId, int]: 522 return input_shape 523 524 @classmethod 525 def from_proc_descr( 526 cls, descr: Union[v0_4.SigmoidDescr, v0_5.SigmoidDescr], member_id: MemberId 527 ) -> Self: 528 assert isinstance(descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)) 529 return cls(input=member_id, output=member_id) 530 531 def get_descr(self): 532 return v0_5.SigmoidDescr() 533 534 535@dataclass 536class Softmax(_SimpleOperator): 537 """Softmax activation function.""" 538 539 axis: AxisId = AxisId("channel") 540 541 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 542 axis_idx = x.dims.index(self.axis) 543 result = scipy.special.softmax(x.data, axis=axis_idx) 544 result_xr = xr.DataArray(result, dims=x.dims) 545 return Tensor.from_xarray(result_xr) 546 547 @property 548 def required_measures(self) -> Collection[Measure]: 549 return {} 550 551 def get_output_shape( 552 self, input_shape: Mapping[AxisId, int] 553 ) -> Mapping[AxisId, int]: 554 return input_shape 555 556 @classmethod 557 def from_proc_descr(cls, descr: v0_5.SoftmaxDescr, member_id: MemberId) -> Self: 558 assert isinstance(descr, v0_5.SoftmaxDescr) 559 return cls(input=member_id, output=member_id, axis=descr.kwargs.axis) 560 561 def get_descr(self): 562 return v0_5.SoftmaxDescr(kwargs=v0_5.SoftmaxKwargs(axis=self.axis)) 563 564 565@dataclass 566class ZeroMeanUnitVariance(_SimpleOperator): 567 """normalize to zero mean, unit variance.""" 568 569 mean: MeanMeasure 570 std: StdMeasure 571 572 eps: float = 1e-6 573 574 def __post_init__(self): 575 assert self.mean.axes == self.std.axes 576 577 @property 578 def required_measures(self) -> Set[Union[MeanMeasure, StdMeasure]]: 579 return {self.mean, self.std} 580 581 def get_output_shape( 582 self, input_shape: Mapping[AxisId, int] 583 ) -> Mapping[AxisId, int]: 584 return input_shape 585 586 @classmethod 587 def from_proc_descr( 588 cls, 589 descr: Union[v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr], 590 member_id: MemberId, 591 ): 592 dataset_mode, axes = _get_axes(descr.kwargs) 593 594 if dataset_mode: 595 Mean = DatasetMean 596 Std = DatasetStd 597 else: 598 Mean = SampleMean 599 Std = SampleStd 600 601 return cls( 602 input=member_id, 603 output=member_id, 604 mean=Mean(axes=axes, member_id=member_id), 605 std=Std(axes=axes, member_id=member_id), 606 ) 607 608 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 609 mean = stat[self.mean] 610 std = stat[self.std] 611 return (x - mean) / (std + self.eps) 612 613 def get_descr(self): 614 return v0_5.ZeroMeanUnitVarianceDescr( 615 kwargs=v0_5.ZeroMeanUnitVarianceKwargs(axes=self.mean.axes, eps=self.eps) 616 ) 617 618 619@dataclass 620class FixedZeroMeanUnitVariance(_SimpleOperator): 621 """normalize to zero mean, unit variance with precomputed values.""" 622 623 mean: Union[float, xr.DataArray] 624 std: Union[float, xr.DataArray] 625 626 eps: float = 1e-6 627 628 def __post_init__(self): 629 assert ( 630 isinstance(self.mean, (int, float)) 631 or isinstance(self.std, (int, float)) 632 or self.mean.dims == self.std.dims 633 ) 634 635 def get_output_shape( 636 self, input_shape: Mapping[AxisId, int] 637 ) -> Mapping[AxisId, int]: 638 return input_shape 639 640 @classmethod 641 def from_proc_descr( 642 cls, 643 descr: v0_5.FixedZeroMeanUnitVarianceDescr, 644 member_id: MemberId, 645 ) -> Self: 646 if isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceKwargs): 647 dims = None 648 elif isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs): 649 dims = (AxisId(descr.kwargs.axis),) 650 else: 651 assert_never(descr.kwargs) 652 653 return cls( 654 input=member_id, 655 output=member_id, 656 mean=xr.DataArray(descr.kwargs.mean, dims=dims), 657 std=xr.DataArray(descr.kwargs.std, dims=dims), 658 ) 659 660 def get_descr(self): 661 if isinstance(self.mean, (int, float)): 662 assert isinstance(self.std, (int, float)) 663 kwargs = v0_5.FixedZeroMeanUnitVarianceKwargs(mean=self.mean, std=self.std) 664 else: 665 assert isinstance(self.std, xr.DataArray) 666 assert len(self.mean.dims) == 1 667 kwargs = v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs( 668 axis=AxisId(str(self.mean.dims[0])), 669 mean=list(self.mean), 670 std=list(self.std), 671 ) 672 673 return v0_5.FixedZeroMeanUnitVarianceDescr(kwargs=kwargs) 674 675 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 676 return (x - self.mean) / (self.std + self.eps) 677 678 679ProcDescr = Union[ 680 v0_4.PreprocessingDescr, 681 v0_4.PostprocessingDescr, 682 v0_5.PreprocessingDescr, 683 v0_5.PostprocessingDescr, 684] 685 686Processing = Union[ 687 AddKnownDatasetStats, 688 Binarize, 689 Clip, 690 EnsureDtype, 691 FixedZeroMeanUnitVariance, 692 ScaleLinear, 693 ScaleMeanVariance, 694 ScaleRange, 695 Sigmoid, 696 Softmax, 697 UpdateStats, 698 ZeroMeanUnitVariance, 699] 700 701 702def get_proc( 703 proc_descr: ProcDescr, 704 tensor_descr: Union[ 705 v0_4.InputTensorDescr, 706 v0_4.OutputTensorDescr, 707 v0_5.InputTensorDescr, 708 v0_5.OutputTensorDescr, 709 ], 710) -> Processing: 711 member_id = get_member_id(tensor_descr) 712 713 if isinstance(proc_descr, (v0_4.BinarizeDescr, v0_5.BinarizeDescr)): 714 return Binarize.from_proc_descr(proc_descr, member_id) 715 elif isinstance(proc_descr, (v0_4.ClipDescr, v0_5.ClipDescr)): 716 return Clip.from_proc_descr(proc_descr, member_id) 717 elif isinstance(proc_descr, v0_5.EnsureDtypeDescr): 718 return EnsureDtype.from_proc_descr(proc_descr, member_id) 719 elif isinstance(proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr): 720 return FixedZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id) 721 elif isinstance(proc_descr, (v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr)): 722 return ScaleLinear.from_proc_descr(proc_descr, member_id) 723 elif isinstance( 724 proc_descr, (v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr) 725 ): 726 return ScaleMeanVariance.from_proc_descr(proc_descr, member_id) 727 elif isinstance(proc_descr, (v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr)): 728 return ScaleRange.from_proc_descr(proc_descr, member_id) 729 elif isinstance(proc_descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)): 730 return Sigmoid.from_proc_descr(proc_descr, member_id) 731 elif ( 732 isinstance(proc_descr, v0_4.ZeroMeanUnitVarianceDescr) 733 and proc_descr.kwargs.mode == "fixed" 734 ): 735 if not isinstance( 736 tensor_descr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr) 737 ): 738 raise TypeError( 739 "Expected v0_4 tensor description for v0_4 processing description" 740 ) 741 742 v5_proc_descr = _convert_proc(proc_descr, tensor_descr.axes) 743 assert isinstance(v5_proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr) 744 return FixedZeroMeanUnitVariance.from_proc_descr(v5_proc_descr, member_id) 745 elif isinstance( 746 proc_descr, 747 (v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr), 748 ): 749 return ZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id) 750 elif isinstance(proc_descr, v0_5.SoftmaxDescr): 751 return Softmax.from_proc_descr(proc_descr, member_id) 752 else: 753 assert_never(proc_descr)
@dataclass
class
AddKnownDatasetStats111@dataclass 112class AddKnownDatasetStats(BlockedOperator): 113 dataset_stats: Mapping[DatasetMeasure, MeasureValue] 114 115 @property 116 def required_measures(self) -> Set[Measure]: 117 return set() 118 119 def __call__(self, sample: Union[Sample, SampleBlock]) -> None: 120 sample.stat.update(self.dataset_stats.items())
AddKnownDatasetStats( dataset_stats: Mapping[Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer>, return_type=PydanticUndefined, when_used='always')]]])
dataset_stats: Mapping[Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7f0219527100>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7f02195272e0>, return_type=PydanticUndefined, when_used='always')]]]
required_measures: Set[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]
@dataclass
class
UpdateStats159@dataclass 160class UpdateStats(Operator): 161 """Calculates sample and/or dataset measures""" 162 163 stats_calculator: StatsCalculator 164 """`StatsCalculator` to be used by this operator.""" 165 keep_updating_initial_dataset_stats: bool = False 166 """indicates if operator calls should keep updating initial dataset statistics or not; 167 if the `stats_calculator` was not provided with any initial dataset statistics, 168 these are always updated with every new sample. 169 """ 170 _keep_updating_dataset_stats: bool = field(init=False) 171 172 @property 173 def required_measures(self) -> Set[Measure]: 174 return set() 175 176 def __post_init__(self): 177 self._keep_updating_dataset_stats = ( 178 self.keep_updating_initial_dataset_stats 179 or not self.stats_calculator.has_dataset_measures 180 ) 181 182 def __call__(self, sample: Union[Sample, SampleBlockWithOrigin]) -> None: 183 if isinstance(sample, SampleBlockWithOrigin): 184 # update stats with whole sample on first block 185 if sample.block_index != 0: 186 return 187 188 origin = sample.origin 189 else: 190 origin = sample 191 192 if self._keep_updating_dataset_stats: 193 sample.stat.update(self.stats_calculator.update_and_get_all(origin)) 194 else: 195 sample.stat.update(self.stats_calculator.skip_update_and_get_all(origin))
Calculates sample and/or dataset measures
UpdateStats( stats_calculator: bioimageio.core.stat_calculators.StatsCalculator, keep_updating_initial_dataset_stats: bool = False)
stats_calculator: bioimageio.core.stat_calculators.StatsCalculator
StatsCalculator
to be used by this operator.
keep_updating_initial_dataset_stats: bool =
False
indicates if operator calls should keep updating initial dataset statistics or not;
if the stats_calculator
was not provided with any initial dataset statistics,
these are always updated with every new sample.
required_measures: Set[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]
198@dataclass 199class Binarize(_SimpleOperator): 200 """'output = tensor > threshold'.""" 201 202 threshold: Union[float, Sequence[float]] 203 axis: Optional[AxisId] = None 204 205 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 206 return x > self.threshold 207 208 def get_output_shape( 209 self, input_shape: Mapping[AxisId, int] 210 ) -> Mapping[AxisId, int]: 211 return input_shape 212 213 @classmethod 214 def from_proc_descr( 215 cls, descr: Union[v0_4.BinarizeDescr, v0_5.BinarizeDescr], member_id: MemberId 216 ) -> Self: 217 if isinstance(descr.kwargs, (v0_4.BinarizeKwargs, v0_5.BinarizeKwargs)): 218 return cls( 219 input=member_id, output=member_id, threshold=descr.kwargs.threshold 220 ) 221 elif isinstance(descr.kwargs, v0_5.BinarizeAlongAxisKwargs): 222 return cls( 223 input=member_id, 224 output=member_id, 225 threshold=descr.kwargs.threshold, 226 axis=descr.kwargs.axis, 227 ) 228 else: 229 assert_never(descr.kwargs)
'output = tensor > threshold'.
Binarize( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId, threshold: Union[float, Sequence[float]], axis: Optional[bioimageio.spec.model.v0_5.AxisId] = None)
def
get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
@classmethod
def
from_proc_descr( cls, descr: Union[bioimageio.spec.model.v0_4.BinarizeDescr, bioimageio.spec.model.v0_5.BinarizeDescr], member_id: bioimageio.spec.model.v0_5.TensorId) -> Self:
213 @classmethod 214 def from_proc_descr( 215 cls, descr: Union[v0_4.BinarizeDescr, v0_5.BinarizeDescr], member_id: MemberId 216 ) -> Self: 217 if isinstance(descr.kwargs, (v0_4.BinarizeKwargs, v0_5.BinarizeKwargs)): 218 return cls( 219 input=member_id, output=member_id, threshold=descr.kwargs.threshold 220 ) 221 elif isinstance(descr.kwargs, v0_5.BinarizeAlongAxisKwargs): 222 return cls( 223 input=member_id, 224 output=member_id, 225 threshold=descr.kwargs.threshold, 226 axis=descr.kwargs.axis, 227 ) 228 else: 229 assert_never(descr.kwargs)
Inherited Members
232@dataclass 233class Clip(_SimpleOperator): 234 min: Optional[float] = None 235 """minimum value for clipping""" 236 max: Optional[float] = None 237 """maximum value for clipping""" 238 239 def __post_init__(self): 240 assert self.min is not None or self.max is not None, "missing min or max value" 241 assert self.min is None or self.max is None or self.min < self.max, ( 242 f"expected min < max, but {self.min} !< {self.max}" 243 ) 244 245 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 246 return x.clip(self.min, self.max) 247 248 def get_output_shape( 249 self, input_shape: Mapping[AxisId, int] 250 ) -> Mapping[AxisId, int]: 251 return input_shape 252 253 @classmethod 254 def from_proc_descr( 255 cls, descr: Union[v0_4.ClipDescr, v0_5.ClipDescr], member_id: MemberId 256 ) -> Self: 257 return cls( 258 input=member_id, 259 output=member_id, 260 min=descr.kwargs.min, 261 max=descr.kwargs.max, 262 )
Clip( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId, min: Optional[float] = None, max: Optional[float] = None)
def
get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
@classmethod
def
from_proc_descr( cls, descr: Union[bioimageio.spec.model.v0_4.ClipDescr, bioimageio.spec.model.v0_5.ClipDescr], member_id: bioimageio.spec.model.v0_5.TensorId) -> Self:
Inherited Members
265@dataclass 266class EnsureDtype(_SimpleOperator): 267 dtype: DTypeStr 268 269 @classmethod 270 def from_proc_descr(cls, descr: v0_5.EnsureDtypeDescr, member_id: MemberId): 271 return cls(input=member_id, output=member_id, dtype=descr.kwargs.dtype) 272 273 def get_descr(self): 274 return v0_5.EnsureDtypeDescr(kwargs=v0_5.EnsureDtypeKwargs(dtype=self.dtype)) 275 276 def get_output_shape( 277 self, input_shape: Mapping[AxisId, int] 278 ) -> Mapping[AxisId, int]: 279 return input_shape 280 281 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 282 return x.astype(self.dtype)
EnsureDtype( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId, dtype: Literal['bool', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64'])
dtype: Literal['bool', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64']
@classmethod
def
from_proc_descr( cls, descr: bioimageio.spec.model.v0_5.EnsureDtypeDescr, member_id: bioimageio.spec.model.v0_5.TensorId):
def
get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
Inherited Members
285@dataclass 286class ScaleLinear(_SimpleOperator): 287 gain: Union[float, xr.DataArray] = 1.0 288 """multiplicative factor""" 289 290 offset: Union[float, xr.DataArray] = 0.0 291 """additive term""" 292 293 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 294 return x * self.gain + self.offset 295 296 def get_output_shape( 297 self, input_shape: Mapping[AxisId, int] 298 ) -> Mapping[AxisId, int]: 299 return input_shape 300 301 @classmethod 302 def from_proc_descr( 303 cls, 304 descr: Union[v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr], 305 member_id: MemberId, 306 ) -> Self: 307 kwargs = descr.kwargs 308 if isinstance(kwargs, v0_5.ScaleLinearKwargs): 309 axis = None 310 elif isinstance(kwargs, v0_5.ScaleLinearAlongAxisKwargs): 311 axis = kwargs.axis 312 elif isinstance(kwargs, v0_4.ScaleLinearKwargs): 313 if kwargs.axes is not None: 314 raise NotImplementedError( 315 "model.v0_4.ScaleLinearKwargs with axes not implemented, please consider updating the model to v0_5." 316 ) 317 axis = None 318 else: 319 assert_never(kwargs) 320 321 if axis: 322 gain = xr.DataArray(np.atleast_1d(kwargs.gain), dims=axis) 323 offset = xr.DataArray(np.atleast_1d(kwargs.offset), dims=axis) 324 else: 325 assert isinstance(kwargs.gain, (float, int)) or len(kwargs.gain) == 1, ( 326 kwargs.gain 327 ) 328 gain = ( 329 kwargs.gain if isinstance(kwargs.gain, (float, int)) else kwargs.gain[0] 330 ) 331 assert isinstance(kwargs.offset, (float, int)) or len(kwargs.offset) == 1 332 offset = ( 333 kwargs.offset 334 if isinstance(kwargs.offset, (float, int)) 335 else kwargs.offset[0] 336 ) 337 338 return cls(input=member_id, output=member_id, gain=gain, offset=offset)
ScaleLinear( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId, gain: Union[float, xarray.core.dataarray.DataArray] = 1.0, offset: Union[float, xarray.core.dataarray.DataArray] = 0.0)
def
get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
@classmethod
def
from_proc_descr( cls, descr: Union[bioimageio.spec.model.v0_4.ScaleLinearDescr, bioimageio.spec.model.v0_5.ScaleLinearDescr], member_id: bioimageio.spec.model.v0_5.TensorId) -> Self:
301 @classmethod 302 def from_proc_descr( 303 cls, 304 descr: Union[v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr], 305 member_id: MemberId, 306 ) -> Self: 307 kwargs = descr.kwargs 308 if isinstance(kwargs, v0_5.ScaleLinearKwargs): 309 axis = None 310 elif isinstance(kwargs, v0_5.ScaleLinearAlongAxisKwargs): 311 axis = kwargs.axis 312 elif isinstance(kwargs, v0_4.ScaleLinearKwargs): 313 if kwargs.axes is not None: 314 raise NotImplementedError( 315 "model.v0_4.ScaleLinearKwargs with axes not implemented, please consider updating the model to v0_5." 316 ) 317 axis = None 318 else: 319 assert_never(kwargs) 320 321 if axis: 322 gain = xr.DataArray(np.atleast_1d(kwargs.gain), dims=axis) 323 offset = xr.DataArray(np.atleast_1d(kwargs.offset), dims=axis) 324 else: 325 assert isinstance(kwargs.gain, (float, int)) or len(kwargs.gain) == 1, ( 326 kwargs.gain 327 ) 328 gain = ( 329 kwargs.gain if isinstance(kwargs.gain, (float, int)) else kwargs.gain[0] 330 ) 331 assert isinstance(kwargs.offset, (float, int)) or len(kwargs.offset) == 1 332 offset = ( 333 kwargs.offset 334 if isinstance(kwargs.offset, (float, int)) 335 else kwargs.offset[0] 336 ) 337 338 return cls(input=member_id, output=member_id, gain=gain, offset=offset)
Inherited Members
341@dataclass 342class ScaleMeanVariance(_SimpleOperator): 343 axes: Optional[Sequence[AxisId]] = None 344 reference_tensor: Optional[MemberId] = None 345 eps: float = 1e-6 346 mean: Union[SampleMean, DatasetMean] = field(init=False) 347 std: Union[SampleStd, DatasetStd] = field(init=False) 348 ref_mean: Union[SampleMean, DatasetMean] = field(init=False) 349 ref_std: Union[SampleStd, DatasetStd] = field(init=False) 350 351 @property 352 def required_measures(self): 353 return {self.mean, self.std, self.ref_mean, self.ref_std} 354 355 def __post_init__(self): 356 axes = None if self.axes is None else tuple(self.axes) 357 ref_tensor = self.reference_tensor or self.input 358 if axes is None or AxisId("batch") not in axes: 359 Mean = SampleMean 360 Std = SampleStd 361 else: 362 Mean = DatasetMean 363 Std = DatasetStd 364 365 self.mean = Mean(member_id=self.input, axes=axes) 366 self.std = Std(member_id=self.input, axes=axes) 367 self.ref_mean = Mean(member_id=ref_tensor, axes=axes) 368 self.ref_std = Std(member_id=ref_tensor, axes=axes) 369 370 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 371 mean = stat[self.mean] 372 std = stat[self.std] + self.eps 373 ref_mean = stat[self.ref_mean] 374 ref_std = stat[self.ref_std] + self.eps 375 return (x - mean) / std * ref_std + ref_mean 376 377 def get_output_shape( 378 self, input_shape: Mapping[AxisId, int] 379 ) -> Mapping[AxisId, int]: 380 return input_shape 381 382 @classmethod 383 def from_proc_descr( 384 cls, 385 descr: Union[v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr], 386 member_id: MemberId, 387 ) -> Self: 388 kwargs = descr.kwargs 389 _, axes = _get_axes(descr.kwargs) 390 391 return cls( 392 input=member_id, 393 output=member_id, 394 reference_tensor=MemberId(str(kwargs.reference_tensor)), 395 axes=axes, 396 eps=kwargs.eps, 397 )
ScaleMeanVariance( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId, axes: Optional[Sequence[bioimageio.spec.model.v0_5.AxisId]] = None, reference_tensor: Optional[bioimageio.spec.model.v0_5.TensorId] = None, eps: float = 1e-06)
ref_mean: Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.DatasetMean]
def
get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
@classmethod
def
from_proc_descr( cls, descr: Union[bioimageio.spec.model.v0_4.ScaleMeanVarianceDescr, bioimageio.spec.model.v0_5.ScaleMeanVarianceDescr], member_id: bioimageio.spec.model.v0_5.TensorId) -> Self:
382 @classmethod 383 def from_proc_descr( 384 cls, 385 descr: Union[v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr], 386 member_id: MemberId, 387 ) -> Self: 388 kwargs = descr.kwargs 389 _, axes = _get_axes(descr.kwargs) 390 391 return cls( 392 input=member_id, 393 output=member_id, 394 reference_tensor=MemberId(str(kwargs.reference_tensor)), 395 axes=axes, 396 eps=kwargs.eps, 397 )
Inherited Members
422@dataclass 423class ScaleRange(_SimpleOperator): 424 lower_percentile: InitVar[Optional[Union[SampleQuantile, DatasetPercentile]]] = None 425 upper_percentile: InitVar[Optional[Union[SampleQuantile, DatasetPercentile]]] = None 426 lower: Union[SampleQuantile, DatasetPercentile] = field(init=False) 427 upper: Union[SampleQuantile, DatasetPercentile] = field(init=False) 428 429 eps: float = 1e-6 430 431 def __post_init__( 432 self, 433 lower_percentile: Optional[Union[SampleQuantile, DatasetPercentile]], 434 upper_percentile: Optional[Union[SampleQuantile, DatasetPercentile]], 435 ): 436 if lower_percentile is None: 437 tid = self.input if upper_percentile is None else upper_percentile.member_id 438 self.lower = DatasetPercentile(q=0.0, member_id=tid) 439 else: 440 self.lower = lower_percentile 441 442 if upper_percentile is None: 443 self.upper = DatasetPercentile(q=1.0, member_id=self.lower.member_id) 444 else: 445 self.upper = upper_percentile 446 447 assert self.lower.member_id == self.upper.member_id 448 assert self.lower.q < self.upper.q 449 assert self.lower.axes == self.upper.axes 450 451 @property 452 def required_measures(self): 453 return {self.lower, self.upper} 454 455 def get_output_shape( 456 self, input_shape: Mapping[AxisId, int] 457 ) -> Mapping[AxisId, int]: 458 return input_shape 459 460 @classmethod 461 def from_proc_descr( 462 cls, 463 descr: Union[v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr], 464 member_id: MemberId, 465 ): 466 kwargs = descr.kwargs 467 ref_tensor = ( 468 member_id 469 if kwargs.reference_tensor is None 470 else MemberId(str(kwargs.reference_tensor)) 471 ) 472 dataset_mode, axes = _get_axes(descr.kwargs) 473 if dataset_mode: 474 Percentile = DatasetPercentile 475 else: 476 Percentile = SampleQuantile 477 478 return cls( 479 input=member_id, 480 output=member_id, 481 lower_percentile=Percentile( 482 q=kwargs.min_percentile / 100, axes=axes, member_id=ref_tensor 483 ), 484 upper_percentile=Percentile( 485 q=kwargs.max_percentile / 100, axes=axes, member_id=ref_tensor 486 ), 487 ) 488 489 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 490 lower = stat[self.lower] 491 upper = stat[self.upper] 492 return (x - lower) / (upper - lower + self.eps) 493 494 def get_descr(self): 495 assert self.lower.axes == self.upper.axes 496 assert self.lower.member_id == self.upper.member_id 497 498 return v0_5.ScaleRangeDescr( 499 kwargs=v0_5.ScaleRangeKwargs( 500 axes=self.lower.axes, 501 min_percentile=self.lower.q * 100, 502 max_percentile=self.upper.q * 100, 503 eps=self.eps, 504 reference_tensor=self.lower.member_id, 505 ) 506 )
ScaleRange( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId, lower_percentile: dataclasses.InitVar[typing.Union[bioimageio.core.stat_measures.SampleQuantile, bioimageio.core.stat_measures.DatasetPercentile, NoneType]] = None, upper_percentile: dataclasses.InitVar[typing.Union[bioimageio.core.stat_measures.SampleQuantile, bioimageio.core.stat_measures.DatasetPercentile, NoneType]] = None, eps: float = 1e-06)
lower_percentile: dataclasses.InitVar[typing.Union[bioimageio.core.stat_measures.SampleQuantile, bioimageio.core.stat_measures.DatasetPercentile, NoneType]] =
None
upper_percentile: dataclasses.InitVar[typing.Union[bioimageio.core.stat_measures.SampleQuantile, bioimageio.core.stat_measures.DatasetPercentile, NoneType]] =
None
lower: Union[bioimageio.core.stat_measures.SampleQuantile, bioimageio.core.stat_measures.DatasetPercentile]
upper: Union[bioimageio.core.stat_measures.SampleQuantile, bioimageio.core.stat_measures.DatasetPercentile]
def
get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
@classmethod
def
from_proc_descr( cls, descr: Union[bioimageio.spec.model.v0_4.ScaleRangeDescr, bioimageio.spec.model.v0_5.ScaleRangeDescr], member_id: bioimageio.spec.model.v0_5.TensorId):
460 @classmethod 461 def from_proc_descr( 462 cls, 463 descr: Union[v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr], 464 member_id: MemberId, 465 ): 466 kwargs = descr.kwargs 467 ref_tensor = ( 468 member_id 469 if kwargs.reference_tensor is None 470 else MemberId(str(kwargs.reference_tensor)) 471 ) 472 dataset_mode, axes = _get_axes(descr.kwargs) 473 if dataset_mode: 474 Percentile = DatasetPercentile 475 else: 476 Percentile = SampleQuantile 477 478 return cls( 479 input=member_id, 480 output=member_id, 481 lower_percentile=Percentile( 482 q=kwargs.min_percentile / 100, axes=axes, member_id=ref_tensor 483 ), 484 upper_percentile=Percentile( 485 q=kwargs.max_percentile / 100, axes=axes, member_id=ref_tensor 486 ), 487 )
def
get_descr(self):
494 def get_descr(self): 495 assert self.lower.axes == self.upper.axes 496 assert self.lower.member_id == self.upper.member_id 497 498 return v0_5.ScaleRangeDescr( 499 kwargs=v0_5.ScaleRangeKwargs( 500 axes=self.lower.axes, 501 min_percentile=self.lower.q * 100, 502 max_percentile=self.upper.q * 100, 503 eps=self.eps, 504 reference_tensor=self.lower.member_id, 505 ) 506 )
Inherited Members
509@dataclass 510class Sigmoid(_SimpleOperator): 511 """1 / (1 + e^(-input)).""" 512 513 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 514 return Tensor(1.0 / (1.0 + np.exp(-x)), dims=x.dims) 515 516 @property 517 def required_measures(self) -> Collection[Measure]: 518 return {} 519 520 def get_output_shape( 521 self, input_shape: Mapping[AxisId, int] 522 ) -> Mapping[AxisId, int]: 523 return input_shape 524 525 @classmethod 526 def from_proc_descr( 527 cls, descr: Union[v0_4.SigmoidDescr, v0_5.SigmoidDescr], member_id: MemberId 528 ) -> Self: 529 assert isinstance(descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)) 530 return cls(input=member_id, output=member_id) 531 532 def get_descr(self): 533 return v0_5.SigmoidDescr()
1 / (1 + e^(-input)).
Sigmoid( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId)
required_measures: Collection[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]
def
get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
@classmethod
def
from_proc_descr( cls, descr: Union[bioimageio.spec.model.v0_4.SigmoidDescr, bioimageio.spec.model.v0_5.SigmoidDescr], member_id: bioimageio.spec.model.v0_5.TensorId) -> Self:
Inherited Members
536@dataclass 537class Softmax(_SimpleOperator): 538 """Softmax activation function.""" 539 540 axis: AxisId = AxisId("channel") 541 542 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 543 axis_idx = x.dims.index(self.axis) 544 result = scipy.special.softmax(x.data, axis=axis_idx) 545 result_xr = xr.DataArray(result, dims=x.dims) 546 return Tensor.from_xarray(result_xr) 547 548 @property 549 def required_measures(self) -> Collection[Measure]: 550 return {} 551 552 def get_output_shape( 553 self, input_shape: Mapping[AxisId, int] 554 ) -> Mapping[AxisId, int]: 555 return input_shape 556 557 @classmethod 558 def from_proc_descr(cls, descr: v0_5.SoftmaxDescr, member_id: MemberId) -> Self: 559 assert isinstance(descr, v0_5.SoftmaxDescr) 560 return cls(input=member_id, output=member_id, axis=descr.kwargs.axis) 561 562 def get_descr(self): 563 return v0_5.SoftmaxDescr(kwargs=v0_5.SoftmaxKwargs(axis=self.axis))
Softmax activation function.
Softmax( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId, axis: bioimageio.spec.model.v0_5.AxisId = 'channel')
required_measures: Collection[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]
def
get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
@classmethod
def
from_proc_descr( cls, descr: bioimageio.spec.model.v0_5.SoftmaxDescr, member_id: bioimageio.spec.model.v0_5.TensorId) -> Self:
Inherited Members
566@dataclass 567class ZeroMeanUnitVariance(_SimpleOperator): 568 """normalize to zero mean, unit variance.""" 569 570 mean: MeanMeasure 571 std: StdMeasure 572 573 eps: float = 1e-6 574 575 def __post_init__(self): 576 assert self.mean.axes == self.std.axes 577 578 @property 579 def required_measures(self) -> Set[Union[MeanMeasure, StdMeasure]]: 580 return {self.mean, self.std} 581 582 def get_output_shape( 583 self, input_shape: Mapping[AxisId, int] 584 ) -> Mapping[AxisId, int]: 585 return input_shape 586 587 @classmethod 588 def from_proc_descr( 589 cls, 590 descr: Union[v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr], 591 member_id: MemberId, 592 ): 593 dataset_mode, axes = _get_axes(descr.kwargs) 594 595 if dataset_mode: 596 Mean = DatasetMean 597 Std = DatasetStd 598 else: 599 Mean = SampleMean 600 Std = SampleStd 601 602 return cls( 603 input=member_id, 604 output=member_id, 605 mean=Mean(axes=axes, member_id=member_id), 606 std=Std(axes=axes, member_id=member_id), 607 ) 608 609 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 610 mean = stat[self.mean] 611 std = stat[self.std] 612 return (x - mean) / (std + self.eps) 613 614 def get_descr(self): 615 return v0_5.ZeroMeanUnitVarianceDescr( 616 kwargs=v0_5.ZeroMeanUnitVarianceKwargs(axes=self.mean.axes, eps=self.eps) 617 )
normalize to zero mean, unit variance.
ZeroMeanUnitVariance( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId, mean: Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.DatasetMean], std: Union[bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.DatasetStd], eps: float = 1e-06)
required_measures: Set[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.DatasetStd]]
def
get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
@classmethod
def
from_proc_descr( cls, descr: Union[bioimageio.spec.model.v0_4.ZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_5.ZeroMeanUnitVarianceDescr], member_id: bioimageio.spec.model.v0_5.TensorId):
587 @classmethod 588 def from_proc_descr( 589 cls, 590 descr: Union[v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr], 591 member_id: MemberId, 592 ): 593 dataset_mode, axes = _get_axes(descr.kwargs) 594 595 if dataset_mode: 596 Mean = DatasetMean 597 Std = DatasetStd 598 else: 599 Mean = SampleMean 600 Std = SampleStd 601 602 return cls( 603 input=member_id, 604 output=member_id, 605 mean=Mean(axes=axes, member_id=member_id), 606 std=Std(axes=axes, member_id=member_id), 607 )
Inherited Members
620@dataclass 621class FixedZeroMeanUnitVariance(_SimpleOperator): 622 """normalize to zero mean, unit variance with precomputed values.""" 623 624 mean: Union[float, xr.DataArray] 625 std: Union[float, xr.DataArray] 626 627 eps: float = 1e-6 628 629 def __post_init__(self): 630 assert ( 631 isinstance(self.mean, (int, float)) 632 or isinstance(self.std, (int, float)) 633 or self.mean.dims == self.std.dims 634 ) 635 636 def get_output_shape( 637 self, input_shape: Mapping[AxisId, int] 638 ) -> Mapping[AxisId, int]: 639 return input_shape 640 641 @classmethod 642 def from_proc_descr( 643 cls, 644 descr: v0_5.FixedZeroMeanUnitVarianceDescr, 645 member_id: MemberId, 646 ) -> Self: 647 if isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceKwargs): 648 dims = None 649 elif isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs): 650 dims = (AxisId(descr.kwargs.axis),) 651 else: 652 assert_never(descr.kwargs) 653 654 return cls( 655 input=member_id, 656 output=member_id, 657 mean=xr.DataArray(descr.kwargs.mean, dims=dims), 658 std=xr.DataArray(descr.kwargs.std, dims=dims), 659 ) 660 661 def get_descr(self): 662 if isinstance(self.mean, (int, float)): 663 assert isinstance(self.std, (int, float)) 664 kwargs = v0_5.FixedZeroMeanUnitVarianceKwargs(mean=self.mean, std=self.std) 665 else: 666 assert isinstance(self.std, xr.DataArray) 667 assert len(self.mean.dims) == 1 668 kwargs = v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs( 669 axis=AxisId(str(self.mean.dims[0])), 670 mean=list(self.mean), 671 std=list(self.std), 672 ) 673 674 return v0_5.FixedZeroMeanUnitVarianceDescr(kwargs=kwargs) 675 676 def _apply(self, x: Tensor, stat: Stat) -> Tensor: 677 return (x - self.mean) / (self.std + self.eps)
normalize to zero mean, unit variance with precomputed values.
FixedZeroMeanUnitVariance( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId, mean: Union[float, xarray.core.dataarray.DataArray], std: Union[float, xarray.core.dataarray.DataArray], eps: float = 1e-06)
def
get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
@classmethod
def
from_proc_descr( cls, descr: bioimageio.spec.model.v0_5.FixedZeroMeanUnitVarianceDescr, member_id: bioimageio.spec.model.v0_5.TensorId) -> Self:
641 @classmethod 642 def from_proc_descr( 643 cls, 644 descr: v0_5.FixedZeroMeanUnitVarianceDescr, 645 member_id: MemberId, 646 ) -> Self: 647 if isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceKwargs): 648 dims = None 649 elif isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs): 650 dims = (AxisId(descr.kwargs.axis),) 651 else: 652 assert_never(descr.kwargs) 653 654 return cls( 655 input=member_id, 656 output=member_id, 657 mean=xr.DataArray(descr.kwargs.mean, dims=dims), 658 std=xr.DataArray(descr.kwargs.std, dims=dims), 659 )
def
get_descr(self):
661 def get_descr(self): 662 if isinstance(self.mean, (int, float)): 663 assert isinstance(self.std, (int, float)) 664 kwargs = v0_5.FixedZeroMeanUnitVarianceKwargs(mean=self.mean, std=self.std) 665 else: 666 assert isinstance(self.std, xr.DataArray) 667 assert len(self.mean.dims) == 1 668 kwargs = v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs( 669 axis=AxisId(str(self.mean.dims[0])), 670 mean=list(self.mean), 671 std=list(self.std), 672 ) 673 674 return v0_5.FixedZeroMeanUnitVarianceDescr(kwargs=kwargs)
Inherited Members
ProcDescr =
typing.Union[typing.Annotated[typing.Union[bioimageio.spec.model.v0_4.BinarizeDescr, bioimageio.spec.model.v0_4.ClipDescr, bioimageio.spec.model.v0_4.ScaleLinearDescr, bioimageio.spec.model.v0_4.SigmoidDescr, bioimageio.spec.model.v0_4.ZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_4.ScaleRangeDescr], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], typing.Annotated[typing.Union[bioimageio.spec.model.v0_4.BinarizeDescr, bioimageio.spec.model.v0_4.ClipDescr, bioimageio.spec.model.v0_4.ScaleLinearDescr, bioimageio.spec.model.v0_4.SigmoidDescr, bioimageio.spec.model.v0_4.ZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_4.ScaleRangeDescr, bioimageio.spec.model.v0_4.ScaleMeanVarianceDescr], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], typing.Annotated[typing.Union[bioimageio.spec.model.v0_5.BinarizeDescr, bioimageio.spec.model.v0_5.ClipDescr, bioimageio.spec.model.v0_5.EnsureDtypeDescr, bioimageio.spec.model.v0_5.FixedZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_5.ScaleLinearDescr, bioimageio.spec.model.v0_5.ScaleRangeDescr, bioimageio.spec.model.v0_5.SigmoidDescr, bioimageio.spec.model.v0_5.SoftmaxDescr, bioimageio.spec.model.v0_5.ZeroMeanUnitVarianceDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)], typing.Annotated[typing.Union[bioimageio.spec.model.v0_5.BinarizeDescr, bioimageio.spec.model.v0_5.ClipDescr, bioimageio.spec.model.v0_5.EnsureDtypeDescr, bioimageio.spec.model.v0_5.FixedZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_5.ScaleLinearDescr, bioimageio.spec.model.v0_5.ScaleMeanVarianceDescr, bioimageio.spec.model.v0_5.ScaleRangeDescr, bioimageio.spec.model.v0_5.SigmoidDescr, bioimageio.spec.model.v0_5.SoftmaxDescr, bioimageio.spec.model.v0_5.ZeroMeanUnitVarianceDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]
Processing =
typing.Union[AddKnownDatasetStats, Binarize, Clip, EnsureDtype, FixedZeroMeanUnitVariance, ScaleLinear, ScaleMeanVariance, ScaleRange, Sigmoid, Softmax, UpdateStats, ZeroMeanUnitVariance]
def
get_proc( proc_descr: Union[Annotated[Union[bioimageio.spec.model.v0_4.BinarizeDescr, bioimageio.spec.model.v0_4.ClipDescr, bioimageio.spec.model.v0_4.ScaleLinearDescr, bioimageio.spec.model.v0_4.SigmoidDescr, bioimageio.spec.model.v0_4.ZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_4.ScaleRangeDescr], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.spec.model.v0_4.BinarizeDescr, bioimageio.spec.model.v0_4.ClipDescr, bioimageio.spec.model.v0_4.ScaleLinearDescr, bioimageio.spec.model.v0_4.SigmoidDescr, bioimageio.spec.model.v0_4.ZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_4.ScaleRangeDescr, bioimageio.spec.model.v0_4.ScaleMeanVarianceDescr], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.spec.model.v0_5.BinarizeDescr, bioimageio.spec.model.v0_5.ClipDescr, bioimageio.spec.model.v0_5.EnsureDtypeDescr, bioimageio.spec.model.v0_5.FixedZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_5.ScaleLinearDescr, bioimageio.spec.model.v0_5.ScaleRangeDescr, bioimageio.spec.model.v0_5.SigmoidDescr, bioimageio.spec.model.v0_5.SoftmaxDescr, bioimageio.spec.model.v0_5.ZeroMeanUnitVarianceDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.spec.model.v0_5.BinarizeDescr, bioimageio.spec.model.v0_5.ClipDescr, bioimageio.spec.model.v0_5.EnsureDtypeDescr, bioimageio.spec.model.v0_5.FixedZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_5.ScaleLinearDescr, bioimageio.spec.model.v0_5.ScaleMeanVarianceDescr, bioimageio.spec.model.v0_5.ScaleRangeDescr, bioimageio.spec.model.v0_5.SigmoidDescr, bioimageio.spec.model.v0_5.SoftmaxDescr, bioimageio.spec.model.v0_5.ZeroMeanUnitVarianceDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], tensor_descr: Union[bioimageio.spec.model.v0_4.InputTensorDescr, bioimageio.spec.model.v0_4.OutputTensorDescr, bioimageio.spec.model.v0_5.InputTensorDescr, bioimageio.spec.model.v0_5.OutputTensorDescr]) -> Union[AddKnownDatasetStats, Binarize, Clip, EnsureDtype, FixedZeroMeanUnitVariance, ScaleLinear, ScaleMeanVariance, ScaleRange, Sigmoid, Softmax, UpdateStats, ZeroMeanUnitVariance]:
703def get_proc( 704 proc_descr: ProcDescr, 705 tensor_descr: Union[ 706 v0_4.InputTensorDescr, 707 v0_4.OutputTensorDescr, 708 v0_5.InputTensorDescr, 709 v0_5.OutputTensorDescr, 710 ], 711) -> Processing: 712 member_id = get_member_id(tensor_descr) 713 714 if isinstance(proc_descr, (v0_4.BinarizeDescr, v0_5.BinarizeDescr)): 715 return Binarize.from_proc_descr(proc_descr, member_id) 716 elif isinstance(proc_descr, (v0_4.ClipDescr, v0_5.ClipDescr)): 717 return Clip.from_proc_descr(proc_descr, member_id) 718 elif isinstance(proc_descr, v0_5.EnsureDtypeDescr): 719 return EnsureDtype.from_proc_descr(proc_descr, member_id) 720 elif isinstance(proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr): 721 return FixedZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id) 722 elif isinstance(proc_descr, (v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr)): 723 return ScaleLinear.from_proc_descr(proc_descr, member_id) 724 elif isinstance( 725 proc_descr, (v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr) 726 ): 727 return ScaleMeanVariance.from_proc_descr(proc_descr, member_id) 728 elif isinstance(proc_descr, (v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr)): 729 return ScaleRange.from_proc_descr(proc_descr, member_id) 730 elif isinstance(proc_descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)): 731 return Sigmoid.from_proc_descr(proc_descr, member_id) 732 elif ( 733 isinstance(proc_descr, v0_4.ZeroMeanUnitVarianceDescr) 734 and proc_descr.kwargs.mode == "fixed" 735 ): 736 if not isinstance( 737 tensor_descr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr) 738 ): 739 raise TypeError( 740 "Expected v0_4 tensor description for v0_4 processing description" 741 ) 742 743 v5_proc_descr = _convert_proc(proc_descr, tensor_descr.axes) 744 assert isinstance(v5_proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr) 745 return FixedZeroMeanUnitVariance.from_proc_descr(v5_proc_descr, member_id) 746 elif isinstance( 747 proc_descr, 748 (v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr), 749 ): 750 return ZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id) 751 elif isinstance(proc_descr, v0_5.SoftmaxDescr): 752 return Softmax.from_proc_descr(proc_descr, member_id) 753 else: 754 assert_never(proc_descr)