bioimageio.core.stat_calculators

  1from __future__ import annotations
  2
  3import collections
  4import warnings
  5from itertools import product
  6from typing import (
  7    Any,
  8    Collection,
  9    Dict,
 10    Iterable,
 11    Iterator,
 12    List,
 13    Mapping,
 14    Optional,
 15    OrderedDict,
 16    Sequence,
 17    Set,
 18    Tuple,
 19    Type,
 20    Union,
 21)
 22
 23import numpy as np
 24import xarray as xr
 25from loguru import logger
 26from numpy.typing import NDArray
 27from typing_extensions import assert_never
 28
 29from bioimageio.spec.model.v0_5 import BATCH_AXIS_ID
 30
 31from .axis import AxisId, PerAxis
 32from .common import MemberId
 33from .sample import Sample
 34from .stat_measures import (
 35    DatasetMean,
 36    DatasetMeasure,
 37    DatasetMeasureBase,
 38    DatasetPercentile,
 39    DatasetStd,
 40    DatasetVar,
 41    Measure,
 42    MeasureValue,
 43    SampleMean,
 44    SampleMeasure,
 45    SampleQuantile,
 46    SampleStd,
 47    SampleVar,
 48)
 49from .tensor import Tensor
 50
 51try:
 52    import crick  # pyright: ignore[reportMissingImports]
 53
 54except Exception:
 55    crick = None
 56
 57    class TDigest:
 58        def update(self, obj: Any):
 59            pass
 60
 61        def quantile(self, q: Any) -> Any:
 62            pass
 63
 64else:
 65    TDigest = crick.TDigest  # type: ignore
 66
 67
 68class MeanCalculator:
 69    """to calculate sample and dataset mean for in-memory samples"""
 70
 71    def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]):
 72        super().__init__()
 73        self._n: int = 0
 74        self._mean: Optional[Tensor] = None
 75        self._axes = None if axes is None else tuple(axes)
 76        self._member_id = member_id
 77        self._sample_mean = SampleMean(member_id=self._member_id, axes=self._axes)
 78        self._dataset_mean = DatasetMean(member_id=self._member_id, axes=self._axes)
 79
 80    def compute(self, sample: Sample) -> Dict[SampleMean, MeasureValue]:
 81        return {self._sample_mean: self._compute_impl(sample)}
 82
 83    def _compute_impl(self, sample: Sample) -> Tensor:
 84        tensor = sample.members[self._member_id].astype("float64", copy=False)
 85        return tensor.mean(dim=self._axes)
 86
 87    def update(self, sample: Sample) -> None:
 88        mean = self._compute_impl(sample)
 89        self._update_impl(sample.members[self._member_id], mean)
 90
 91    def compute_and_update(self, sample: Sample) -> Dict[SampleMean, MeasureValue]:
 92        mean = self._compute_impl(sample)
 93        self._update_impl(sample.members[self._member_id], mean)
 94        return {self._sample_mean: mean}
 95
 96    def _update_impl(self, tensor: Tensor, tensor_mean: Tensor):
 97        assert tensor_mean.dtype == "float64"
 98        # reduced voxel count
 99        n_b = int(tensor.size / tensor_mean.size)
100
101        if self._mean is None:
102            assert self._n == 0
103            self._n = n_b
104            self._mean = tensor_mean
105        else:
106            assert self._n != 0
107            n_a = self._n
108            mean_old = self._mean
109            self._n = n_a + n_b
110            self._mean = (n_a * mean_old + n_b * tensor_mean) / self._n
111            assert self._mean.dtype == "float64"
112
113    def finalize(self) -> Dict[DatasetMean, MeasureValue]:
114        if self._mean is None:
115            return {}
116        else:
117            return {self._dataset_mean: self._mean}
118
119
120class MeanVarStdCalculator:
121    """to calculate sample and dataset mean, variance or standard deviation"""
122
123    def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]):
124        super().__init__()
125        self._axes = None if axes is None else tuple(map(AxisId, axes))
126        self._member_id = member_id
127        self._n: int = 0
128        self._mean: Optional[Tensor] = None
129        self._m2: Optional[Tensor] = None
130
131    def compute(
132        self, sample: Sample
133    ) -> Dict[Union[SampleMean, SampleVar, SampleStd], MeasureValue]:
134        tensor = sample.members[self._member_id]
135        mean = tensor.mean(dim=self._axes)
136        c = (tensor - mean).data
137        if self._axes is None:
138            n = tensor.size
139        else:
140            n = int(np.prod([tensor.sizes[d] for d in self._axes]))
141
142        if xr.__version__.startswith("2023"):
143            var = (  # pyright: ignore[reportUnknownVariableType]
144                xr.dot(c, c, dims=self._axes) / n
145            )
146        else:
147            var = (  # pyright: ignore[reportUnknownVariableType]
148                xr.dot(c, c, dim=self._axes) / n
149            )
150
151        assert isinstance(var, xr.DataArray)
152        std = np.sqrt(var)
153        assert isinstance(std, xr.DataArray)
154        return {
155            SampleMean(axes=self._axes, member_id=self._member_id): mean,
156            SampleVar(axes=self._axes, member_id=self._member_id): Tensor.from_xarray(
157                var
158            ),
159            SampleStd(axes=self._axes, member_id=self._member_id): Tensor.from_xarray(
160                std
161            ),
162        }
163
164    def update(self, sample: Sample):
165        if self._axes is not None and BATCH_AXIS_ID not in self._axes:
166            return
167
168        tensor = sample.members[self._member_id].astype("float64", copy=False)
169        mean_b = tensor.mean(dim=self._axes)
170        assert mean_b.dtype == "float64"
171        # reduced voxel count
172        n_b = int(tensor.size / mean_b.size)
173        m2_b = ((tensor - mean_b) ** 2).sum(dim=self._axes)
174        assert m2_b.dtype == "float64"
175        if self._mean is None:
176            assert self._m2 is None
177            self._n = n_b
178            self._mean = mean_b
179            self._m2 = m2_b
180        else:
181            n_a = self._n
182            mean_a = self._mean
183            m2_a = self._m2
184            self._n = n = n_a + n_b
185            self._mean = (n_a * mean_a + n_b * mean_b) / n
186            assert self._mean.dtype == "float64"
187            d = mean_b - mean_a
188            self._m2 = m2_a + m2_b + d**2 * n_a * n_b / n
189            assert self._m2.dtype == "float64"
190
191    def finalize(
192        self,
193    ) -> Dict[Union[DatasetMean, DatasetVar, DatasetStd], MeasureValue]:
194        if (
195            self._axes is not None
196            and BATCH_AXIS_ID not in self._axes
197            or self._mean is None
198        ):
199            return {}
200        else:
201            assert self._m2 is not None
202            var = self._m2 / self._n
203            sqrt = var**0.5
204            if isinstance(sqrt, (int, float)):
205                # var and mean are scalar tensors, let's keep it consistent
206                sqrt = Tensor.from_xarray(xr.DataArray(sqrt))
207
208            assert isinstance(sqrt, Tensor), type(sqrt)
209            return {
210                DatasetMean(member_id=self._member_id, axes=self._axes): self._mean,
211                DatasetVar(member_id=self._member_id, axes=self._axes): var,
212                DatasetStd(member_id=self._member_id, axes=self._axes): sqrt,
213            }
214
215
216class SamplePercentilesCalculator:
217    """to calculate sample percentiles"""
218
219    def __init__(
220        self,
221        member_id: MemberId,
222        axes: Optional[Sequence[AxisId]],
223        qs: Collection[float],
224    ):
225        super().__init__()
226        assert all(0.0 <= q <= 1.0 for q in qs)
227        self._qs = sorted(set(qs))
228        self._axes = None if axes is None else tuple(axes)
229        self._member_id = member_id
230
231    def compute(self, sample: Sample) -> Dict[SampleQuantile, MeasureValue]:
232        tensor = sample.members[self._member_id]
233        ps = tensor.quantile(self._qs, dim=self._axes)
234        return {
235            SampleQuantile(q=q, axes=self._axes, member_id=self._member_id): p
236            for q, p in zip(self._qs, ps)
237        }
238
239
240class MeanPercentilesCalculator:
241    """to calculate dataset percentiles heuristically by averaging across samples
242    **note**: the returned dataset percentiles are an estiamte and **not mathematically correct**
243    """
244
245    def __init__(
246        self,
247        member_id: MemberId,
248        axes: Optional[Sequence[AxisId]],
249        qs: Collection[float],
250    ):
251        super().__init__()
252        assert all(0.0 <= q <= 1.0 for q in qs)
253        self._qs = sorted(set(qs))
254        self._axes = None if axes is None else tuple(axes)
255        self._member_id = member_id
256        self._n: int = 0
257        self._estimates: Optional[Tensor] = None
258
259    def update(self, sample: Sample):
260        tensor = sample.members[self._member_id]
261        sample_estimates = tensor.quantile(self._qs, dim=self._axes).astype(
262            "float64", copy=False
263        )
264
265        # reduced voxel count
266        n = int(tensor.size / np.prod(sample_estimates.shape_tuple[1:]))
267
268        if self._estimates is None:
269            assert self._n == 0
270            self._estimates = sample_estimates
271        else:
272            self._estimates = (self._n * self._estimates + n * sample_estimates) / (
273                self._n + n
274            )
275            assert self._estimates.dtype == "float64"
276
277        self._n += n
278
279    def finalize(self) -> Dict[DatasetPercentile, MeasureValue]:
280        if self._estimates is None:
281            return {}
282        else:
283            warnings.warn(
284                "Computed dataset percentiles naively by averaging percentiles of samples."
285            )
286            return {
287                DatasetPercentile(q=q, axes=self._axes, member_id=self._member_id): e
288                for q, e in zip(self._qs, self._estimates)
289            }
290
291
292class CrickPercentilesCalculator:
293    """to calculate dataset percentiles with the experimental [crick libray](https://github.com/dask/crick)"""
294
295    def __init__(
296        self,
297        member_id: MemberId,
298        axes: Optional[Sequence[AxisId]],
299        qs: Collection[float],
300    ):
301        warnings.warn(
302            "Computing dataset percentiles with experimental 'crick' library."
303        )
304        super().__init__()
305        assert all(0.0 <= q <= 1.0 for q in qs)
306        assert axes is None or "_percentiles" not in axes
307        self._qs = sorted(set(qs))
308        self._axes = None if axes is None else tuple(axes)
309        self._member_id = member_id
310        self._digest: Optional[List[TDigest]] = None
311        self._dims: Optional[Tuple[AxisId, ...]] = None
312        self._indices: Optional[Iterator[Tuple[int, ...]]] = None
313        self._shape: Optional[Tuple[int, ...]] = None
314
315    def _initialize(self, tensor_sizes: PerAxis[int]):
316        assert crick is not None
317        out_sizes: OrderedDict[AxisId, int] = collections.OrderedDict(
318            _percentiles=len(self._qs)
319        )
320        if self._axes is not None:
321            for d, s in tensor_sizes.items():
322                if d not in self._axes:
323                    out_sizes[d] = s
324
325        self._dims, self._shape = zip(*out_sizes.items())
326        assert self._shape is not None
327        d = int(np.prod(self._shape[1:]))
328        self._digest = [TDigest() for _ in range(d)]
329        self._indices = product(*map(range, self._shape[1:]))
330
331    def update(self, part: Sample):
332        tensor = (
333            part.members[self._member_id]
334            if isinstance(part, Sample)
335            else part.members[self._member_id].data
336        )
337        assert "_percentiles" not in tensor.dims
338        if self._digest is None:
339            self._initialize(tensor.tagged_shape)
340
341        assert self._digest is not None
342        assert self._indices is not None
343        assert self._dims is not None
344        for i, idx in enumerate(self._indices):
345            self._digest[i].update(tensor[dict(zip(self._dims[1:], idx))])
346
347    def finalize(self) -> Dict[DatasetPercentile, MeasureValue]:
348        if self._digest is None:
349            return {}
350        else:
351            assert self._dims is not None
352            assert self._shape is not None
353
354            vs: NDArray[Any] = np.asarray(
355                [[d.quantile(q) for d in self._digest] for q in self._qs]
356            ).reshape(self._shape)
357            return {
358                DatasetPercentile(
359                    q=q, axes=self._axes, member_id=self._member_id
360                ): Tensor(v, dims=self._dims[1:])
361                for q, v in zip(self._qs, vs)
362            }
363
364
365if crick is None:
366    DatasetPercentilesCalculator: Type[
367        Union[MeanPercentilesCalculator, CrickPercentilesCalculator]
368    ] = MeanPercentilesCalculator
369else:
370    DatasetPercentilesCalculator = CrickPercentilesCalculator
371
372
373class NaiveSampleMeasureCalculator:
374    """wrapper for measures to match interface of other sample measure calculators"""
375
376    def __init__(self, member_id: MemberId, measure: SampleMeasure):
377        super().__init__()
378        self.tensor_name = member_id
379        self.measure = measure
380
381    def compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]:
382        return {self.measure: self.measure.compute(sample)}
383
384
385SampleMeasureCalculator = Union[
386    MeanCalculator,
387    MeanVarStdCalculator,
388    SamplePercentilesCalculator,
389    NaiveSampleMeasureCalculator,
390]
391DatasetMeasureCalculator = Union[
392    MeanCalculator, MeanVarStdCalculator, DatasetPercentilesCalculator
393]
394
395
396class StatsCalculator:
397    """Estimates dataset statistics and computes sample statistics efficiently"""
398
399    def __init__(
400        self,
401        measures: Collection[Measure],
402        initial_dataset_measures: Optional[
403            Mapping[DatasetMeasure, MeasureValue]
404        ] = None,
405    ):
406        super().__init__()
407        self.sample_count = 0
408        self.sample_calculators, self.dataset_calculators = get_measure_calculators(
409            measures
410        )
411        if not initial_dataset_measures:
412            self._current_dataset_measures: Optional[
413                Dict[DatasetMeasure, MeasureValue]
414            ] = None
415        else:
416            missing_dataset_meas = {
417                m
418                for m in measures
419                if isinstance(m, DatasetMeasureBase)
420                and m not in initial_dataset_measures
421            }
422            if missing_dataset_meas:
423                logger.debug(
424                    f"ignoring `initial_dataset_measure` as it is missing {missing_dataset_meas}"
425                )
426                self._current_dataset_measures = None
427            else:
428                self._current_dataset_measures = dict(initial_dataset_measures)
429
430    @property
431    def has_dataset_measures(self):
432        return self._current_dataset_measures is not None
433
434    def update(
435        self,
436        sample: Union[Sample, Iterable[Sample]],
437    ) -> None:
438        _ = self._update(sample)
439
440    def finalize(self) -> Dict[DatasetMeasure, MeasureValue]:
441        """returns aggregated dataset statistics"""
442        if self._current_dataset_measures is None:
443            self._current_dataset_measures = {}
444            for calc in self.dataset_calculators:
445                values = calc.finalize()
446                self._current_dataset_measures.update(values.items())
447
448        return self._current_dataset_measures
449
450    def update_and_get_all(
451        self,
452        sample: Union[Sample, Iterable[Sample]],
453    ) -> Dict[Measure, MeasureValue]:
454        """Returns sample as well as updated dataset statistics"""
455        last_sample = self._update(sample)
456        if last_sample is None:
457            raise ValueError("`sample` was not a `Sample`, nor did it yield any.")
458
459        return {**self._compute(last_sample), **self.finalize()}
460
461    def skip_update_and_get_all(self, sample: Sample) -> Dict[Measure, MeasureValue]:
462        """Returns sample as well as previously computed dataset statistics"""
463        return {**self._compute(sample), **self.finalize()}
464
465    def _compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]:
466        ret: Dict[SampleMeasure, MeasureValue] = {}
467        for calc in self.sample_calculators:
468            values = calc.compute(sample)
469            ret.update(values.items())
470
471        return ret
472
473    def _update(self, sample: Union[Sample, Iterable[Sample]]) -> Optional[Sample]:
474        self.sample_count += 1
475        samples = [sample] if isinstance(sample, Sample) else sample
476        last_sample = None
477        for el in samples:
478            last_sample = el
479            for calc in self.dataset_calculators:
480                calc.update(el)
481
482        self._current_dataset_measures = None
483        return last_sample
484
485
486def get_measure_calculators(
487    required_measures: Iterable[Measure],
488) -> Tuple[List[SampleMeasureCalculator], List[DatasetMeasureCalculator]]:
489    """determines which calculators are needed to compute the required measures efficiently"""
490
491    sample_calculators: List[SampleMeasureCalculator] = []
492    dataset_calculators: List[DatasetMeasureCalculator] = []
493
494    # split required measures into groups
495    required_sample_means: Set[SampleMean] = set()
496    required_dataset_means: Set[DatasetMean] = set()
497    required_sample_mean_var_std: Set[Union[SampleMean, SampleVar, SampleStd]] = set()
498    required_dataset_mean_var_std: Set[Union[DatasetMean, DatasetVar, DatasetStd]] = (
499        set()
500    )
501    required_sample_percentiles: Dict[
502        Tuple[MemberId, Optional[Tuple[AxisId, ...]]], Set[float]
503    ] = {}
504    required_dataset_percentiles: Dict[
505        Tuple[MemberId, Optional[Tuple[AxisId, ...]]], Set[float]
506    ] = {}
507
508    for rm in required_measures:
509        if isinstance(rm, SampleMean):
510            required_sample_means.add(rm)
511        elif isinstance(rm, DatasetMean):
512            required_dataset_means.add(rm)
513        elif isinstance(rm, (SampleVar, SampleStd)):
514            required_sample_mean_var_std.update(
515                {
516                    msv(axes=rm.axes, member_id=rm.member_id)
517                    for msv in (SampleMean, SampleStd, SampleVar)
518                }
519            )
520            assert rm in required_sample_mean_var_std
521        elif isinstance(rm, (DatasetVar, DatasetStd)):
522            required_dataset_mean_var_std.update(
523                {
524                    msv(axes=rm.axes, member_id=rm.member_id)
525                    for msv in (DatasetMean, DatasetStd, DatasetVar)
526                }
527            )
528            assert rm in required_dataset_mean_var_std
529        elif isinstance(rm, SampleQuantile):
530            required_sample_percentiles.setdefault((rm.member_id, rm.axes), set()).add(
531                rm.q
532            )
533        elif isinstance(rm, DatasetPercentile):
534            required_dataset_percentiles.setdefault((rm.member_id, rm.axes), set()).add(
535                rm.q
536            )
537        else:
538            assert_never(rm)
539
540    for rm in required_sample_means:
541        if rm in required_sample_mean_var_std:
542            # computed togehter with var and std
543            continue
544
545        sample_calculators.append(MeanCalculator(member_id=rm.member_id, axes=rm.axes))
546
547    for rm in required_sample_mean_var_std:
548        sample_calculators.append(
549            MeanVarStdCalculator(member_id=rm.member_id, axes=rm.axes)
550        )
551
552    for rm in required_dataset_means:
553        if rm in required_dataset_mean_var_std:
554            # computed togehter with var and std
555            continue
556
557        dataset_calculators.append(MeanCalculator(member_id=rm.member_id, axes=rm.axes))
558
559    for rm in required_dataset_mean_var_std:
560        dataset_calculators.append(
561            MeanVarStdCalculator(member_id=rm.member_id, axes=rm.axes)
562        )
563
564    for (tid, axes), qs in required_sample_percentiles.items():
565        sample_calculators.append(
566            SamplePercentilesCalculator(member_id=tid, axes=axes, qs=qs)
567        )
568
569    for (tid, axes), qs in required_dataset_percentiles.items():
570        dataset_calculators.append(
571            DatasetPercentilesCalculator(member_id=tid, axes=axes, qs=qs)
572        )
573
574    return sample_calculators, dataset_calculators
575
576
577def compute_dataset_measures(
578    measures: Iterable[DatasetMeasure], dataset: Iterable[Sample]
579) -> Dict[DatasetMeasure, MeasureValue]:
580    """compute all dataset `measures` for the given `dataset`"""
581    sample_calculators, calculators = get_measure_calculators(measures)
582    assert not sample_calculators
583
584    ret: Dict[DatasetMeasure, MeasureValue] = {}
585
586    for sample in dataset:
587        for calc in calculators:
588            calc.update(sample)
589
590    for calc in calculators:
591        ret.update(calc.finalize().items())
592
593    return ret
594
595
596def compute_sample_measures(
597    measures: Iterable[SampleMeasure], sample: Sample
598) -> Dict[SampleMeasure, MeasureValue]:
599    """compute all sample `measures` for the given `sample`"""
600    calculators, dataset_calculators = get_measure_calculators(measures)
601    assert not dataset_calculators
602    ret: Dict[SampleMeasure, MeasureValue] = {}
603
604    for calc in calculators:
605        ret.update(calc.compute(sample).items())
606
607    return ret
608
609
610def compute_measures(
611    measures: Iterable[Measure], dataset: Iterable[Sample]
612) -> Dict[Measure, MeasureValue]:
613    """compute all `measures` for the given `dataset`
614    sample measures are computed for the last sample in `dataset`"""
615    sample_calculators, dataset_calculators = get_measure_calculators(measures)
616    ret: Dict[Measure, MeasureValue] = {}
617    sample = None
618    for sample in dataset:
619        for calc in dataset_calculators:
620            calc.update(sample)
621    if sample is None:
622        raise ValueError("empty dataset")
623
624    for calc in dataset_calculators:
625        ret.update(calc.finalize().items())
626
627    for calc in sample_calculators:
628        ret.update(calc.compute(sample).items())
629
630    return ret
class MeanCalculator:
 69class MeanCalculator:
 70    """to calculate sample and dataset mean for in-memory samples"""
 71
 72    def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]):
 73        super().__init__()
 74        self._n: int = 0
 75        self._mean: Optional[Tensor] = None
 76        self._axes = None if axes is None else tuple(axes)
 77        self._member_id = member_id
 78        self._sample_mean = SampleMean(member_id=self._member_id, axes=self._axes)
 79        self._dataset_mean = DatasetMean(member_id=self._member_id, axes=self._axes)
 80
 81    def compute(self, sample: Sample) -> Dict[SampleMean, MeasureValue]:
 82        return {self._sample_mean: self._compute_impl(sample)}
 83
 84    def _compute_impl(self, sample: Sample) -> Tensor:
 85        tensor = sample.members[self._member_id].astype("float64", copy=False)
 86        return tensor.mean(dim=self._axes)
 87
 88    def update(self, sample: Sample) -> None:
 89        mean = self._compute_impl(sample)
 90        self._update_impl(sample.members[self._member_id], mean)
 91
 92    def compute_and_update(self, sample: Sample) -> Dict[SampleMean, MeasureValue]:
 93        mean = self._compute_impl(sample)
 94        self._update_impl(sample.members[self._member_id], mean)
 95        return {self._sample_mean: mean}
 96
 97    def _update_impl(self, tensor: Tensor, tensor_mean: Tensor):
 98        assert tensor_mean.dtype == "float64"
 99        # reduced voxel count
100        n_b = int(tensor.size / tensor_mean.size)
101
102        if self._mean is None:
103            assert self._n == 0
104            self._n = n_b
105            self._mean = tensor_mean
106        else:
107            assert self._n != 0
108            n_a = self._n
109            mean_old = self._mean
110            self._n = n_a + n_b
111            self._mean = (n_a * mean_old + n_b * tensor_mean) / self._n
112            assert self._mean.dtype == "float64"
113
114    def finalize(self) -> Dict[DatasetMean, MeasureValue]:
115        if self._mean is None:
116            return {}
117        else:
118            return {self._dataset_mean: self._mean}

to calculate sample and dataset mean for in-memory samples

MeanCalculator( member_id: bioimageio.spec.model.v0_5.TensorId, axes: Optional[Sequence[bioimageio.spec.model.v0_5.AxisId]])
72    def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]):
73        super().__init__()
74        self._n: int = 0
75        self._mean: Optional[Tensor] = None
76        self._axes = None if axes is None else tuple(axes)
77        self._member_id = member_id
78        self._sample_mean = SampleMean(member_id=self._member_id, axes=self._axes)
79        self._dataset_mean = DatasetMean(member_id=self._member_id, axes=self._axes)
def compute( self, sample: bioimageio.core.Sample) -> Dict[bioimageio.core.stat_measures.SampleMean, Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7f25f54c6f20>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7f25f54c7100>, return_type=PydanticUndefined, when_used='always')]]]:
81    def compute(self, sample: Sample) -> Dict[SampleMean, MeasureValue]:
82        return {self._sample_mean: self._compute_impl(sample)}
def update(self, sample: bioimageio.core.Sample) -> None:
88    def update(self, sample: Sample) -> None:
89        mean = self._compute_impl(sample)
90        self._update_impl(sample.members[self._member_id], mean)
def compute_and_update( self, sample: bioimageio.core.Sample) -> Dict[bioimageio.core.stat_measures.SampleMean, Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7f25f54c6f20>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7f25f54c7100>, return_type=PydanticUndefined, when_used='always')]]]:
92    def compute_and_update(self, sample: Sample) -> Dict[SampleMean, MeasureValue]:
93        mean = self._compute_impl(sample)
94        self._update_impl(sample.members[self._member_id], mean)
95        return {self._sample_mean: mean}
def finalize( self) -> Dict[bioimageio.core.stat_measures.DatasetMean, Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7f25f54c6f20>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7f25f54c7100>, return_type=PydanticUndefined, when_used='always')]]]:
114    def finalize(self) -> Dict[DatasetMean, MeasureValue]:
115        if self._mean is None:
116            return {}
117        else:
118            return {self._dataset_mean: self._mean}
class MeanVarStdCalculator:
121class MeanVarStdCalculator:
122    """to calculate sample and dataset mean, variance or standard deviation"""
123
124    def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]):
125        super().__init__()
126        self._axes = None if axes is None else tuple(map(AxisId, axes))
127        self._member_id = member_id
128        self._n: int = 0
129        self._mean: Optional[Tensor] = None
130        self._m2: Optional[Tensor] = None
131
132    def compute(
133        self, sample: Sample
134    ) -> Dict[Union[SampleMean, SampleVar, SampleStd], MeasureValue]:
135        tensor = sample.members[self._member_id]
136        mean = tensor.mean(dim=self._axes)
137        c = (tensor - mean).data
138        if self._axes is None:
139            n = tensor.size
140        else:
141            n = int(np.prod([tensor.sizes[d] for d in self._axes]))
142
143        if xr.__version__.startswith("2023"):
144            var = (  # pyright: ignore[reportUnknownVariableType]
145                xr.dot(c, c, dims=self._axes) / n
146            )
147        else:
148            var = (  # pyright: ignore[reportUnknownVariableType]
149                xr.dot(c, c, dim=self._axes) / n
150            )
151
152        assert isinstance(var, xr.DataArray)
153        std = np.sqrt(var)
154        assert isinstance(std, xr.DataArray)
155        return {
156            SampleMean(axes=self._axes, member_id=self._member_id): mean,
157            SampleVar(axes=self._axes, member_id=self._member_id): Tensor.from_xarray(
158                var
159            ),
160            SampleStd(axes=self._axes, member_id=self._member_id): Tensor.from_xarray(
161                std
162            ),
163        }
164
165    def update(self, sample: Sample):
166        if self._axes is not None and BATCH_AXIS_ID not in self._axes:
167            return
168
169        tensor = sample.members[self._member_id].astype("float64", copy=False)
170        mean_b = tensor.mean(dim=self._axes)
171        assert mean_b.dtype == "float64"
172        # reduced voxel count
173        n_b = int(tensor.size / mean_b.size)
174        m2_b = ((tensor - mean_b) ** 2).sum(dim=self._axes)
175        assert m2_b.dtype == "float64"
176        if self._mean is None:
177            assert self._m2 is None
178            self._n = n_b
179            self._mean = mean_b
180            self._m2 = m2_b
181        else:
182            n_a = self._n
183            mean_a = self._mean
184            m2_a = self._m2
185            self._n = n = n_a + n_b
186            self._mean = (n_a * mean_a + n_b * mean_b) / n
187            assert self._mean.dtype == "float64"
188            d = mean_b - mean_a
189            self._m2 = m2_a + m2_b + d**2 * n_a * n_b / n
190            assert self._m2.dtype == "float64"
191
192    def finalize(
193        self,
194    ) -> Dict[Union[DatasetMean, DatasetVar, DatasetStd], MeasureValue]:
195        if (
196            self._axes is not None
197            and BATCH_AXIS_ID not in self._axes
198            or self._mean is None
199        ):
200            return {}
201        else:
202            assert self._m2 is not None
203            var = self._m2 / self._n
204            sqrt = var**0.5
205            if isinstance(sqrt, (int, float)):
206                # var and mean are scalar tensors, let's keep it consistent
207                sqrt = Tensor.from_xarray(xr.DataArray(sqrt))
208
209            assert isinstance(sqrt, Tensor), type(sqrt)
210            return {
211                DatasetMean(member_id=self._member_id, axes=self._axes): self._mean,
212                DatasetVar(member_id=self._member_id, axes=self._axes): var,
213                DatasetStd(member_id=self._member_id, axes=self._axes): sqrt,
214            }

to calculate sample and dataset mean, variance or standard deviation

MeanVarStdCalculator( member_id: bioimageio.spec.model.v0_5.TensorId, axes: Optional[Sequence[bioimageio.spec.model.v0_5.AxisId]])
124    def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]):
125        super().__init__()
126        self._axes = None if axes is None else tuple(map(AxisId, axes))
127        self._member_id = member_id
128        self._n: int = 0
129        self._mean: Optional[Tensor] = None
130        self._m2: Optional[Tensor] = None
def compute( self, sample: bioimageio.core.Sample) -> Dict[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleStd], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7f25f54c6f20>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7f25f54c7100>, return_type=PydanticUndefined, when_used='always')]]]:
132    def compute(
133        self, sample: Sample
134    ) -> Dict[Union[SampleMean, SampleVar, SampleStd], MeasureValue]:
135        tensor = sample.members[self._member_id]
136        mean = tensor.mean(dim=self._axes)
137        c = (tensor - mean).data
138        if self._axes is None:
139            n = tensor.size
140        else:
141            n = int(np.prod([tensor.sizes[d] for d in self._axes]))
142
143        if xr.__version__.startswith("2023"):
144            var = (  # pyright: ignore[reportUnknownVariableType]
145                xr.dot(c, c, dims=self._axes) / n
146            )
147        else:
148            var = (  # pyright: ignore[reportUnknownVariableType]
149                xr.dot(c, c, dim=self._axes) / n
150            )
151
152        assert isinstance(var, xr.DataArray)
153        std = np.sqrt(var)
154        assert isinstance(std, xr.DataArray)
155        return {
156            SampleMean(axes=self._axes, member_id=self._member_id): mean,
157            SampleVar(axes=self._axes, member_id=self._member_id): Tensor.from_xarray(
158                var
159            ),
160            SampleStd(axes=self._axes, member_id=self._member_id): Tensor.from_xarray(
161                std
162            ),
163        }
def update(self, sample: bioimageio.core.Sample):
165    def update(self, sample: Sample):
166        if self._axes is not None and BATCH_AXIS_ID not in self._axes:
167            return
168
169        tensor = sample.members[self._member_id].astype("float64", copy=False)
170        mean_b = tensor.mean(dim=self._axes)
171        assert mean_b.dtype == "float64"
172        # reduced voxel count
173        n_b = int(tensor.size / mean_b.size)
174        m2_b = ((tensor - mean_b) ** 2).sum(dim=self._axes)
175        assert m2_b.dtype == "float64"
176        if self._mean is None:
177            assert self._m2 is None
178            self._n = n_b
179            self._mean = mean_b
180            self._m2 = m2_b
181        else:
182            n_a = self._n
183            mean_a = self._mean
184            m2_a = self._m2
185            self._n = n = n_a + n_b
186            self._mean = (n_a * mean_a + n_b * mean_b) / n
187            assert self._mean.dtype == "float64"
188            d = mean_b - mean_a
189            self._m2 = m2_a + m2_b + d**2 * n_a * n_b / n
190            assert self._m2.dtype == "float64"
def finalize( self) -> Dict[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetStd], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7f25f54c6f20>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7f25f54c7100>, return_type=PydanticUndefined, when_used='always')]]]:
192    def finalize(
193        self,
194    ) -> Dict[Union[DatasetMean, DatasetVar, DatasetStd], MeasureValue]:
195        if (
196            self._axes is not None
197            and BATCH_AXIS_ID not in self._axes
198            or self._mean is None
199        ):
200            return {}
201        else:
202            assert self._m2 is not None
203            var = self._m2 / self._n
204            sqrt = var**0.5
205            if isinstance(sqrt, (int, float)):
206                # var and mean are scalar tensors, let's keep it consistent
207                sqrt = Tensor.from_xarray(xr.DataArray(sqrt))
208
209            assert isinstance(sqrt, Tensor), type(sqrt)
210            return {
211                DatasetMean(member_id=self._member_id, axes=self._axes): self._mean,
212                DatasetVar(member_id=self._member_id, axes=self._axes): var,
213                DatasetStd(member_id=self._member_id, axes=self._axes): sqrt,
214            }
class SamplePercentilesCalculator:
217class SamplePercentilesCalculator:
218    """to calculate sample percentiles"""
219
220    def __init__(
221        self,
222        member_id: MemberId,
223        axes: Optional[Sequence[AxisId]],
224        qs: Collection[float],
225    ):
226        super().__init__()
227        assert all(0.0 <= q <= 1.0 for q in qs)
228        self._qs = sorted(set(qs))
229        self._axes = None if axes is None else tuple(axes)
230        self._member_id = member_id
231
232    def compute(self, sample: Sample) -> Dict[SampleQuantile, MeasureValue]:
233        tensor = sample.members[self._member_id]
234        ps = tensor.quantile(self._qs, dim=self._axes)
235        return {
236            SampleQuantile(q=q, axes=self._axes, member_id=self._member_id): p
237            for q, p in zip(self._qs, ps)
238        }

to calculate sample percentiles

SamplePercentilesCalculator( member_id: bioimageio.spec.model.v0_5.TensorId, axes: Optional[Sequence[bioimageio.spec.model.v0_5.AxisId]], qs: Collection[float])
220    def __init__(
221        self,
222        member_id: MemberId,
223        axes: Optional[Sequence[AxisId]],
224        qs: Collection[float],
225    ):
226        super().__init__()
227        assert all(0.0 <= q <= 1.0 for q in qs)
228        self._qs = sorted(set(qs))
229        self._axes = None if axes is None else tuple(axes)
230        self._member_id = member_id
def compute( self, sample: bioimageio.core.Sample) -> Dict[bioimageio.core.stat_measures.SampleQuantile, Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7f25f54c6f20>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7f25f54c7100>, return_type=PydanticUndefined, when_used='always')]]]:
232    def compute(self, sample: Sample) -> Dict[SampleQuantile, MeasureValue]:
233        tensor = sample.members[self._member_id]
234        ps = tensor.quantile(self._qs, dim=self._axes)
235        return {
236            SampleQuantile(q=q, axes=self._axes, member_id=self._member_id): p
237            for q, p in zip(self._qs, ps)
238        }
class MeanPercentilesCalculator:
241class MeanPercentilesCalculator:
242    """to calculate dataset percentiles heuristically by averaging across samples
243    **note**: the returned dataset percentiles are an estiamte and **not mathematically correct**
244    """
245
246    def __init__(
247        self,
248        member_id: MemberId,
249        axes: Optional[Sequence[AxisId]],
250        qs: Collection[float],
251    ):
252        super().__init__()
253        assert all(0.0 <= q <= 1.0 for q in qs)
254        self._qs = sorted(set(qs))
255        self._axes = None if axes is None else tuple(axes)
256        self._member_id = member_id
257        self._n: int = 0
258        self._estimates: Optional[Tensor] = None
259
260    def update(self, sample: Sample):
261        tensor = sample.members[self._member_id]
262        sample_estimates = tensor.quantile(self._qs, dim=self._axes).astype(
263            "float64", copy=False
264        )
265
266        # reduced voxel count
267        n = int(tensor.size / np.prod(sample_estimates.shape_tuple[1:]))
268
269        if self._estimates is None:
270            assert self._n == 0
271            self._estimates = sample_estimates
272        else:
273            self._estimates = (self._n * self._estimates + n * sample_estimates) / (
274                self._n + n
275            )
276            assert self._estimates.dtype == "float64"
277
278        self._n += n
279
280    def finalize(self) -> Dict[DatasetPercentile, MeasureValue]:
281        if self._estimates is None:
282            return {}
283        else:
284            warnings.warn(
285                "Computed dataset percentiles naively by averaging percentiles of samples."
286            )
287            return {
288                DatasetPercentile(q=q, axes=self._axes, member_id=self._member_id): e
289                for q, e in zip(self._qs, self._estimates)
290            }

to calculate dataset percentiles heuristically by averaging across samples note: the returned dataset percentiles are an estiamte and not mathematically correct

MeanPercentilesCalculator( member_id: bioimageio.spec.model.v0_5.TensorId, axes: Optional[Sequence[bioimageio.spec.model.v0_5.AxisId]], qs: Collection[float])
246    def __init__(
247        self,
248        member_id: MemberId,
249        axes: Optional[Sequence[AxisId]],
250        qs: Collection[float],
251    ):
252        super().__init__()
253        assert all(0.0 <= q <= 1.0 for q in qs)
254        self._qs = sorted(set(qs))
255        self._axes = None if axes is None else tuple(axes)
256        self._member_id = member_id
257        self._n: int = 0
258        self._estimates: Optional[Tensor] = None
def update(self, sample: bioimageio.core.Sample):
260    def update(self, sample: Sample):
261        tensor = sample.members[self._member_id]
262        sample_estimates = tensor.quantile(self._qs, dim=self._axes).astype(
263            "float64", copy=False
264        )
265
266        # reduced voxel count
267        n = int(tensor.size / np.prod(sample_estimates.shape_tuple[1:]))
268
269        if self._estimates is None:
270            assert self._n == 0
271            self._estimates = sample_estimates
272        else:
273            self._estimates = (self._n * self._estimates + n * sample_estimates) / (
274                self._n + n
275            )
276            assert self._estimates.dtype == "float64"
277
278        self._n += n
def finalize( self) -> Dict[bioimageio.core.stat_measures.DatasetPercentile, Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7f25f54c6f20>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7f25f54c7100>, return_type=PydanticUndefined, when_used='always')]]]:
280    def finalize(self) -> Dict[DatasetPercentile, MeasureValue]:
281        if self._estimates is None:
282            return {}
283        else:
284            warnings.warn(
285                "Computed dataset percentiles naively by averaging percentiles of samples."
286            )
287            return {
288                DatasetPercentile(q=q, axes=self._axes, member_id=self._member_id): e
289                for q, e in zip(self._qs, self._estimates)
290            }
class CrickPercentilesCalculator:
293class CrickPercentilesCalculator:
294    """to calculate dataset percentiles with the experimental [crick libray](https://github.com/dask/crick)"""
295
296    def __init__(
297        self,
298        member_id: MemberId,
299        axes: Optional[Sequence[AxisId]],
300        qs: Collection[float],
301    ):
302        warnings.warn(
303            "Computing dataset percentiles with experimental 'crick' library."
304        )
305        super().__init__()
306        assert all(0.0 <= q <= 1.0 for q in qs)
307        assert axes is None or "_percentiles" not in axes
308        self._qs = sorted(set(qs))
309        self._axes = None if axes is None else tuple(axes)
310        self._member_id = member_id
311        self._digest: Optional[List[TDigest]] = None
312        self._dims: Optional[Tuple[AxisId, ...]] = None
313        self._indices: Optional[Iterator[Tuple[int, ...]]] = None
314        self._shape: Optional[Tuple[int, ...]] = None
315
316    def _initialize(self, tensor_sizes: PerAxis[int]):
317        assert crick is not None
318        out_sizes: OrderedDict[AxisId, int] = collections.OrderedDict(
319            _percentiles=len(self._qs)
320        )
321        if self._axes is not None:
322            for d, s in tensor_sizes.items():
323                if d not in self._axes:
324                    out_sizes[d] = s
325
326        self._dims, self._shape = zip(*out_sizes.items())
327        assert self._shape is not None
328        d = int(np.prod(self._shape[1:]))
329        self._digest = [TDigest() for _ in range(d)]
330        self._indices = product(*map(range, self._shape[1:]))
331
332    def update(self, part: Sample):
333        tensor = (
334            part.members[self._member_id]
335            if isinstance(part, Sample)
336            else part.members[self._member_id].data
337        )
338        assert "_percentiles" not in tensor.dims
339        if self._digest is None:
340            self._initialize(tensor.tagged_shape)
341
342        assert self._digest is not None
343        assert self._indices is not None
344        assert self._dims is not None
345        for i, idx in enumerate(self._indices):
346            self._digest[i].update(tensor[dict(zip(self._dims[1:], idx))])
347
348    def finalize(self) -> Dict[DatasetPercentile, MeasureValue]:
349        if self._digest is None:
350            return {}
351        else:
352            assert self._dims is not None
353            assert self._shape is not None
354
355            vs: NDArray[Any] = np.asarray(
356                [[d.quantile(q) for d in self._digest] for q in self._qs]
357            ).reshape(self._shape)
358            return {
359                DatasetPercentile(
360                    q=q, axes=self._axes, member_id=self._member_id
361                ): Tensor(v, dims=self._dims[1:])
362                for q, v in zip(self._qs, vs)
363            }

to calculate dataset percentiles with the experimental crick libray

CrickPercentilesCalculator( member_id: bioimageio.spec.model.v0_5.TensorId, axes: Optional[Sequence[bioimageio.spec.model.v0_5.AxisId]], qs: Collection[float])
296    def __init__(
297        self,
298        member_id: MemberId,
299        axes: Optional[Sequence[AxisId]],
300        qs: Collection[float],
301    ):
302        warnings.warn(
303            "Computing dataset percentiles with experimental 'crick' library."
304        )
305        super().__init__()
306        assert all(0.0 <= q <= 1.0 for q in qs)
307        assert axes is None or "_percentiles" not in axes
308        self._qs = sorted(set(qs))
309        self._axes = None if axes is None else tuple(axes)
310        self._member_id = member_id
311        self._digest: Optional[List[TDigest]] = None
312        self._dims: Optional[Tuple[AxisId, ...]] = None
313        self._indices: Optional[Iterator[Tuple[int, ...]]] = None
314        self._shape: Optional[Tuple[int, ...]] = None
def update(self, part: bioimageio.core.Sample):
332    def update(self, part: Sample):
333        tensor = (
334            part.members[self._member_id]
335            if isinstance(part, Sample)
336            else part.members[self._member_id].data
337        )
338        assert "_percentiles" not in tensor.dims
339        if self._digest is None:
340            self._initialize(tensor.tagged_shape)
341
342        assert self._digest is not None
343        assert self._indices is not None
344        assert self._dims is not None
345        for i, idx in enumerate(self._indices):
346            self._digest[i].update(tensor[dict(zip(self._dims[1:], idx))])
def finalize( self) -> Dict[bioimageio.core.stat_measures.DatasetPercentile, Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7f25f54c6f20>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7f25f54c7100>, return_type=PydanticUndefined, when_used='always')]]]:
348    def finalize(self) -> Dict[DatasetPercentile, MeasureValue]:
349        if self._digest is None:
350            return {}
351        else:
352            assert self._dims is not None
353            assert self._shape is not None
354
355            vs: NDArray[Any] = np.asarray(
356                [[d.quantile(q) for d in self._digest] for q in self._qs]
357            ).reshape(self._shape)
358            return {
359                DatasetPercentile(
360                    q=q, axes=self._axes, member_id=self._member_id
361                ): Tensor(v, dims=self._dims[1:])
362                for q, v in zip(self._qs, vs)
363            }
class NaiveSampleMeasureCalculator:
374class NaiveSampleMeasureCalculator:
375    """wrapper for measures to match interface of other sample measure calculators"""
376
377    def __init__(self, member_id: MemberId, measure: SampleMeasure):
378        super().__init__()
379        self.tensor_name = member_id
380        self.measure = measure
381
382    def compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]:
383        return {self.measure: self.measure.compute(sample)}

wrapper for measures to match interface of other sample measure calculators

NaiveSampleMeasureCalculator( member_id: bioimageio.spec.model.v0_5.TensorId, measure: Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)])
377    def __init__(self, member_id: MemberId, measure: SampleMeasure):
378        super().__init__()
379        self.tensor_name = member_id
380        self.measure = measure
tensor_name
measure
def compute( self, sample: bioimageio.core.Sample) -> Dict[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7f25f54c6f20>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7f25f54c7100>, return_type=PydanticUndefined, when_used='always')]]]:
382    def compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]:
383        return {self.measure: self.measure.compute(sample)}
DatasetMeasureCalculator = typing.Union[MeanCalculator, MeanVarStdCalculator, MeanPercentilesCalculator]
class StatsCalculator:
397class StatsCalculator:
398    """Estimates dataset statistics and computes sample statistics efficiently"""
399
400    def __init__(
401        self,
402        measures: Collection[Measure],
403        initial_dataset_measures: Optional[
404            Mapping[DatasetMeasure, MeasureValue]
405        ] = None,
406    ):
407        super().__init__()
408        self.sample_count = 0
409        self.sample_calculators, self.dataset_calculators = get_measure_calculators(
410            measures
411        )
412        if not initial_dataset_measures:
413            self._current_dataset_measures: Optional[
414                Dict[DatasetMeasure, MeasureValue]
415            ] = None
416        else:
417            missing_dataset_meas = {
418                m
419                for m in measures
420                if isinstance(m, DatasetMeasureBase)
421                and m not in initial_dataset_measures
422            }
423            if missing_dataset_meas:
424                logger.debug(
425                    f"ignoring `initial_dataset_measure` as it is missing {missing_dataset_meas}"
426                )
427                self._current_dataset_measures = None
428            else:
429                self._current_dataset_measures = dict(initial_dataset_measures)
430
431    @property
432    def has_dataset_measures(self):
433        return self._current_dataset_measures is not None
434
435    def update(
436        self,
437        sample: Union[Sample, Iterable[Sample]],
438    ) -> None:
439        _ = self._update(sample)
440
441    def finalize(self) -> Dict[DatasetMeasure, MeasureValue]:
442        """returns aggregated dataset statistics"""
443        if self._current_dataset_measures is None:
444            self._current_dataset_measures = {}
445            for calc in self.dataset_calculators:
446                values = calc.finalize()
447                self._current_dataset_measures.update(values.items())
448
449        return self._current_dataset_measures
450
451    def update_and_get_all(
452        self,
453        sample: Union[Sample, Iterable[Sample]],
454    ) -> Dict[Measure, MeasureValue]:
455        """Returns sample as well as updated dataset statistics"""
456        last_sample = self._update(sample)
457        if last_sample is None:
458            raise ValueError("`sample` was not a `Sample`, nor did it yield any.")
459
460        return {**self._compute(last_sample), **self.finalize()}
461
462    def skip_update_and_get_all(self, sample: Sample) -> Dict[Measure, MeasureValue]:
463        """Returns sample as well as previously computed dataset statistics"""
464        return {**self._compute(sample), **self.finalize()}
465
466    def _compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]:
467        ret: Dict[SampleMeasure, MeasureValue] = {}
468        for calc in self.sample_calculators:
469            values = calc.compute(sample)
470            ret.update(values.items())
471
472        return ret
473
474    def _update(self, sample: Union[Sample, Iterable[Sample]]) -> Optional[Sample]:
475        self.sample_count += 1
476        samples = [sample] if isinstance(sample, Sample) else sample
477        last_sample = None
478        for el in samples:
479            last_sample = el
480            for calc in self.dataset_calculators:
481                calc.update(el)
482
483        self._current_dataset_measures = None
484        return last_sample

Estimates dataset statistics and computes sample statistics efficiently

StatsCalculator( measures: Collection[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], initial_dataset_measures: Optional[Mapping[Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer>, return_type=PydanticUndefined, when_used='always')]]]] = None)
400    def __init__(
401        self,
402        measures: Collection[Measure],
403        initial_dataset_measures: Optional[
404            Mapping[DatasetMeasure, MeasureValue]
405        ] = None,
406    ):
407        super().__init__()
408        self.sample_count = 0
409        self.sample_calculators, self.dataset_calculators = get_measure_calculators(
410            measures
411        )
412        if not initial_dataset_measures:
413            self._current_dataset_measures: Optional[
414                Dict[DatasetMeasure, MeasureValue]
415            ] = None
416        else:
417            missing_dataset_meas = {
418                m
419                for m in measures
420                if isinstance(m, DatasetMeasureBase)
421                and m not in initial_dataset_measures
422            }
423            if missing_dataset_meas:
424                logger.debug(
425                    f"ignoring `initial_dataset_measure` as it is missing {missing_dataset_meas}"
426                )
427                self._current_dataset_measures = None
428            else:
429                self._current_dataset_measures = dict(initial_dataset_measures)
sample_count
has_dataset_measures
431    @property
432    def has_dataset_measures(self):
433        return self._current_dataset_measures is not None
def update( self, sample: Union[bioimageio.core.Sample, Iterable[bioimageio.core.Sample]]) -> None:
435    def update(
436        self,
437        sample: Union[Sample, Iterable[Sample]],
438    ) -> None:
439        _ = self._update(sample)
def finalize( self) -> Dict[Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7f25f54c6f20>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7f25f54c7100>, return_type=PydanticUndefined, when_used='always')]]]:
441    def finalize(self) -> Dict[DatasetMeasure, MeasureValue]:
442        """returns aggregated dataset statistics"""
443        if self._current_dataset_measures is None:
444            self._current_dataset_measures = {}
445            for calc in self.dataset_calculators:
446                values = calc.finalize()
447                self._current_dataset_measures.update(values.items())
448
449        return self._current_dataset_measures

returns aggregated dataset statistics

def update_and_get_all( self, sample: Union[bioimageio.core.Sample, Iterable[bioimageio.core.Sample]]) -> Dict[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7f25f54c6f20>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7f25f54c7100>, return_type=PydanticUndefined, when_used='always')]]]:
451    def update_and_get_all(
452        self,
453        sample: Union[Sample, Iterable[Sample]],
454    ) -> Dict[Measure, MeasureValue]:
455        """Returns sample as well as updated dataset statistics"""
456        last_sample = self._update(sample)
457        if last_sample is None:
458            raise ValueError("`sample` was not a `Sample`, nor did it yield any.")
459
460        return {**self._compute(last_sample), **self.finalize()}

Returns sample as well as updated dataset statistics

def skip_update_and_get_all( self, sample: bioimageio.core.Sample) -> Dict[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7f25f54c6f20>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7f25f54c7100>, return_type=PydanticUndefined, when_used='always')]]]:
462    def skip_update_and_get_all(self, sample: Sample) -> Dict[Measure, MeasureValue]:
463        """Returns sample as well as previously computed dataset statistics"""
464        return {**self._compute(sample), **self.finalize()}

Returns sample as well as previously computed dataset statistics

def get_measure_calculators( required_measures: Iterable[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]) -> Tuple[List[Union[MeanCalculator, MeanVarStdCalculator, SamplePercentilesCalculator, NaiveSampleMeasureCalculator]], List[Union[MeanCalculator, MeanVarStdCalculator, MeanPercentilesCalculator]]]:
487def get_measure_calculators(
488    required_measures: Iterable[Measure],
489) -> Tuple[List[SampleMeasureCalculator], List[DatasetMeasureCalculator]]:
490    """determines which calculators are needed to compute the required measures efficiently"""
491
492    sample_calculators: List[SampleMeasureCalculator] = []
493    dataset_calculators: List[DatasetMeasureCalculator] = []
494
495    # split required measures into groups
496    required_sample_means: Set[SampleMean] = set()
497    required_dataset_means: Set[DatasetMean] = set()
498    required_sample_mean_var_std: Set[Union[SampleMean, SampleVar, SampleStd]] = set()
499    required_dataset_mean_var_std: Set[Union[DatasetMean, DatasetVar, DatasetStd]] = (
500        set()
501    )
502    required_sample_percentiles: Dict[
503        Tuple[MemberId, Optional[Tuple[AxisId, ...]]], Set[float]
504    ] = {}
505    required_dataset_percentiles: Dict[
506        Tuple[MemberId, Optional[Tuple[AxisId, ...]]], Set[float]
507    ] = {}
508
509    for rm in required_measures:
510        if isinstance(rm, SampleMean):
511            required_sample_means.add(rm)
512        elif isinstance(rm, DatasetMean):
513            required_dataset_means.add(rm)
514        elif isinstance(rm, (SampleVar, SampleStd)):
515            required_sample_mean_var_std.update(
516                {
517                    msv(axes=rm.axes, member_id=rm.member_id)
518                    for msv in (SampleMean, SampleStd, SampleVar)
519                }
520            )
521            assert rm in required_sample_mean_var_std
522        elif isinstance(rm, (DatasetVar, DatasetStd)):
523            required_dataset_mean_var_std.update(
524                {
525                    msv(axes=rm.axes, member_id=rm.member_id)
526                    for msv in (DatasetMean, DatasetStd, DatasetVar)
527                }
528            )
529            assert rm in required_dataset_mean_var_std
530        elif isinstance(rm, SampleQuantile):
531            required_sample_percentiles.setdefault((rm.member_id, rm.axes), set()).add(
532                rm.q
533            )
534        elif isinstance(rm, DatasetPercentile):
535            required_dataset_percentiles.setdefault((rm.member_id, rm.axes), set()).add(
536                rm.q
537            )
538        else:
539            assert_never(rm)
540
541    for rm in required_sample_means:
542        if rm in required_sample_mean_var_std:
543            # computed togehter with var and std
544            continue
545
546        sample_calculators.append(MeanCalculator(member_id=rm.member_id, axes=rm.axes))
547
548    for rm in required_sample_mean_var_std:
549        sample_calculators.append(
550            MeanVarStdCalculator(member_id=rm.member_id, axes=rm.axes)
551        )
552
553    for rm in required_dataset_means:
554        if rm in required_dataset_mean_var_std:
555            # computed togehter with var and std
556            continue
557
558        dataset_calculators.append(MeanCalculator(member_id=rm.member_id, axes=rm.axes))
559
560    for rm in required_dataset_mean_var_std:
561        dataset_calculators.append(
562            MeanVarStdCalculator(member_id=rm.member_id, axes=rm.axes)
563        )
564
565    for (tid, axes), qs in required_sample_percentiles.items():
566        sample_calculators.append(
567            SamplePercentilesCalculator(member_id=tid, axes=axes, qs=qs)
568        )
569
570    for (tid, axes), qs in required_dataset_percentiles.items():
571        dataset_calculators.append(
572            DatasetPercentilesCalculator(member_id=tid, axes=axes, qs=qs)
573        )
574
575    return sample_calculators, dataset_calculators

determines which calculators are needed to compute the required measures efficiently

def compute_dataset_measures( measures: Iterable[Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], dataset: Iterable[bioimageio.core.Sample]) -> Dict[Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7f25f54c6f20>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7f25f54c7100>, return_type=PydanticUndefined, when_used='always')]]]:
578def compute_dataset_measures(
579    measures: Iterable[DatasetMeasure], dataset: Iterable[Sample]
580) -> Dict[DatasetMeasure, MeasureValue]:
581    """compute all dataset `measures` for the given `dataset`"""
582    sample_calculators, calculators = get_measure_calculators(measures)
583    assert not sample_calculators
584
585    ret: Dict[DatasetMeasure, MeasureValue] = {}
586
587    for sample in dataset:
588        for calc in calculators:
589            calc.update(sample)
590
591    for calc in calculators:
592        ret.update(calc.finalize().items())
593
594    return ret

compute all dataset measures for the given dataset

def compute_sample_measures( measures: Iterable[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], sample: bioimageio.core.Sample) -> Dict[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7f25f54c6f20>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7f25f54c7100>, return_type=PydanticUndefined, when_used='always')]]]:
597def compute_sample_measures(
598    measures: Iterable[SampleMeasure], sample: Sample
599) -> Dict[SampleMeasure, MeasureValue]:
600    """compute all sample `measures` for the given `sample`"""
601    calculators, dataset_calculators = get_measure_calculators(measures)
602    assert not dataset_calculators
603    ret: Dict[SampleMeasure, MeasureValue] = {}
604
605    for calc in calculators:
606        ret.update(calc.compute(sample).items())
607
608    return ret

compute all sample measures for the given sample

def compute_measures( measures: Iterable[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], dataset: Iterable[bioimageio.core.Sample]) -> Dict[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7f25f54c6f20>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7f25f54c7100>, return_type=PydanticUndefined, when_used='always')]]]:
611def compute_measures(
612    measures: Iterable[Measure], dataset: Iterable[Sample]
613) -> Dict[Measure, MeasureValue]:
614    """compute all `measures` for the given `dataset`
615    sample measures are computed for the last sample in `dataset`"""
616    sample_calculators, dataset_calculators = get_measure_calculators(measures)
617    ret: Dict[Measure, MeasureValue] = {}
618    sample = None
619    for sample in dataset:
620        for calc in dataset_calculators:
621            calc.update(sample)
622    if sample is None:
623        raise ValueError("empty dataset")
624
625    for calc in dataset_calculators:
626        ret.update(calc.finalize().items())
627
628    for calc in sample_calculators:
629        ret.update(calc.compute(sample).items())
630
631    return ret

compute all measures for the given dataset sample measures are computed for the last sample in dataset

class TDigest:
58    class TDigest:
59        def update(self, obj: Any):
60            pass
61
62        def quantile(self, q: Any) -> Any:
63            pass
def update(self, obj: Any):
59        def update(self, obj: Any):
60            pass
def quantile(self, q: Any) -> Any:
62        def quantile(self, q: Any) -> Any:
63            pass
DatasetPercentilesCalculator: Type[Union[MeanPercentilesCalculator, CrickPercentilesCalculator]] = <class 'MeanPercentilesCalculator'>