bioimageio.core.proc_ops
1import collections.abc 2from abc import ABC, abstractmethod 3from dataclasses import InitVar, dataclass, field 4from typing import ( 5 Collection, 6 Literal, 7 Mapping, 8 Optional, 9 Sequence, 10 Set, 11 Tuple, 12 Union, 13) 14 15import numpy as np 16import xarray as xr 17from typing_extensions import Self, assert_never 18 19from bioimageio.core.digest_spec import get_member_id 20from bioimageio.spec.model import v0_4, v0_5 21from bioimageio.spec.model.v0_5 import ( 22 _convert_proc, # pyright: ignore [reportPrivateUsage] 23) 24 25from ._op_base import BlockedOperator, Operator 26from .axis import AxisId, PerAxis 27from .block import Block 28from .common import DTypeStr, MemberId 29from .sample import Sample, SampleBlock, SampleBlockWithOrigin 30from .stat_calculators import StatsCalculator 31from .stat_measures import ( 32 DatasetMean, 33 DatasetMeasure, 34 DatasetPercentile, 35 DatasetStd, 36 MeanMeasure, 37 Measure, 38 MeasureValue, 39 SampleMean, 40 SampleQuantile, 41 SampleStd, 42 Stat, 43 StdMeasure, 44) 45from .tensor import Tensor 46 47 48def _convert_axis_ids( 49 axes: v0_4.AxesInCZYX, 50 mode: Literal["per_sample", "per_dataset"], 51) -> Tuple[AxisId, ...]: 52 if not isinstance(axes, str): 53 return tuple(axes) 54 55 if mode == "per_sample": 56 ret = [] 57 elif mode == "per_dataset": 58 ret = [v0_5.BATCH_AXIS_ID] 59 else: 60 assert_never(mode) 61 62 ret.extend([AxisId(a) for a in axes]) 63 return tuple(ret) 64 65 66@dataclass 67class _SimpleOperator(BlockedOperator, ABC): 68 input: MemberId 69 output: MemberId 70 71 @property 72 def required_measures(self) -> Collection[Measure]: 73 return set() 74 75 @abstractmethod 76 def get_output_shape(self, input_shape: PerAxis[int]) -> PerAxis[int]: ... 77 78 def __call__(self, sample: Union[Sample, SampleBlock]) -> None: 79 if self.input not in sample.members: 80 return 81 82 input_tensor = sample.members[self.input] 83 output_tensor = self._apply(input_tensor, sample.stat) 84 85 if self.output in sample.members: 86 assert ( 87 sample.members[self.output].tagged_shape == output_tensor.tagged_shape 88 ) 89 90 if isinstance(sample, Sample): 91 sample.members[self.output] = output_tensor 92 elif isinstance(sample, SampleBlock): 93 b = sample.blocks[self.input] 94 sample.blocks[self.output] = Block( 95 sample_shape=self.get_output_shape(sample.shape[self.input]), 96 data=output_tensor, 97 inner_slice=b.inner_slice, 98 halo=b.halo, 99 block_index=b.block_index, 100 blocks_in_sample=b.blocks_in_sample, 101 ) 102 else: 103 assert_never(sample) 104 105 @abstractmethod 106 def _apply(self, input: Tensor, stat: Stat) -> Tensor: ... 107 108 109@dataclass 110class AddKnownDatasetStats(BlockedOperator): 111 dataset_stats: Mapping[DatasetMeasure, MeasureValue] 112 113 @property 114 def required_measures(self) -> Set[Measure]: 115 return set() 116 117 def __call__(self, sample: Union[Sample, SampleBlock]) -> None: 118 sample.stat.update(self.dataset_stats.items()) 119 120 121# @dataclass 122# class UpdateStats(Operator): 123# """Calculates sample and/or dataset measures""" 124 125# measures: Union[Sequence[Measure], Set[Measure], Mapping[Measure, MeasureValue]] 126# """sample and dataset `measuers` to be calculated by this operator. Initial/fixed 127# dataset measure values may be given, see `keep_updating_dataset_stats` for details. 128# """ 129# keep_updating_dataset_stats: Optional[bool] = None 130# """indicates if operator calls should keep updating dataset statistics or not 131 132# default (None): if `measures` is a `Mapping` (i.e. initial measure values are 133# given) no further updates to dataset statistics is conducted, otherwise (w.o. 134# initial measure values) dataset statistics are updated by each processed sample. 135# """ 136# _keep_updating_dataset_stats: bool = field(init=False) 137# _stats_calculator: StatsCalculator = field(init=False) 138 139# @property 140# def required_measures(self) -> Set[Measure]: 141# return set() 142 143# def __post_init__(self): 144# self._stats_calculator = StatsCalculator(self.measures) 145# if self.keep_updating_dataset_stats is None: 146# self._keep_updating_dataset_stats = not isinstance(self.measures, collections.abc.Mapping) 147# else: 148# self._keep_updating_dataset_stats = self.keep_updating_dataset_stats 149 150# def __call__(self, sample_block: SampleBlockWithOrigin> None: 151# if self._keep_updating_dataset_stats: 152# sample.stat.update(self._stats_calculator.update_and_get_all(sample)) 153# else: 154# sample.stat.update(self._stats_calculator.skip_update_and_get_all(sample)) 155 156 157@dataclass 158class UpdateStats(Operator): 159 """Calculates sample and/or dataset measures""" 160 161 stats_calculator: StatsCalculator 162 """`StatsCalculator` to be used by this operator.""" 163 keep_updating_initial_dataset_stats: bool = False 164 """indicates if operator calls should keep updating initial dataset statistics or not; 165 if the `stats_calculator` was not provided with any initial dataset statistics, 166 these are always updated with every new sample. 167 """ 168 _keep_updating_dataset_stats: bool = field(init=False) 169 170 @property 171 def required_measures(self) -> Set[Measure]: 172 return set() 173 174 def __post_init__(self): 175 self._keep_updating_dataset_stats = ( 176 self.keep_updating_initial_dataset_stats 177 or not self.stats_calculator.has_dataset_measures 178 ) 179 180 def __call__(self, sample: Union[Sample, SampleBlockWithOrigin]) -> None: 181 if isinstance(sample, SampleBlockWithOrigin): 182 # update stats with whole sample on first block 183 if sample.block_index != 0: 184 return 185 186 origin = sample.origin 187 else: 188 origin = sample 189 190 if self._keep_updating_dataset_stats: 191 sample.stat.update(self.stats_calculator.update_and_get_all(origin)) 192 else: 193 sample.stat.update(self.stats_calculator.skip_update_and_get_all(origin)) 194 195 196@dataclass 197class Binarize(_SimpleOperator): 198 """'output = tensor > threshold'.""" 199 200 threshold: Union[float, Sequence[float]] 201 axis: Optional[AxisId] = None 202 203 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 204 return input > self.threshold 205 206 def get_output_shape( 207 self, input_shape: Mapping[AxisId, int] 208 ) -> Mapping[AxisId, int]: 209 return input_shape 210 211 @classmethod 212 def from_proc_descr( 213 cls, descr: Union[v0_4.BinarizeDescr, v0_5.BinarizeDescr], member_id: MemberId 214 ) -> Self: 215 if isinstance(descr.kwargs, (v0_4.BinarizeKwargs, v0_5.BinarizeKwargs)): 216 return cls( 217 input=member_id, output=member_id, threshold=descr.kwargs.threshold 218 ) 219 elif isinstance(descr.kwargs, v0_5.BinarizeAlongAxisKwargs): 220 return cls( 221 input=member_id, 222 output=member_id, 223 threshold=descr.kwargs.threshold, 224 axis=descr.kwargs.axis, 225 ) 226 else: 227 assert_never(descr.kwargs) 228 229 230@dataclass 231class Clip(_SimpleOperator): 232 min: Optional[float] = None 233 """minimum value for clipping""" 234 max: Optional[float] = None 235 """maximum value for clipping""" 236 237 def __post_init__(self): 238 assert self.min is not None or self.max is not None, "missing min or max value" 239 assert ( 240 self.min is None or self.max is None or self.min < self.max 241 ), f"expected min < max, but {self.min} !< {self.max}" 242 243 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 244 return input.clip(self.min, self.max) 245 246 def get_output_shape( 247 self, input_shape: Mapping[AxisId, int] 248 ) -> Mapping[AxisId, int]: 249 return input_shape 250 251 @classmethod 252 def from_proc_descr( 253 cls, descr: Union[v0_4.ClipDescr, v0_5.ClipDescr], member_id: MemberId 254 ) -> Self: 255 return cls( 256 input=member_id, 257 output=member_id, 258 min=descr.kwargs.min, 259 max=descr.kwargs.max, 260 ) 261 262 263@dataclass 264class EnsureDtype(_SimpleOperator): 265 dtype: DTypeStr 266 267 @classmethod 268 def from_proc_descr(cls, descr: v0_5.EnsureDtypeDescr, member_id: MemberId): 269 return cls(input=member_id, output=member_id, dtype=descr.kwargs.dtype) 270 271 def get_descr(self): 272 return v0_5.EnsureDtypeDescr(kwargs=v0_5.EnsureDtypeKwargs(dtype=self.dtype)) 273 274 def get_output_shape( 275 self, input_shape: Mapping[AxisId, int] 276 ) -> Mapping[AxisId, int]: 277 return input_shape 278 279 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 280 return input.astype(self.dtype) 281 282 283@dataclass 284class ScaleLinear(_SimpleOperator): 285 gain: Union[float, xr.DataArray] = 1.0 286 """multiplicative factor""" 287 288 offset: Union[float, xr.DataArray] = 0.0 289 """additive term""" 290 291 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 292 return input * self.gain + self.offset 293 294 def get_output_shape( 295 self, input_shape: Mapping[AxisId, int] 296 ) -> Mapping[AxisId, int]: 297 return input_shape 298 299 @classmethod 300 def from_proc_descr( 301 cls, 302 descr: Union[v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr], 303 member_id: MemberId, 304 ) -> Self: 305 kwargs = descr.kwargs 306 if isinstance(kwargs, v0_5.ScaleLinearKwargs): 307 axis = None 308 elif isinstance(kwargs, v0_5.ScaleLinearAlongAxisKwargs): 309 axis = kwargs.axis 310 elif isinstance(kwargs, v0_4.ScaleLinearKwargs): 311 if kwargs.axes is not None: 312 raise NotImplementedError( 313 "model.v0_4.ScaleLinearKwargs with axes not implemented, please consider updating the model to v0_5." 314 ) 315 axis = None 316 else: 317 assert_never(kwargs) 318 319 if axis: 320 gain = xr.DataArray(np.atleast_1d(kwargs.gain), dims=axis) 321 offset = xr.DataArray(np.atleast_1d(kwargs.offset), dims=axis) 322 else: 323 assert ( 324 isinstance(kwargs.gain, (float, int)) or len(kwargs.gain) == 1 325 ), kwargs.gain 326 gain = ( 327 kwargs.gain if isinstance(kwargs.gain, (float, int)) else kwargs.gain[0] 328 ) 329 assert isinstance(kwargs.offset, (float, int)) or len(kwargs.offset) == 1 330 offset = ( 331 kwargs.offset 332 if isinstance(kwargs.offset, (float, int)) 333 else kwargs.offset[0] 334 ) 335 336 return cls(input=member_id, output=member_id, gain=gain, offset=offset) 337 338 339@dataclass 340class ScaleMeanVariance(_SimpleOperator): 341 axes: Optional[Sequence[AxisId]] = None 342 reference_tensor: Optional[MemberId] = None 343 eps: float = 1e-6 344 mean: Union[SampleMean, DatasetMean] = field(init=False) 345 std: Union[SampleStd, DatasetStd] = field(init=False) 346 ref_mean: Union[SampleMean, DatasetMean] = field(init=False) 347 ref_std: Union[SampleStd, DatasetStd] = field(init=False) 348 349 @property 350 def required_measures(self): 351 return {self.mean, self.std, self.ref_mean, self.ref_std} 352 353 def __post_init__(self): 354 axes = None if self.axes is None else tuple(self.axes) 355 ref_tensor = self.reference_tensor or self.input 356 if axes is None or AxisId("batch") not in axes: 357 Mean = SampleMean 358 Std = SampleStd 359 else: 360 Mean = DatasetMean 361 Std = DatasetStd 362 363 self.mean = Mean(member_id=self.input, axes=axes) 364 self.std = Std(member_id=self.input, axes=axes) 365 self.ref_mean = Mean(member_id=ref_tensor, axes=axes) 366 self.ref_std = Std(member_id=ref_tensor, axes=axes) 367 368 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 369 mean = stat[self.mean] 370 std = stat[self.std] + self.eps 371 ref_mean = stat[self.ref_mean] 372 ref_std = stat[self.ref_std] + self.eps 373 return (input - mean) / std * ref_std + ref_mean 374 375 def get_output_shape( 376 self, input_shape: Mapping[AxisId, int] 377 ) -> Mapping[AxisId, int]: 378 return input_shape 379 380 @classmethod 381 def from_proc_descr( 382 cls, 383 descr: Union[v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr], 384 member_id: MemberId, 385 ) -> Self: 386 kwargs = descr.kwargs 387 _, axes = _get_axes(descr.kwargs) 388 389 return cls( 390 input=member_id, 391 output=member_id, 392 reference_tensor=MemberId(str(kwargs.reference_tensor)), 393 axes=axes, 394 eps=kwargs.eps, 395 ) 396 397 398def _get_axes( 399 kwargs: Union[ 400 v0_4.ZeroMeanUnitVarianceKwargs, 401 v0_5.ZeroMeanUnitVarianceKwargs, 402 v0_4.ScaleRangeKwargs, 403 v0_5.ScaleRangeKwargs, 404 v0_4.ScaleMeanVarianceKwargs, 405 v0_5.ScaleMeanVarianceKwargs, 406 ], 407) -> Tuple[bool, Optional[Tuple[AxisId, ...]]]: 408 if kwargs.axes is None: 409 return True, None 410 elif isinstance(kwargs.axes, str): 411 axes = _convert_axis_ids(kwargs.axes, kwargs["mode"]) 412 return AxisId("b") in axes, axes 413 elif isinstance(kwargs.axes, collections.abc.Sequence): 414 axes = tuple(kwargs.axes) 415 return AxisId("batch") in axes, axes 416 else: 417 assert_never(kwargs.axes) 418 419 420@dataclass 421class ScaleRange(_SimpleOperator): 422 lower_percentile: InitVar[Optional[Union[SampleQuantile, DatasetPercentile]]] = None 423 upper_percentile: InitVar[Optional[Union[SampleQuantile, DatasetPercentile]]] = None 424 lower: Union[SampleQuantile, DatasetPercentile] = field(init=False) 425 upper: Union[SampleQuantile, DatasetPercentile] = field(init=False) 426 427 eps: float = 1e-6 428 429 def __post_init__( 430 self, 431 lower_percentile: Optional[Union[SampleQuantile, DatasetPercentile]], 432 upper_percentile: Optional[Union[SampleQuantile, DatasetPercentile]], 433 ): 434 if lower_percentile is None: 435 tid = self.input if upper_percentile is None else upper_percentile.member_id 436 self.lower = DatasetPercentile(q=0.0, member_id=tid) 437 else: 438 self.lower = lower_percentile 439 440 if upper_percentile is None: 441 self.upper = DatasetPercentile(q=1.0, member_id=self.lower.member_id) 442 else: 443 self.upper = upper_percentile 444 445 assert self.lower.member_id == self.upper.member_id 446 assert self.lower.q < self.upper.q 447 assert self.lower.axes == self.upper.axes 448 449 @property 450 def required_measures(self): 451 return {self.lower, self.upper} 452 453 def get_output_shape( 454 self, input_shape: Mapping[AxisId, int] 455 ) -> Mapping[AxisId, int]: 456 return input_shape 457 458 @classmethod 459 def from_proc_descr( 460 cls, 461 descr: Union[v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr], 462 member_id: MemberId, 463 ): 464 kwargs = descr.kwargs 465 ref_tensor = ( 466 member_id 467 if kwargs.reference_tensor is None 468 else MemberId(str(kwargs.reference_tensor)) 469 ) 470 dataset_mode, axes = _get_axes(descr.kwargs) 471 if dataset_mode: 472 Percentile = DatasetPercentile 473 else: 474 Percentile = SampleQuantile 475 476 return cls( 477 input=member_id, 478 output=member_id, 479 lower_percentile=Percentile( 480 q=kwargs.min_percentile / 100, axes=axes, member_id=ref_tensor 481 ), 482 upper_percentile=Percentile( 483 q=kwargs.max_percentile / 100, axes=axes, member_id=ref_tensor 484 ), 485 ) 486 487 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 488 lower = stat[self.lower] 489 upper = stat[self.upper] 490 return (input - lower) / (upper - lower + self.eps) 491 492 def get_descr(self): 493 assert self.lower.axes == self.upper.axes 494 assert self.lower.member_id == self.upper.member_id 495 496 return v0_5.ScaleRangeDescr( 497 kwargs=v0_5.ScaleRangeKwargs( 498 axes=self.lower.axes, 499 min_percentile=self.lower.q * 100, 500 max_percentile=self.upper.q * 100, 501 eps=self.eps, 502 reference_tensor=self.lower.member_id, 503 ) 504 ) 505 506 507@dataclass 508class Sigmoid(_SimpleOperator): 509 """1 / (1 + e^(-input)).""" 510 511 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 512 return Tensor(1.0 / (1.0 + np.exp(-input)), dims=input.dims) 513 514 @property 515 def required_measures(self) -> Collection[Measure]: 516 return {} 517 518 def get_output_shape( 519 self, input_shape: Mapping[AxisId, int] 520 ) -> Mapping[AxisId, int]: 521 return input_shape 522 523 @classmethod 524 def from_proc_descr( 525 cls, descr: Union[v0_4.SigmoidDescr, v0_5.SigmoidDescr], member_id: MemberId 526 ) -> Self: 527 assert isinstance(descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)) 528 return cls(input=member_id, output=member_id) 529 530 def get_descr(self): 531 return v0_5.SigmoidDescr() 532 533 534@dataclass 535class ZeroMeanUnitVariance(_SimpleOperator): 536 """normalize to zero mean, unit variance.""" 537 538 mean: MeanMeasure 539 std: StdMeasure 540 541 eps: float = 1e-6 542 543 def __post_init__(self): 544 assert self.mean.axes == self.std.axes 545 546 @property 547 def required_measures(self) -> Set[Union[MeanMeasure, StdMeasure]]: 548 return {self.mean, self.std} 549 550 def get_output_shape( 551 self, input_shape: Mapping[AxisId, int] 552 ) -> Mapping[AxisId, int]: 553 return input_shape 554 555 @classmethod 556 def from_proc_descr( 557 cls, 558 descr: Union[v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr], 559 member_id: MemberId, 560 ): 561 dataset_mode, axes = _get_axes(descr.kwargs) 562 563 if dataset_mode: 564 Mean = DatasetMean 565 Std = DatasetStd 566 else: 567 Mean = SampleMean 568 Std = SampleStd 569 570 return cls( 571 input=member_id, 572 output=member_id, 573 mean=Mean(axes=axes, member_id=member_id), 574 std=Std(axes=axes, member_id=member_id), 575 ) 576 577 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 578 mean = stat[self.mean] 579 std = stat[self.std] 580 return (input - mean) / (std + self.eps) 581 582 def get_descr(self): 583 return v0_5.ZeroMeanUnitVarianceDescr( 584 kwargs=v0_5.ZeroMeanUnitVarianceKwargs(axes=self.mean.axes, eps=self.eps) 585 ) 586 587 588@dataclass 589class FixedZeroMeanUnitVariance(_SimpleOperator): 590 """normalize to zero mean, unit variance with precomputed values.""" 591 592 mean: Union[float, xr.DataArray] 593 std: Union[float, xr.DataArray] 594 595 eps: float = 1e-6 596 597 def __post_init__(self): 598 assert ( 599 isinstance(self.mean, (int, float)) 600 or isinstance(self.std, (int, float)) 601 or self.mean.dims == self.std.dims 602 ) 603 604 def get_output_shape( 605 self, input_shape: Mapping[AxisId, int] 606 ) -> Mapping[AxisId, int]: 607 return input_shape 608 609 @classmethod 610 def from_proc_descr( 611 cls, 612 descr: v0_5.FixedZeroMeanUnitVarianceDescr, 613 member_id: MemberId, 614 ) -> Self: 615 if isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceKwargs): 616 dims = None 617 elif isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs): 618 dims = (AxisId(descr.kwargs.axis),) 619 else: 620 assert_never(descr.kwargs) 621 622 return cls( 623 input=member_id, 624 output=member_id, 625 mean=xr.DataArray(descr.kwargs.mean, dims=dims), 626 std=xr.DataArray(descr.kwargs.std, dims=dims), 627 ) 628 629 def get_descr(self): 630 if isinstance(self.mean, (int, float)): 631 assert isinstance(self.std, (int, float)) 632 kwargs = v0_5.FixedZeroMeanUnitVarianceKwargs(mean=self.mean, std=self.std) 633 else: 634 assert isinstance(self.std, xr.DataArray) 635 assert len(self.mean.dims) == 1 636 kwargs = v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs( 637 axis=AxisId(str(self.mean.dims[0])), 638 mean=list(self.mean), 639 std=list(self.std), 640 ) 641 642 return v0_5.FixedZeroMeanUnitVarianceDescr(kwargs=kwargs) 643 644 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 645 return (input - self.mean) / (self.std + self.eps) 646 647 648ProcDescr = Union[ 649 v0_4.PreprocessingDescr, 650 v0_4.PostprocessingDescr, 651 v0_5.PreprocessingDescr, 652 v0_5.PostprocessingDescr, 653] 654 655Processing = Union[ 656 AddKnownDatasetStats, 657 Binarize, 658 Clip, 659 EnsureDtype, 660 FixedZeroMeanUnitVariance, 661 ScaleLinear, 662 ScaleMeanVariance, 663 ScaleRange, 664 Sigmoid, 665 UpdateStats, 666 ZeroMeanUnitVariance, 667] 668 669 670def get_proc( 671 proc_descr: ProcDescr, 672 tensor_descr: Union[ 673 v0_4.InputTensorDescr, 674 v0_4.OutputTensorDescr, 675 v0_5.InputTensorDescr, 676 v0_5.OutputTensorDescr, 677 ], 678) -> Processing: 679 member_id = get_member_id(tensor_descr) 680 681 if isinstance(proc_descr, (v0_4.BinarizeDescr, v0_5.BinarizeDescr)): 682 return Binarize.from_proc_descr(proc_descr, member_id) 683 elif isinstance(proc_descr, (v0_4.ClipDescr, v0_5.ClipDescr)): 684 return Clip.from_proc_descr(proc_descr, member_id) 685 elif isinstance(proc_descr, v0_5.EnsureDtypeDescr): 686 return EnsureDtype.from_proc_descr(proc_descr, member_id) 687 elif isinstance(proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr): 688 return FixedZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id) 689 elif isinstance(proc_descr, (v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr)): 690 return ScaleLinear.from_proc_descr(proc_descr, member_id) 691 elif isinstance( 692 proc_descr, (v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr) 693 ): 694 return ScaleMeanVariance.from_proc_descr(proc_descr, member_id) 695 elif isinstance(proc_descr, (v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr)): 696 return ScaleRange.from_proc_descr(proc_descr, member_id) 697 elif isinstance(proc_descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)): 698 return Sigmoid.from_proc_descr(proc_descr, member_id) 699 elif ( 700 isinstance(proc_descr, v0_4.ZeroMeanUnitVarianceDescr) 701 and proc_descr.kwargs.mode == "fixed" 702 ): 703 if not isinstance( 704 tensor_descr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr) 705 ): 706 raise TypeError( 707 "Expected v0_4 tensor description for v0_4 processing description" 708 ) 709 710 v5_proc_descr = _convert_proc(proc_descr, tensor_descr.axes) 711 assert isinstance(v5_proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr) 712 return FixedZeroMeanUnitVariance.from_proc_descr(v5_proc_descr, member_id) 713 elif isinstance( 714 proc_descr, 715 (v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr), 716 ): 717 return ZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id) 718 else: 719 assert_never(proc_descr)
@dataclass
class
AddKnownDatasetStats110@dataclass 111class AddKnownDatasetStats(BlockedOperator): 112 dataset_stats: Mapping[DatasetMeasure, MeasureValue] 113 114 @property 115 def required_measures(self) -> Set[Measure]: 116 return set() 117 118 def __call__(self, sample: Union[Sample, SampleBlock]) -> None: 119 sample.stat.update(self.dataset_stats.items())
AddKnownDatasetStats( dataset_stats: Mapping[Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer>, return_type=PydanticUndefined, when_used='always')]]])
dataset_stats: Mapping[Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7f25f54c6f20>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7f25f54c7100>, return_type=PydanticUndefined, when_used='always')]]]
required_measures: Set[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]
@dataclass
class
UpdateStats158@dataclass 159class UpdateStats(Operator): 160 """Calculates sample and/or dataset measures""" 161 162 stats_calculator: StatsCalculator 163 """`StatsCalculator` to be used by this operator.""" 164 keep_updating_initial_dataset_stats: bool = False 165 """indicates if operator calls should keep updating initial dataset statistics or not; 166 if the `stats_calculator` was not provided with any initial dataset statistics, 167 these are always updated with every new sample. 168 """ 169 _keep_updating_dataset_stats: bool = field(init=False) 170 171 @property 172 def required_measures(self) -> Set[Measure]: 173 return set() 174 175 def __post_init__(self): 176 self._keep_updating_dataset_stats = ( 177 self.keep_updating_initial_dataset_stats 178 or not self.stats_calculator.has_dataset_measures 179 ) 180 181 def __call__(self, sample: Union[Sample, SampleBlockWithOrigin]) -> None: 182 if isinstance(sample, SampleBlockWithOrigin): 183 # update stats with whole sample on first block 184 if sample.block_index != 0: 185 return 186 187 origin = sample.origin 188 else: 189 origin = sample 190 191 if self._keep_updating_dataset_stats: 192 sample.stat.update(self.stats_calculator.update_and_get_all(origin)) 193 else: 194 sample.stat.update(self.stats_calculator.skip_update_and_get_all(origin))
Calculates sample and/or dataset measures
UpdateStats( stats_calculator: bioimageio.core.stat_calculators.StatsCalculator, keep_updating_initial_dataset_stats: bool = False)
stats_calculator: bioimageio.core.stat_calculators.StatsCalculator
StatsCalculator
to be used by this operator.
keep_updating_initial_dataset_stats: bool =
False
indicates if operator calls should keep updating initial dataset statistics or not;
if the stats_calculator
was not provided with any initial dataset statistics,
these are always updated with every new sample.
required_measures: Set[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]
197@dataclass 198class Binarize(_SimpleOperator): 199 """'output = tensor > threshold'.""" 200 201 threshold: Union[float, Sequence[float]] 202 axis: Optional[AxisId] = None 203 204 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 205 return input > self.threshold 206 207 def get_output_shape( 208 self, input_shape: Mapping[AxisId, int] 209 ) -> Mapping[AxisId, int]: 210 return input_shape 211 212 @classmethod 213 def from_proc_descr( 214 cls, descr: Union[v0_4.BinarizeDescr, v0_5.BinarizeDescr], member_id: MemberId 215 ) -> Self: 216 if isinstance(descr.kwargs, (v0_4.BinarizeKwargs, v0_5.BinarizeKwargs)): 217 return cls( 218 input=member_id, output=member_id, threshold=descr.kwargs.threshold 219 ) 220 elif isinstance(descr.kwargs, v0_5.BinarizeAlongAxisKwargs): 221 return cls( 222 input=member_id, 223 output=member_id, 224 threshold=descr.kwargs.threshold, 225 axis=descr.kwargs.axis, 226 ) 227 else: 228 assert_never(descr.kwargs)
'output = tensor > threshold'.
Binarize( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId, threshold: Union[float, Sequence[float]], axis: Optional[bioimageio.spec.model.v0_5.AxisId] = None)
def
get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
@classmethod
def
from_proc_descr( cls, descr: Union[bioimageio.spec.model.v0_4.BinarizeDescr, bioimageio.spec.model.v0_5.BinarizeDescr], member_id: bioimageio.spec.model.v0_5.TensorId) -> Self:
212 @classmethod 213 def from_proc_descr( 214 cls, descr: Union[v0_4.BinarizeDescr, v0_5.BinarizeDescr], member_id: MemberId 215 ) -> Self: 216 if isinstance(descr.kwargs, (v0_4.BinarizeKwargs, v0_5.BinarizeKwargs)): 217 return cls( 218 input=member_id, output=member_id, threshold=descr.kwargs.threshold 219 ) 220 elif isinstance(descr.kwargs, v0_5.BinarizeAlongAxisKwargs): 221 return cls( 222 input=member_id, 223 output=member_id, 224 threshold=descr.kwargs.threshold, 225 axis=descr.kwargs.axis, 226 ) 227 else: 228 assert_never(descr.kwargs)
Inherited Members
231@dataclass 232class Clip(_SimpleOperator): 233 min: Optional[float] = None 234 """minimum value for clipping""" 235 max: Optional[float] = None 236 """maximum value for clipping""" 237 238 def __post_init__(self): 239 assert self.min is not None or self.max is not None, "missing min or max value" 240 assert ( 241 self.min is None or self.max is None or self.min < self.max 242 ), f"expected min < max, but {self.min} !< {self.max}" 243 244 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 245 return input.clip(self.min, self.max) 246 247 def get_output_shape( 248 self, input_shape: Mapping[AxisId, int] 249 ) -> Mapping[AxisId, int]: 250 return input_shape 251 252 @classmethod 253 def from_proc_descr( 254 cls, descr: Union[v0_4.ClipDescr, v0_5.ClipDescr], member_id: MemberId 255 ) -> Self: 256 return cls( 257 input=member_id, 258 output=member_id, 259 min=descr.kwargs.min, 260 max=descr.kwargs.max, 261 )
Clip( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId, min: Optional[float] = None, max: Optional[float] = None)
def
get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
@classmethod
def
from_proc_descr( cls, descr: Union[bioimageio.spec.model.v0_4.ClipDescr, bioimageio.spec.model.v0_5.ClipDescr], member_id: bioimageio.spec.model.v0_5.TensorId) -> Self:
Inherited Members
264@dataclass 265class EnsureDtype(_SimpleOperator): 266 dtype: DTypeStr 267 268 @classmethod 269 def from_proc_descr(cls, descr: v0_5.EnsureDtypeDescr, member_id: MemberId): 270 return cls(input=member_id, output=member_id, dtype=descr.kwargs.dtype) 271 272 def get_descr(self): 273 return v0_5.EnsureDtypeDescr(kwargs=v0_5.EnsureDtypeKwargs(dtype=self.dtype)) 274 275 def get_output_shape( 276 self, input_shape: Mapping[AxisId, int] 277 ) -> Mapping[AxisId, int]: 278 return input_shape 279 280 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 281 return input.astype(self.dtype)
EnsureDtype( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId, dtype: Literal['bool', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64'])
dtype: Literal['bool', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64']
@classmethod
def
from_proc_descr( cls, descr: bioimageio.spec.model.v0_5.EnsureDtypeDescr, member_id: bioimageio.spec.model.v0_5.TensorId):
def
get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
Inherited Members
284@dataclass 285class ScaleLinear(_SimpleOperator): 286 gain: Union[float, xr.DataArray] = 1.0 287 """multiplicative factor""" 288 289 offset: Union[float, xr.DataArray] = 0.0 290 """additive term""" 291 292 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 293 return input * self.gain + self.offset 294 295 def get_output_shape( 296 self, input_shape: Mapping[AxisId, int] 297 ) -> Mapping[AxisId, int]: 298 return input_shape 299 300 @classmethod 301 def from_proc_descr( 302 cls, 303 descr: Union[v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr], 304 member_id: MemberId, 305 ) -> Self: 306 kwargs = descr.kwargs 307 if isinstance(kwargs, v0_5.ScaleLinearKwargs): 308 axis = None 309 elif isinstance(kwargs, v0_5.ScaleLinearAlongAxisKwargs): 310 axis = kwargs.axis 311 elif isinstance(kwargs, v0_4.ScaleLinearKwargs): 312 if kwargs.axes is not None: 313 raise NotImplementedError( 314 "model.v0_4.ScaleLinearKwargs with axes not implemented, please consider updating the model to v0_5." 315 ) 316 axis = None 317 else: 318 assert_never(kwargs) 319 320 if axis: 321 gain = xr.DataArray(np.atleast_1d(kwargs.gain), dims=axis) 322 offset = xr.DataArray(np.atleast_1d(kwargs.offset), dims=axis) 323 else: 324 assert ( 325 isinstance(kwargs.gain, (float, int)) or len(kwargs.gain) == 1 326 ), kwargs.gain 327 gain = ( 328 kwargs.gain if isinstance(kwargs.gain, (float, int)) else kwargs.gain[0] 329 ) 330 assert isinstance(kwargs.offset, (float, int)) or len(kwargs.offset) == 1 331 offset = ( 332 kwargs.offset 333 if isinstance(kwargs.offset, (float, int)) 334 else kwargs.offset[0] 335 ) 336 337 return cls(input=member_id, output=member_id, gain=gain, offset=offset)
ScaleLinear( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId, gain: Union[float, xarray.core.dataarray.DataArray] = 1.0, offset: Union[float, xarray.core.dataarray.DataArray] = 0.0)
def
get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
@classmethod
def
from_proc_descr( cls, descr: Union[bioimageio.spec.model.v0_4.ScaleLinearDescr, bioimageio.spec.model.v0_5.ScaleLinearDescr], member_id: bioimageio.spec.model.v0_5.TensorId) -> Self:
300 @classmethod 301 def from_proc_descr( 302 cls, 303 descr: Union[v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr], 304 member_id: MemberId, 305 ) -> Self: 306 kwargs = descr.kwargs 307 if isinstance(kwargs, v0_5.ScaleLinearKwargs): 308 axis = None 309 elif isinstance(kwargs, v0_5.ScaleLinearAlongAxisKwargs): 310 axis = kwargs.axis 311 elif isinstance(kwargs, v0_4.ScaleLinearKwargs): 312 if kwargs.axes is not None: 313 raise NotImplementedError( 314 "model.v0_4.ScaleLinearKwargs with axes not implemented, please consider updating the model to v0_5." 315 ) 316 axis = None 317 else: 318 assert_never(kwargs) 319 320 if axis: 321 gain = xr.DataArray(np.atleast_1d(kwargs.gain), dims=axis) 322 offset = xr.DataArray(np.atleast_1d(kwargs.offset), dims=axis) 323 else: 324 assert ( 325 isinstance(kwargs.gain, (float, int)) or len(kwargs.gain) == 1 326 ), kwargs.gain 327 gain = ( 328 kwargs.gain if isinstance(kwargs.gain, (float, int)) else kwargs.gain[0] 329 ) 330 assert isinstance(kwargs.offset, (float, int)) or len(kwargs.offset) == 1 331 offset = ( 332 kwargs.offset 333 if isinstance(kwargs.offset, (float, int)) 334 else kwargs.offset[0] 335 ) 336 337 return cls(input=member_id, output=member_id, gain=gain, offset=offset)
Inherited Members
340@dataclass 341class ScaleMeanVariance(_SimpleOperator): 342 axes: Optional[Sequence[AxisId]] = None 343 reference_tensor: Optional[MemberId] = None 344 eps: float = 1e-6 345 mean: Union[SampleMean, DatasetMean] = field(init=False) 346 std: Union[SampleStd, DatasetStd] = field(init=False) 347 ref_mean: Union[SampleMean, DatasetMean] = field(init=False) 348 ref_std: Union[SampleStd, DatasetStd] = field(init=False) 349 350 @property 351 def required_measures(self): 352 return {self.mean, self.std, self.ref_mean, self.ref_std} 353 354 def __post_init__(self): 355 axes = None if self.axes is None else tuple(self.axes) 356 ref_tensor = self.reference_tensor or self.input 357 if axes is None or AxisId("batch") not in axes: 358 Mean = SampleMean 359 Std = SampleStd 360 else: 361 Mean = DatasetMean 362 Std = DatasetStd 363 364 self.mean = Mean(member_id=self.input, axes=axes) 365 self.std = Std(member_id=self.input, axes=axes) 366 self.ref_mean = Mean(member_id=ref_tensor, axes=axes) 367 self.ref_std = Std(member_id=ref_tensor, axes=axes) 368 369 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 370 mean = stat[self.mean] 371 std = stat[self.std] + self.eps 372 ref_mean = stat[self.ref_mean] 373 ref_std = stat[self.ref_std] + self.eps 374 return (input - mean) / std * ref_std + ref_mean 375 376 def get_output_shape( 377 self, input_shape: Mapping[AxisId, int] 378 ) -> Mapping[AxisId, int]: 379 return input_shape 380 381 @classmethod 382 def from_proc_descr( 383 cls, 384 descr: Union[v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr], 385 member_id: MemberId, 386 ) -> Self: 387 kwargs = descr.kwargs 388 _, axes = _get_axes(descr.kwargs) 389 390 return cls( 391 input=member_id, 392 output=member_id, 393 reference_tensor=MemberId(str(kwargs.reference_tensor)), 394 axes=axes, 395 eps=kwargs.eps, 396 )
ScaleMeanVariance( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId, axes: Optional[Sequence[bioimageio.spec.model.v0_5.AxisId]] = None, reference_tensor: Optional[bioimageio.spec.model.v0_5.TensorId] = None, eps: float = 1e-06)
ref_mean: Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.DatasetMean]
def
get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
@classmethod
def
from_proc_descr( cls, descr: Union[bioimageio.spec.model.v0_4.ScaleMeanVarianceDescr, bioimageio.spec.model.v0_5.ScaleMeanVarianceDescr], member_id: bioimageio.spec.model.v0_5.TensorId) -> Self:
381 @classmethod 382 def from_proc_descr( 383 cls, 384 descr: Union[v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr], 385 member_id: MemberId, 386 ) -> Self: 387 kwargs = descr.kwargs 388 _, axes = _get_axes(descr.kwargs) 389 390 return cls( 391 input=member_id, 392 output=member_id, 393 reference_tensor=MemberId(str(kwargs.reference_tensor)), 394 axes=axes, 395 eps=kwargs.eps, 396 )
Inherited Members
421@dataclass 422class ScaleRange(_SimpleOperator): 423 lower_percentile: InitVar[Optional[Union[SampleQuantile, DatasetPercentile]]] = None 424 upper_percentile: InitVar[Optional[Union[SampleQuantile, DatasetPercentile]]] = None 425 lower: Union[SampleQuantile, DatasetPercentile] = field(init=False) 426 upper: Union[SampleQuantile, DatasetPercentile] = field(init=False) 427 428 eps: float = 1e-6 429 430 def __post_init__( 431 self, 432 lower_percentile: Optional[Union[SampleQuantile, DatasetPercentile]], 433 upper_percentile: Optional[Union[SampleQuantile, DatasetPercentile]], 434 ): 435 if lower_percentile is None: 436 tid = self.input if upper_percentile is None else upper_percentile.member_id 437 self.lower = DatasetPercentile(q=0.0, member_id=tid) 438 else: 439 self.lower = lower_percentile 440 441 if upper_percentile is None: 442 self.upper = DatasetPercentile(q=1.0, member_id=self.lower.member_id) 443 else: 444 self.upper = upper_percentile 445 446 assert self.lower.member_id == self.upper.member_id 447 assert self.lower.q < self.upper.q 448 assert self.lower.axes == self.upper.axes 449 450 @property 451 def required_measures(self): 452 return {self.lower, self.upper} 453 454 def get_output_shape( 455 self, input_shape: Mapping[AxisId, int] 456 ) -> Mapping[AxisId, int]: 457 return input_shape 458 459 @classmethod 460 def from_proc_descr( 461 cls, 462 descr: Union[v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr], 463 member_id: MemberId, 464 ): 465 kwargs = descr.kwargs 466 ref_tensor = ( 467 member_id 468 if kwargs.reference_tensor is None 469 else MemberId(str(kwargs.reference_tensor)) 470 ) 471 dataset_mode, axes = _get_axes(descr.kwargs) 472 if dataset_mode: 473 Percentile = DatasetPercentile 474 else: 475 Percentile = SampleQuantile 476 477 return cls( 478 input=member_id, 479 output=member_id, 480 lower_percentile=Percentile( 481 q=kwargs.min_percentile / 100, axes=axes, member_id=ref_tensor 482 ), 483 upper_percentile=Percentile( 484 q=kwargs.max_percentile / 100, axes=axes, member_id=ref_tensor 485 ), 486 ) 487 488 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 489 lower = stat[self.lower] 490 upper = stat[self.upper] 491 return (input - lower) / (upper - lower + self.eps) 492 493 def get_descr(self): 494 assert self.lower.axes == self.upper.axes 495 assert self.lower.member_id == self.upper.member_id 496 497 return v0_5.ScaleRangeDescr( 498 kwargs=v0_5.ScaleRangeKwargs( 499 axes=self.lower.axes, 500 min_percentile=self.lower.q * 100, 501 max_percentile=self.upper.q * 100, 502 eps=self.eps, 503 reference_tensor=self.lower.member_id, 504 ) 505 )
ScaleRange( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId, lower_percentile: dataclasses.InitVar[typing.Union[bioimageio.core.stat_measures.SampleQuantile, bioimageio.core.stat_measures.DatasetPercentile, NoneType]] = None, upper_percentile: dataclasses.InitVar[typing.Union[bioimageio.core.stat_measures.SampleQuantile, bioimageio.core.stat_measures.DatasetPercentile, NoneType]] = None, eps: float = 1e-06)
lower_percentile: dataclasses.InitVar[typing.Union[bioimageio.core.stat_measures.SampleQuantile, bioimageio.core.stat_measures.DatasetPercentile, NoneType]] =
None
upper_percentile: dataclasses.InitVar[typing.Union[bioimageio.core.stat_measures.SampleQuantile, bioimageio.core.stat_measures.DatasetPercentile, NoneType]] =
None
lower: Union[bioimageio.core.stat_measures.SampleQuantile, bioimageio.core.stat_measures.DatasetPercentile]
upper: Union[bioimageio.core.stat_measures.SampleQuantile, bioimageio.core.stat_measures.DatasetPercentile]
def
get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
@classmethod
def
from_proc_descr( cls, descr: Union[bioimageio.spec.model.v0_4.ScaleRangeDescr, bioimageio.spec.model.v0_5.ScaleRangeDescr], member_id: bioimageio.spec.model.v0_5.TensorId):
459 @classmethod 460 def from_proc_descr( 461 cls, 462 descr: Union[v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr], 463 member_id: MemberId, 464 ): 465 kwargs = descr.kwargs 466 ref_tensor = ( 467 member_id 468 if kwargs.reference_tensor is None 469 else MemberId(str(kwargs.reference_tensor)) 470 ) 471 dataset_mode, axes = _get_axes(descr.kwargs) 472 if dataset_mode: 473 Percentile = DatasetPercentile 474 else: 475 Percentile = SampleQuantile 476 477 return cls( 478 input=member_id, 479 output=member_id, 480 lower_percentile=Percentile( 481 q=kwargs.min_percentile / 100, axes=axes, member_id=ref_tensor 482 ), 483 upper_percentile=Percentile( 484 q=kwargs.max_percentile / 100, axes=axes, member_id=ref_tensor 485 ), 486 )
def
get_descr(self):
493 def get_descr(self): 494 assert self.lower.axes == self.upper.axes 495 assert self.lower.member_id == self.upper.member_id 496 497 return v0_5.ScaleRangeDescr( 498 kwargs=v0_5.ScaleRangeKwargs( 499 axes=self.lower.axes, 500 min_percentile=self.lower.q * 100, 501 max_percentile=self.upper.q * 100, 502 eps=self.eps, 503 reference_tensor=self.lower.member_id, 504 ) 505 )
Inherited Members
508@dataclass 509class Sigmoid(_SimpleOperator): 510 """1 / (1 + e^(-input)).""" 511 512 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 513 return Tensor(1.0 / (1.0 + np.exp(-input)), dims=input.dims) 514 515 @property 516 def required_measures(self) -> Collection[Measure]: 517 return {} 518 519 def get_output_shape( 520 self, input_shape: Mapping[AxisId, int] 521 ) -> Mapping[AxisId, int]: 522 return input_shape 523 524 @classmethod 525 def from_proc_descr( 526 cls, descr: Union[v0_4.SigmoidDescr, v0_5.SigmoidDescr], member_id: MemberId 527 ) -> Self: 528 assert isinstance(descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)) 529 return cls(input=member_id, output=member_id) 530 531 def get_descr(self): 532 return v0_5.SigmoidDescr()
1 / (1 + e^(-input)).
Sigmoid( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId)
required_measures: Collection[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]
def
get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
@classmethod
def
from_proc_descr( cls, descr: Union[bioimageio.spec.model.v0_4.SigmoidDescr, bioimageio.spec.model.v0_5.SigmoidDescr], member_id: bioimageio.spec.model.v0_5.TensorId) -> Self:
Inherited Members
535@dataclass 536class ZeroMeanUnitVariance(_SimpleOperator): 537 """normalize to zero mean, unit variance.""" 538 539 mean: MeanMeasure 540 std: StdMeasure 541 542 eps: float = 1e-6 543 544 def __post_init__(self): 545 assert self.mean.axes == self.std.axes 546 547 @property 548 def required_measures(self) -> Set[Union[MeanMeasure, StdMeasure]]: 549 return {self.mean, self.std} 550 551 def get_output_shape( 552 self, input_shape: Mapping[AxisId, int] 553 ) -> Mapping[AxisId, int]: 554 return input_shape 555 556 @classmethod 557 def from_proc_descr( 558 cls, 559 descr: Union[v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr], 560 member_id: MemberId, 561 ): 562 dataset_mode, axes = _get_axes(descr.kwargs) 563 564 if dataset_mode: 565 Mean = DatasetMean 566 Std = DatasetStd 567 else: 568 Mean = SampleMean 569 Std = SampleStd 570 571 return cls( 572 input=member_id, 573 output=member_id, 574 mean=Mean(axes=axes, member_id=member_id), 575 std=Std(axes=axes, member_id=member_id), 576 ) 577 578 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 579 mean = stat[self.mean] 580 std = stat[self.std] 581 return (input - mean) / (std + self.eps) 582 583 def get_descr(self): 584 return v0_5.ZeroMeanUnitVarianceDescr( 585 kwargs=v0_5.ZeroMeanUnitVarianceKwargs(axes=self.mean.axes, eps=self.eps) 586 )
normalize to zero mean, unit variance.
ZeroMeanUnitVariance( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId, mean: Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.DatasetMean], std: Union[bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.DatasetStd], eps: float = 1e-06)
required_measures: Set[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.DatasetStd]]
def
get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
@classmethod
def
from_proc_descr( cls, descr: Union[bioimageio.spec.model.v0_4.ZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_5.ZeroMeanUnitVarianceDescr], member_id: bioimageio.spec.model.v0_5.TensorId):
556 @classmethod 557 def from_proc_descr( 558 cls, 559 descr: Union[v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr], 560 member_id: MemberId, 561 ): 562 dataset_mode, axes = _get_axes(descr.kwargs) 563 564 if dataset_mode: 565 Mean = DatasetMean 566 Std = DatasetStd 567 else: 568 Mean = SampleMean 569 Std = SampleStd 570 571 return cls( 572 input=member_id, 573 output=member_id, 574 mean=Mean(axes=axes, member_id=member_id), 575 std=Std(axes=axes, member_id=member_id), 576 )
Inherited Members
589@dataclass 590class FixedZeroMeanUnitVariance(_SimpleOperator): 591 """normalize to zero mean, unit variance with precomputed values.""" 592 593 mean: Union[float, xr.DataArray] 594 std: Union[float, xr.DataArray] 595 596 eps: float = 1e-6 597 598 def __post_init__(self): 599 assert ( 600 isinstance(self.mean, (int, float)) 601 or isinstance(self.std, (int, float)) 602 or self.mean.dims == self.std.dims 603 ) 604 605 def get_output_shape( 606 self, input_shape: Mapping[AxisId, int] 607 ) -> Mapping[AxisId, int]: 608 return input_shape 609 610 @classmethod 611 def from_proc_descr( 612 cls, 613 descr: v0_5.FixedZeroMeanUnitVarianceDescr, 614 member_id: MemberId, 615 ) -> Self: 616 if isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceKwargs): 617 dims = None 618 elif isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs): 619 dims = (AxisId(descr.kwargs.axis),) 620 else: 621 assert_never(descr.kwargs) 622 623 return cls( 624 input=member_id, 625 output=member_id, 626 mean=xr.DataArray(descr.kwargs.mean, dims=dims), 627 std=xr.DataArray(descr.kwargs.std, dims=dims), 628 ) 629 630 def get_descr(self): 631 if isinstance(self.mean, (int, float)): 632 assert isinstance(self.std, (int, float)) 633 kwargs = v0_5.FixedZeroMeanUnitVarianceKwargs(mean=self.mean, std=self.std) 634 else: 635 assert isinstance(self.std, xr.DataArray) 636 assert len(self.mean.dims) == 1 637 kwargs = v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs( 638 axis=AxisId(str(self.mean.dims[0])), 639 mean=list(self.mean), 640 std=list(self.std), 641 ) 642 643 return v0_5.FixedZeroMeanUnitVarianceDescr(kwargs=kwargs) 644 645 def _apply(self, input: Tensor, stat: Stat) -> Tensor: 646 return (input - self.mean) / (self.std + self.eps)
normalize to zero mean, unit variance with precomputed values.
FixedZeroMeanUnitVariance( input: bioimageio.spec.model.v0_5.TensorId, output: bioimageio.spec.model.v0_5.TensorId, mean: Union[float, xarray.core.dataarray.DataArray], std: Union[float, xarray.core.dataarray.DataArray], eps: float = 1e-06)
def
get_output_shape( self, input_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
@classmethod
def
from_proc_descr( cls, descr: bioimageio.spec.model.v0_5.FixedZeroMeanUnitVarianceDescr, member_id: bioimageio.spec.model.v0_5.TensorId) -> Self:
610 @classmethod 611 def from_proc_descr( 612 cls, 613 descr: v0_5.FixedZeroMeanUnitVarianceDescr, 614 member_id: MemberId, 615 ) -> Self: 616 if isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceKwargs): 617 dims = None 618 elif isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs): 619 dims = (AxisId(descr.kwargs.axis),) 620 else: 621 assert_never(descr.kwargs) 622 623 return cls( 624 input=member_id, 625 output=member_id, 626 mean=xr.DataArray(descr.kwargs.mean, dims=dims), 627 std=xr.DataArray(descr.kwargs.std, dims=dims), 628 )
def
get_descr(self):
630 def get_descr(self): 631 if isinstance(self.mean, (int, float)): 632 assert isinstance(self.std, (int, float)) 633 kwargs = v0_5.FixedZeroMeanUnitVarianceKwargs(mean=self.mean, std=self.std) 634 else: 635 assert isinstance(self.std, xr.DataArray) 636 assert len(self.mean.dims) == 1 637 kwargs = v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs( 638 axis=AxisId(str(self.mean.dims[0])), 639 mean=list(self.mean), 640 std=list(self.std), 641 ) 642 643 return v0_5.FixedZeroMeanUnitVarianceDescr(kwargs=kwargs)
Inherited Members
ProcDescr =
typing.Union[typing.Annotated[typing.Union[bioimageio.spec.model.v0_4.BinarizeDescr, bioimageio.spec.model.v0_4.ClipDescr, bioimageio.spec.model.v0_4.ScaleLinearDescr, bioimageio.spec.model.v0_4.SigmoidDescr, bioimageio.spec.model.v0_4.ZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_4.ScaleRangeDescr], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], typing.Annotated[typing.Union[bioimageio.spec.model.v0_4.BinarizeDescr, bioimageio.spec.model.v0_4.ClipDescr, bioimageio.spec.model.v0_4.ScaleLinearDescr, bioimageio.spec.model.v0_4.SigmoidDescr, bioimageio.spec.model.v0_4.ZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_4.ScaleRangeDescr, bioimageio.spec.model.v0_4.ScaleMeanVarianceDescr], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], typing.Annotated[typing.Union[bioimageio.spec.model.v0_5.BinarizeDescr, bioimageio.spec.model.v0_5.ClipDescr, bioimageio.spec.model.v0_5.EnsureDtypeDescr, bioimageio.spec.model.v0_5.ScaleLinearDescr, bioimageio.spec.model.v0_5.SigmoidDescr, bioimageio.spec.model.v0_5.FixedZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_5.ZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_5.ScaleRangeDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)], typing.Annotated[typing.Union[bioimageio.spec.model.v0_5.BinarizeDescr, bioimageio.spec.model.v0_5.ClipDescr, bioimageio.spec.model.v0_5.EnsureDtypeDescr, bioimageio.spec.model.v0_5.ScaleLinearDescr, bioimageio.spec.model.v0_5.SigmoidDescr, bioimageio.spec.model.v0_5.FixedZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_5.ZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_5.ScaleRangeDescr, bioimageio.spec.model.v0_5.ScaleMeanVarianceDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]
Processing =
typing.Union[AddKnownDatasetStats, Binarize, Clip, EnsureDtype, FixedZeroMeanUnitVariance, ScaleLinear, ScaleMeanVariance, ScaleRange, Sigmoid, UpdateStats, ZeroMeanUnitVariance]
def
get_proc( proc_descr: Union[Annotated[Union[bioimageio.spec.model.v0_4.BinarizeDescr, bioimageio.spec.model.v0_4.ClipDescr, bioimageio.spec.model.v0_4.ScaleLinearDescr, bioimageio.spec.model.v0_4.SigmoidDescr, bioimageio.spec.model.v0_4.ZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_4.ScaleRangeDescr], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.spec.model.v0_4.BinarizeDescr, bioimageio.spec.model.v0_4.ClipDescr, bioimageio.spec.model.v0_4.ScaleLinearDescr, bioimageio.spec.model.v0_4.SigmoidDescr, bioimageio.spec.model.v0_4.ZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_4.ScaleRangeDescr, bioimageio.spec.model.v0_4.ScaleMeanVarianceDescr], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.spec.model.v0_5.BinarizeDescr, bioimageio.spec.model.v0_5.ClipDescr, bioimageio.spec.model.v0_5.EnsureDtypeDescr, bioimageio.spec.model.v0_5.ScaleLinearDescr, bioimageio.spec.model.v0_5.SigmoidDescr, bioimageio.spec.model.v0_5.FixedZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_5.ZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_5.ScaleRangeDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.spec.model.v0_5.BinarizeDescr, bioimageio.spec.model.v0_5.ClipDescr, bioimageio.spec.model.v0_5.EnsureDtypeDescr, bioimageio.spec.model.v0_5.ScaleLinearDescr, bioimageio.spec.model.v0_5.SigmoidDescr, bioimageio.spec.model.v0_5.FixedZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_5.ZeroMeanUnitVarianceDescr, bioimageio.spec.model.v0_5.ScaleRangeDescr, bioimageio.spec.model.v0_5.ScaleMeanVarianceDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], tensor_descr: Union[bioimageio.spec.model.v0_4.InputTensorDescr, bioimageio.spec.model.v0_4.OutputTensorDescr, bioimageio.spec.model.v0_5.InputTensorDescr, bioimageio.spec.model.v0_5.OutputTensorDescr]) -> Union[AddKnownDatasetStats, Binarize, Clip, EnsureDtype, FixedZeroMeanUnitVariance, ScaleLinear, ScaleMeanVariance, ScaleRange, Sigmoid, UpdateStats, ZeroMeanUnitVariance]:
671def get_proc( 672 proc_descr: ProcDescr, 673 tensor_descr: Union[ 674 v0_4.InputTensorDescr, 675 v0_4.OutputTensorDescr, 676 v0_5.InputTensorDescr, 677 v0_5.OutputTensorDescr, 678 ], 679) -> Processing: 680 member_id = get_member_id(tensor_descr) 681 682 if isinstance(proc_descr, (v0_4.BinarizeDescr, v0_5.BinarizeDescr)): 683 return Binarize.from_proc_descr(proc_descr, member_id) 684 elif isinstance(proc_descr, (v0_4.ClipDescr, v0_5.ClipDescr)): 685 return Clip.from_proc_descr(proc_descr, member_id) 686 elif isinstance(proc_descr, v0_5.EnsureDtypeDescr): 687 return EnsureDtype.from_proc_descr(proc_descr, member_id) 688 elif isinstance(proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr): 689 return FixedZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id) 690 elif isinstance(proc_descr, (v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr)): 691 return ScaleLinear.from_proc_descr(proc_descr, member_id) 692 elif isinstance( 693 proc_descr, (v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr) 694 ): 695 return ScaleMeanVariance.from_proc_descr(proc_descr, member_id) 696 elif isinstance(proc_descr, (v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr)): 697 return ScaleRange.from_proc_descr(proc_descr, member_id) 698 elif isinstance(proc_descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)): 699 return Sigmoid.from_proc_descr(proc_descr, member_id) 700 elif ( 701 isinstance(proc_descr, v0_4.ZeroMeanUnitVarianceDescr) 702 and proc_descr.kwargs.mode == "fixed" 703 ): 704 if not isinstance( 705 tensor_descr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr) 706 ): 707 raise TypeError( 708 "Expected v0_4 tensor description for v0_4 processing description" 709 ) 710 711 v5_proc_descr = _convert_proc(proc_descr, tensor_descr.axes) 712 assert isinstance(v5_proc_descr, v0_5.FixedZeroMeanUnitVarianceDescr) 713 return FixedZeroMeanUnitVariance.from_proc_descr(v5_proc_descr, member_id) 714 elif isinstance( 715 proc_descr, 716 (v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr), 717 ): 718 return ZeroMeanUnitVariance.from_proc_descr(proc_descr, member_id) 719 else: 720 assert_never(proc_descr)