Coverage for bioimageio/core/proc_ops.py: 77%
321 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-19 09:02 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-19 09:02 +0000
1import collections.abc
2from abc import ABC, abstractmethod
3from dataclasses import InitVar, dataclass, field
4from typing import (
5 Collection,
6 Literal,
7 Mapping,
8 Optional,
9 Sequence,
10 Set,
11 Tuple,
12 Union,
13)
15import numpy as np
16import xarray as xr
17from typing_extensions import Self, assert_never
19from bioimageio.spec.model import v0_4, v0_5
21from ._op_base import BlockedOperator, Operator
22from .axis import AxisId, PerAxis
23from .block import Block
24from .common import DTypeStr, MemberId
25from .sample import Sample, SampleBlock, SampleBlockWithOrigin
26from .stat_calculators import StatsCalculator
27from .stat_measures import (
28 DatasetMean,
29 DatasetMeasure,
30 DatasetPercentile,
31 DatasetStd,
32 MeanMeasure,
33 Measure,
34 MeasureValue,
35 SampleMean,
36 SampleQuantile,
37 SampleStd,
38 Stat,
39 StdMeasure,
40)
41from .tensor import Tensor
44def _convert_axis_ids(
45 axes: v0_4.AxesInCZYX,
46 mode: Literal["per_sample", "per_dataset"],
47) -> Tuple[AxisId, ...]:
48 if not isinstance(axes, str):
49 return tuple(axes)
51 if mode == "per_sample":
52 ret = []
53 elif mode == "per_dataset":
54 ret = [AxisId("b")]
55 else:
56 assert_never(mode)
58 ret.extend([AxisId(a) for a in axes])
59 return tuple(ret)
62@dataclass
63class _SimpleOperator(BlockedOperator, ABC):
64 input: MemberId
65 output: MemberId
67 @property
68 def required_measures(self) -> Collection[Measure]:
69 return set()
71 @abstractmethod
72 def get_output_shape(self, input_shape: PerAxis[int]) -> PerAxis[int]: ...
74 def __call__(self, sample: Union[Sample, SampleBlock]) -> None:
75 if self.input not in sample.members:
76 return
78 input_tensor = sample.members[self.input]
79 output_tensor = self._apply(input_tensor, sample.stat)
81 if self.output in sample.members:
82 assert (
83 sample.members[self.output].tagged_shape == output_tensor.tagged_shape
84 )
86 if isinstance(sample, Sample):
87 sample.members[self.output] = output_tensor
88 elif isinstance(sample, SampleBlock):
89 b = sample.blocks[self.input]
90 sample.blocks[self.output] = Block(
91 sample_shape=self.get_output_shape(sample.shape[self.input]),
92 data=output_tensor,
93 inner_slice=b.inner_slice,
94 halo=b.halo,
95 block_index=b.block_index,
96 blocks_in_sample=b.blocks_in_sample,
97 )
98 else:
99 assert_never(sample)
101 @abstractmethod
102 def _apply(self, input: Tensor, stat: Stat) -> Tensor: ...
105@dataclass
106class AddKnownDatasetStats(BlockedOperator):
107 dataset_stats: Mapping[DatasetMeasure, MeasureValue]
109 @property
110 def required_measures(self) -> Set[Measure]:
111 return set()
113 def __call__(self, sample: Union[Sample, SampleBlock]) -> None:
114 sample.stat.update(self.dataset_stats.items())
117# @dataclass
118# class UpdateStats(Operator):
119# """Calculates sample and/or dataset measures"""
121# measures: Union[Sequence[Measure], Set[Measure], Mapping[Measure, MeasureValue]]
122# """sample and dataset `measuers` to be calculated by this operator. Initial/fixed
123# dataset measure values may be given, see `keep_updating_dataset_stats` for details.
124# """
125# keep_updating_dataset_stats: Optional[bool] = None
126# """indicates if operator calls should keep updating dataset statistics or not
128# default (None): if `measures` is a `Mapping` (i.e. initial measure values are
129# given) no further updates to dataset statistics is conducted, otherwise (w.o.
130# initial measure values) dataset statistics are updated by each processed sample.
131# """
132# _keep_updating_dataset_stats: bool = field(init=False)
133# _stats_calculator: StatsCalculator = field(init=False)
135# @property
136# def required_measures(self) -> Set[Measure]:
137# return set()
139# def __post_init__(self):
140# self._stats_calculator = StatsCalculator(self.measures)
141# if self.keep_updating_dataset_stats is None:
142# self._keep_updating_dataset_stats = not isinstance(self.measures, collections.abc.Mapping)
143# else:
144# self._keep_updating_dataset_stats = self.keep_updating_dataset_stats
146# def __call__(self, sample_block: SampleBlockWithOrigin> None:
147# if self._keep_updating_dataset_stats:
148# sample.stat.update(self._stats_calculator.update_and_get_all(sample))
149# else:
150# sample.stat.update(self._stats_calculator.skip_update_and_get_all(sample))
153@dataclass
154class UpdateStats(Operator):
155 """Calculates sample and/or dataset measures"""
157 stats_calculator: StatsCalculator
158 """`StatsCalculator` to be used by this operator."""
159 keep_updating_initial_dataset_stats: bool = False
160 """indicates if operator calls should keep updating initial dataset statistics or not;
161 if the `stats_calculator` was not provided with any initial dataset statistics,
162 these are always updated with every new sample.
163 """
164 _keep_updating_dataset_stats: bool = field(init=False)
166 @property
167 def required_measures(self) -> Set[Measure]:
168 return set()
170 def __post_init__(self):
171 self._keep_updating_dataset_stats = (
172 self.keep_updating_initial_dataset_stats
173 or not self.stats_calculator.has_dataset_measures
174 )
176 def __call__(self, sample: Union[Sample, SampleBlockWithOrigin]) -> None:
177 if isinstance(sample, SampleBlockWithOrigin):
178 # update stats with whole sample on first block
179 if sample.block_index != 0:
180 return
182 origin = sample.origin
183 else:
184 origin = sample
186 if self._keep_updating_dataset_stats:
187 sample.stat.update(self.stats_calculator.update_and_get_all(origin))
188 else:
189 sample.stat.update(self.stats_calculator.skip_update_and_get_all(origin))
192@dataclass
193class Binarize(_SimpleOperator):
194 """'output = tensor > threshold'."""
196 threshold: Union[float, Sequence[float]]
197 axis: Optional[AxisId] = None
199 def _apply(self, input: Tensor, stat: Stat) -> Tensor:
200 return input > self.threshold
202 def get_output_shape(
203 self, input_shape: Mapping[AxisId, int]
204 ) -> Mapping[AxisId, int]:
205 return input_shape
207 @classmethod
208 def from_proc_descr(
209 cls, descr: Union[v0_4.BinarizeDescr, v0_5.BinarizeDescr], member_id: MemberId
210 ) -> Self:
211 if isinstance(descr.kwargs, (v0_4.BinarizeKwargs, v0_5.BinarizeKwargs)):
212 return cls(
213 input=member_id, output=member_id, threshold=descr.kwargs.threshold
214 )
215 elif isinstance(descr.kwargs, v0_5.BinarizeAlongAxisKwargs):
216 return cls(
217 input=member_id,
218 output=member_id,
219 threshold=descr.kwargs.threshold,
220 axis=descr.kwargs.axis,
221 )
222 else:
223 assert_never(descr.kwargs)
226@dataclass
227class Clip(_SimpleOperator):
228 min: Optional[float] = None
229 """minimum value for clipping"""
230 max: Optional[float] = None
231 """maximum value for clipping"""
233 def __post_init__(self):
234 assert self.min is not None or self.max is not None, "missing min or max value"
235 assert (
236 self.min is None or self.max is None or self.min < self.max
237 ), f"expected min < max, but {self.min} !< {self.max}"
239 def _apply(self, input: Tensor, stat: Stat) -> Tensor:
240 return input.clip(self.min, self.max)
242 def get_output_shape(
243 self, input_shape: Mapping[AxisId, int]
244 ) -> Mapping[AxisId, int]:
245 return input_shape
247 @classmethod
248 def from_proc_descr(
249 cls, descr: Union[v0_4.ClipDescr, v0_5.ClipDescr], member_id: MemberId
250 ) -> Self:
251 return cls(
252 input=member_id,
253 output=member_id,
254 min=descr.kwargs.min,
255 max=descr.kwargs.max,
256 )
259@dataclass
260class EnsureDtype(_SimpleOperator):
261 dtype: DTypeStr
263 @classmethod
264 def from_proc_descr(cls, descr: v0_5.EnsureDtypeDescr, member_id: MemberId):
265 return cls(input=member_id, output=member_id, dtype=descr.kwargs.dtype)
267 def get_descr(self):
268 return v0_5.EnsureDtypeDescr(kwargs=v0_5.EnsureDtypeKwargs(dtype=self.dtype))
270 def get_output_shape(
271 self, input_shape: Mapping[AxisId, int]
272 ) -> Mapping[AxisId, int]:
273 return input_shape
275 def _apply(self, input: Tensor, stat: Stat) -> Tensor:
276 return input.astype(self.dtype)
279@dataclass
280class ScaleLinear(_SimpleOperator):
281 gain: Union[float, xr.DataArray] = 1.0
282 """multiplicative factor"""
284 offset: Union[float, xr.DataArray] = 0.0
285 """additive term"""
287 def _apply(self, input: Tensor, stat: Stat) -> Tensor:
288 return input * self.gain + self.offset
290 def get_output_shape(
291 self, input_shape: Mapping[AxisId, int]
292 ) -> Mapping[AxisId, int]:
293 return input_shape
295 @classmethod
296 def from_proc_descr(
297 cls,
298 descr: Union[v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr],
299 member_id: MemberId,
300 ) -> Self:
301 kwargs = descr.kwargs
302 if isinstance(kwargs, v0_5.ScaleLinearAlongAxisKwargs):
303 axis = kwargs.axis
304 elif isinstance(kwargs, (v0_4.ScaleLinearKwargs, v0_5.ScaleLinearKwargs)):
305 axis = None
306 else:
307 assert_never(kwargs)
309 if axis:
310 gain = xr.DataArray(np.atleast_1d(kwargs.gain), dims=axis)
311 offset = xr.DataArray(np.atleast_1d(kwargs.offset), dims=axis)
312 else:
313 assert (
314 isinstance(kwargs.gain, (float, int)) or len(kwargs.gain) == 1
315 ), kwargs.gain
316 gain = (
317 kwargs.gain if isinstance(kwargs.gain, (float, int)) else kwargs.gain[0]
318 )
319 assert isinstance(kwargs.offset, (float, int)) or len(kwargs.offset) == 1
320 offset = (
321 kwargs.offset
322 if isinstance(kwargs.offset, (float, int))
323 else kwargs.offset[0]
324 )
326 return cls(input=member_id, output=member_id, gain=gain, offset=offset)
329@dataclass
330class ScaleMeanVariance(_SimpleOperator):
331 axes: Optional[Sequence[AxisId]] = None
332 reference_tensor: Optional[MemberId] = None
333 eps: float = 1e-6
334 mean: Union[SampleMean, DatasetMean] = field(init=False)
335 std: Union[SampleStd, DatasetStd] = field(init=False)
336 ref_mean: Union[SampleMean, DatasetMean] = field(init=False)
337 ref_std: Union[SampleStd, DatasetStd] = field(init=False)
339 @property
340 def required_measures(self):
341 return {self.mean, self.std, self.ref_mean, self.ref_std}
343 def __post_init__(self):
344 axes = None if self.axes is None else tuple(self.axes)
345 ref_tensor = self.reference_tensor or self.input
346 if axes is None or AxisId("batch") not in axes:
347 Mean = SampleMean
348 Std = SampleStd
349 else:
350 Mean = DatasetMean
351 Std = DatasetStd
353 self.mean = Mean(member_id=self.input, axes=axes)
354 self.std = Std(member_id=self.input, axes=axes)
355 self.ref_mean = Mean(member_id=ref_tensor, axes=axes)
356 self.ref_std = Std(member_id=ref_tensor, axes=axes)
358 def _apply(self, input: Tensor, stat: Stat) -> Tensor:
359 mean = stat[self.mean]
360 std = stat[self.std] + self.eps
361 ref_mean = stat[self.ref_mean]
362 ref_std = stat[self.ref_std] + self.eps
363 return (input - mean) / std * ref_std + ref_mean
365 def get_output_shape(
366 self, input_shape: Mapping[AxisId, int]
367 ) -> Mapping[AxisId, int]:
368 return input_shape
370 @classmethod
371 def from_proc_descr(
372 cls,
373 descr: Union[v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr],
374 member_id: MemberId,
375 ) -> Self:
376 kwargs = descr.kwargs
377 _, axes = _get_axes(descr.kwargs)
379 return cls(
380 input=member_id,
381 output=member_id,
382 reference_tensor=MemberId(str(kwargs.reference_tensor)),
383 axes=axes,
384 eps=kwargs.eps,
385 )
388def _get_axes(
389 kwargs: Union[
390 v0_4.ZeroMeanUnitVarianceKwargs,
391 v0_5.ZeroMeanUnitVarianceKwargs,
392 v0_4.ScaleRangeKwargs,
393 v0_5.ScaleRangeKwargs,
394 v0_4.ScaleMeanVarianceKwargs,
395 v0_5.ScaleMeanVarianceKwargs,
396 ],
397) -> Tuple[bool, Optional[Tuple[AxisId, ...]]]:
398 if kwargs.axes is None:
399 return True, None
400 elif isinstance(kwargs.axes, str):
401 axes = _convert_axis_ids(kwargs.axes, kwargs["mode"])
402 return AxisId("b") in axes, axes
403 elif isinstance(kwargs.axes, collections.abc.Sequence):
404 axes = tuple(kwargs.axes)
405 return AxisId("batch") in axes, axes
406 else:
407 assert_never(kwargs.axes)
410@dataclass
411class ScaleRange(_SimpleOperator):
412 lower_percentile: InitVar[Optional[Union[SampleQuantile, DatasetPercentile]]] = None
413 upper_percentile: InitVar[Optional[Union[SampleQuantile, DatasetPercentile]]] = None
414 lower: Union[SampleQuantile, DatasetPercentile] = field(init=False)
415 upper: Union[SampleQuantile, DatasetPercentile] = field(init=False)
417 eps: float = 1e-6
419 def __post_init__(
420 self,
421 lower_percentile: Optional[Union[SampleQuantile, DatasetPercentile]],
422 upper_percentile: Optional[Union[SampleQuantile, DatasetPercentile]],
423 ):
424 if lower_percentile is None:
425 tid = self.input if upper_percentile is None else upper_percentile.member_id
426 self.lower = DatasetPercentile(q=0.0, member_id=tid)
427 else:
428 self.lower = lower_percentile
430 if upper_percentile is None:
431 self.upper = DatasetPercentile(q=1.0, member_id=self.lower.member_id)
432 else:
433 self.upper = upper_percentile
435 assert self.lower.member_id == self.upper.member_id
436 assert self.lower.q < self.upper.q
437 assert self.lower.axes == self.upper.axes
439 @property
440 def required_measures(self):
441 return {self.lower, self.upper}
443 def get_output_shape(
444 self, input_shape: Mapping[AxisId, int]
445 ) -> Mapping[AxisId, int]:
446 return input_shape
448 @classmethod
449 def from_proc_descr(
450 cls,
451 descr: Union[v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr],
452 member_id: MemberId,
453 ):
454 kwargs = descr.kwargs
455 ref_tensor = (
456 member_id
457 if kwargs.reference_tensor is None
458 else MemberId(str(kwargs.reference_tensor))
459 )
460 dataset_mode, axes = _get_axes(descr.kwargs)
461 if dataset_mode:
462 Percentile = DatasetPercentile
463 else:
464 Percentile = SampleQuantile
466 return cls(
467 input=member_id,
468 output=member_id,
469 lower_percentile=Percentile(
470 q=kwargs.min_percentile / 100, axes=axes, member_id=ref_tensor
471 ),
472 upper_percentile=Percentile(
473 q=kwargs.max_percentile / 100, axes=axes, member_id=ref_tensor
474 ),
475 )
477 def _apply(self, input: Tensor, stat: Stat) -> Tensor:
478 lower = stat[self.lower]
479 upper = stat[self.upper]
480 return (input - lower) / (upper - lower + self.eps)
482 def get_descr(self):
483 assert self.lower.axes == self.upper.axes
484 assert self.lower.member_id == self.upper.member_id
486 return v0_5.ScaleRangeDescr(
487 kwargs=v0_5.ScaleRangeKwargs(
488 axes=self.lower.axes,
489 min_percentile=self.lower.q * 100,
490 max_percentile=self.upper.q * 100,
491 eps=self.eps,
492 reference_tensor=self.lower.member_id,
493 )
494 )
497@dataclass
498class Sigmoid(_SimpleOperator):
499 """1 / (1 + e^(-input))."""
501 def _apply(self, input: Tensor, stat: Stat) -> Tensor:
502 return Tensor(1.0 / (1.0 + np.exp(-input)), dims=input.dims)
504 @property
505 def required_measures(self) -> Collection[Measure]:
506 return {}
508 def get_output_shape(
509 self, input_shape: Mapping[AxisId, int]
510 ) -> Mapping[AxisId, int]:
511 return input_shape
513 @classmethod
514 def from_proc_descr(
515 cls, descr: Union[v0_4.SigmoidDescr, v0_5.SigmoidDescr], member_id: MemberId
516 ) -> Self:
517 assert isinstance(descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr))
518 return cls(input=member_id, output=member_id)
520 def get_descr(self):
521 return v0_5.SigmoidDescr()
524@dataclass
525class ZeroMeanUnitVariance(_SimpleOperator):
526 """normalize to zero mean, unit variance."""
528 mean: MeanMeasure
529 std: StdMeasure
531 eps: float = 1e-6
533 def __post_init__(self):
534 assert self.mean.axes == self.std.axes
536 @property
537 def required_measures(self) -> Set[Union[MeanMeasure, StdMeasure]]:
538 return {self.mean, self.std}
540 def get_output_shape(
541 self, input_shape: Mapping[AxisId, int]
542 ) -> Mapping[AxisId, int]:
543 return input_shape
545 @classmethod
546 def from_proc_descr(
547 cls,
548 descr: Union[v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr],
549 member_id: MemberId,
550 ):
551 dataset_mode, axes = _get_axes(descr.kwargs)
553 if dataset_mode:
554 Mean = DatasetMean
555 Std = DatasetStd
556 else:
557 Mean = SampleMean
558 Std = SampleStd
560 return cls(
561 input=member_id,
562 output=member_id,
563 mean=Mean(axes=axes, member_id=member_id),
564 std=Std(axes=axes, member_id=member_id),
565 )
567 def _apply(self, input: Tensor, stat: Stat) -> Tensor:
568 mean = stat[self.mean]
569 std = stat[self.std]
570 return (input - mean) / (std + self.eps)
572 def get_descr(self):
573 return v0_5.ZeroMeanUnitVarianceDescr(
574 kwargs=v0_5.ZeroMeanUnitVarianceKwargs(axes=self.mean.axes, eps=self.eps)
575 )
578@dataclass
579class FixedZeroMeanUnitVariance(_SimpleOperator):
580 """normalize to zero mean, unit variance with precomputed values."""
582 mean: Union[float, xr.DataArray]
583 std: Union[float, xr.DataArray]
585 eps: float = 1e-6
587 def __post_init__(self):
588 assert (
589 isinstance(self.mean, (int, float))
590 or isinstance(self.std, (int, float))
591 or self.mean.dims == self.std.dims
592 )
594 def get_output_shape(
595 self, input_shape: Mapping[AxisId, int]
596 ) -> Mapping[AxisId, int]:
597 return input_shape
599 @classmethod
600 def from_proc_descr(
601 cls,
602 descr: v0_5.FixedZeroMeanUnitVarianceDescr,
603 member_id: MemberId,
604 ) -> Self:
605 if isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceKwargs):
606 dims = None
607 elif isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs):
608 dims = (descr.kwargs.axis,)
609 else:
610 assert_never(descr.kwargs)
612 return cls(
613 input=member_id,
614 output=member_id,
615 mean=xr.DataArray(descr.kwargs.mean, dims=dims),
616 std=xr.DataArray(descr.kwargs.std, dims=dims),
617 )
619 def get_descr(self):
620 if isinstance(self.mean, (int, float)):
621 assert isinstance(self.std, (int, float))
622 kwargs = v0_5.FixedZeroMeanUnitVarianceKwargs(mean=self.mean, std=self.std)
623 else:
624 assert isinstance(self.std, xr.DataArray)
625 assert len(self.mean.dims) == 1
626 kwargs = v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs(
627 axis=AxisId(str(self.mean.dims[0])),
628 mean=list(self.mean),
629 std=list(self.std),
630 )
632 return v0_5.FixedZeroMeanUnitVarianceDescr(kwargs=kwargs)
634 def _apply(self, input: Tensor, stat: Stat) -> Tensor:
635 return (input - self.mean) / (self.std + self.eps)
638ProcDescr = Union[
639 v0_4.PreprocessingDescr,
640 v0_4.PostprocessingDescr,
641 v0_5.PreprocessingDescr,
642 v0_5.PostprocessingDescr,
643]
645Processing = Union[
646 AddKnownDatasetStats,
647 Binarize,
648 Clip,
649 EnsureDtype,
650 FixedZeroMeanUnitVariance,
651 ScaleLinear,
652 ScaleMeanVariance,
653 ScaleRange,
654 Sigmoid,
655 UpdateStats,
656 ZeroMeanUnitVariance,
657]
660def get_proc_class(proc_spec: ProcDescr):
661 if isinstance(proc_spec, (v0_4.BinarizeDescr, v0_5.BinarizeDescr)):
662 return Binarize
663 elif isinstance(proc_spec, (v0_4.ClipDescr, v0_5.ClipDescr)):
664 return Clip
665 elif isinstance(proc_spec, v0_5.EnsureDtypeDescr):
666 return EnsureDtype
667 elif isinstance(proc_spec, v0_5.FixedZeroMeanUnitVarianceDescr):
668 return FixedZeroMeanUnitVariance
669 elif isinstance(proc_spec, (v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr)):
670 return ScaleLinear
671 elif isinstance(
672 proc_spec, (v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr)
673 ):
674 return ScaleMeanVariance
675 elif isinstance(proc_spec, (v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr)):
676 return ScaleRange
677 elif isinstance(proc_spec, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)):
678 return Sigmoid
679 elif (
680 isinstance(proc_spec, v0_4.ZeroMeanUnitVarianceDescr)
681 and proc_spec.kwargs.mode == "fixed"
682 ):
683 return FixedZeroMeanUnitVariance
684 elif isinstance(
685 proc_spec,
686 (v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr),
687 ):
688 return ZeroMeanUnitVariance
689 else:
690 assert_never(proc_spec)