bioimageio.core.stat_calculators

  1from __future__ import annotations
  2
  3import collections.abc
  4import warnings
  5from itertools import product
  6from typing import (
  7    Any,
  8    Collection,
  9    Dict,
 10    Iterable,
 11    Iterator,
 12    List,
 13    Mapping,
 14    Optional,
 15    OrderedDict,
 16    Sequence,
 17    Set,
 18    Tuple,
 19    Type,
 20    Union,
 21)
 22
 23import numpy as np
 24import xarray as xr
 25from loguru import logger
 26from numpy.typing import NDArray
 27from typing_extensions import assert_never
 28
 29from .axis import AxisId, PerAxis
 30from .common import MemberId
 31from .sample import Sample
 32from .stat_measures import (
 33    DatasetMean,
 34    DatasetMeasure,
 35    DatasetMeasureBase,
 36    DatasetPercentile,
 37    DatasetStd,
 38    DatasetVar,
 39    Measure,
 40    MeasureValue,
 41    SampleMean,
 42    SampleMeasure,
 43    SampleQuantile,
 44    SampleStd,
 45    SampleVar,
 46)
 47from .tensor import Tensor
 48
 49try:
 50    import crick
 51
 52except Exception:
 53    crick = None
 54
 55    class TDigest:
 56        def update(self, obj: Any):
 57            pass
 58
 59        def quantile(self, q: Any) -> Any:
 60            pass
 61
 62else:
 63    TDigest = crick.TDigest  # type: ignore
 64
 65
 66class MeanCalculator:
 67    """to calculate sample and dataset mean for in-memory samples"""
 68
 69    def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]):
 70        super().__init__()
 71        self._n: int = 0
 72        self._mean: Optional[Tensor] = None
 73        self._axes = None if axes is None else tuple(axes)
 74        self._member_id = member_id
 75        self._sample_mean = SampleMean(member_id=self._member_id, axes=self._axes)
 76        self._dataset_mean = DatasetMean(member_id=self._member_id, axes=self._axes)
 77
 78    def compute(self, sample: Sample) -> Dict[SampleMean, MeasureValue]:
 79        return {self._sample_mean: self._compute_impl(sample)}
 80
 81    def _compute_impl(self, sample: Sample) -> Tensor:
 82        tensor = sample.members[self._member_id].astype("float64", copy=False)
 83        return tensor.mean(dim=self._axes)
 84
 85    def update(self, sample: Sample) -> None:
 86        mean = self._compute_impl(sample)
 87        self._update_impl(sample.members[self._member_id], mean)
 88
 89    def compute_and_update(self, sample: Sample) -> Dict[SampleMean, MeasureValue]:
 90        mean = self._compute_impl(sample)
 91        self._update_impl(sample.members[self._member_id], mean)
 92        return {self._sample_mean: mean}
 93
 94    def _update_impl(self, tensor: Tensor, tensor_mean: Tensor):
 95        assert tensor_mean.dtype == "float64"
 96        # reduced voxel count
 97        n_b = int(tensor.size / tensor_mean.size)
 98
 99        if self._mean is None:
100            assert self._n == 0
101            self._n = n_b
102            self._mean = tensor_mean
103        else:
104            assert self._n != 0
105            n_a = self._n
106            mean_old = self._mean
107            self._n = n_a + n_b
108            self._mean = (n_a * mean_old + n_b * tensor_mean) / self._n
109            assert self._mean.dtype == "float64"
110
111    def finalize(self) -> Dict[DatasetMean, MeasureValue]:
112        if self._mean is None:
113            return {}
114        else:
115            return {self._dataset_mean: self._mean}
116
117
118class MeanVarStdCalculator:
119    """to calculate sample and dataset mean, variance or standard deviation"""
120
121    def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]):
122        super().__init__()
123        self._axes = None if axes is None else tuple(axes)
124        self._member_id = member_id
125        self._n: int = 0
126        self._mean: Optional[Tensor] = None
127        self._m2: Optional[Tensor] = None
128
129    def compute(
130        self, sample: Sample
131    ) -> Dict[Union[SampleMean, SampleVar, SampleStd], MeasureValue]:
132        tensor = sample.members[self._member_id]
133        mean = tensor.mean(dim=self._axes)
134        c = (tensor - mean).data
135        if self._axes is None:
136            n = tensor.size
137        else:
138            n = int(np.prod([tensor.sizes[d] for d in self._axes]))
139
140        var = xr.dot(c, c, dims=self._axes) / n
141        assert isinstance(var, xr.DataArray)
142        std = np.sqrt(var)
143        assert isinstance(std, xr.DataArray)
144        return {
145            SampleMean(axes=self._axes, member_id=self._member_id): mean,
146            SampleVar(axes=self._axes, member_id=self._member_id): Tensor.from_xarray(
147                var
148            ),
149            SampleStd(axes=self._axes, member_id=self._member_id): Tensor.from_xarray(
150                std
151            ),
152        }
153
154    def update(self, sample: Sample):
155        tensor = sample.members[self._member_id].astype("float64", copy=False)
156        mean_b = tensor.mean(dim=self._axes)
157        assert mean_b.dtype == "float64"
158        # reduced voxel count
159        n_b = int(tensor.size / mean_b.size)
160        m2_b = ((tensor - mean_b) ** 2).sum(dim=self._axes)
161        assert m2_b.dtype == "float64"
162        if self._mean is None:
163            assert self._m2 is None
164            self._n = n_b
165            self._mean = mean_b
166            self._m2 = m2_b
167        else:
168            n_a = self._n
169            mean_a = self._mean
170            m2_a = self._m2
171            self._n = n = n_a + n_b
172            self._mean = (n_a * mean_a + n_b * mean_b) / n
173            assert self._mean.dtype == "float64"
174            d = mean_b - mean_a
175            self._m2 = m2_a + m2_b + d**2 * n_a * n_b / n
176            assert self._m2.dtype == "float64"
177
178    def finalize(
179        self,
180    ) -> Dict[Union[DatasetMean, DatasetVar, DatasetStd], MeasureValue]:
181        if self._mean is None:
182            return {}
183        else:
184            assert self._m2 is not None
185            var = self._m2 / self._n
186            sqrt = np.sqrt(var)
187            if isinstance(sqrt, (int, float)):
188                # var and mean are scalar tensors, let's keep it consistent
189                sqrt = Tensor.from_xarray(xr.DataArray(sqrt))
190
191            assert isinstance(sqrt, Tensor), type(sqrt)
192            return {
193                DatasetMean(member_id=self._member_id, axes=self._axes): self._mean,
194                DatasetVar(member_id=self._member_id, axes=self._axes): var,
195                DatasetStd(member_id=self._member_id, axes=self._axes): sqrt,
196            }
197
198
199class SamplePercentilesCalculator:
200    """to calculate sample percentiles"""
201
202    def __init__(
203        self,
204        member_id: MemberId,
205        axes: Optional[Sequence[AxisId]],
206        qs: Collection[float],
207    ):
208        super().__init__()
209        assert all(0.0 <= q <= 1.0 for q in qs)
210        self._qs = sorted(set(qs))
211        self._axes = None if axes is None else tuple(axes)
212        self._member_id = member_id
213
214    def compute(self, sample: Sample) -> Dict[SampleQuantile, MeasureValue]:
215        tensor = sample.members[self._member_id]
216        ps = tensor.quantile(self._qs, dim=self._axes)
217        return {
218            SampleQuantile(q=q, axes=self._axes, member_id=self._member_id): p
219            for q, p in zip(self._qs, ps)
220        }
221
222
223class MeanPercentilesCalculator:
224    """to calculate dataset percentiles heuristically by averaging across samples
225    **note**: the returned dataset percentiles are an estiamte and **not mathematically correct**
226    """
227
228    def __init__(
229        self,
230        member_id: MemberId,
231        axes: Optional[Sequence[AxisId]],
232        qs: Collection[float],
233    ):
234        super().__init__()
235        assert all(0.0 <= q <= 1.0 for q in qs)
236        self._qs = sorted(set(qs))
237        self._axes = None if axes is None else tuple(axes)
238        self._member_id = member_id
239        self._n: int = 0
240        self._estimates: Optional[Tensor] = None
241
242    def update(self, sample: Sample):
243        tensor = sample.members[self._member_id]
244        sample_estimates = tensor.quantile(self._qs, dim=self._axes).astype(
245            "float64", copy=False
246        )
247
248        # reduced voxel count
249        n = int(tensor.size / np.prod(sample_estimates.shape_tuple[1:]))
250
251        if self._estimates is None:
252            assert self._n == 0
253            self._estimates = sample_estimates
254        else:
255            self._estimates = (self._n * self._estimates + n * sample_estimates) / (
256                self._n + n
257            )
258            assert self._estimates.dtype == "float64"
259
260        self._n += n
261
262    def finalize(self) -> Dict[DatasetPercentile, MeasureValue]:
263        if self._estimates is None:
264            return {}
265        else:
266            warnings.warn(
267                "Computed dataset percentiles naively by averaging percentiles of samples."
268            )
269            return {
270                DatasetPercentile(q=q, axes=self._axes, member_id=self._member_id): e
271                for q, e in zip(self._qs, self._estimates)
272            }
273
274
275class CrickPercentilesCalculator:
276    """to calculate dataset percentiles with the experimental [crick libray](https://github.com/dask/crick)"""
277
278    def __init__(
279        self,
280        member_id: MemberId,
281        axes: Optional[Sequence[AxisId]],
282        qs: Collection[float],
283    ):
284        warnings.warn(
285            "Computing dataset percentiles with experimental 'crick' library."
286        )
287        super().__init__()
288        assert all(0.0 <= q <= 1.0 for q in qs)
289        assert axes is None or "_percentiles" not in axes
290        self._qs = sorted(set(qs))
291        self._axes = None if axes is None else tuple(axes)
292        self._member_id = member_id
293        self._digest: Optional[List[TDigest]] = None
294        self._dims: Optional[Tuple[AxisId, ...]] = None
295        self._indices: Optional[Iterator[Tuple[int, ...]]] = None
296        self._shape: Optional[Tuple[int, ...]] = None
297
298    def _initialize(self, tensor_sizes: PerAxis[int]):
299        assert crick is not None
300        out_sizes: OrderedDict[AxisId, int] = collections.OrderedDict(
301            _percentiles=len(self._qs)
302        )
303        if self._axes is not None:
304            for d, s in tensor_sizes.items():
305                if d not in self._axes:
306                    out_sizes[d] = s
307
308        self._dims, self._shape = zip(*out_sizes.items())
309        d = int(np.prod(self._shape[1:]))  # type: ignore
310        self._digest = [TDigest() for _ in range(d)]
311        self._indices = product(*map(range, self._shape[1:]))
312
313    def update(self, part: Sample):
314        tensor = (
315            part.members[self._member_id]
316            if isinstance(part, Sample)
317            else part.members[self._member_id].data
318        )
319        assert "_percentiles" not in tensor.dims
320        if self._digest is None:
321            self._initialize(tensor.tagged_shape)
322
323        assert self._digest is not None
324        assert self._indices is not None
325        assert self._dims is not None
326        for i, idx in enumerate(self._indices):
327            self._digest[i].update(tensor[dict(zip(self._dims[1:], idx))])
328
329    def finalize(self) -> Dict[DatasetPercentile, MeasureValue]:
330        if self._digest is None:
331            return {}
332        else:
333            assert self._dims is not None
334            assert self._shape is not None
335
336            vs: NDArray[Any] = np.asarray(
337                [[d.quantile(q) for d in self._digest] for q in self._qs]
338            ).reshape(self._shape)
339            return {
340                DatasetPercentile(
341                    q=q, axes=self._axes, member_id=self._member_id
342                ): Tensor(v, dims=self._dims[1:])
343                for q, v in zip(self._qs, vs)
344            }
345
346
347if crick is None:
348    DatasetPercentilesCalculator: Type[
349        Union[MeanPercentilesCalculator, CrickPercentilesCalculator]
350    ] = MeanPercentilesCalculator
351else:
352    DatasetPercentilesCalculator = CrickPercentilesCalculator
353
354
355class NaiveSampleMeasureCalculator:
356    """wrapper for measures to match interface of other sample measure calculators"""
357
358    def __init__(self, member_id: MemberId, measure: SampleMeasure):
359        super().__init__()
360        self.tensor_name = member_id
361        self.measure = measure
362
363    def compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]:
364        return {self.measure: self.measure.compute(sample)}
365
366
367SampleMeasureCalculator = Union[
368    MeanCalculator,
369    MeanVarStdCalculator,
370    SamplePercentilesCalculator,
371    NaiveSampleMeasureCalculator,
372]
373DatasetMeasureCalculator = Union[
374    MeanCalculator, MeanVarStdCalculator, DatasetPercentilesCalculator
375]
376
377
378class StatsCalculator:
379    """Estimates dataset statistics and computes sample statistics efficiently"""
380
381    def __init__(
382        self,
383        measures: Collection[Measure],
384        initial_dataset_measures: Optional[
385            Mapping[DatasetMeasure, MeasureValue]
386        ] = None,
387    ):
388        super().__init__()
389        self.sample_count = 0
390        self.sample_calculators, self.dataset_calculators = get_measure_calculators(
391            measures
392        )
393        if not initial_dataset_measures:
394            self._current_dataset_measures: Optional[
395                Dict[DatasetMeasure, MeasureValue]
396            ] = None
397        else:
398            missing_dataset_meas = {
399                m
400                for m in measures
401                if isinstance(m, DatasetMeasureBase)
402                and m not in initial_dataset_measures
403            }
404            if missing_dataset_meas:
405                logger.debug(
406                    f"ignoring `initial_dataset_measure` as it is missing {missing_dataset_meas}"
407                )
408                self._current_dataset_measures = None
409            else:
410                self._current_dataset_measures = dict(initial_dataset_measures)
411
412    @property
413    def has_dataset_measures(self):
414        return self._current_dataset_measures is not None
415
416    def update(
417        self,
418        sample: Union[Sample, Iterable[Sample]],
419    ) -> None:
420        _ = self._update(sample)
421
422    def finalize(self) -> Dict[DatasetMeasure, MeasureValue]:
423        """returns aggregated dataset statistics"""
424        if self._current_dataset_measures is None:
425            self._current_dataset_measures = {}
426            for calc in self.dataset_calculators:
427                values = calc.finalize()
428                self._current_dataset_measures.update(values.items())
429
430        return self._current_dataset_measures
431
432    def update_and_get_all(
433        self,
434        sample: Union[Sample, Iterable[Sample]],
435    ) -> Dict[Measure, MeasureValue]:
436        """Returns sample as well as updated dataset statistics"""
437        last_sample = self._update(sample)
438        if last_sample is None:
439            raise ValueError("`sample` was not a `Sample`, nor did it yield any.")
440
441        return {**self._compute(last_sample), **self.finalize()}
442
443    def skip_update_and_get_all(self, sample: Sample) -> Dict[Measure, MeasureValue]:
444        """Returns sample as well as previously computed dataset statistics"""
445        return {**self._compute(sample), **self.finalize()}
446
447    def _compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]:
448        ret: Dict[SampleMeasure, MeasureValue] = {}
449        for calc in self.sample_calculators:
450            values = calc.compute(sample)
451            ret.update(values.items())
452
453        return ret
454
455    def _update(self, sample: Union[Sample, Iterable[Sample]]) -> Optional[Sample]:
456        self.sample_count += 1
457        samples = [sample] if isinstance(sample, Sample) else sample
458        last_sample = None
459        for el in samples:
460            last_sample = el
461            for calc in self.dataset_calculators:
462                calc.update(el)
463
464        self._current_dataset_measures = None
465        return last_sample
466
467
468def get_measure_calculators(
469    required_measures: Iterable[Measure],
470) -> Tuple[List[SampleMeasureCalculator], List[DatasetMeasureCalculator]]:
471    """determines which calculators are needed to compute the required measures efficiently"""
472
473    sample_calculators: List[SampleMeasureCalculator] = []
474    dataset_calculators: List[DatasetMeasureCalculator] = []
475
476    # split required measures into groups
477    required_sample_means: Set[SampleMean] = set()
478    required_dataset_means: Set[DatasetMean] = set()
479    required_sample_mean_var_std: Set[Union[SampleMean, SampleVar, SampleStd]] = set()
480    required_dataset_mean_var_std: Set[Union[DatasetMean, DatasetVar, DatasetStd]] = (
481        set()
482    )
483    required_sample_percentiles: Dict[
484        Tuple[MemberId, Optional[Tuple[AxisId, ...]]], Set[float]
485    ] = {}
486    required_dataset_percentiles: Dict[
487        Tuple[MemberId, Optional[Tuple[AxisId, ...]]], Set[float]
488    ] = {}
489
490    for rm in required_measures:
491        if isinstance(rm, SampleMean):
492            required_sample_means.add(rm)
493        elif isinstance(rm, DatasetMean):
494            required_dataset_means.add(rm)
495        elif isinstance(rm, (SampleVar, SampleStd)):
496            required_sample_mean_var_std.update(
497                {
498                    msv(axes=rm.axes, member_id=rm.member_id)
499                    for msv in (SampleMean, SampleStd, SampleVar)
500                }
501            )
502            assert rm in required_sample_mean_var_std
503        elif isinstance(rm, (DatasetVar, DatasetStd)):
504            required_dataset_mean_var_std.update(
505                {
506                    msv(axes=rm.axes, member_id=rm.member_id)
507                    for msv in (DatasetMean, DatasetStd, DatasetVar)
508                }
509            )
510            assert rm in required_dataset_mean_var_std
511        elif isinstance(rm, SampleQuantile):
512            required_sample_percentiles.setdefault((rm.member_id, rm.axes), set()).add(
513                rm.q
514            )
515        elif isinstance(rm, DatasetPercentile):
516            required_dataset_percentiles.setdefault((rm.member_id, rm.axes), set()).add(
517                rm.q
518            )
519        else:
520            assert_never(rm)
521
522    for rm in required_sample_means:
523        if rm in required_sample_mean_var_std:
524            # computed togehter with var and std
525            continue
526
527        sample_calculators.append(MeanCalculator(member_id=rm.member_id, axes=rm.axes))
528
529    for rm in required_sample_mean_var_std:
530        sample_calculators.append(
531            MeanVarStdCalculator(member_id=rm.member_id, axes=rm.axes)
532        )
533
534    for rm in required_dataset_means:
535        if rm in required_dataset_mean_var_std:
536            # computed togehter with var and std
537            continue
538
539        dataset_calculators.append(MeanCalculator(member_id=rm.member_id, axes=rm.axes))
540
541    for rm in required_dataset_mean_var_std:
542        dataset_calculators.append(
543            MeanVarStdCalculator(member_id=rm.member_id, axes=rm.axes)
544        )
545
546    for (tid, axes), qs in required_sample_percentiles.items():
547        sample_calculators.append(
548            SamplePercentilesCalculator(member_id=tid, axes=axes, qs=qs)
549        )
550
551    for (tid, axes), qs in required_dataset_percentiles.items():
552        dataset_calculators.append(
553            DatasetPercentilesCalculator(member_id=tid, axes=axes, qs=qs)
554        )
555
556    return sample_calculators, dataset_calculators
557
558
559def compute_dataset_measures(
560    measures: Iterable[DatasetMeasure], dataset: Iterable[Sample]
561) -> Dict[DatasetMeasure, MeasureValue]:
562    """compute all dataset `measures` for the given `dataset`"""
563    sample_calculators, calculators = get_measure_calculators(measures)
564    assert not sample_calculators
565
566    ret: Dict[DatasetMeasure, MeasureValue] = {}
567
568    for sample in dataset:
569        for calc in calculators:
570            calc.update(sample)
571
572    for calc in calculators:
573        ret.update(calc.finalize().items())
574
575    return ret
576
577
578def compute_sample_measures(
579    measures: Iterable[SampleMeasure], sample: Sample
580) -> Dict[SampleMeasure, MeasureValue]:
581    """compute all sample `measures` for the given `sample`"""
582    calculators, dataset_calculators = get_measure_calculators(measures)
583    assert not dataset_calculators
584    ret: Dict[SampleMeasure, MeasureValue] = {}
585
586    for calc in calculators:
587        ret.update(calc.compute(sample).items())
588
589    return ret
590
591
592def compute_measures(
593    measures: Iterable[Measure], dataset: Iterable[Sample]
594) -> Dict[Measure, MeasureValue]:
595    """compute all `measures` for the given `dataset`
596    sample measures are computed for the last sample in `dataset`"""
597    sample_calculators, dataset_calculators = get_measure_calculators(measures)
598    ret: Dict[Measure, MeasureValue] = {}
599    sample = None
600    for sample in dataset:
601        for calc in dataset_calculators:
602            calc.update(sample)
603    if sample is None:
604        raise ValueError("empty dataset")
605
606    for calc in dataset_calculators:
607        ret.update(calc.finalize().items())
608
609    for calc in sample_calculators:
610        ret.update(calc.compute(sample).items())
611
612    return ret
class MeanCalculator:
 67class MeanCalculator:
 68    """to calculate sample and dataset mean for in-memory samples"""
 69
 70    def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]):
 71        super().__init__()
 72        self._n: int = 0
 73        self._mean: Optional[Tensor] = None
 74        self._axes = None if axes is None else tuple(axes)
 75        self._member_id = member_id
 76        self._sample_mean = SampleMean(member_id=self._member_id, axes=self._axes)
 77        self._dataset_mean = DatasetMean(member_id=self._member_id, axes=self._axes)
 78
 79    def compute(self, sample: Sample) -> Dict[SampleMean, MeasureValue]:
 80        return {self._sample_mean: self._compute_impl(sample)}
 81
 82    def _compute_impl(self, sample: Sample) -> Tensor:
 83        tensor = sample.members[self._member_id].astype("float64", copy=False)
 84        return tensor.mean(dim=self._axes)
 85
 86    def update(self, sample: Sample) -> None:
 87        mean = self._compute_impl(sample)
 88        self._update_impl(sample.members[self._member_id], mean)
 89
 90    def compute_and_update(self, sample: Sample) -> Dict[SampleMean, MeasureValue]:
 91        mean = self._compute_impl(sample)
 92        self._update_impl(sample.members[self._member_id], mean)
 93        return {self._sample_mean: mean}
 94
 95    def _update_impl(self, tensor: Tensor, tensor_mean: Tensor):
 96        assert tensor_mean.dtype == "float64"
 97        # reduced voxel count
 98        n_b = int(tensor.size / tensor_mean.size)
 99
100        if self._mean is None:
101            assert self._n == 0
102            self._n = n_b
103            self._mean = tensor_mean
104        else:
105            assert self._n != 0
106            n_a = self._n
107            mean_old = self._mean
108            self._n = n_a + n_b
109            self._mean = (n_a * mean_old + n_b * tensor_mean) / self._n
110            assert self._mean.dtype == "float64"
111
112    def finalize(self) -> Dict[DatasetMean, MeasureValue]:
113        if self._mean is None:
114            return {}
115        else:
116            return {self._dataset_mean: self._mean}

to calculate sample and dataset mean for in-memory samples

MeanCalculator( member_id: bioimageio.spec.model.v0_5.TensorId, axes: Optional[Sequence[bioimageio.spec.model.v0_5.AxisId]])
70    def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]):
71        super().__init__()
72        self._n: int = 0
73        self._mean: Optional[Tensor] = None
74        self._axes = None if axes is None else tuple(axes)
75        self._member_id = member_id
76        self._sample_mean = SampleMean(member_id=self._member_id, axes=self._axes)
77        self._dataset_mean = DatasetMean(member_id=self._member_id, axes=self._axes)
def compute( self, sample: bioimageio.core.Sample) -> Dict[bioimageio.core.stat_measures.SampleMean, Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7f9a7099e840>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7f9a7099ea20>, return_type=PydanticUndefined, when_used='always')]]]:
79    def compute(self, sample: Sample) -> Dict[SampleMean, MeasureValue]:
80        return {self._sample_mean: self._compute_impl(sample)}
def update(self, sample: bioimageio.core.Sample) -> None:
86    def update(self, sample: Sample) -> None:
87        mean = self._compute_impl(sample)
88        self._update_impl(sample.members[self._member_id], mean)
def compute_and_update( self, sample: bioimageio.core.Sample) -> Dict[bioimageio.core.stat_measures.SampleMean, Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7f9a7099e840>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7f9a7099ea20>, return_type=PydanticUndefined, when_used='always')]]]:
90    def compute_and_update(self, sample: Sample) -> Dict[SampleMean, MeasureValue]:
91        mean = self._compute_impl(sample)
92        self._update_impl(sample.members[self._member_id], mean)
93        return {self._sample_mean: mean}
def finalize( self) -> Dict[bioimageio.core.stat_measures.DatasetMean, Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7f9a7099e840>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7f9a7099ea20>, return_type=PydanticUndefined, when_used='always')]]]:
112    def finalize(self) -> Dict[DatasetMean, MeasureValue]:
113        if self._mean is None:
114            return {}
115        else:
116            return {self._dataset_mean: self._mean}
class MeanVarStdCalculator:
119class MeanVarStdCalculator:
120    """to calculate sample and dataset mean, variance or standard deviation"""
121
122    def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]):
123        super().__init__()
124        self._axes = None if axes is None else tuple(axes)
125        self._member_id = member_id
126        self._n: int = 0
127        self._mean: Optional[Tensor] = None
128        self._m2: Optional[Tensor] = None
129
130    def compute(
131        self, sample: Sample
132    ) -> Dict[Union[SampleMean, SampleVar, SampleStd], MeasureValue]:
133        tensor = sample.members[self._member_id]
134        mean = tensor.mean(dim=self._axes)
135        c = (tensor - mean).data
136        if self._axes is None:
137            n = tensor.size
138        else:
139            n = int(np.prod([tensor.sizes[d] for d in self._axes]))
140
141        var = xr.dot(c, c, dims=self._axes) / n
142        assert isinstance(var, xr.DataArray)
143        std = np.sqrt(var)
144        assert isinstance(std, xr.DataArray)
145        return {
146            SampleMean(axes=self._axes, member_id=self._member_id): mean,
147            SampleVar(axes=self._axes, member_id=self._member_id): Tensor.from_xarray(
148                var
149            ),
150            SampleStd(axes=self._axes, member_id=self._member_id): Tensor.from_xarray(
151                std
152            ),
153        }
154
155    def update(self, sample: Sample):
156        tensor = sample.members[self._member_id].astype("float64", copy=False)
157        mean_b = tensor.mean(dim=self._axes)
158        assert mean_b.dtype == "float64"
159        # reduced voxel count
160        n_b = int(tensor.size / mean_b.size)
161        m2_b = ((tensor - mean_b) ** 2).sum(dim=self._axes)
162        assert m2_b.dtype == "float64"
163        if self._mean is None:
164            assert self._m2 is None
165            self._n = n_b
166            self._mean = mean_b
167            self._m2 = m2_b
168        else:
169            n_a = self._n
170            mean_a = self._mean
171            m2_a = self._m2
172            self._n = n = n_a + n_b
173            self._mean = (n_a * mean_a + n_b * mean_b) / n
174            assert self._mean.dtype == "float64"
175            d = mean_b - mean_a
176            self._m2 = m2_a + m2_b + d**2 * n_a * n_b / n
177            assert self._m2.dtype == "float64"
178
179    def finalize(
180        self,
181    ) -> Dict[Union[DatasetMean, DatasetVar, DatasetStd], MeasureValue]:
182        if self._mean is None:
183            return {}
184        else:
185            assert self._m2 is not None
186            var = self._m2 / self._n
187            sqrt = np.sqrt(var)
188            if isinstance(sqrt, (int, float)):
189                # var and mean are scalar tensors, let's keep it consistent
190                sqrt = Tensor.from_xarray(xr.DataArray(sqrt))
191
192            assert isinstance(sqrt, Tensor), type(sqrt)
193            return {
194                DatasetMean(member_id=self._member_id, axes=self._axes): self._mean,
195                DatasetVar(member_id=self._member_id, axes=self._axes): var,
196                DatasetStd(member_id=self._member_id, axes=self._axes): sqrt,
197            }

to calculate sample and dataset mean, variance or standard deviation

MeanVarStdCalculator( member_id: bioimageio.spec.model.v0_5.TensorId, axes: Optional[Sequence[bioimageio.spec.model.v0_5.AxisId]])
122    def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]):
123        super().__init__()
124        self._axes = None if axes is None else tuple(axes)
125        self._member_id = member_id
126        self._n: int = 0
127        self._mean: Optional[Tensor] = None
128        self._m2: Optional[Tensor] = None
def compute( self, sample: bioimageio.core.Sample) -> Dict[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleStd], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7f9a7099e840>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7f9a7099ea20>, return_type=PydanticUndefined, when_used='always')]]]:
130    def compute(
131        self, sample: Sample
132    ) -> Dict[Union[SampleMean, SampleVar, SampleStd], MeasureValue]:
133        tensor = sample.members[self._member_id]
134        mean = tensor.mean(dim=self._axes)
135        c = (tensor - mean).data
136        if self._axes is None:
137            n = tensor.size
138        else:
139            n = int(np.prod([tensor.sizes[d] for d in self._axes]))
140
141        var = xr.dot(c, c, dims=self._axes) / n
142        assert isinstance(var, xr.DataArray)
143        std = np.sqrt(var)
144        assert isinstance(std, xr.DataArray)
145        return {
146            SampleMean(axes=self._axes, member_id=self._member_id): mean,
147            SampleVar(axes=self._axes, member_id=self._member_id): Tensor.from_xarray(
148                var
149            ),
150            SampleStd(axes=self._axes, member_id=self._member_id): Tensor.from_xarray(
151                std
152            ),
153        }
def update(self, sample: bioimageio.core.Sample):
155    def update(self, sample: Sample):
156        tensor = sample.members[self._member_id].astype("float64", copy=False)
157        mean_b = tensor.mean(dim=self._axes)
158        assert mean_b.dtype == "float64"
159        # reduced voxel count
160        n_b = int(tensor.size / mean_b.size)
161        m2_b = ((tensor - mean_b) ** 2).sum(dim=self._axes)
162        assert m2_b.dtype == "float64"
163        if self._mean is None:
164            assert self._m2 is None
165            self._n = n_b
166            self._mean = mean_b
167            self._m2 = m2_b
168        else:
169            n_a = self._n
170            mean_a = self._mean
171            m2_a = self._m2
172            self._n = n = n_a + n_b
173            self._mean = (n_a * mean_a + n_b * mean_b) / n
174            assert self._mean.dtype == "float64"
175            d = mean_b - mean_a
176            self._m2 = m2_a + m2_b + d**2 * n_a * n_b / n
177            assert self._m2.dtype == "float64"
def finalize( self) -> Dict[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetStd], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7f9a7099e840>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7f9a7099ea20>, return_type=PydanticUndefined, when_used='always')]]]:
179    def finalize(
180        self,
181    ) -> Dict[Union[DatasetMean, DatasetVar, DatasetStd], MeasureValue]:
182        if self._mean is None:
183            return {}
184        else:
185            assert self._m2 is not None
186            var = self._m2 / self._n
187            sqrt = np.sqrt(var)
188            if isinstance(sqrt, (int, float)):
189                # var and mean are scalar tensors, let's keep it consistent
190                sqrt = Tensor.from_xarray(xr.DataArray(sqrt))
191
192            assert isinstance(sqrt, Tensor), type(sqrt)
193            return {
194                DatasetMean(member_id=self._member_id, axes=self._axes): self._mean,
195                DatasetVar(member_id=self._member_id, axes=self._axes): var,
196                DatasetStd(member_id=self._member_id, axes=self._axes): sqrt,
197            }
class SamplePercentilesCalculator:
200class SamplePercentilesCalculator:
201    """to calculate sample percentiles"""
202
203    def __init__(
204        self,
205        member_id: MemberId,
206        axes: Optional[Sequence[AxisId]],
207        qs: Collection[float],
208    ):
209        super().__init__()
210        assert all(0.0 <= q <= 1.0 for q in qs)
211        self._qs = sorted(set(qs))
212        self._axes = None if axes is None else tuple(axes)
213        self._member_id = member_id
214
215    def compute(self, sample: Sample) -> Dict[SampleQuantile, MeasureValue]:
216        tensor = sample.members[self._member_id]
217        ps = tensor.quantile(self._qs, dim=self._axes)
218        return {
219            SampleQuantile(q=q, axes=self._axes, member_id=self._member_id): p
220            for q, p in zip(self._qs, ps)
221        }

to calculate sample percentiles

SamplePercentilesCalculator( member_id: bioimageio.spec.model.v0_5.TensorId, axes: Optional[Sequence[bioimageio.spec.model.v0_5.AxisId]], qs: Collection[float])
203    def __init__(
204        self,
205        member_id: MemberId,
206        axes: Optional[Sequence[AxisId]],
207        qs: Collection[float],
208    ):
209        super().__init__()
210        assert all(0.0 <= q <= 1.0 for q in qs)
211        self._qs = sorted(set(qs))
212        self._axes = None if axes is None else tuple(axes)
213        self._member_id = member_id
def compute( self, sample: bioimageio.core.Sample) -> Dict[bioimageio.core.stat_measures.SampleQuantile, Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7f9a7099e840>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7f9a7099ea20>, return_type=PydanticUndefined, when_used='always')]]]:
215    def compute(self, sample: Sample) -> Dict[SampleQuantile, MeasureValue]:
216        tensor = sample.members[self._member_id]
217        ps = tensor.quantile(self._qs, dim=self._axes)
218        return {
219            SampleQuantile(q=q, axes=self._axes, member_id=self._member_id): p
220            for q, p in zip(self._qs, ps)
221        }
class MeanPercentilesCalculator:
224class MeanPercentilesCalculator:
225    """to calculate dataset percentiles heuristically by averaging across samples
226    **note**: the returned dataset percentiles are an estiamte and **not mathematically correct**
227    """
228
229    def __init__(
230        self,
231        member_id: MemberId,
232        axes: Optional[Sequence[AxisId]],
233        qs: Collection[float],
234    ):
235        super().__init__()
236        assert all(0.0 <= q <= 1.0 for q in qs)
237        self._qs = sorted(set(qs))
238        self._axes = None if axes is None else tuple(axes)
239        self._member_id = member_id
240        self._n: int = 0
241        self._estimates: Optional[Tensor] = None
242
243    def update(self, sample: Sample):
244        tensor = sample.members[self._member_id]
245        sample_estimates = tensor.quantile(self._qs, dim=self._axes).astype(
246            "float64", copy=False
247        )
248
249        # reduced voxel count
250        n = int(tensor.size / np.prod(sample_estimates.shape_tuple[1:]))
251
252        if self._estimates is None:
253            assert self._n == 0
254            self._estimates = sample_estimates
255        else:
256            self._estimates = (self._n * self._estimates + n * sample_estimates) / (
257                self._n + n
258            )
259            assert self._estimates.dtype == "float64"
260
261        self._n += n
262
263    def finalize(self) -> Dict[DatasetPercentile, MeasureValue]:
264        if self._estimates is None:
265            return {}
266        else:
267            warnings.warn(
268                "Computed dataset percentiles naively by averaging percentiles of samples."
269            )
270            return {
271                DatasetPercentile(q=q, axes=self._axes, member_id=self._member_id): e
272                for q, e in zip(self._qs, self._estimates)
273            }

to calculate dataset percentiles heuristically by averaging across samples note: the returned dataset percentiles are an estiamte and not mathematically correct

MeanPercentilesCalculator( member_id: bioimageio.spec.model.v0_5.TensorId, axes: Optional[Sequence[bioimageio.spec.model.v0_5.AxisId]], qs: Collection[float])
229    def __init__(
230        self,
231        member_id: MemberId,
232        axes: Optional[Sequence[AxisId]],
233        qs: Collection[float],
234    ):
235        super().__init__()
236        assert all(0.0 <= q <= 1.0 for q in qs)
237        self._qs = sorted(set(qs))
238        self._axes = None if axes is None else tuple(axes)
239        self._member_id = member_id
240        self._n: int = 0
241        self._estimates: Optional[Tensor] = None
def update(self, sample: bioimageio.core.Sample):
243    def update(self, sample: Sample):
244        tensor = sample.members[self._member_id]
245        sample_estimates = tensor.quantile(self._qs, dim=self._axes).astype(
246            "float64", copy=False
247        )
248
249        # reduced voxel count
250        n = int(tensor.size / np.prod(sample_estimates.shape_tuple[1:]))
251
252        if self._estimates is None:
253            assert self._n == 0
254            self._estimates = sample_estimates
255        else:
256            self._estimates = (self._n * self._estimates + n * sample_estimates) / (
257                self._n + n
258            )
259            assert self._estimates.dtype == "float64"
260
261        self._n += n
def finalize( self) -> Dict[bioimageio.core.stat_measures.DatasetPercentile, Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7f9a7099e840>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7f9a7099ea20>, return_type=PydanticUndefined, when_used='always')]]]:
263    def finalize(self) -> Dict[DatasetPercentile, MeasureValue]:
264        if self._estimates is None:
265            return {}
266        else:
267            warnings.warn(
268                "Computed dataset percentiles naively by averaging percentiles of samples."
269            )
270            return {
271                DatasetPercentile(q=q, axes=self._axes, member_id=self._member_id): e
272                for q, e in zip(self._qs, self._estimates)
273            }
class CrickPercentilesCalculator:
276class CrickPercentilesCalculator:
277    """to calculate dataset percentiles with the experimental [crick libray](https://github.com/dask/crick)"""
278
279    def __init__(
280        self,
281        member_id: MemberId,
282        axes: Optional[Sequence[AxisId]],
283        qs: Collection[float],
284    ):
285        warnings.warn(
286            "Computing dataset percentiles with experimental 'crick' library."
287        )
288        super().__init__()
289        assert all(0.0 <= q <= 1.0 for q in qs)
290        assert axes is None or "_percentiles" not in axes
291        self._qs = sorted(set(qs))
292        self._axes = None if axes is None else tuple(axes)
293        self._member_id = member_id
294        self._digest: Optional[List[TDigest]] = None
295        self._dims: Optional[Tuple[AxisId, ...]] = None
296        self._indices: Optional[Iterator[Tuple[int, ...]]] = None
297        self._shape: Optional[Tuple[int, ...]] = None
298
299    def _initialize(self, tensor_sizes: PerAxis[int]):
300        assert crick is not None
301        out_sizes: OrderedDict[AxisId, int] = collections.OrderedDict(
302            _percentiles=len(self._qs)
303        )
304        if self._axes is not None:
305            for d, s in tensor_sizes.items():
306                if d not in self._axes:
307                    out_sizes[d] = s
308
309        self._dims, self._shape = zip(*out_sizes.items())
310        d = int(np.prod(self._shape[1:]))  # type: ignore
311        self._digest = [TDigest() for _ in range(d)]
312        self._indices = product(*map(range, self._shape[1:]))
313
314    def update(self, part: Sample):
315        tensor = (
316            part.members[self._member_id]
317            if isinstance(part, Sample)
318            else part.members[self._member_id].data
319        )
320        assert "_percentiles" not in tensor.dims
321        if self._digest is None:
322            self._initialize(tensor.tagged_shape)
323
324        assert self._digest is not None
325        assert self._indices is not None
326        assert self._dims is not None
327        for i, idx in enumerate(self._indices):
328            self._digest[i].update(tensor[dict(zip(self._dims[1:], idx))])
329
330    def finalize(self) -> Dict[DatasetPercentile, MeasureValue]:
331        if self._digest is None:
332            return {}
333        else:
334            assert self._dims is not None
335            assert self._shape is not None
336
337            vs: NDArray[Any] = np.asarray(
338                [[d.quantile(q) for d in self._digest] for q in self._qs]
339            ).reshape(self._shape)
340            return {
341                DatasetPercentile(
342                    q=q, axes=self._axes, member_id=self._member_id
343                ): Tensor(v, dims=self._dims[1:])
344                for q, v in zip(self._qs, vs)
345            }

to calculate dataset percentiles with the experimental crick libray

CrickPercentilesCalculator( member_id: bioimageio.spec.model.v0_5.TensorId, axes: Optional[Sequence[bioimageio.spec.model.v0_5.AxisId]], qs: Collection[float])
279    def __init__(
280        self,
281        member_id: MemberId,
282        axes: Optional[Sequence[AxisId]],
283        qs: Collection[float],
284    ):
285        warnings.warn(
286            "Computing dataset percentiles with experimental 'crick' library."
287        )
288        super().__init__()
289        assert all(0.0 <= q <= 1.0 for q in qs)
290        assert axes is None or "_percentiles" not in axes
291        self._qs = sorted(set(qs))
292        self._axes = None if axes is None else tuple(axes)
293        self._member_id = member_id
294        self._digest: Optional[List[TDigest]] = None
295        self._dims: Optional[Tuple[AxisId, ...]] = None
296        self._indices: Optional[Iterator[Tuple[int, ...]]] = None
297        self._shape: Optional[Tuple[int, ...]] = None
def update(self, part: bioimageio.core.Sample):
314    def update(self, part: Sample):
315        tensor = (
316            part.members[self._member_id]
317            if isinstance(part, Sample)
318            else part.members[self._member_id].data
319        )
320        assert "_percentiles" not in tensor.dims
321        if self._digest is None:
322            self._initialize(tensor.tagged_shape)
323
324        assert self._digest is not None
325        assert self._indices is not None
326        assert self._dims is not None
327        for i, idx in enumerate(self._indices):
328            self._digest[i].update(tensor[dict(zip(self._dims[1:], idx))])
def finalize( self) -> Dict[bioimageio.core.stat_measures.DatasetPercentile, Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7f9a7099e840>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7f9a7099ea20>, return_type=PydanticUndefined, when_used='always')]]]:
330    def finalize(self) -> Dict[DatasetPercentile, MeasureValue]:
331        if self._digest is None:
332            return {}
333        else:
334            assert self._dims is not None
335            assert self._shape is not None
336
337            vs: NDArray[Any] = np.asarray(
338                [[d.quantile(q) for d in self._digest] for q in self._qs]
339            ).reshape(self._shape)
340            return {
341                DatasetPercentile(
342                    q=q, axes=self._axes, member_id=self._member_id
343                ): Tensor(v, dims=self._dims[1:])
344                for q, v in zip(self._qs, vs)
345            }
class NaiveSampleMeasureCalculator:
356class NaiveSampleMeasureCalculator:
357    """wrapper for measures to match interface of other sample measure calculators"""
358
359    def __init__(self, member_id: MemberId, measure: SampleMeasure):
360        super().__init__()
361        self.tensor_name = member_id
362        self.measure = measure
363
364    def compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]:
365        return {self.measure: self.measure.compute(sample)}

wrapper for measures to match interface of other sample measure calculators

NaiveSampleMeasureCalculator( member_id: bioimageio.spec.model.v0_5.TensorId, measure: Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)])
359    def __init__(self, member_id: MemberId, measure: SampleMeasure):
360        super().__init__()
361        self.tensor_name = member_id
362        self.measure = measure
tensor_name
measure
def compute( self, sample: bioimageio.core.Sample) -> Dict[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7f9a7099e840>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7f9a7099ea20>, return_type=PydanticUndefined, when_used='always')]]]:
364    def compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]:
365        return {self.measure: self.measure.compute(sample)}
DatasetMeasureCalculator = typing.Union[MeanCalculator, MeanVarStdCalculator, MeanPercentilesCalculator]
class StatsCalculator:
379class StatsCalculator:
380    """Estimates dataset statistics and computes sample statistics efficiently"""
381
382    def __init__(
383        self,
384        measures: Collection[Measure],
385        initial_dataset_measures: Optional[
386            Mapping[DatasetMeasure, MeasureValue]
387        ] = None,
388    ):
389        super().__init__()
390        self.sample_count = 0
391        self.sample_calculators, self.dataset_calculators = get_measure_calculators(
392            measures
393        )
394        if not initial_dataset_measures:
395            self._current_dataset_measures: Optional[
396                Dict[DatasetMeasure, MeasureValue]
397            ] = None
398        else:
399            missing_dataset_meas = {
400                m
401                for m in measures
402                if isinstance(m, DatasetMeasureBase)
403                and m not in initial_dataset_measures
404            }
405            if missing_dataset_meas:
406                logger.debug(
407                    f"ignoring `initial_dataset_measure` as it is missing {missing_dataset_meas}"
408                )
409                self._current_dataset_measures = None
410            else:
411                self._current_dataset_measures = dict(initial_dataset_measures)
412
413    @property
414    def has_dataset_measures(self):
415        return self._current_dataset_measures is not None
416
417    def update(
418        self,
419        sample: Union[Sample, Iterable[Sample]],
420    ) -> None:
421        _ = self._update(sample)
422
423    def finalize(self) -> Dict[DatasetMeasure, MeasureValue]:
424        """returns aggregated dataset statistics"""
425        if self._current_dataset_measures is None:
426            self._current_dataset_measures = {}
427            for calc in self.dataset_calculators:
428                values = calc.finalize()
429                self._current_dataset_measures.update(values.items())
430
431        return self._current_dataset_measures
432
433    def update_and_get_all(
434        self,
435        sample: Union[Sample, Iterable[Sample]],
436    ) -> Dict[Measure, MeasureValue]:
437        """Returns sample as well as updated dataset statistics"""
438        last_sample = self._update(sample)
439        if last_sample is None:
440            raise ValueError("`sample` was not a `Sample`, nor did it yield any.")
441
442        return {**self._compute(last_sample), **self.finalize()}
443
444    def skip_update_and_get_all(self, sample: Sample) -> Dict[Measure, MeasureValue]:
445        """Returns sample as well as previously computed dataset statistics"""
446        return {**self._compute(sample), **self.finalize()}
447
448    def _compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]:
449        ret: Dict[SampleMeasure, MeasureValue] = {}
450        for calc in self.sample_calculators:
451            values = calc.compute(sample)
452            ret.update(values.items())
453
454        return ret
455
456    def _update(self, sample: Union[Sample, Iterable[Sample]]) -> Optional[Sample]:
457        self.sample_count += 1
458        samples = [sample] if isinstance(sample, Sample) else sample
459        last_sample = None
460        for el in samples:
461            last_sample = el
462            for calc in self.dataset_calculators:
463                calc.update(el)
464
465        self._current_dataset_measures = None
466        return last_sample

Estimates dataset statistics and computes sample statistics efficiently

StatsCalculator( measures: Collection[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], initial_dataset_measures: Optional[Mapping[Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer>, return_type=PydanticUndefined, when_used='always')]]]] = None)
382    def __init__(
383        self,
384        measures: Collection[Measure],
385        initial_dataset_measures: Optional[
386            Mapping[DatasetMeasure, MeasureValue]
387        ] = None,
388    ):
389        super().__init__()
390        self.sample_count = 0
391        self.sample_calculators, self.dataset_calculators = get_measure_calculators(
392            measures
393        )
394        if not initial_dataset_measures:
395            self._current_dataset_measures: Optional[
396                Dict[DatasetMeasure, MeasureValue]
397            ] = None
398        else:
399            missing_dataset_meas = {
400                m
401                for m in measures
402                if isinstance(m, DatasetMeasureBase)
403                and m not in initial_dataset_measures
404            }
405            if missing_dataset_meas:
406                logger.debug(
407                    f"ignoring `initial_dataset_measure` as it is missing {missing_dataset_meas}"
408                )
409                self._current_dataset_measures = None
410            else:
411                self._current_dataset_measures = dict(initial_dataset_measures)
sample_count
has_dataset_measures
413    @property
414    def has_dataset_measures(self):
415        return self._current_dataset_measures is not None
def update( self, sample: Union[bioimageio.core.Sample, Iterable[bioimageio.core.Sample]]) -> None:
417    def update(
418        self,
419        sample: Union[Sample, Iterable[Sample]],
420    ) -> None:
421        _ = self._update(sample)
def finalize( self) -> Dict[Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7f9a7099e840>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7f9a7099ea20>, return_type=PydanticUndefined, when_used='always')]]]:
423    def finalize(self) -> Dict[DatasetMeasure, MeasureValue]:
424        """returns aggregated dataset statistics"""
425        if self._current_dataset_measures is None:
426            self._current_dataset_measures = {}
427            for calc in self.dataset_calculators:
428                values = calc.finalize()
429                self._current_dataset_measures.update(values.items())
430
431        return self._current_dataset_measures

returns aggregated dataset statistics

def update_and_get_all( self, sample: Union[bioimageio.core.Sample, Iterable[bioimageio.core.Sample]]) -> Dict[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7f9a7099e840>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7f9a7099ea20>, return_type=PydanticUndefined, when_used='always')]]]:
433    def update_and_get_all(
434        self,
435        sample: Union[Sample, Iterable[Sample]],
436    ) -> Dict[Measure, MeasureValue]:
437        """Returns sample as well as updated dataset statistics"""
438        last_sample = self._update(sample)
439        if last_sample is None:
440            raise ValueError("`sample` was not a `Sample`, nor did it yield any.")
441
442        return {**self._compute(last_sample), **self.finalize()}

Returns sample as well as updated dataset statistics

def skip_update_and_get_all( self, sample: bioimageio.core.Sample) -> Dict[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7f9a7099e840>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7f9a7099ea20>, return_type=PydanticUndefined, when_used='always')]]]:
444    def skip_update_and_get_all(self, sample: Sample) -> Dict[Measure, MeasureValue]:
445        """Returns sample as well as previously computed dataset statistics"""
446        return {**self._compute(sample), **self.finalize()}

Returns sample as well as previously computed dataset statistics

def get_measure_calculators( required_measures: Iterable[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]) -> Tuple[List[Union[MeanCalculator, MeanVarStdCalculator, SamplePercentilesCalculator, NaiveSampleMeasureCalculator]], List[Union[MeanCalculator, MeanVarStdCalculator, MeanPercentilesCalculator]]]:
469def get_measure_calculators(
470    required_measures: Iterable[Measure],
471) -> Tuple[List[SampleMeasureCalculator], List[DatasetMeasureCalculator]]:
472    """determines which calculators are needed to compute the required measures efficiently"""
473
474    sample_calculators: List[SampleMeasureCalculator] = []
475    dataset_calculators: List[DatasetMeasureCalculator] = []
476
477    # split required measures into groups
478    required_sample_means: Set[SampleMean] = set()
479    required_dataset_means: Set[DatasetMean] = set()
480    required_sample_mean_var_std: Set[Union[SampleMean, SampleVar, SampleStd]] = set()
481    required_dataset_mean_var_std: Set[Union[DatasetMean, DatasetVar, DatasetStd]] = (
482        set()
483    )
484    required_sample_percentiles: Dict[
485        Tuple[MemberId, Optional[Tuple[AxisId, ...]]], Set[float]
486    ] = {}
487    required_dataset_percentiles: Dict[
488        Tuple[MemberId, Optional[Tuple[AxisId, ...]]], Set[float]
489    ] = {}
490
491    for rm in required_measures:
492        if isinstance(rm, SampleMean):
493            required_sample_means.add(rm)
494        elif isinstance(rm, DatasetMean):
495            required_dataset_means.add(rm)
496        elif isinstance(rm, (SampleVar, SampleStd)):
497            required_sample_mean_var_std.update(
498                {
499                    msv(axes=rm.axes, member_id=rm.member_id)
500                    for msv in (SampleMean, SampleStd, SampleVar)
501                }
502            )
503            assert rm in required_sample_mean_var_std
504        elif isinstance(rm, (DatasetVar, DatasetStd)):
505            required_dataset_mean_var_std.update(
506                {
507                    msv(axes=rm.axes, member_id=rm.member_id)
508                    for msv in (DatasetMean, DatasetStd, DatasetVar)
509                }
510            )
511            assert rm in required_dataset_mean_var_std
512        elif isinstance(rm, SampleQuantile):
513            required_sample_percentiles.setdefault((rm.member_id, rm.axes), set()).add(
514                rm.q
515            )
516        elif isinstance(rm, DatasetPercentile):
517            required_dataset_percentiles.setdefault((rm.member_id, rm.axes), set()).add(
518                rm.q
519            )
520        else:
521            assert_never(rm)
522
523    for rm in required_sample_means:
524        if rm in required_sample_mean_var_std:
525            # computed togehter with var and std
526            continue
527
528        sample_calculators.append(MeanCalculator(member_id=rm.member_id, axes=rm.axes))
529
530    for rm in required_sample_mean_var_std:
531        sample_calculators.append(
532            MeanVarStdCalculator(member_id=rm.member_id, axes=rm.axes)
533        )
534
535    for rm in required_dataset_means:
536        if rm in required_dataset_mean_var_std:
537            # computed togehter with var and std
538            continue
539
540        dataset_calculators.append(MeanCalculator(member_id=rm.member_id, axes=rm.axes))
541
542    for rm in required_dataset_mean_var_std:
543        dataset_calculators.append(
544            MeanVarStdCalculator(member_id=rm.member_id, axes=rm.axes)
545        )
546
547    for (tid, axes), qs in required_sample_percentiles.items():
548        sample_calculators.append(
549            SamplePercentilesCalculator(member_id=tid, axes=axes, qs=qs)
550        )
551
552    for (tid, axes), qs in required_dataset_percentiles.items():
553        dataset_calculators.append(
554            DatasetPercentilesCalculator(member_id=tid, axes=axes, qs=qs)
555        )
556
557    return sample_calculators, dataset_calculators

determines which calculators are needed to compute the required measures efficiently

def compute_dataset_measures( measures: Iterable[Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], dataset: Iterable[bioimageio.core.Sample]) -> Dict[Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7f9a7099e840>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7f9a7099ea20>, return_type=PydanticUndefined, when_used='always')]]]:
560def compute_dataset_measures(
561    measures: Iterable[DatasetMeasure], dataset: Iterable[Sample]
562) -> Dict[DatasetMeasure, MeasureValue]:
563    """compute all dataset `measures` for the given `dataset`"""
564    sample_calculators, calculators = get_measure_calculators(measures)
565    assert not sample_calculators
566
567    ret: Dict[DatasetMeasure, MeasureValue] = {}
568
569    for sample in dataset:
570        for calc in calculators:
571            calc.update(sample)
572
573    for calc in calculators:
574        ret.update(calc.finalize().items())
575
576    return ret

compute all dataset measures for the given dataset

def compute_sample_measures( measures: Iterable[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], sample: bioimageio.core.Sample) -> Dict[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7f9a7099e840>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7f9a7099ea20>, return_type=PydanticUndefined, when_used='always')]]]:
579def compute_sample_measures(
580    measures: Iterable[SampleMeasure], sample: Sample
581) -> Dict[SampleMeasure, MeasureValue]:
582    """compute all sample `measures` for the given `sample`"""
583    calculators, dataset_calculators = get_measure_calculators(measures)
584    assert not dataset_calculators
585    ret: Dict[SampleMeasure, MeasureValue] = {}
586
587    for calc in calculators:
588        ret.update(calc.compute(sample).items())
589
590    return ret

compute all sample measures for the given sample

def compute_measures( measures: Iterable[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], dataset: Iterable[bioimageio.core.Sample]) -> Dict[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7f9a7099e840>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7f9a7099ea20>, return_type=PydanticUndefined, when_used='always')]]]:
593def compute_measures(
594    measures: Iterable[Measure], dataset: Iterable[Sample]
595) -> Dict[Measure, MeasureValue]:
596    """compute all `measures` for the given `dataset`
597    sample measures are computed for the last sample in `dataset`"""
598    sample_calculators, dataset_calculators = get_measure_calculators(measures)
599    ret: Dict[Measure, MeasureValue] = {}
600    sample = None
601    for sample in dataset:
602        for calc in dataset_calculators:
603            calc.update(sample)
604    if sample is None:
605        raise ValueError("empty dataset")
606
607    for calc in dataset_calculators:
608        ret.update(calc.finalize().items())
609
610    for calc in sample_calculators:
611        ret.update(calc.compute(sample).items())
612
613    return ret

compute all measures for the given dataset sample measures are computed for the last sample in dataset

class TDigest:
56    class TDigest:
57        def update(self, obj: Any):
58            pass
59
60        def quantile(self, q: Any) -> Any:
61            pass
def update(self, obj: Any):
57        def update(self, obj: Any):
58            pass
def quantile(self, q: Any) -> Any:
60        def quantile(self, q: Any) -> Any:
61            pass
DatasetPercentilesCalculator: Type[Union[MeanPercentilesCalculator, CrickPercentilesCalculator]] = <class 'MeanPercentilesCalculator'>