bioimageio.core.proc_ops
1import collections.abc 2from abc import ABC, abstractmethod 3from dataclasses import InitVar, dataclass, field 4from typing import ( 5 Collection, 6 Literal, 7 Mapping, 8 Optional, 9 Sequence, 10 Set, 11 Tuple, 12 Union, 13) 14 15import numpy as np 16import xarray as xr 17from typing_extensions import Self, assert_never 18 19from bioimageio.spec.model import v0_4, v0_5 20 21from ._op_base import BlockedOperator, Operator 22from .axis import AxisId, PerAxis 23from .block import Block 24from .common import DTypeStr, MemberId 25from .sample import Sample, SampleBlock, SampleBlockWithOrigin 26from .stat_calculators import StatsCalculator 27from .stat_measures import ( 28 DatasetMean, 29 DatasetMeasure, 30 DatasetPercentile, 31 DatasetStd, 32 MeanMeasure, 33 Measure, 34 MeasureValue, 35 SampleMean, 36 SampleQuantile, 37 SampleStd, 38 Stat, 39 StdMeasure, 40) 41from .tensor import Tensor 42 43 44def _convert_axis_ids( 45 axes: v0_4.AxesInCZYX, 46 mode: Literal["per_sample", "per_dataset"], 47) -> Tuple[AxisId, ...]: 48 if not isinstance(axes, str): 49 return tuple(axes) 50 51 if mode == "per_sample": 52 ret = [] 53 elif mode == "per_dataset": 54 ret = [AxisId("b")] 55 else: 56 assert_never(mode) 57 58 ret.extend([AxisId(a) for a in axes]) 59 return tuple(ret) 60 61 62@dataclass 63class _SimpleOperator(BlockedOperator, ABC): 64 input: MemberId 65 output: MemberId 66 67 @property 68 def required_measures(self) -> Collection[Measure]: 69 return set() 70 71 @abstractmethod 72 def get_output_shape(self, input_shape: PerAxis[int]) -> PerAxis[int]: ... 73 74 def __call__(self, sample: Union[Sample, SampleBlock]) -> None: 75 if self.input not in sample.members: 76 return 77 78 input_tensor = sample.members[self.input] 79 output_tensor = self._apply(input_tensor, sample.stat) 80 81 if self.output in sample.members: 82 assert ( 83 sample.members[self.output].tagged_shape == output_tensor.tagged_shape 84 ) 85 86 if isinstance(sample, Sample): 87 sample.members[self.output] = output_tensor 88 elif isinstance(sample, SampleBlock): 89 b = sample.blocks[self.input] 90 sample.blocks[self.output] = Block( 91 sample_shape=self.get_output_shape(sample.shape[self.input]), 92 data=output_tensor, 93 inner_slice=b.inner_slice, 94 halo=b.halo, 95 block_index=b.block_index, 96 blocks_in_sample=b.blocks_in_sample, 97 ) 98 else: 99 assert_never(sample) 100 101 @abstractmethod 102 def _apply(self, input: Tensor, stat: Stat) -> Tensor: ... 103 104 105@dataclass 106class AddKnownDatasetStats(BlockedOperator): 107 dataset_stats: Mapping[DatasetMeasure, MeasureValue] 108 109 @property 110 def required_measures(self) -> Set[Measure]: 111 return set() 112 113 def __call__(self, sample: Union[Sample, SampleBlock]) -> None: 114 sample.stat.update(self.dataset_stats.items()) 115 116 117# @dataclass 118# class UpdateStats(Operator): 119# """Calculates sample and/or dataset measures""" 120 121# measures: Union[Sequence[Measure], Set[Measure], Mapping[Measure, MeasureValue]] 122# """sample and dataset `measuers` to be calculated by this operator. Initial/fixed 123# dataset measure values may be given, see `keep_updating_dataset_stats` for details. 124# """ 125# keep_updating_dataset_stats: Optional[bool] = None 126# """indicates if operator calls should keep updating dataset statistics or not 127 128# default (None): if `measures` is a `Mapping` (i.e. initial measure values are 129# given) no further updates to dataset statistics is conducted, otherwise (w.o. 130# initial measure values) dataset statistics are updated by each processed sample. 131# """ 132# _keep_updating_dataset_stats: bool = field(init=False) 133# _stats_calculator: StatsCalculator = field(init=False) 134 135# @property 136# def required_measures(self) -> Set[Measure]: 137# return set() 138 139# def __post_init__(self): 140# self._stats_calculator = StatsCalculator(self.measures) 141# if self.keep_updating_dataset_stats is None: 142# self._keep_updating_dataset_stats = not isinstance(self.measures, collections.abc.Mapping) 143# else: 144# self._keep_updating_dataset_stats = self.keep_updating_dataset_stats 145 146# def __call__(self, sample_block: SampleBlockWithOrigin> None: 147# if self._keep_updating_dataset_stats: 148# sample.stat.update(self._stats_calculator.update_and_get_all(sample)) 149# else: 150# sample.stat.update(self._stats_calculator.skip_update_and_get_all(sample)) 151 152 153@dataclass 154class UpdateStats(Operator): 155 """Calculates sample and/or dataset measures""" 156 157 stats_calculator: StatsCalculator 158 """`StatsCalculator` to be used by this operator.""" 159 keep_updating_initial_dataset_stats: bool = False 160 """indicates if operator calls should keep updating initial dataset statistics or not; 161 if the `stats_calculator` was not provided with any initial dataset statistics, 162 these are always updated with every new sample. 163 """ 164 _keep_updating_dataset_stats: bool = field(init=False) 165 166 @property 167 def required_measures(self) -> Set[Measure]: 168 return set() 169 170 def __post_init__(self): 171 self._keep_updating_dataset_stats = ( 172 self.keep_updating_initial_dataset_stats 173 or not self.stats_calculator.has_dataset_measures 174 ) 175 176 def __call__(self, sample: Union[Sample, SampleBlockWithOrigin]) -> None: 177 if isinstance(sample, SampleBlockWithOrigin): 178 # update stats with whole sample on first block 179 if sample.block_index != 0: 180 return 181 182 origin = sample.origin 183 else: 184 origin = sample 185 186 if self._keep_updating_dataset_stats: 187 sample.stat.update(self.stats_calculator.update_and_get_all(origin)) 188 else: 189 sample.stat.update(self.stats_calculator.skip_update_and_get_all(origin)) 190 191 192@dataclass 193class Binarize(_SimpleOperator): 194 """'output = tensor > threshold'.""" 195 196 threshold: Union[float, Sequence[float]] 197 axis: Optional[AxisId] = None 198 199 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 200 return input > self.threshold 201 202 def get_output_shape( 203 self, input_shape: Mapping[AxisId, int] 204 ) -> Mapping[AxisId, int]: 205 return input_shape 206 207 @classmethod 208 def from_proc_descr( 209 cls, descr: Union[v0_4.BinarizeDescr, v0_5.BinarizeDescr], member_id: MemberId 210 ) -> Self: 211 if isinstance(descr.kwargs, (v0_4.BinarizeKwargs, v0_5.BinarizeKwargs)): 212 return cls( 213 input=member_id, output=member_id, threshold=descr.kwargs.threshold 214 ) 215 elif isinstance(descr.kwargs, v0_5.BinarizeAlongAxisKwargs): 216 return cls( 217 input=member_id, 218 output=member_id, 219 threshold=descr.kwargs.threshold, 220 axis=descr.kwargs.axis, 221 ) 222 else: 223 assert_never(descr.kwargs) 224 225 226@dataclass 227class Clip(_SimpleOperator): 228 min: Optional[float] = None 229 """minimum value for clipping""" 230 max: Optional[float] = None 231 """maximum value for clipping""" 232 233 def __post_init__(self): 234 assert self.min is not None or self.max is not None, "missing min or max value" 235 assert ( 236 self.min is None or self.max is None or self.min < self.max 237 ), f"expected min < max, but {self.min} !< {self.max}" 238 239 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 240 return input.clip(self.min, self.max) 241 242 def get_output_shape( 243 self, input_shape: Mapping[AxisId, int] 244 ) -> Mapping[AxisId, int]: 245 return input_shape 246 247 @classmethod 248 def from_proc_descr( 249 cls, descr: Union[v0_4.ClipDescr, v0_5.ClipDescr], member_id: MemberId 250 ) -> Self: 251 return cls( 252 input=member_id, 253 output=member_id, 254 min=descr.kwargs.min, 255 max=descr.kwargs.max, 256 ) 257 258 259@dataclass 260class EnsureDtype(_SimpleOperator): 261 dtype: DTypeStr 262 263 @classmethod 264 def from_proc_descr(cls, descr: v0_5.EnsureDtypeDescr, member_id: MemberId): 265 return cls(input=member_id, output=member_id, dtype=descr.kwargs.dtype) 266 267 def get_descr(self): 268 return v0_5.EnsureDtypeDescr(kwargs=v0_5.EnsureDtypeKwargs(dtype=self.dtype)) 269 270 def get_output_shape( 271 self, input_shape: Mapping[AxisId, int] 272 ) -> Mapping[AxisId, int]: 273 return input_shape 274 275 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 276 return input.astype(self.dtype) 277 278 279@dataclass 280class ScaleLinear(_SimpleOperator): 281 gain: Union[float, xr.DataArray] = 1.0 282 """multiplicative factor""" 283 284 offset: Union[float, xr.DataArray] = 0.0 285 """additive term""" 286 287 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 288 return input * self.gain + self.offset 289 290 def get_output_shape( 291 self, input_shape: Mapping[AxisId, int] 292 ) -> Mapping[AxisId, int]: 293 return input_shape 294 295 @classmethod 296 def from_proc_descr( 297 cls, 298 descr: Union[v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr], 299 member_id: MemberId, 300 ) -> Self: 301 kwargs = descr.kwargs 302 if isinstance(kwargs, v0_5.ScaleLinearAlongAxisKwargs): 303 axis = kwargs.axis 304 elif isinstance(kwargs, (v0_4.ScaleLinearKwargs, v0_5.ScaleLinearKwargs)): 305 axis = None 306 else: 307 assert_never(kwargs) 308 309 if axis: 310 gain = xr.DataArray(np.atleast_1d(kwargs.gain), dims=axis) 311 offset = xr.DataArray(np.atleast_1d(kwargs.offset), dims=axis) 312 else: 313 assert ( 314 isinstance(kwargs.gain, (float, int)) or len(kwargs.gain) == 1 315 ), kwargs.gain 316 gain = ( 317 kwargs.gain if isinstance(kwargs.gain, (float, int)) else kwargs.gain[0] 318 ) 319 assert isinstance(kwargs.offset, (float, int)) or len(kwargs.offset) == 1 320 offset = ( 321 kwargs.offset 322 if isinstance(kwargs.offset, (float, int)) 323 else kwargs.offset[0] 324 ) 325 326 return cls(input=member_id, output=member_id, gain=gain, offset=offset) 327 328 329@dataclass 330class ScaleMeanVariance(_SimpleOperator): 331 axes: Optional[Sequence[AxisId]] = None 332 reference_tensor: Optional[MemberId] = None 333 eps: float = 1e-6 334 mean: Union[SampleMean, DatasetMean] = field(init=False) 335 std: Union[SampleStd, DatasetStd] = field(init=False) 336 ref_mean: Union[SampleMean, DatasetMean] = field(init=False) 337 ref_std: Union[SampleStd, DatasetStd] = field(init=False) 338 339 @property 340 def required_measures(self): 341 return {self.mean, self.std, self.ref_mean, self.ref_std} 342 343 def __post_init__(self): 344 axes = None if self.axes is None else tuple(self.axes) 345 ref_tensor = self.reference_tensor or self.input 346 if axes is None or AxisId("batch") not in axes: 347 Mean = SampleMean 348 Std = SampleStd 349 else: 350 Mean = DatasetMean 351 Std = DatasetStd 352 353 self.mean = Mean(member_id=self.input, axes=axes) 354 self.std = Std(member_id=self.input, axes=axes) 355 self.ref_mean = Mean(member_id=ref_tensor, axes=axes) 356 self.ref_std = Std(member_id=ref_tensor, axes=axes) 357 358 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 359 mean = stat[self.mean] 360 std = stat[self.std] + self.eps 361 ref_mean = stat[self.ref_mean] 362 ref_std = stat[self.ref_std] + self.eps 363 return (input - mean) / std * ref_std + ref_mean 364 365 def get_output_shape( 366 self, input_shape: Mapping[AxisId, int] 367 ) -> Mapping[AxisId, int]: 368 return input_shape 369 370 @classmethod 371 def from_proc_descr( 372 cls, 373 descr: Union[v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr], 374 member_id: MemberId, 375 ) -> Self: 376 kwargs = descr.kwargs 377 _, axes = _get_axes(descr.kwargs) 378 379 return cls( 380 input=member_id, 381 output=member_id, 382 reference_tensor=MemberId(str(kwargs.reference_tensor)), 383 axes=axes, 384 eps=kwargs.eps, 385 ) 386 387 388def _get_axes( 389 kwargs: Union[ 390 v0_4.ZeroMeanUnitVarianceKwargs, 391 v0_5.ZeroMeanUnitVarianceKwargs, 392 v0_4.ScaleRangeKwargs, 393 v0_5.ScaleRangeKwargs, 394 v0_4.ScaleMeanVarianceKwargs, 395 v0_5.ScaleMeanVarianceKwargs, 396 ], 397) -> Tuple[bool, Optional[Tuple[AxisId, ...]]]: 398 if kwargs.axes is None: 399 return True, None 400 elif isinstance(kwargs.axes, str): 401 axes = _convert_axis_ids(kwargs.axes, kwargs["mode"]) 402 return AxisId("b") in axes, axes 403 elif isinstance(kwargs.axes, collections.abc.Sequence): 404 axes = tuple(kwargs.axes) 405 return AxisId("batch") in axes, axes 406 else: 407 assert_never(kwargs.axes) 408 409 410@dataclass 411class ScaleRange(_SimpleOperator): 412 lower_percentile: InitVar[Optional[Union[SampleQuantile, DatasetPercentile]]] = None 413 upper_percentile: InitVar[Optional[Union[SampleQuantile, DatasetPercentile]]] = None 414 lower: Union[SampleQuantile, DatasetPercentile] = field(init=False) 415 upper: Union[SampleQuantile, DatasetPercentile] = field(init=False) 416 417 eps: float = 1e-6 418 419 def __post_init__( 420 self, 421 lower_percentile: Optional[Union[SampleQuantile, DatasetPercentile]], 422 upper_percentile: Optional[Union[SampleQuantile, DatasetPercentile]], 423 ): 424 if lower_percentile is None: 425 tid = self.input if upper_percentile is None else upper_percentile.member_id 426 self.lower = DatasetPercentile(q=0.0, member_id=tid) 427 else: 428 self.lower = lower_percentile 429 430 if upper_percentile is None: 431 self.upper = DatasetPercentile(q=1.0, member_id=self.lower.member_id) 432 else: 433 self.upper = upper_percentile 434 435 assert self.lower.member_id == self.upper.member_id 436 assert self.lower.q < self.upper.q 437 assert self.lower.axes == self.upper.axes 438 439 @property 440 def required_measures(self): 441 return {self.lower, self.upper} 442 443 def get_output_shape( 444 self, input_shape: Mapping[AxisId, int] 445 ) -> Mapping[AxisId, int]: 446 return input_shape 447 448 @classmethod 449 def from_proc_descr( 450 cls, 451 descr: Union[v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr], 452 member_id: MemberId, 453 ): 454 kwargs = descr.kwargs 455 ref_tensor = ( 456 member_id 457 if kwargs.reference_tensor is None 458 else MemberId(str(kwargs.reference_tensor)) 459 ) 460 dataset_mode, axes = _get_axes(descr.kwargs) 461 if dataset_mode: 462 Percentile = DatasetPercentile 463 else: 464 Percentile = SampleQuantile 465 466 return cls( 467 input=member_id, 468 output=member_id, 469 lower_percentile=Percentile( 470 q=kwargs.min_percentile / 100, axes=axes, member_id=ref_tensor 471 ), 472 upper_percentile=Percentile( 473 q=kwargs.max_percentile / 100, axes=axes, member_id=ref_tensor 474 ), 475 ) 476 477 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 478 lower = stat[self.lower] 479 upper = stat[self.upper] 480 return (input - lower) / (upper - lower + self.eps) 481 482 def get_descr(self): 483 assert self.lower.axes == self.upper.axes 484 assert self.lower.member_id == self.upper.member_id 485 486 return v0_5.ScaleRangeDescr( 487 kwargs=v0_5.ScaleRangeKwargs( 488 axes=self.lower.axes, 489 min_percentile=self.lower.q * 100, 490 max_percentile=self.upper.q * 100, 491 eps=self.eps, 492 reference_tensor=self.lower.member_id, 493 ) 494 ) 495 496 497@dataclass 498class Sigmoid(_SimpleOperator): 499 """1 / (1 + e^(-input)).""" 500 501 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 502 return Tensor(1.0 / (1.0 + np.exp(-input)), dims=input.dims) 503 504 @property 505 def required_measures(self) -> Collection[Measure]: 506 return {} 507 508 def get_output_shape( 509 self, input_shape: Mapping[AxisId, int] 510 ) -> Mapping[AxisId, int]: 511 return input_shape 512 513 @classmethod 514 def from_proc_descr( 515 cls, descr: Union[v0_4.SigmoidDescr, v0_5.SigmoidDescr], member_id: MemberId 516 ) -> Self: 517 assert isinstance(descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)) 518 return cls(input=member_id, output=member_id) 519 520 def get_descr(self): 521 return v0_5.SigmoidDescr() 522 523 524@dataclass 525class ZeroMeanUnitVariance(_SimpleOperator): 526 """normalize to zero mean, unit variance.""" 527 528 mean: MeanMeasure 529 std: StdMeasure 530 531 eps: float = 1e-6 532 533 def __post_init__(self): 534 assert self.mean.axes == self.std.axes 535 536 @property 537 def required_measures(self) -> Set[Union[MeanMeasure, StdMeasure]]: 538 return {self.mean, self.std} 539 540 def get_output_shape( 541 self, input_shape: Mapping[AxisId, int] 542 ) -> Mapping[AxisId, int]: 543 return input_shape 544 545 @classmethod 546 def from_proc_descr( 547 cls, 548 descr: Union[v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr], 549 member_id: MemberId, 550 ): 551 dataset_mode, axes = _get_axes(descr.kwargs) 552 553 if dataset_mode: 554 Mean = DatasetMean 555 Std = DatasetStd 556 else: 557 Mean = SampleMean 558 Std = SampleStd 559 560 return cls( 561 input=member_id, 562 output=member_id, 563 mean=Mean(axes=axes, member_id=member_id), 564 std=Std(axes=axes, member_id=member_id), 565 ) 566 567 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 568 mean = stat[self.mean] 569 std = stat[self.std] 570 return (input - mean) / (std + self.eps) 571 572 def get_descr(self): 573 return v0_5.ZeroMeanUnitVarianceDescr( 574 kwargs=v0_5.ZeroMeanUnitVarianceKwargs(axes=self.mean.axes, eps=self.eps) 575 ) 576 577 578@dataclass 579class FixedZeroMeanUnitVariance(_SimpleOperator): 580 """normalize to zero mean, unit variance with precomputed values.""" 581 582 mean: Union[float, xr.DataArray] 583 std: Union[float, xr.DataArray] 584 585 eps: float = 1e-6 586 587 def __post_init__(self): 588 assert ( 589 isinstance(self.mean, (int, float)) 590 or isinstance(self.std, (int, float)) 591 or self.mean.dims == self.std.dims 592 ) 593 594 def get_output_shape( 595 self, input_shape: Mapping[AxisId, int] 596 ) -> Mapping[AxisId, int]: 597 return input_shape 598 599 @classmethod 600 def from_proc_descr( 601 cls, 602 descr: v0_5.FixedZeroMeanUnitVarianceDescr, 603 member_id: MemberId, 604 ) -> Self: 605 if isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceKwargs): 606 dims = None 607 elif isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs): 608 dims = (descr.kwargs.axis,) 609 else: 610 assert_never(descr.kwargs) 611 612 return cls( 613 input=member_id, 614 output=member_id, 615 mean=xr.DataArray(descr.kwargs.mean, dims=dims), 616 std=xr.DataArray(descr.kwargs.std, dims=dims), 617 ) 618 619 def get_descr(self): 620 if isinstance(self.mean, (int, float)): 621 assert isinstance(self.std, (int, float)) 622 kwargs = v0_5.FixedZeroMeanUnitVarianceKwargs(mean=self.mean, std=self.std) 623 else: 624 assert isinstance(self.std, xr.DataArray) 625 assert len(self.mean.dims) == 1 626 kwargs = v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs( 627 axis=AxisId(str(self.mean.dims[0])), 628 mean=list(self.mean), 629 std=list(self.std), 630 ) 631 632 return v0_5.FixedZeroMeanUnitVarianceDescr(kwargs=kwargs) 633 634 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 635 return (input - self.mean) / (self.std + self.eps) 636 637 638ProcDescr = Union[ 639 v0_4.PreprocessingDescr, 640 v0_4.PostprocessingDescr, 641 v0_5.PreprocessingDescr, 642 v0_5.PostprocessingDescr, 643] 644 645Processing = Union[ 646 AddKnownDatasetStats, 647 Binarize, 648 Clip, 649 EnsureDtype, 650 FixedZeroMeanUnitVariance, 651 ScaleLinear, 652 ScaleMeanVariance, 653 ScaleRange, 654 Sigmoid, 655 UpdateStats, 656 ZeroMeanUnitVariance, 657] 658 659 660def get_proc_class(proc_spec: ProcDescr): 661 if isinstance(proc_spec, (v0_4.BinarizeDescr, v0_5.BinarizeDescr)): 662 return Binarize 663 elif isinstance(proc_spec, (v0_4.ClipDescr, v0_5.ClipDescr)): 664 return Clip 665 elif isinstance(proc_spec, v0_5.EnsureDtypeDescr): 666 return EnsureDtype 667 elif isinstance(proc_spec, v0_5.FixedZeroMeanUnitVarianceDescr): 668 return FixedZeroMeanUnitVariance 669 elif isinstance(proc_spec, (v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr)): 670 return ScaleLinear 671 elif isinstance( 672 proc_spec, (v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr) 673 ): 674 return ScaleMeanVariance 675 elif isinstance(proc_spec, (v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr)): 676 return ScaleRange 677 elif isinstance(proc_spec, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)): 678 return Sigmoid 679 elif ( 680 isinstance(proc_spec, v0_4.ZeroMeanUnitVarianceDescr) 681 and proc_spec.kwargs.mode == "fixed" 682 ): 683 return FixedZeroMeanUnitVariance 684 elif isinstance( 685 proc_spec, 686 (v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr), 687 ): 688 return ZeroMeanUnitVariance 689 else: 690 assert_never(proc_spec)
@dataclass
class
AddKnownDatasetStats106@dataclass 107class AddKnownDatasetStats(BlockedOperator): 108 dataset_stats: Mapping[DatasetMeasure, MeasureValue] 109 110 @property 111 def required_measures(self) -> Set[Measure]: 112 return set() 113 114 def __call__(self, sample: Union[Sample, SampleBlock]) -> None: 115 sample.stat.update(self.dataset_stats.items())
AddKnownDatasetStats( dataset_stats: Mapping[Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer>, return_type=PydanticUndefined, when_used='always')]]])
dataset_stats: Mapping[Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7f9a7099e840>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7f9a7099ea20>, return_type=PydanticUndefined, when_used='always')]]]
required_measures: Set[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]
@dataclass
class
UpdateStats154@dataclass 155class UpdateStats(Operator): 156 """Calculates sample and/or dataset measures""" 157 158 stats_calculator: StatsCalculator 159 """`StatsCalculator` to be used by this operator.""" 160 keep_updating_initial_dataset_stats: bool = False 161 """indicates if operator calls should keep updating initial dataset statistics or not; 162 if the `stats_calculator` was not provided with any initial dataset statistics, 163 these are always updated with every new sample. 164 """ 165 _keep_updating_dataset_stats: bool = field(init=False) 166 167 @property 168 def required_measures(self) -> Set[Measure]: 169 return set() 170 171 def __post_init__(self): 172 self._keep_updating_dataset_stats = ( 173 self.keep_updating_initial_dataset_stats 174 or not self.stats_calculator.has_dataset_measures 175 ) 176 177 def __call__(self, sample: Union[Sample, SampleBlockWithOrigin]) -> None: 178 if isinstance(sample, SampleBlockWithOrigin): 179 # update stats with whole sample on first block 180 if sample.block_index != 0: 181 return 182 183 origin = sample.origin 184 else: 185 origin = sample 186 187 if self._keep_updating_dataset_stats: 188 sample.stat.update(self.stats_calculator.update_and_get_all(origin)) 189 else: 190 sample.stat.update(self.stats_calculator.skip_update_and_get_all(origin))
Calculates sample and/or dataset measures
UpdateStats( stats_calculator: bioimageio.core.stat_calculators.StatsCalculator, keep_updating_initial_dataset_stats: bool = False)
stats_calculator: bioimageio.core.stat_calculators.StatsCalculator
StatsCalculator
to be used by this operator.
keep_updating_initial_dataset_stats: bool =
False
indicates if operator calls should keep updating initial dataset statistics or not;
if the stats_calculator
was not provided with any initial dataset statistics,
these are always updated with every new sample.
required_measures: Set[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]
193@dataclass 194class Binarize(_SimpleOperator): 195 """'output = tensor > threshold'.""" 196 197 threshold: Union[float, Sequence[float]] 198 axis: Optional[AxisId] = None 199 200 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 201 return input > self.threshold 202 203 def get_output_shape( 204 self, input_shape: Mapping[AxisId, int] 205 ) -> Mapping[AxisId, int]: 206 return input_shape 207 208 @classmethod 209 def from_proc_descr( 210 cls, descr: Union[v0_4.BinarizeDescr, v0_5.BinarizeDescr], member_id: MemberId 211 ) -> Self: 212 if isinstance(descr.kwargs, (v0_4.BinarizeKwargs, v0_5.BinarizeKwargs)): 213 return cls( 214 input=member_id, output=member_id, threshold=descr.kwargs.threshold 215 ) 216 elif isinstance(descr.kwargs, v0_5.BinarizeAlongAxisKwargs): 217 return cls( 218 input=member_id, 219 output=member_id, 220 threshold=descr.kwargs.threshold, 221 axis=descr.kwargs.axis, 222 ) 223 else: 224 assert_never(descr.kwargs)
'output = tensor > threshold'.
Binarize( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId, threshold: Union[float, Sequence[float]], axis: Optional[bioimageio.spec.model.v0_5.AxisId] = None)
def
get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
@classmethod
def
from_proc_descr( cls, descr: Union[bioimageio.spec.model.v0_4.BinarizeDescr, bioimageio.spec.model.v0_5.BinarizeDescr], member_id: bioimageio.spec.model.v0_5.TensorId) -> Self:
208 @classmethod 209 def from_proc_descr( 210 cls, descr: Union[v0_4.BinarizeDescr, v0_5.BinarizeDescr], member_id: MemberId 211 ) -> Self: 212 if isinstance(descr.kwargs, (v0_4.BinarizeKwargs, v0_5.BinarizeKwargs)): 213 return cls( 214 input=member_id, output=member_id, threshold=descr.kwargs.threshold 215 ) 216 elif isinstance(descr.kwargs, v0_5.BinarizeAlongAxisKwargs): 217 return cls( 218 input=member_id, 219 output=member_id, 220 threshold=descr.kwargs.threshold, 221 axis=descr.kwargs.axis, 222 ) 223 else: 224 assert_never(descr.kwargs)
Inherited Members
227@dataclass 228class Clip(_SimpleOperator): 229 min: Optional[float] = None 230 """minimum value for clipping""" 231 max: Optional[float] = None 232 """maximum value for clipping""" 233 234 def __post_init__(self): 235 assert self.min is not None or self.max is not None, "missing min or max value" 236 assert ( 237 self.min is None or self.max is None or self.min < self.max 238 ), f"expected min < max, but {self.min} !< {self.max}" 239 240 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 241 return input.clip(self.min, self.max) 242 243 def get_output_shape( 244 self, input_shape: Mapping[AxisId, int] 245 ) -> Mapping[AxisId, int]: 246 return input_shape 247 248 @classmethod 249 def from_proc_descr( 250 cls, descr: Union[v0_4.ClipDescr, v0_5.ClipDescr], member_id: MemberId 251 ) -> Self: 252 return cls( 253 input=member_id, 254 output=member_id, 255 min=descr.kwargs.min, 256 max=descr.kwargs.max, 257 )
Clip( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId, min: Optional[float] = None, max: Optional[float] = None)
def
get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
@classmethod
def
from_proc_descr( cls, descr: Union[bioimageio.spec.model.v0_4.ClipDescr, bioimageio.spec.model.v0_5.ClipDescr], member_id: bioimageio.spec.model.v0_5.TensorId) -> Self:
Inherited Members
260@dataclass 261class EnsureDtype(_SimpleOperator): 262 dtype: DTypeStr 263 264 @classmethod 265 def from_proc_descr(cls, descr: v0_5.EnsureDtypeDescr, member_id: MemberId): 266 return cls(input=member_id, output=member_id, dtype=descr.kwargs.dtype) 267 268 def get_descr(self): 269 return v0_5.EnsureDtypeDescr(kwargs=v0_5.EnsureDtypeKwargs(dtype=self.dtype)) 270 271 def get_output_shape( 272 self, input_shape: Mapping[AxisId, int] 273 ) -> Mapping[AxisId, int]: 274 return input_shape 275 276 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 277 return input.astype(self.dtype)
EnsureDtype( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId, dtype: Literal['bool', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64'])
dtype: Literal['bool', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64']
@classmethod
def
from_proc_descr( cls, descr: bioimageio.spec.model.v0_5.EnsureDtypeDescr, member_id: bioimageio.spec.model.v0_5.TensorId):
def
get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
Inherited Members
280@dataclass 281class ScaleLinear(_SimpleOperator): 282 gain: Union[float, xr.DataArray] = 1.0 283 """multiplicative factor""" 284 285 offset: Union[float, xr.DataArray] = 0.0 286 """additive term""" 287 288 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 289 return input * self.gain + self.offset 290 291 def get_output_shape( 292 self, input_shape: Mapping[AxisId, int] 293 ) -> Mapping[AxisId, int]: 294 return input_shape 295 296 @classmethod 297 def from_proc_descr( 298 cls, 299 descr: Union[v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr], 300 member_id: MemberId, 301 ) -> Self: 302 kwargs = descr.kwargs 303 if isinstance(kwargs, v0_5.ScaleLinearAlongAxisKwargs): 304 axis = kwargs.axis 305 elif isinstance(kwargs, (v0_4.ScaleLinearKwargs, v0_5.ScaleLinearKwargs)): 306 axis = None 307 else: 308 assert_never(kwargs) 309 310 if axis: 311 gain = xr.DataArray(np.atleast_1d(kwargs.gain), dims=axis) 312 offset = xr.DataArray(np.atleast_1d(kwargs.offset), dims=axis) 313 else: 314 assert ( 315 isinstance(kwargs.gain, (float, int)) or len(kwargs.gain) == 1 316 ), kwargs.gain 317 gain = ( 318 kwargs.gain if isinstance(kwargs.gain, (float, int)) else kwargs.gain[0] 319 ) 320 assert isinstance(kwargs.offset, (float, int)) or len(kwargs.offset) == 1 321 offset = ( 322 kwargs.offset 323 if isinstance(kwargs.offset, (float, int)) 324 else kwargs.offset[0] 325 ) 326 327 return cls(input=member_id, output=member_id, gain=gain, offset=offset)
ScaleLinear( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId, gain: Union[float, xarray.core.dataarray.DataArray] = 1.0, offset: Union[float, xarray.core.dataarray.DataArray] = 0.0)
def
get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
@classmethod
def
from_proc_descr( cls, descr: Union[bioimageio.spec.model.v0_4.ScaleLinearDescr, bioimageio.spec.model.v0_5.ScaleLinearDescr], member_id: bioimageio.spec.model.v0_5.TensorId) -> Self:
296 @classmethod 297 def from_proc_descr( 298 cls, 299 descr: Union[v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr], 300 member_id: MemberId, 301 ) -> Self: 302 kwargs = descr.kwargs 303 if isinstance(kwargs, v0_5.ScaleLinearAlongAxisKwargs): 304 axis = kwargs.axis 305 elif isinstance(kwargs, (v0_4.ScaleLinearKwargs, v0_5.ScaleLinearKwargs)): 306 axis = None 307 else: 308 assert_never(kwargs) 309 310 if axis: 311 gain = xr.DataArray(np.atleast_1d(kwargs.gain), dims=axis) 312 offset = xr.DataArray(np.atleast_1d(kwargs.offset), dims=axis) 313 else: 314 assert ( 315 isinstance(kwargs.gain, (float, int)) or len(kwargs.gain) == 1 316 ), kwargs.gain 317 gain = ( 318 kwargs.gain if isinstance(kwargs.gain, (float, int)) else kwargs.gain[0] 319 ) 320 assert isinstance(kwargs.offset, (float, int)) or len(kwargs.offset) == 1 321 offset = ( 322 kwargs.offset 323 if isinstance(kwargs.offset, (float, int)) 324 else kwargs.offset[0] 325 ) 326 327 return cls(input=member_id, output=member_id, gain=gain, offset=offset)
Inherited Members
330@dataclass 331class ScaleMeanVariance(_SimpleOperator): 332 axes: Optional[Sequence[AxisId]] = None 333 reference_tensor: Optional[MemberId] = None 334 eps: float = 1e-6 335 mean: Union[SampleMean, DatasetMean] = field(init=False) 336 std: Union[SampleStd, DatasetStd] = field(init=False) 337 ref_mean: Union[SampleMean, DatasetMean] = field(init=False) 338 ref_std: Union[SampleStd, DatasetStd] = field(init=False) 339 340 @property 341 def required_measures(self): 342 return {self.mean, self.std, self.ref_mean, self.ref_std} 343 344 def __post_init__(self): 345 axes = None if self.axes is None else tuple(self.axes) 346 ref_tensor = self.reference_tensor or self.input 347 if axes is None or AxisId("batch") not in axes: 348 Mean = SampleMean 349 Std = SampleStd 350 else: 351 Mean = DatasetMean 352 Std = DatasetStd 353 354 self.mean = Mean(member_id=self.input, axes=axes) 355 self.std = Std(member_id=self.input, axes=axes) 356 self.ref_mean = Mean(member_id=ref_tensor, axes=axes) 357 self.ref_std = Std(member_id=ref_tensor, axes=axes) 358 359 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 360 mean = stat[self.mean] 361 std = stat[self.std] + self.eps 362 ref_mean = stat[self.ref_mean] 363 ref_std = stat[self.ref_std] + self.eps 364 return (input - mean) / std * ref_std + ref_mean 365 366 def get_output_shape( 367 self, input_shape: Mapping[AxisId, int] 368 ) -> Mapping[AxisId, int]: 369 return input_shape 370 371 @classmethod 372 def from_proc_descr( 373 cls, 374 descr: Union[v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr], 375 member_id: MemberId, 376 ) -> Self: 377 kwargs = descr.kwargs 378 _, axes = _get_axes(descr.kwargs) 379 380 return cls( 381 input=member_id, 382 output=member_id, 383 reference_tensor=MemberId(str(kwargs.reference_tensor)), 384 axes=axes, 385 eps=kwargs.eps, 386 )
ScaleMeanVariance( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId, axes: Optional[Sequence[bioimageio.spec.model.v0_5.AxisId]] = None, reference_tensor: Optional[bioimageio.spec.model.v0_5.TensorId] = None, eps: float = 1e-06)
ref_mean: Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.DatasetMean]
def
get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
@classmethod
def
from_proc_descr( cls, descr: Union[bioimageio.spec.model.v0_4.ScaleMeanVarianceDescr, bioimageio.spec.model.v0_5.ScaleMeanVarianceDescr], member_id: bioimageio.spec.model.v0_5.TensorId) -> Self:
371 @classmethod 372 def from_proc_descr( 373 cls, 374 descr: Union[v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr], 375 member_id: MemberId, 376 ) -> Self: 377 kwargs = descr.kwargs 378 _, axes = _get_axes(descr.kwargs) 379 380 return cls( 381 input=member_id, 382 output=member_id, 383 reference_tensor=MemberId(str(kwargs.reference_tensor)), 384 axes=axes, 385 eps=kwargs.eps, 386 )
Inherited Members
411@dataclass 412class ScaleRange(_SimpleOperator): 413 lower_percentile: InitVar[Optional[Union[SampleQuantile, DatasetPercentile]]] = None 414 upper_percentile: InitVar[Optional[Union[SampleQuantile, DatasetPercentile]]] = None 415 lower: Union[SampleQuantile, DatasetPercentile] = field(init=False) 416 upper: Union[SampleQuantile, DatasetPercentile] = field(init=False) 417 418 eps: float = 1e-6 419 420 def __post_init__( 421 self, 422 lower_percentile: Optional[Union[SampleQuantile, DatasetPercentile]], 423 upper_percentile: Optional[Union[SampleQuantile, DatasetPercentile]], 424 ): 425 if lower_percentile is None: 426 tid = self.input if upper_percentile is None else upper_percentile.member_id 427 self.lower = DatasetPercentile(q=0.0, member_id=tid) 428 else: 429 self.lower = lower_percentile 430 431 if upper_percentile is None: 432 self.upper = DatasetPercentile(q=1.0, member_id=self.lower.member_id) 433 else: 434 self.upper = upper_percentile 435 436 assert self.lower.member_id == self.upper.member_id 437 assert self.lower.q < self.upper.q 438 assert self.lower.axes == self.upper.axes 439 440 @property 441 def required_measures(self): 442 return {self.lower, self.upper} 443 444 def get_output_shape( 445 self, input_shape: Mapping[AxisId, int] 446 ) -> Mapping[AxisId, int]: 447 return input_shape 448 449 @classmethod 450 def from_proc_descr( 451 cls, 452 descr: Union[v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr], 453 member_id: MemberId, 454 ): 455 kwargs = descr.kwargs 456 ref_tensor = ( 457 member_id 458 if kwargs.reference_tensor is None 459 else MemberId(str(kwargs.reference_tensor)) 460 ) 461 dataset_mode, axes = _get_axes(descr.kwargs) 462 if dataset_mode: 463 Percentile = DatasetPercentile 464 else: 465 Percentile = SampleQuantile 466 467 return cls( 468 input=member_id, 469 output=member_id, 470 lower_percentile=Percentile( 471 q=kwargs.min_percentile / 100, axes=axes, member_id=ref_tensor 472 ), 473 upper_percentile=Percentile( 474 q=kwargs.max_percentile / 100, axes=axes, member_id=ref_tensor 475 ), 476 ) 477 478 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 479 lower = stat[self.lower] 480 upper = stat[self.upper] 481 return (input - lower) / (upper - lower + self.eps) 482 483 def get_descr(self): 484 assert self.lower.axes == self.upper.axes 485 assert self.lower.member_id == self.upper.member_id 486 487 return v0_5.ScaleRangeDescr( 488 kwargs=v0_5.ScaleRangeKwargs( 489 axes=self.lower.axes, 490 min_percentile=self.lower.q * 100, 491 max_percentile=self.upper.q * 100, 492 eps=self.eps, 493 reference_tensor=self.lower.member_id, 494 ) 495 )
ScaleRange( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId, lower_percentile: dataclasses.InitVar[typing.Union[bioimageio.core.stat_measures.SampleQuantile, bioimageio.core.stat_measures.DatasetPercentile, NoneType]] = None, upper_percentile: dataclasses.InitVar[typing.Union[bioimageio.core.stat_measures.SampleQuantile, bioimageio.core.stat_measures.DatasetPercentile, NoneType]] = None, eps: float = 1e-06)
lower_percentile: dataclasses.InitVar[typing.Union[bioimageio.core.stat_measures.SampleQuantile, bioimageio.core.stat_measures.DatasetPercentile, NoneType]] =
None
upper_percentile: dataclasses.InitVar[typing.Union[bioimageio.core.stat_measures.SampleQuantile, bioimageio.core.stat_measures.DatasetPercentile, NoneType]] =
None
lower: Union[bioimageio.core.stat_measures.SampleQuantile, bioimageio.core.stat_measures.DatasetPercentile]
upper: Union[bioimageio.core.stat_measures.SampleQuantile, bioimageio.core.stat_measures.DatasetPercentile]
def
get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
@classmethod
def
from_proc_descr( cls, descr: Union[bioimageio.spec.model.v0_4.ScaleRangeDescr, bioimageio.spec.model.v0_5.ScaleRangeDescr], member_id: bioimageio.spec.model.v0_5.TensorId):
449 @classmethod 450 def from_proc_descr( 451 cls, 452 descr: Union[v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr], 453 member_id: MemberId, 454 ): 455 kwargs = descr.kwargs 456 ref_tensor = ( 457 member_id 458 if kwargs.reference_tensor is None 459 else MemberId(str(kwargs.reference_tensor)) 460 ) 461 dataset_mode, axes = _get_axes(descr.kwargs) 462 if dataset_mode: 463 Percentile = DatasetPercentile 464 else: 465 Percentile = SampleQuantile 466 467 return cls( 468 input=member_id, 469 output=member_id, 470 lower_percentile=Percentile( 471 q=kwargs.min_percentile / 100, axes=axes, member_id=ref_tensor 472 ), 473 upper_percentile=Percentile( 474 q=kwargs.max_percentile / 100, axes=axes, member_id=ref_tensor 475 ), 476 )
def
get_descr(self):
483 def get_descr(self): 484 assert self.lower.axes == self.upper.axes 485 assert self.lower.member_id == self.upper.member_id 486 487 return v0_5.ScaleRangeDescr( 488 kwargs=v0_5.ScaleRangeKwargs( 489 axes=self.lower.axes, 490 min_percentile=self.lower.q * 100, 491 max_percentile=self.upper.q * 100, 492 eps=self.eps, 493 reference_tensor=self.lower.member_id, 494 ) 495 )
Inherited Members
498@dataclass 499class Sigmoid(_SimpleOperator): 500 """1 / (1 + e^(-input)).""" 501 502 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 503 return Tensor(1.0 / (1.0 + np.exp(-input)), dims=input.dims) 504 505 @property 506 def required_measures(self) -> Collection[Measure]: 507 return {} 508 509 def get_output_shape( 510 self, input_shape: Mapping[AxisId, int] 511 ) -> Mapping[AxisId, int]: 512 return input_shape 513 514 @classmethod 515 def from_proc_descr( 516 cls, descr: Union[v0_4.SigmoidDescr, v0_5.SigmoidDescr], member_id: MemberId 517 ) -> Self: 518 assert isinstance(descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)) 519 return cls(input=member_id, output=member_id) 520 521 def get_descr(self): 522 return v0_5.SigmoidDescr()
1 / (1 + e^(-input)).
Sigmoid( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId)
required_measures: Collection[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]
def
get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
@classmethod
def
from_proc_descr( cls, descr: Union[bioimageio.spec.model.v0_4.SigmoidDescr, bioimageio.spec.model.v0_5.SigmoidDescr], member_id: bioimageio.spec.model.v0_5.TensorId) -> Self:
Inherited Members
525@dataclass 526class ZeroMeanUnitVariance(_SimpleOperator): 527 """normalize to zero mean, unit variance.""" 528 529 mean: MeanMeasure 530 std: StdMeasure 531 532 eps: float = 1e-6 533 534 def __post_init__(self): 535 assert self.mean.axes == self.std.axes 536 537 @property 538 def required_measures(self) -> Set[Union[MeanMeasure, StdMeasure]]: 539 return {self.mean, self.std} 540 541 def get_output_shape( 542 self, input_shape: Mapping[AxisId, int] 543 ) -> Mapping[AxisId, int]: 544 return input_shape 545 546 @classmethod 547 def from_proc_descr( 548 cls, 549 descr: Union[v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr], 550 member_id: MemberId, 551 ): 552 dataset_mode, axes = _get_axes(descr.kwargs) 553 554 if dataset_mode: 555 Mean = DatasetMean 556 Std = DatasetStd 557 else: 558 Mean = SampleMean 559 Std = SampleStd 560 561 return cls( 562 input=member_id, 563 output=member_id, 564 mean=Mean(axes=axes, member_id=member_id), 565 std=Std(axes=axes, member_id=member_id), 566 ) 567 568 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 569 mean = stat[self.mean] 570 std = stat[self.std] 571 return (input - mean) / (std + self.eps) 572 573 def get_descr(self): 574 return v0_5.ZeroMeanUnitVarianceDescr( 575 kwargs=v0_5.ZeroMeanUnitVarianceKwargs(axes=self.mean.axes, eps=self.eps) 576 )
normalize to zero mean, unit variance.
ZeroMeanUnitVariance( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId, mean: Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.DatasetMean], std: Union[bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.DatasetStd], eps: float = 1e-06)
required_measures: Set[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.DatasetStd]]
def
get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
@classmethod
def
from_proc_descr( cls, descr: Union[bioimageio.spec.model.v0_4.ZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_5.ZeroMeanUnitVarianceDescr], member_id: bioimageio.spec.model.v0_5.TensorId):
546 @classmethod 547 def from_proc_descr( 548 cls, 549 descr: Union[v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr], 550 member_id: MemberId, 551 ): 552 dataset_mode, axes = _get_axes(descr.kwargs) 553 554 if dataset_mode: 555 Mean = DatasetMean 556 Std = DatasetStd 557 else: 558 Mean = SampleMean 559 Std = SampleStd 560 561 return cls( 562 input=member_id, 563 output=member_id, 564 mean=Mean(axes=axes, member_id=member_id), 565 std=Std(axes=axes, member_id=member_id), 566 )
Inherited Members
579@dataclass 580class FixedZeroMeanUnitVariance(_SimpleOperator): 581 """normalize to zero mean, unit variance with precomputed values.""" 582 583 mean: Union[float, xr.DataArray] 584 std: Union[float, xr.DataArray] 585 586 eps: float = 1e-6 587 588 def __post_init__(self): 589 assert ( 590 isinstance(self.mean, (int, float)) 591 or isinstance(self.std, (int, float)) 592 or self.mean.dims == self.std.dims 593 ) 594 595 def get_output_shape( 596 self, input_shape: Mapping[AxisId, int] 597 ) -> Mapping[AxisId, int]: 598 return input_shape 599 600 @classmethod 601 def from_proc_descr( 602 cls, 603 descr: v0_5.FixedZeroMeanUnitVarianceDescr, 604 member_id: MemberId, 605 ) -> Self: 606 if isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceKwargs): 607 dims = None 608 elif isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs): 609 dims = (descr.kwargs.axis,) 610 else: 611 assert_never(descr.kwargs) 612 613 return cls( 614 input=member_id, 615 output=member_id, 616 mean=xr.DataArray(descr.kwargs.mean, dims=dims), 617 std=xr.DataArray(descr.kwargs.std, dims=dims), 618 ) 619 620 def get_descr(self): 621 if isinstance(self.mean, (int, float)): 622 assert isinstance(self.std, (int, float)) 623 kwargs = v0_5.FixedZeroMeanUnitVarianceKwargs(mean=self.mean, std=self.std) 624 else: 625 assert isinstance(self.std, xr.DataArray) 626 assert len(self.mean.dims) == 1 627 kwargs = v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs( 628 axis=AxisId(str(self.mean.dims[0])), 629 mean=list(self.mean), 630 std=list(self.std), 631 ) 632 633 return v0_5.FixedZeroMeanUnitVarianceDescr(kwargs=kwargs) 634 635 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 636 return (input - self.mean) / (self.std + self.eps)
normalize to zero mean, unit variance with precomputed values.
FixedZeroMeanUnitVariance( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId, mean: Union[float, xarray.core.dataarray.DataArray], std: Union[float, xarray.core.dataarray.DataArray], eps: float = 1e-06)
def
get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
@classmethod
def
from_proc_descr( cls, descr: bioimageio.spec.model.v0_5.FixedZeroMeanUnitVarianceDescr, member_id: bioimageio.spec.model.v0_5.TensorId) -> Self:
600 @classmethod 601 def from_proc_descr( 602 cls, 603 descr: v0_5.FixedZeroMeanUnitVarianceDescr, 604 member_id: MemberId, 605 ) -> Self: 606 if isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceKwargs): 607 dims = None 608 elif isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs): 609 dims = (descr.kwargs.axis,) 610 else: 611 assert_never(descr.kwargs) 612 613 return cls( 614 input=member_id, 615 output=member_id, 616 mean=xr.DataArray(descr.kwargs.mean, dims=dims), 617 std=xr.DataArray(descr.kwargs.std, dims=dims), 618 )
def
get_descr(self):
620 def get_descr(self): 621 if isinstance(self.mean, (int, float)): 622 assert isinstance(self.std, (int, float)) 623 kwargs = v0_5.FixedZeroMeanUnitVarianceKwargs(mean=self.mean, std=self.std) 624 else: 625 assert isinstance(self.std, xr.DataArray) 626 assert len(self.mean.dims) == 1 627 kwargs = v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs( 628 axis=AxisId(str(self.mean.dims[0])), 629 mean=list(self.mean), 630 std=list(self.std), 631 ) 632 633 return v0_5.FixedZeroMeanUnitVarianceDescr(kwargs=kwargs)
Inherited Members
ProcDescr =
typing.Union[typing.Annotated[typing.Union[bioimageio.spec.model.v0_4.BinarizeDescr, bioimageio.spec.model.v0_4.ClipDescr, bioimageio.spec.model.v0_4.ScaleLinearDescr, bioimageio.spec.model.v0_4.SigmoidDescr, bioimageio.spec.model.v0_4.ZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_4.ScaleRangeDescr], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], typing.Annotated[typing.Union[bioimageio.spec.model.v0_4.BinarizeDescr, bioimageio.spec.model.v0_4.ClipDescr, bioimageio.spec.model.v0_4.ScaleLinearDescr, bioimageio.spec.model.v0_4.SigmoidDescr, bioimageio.spec.model.v0_4.ZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_4.ScaleRangeDescr, bioimageio.spec.model.v0_4.ScaleMeanVarianceDescr], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], typing.Annotated[typing.Union[bioimageio.spec.model.v0_5.BinarizeDescr, bioimageio.spec.model.v0_5.ClipDescr, bioimageio.spec.model.v0_5.EnsureDtypeDescr, bioimageio.spec.model.v0_5.ScaleLinearDescr, bioimageio.spec.model.v0_5.SigmoidDescr, bioimageio.spec.model.v0_5.FixedZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_5.ZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_5.ScaleRangeDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)], typing.Annotated[typing.Union[bioimageio.spec.model.v0_5.BinarizeDescr, bioimageio.spec.model.v0_5.ClipDescr, bioimageio.spec.model.v0_5.EnsureDtypeDescr, bioimageio.spec.model.v0_5.ScaleLinearDescr, bioimageio.spec.model.v0_5.SigmoidDescr, bioimageio.spec.model.v0_5.FixedZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_5.ZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_5.ScaleRangeDescr, bioimageio.spec.model.v0_5.ScaleMeanVarianceDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]
Processing =
typing.Union[AddKnownDatasetStats, Binarize, Clip, EnsureDtype, FixedZeroMeanUnitVariance, ScaleLinear, ScaleMeanVariance, ScaleRange, Sigmoid, UpdateStats, ZeroMeanUnitVariance]
def
get_proc_class( proc_spec: Union[Annotated[Union[bioimageio.spec.model.v0_4.BinarizeDescr, bioimageio.spec.model.v0_4.ClipDescr, bioimageio.spec.model.v0_4.ScaleLinearDescr, bioimageio.spec.model.v0_4.SigmoidDescr, bioimageio.spec.model.v0_4.ZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_4.ScaleRangeDescr], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.spec.model.v0_4.BinarizeDescr, bioimageio.spec.model.v0_4.ClipDescr, bioimageio.spec.model.v0_4.ScaleLinearDescr, bioimageio.spec.model.v0_4.SigmoidDescr, bioimageio.spec.model.v0_4.ZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_4.ScaleRangeDescr, bioimageio.spec.model.v0_4.ScaleMeanVarianceDescr], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.spec.model.v0_5.BinarizeDescr, bioimageio.spec.model.v0_5.ClipDescr, bioimageio.spec.model.v0_5.EnsureDtypeDescr, bioimageio.spec.model.v0_5.ScaleLinearDescr, bioimageio.spec.model.v0_5.SigmoidDescr, bioimageio.spec.model.v0_5.FixedZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_5.ZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_5.ScaleRangeDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.spec.model.v0_5.BinarizeDescr, bioimageio.spec.model.v0_5.ClipDescr, bioimageio.spec.model.v0_5.EnsureDtypeDescr, bioimageio.spec.model.v0_5.ScaleLinearDescr, bioimageio.spec.model.v0_5.SigmoidDescr, bioimageio.spec.model.v0_5.FixedZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_5.ZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_5.ScaleRangeDescr, bioimageio.spec.model.v0_5.ScaleMeanVarianceDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]):
661def get_proc_class(proc_spec: ProcDescr): 662 if isinstance(proc_spec, (v0_4.BinarizeDescr, v0_5.BinarizeDescr)): 663 return Binarize 664 elif isinstance(proc_spec, (v0_4.ClipDescr, v0_5.ClipDescr)): 665 return Clip 666 elif isinstance(proc_spec, v0_5.EnsureDtypeDescr): 667 return EnsureDtype 668 elif isinstance(proc_spec, v0_5.FixedZeroMeanUnitVarianceDescr): 669 return FixedZeroMeanUnitVariance 670 elif isinstance(proc_spec, (v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr)): 671 return ScaleLinear 672 elif isinstance( 673 proc_spec, (v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr) 674 ): 675 return ScaleMeanVariance 676 elif isinstance(proc_spec, (v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr)): 677 return ScaleRange 678 elif isinstance(proc_spec, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)): 679 return Sigmoid 680 elif ( 681 isinstance(proc_spec, v0_4.ZeroMeanUnitVarianceDescr) 682 and proc_spec.kwargs.mode == "fixed" 683 ): 684 return FixedZeroMeanUnitVariance 685 elif isinstance( 686 proc_spec, 687 (v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr), 688 ): 689 return ZeroMeanUnitVariance 690 else: 691 assert_never(proc_spec)