bioimageio.core.stat_calculators

  1from __future__ import annotations
  2
  3import collections
  4import warnings
  5from itertools import product
  6from typing import (
  7    Any,
  8    Collection,
  9    Dict,
 10    Iterable,
 11    Iterator,
 12    List,
 13    Mapping,
 14    Optional,
 15    OrderedDict,
 16    Sequence,
 17    Set,
 18    Tuple,
 19    Type,
 20    Union,
 21)
 22
 23import numpy as np
 24import xarray as xr
 25from bioimageio.spec.model.v0_5 import BATCH_AXIS_ID
 26from loguru import logger
 27from numpy.typing import NDArray
 28from typing_extensions import assert_never
 29
 30from .axis import AxisId, PerAxis
 31from .common import MemberId
 32from .sample import Sample
 33from .stat_measures import (
 34    DatasetMean,
 35    DatasetMeasure,
 36    DatasetMeasureBase,
 37    DatasetPercentile,
 38    DatasetStd,
 39    DatasetVar,
 40    Measure,
 41    MeasureValue,
 42    SampleMean,
 43    SampleMeasure,
 44    SampleQuantile,
 45    SampleStd,
 46    SampleVar,
 47)
 48from .tensor import Tensor
 49
 50try:
 51    import crick  # pyright: ignore[reportMissingTypeStubs]
 52
 53except Exception:
 54    crick = None
 55
 56    class TDigest:
 57        def update(self, obj: Any):
 58            pass
 59
 60        def quantile(self, q: Any) -> Any:
 61            pass
 62
 63else:
 64    TDigest = crick.TDigest  # type: ignore
 65
 66
 67class MeanCalculator:
 68    """to calculate sample and dataset mean for in-memory samples"""
 69
 70    def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]):
 71        super().__init__()
 72        self._n: int = 0
 73        self._mean: Optional[Tensor] = None
 74        self._axes = None if axes is None else tuple(axes)
 75        self._member_id = member_id
 76        self._sample_mean = SampleMean(member_id=self._member_id, axes=self._axes)
 77        self._dataset_mean = DatasetMean(member_id=self._member_id, axes=self._axes)
 78
 79    def compute(self, sample: Sample) -> Dict[SampleMean, MeasureValue]:
 80        return {self._sample_mean: self._compute_impl(sample)}
 81
 82    def _compute_impl(self, sample: Sample) -> Tensor:
 83        tensor = sample.members[self._member_id].astype("float64", copy=False)
 84        return tensor.mean(dim=self._axes)
 85
 86    def update(self, sample: Sample) -> None:
 87        mean = self._compute_impl(sample)
 88        self._update_impl(sample.members[self._member_id], mean)
 89
 90    def compute_and_update(self, sample: Sample) -> Dict[SampleMean, MeasureValue]:
 91        mean = self._compute_impl(sample)
 92        self._update_impl(sample.members[self._member_id], mean)
 93        return {self._sample_mean: mean}
 94
 95    def _update_impl(self, tensor: Tensor, tensor_mean: Tensor):
 96        assert tensor_mean.dtype == "float64"
 97        # reduced voxel count
 98        n_b = int(tensor.size / tensor_mean.size)
 99
100        if self._mean is None:
101            assert self._n == 0
102            self._n = n_b
103            self._mean = tensor_mean
104        else:
105            assert self._n != 0
106            n_a = self._n
107            mean_old = self._mean
108            self._n = n_a + n_b
109            self._mean = (n_a * mean_old + n_b * tensor_mean) / self._n
110            assert self._mean.dtype == "float64"
111
112    def finalize(self) -> Dict[DatasetMean, MeasureValue]:
113        if self._mean is None:
114            return {}
115        else:
116            return {self._dataset_mean: self._mean}
117
118
119class MeanVarStdCalculator:
120    """to calculate sample and dataset mean, variance or standard deviation"""
121
122    def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]):
123        super().__init__()
124        self._axes = None if axes is None else tuple(map(AxisId, axes))
125        self._member_id = member_id
126        self._n: int = 0
127        self._mean: Optional[Tensor] = None
128        self._m2: Optional[Tensor] = None
129
130    def compute(
131        self, sample: Sample
132    ) -> Dict[Union[SampleMean, SampleVar, SampleStd], MeasureValue]:
133        tensor = sample.members[self._member_id]
134        mean = tensor.mean(dim=self._axes)
135        c = (tensor - mean).data
136        if self._axes is None:
137            n = tensor.size
138        else:
139            n = int(np.prod([tensor.sizes[d] for d in self._axes]))
140
141        if xr.__version__.startswith("2023"):
142            var = xr.dot(c, c, dims=self._axes) / n
143        else:
144            var = xr.dot(c, c, dim=self._axes) / n
145
146        assert isinstance(var, xr.DataArray)
147        std = np.sqrt(var)
148        assert isinstance(std, xr.DataArray)
149        return {
150            SampleMean(axes=self._axes, member_id=self._member_id): mean,
151            SampleVar(axes=self._axes, member_id=self._member_id): Tensor.from_xarray(
152                var
153            ),
154            SampleStd(axes=self._axes, member_id=self._member_id): Tensor.from_xarray(
155                std
156            ),
157        }
158
159    def update(self, sample: Sample):
160        if self._axes is not None and BATCH_AXIS_ID not in self._axes:
161            return
162
163        tensor = sample.members[self._member_id].astype("float64", copy=False)
164        mean_b = tensor.mean(dim=self._axes)
165        assert mean_b.dtype == "float64"
166        # reduced voxel count
167        n_b = int(tensor.size / mean_b.size)
168        m2_b = ((tensor - mean_b) ** 2).sum(dim=self._axes)
169        assert m2_b.dtype == "float64"
170        if self._mean is None:
171            assert self._m2 is None
172            self._n = n_b
173            self._mean = mean_b
174            self._m2 = m2_b
175        else:
176            n_a = self._n
177            mean_a = self._mean
178            m2_a = self._m2
179            self._n = n = n_a + n_b
180            self._mean = (n_a * mean_a + n_b * mean_b) / n
181            assert self._mean.dtype == "float64"
182            d = mean_b - mean_a
183            self._m2 = m2_a + m2_b + d**2 * n_a * n_b / n
184            assert self._m2.dtype == "float64"
185
186    def finalize(
187        self,
188    ) -> Dict[Union[DatasetMean, DatasetVar, DatasetStd], MeasureValue]:
189        if (
190            self._axes is not None
191            and BATCH_AXIS_ID not in self._axes
192            or self._mean is None
193        ):
194            return {}
195        else:
196            assert self._m2 is not None
197            var = self._m2 / self._n
198            sqrt = var**0.5
199            if isinstance(sqrt, (int, float)):
200                # var and mean are scalar tensors, let's keep it consistent
201                sqrt = Tensor.from_xarray(xr.DataArray(sqrt))
202
203            assert isinstance(sqrt, Tensor), type(sqrt)
204            return {
205                DatasetMean(member_id=self._member_id, axes=self._axes): self._mean,
206                DatasetVar(member_id=self._member_id, axes=self._axes): var,
207                DatasetStd(member_id=self._member_id, axes=self._axes): sqrt,
208            }
209
210
211class SamplePercentilesCalculator:
212    """to calculate sample percentiles"""
213
214    def __init__(
215        self,
216        member_id: MemberId,
217        axes: Optional[Sequence[AxisId]],
218        qs: Collection[float],
219    ):
220        super().__init__()
221        assert all(0.0 <= q <= 1.0 for q in qs)
222        self._qs = sorted(set(qs))
223        self._axes = None if axes is None else tuple(axes)
224        self._member_id = member_id
225
226    def compute(self, sample: Sample) -> Dict[SampleQuantile, MeasureValue]:
227        tensor = sample.members[self._member_id]
228        ps = tensor.quantile(self._qs, dim=self._axes)
229        return {
230            SampleQuantile(q=q, axes=self._axes, member_id=self._member_id): p
231            for q, p in zip(self._qs, ps)
232        }
233
234
235class MeanPercentilesCalculator:
236    """to calculate dataset percentiles heuristically by averaging across samples
237    **note**: the returned dataset percentiles are an estiamte and **not mathematically correct**
238    """
239
240    def __init__(
241        self,
242        member_id: MemberId,
243        axes: Optional[Sequence[AxisId]],
244        qs: Collection[float],
245    ):
246        super().__init__()
247        assert all(0.0 <= q <= 1.0 for q in qs)
248        self._qs = sorted(set(qs))
249        self._axes = None if axes is None else tuple(axes)
250        self._member_id = member_id
251        self._n: int = 0
252        self._estimates: Optional[Tensor] = None
253
254    def update(self, sample: Sample):
255        tensor = sample.members[self._member_id]
256        sample_estimates = tensor.quantile(self._qs, dim=self._axes).astype(
257            "float64", copy=False
258        )
259
260        # reduced voxel count
261        n = int(tensor.size / np.prod(sample_estimates.shape_tuple[1:]))
262
263        if self._estimates is None:
264            assert self._n == 0
265            self._estimates = sample_estimates
266        else:
267            self._estimates = (self._n * self._estimates + n * sample_estimates) / (
268                self._n + n
269            )
270            assert self._estimates.dtype == "float64"
271
272        self._n += n
273
274    def finalize(self) -> Dict[DatasetPercentile, MeasureValue]:
275        if self._estimates is None:
276            return {}
277        else:
278            warnings.warn(
279                "Computed dataset percentiles naively by averaging percentiles of samples."
280            )
281            return {
282                DatasetPercentile(q=q, axes=self._axes, member_id=self._member_id): e
283                for q, e in zip(self._qs, self._estimates)
284            }
285
286
287class CrickPercentilesCalculator:
288    """to calculate dataset percentiles with the experimental [crick libray](https://github.com/dask/crick)"""
289
290    def __init__(
291        self,
292        member_id: MemberId,
293        axes: Optional[Sequence[AxisId]],
294        qs: Collection[float],
295    ):
296        warnings.warn(
297            "Computing dataset percentiles with experimental 'crick' library."
298        )
299        super().__init__()
300        assert all(0.0 <= q <= 1.0 for q in qs)
301        assert axes is None or "_percentiles" not in axes
302        self._qs = sorted(set(qs))
303        self._axes = None if axes is None else tuple(axes)
304        self._member_id = member_id
305        self._digest: Optional[List[TDigest]] = None
306        self._dims: Optional[Tuple[AxisId, ...]] = None
307        self._indices: Optional[Iterator[Tuple[int, ...]]] = None
308        self._shape: Optional[Tuple[int, ...]] = None
309
310    def _initialize(self, tensor_sizes: PerAxis[int]):
311        assert crick is not None
312        out_sizes: OrderedDict[AxisId, int] = collections.OrderedDict(
313            _percentiles=len(self._qs)
314        )
315        if self._axes is not None:
316            for d, s in tensor_sizes.items():
317                if d not in self._axes:
318                    out_sizes[d] = s
319
320        self._dims, self._shape = zip(*out_sizes.items())
321        assert self._shape is not None
322        d = int(np.prod(self._shape[1:]))
323        self._digest = [TDigest() for _ in range(d)]
324        self._indices = product(*map(range, self._shape[1:]))
325
326    def update(self, part: Sample):
327        tensor = (
328            part.members[self._member_id]
329            if isinstance(part, Sample)
330            else part.members[self._member_id].data
331        )
332        assert "_percentiles" not in tensor.dims
333        if self._digest is None:
334            self._initialize(tensor.tagged_shape)
335
336        assert self._digest is not None
337        assert self._indices is not None
338        assert self._dims is not None
339        for i, idx in enumerate(self._indices):
340            self._digest[i].update(tensor[dict(zip(self._dims[1:], idx))])
341
342    def finalize(self) -> Dict[DatasetPercentile, MeasureValue]:
343        if self._digest is None:
344            return {}
345        else:
346            assert self._dims is not None
347            assert self._shape is not None
348
349            vs: NDArray[Any] = np.asarray(
350                [[d.quantile(q) for d in self._digest] for q in self._qs]
351            ).reshape(self._shape)
352            return {
353                DatasetPercentile(
354                    q=q, axes=self._axes, member_id=self._member_id
355                ): Tensor(v, dims=self._dims[1:])
356                for q, v in zip(self._qs, vs)
357            }
358
359
360if crick is None:
361    DatasetPercentilesCalculator: Type[
362        Union[MeanPercentilesCalculator, CrickPercentilesCalculator]
363    ] = MeanPercentilesCalculator
364else:
365    DatasetPercentilesCalculator = CrickPercentilesCalculator
366
367
368class NaiveSampleMeasureCalculator:
369    """wrapper for measures to match interface of other sample measure calculators"""
370
371    def __init__(self, member_id: MemberId, measure: SampleMeasure):
372        super().__init__()
373        self.tensor_name = member_id
374        self.measure = measure
375
376    def compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]:
377        return {self.measure: self.measure.compute(sample)}
378
379
380SampleMeasureCalculator = Union[
381    MeanCalculator,
382    MeanVarStdCalculator,
383    SamplePercentilesCalculator,
384    NaiveSampleMeasureCalculator,
385]
386DatasetMeasureCalculator = Union[
387    MeanCalculator, MeanVarStdCalculator, DatasetPercentilesCalculator
388]
389
390
391class StatsCalculator:
392    """Estimates dataset statistics and computes sample statistics efficiently"""
393
394    def __init__(
395        self,
396        measures: Collection[Measure],
397        initial_dataset_measures: Optional[
398            Mapping[DatasetMeasure, MeasureValue]
399        ] = None,
400    ):
401        super().__init__()
402        self.sample_count = 0
403        self.sample_calculators, self.dataset_calculators = get_measure_calculators(
404            measures
405        )
406        if not initial_dataset_measures:
407            self._current_dataset_measures: Optional[
408                Dict[DatasetMeasure, MeasureValue]
409            ] = None
410        else:
411            missing_dataset_meas = {
412                m
413                for m in measures
414                if isinstance(m, DatasetMeasureBase)
415                and m not in initial_dataset_measures
416            }
417            if missing_dataset_meas:
418                logger.debug(
419                    f"ignoring `initial_dataset_measure` as it is missing {missing_dataset_meas}"
420                )
421                self._current_dataset_measures = None
422            else:
423                self._current_dataset_measures = dict(initial_dataset_measures)
424
425    @property
426    def has_dataset_measures(self):
427        return self._current_dataset_measures is not None
428
429    def update(
430        self,
431        sample: Union[Sample, Iterable[Sample]],
432    ) -> None:
433        _ = self._update(sample)
434
435    def finalize(self) -> Dict[DatasetMeasure, MeasureValue]:
436        """returns aggregated dataset statistics"""
437        if self._current_dataset_measures is None:
438            self._current_dataset_measures = {}
439            for calc in self.dataset_calculators:
440                values = calc.finalize()
441                self._current_dataset_measures.update(values.items())
442
443        return self._current_dataset_measures
444
445    def update_and_get_all(
446        self,
447        sample: Union[Sample, Iterable[Sample]],
448    ) -> Dict[Measure, MeasureValue]:
449        """Returns sample as well as updated dataset statistics"""
450        last_sample = self._update(sample)
451        if last_sample is None:
452            raise ValueError("`sample` was not a `Sample`, nor did it yield any.")
453
454        return {**self._compute(last_sample), **self.finalize()}
455
456    def skip_update_and_get_all(self, sample: Sample) -> Dict[Measure, MeasureValue]:
457        """Returns sample as well as previously computed dataset statistics"""
458        return {**self._compute(sample), **self.finalize()}
459
460    def _compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]:
461        ret: Dict[SampleMeasure, MeasureValue] = {}
462        for calc in self.sample_calculators:
463            values = calc.compute(sample)
464            ret.update(values.items())
465
466        return ret
467
468    def _update(self, sample: Union[Sample, Iterable[Sample]]) -> Optional[Sample]:
469        self.sample_count += 1
470        samples = [sample] if isinstance(sample, Sample) else sample
471        last_sample = None
472        for el in samples:
473            last_sample = el
474            for calc in self.dataset_calculators:
475                calc.update(el)
476
477        self._current_dataset_measures = None
478        return last_sample
479
480
481def get_measure_calculators(
482    required_measures: Iterable[Measure],
483) -> Tuple[List[SampleMeasureCalculator], List[DatasetMeasureCalculator]]:
484    """determines which calculators are needed to compute the required measures efficiently"""
485
486    sample_calculators: List[SampleMeasureCalculator] = []
487    dataset_calculators: List[DatasetMeasureCalculator] = []
488
489    # split required measures into groups
490    required_sample_means: Set[SampleMean] = set()
491    required_dataset_means: Set[DatasetMean] = set()
492    required_sample_mean_var_std: Set[Union[SampleMean, SampleVar, SampleStd]] = set()
493    required_dataset_mean_var_std: Set[Union[DatasetMean, DatasetVar, DatasetStd]] = (
494        set()
495    )
496    required_sample_percentiles: Dict[
497        Tuple[MemberId, Optional[Tuple[AxisId, ...]]], Set[float]
498    ] = {}
499    required_dataset_percentiles: Dict[
500        Tuple[MemberId, Optional[Tuple[AxisId, ...]]], Set[float]
501    ] = {}
502
503    for rm in required_measures:
504        if isinstance(rm, SampleMean):
505            required_sample_means.add(rm)
506        elif isinstance(rm, DatasetMean):
507            required_dataset_means.add(rm)
508        elif isinstance(rm, (SampleVar, SampleStd)):
509            required_sample_mean_var_std.update(
510                {
511                    msv(axes=rm.axes, member_id=rm.member_id)
512                    for msv in (SampleMean, SampleStd, SampleVar)
513                }
514            )
515            assert rm in required_sample_mean_var_std
516        elif isinstance(rm, (DatasetVar, DatasetStd)):
517            required_dataset_mean_var_std.update(
518                {
519                    msv(axes=rm.axes, member_id=rm.member_id)
520                    for msv in (DatasetMean, DatasetStd, DatasetVar)
521                }
522            )
523            assert rm in required_dataset_mean_var_std
524        elif isinstance(rm, SampleQuantile):
525            required_sample_percentiles.setdefault((rm.member_id, rm.axes), set()).add(
526                rm.q
527            )
528        elif isinstance(rm, DatasetPercentile):
529            required_dataset_percentiles.setdefault((rm.member_id, rm.axes), set()).add(
530                rm.q
531            )
532        else:
533            assert_never(rm)
534
535    for rm in required_sample_means:
536        if rm in required_sample_mean_var_std:
537            # computed togehter with var and std
538            continue
539
540        sample_calculators.append(MeanCalculator(member_id=rm.member_id, axes=rm.axes))
541
542    for rm in required_sample_mean_var_std:
543        sample_calculators.append(
544            MeanVarStdCalculator(member_id=rm.member_id, axes=rm.axes)
545        )
546
547    for rm in required_dataset_means:
548        if rm in required_dataset_mean_var_std:
549            # computed togehter with var and std
550            continue
551
552        dataset_calculators.append(MeanCalculator(member_id=rm.member_id, axes=rm.axes))
553
554    for rm in required_dataset_mean_var_std:
555        dataset_calculators.append(
556            MeanVarStdCalculator(member_id=rm.member_id, axes=rm.axes)
557        )
558
559    for (tid, axes), qs in required_sample_percentiles.items():
560        sample_calculators.append(
561            SamplePercentilesCalculator(member_id=tid, axes=axes, qs=qs)
562        )
563
564    for (tid, axes), qs in required_dataset_percentiles.items():
565        dataset_calculators.append(
566            DatasetPercentilesCalculator(member_id=tid, axes=axes, qs=qs)
567        )
568
569    return sample_calculators, dataset_calculators
570
571
572def compute_dataset_measures(
573    measures: Iterable[DatasetMeasure], dataset: Iterable[Sample]
574) -> Dict[DatasetMeasure, MeasureValue]:
575    """compute all dataset `measures` for the given `dataset`"""
576    sample_calculators, calculators = get_measure_calculators(measures)
577    assert not sample_calculators
578
579    ret: Dict[DatasetMeasure, MeasureValue] = {}
580
581    for sample in dataset:
582        for calc in calculators:
583            calc.update(sample)
584
585    for calc in calculators:
586        ret.update(calc.finalize().items())
587
588    return ret
589
590
591def compute_sample_measures(
592    measures: Iterable[SampleMeasure], sample: Sample
593) -> Dict[SampleMeasure, MeasureValue]:
594    """compute all sample `measures` for the given `sample`"""
595    calculators, dataset_calculators = get_measure_calculators(measures)
596    assert not dataset_calculators
597    ret: Dict[SampleMeasure, MeasureValue] = {}
598
599    for calc in calculators:
600        ret.update(calc.compute(sample).items())
601
602    return ret
603
604
605def compute_measures(
606    measures: Iterable[Measure], dataset: Iterable[Sample]
607) -> Dict[Measure, MeasureValue]:
608    """compute all `measures` for the given `dataset`
609    sample measures are computed for the last sample in `dataset`"""
610    sample_calculators, dataset_calculators = get_measure_calculators(measures)
611    ret: Dict[Measure, MeasureValue] = {}
612    sample = None
613    for sample in dataset:
614        for calc in dataset_calculators:
615            calc.update(sample)
616    if sample is None:
617        raise ValueError("empty dataset")
618
619    for calc in dataset_calculators:
620        ret.update(calc.finalize().items())
621
622    for calc in sample_calculators:
623        ret.update(calc.compute(sample).items())
624
625    return ret
class MeanCalculator:
 68class MeanCalculator:
 69    """to calculate sample and dataset mean for in-memory samples"""
 70
 71    def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]):
 72        super().__init__()
 73        self._n: int = 0
 74        self._mean: Optional[Tensor] = None
 75        self._axes = None if axes is None else tuple(axes)
 76        self._member_id = member_id
 77        self._sample_mean = SampleMean(member_id=self._member_id, axes=self._axes)
 78        self._dataset_mean = DatasetMean(member_id=self._member_id, axes=self._axes)
 79
 80    def compute(self, sample: Sample) -> Dict[SampleMean, MeasureValue]:
 81        return {self._sample_mean: self._compute_impl(sample)}
 82
 83    def _compute_impl(self, sample: Sample) -> Tensor:
 84        tensor = sample.members[self._member_id].astype("float64", copy=False)
 85        return tensor.mean(dim=self._axes)
 86
 87    def update(self, sample: Sample) -> None:
 88        mean = self._compute_impl(sample)
 89        self._update_impl(sample.members[self._member_id], mean)
 90
 91    def compute_and_update(self, sample: Sample) -> Dict[SampleMean, MeasureValue]:
 92        mean = self._compute_impl(sample)
 93        self._update_impl(sample.members[self._member_id], mean)
 94        return {self._sample_mean: mean}
 95
 96    def _update_impl(self, tensor: Tensor, tensor_mean: Tensor):
 97        assert tensor_mean.dtype == "float64"
 98        # reduced voxel count
 99        n_b = int(tensor.size / tensor_mean.size)
100
101        if self._mean is None:
102            assert self._n == 0
103            self._n = n_b
104            self._mean = tensor_mean
105        else:
106            assert self._n != 0
107            n_a = self._n
108            mean_old = self._mean
109            self._n = n_a + n_b
110            self._mean = (n_a * mean_old + n_b * tensor_mean) / self._n
111            assert self._mean.dtype == "float64"
112
113    def finalize(self) -> Dict[DatasetMean, MeasureValue]:
114        if self._mean is None:
115            return {}
116        else:
117            return {self._dataset_mean: self._mean}

to calculate sample and dataset mean for in-memory samples

MeanCalculator( member_id: bioimageio.spec.model.v0_5.TensorId, axes: Optional[Sequence[bioimageio.spec.model.v0_5.AxisId]])
71    def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]):
72        super().__init__()
73        self._n: int = 0
74        self._mean: Optional[Tensor] = None
75        self._axes = None if axes is None else tuple(axes)
76        self._member_id = member_id
77        self._sample_mean = SampleMean(member_id=self._member_id, axes=self._axes)
78        self._dataset_mean = DatasetMean(member_id=self._member_id, axes=self._axes)
def compute( self, sample: bioimageio.core.Sample) -> Dict[bioimageio.core.stat_measures.SampleMean, Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7fc58e9393a0>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7fc58e939580>, return_type=PydanticUndefined, when_used='always')]]]:
80    def compute(self, sample: Sample) -> Dict[SampleMean, MeasureValue]:
81        return {self._sample_mean: self._compute_impl(sample)}
def update(self, sample: bioimageio.core.Sample) -> None:
87    def update(self, sample: Sample) -> None:
88        mean = self._compute_impl(sample)
89        self._update_impl(sample.members[self._member_id], mean)
def compute_and_update( self, sample: bioimageio.core.Sample) -> Dict[bioimageio.core.stat_measures.SampleMean, Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7fc58e9393a0>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7fc58e939580>, return_type=PydanticUndefined, when_used='always')]]]:
91    def compute_and_update(self, sample: Sample) -> Dict[SampleMean, MeasureValue]:
92        mean = self._compute_impl(sample)
93        self._update_impl(sample.members[self._member_id], mean)
94        return {self._sample_mean: mean}
def finalize( self) -> Dict[bioimageio.core.stat_measures.DatasetMean, Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7fc58e9393a0>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7fc58e939580>, return_type=PydanticUndefined, when_used='always')]]]:
113    def finalize(self) -> Dict[DatasetMean, MeasureValue]:
114        if self._mean is None:
115            return {}
116        else:
117            return {self._dataset_mean: self._mean}
class MeanVarStdCalculator:
120class MeanVarStdCalculator:
121    """to calculate sample and dataset mean, variance or standard deviation"""
122
123    def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]):
124        super().__init__()
125        self._axes = None if axes is None else tuple(map(AxisId, axes))
126        self._member_id = member_id
127        self._n: int = 0
128        self._mean: Optional[Tensor] = None
129        self._m2: Optional[Tensor] = None
130
131    def compute(
132        self, sample: Sample
133    ) -> Dict[Union[SampleMean, SampleVar, SampleStd], MeasureValue]:
134        tensor = sample.members[self._member_id]
135        mean = tensor.mean(dim=self._axes)
136        c = (tensor - mean).data
137        if self._axes is None:
138            n = tensor.size
139        else:
140            n = int(np.prod([tensor.sizes[d] for d in self._axes]))
141
142        if xr.__version__.startswith("2023"):
143            var = xr.dot(c, c, dims=self._axes) / n
144        else:
145            var = xr.dot(c, c, dim=self._axes) / n
146
147        assert isinstance(var, xr.DataArray)
148        std = np.sqrt(var)
149        assert isinstance(std, xr.DataArray)
150        return {
151            SampleMean(axes=self._axes, member_id=self._member_id): mean,
152            SampleVar(axes=self._axes, member_id=self._member_id): Tensor.from_xarray(
153                var
154            ),
155            SampleStd(axes=self._axes, member_id=self._member_id): Tensor.from_xarray(
156                std
157            ),
158        }
159
160    def update(self, sample: Sample):
161        if self._axes is not None and BATCH_AXIS_ID not in self._axes:
162            return
163
164        tensor = sample.members[self._member_id].astype("float64", copy=False)
165        mean_b = tensor.mean(dim=self._axes)
166        assert mean_b.dtype == "float64"
167        # reduced voxel count
168        n_b = int(tensor.size / mean_b.size)
169        m2_b = ((tensor - mean_b) ** 2).sum(dim=self._axes)
170        assert m2_b.dtype == "float64"
171        if self._mean is None:
172            assert self._m2 is None
173            self._n = n_b
174            self._mean = mean_b
175            self._m2 = m2_b
176        else:
177            n_a = self._n
178            mean_a = self._mean
179            m2_a = self._m2
180            self._n = n = n_a + n_b
181            self._mean = (n_a * mean_a + n_b * mean_b) / n
182            assert self._mean.dtype == "float64"
183            d = mean_b - mean_a
184            self._m2 = m2_a + m2_b + d**2 * n_a * n_b / n
185            assert self._m2.dtype == "float64"
186
187    def finalize(
188        self,
189    ) -> Dict[Union[DatasetMean, DatasetVar, DatasetStd], MeasureValue]:
190        if (
191            self._axes is not None
192            and BATCH_AXIS_ID not in self._axes
193            or self._mean is None
194        ):
195            return {}
196        else:
197            assert self._m2 is not None
198            var = self._m2 / self._n
199            sqrt = var**0.5
200            if isinstance(sqrt, (int, float)):
201                # var and mean are scalar tensors, let's keep it consistent
202                sqrt = Tensor.from_xarray(xr.DataArray(sqrt))
203
204            assert isinstance(sqrt, Tensor), type(sqrt)
205            return {
206                DatasetMean(member_id=self._member_id, axes=self._axes): self._mean,
207                DatasetVar(member_id=self._member_id, axes=self._axes): var,
208                DatasetStd(member_id=self._member_id, axes=self._axes): sqrt,
209            }

to calculate sample and dataset mean, variance or standard deviation

MeanVarStdCalculator( member_id: bioimageio.spec.model.v0_5.TensorId, axes: Optional[Sequence[bioimageio.spec.model.v0_5.AxisId]])
123    def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]):
124        super().__init__()
125        self._axes = None if axes is None else tuple(map(AxisId, axes))
126        self._member_id = member_id
127        self._n: int = 0
128        self._mean: Optional[Tensor] = None
129        self._m2: Optional[Tensor] = None
def compute( self, sample: bioimageio.core.Sample) -> Dict[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleStd], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7fc58e9393a0>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7fc58e939580>, return_type=PydanticUndefined, when_used='always')]]]:
131    def compute(
132        self, sample: Sample
133    ) -> Dict[Union[SampleMean, SampleVar, SampleStd], MeasureValue]:
134        tensor = sample.members[self._member_id]
135        mean = tensor.mean(dim=self._axes)
136        c = (tensor - mean).data
137        if self._axes is None:
138            n = tensor.size
139        else:
140            n = int(np.prod([tensor.sizes[d] for d in self._axes]))
141
142        if xr.__version__.startswith("2023"):
143            var = xr.dot(c, c, dims=self._axes) / n
144        else:
145            var = xr.dot(c, c, dim=self._axes) / n
146
147        assert isinstance(var, xr.DataArray)
148        std = np.sqrt(var)
149        assert isinstance(std, xr.DataArray)
150        return {
151            SampleMean(axes=self._axes, member_id=self._member_id): mean,
152            SampleVar(axes=self._axes, member_id=self._member_id): Tensor.from_xarray(
153                var
154            ),
155            SampleStd(axes=self._axes, member_id=self._member_id): Tensor.from_xarray(
156                std
157            ),
158        }
def update(self, sample: bioimageio.core.Sample):
160    def update(self, sample: Sample):
161        if self._axes is not None and BATCH_AXIS_ID not in self._axes:
162            return
163
164        tensor = sample.members[self._member_id].astype("float64", copy=False)
165        mean_b = tensor.mean(dim=self._axes)
166        assert mean_b.dtype == "float64"
167        # reduced voxel count
168        n_b = int(tensor.size / mean_b.size)
169        m2_b = ((tensor - mean_b) ** 2).sum(dim=self._axes)
170        assert m2_b.dtype == "float64"
171        if self._mean is None:
172            assert self._m2 is None
173            self._n = n_b
174            self._mean = mean_b
175            self._m2 = m2_b
176        else:
177            n_a = self._n
178            mean_a = self._mean
179            m2_a = self._m2
180            self._n = n = n_a + n_b
181            self._mean = (n_a * mean_a + n_b * mean_b) / n
182            assert self._mean.dtype == "float64"
183            d = mean_b - mean_a
184            self._m2 = m2_a + m2_b + d**2 * n_a * n_b / n
185            assert self._m2.dtype == "float64"
def finalize( self) -> Dict[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetStd], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7fc58e9393a0>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7fc58e939580>, return_type=PydanticUndefined, when_used='always')]]]:
187    def finalize(
188        self,
189    ) -> Dict[Union[DatasetMean, DatasetVar, DatasetStd], MeasureValue]:
190        if (
191            self._axes is not None
192            and BATCH_AXIS_ID not in self._axes
193            or self._mean is None
194        ):
195            return {}
196        else:
197            assert self._m2 is not None
198            var = self._m2 / self._n
199            sqrt = var**0.5
200            if isinstance(sqrt, (int, float)):
201                # var and mean are scalar tensors, let's keep it consistent
202                sqrt = Tensor.from_xarray(xr.DataArray(sqrt))
203
204            assert isinstance(sqrt, Tensor), type(sqrt)
205            return {
206                DatasetMean(member_id=self._member_id, axes=self._axes): self._mean,
207                DatasetVar(member_id=self._member_id, axes=self._axes): var,
208                DatasetStd(member_id=self._member_id, axes=self._axes): sqrt,
209            }
class SamplePercentilesCalculator:
212class SamplePercentilesCalculator:
213    """to calculate sample percentiles"""
214
215    def __init__(
216        self,
217        member_id: MemberId,
218        axes: Optional[Sequence[AxisId]],
219        qs: Collection[float],
220    ):
221        super().__init__()
222        assert all(0.0 <= q <= 1.0 for q in qs)
223        self._qs = sorted(set(qs))
224        self._axes = None if axes is None else tuple(axes)
225        self._member_id = member_id
226
227    def compute(self, sample: Sample) -> Dict[SampleQuantile, MeasureValue]:
228        tensor = sample.members[self._member_id]
229        ps = tensor.quantile(self._qs, dim=self._axes)
230        return {
231            SampleQuantile(q=q, axes=self._axes, member_id=self._member_id): p
232            for q, p in zip(self._qs, ps)
233        }

to calculate sample percentiles

SamplePercentilesCalculator( member_id: bioimageio.spec.model.v0_5.TensorId, axes: Optional[Sequence[bioimageio.spec.model.v0_5.AxisId]], qs: Collection[float])
215    def __init__(
216        self,
217        member_id: MemberId,
218        axes: Optional[Sequence[AxisId]],
219        qs: Collection[float],
220    ):
221        super().__init__()
222        assert all(0.0 <= q <= 1.0 for q in qs)
223        self._qs = sorted(set(qs))
224        self._axes = None if axes is None else tuple(axes)
225        self._member_id = member_id
def compute( self, sample: bioimageio.core.Sample) -> Dict[bioimageio.core.stat_measures.SampleQuantile, Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7fc58e9393a0>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7fc58e939580>, return_type=PydanticUndefined, when_used='always')]]]:
227    def compute(self, sample: Sample) -> Dict[SampleQuantile, MeasureValue]:
228        tensor = sample.members[self._member_id]
229        ps = tensor.quantile(self._qs, dim=self._axes)
230        return {
231            SampleQuantile(q=q, axes=self._axes, member_id=self._member_id): p
232            for q, p in zip(self._qs, ps)
233        }
class MeanPercentilesCalculator:
236class MeanPercentilesCalculator:
237    """to calculate dataset percentiles heuristically by averaging across samples
238    **note**: the returned dataset percentiles are an estiamte and **not mathematically correct**
239    """
240
241    def __init__(
242        self,
243        member_id: MemberId,
244        axes: Optional[Sequence[AxisId]],
245        qs: Collection[float],
246    ):
247        super().__init__()
248        assert all(0.0 <= q <= 1.0 for q in qs)
249        self._qs = sorted(set(qs))
250        self._axes = None if axes is None else tuple(axes)
251        self._member_id = member_id
252        self._n: int = 0
253        self._estimates: Optional[Tensor] = None
254
255    def update(self, sample: Sample):
256        tensor = sample.members[self._member_id]
257        sample_estimates = tensor.quantile(self._qs, dim=self._axes).astype(
258            "float64", copy=False
259        )
260
261        # reduced voxel count
262        n = int(tensor.size / np.prod(sample_estimates.shape_tuple[1:]))
263
264        if self._estimates is None:
265            assert self._n == 0
266            self._estimates = sample_estimates
267        else:
268            self._estimates = (self._n * self._estimates + n * sample_estimates) / (
269                self._n + n
270            )
271            assert self._estimates.dtype == "float64"
272
273        self._n += n
274
275    def finalize(self) -> Dict[DatasetPercentile, MeasureValue]:
276        if self._estimates is None:
277            return {}
278        else:
279            warnings.warn(
280                "Computed dataset percentiles naively by averaging percentiles of samples."
281            )
282            return {
283                DatasetPercentile(q=q, axes=self._axes, member_id=self._member_id): e
284                for q, e in zip(self._qs, self._estimates)
285            }

to calculate dataset percentiles heuristically by averaging across samples note: the returned dataset percentiles are an estiamte and not mathematically correct

MeanPercentilesCalculator( member_id: bioimageio.spec.model.v0_5.TensorId, axes: Optional[Sequence[bioimageio.spec.model.v0_5.AxisId]], qs: Collection[float])
241    def __init__(
242        self,
243        member_id: MemberId,
244        axes: Optional[Sequence[AxisId]],
245        qs: Collection[float],
246    ):
247        super().__init__()
248        assert all(0.0 <= q <= 1.0 for q in qs)
249        self._qs = sorted(set(qs))
250        self._axes = None if axes is None else tuple(axes)
251        self._member_id = member_id
252        self._n: int = 0
253        self._estimates: Optional[Tensor] = None
def update(self, sample: bioimageio.core.Sample):
255    def update(self, sample: Sample):
256        tensor = sample.members[self._member_id]
257        sample_estimates = tensor.quantile(self._qs, dim=self._axes).astype(
258            "float64", copy=False
259        )
260
261        # reduced voxel count
262        n = int(tensor.size / np.prod(sample_estimates.shape_tuple[1:]))
263
264        if self._estimates is None:
265            assert self._n == 0
266            self._estimates = sample_estimates
267        else:
268            self._estimates = (self._n * self._estimates + n * sample_estimates) / (
269                self._n + n
270            )
271            assert self._estimates.dtype == "float64"
272
273        self._n += n
def finalize( self) -> Dict[bioimageio.core.stat_measures.DatasetPercentile, Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7fc58e9393a0>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7fc58e939580>, return_type=PydanticUndefined, when_used='always')]]]:
275    def finalize(self) -> Dict[DatasetPercentile, MeasureValue]:
276        if self._estimates is None:
277            return {}
278        else:
279            warnings.warn(
280                "Computed dataset percentiles naively by averaging percentiles of samples."
281            )
282            return {
283                DatasetPercentile(q=q, axes=self._axes, member_id=self._member_id): e
284                for q, e in zip(self._qs, self._estimates)
285            }
class CrickPercentilesCalculator:
288class CrickPercentilesCalculator:
289    """to calculate dataset percentiles with the experimental [crick libray](https://github.com/dask/crick)"""
290
291    def __init__(
292        self,
293        member_id: MemberId,
294        axes: Optional[Sequence[AxisId]],
295        qs: Collection[float],
296    ):
297        warnings.warn(
298            "Computing dataset percentiles with experimental 'crick' library."
299        )
300        super().__init__()
301        assert all(0.0 <= q <= 1.0 for q in qs)
302        assert axes is None or "_percentiles" not in axes
303        self._qs = sorted(set(qs))
304        self._axes = None if axes is None else tuple(axes)
305        self._member_id = member_id
306        self._digest: Optional[List[TDigest]] = None
307        self._dims: Optional[Tuple[AxisId, ...]] = None
308        self._indices: Optional[Iterator[Tuple[int, ...]]] = None
309        self._shape: Optional[Tuple[int, ...]] = None
310
311    def _initialize(self, tensor_sizes: PerAxis[int]):
312        assert crick is not None
313        out_sizes: OrderedDict[AxisId, int] = collections.OrderedDict(
314            _percentiles=len(self._qs)
315        )
316        if self._axes is not None:
317            for d, s in tensor_sizes.items():
318                if d not in self._axes:
319                    out_sizes[d] = s
320
321        self._dims, self._shape = zip(*out_sizes.items())
322        assert self._shape is not None
323        d = int(np.prod(self._shape[1:]))
324        self._digest = [TDigest() for _ in range(d)]
325        self._indices = product(*map(range, self._shape[1:]))
326
327    def update(self, part: Sample):
328        tensor = (
329            part.members[self._member_id]
330            if isinstance(part, Sample)
331            else part.members[self._member_id].data
332        )
333        assert "_percentiles" not in tensor.dims
334        if self._digest is None:
335            self._initialize(tensor.tagged_shape)
336
337        assert self._digest is not None
338        assert self._indices is not None
339        assert self._dims is not None
340        for i, idx in enumerate(self._indices):
341            self._digest[i].update(tensor[dict(zip(self._dims[1:], idx))])
342
343    def finalize(self) -> Dict[DatasetPercentile, MeasureValue]:
344        if self._digest is None:
345            return {}
346        else:
347            assert self._dims is not None
348            assert self._shape is not None
349
350            vs: NDArray[Any] = np.asarray(
351                [[d.quantile(q) for d in self._digest] for q in self._qs]
352            ).reshape(self._shape)
353            return {
354                DatasetPercentile(
355                    q=q, axes=self._axes, member_id=self._member_id
356                ): Tensor(v, dims=self._dims[1:])
357                for q, v in zip(self._qs, vs)
358            }

to calculate dataset percentiles with the experimental crick libray

CrickPercentilesCalculator( member_id: bioimageio.spec.model.v0_5.TensorId, axes: Optional[Sequence[bioimageio.spec.model.v0_5.AxisId]], qs: Collection[float])
291    def __init__(
292        self,
293        member_id: MemberId,
294        axes: Optional[Sequence[AxisId]],
295        qs: Collection[float],
296    ):
297        warnings.warn(
298            "Computing dataset percentiles with experimental 'crick' library."
299        )
300        super().__init__()
301        assert all(0.0 <= q <= 1.0 for q in qs)
302        assert axes is None or "_percentiles" not in axes
303        self._qs = sorted(set(qs))
304        self._axes = None if axes is None else tuple(axes)
305        self._member_id = member_id
306        self._digest: Optional[List[TDigest]] = None
307        self._dims: Optional[Tuple[AxisId, ...]] = None
308        self._indices: Optional[Iterator[Tuple[int, ...]]] = None
309        self._shape: Optional[Tuple[int, ...]] = None
def update(self, part: bioimageio.core.Sample):
327    def update(self, part: Sample):
328        tensor = (
329            part.members[self._member_id]
330            if isinstance(part, Sample)
331            else part.members[self._member_id].data
332        )
333        assert "_percentiles" not in tensor.dims
334        if self._digest is None:
335            self._initialize(tensor.tagged_shape)
336
337        assert self._digest is not None
338        assert self._indices is not None
339        assert self._dims is not None
340        for i, idx in enumerate(self._indices):
341            self._digest[i].update(tensor[dict(zip(self._dims[1:], idx))])
def finalize( self) -> Dict[bioimageio.core.stat_measures.DatasetPercentile, Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7fc58e9393a0>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7fc58e939580>, return_type=PydanticUndefined, when_used='always')]]]:
343    def finalize(self) -> Dict[DatasetPercentile, MeasureValue]:
344        if self._digest is None:
345            return {}
346        else:
347            assert self._dims is not None
348            assert self._shape is not None
349
350            vs: NDArray[Any] = np.asarray(
351                [[d.quantile(q) for d in self._digest] for q in self._qs]
352            ).reshape(self._shape)
353            return {
354                DatasetPercentile(
355                    q=q, axes=self._axes, member_id=self._member_id
356                ): Tensor(v, dims=self._dims[1:])
357                for q, v in zip(self._qs, vs)
358            }
class NaiveSampleMeasureCalculator:
369class NaiveSampleMeasureCalculator:
370    """wrapper for measures to match interface of other sample measure calculators"""
371
372    def __init__(self, member_id: MemberId, measure: SampleMeasure):
373        super().__init__()
374        self.tensor_name = member_id
375        self.measure = measure
376
377    def compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]:
378        return {self.measure: self.measure.compute(sample)}

wrapper for measures to match interface of other sample measure calculators

NaiveSampleMeasureCalculator( member_id: bioimageio.spec.model.v0_5.TensorId, measure: Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)])
372    def __init__(self, member_id: MemberId, measure: SampleMeasure):
373        super().__init__()
374        self.tensor_name = member_id
375        self.measure = measure
tensor_name
measure
def compute( self, sample: bioimageio.core.Sample) -> Dict[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7fc58e9393a0>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7fc58e939580>, return_type=PydanticUndefined, when_used='always')]]]:
377    def compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]:
378        return {self.measure: self.measure.compute(sample)}
DatasetMeasureCalculator = typing.Union[MeanCalculator, MeanVarStdCalculator, CrickPercentilesCalculator]
class StatsCalculator:
392class StatsCalculator:
393    """Estimates dataset statistics and computes sample statistics efficiently"""
394
395    def __init__(
396        self,
397        measures: Collection[Measure],
398        initial_dataset_measures: Optional[
399            Mapping[DatasetMeasure, MeasureValue]
400        ] = None,
401    ):
402        super().__init__()
403        self.sample_count = 0
404        self.sample_calculators, self.dataset_calculators = get_measure_calculators(
405            measures
406        )
407        if not initial_dataset_measures:
408            self._current_dataset_measures: Optional[
409                Dict[DatasetMeasure, MeasureValue]
410            ] = None
411        else:
412            missing_dataset_meas = {
413                m
414                for m in measures
415                if isinstance(m, DatasetMeasureBase)
416                and m not in initial_dataset_measures
417            }
418            if missing_dataset_meas:
419                logger.debug(
420                    f"ignoring `initial_dataset_measure` as it is missing {missing_dataset_meas}"
421                )
422                self._current_dataset_measures = None
423            else:
424                self._current_dataset_measures = dict(initial_dataset_measures)
425
426    @property
427    def has_dataset_measures(self):
428        return self._current_dataset_measures is not None
429
430    def update(
431        self,
432        sample: Union[Sample, Iterable[Sample]],
433    ) -> None:
434        _ = self._update(sample)
435
436    def finalize(self) -> Dict[DatasetMeasure, MeasureValue]:
437        """returns aggregated dataset statistics"""
438        if self._current_dataset_measures is None:
439            self._current_dataset_measures = {}
440            for calc in self.dataset_calculators:
441                values = calc.finalize()
442                self._current_dataset_measures.update(values.items())
443
444        return self._current_dataset_measures
445
446    def update_and_get_all(
447        self,
448        sample: Union[Sample, Iterable[Sample]],
449    ) -> Dict[Measure, MeasureValue]:
450        """Returns sample as well as updated dataset statistics"""
451        last_sample = self._update(sample)
452        if last_sample is None:
453            raise ValueError("`sample` was not a `Sample`, nor did it yield any.")
454
455        return {**self._compute(last_sample), **self.finalize()}
456
457    def skip_update_and_get_all(self, sample: Sample) -> Dict[Measure, MeasureValue]:
458        """Returns sample as well as previously computed dataset statistics"""
459        return {**self._compute(sample), **self.finalize()}
460
461    def _compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]:
462        ret: Dict[SampleMeasure, MeasureValue] = {}
463        for calc in self.sample_calculators:
464            values = calc.compute(sample)
465            ret.update(values.items())
466
467        return ret
468
469    def _update(self, sample: Union[Sample, Iterable[Sample]]) -> Optional[Sample]:
470        self.sample_count += 1
471        samples = [sample] if isinstance(sample, Sample) else sample
472        last_sample = None
473        for el in samples:
474            last_sample = el
475            for calc in self.dataset_calculators:
476                calc.update(el)
477
478        self._current_dataset_measures = None
479        return last_sample

Estimates dataset statistics and computes sample statistics efficiently

StatsCalculator( measures: Collection[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], initial_dataset_measures: Optional[Mapping[Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer>, return_type=PydanticUndefined, when_used='always')]]]] = None)
395    def __init__(
396        self,
397        measures: Collection[Measure],
398        initial_dataset_measures: Optional[
399            Mapping[DatasetMeasure, MeasureValue]
400        ] = None,
401    ):
402        super().__init__()
403        self.sample_count = 0
404        self.sample_calculators, self.dataset_calculators = get_measure_calculators(
405            measures
406        )
407        if not initial_dataset_measures:
408            self._current_dataset_measures: Optional[
409                Dict[DatasetMeasure, MeasureValue]
410            ] = None
411        else:
412            missing_dataset_meas = {
413                m
414                for m in measures
415                if isinstance(m, DatasetMeasureBase)
416                and m not in initial_dataset_measures
417            }
418            if missing_dataset_meas:
419                logger.debug(
420                    f"ignoring `initial_dataset_measure` as it is missing {missing_dataset_meas}"
421                )
422                self._current_dataset_measures = None
423            else:
424                self._current_dataset_measures = dict(initial_dataset_measures)
sample_count
has_dataset_measures
426    @property
427    def has_dataset_measures(self):
428        return self._current_dataset_measures is not None
def update( self, sample: Union[bioimageio.core.Sample, Iterable[bioimageio.core.Sample]]) -> None:
430    def update(
431        self,
432        sample: Union[Sample, Iterable[Sample]],
433    ) -> None:
434        _ = self._update(sample)
def finalize( self) -> Dict[Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7fc58e9393a0>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7fc58e939580>, return_type=PydanticUndefined, when_used='always')]]]:
436    def finalize(self) -> Dict[DatasetMeasure, MeasureValue]:
437        """returns aggregated dataset statistics"""
438        if self._current_dataset_measures is None:
439            self._current_dataset_measures = {}
440            for calc in self.dataset_calculators:
441                values = calc.finalize()
442                self._current_dataset_measures.update(values.items())
443
444        return self._current_dataset_measures

returns aggregated dataset statistics

def update_and_get_all( self, sample: Union[bioimageio.core.Sample, Iterable[bioimageio.core.Sample]]) -> Dict[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7fc58e9393a0>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7fc58e939580>, return_type=PydanticUndefined, when_used='always')]]]:
446    def update_and_get_all(
447        self,
448        sample: Union[Sample, Iterable[Sample]],
449    ) -> Dict[Measure, MeasureValue]:
450        """Returns sample as well as updated dataset statistics"""
451        last_sample = self._update(sample)
452        if last_sample is None:
453            raise ValueError("`sample` was not a `Sample`, nor did it yield any.")
454
455        return {**self._compute(last_sample), **self.finalize()}

Returns sample as well as updated dataset statistics

def skip_update_and_get_all( self, sample: bioimageio.core.Sample) -> Dict[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7fc58e9393a0>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7fc58e939580>, return_type=PydanticUndefined, when_used='always')]]]:
457    def skip_update_and_get_all(self, sample: Sample) -> Dict[Measure, MeasureValue]:
458        """Returns sample as well as previously computed dataset statistics"""
459        return {**self._compute(sample), **self.finalize()}

Returns sample as well as previously computed dataset statistics

def get_measure_calculators( required_measures: Iterable[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]) -> Tuple[List[Union[MeanCalculator, MeanVarStdCalculator, SamplePercentilesCalculator, NaiveSampleMeasureCalculator]], List[Union[MeanCalculator, MeanVarStdCalculator, CrickPercentilesCalculator]]]:
482def get_measure_calculators(
483    required_measures: Iterable[Measure],
484) -> Tuple[List[SampleMeasureCalculator], List[DatasetMeasureCalculator]]:
485    """determines which calculators are needed to compute the required measures efficiently"""
486
487    sample_calculators: List[SampleMeasureCalculator] = []
488    dataset_calculators: List[DatasetMeasureCalculator] = []
489
490    # split required measures into groups
491    required_sample_means: Set[SampleMean] = set()
492    required_dataset_means: Set[DatasetMean] = set()
493    required_sample_mean_var_std: Set[Union[SampleMean, SampleVar, SampleStd]] = set()
494    required_dataset_mean_var_std: Set[Union[DatasetMean, DatasetVar, DatasetStd]] = (
495        set()
496    )
497    required_sample_percentiles: Dict[
498        Tuple[MemberId, Optional[Tuple[AxisId, ...]]], Set[float]
499    ] = {}
500    required_dataset_percentiles: Dict[
501        Tuple[MemberId, Optional[Tuple[AxisId, ...]]], Set[float]
502    ] = {}
503
504    for rm in required_measures:
505        if isinstance(rm, SampleMean):
506            required_sample_means.add(rm)
507        elif isinstance(rm, DatasetMean):
508            required_dataset_means.add(rm)
509        elif isinstance(rm, (SampleVar, SampleStd)):
510            required_sample_mean_var_std.update(
511                {
512                    msv(axes=rm.axes, member_id=rm.member_id)
513                    for msv in (SampleMean, SampleStd, SampleVar)
514                }
515            )
516            assert rm in required_sample_mean_var_std
517        elif isinstance(rm, (DatasetVar, DatasetStd)):
518            required_dataset_mean_var_std.update(
519                {
520                    msv(axes=rm.axes, member_id=rm.member_id)
521                    for msv in (DatasetMean, DatasetStd, DatasetVar)
522                }
523            )
524            assert rm in required_dataset_mean_var_std
525        elif isinstance(rm, SampleQuantile):
526            required_sample_percentiles.setdefault((rm.member_id, rm.axes), set()).add(
527                rm.q
528            )
529        elif isinstance(rm, DatasetPercentile):
530            required_dataset_percentiles.setdefault((rm.member_id, rm.axes), set()).add(
531                rm.q
532            )
533        else:
534            assert_never(rm)
535
536    for rm in required_sample_means:
537        if rm in required_sample_mean_var_std:
538            # computed togehter with var and std
539            continue
540
541        sample_calculators.append(MeanCalculator(member_id=rm.member_id, axes=rm.axes))
542
543    for rm in required_sample_mean_var_std:
544        sample_calculators.append(
545            MeanVarStdCalculator(member_id=rm.member_id, axes=rm.axes)
546        )
547
548    for rm in required_dataset_means:
549        if rm in required_dataset_mean_var_std:
550            # computed togehter with var and std
551            continue
552
553        dataset_calculators.append(MeanCalculator(member_id=rm.member_id, axes=rm.axes))
554
555    for rm in required_dataset_mean_var_std:
556        dataset_calculators.append(
557            MeanVarStdCalculator(member_id=rm.member_id, axes=rm.axes)
558        )
559
560    for (tid, axes), qs in required_sample_percentiles.items():
561        sample_calculators.append(
562            SamplePercentilesCalculator(member_id=tid, axes=axes, qs=qs)
563        )
564
565    for (tid, axes), qs in required_dataset_percentiles.items():
566        dataset_calculators.append(
567            DatasetPercentilesCalculator(member_id=tid, axes=axes, qs=qs)
568        )
569
570    return sample_calculators, dataset_calculators

determines which calculators are needed to compute the required measures efficiently

def compute_dataset_measures( measures: Iterable[Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], dataset: Iterable[bioimageio.core.Sample]) -> Dict[Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7fc58e9393a0>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7fc58e939580>, return_type=PydanticUndefined, when_used='always')]]]:
573def compute_dataset_measures(
574    measures: Iterable[DatasetMeasure], dataset: Iterable[Sample]
575) -> Dict[DatasetMeasure, MeasureValue]:
576    """compute all dataset `measures` for the given `dataset`"""
577    sample_calculators, calculators = get_measure_calculators(measures)
578    assert not sample_calculators
579
580    ret: Dict[DatasetMeasure, MeasureValue] = {}
581
582    for sample in dataset:
583        for calc in calculators:
584            calc.update(sample)
585
586    for calc in calculators:
587        ret.update(calc.finalize().items())
588
589    return ret

compute all dataset measures for the given dataset

def compute_sample_measures( measures: Iterable[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], sample: bioimageio.core.Sample) -> Dict[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7fc58e9393a0>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7fc58e939580>, return_type=PydanticUndefined, when_used='always')]]]:
592def compute_sample_measures(
593    measures: Iterable[SampleMeasure], sample: Sample
594) -> Dict[SampleMeasure, MeasureValue]:
595    """compute all sample `measures` for the given `sample`"""
596    calculators, dataset_calculators = get_measure_calculators(measures)
597    assert not dataset_calculators
598    ret: Dict[SampleMeasure, MeasureValue] = {}
599
600    for calc in calculators:
601        ret.update(calc.compute(sample).items())
602
603    return ret

compute all sample measures for the given sample

def compute_measures( measures: Iterable[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], dataset: Iterable[bioimageio.core.Sample]) -> Dict[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator at 0x7fc58e9393a0>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer at 0x7fc58e939580>, return_type=PydanticUndefined, when_used='always')]]]:
606def compute_measures(
607    measures: Iterable[Measure], dataset: Iterable[Sample]
608) -> Dict[Measure, MeasureValue]:
609    """compute all `measures` for the given `dataset`
610    sample measures are computed for the last sample in `dataset`"""
611    sample_calculators, dataset_calculators = get_measure_calculators(measures)
612    ret: Dict[Measure, MeasureValue] = {}
613    sample = None
614    for sample in dataset:
615        for calc in dataset_calculators:
616            calc.update(sample)
617    if sample is None:
618        raise ValueError("empty dataset")
619
620    for calc in dataset_calculators:
621        ret.update(calc.finalize().items())
622
623    for calc in sample_calculators:
624        ret.update(calc.compute(sample).items())
625
626    return ret

compute all measures for the given dataset sample measures are computed for the last sample in dataset

DatasetPercentilesCalculator = <class 'CrickPercentilesCalculator'>