Coverage for src/bioimageio/core/stat_calculators.py: 75%
327 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-13 11:02 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-13 11:02 +0000
1from __future__ import annotations
3import collections
4import warnings
5from itertools import product
6from typing import (
7 Any,
8 Collection,
9 Dict,
10 Iterable,
11 Iterator,
12 List,
13 Mapping,
14 Optional,
15 OrderedDict,
16 Sequence,
17 Set,
18 Tuple,
19 Type,
20 Union,
21)
23import numpy as np
24import xarray as xr
25from bioimageio.spec.model.v0_5 import BATCH_AXIS_ID
26from loguru import logger
27from numpy.typing import NDArray
28from typing_extensions import assert_never
30from .axis import AxisId, PerAxis
31from .common import MemberId
32from .sample import Sample
33from .stat_measures import (
34 DatasetMean,
35 DatasetMeasure,
36 DatasetMeasureBase,
37 DatasetPercentile,
38 DatasetStd,
39 DatasetVar,
40 Measure,
41 MeasureValue,
42 SampleMean,
43 SampleMeasure,
44 SampleQuantile,
45 SampleStd,
46 SampleVar,
47)
48from .tensor import Tensor
50try:
51 import crick # pyright: ignore[reportMissingTypeStubs]
53except Exception:
54 crick = None
56 class TDigest:
57 def update(self, obj: Any):
58 pass
60 def quantile(self, q: Any) -> Any:
61 pass
63else:
64 TDigest = crick.TDigest # type: ignore
67class MeanCalculator:
68 """to calculate sample and dataset mean for in-memory samples"""
70 def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]):
71 super().__init__()
72 self._n: int = 0
73 self._mean: Optional[Tensor] = None
74 self._axes = None if axes is None else tuple(axes)
75 self._member_id = member_id
76 self._sample_mean = SampleMean(member_id=self._member_id, axes=self._axes)
77 self._dataset_mean = DatasetMean(member_id=self._member_id, axes=self._axes)
79 def compute(self, sample: Sample) -> Dict[SampleMean, MeasureValue]:
80 return {self._sample_mean: self._compute_impl(sample)}
82 def _compute_impl(self, sample: Sample) -> Tensor:
83 tensor = sample.members[self._member_id].astype("float64", copy=False)
84 return tensor.mean(dim=self._axes)
86 def update(self, sample: Sample) -> None:
87 mean = self._compute_impl(sample)
88 self._update_impl(sample.members[self._member_id], mean)
90 def compute_and_update(self, sample: Sample) -> Dict[SampleMean, MeasureValue]:
91 mean = self._compute_impl(sample)
92 self._update_impl(sample.members[self._member_id], mean)
93 return {self._sample_mean: mean}
95 def _update_impl(self, tensor: Tensor, tensor_mean: Tensor):
96 assert tensor_mean.dtype == "float64"
97 # reduced voxel count
98 n_b = int(tensor.size / tensor_mean.size)
100 if self._mean is None:
101 assert self._n == 0
102 self._n = n_b
103 self._mean = tensor_mean
104 else:
105 assert self._n != 0
106 n_a = self._n
107 mean_old = self._mean
108 self._n = n_a + n_b
109 self._mean = (n_a * mean_old + n_b * tensor_mean) / self._n
110 assert self._mean.dtype == "float64"
112 def finalize(self) -> Dict[DatasetMean, MeasureValue]:
113 if self._mean is None:
114 return {}
115 else:
116 return {self._dataset_mean: self._mean}
119class MeanVarStdCalculator:
120 """to calculate sample and dataset mean, variance or standard deviation"""
122 def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]):
123 super().__init__()
124 self._axes = None if axes is None else tuple(map(AxisId, axes))
125 self._member_id = member_id
126 self._n: int = 0
127 self._mean: Optional[Tensor] = None
128 self._m2: Optional[Tensor] = None
130 def compute(
131 self, sample: Sample
132 ) -> Dict[Union[SampleMean, SampleVar, SampleStd], MeasureValue]:
133 tensor = sample.members[self._member_id]
134 mean = tensor.mean(dim=self._axes)
135 c = (tensor - mean).data
136 if self._axes is None:
137 n = tensor.size
138 else:
139 n = int(np.prod([tensor.sizes[d] for d in self._axes]))
141 if xr.__version__.startswith("2023"):
142 var = xr.dot(c, c, dims=self._axes) / n
143 else:
144 var = xr.dot(c, c, dim=self._axes) / n
146 assert isinstance(var, xr.DataArray)
147 std = np.sqrt(var)
148 assert isinstance(std, xr.DataArray)
149 return {
150 SampleMean(axes=self._axes, member_id=self._member_id): mean,
151 SampleVar(axes=self._axes, member_id=self._member_id): Tensor.from_xarray(
152 var
153 ),
154 SampleStd(axes=self._axes, member_id=self._member_id): Tensor.from_xarray(
155 std
156 ),
157 }
159 def update(self, sample: Sample):
160 if self._axes is not None and BATCH_AXIS_ID not in self._axes:
161 return
163 tensor = sample.members[self._member_id].astype("float64", copy=False)
164 mean_b = tensor.mean(dim=self._axes)
165 assert mean_b.dtype == "float64"
166 # reduced voxel count
167 n_b = int(tensor.size / mean_b.size)
168 m2_b = ((tensor - mean_b) ** 2).sum(dim=self._axes)
169 assert m2_b.dtype == "float64"
170 if self._mean is None:
171 assert self._m2 is None
172 self._n = n_b
173 self._mean = mean_b
174 self._m2 = m2_b
175 else:
176 n_a = self._n
177 mean_a = self._mean
178 m2_a = self._m2
179 self._n = n = n_a + n_b
180 self._mean = (n_a * mean_a + n_b * mean_b) / n
181 assert self._mean.dtype == "float64"
182 d = mean_b - mean_a
183 self._m2 = m2_a + m2_b + d**2 * n_a * n_b / n
184 assert self._m2.dtype == "float64"
186 def finalize(
187 self,
188 ) -> Dict[Union[DatasetMean, DatasetVar, DatasetStd], MeasureValue]:
189 if (
190 self._axes is not None
191 and BATCH_AXIS_ID not in self._axes
192 or self._mean is None
193 ):
194 return {}
195 else:
196 assert self._m2 is not None
197 var = self._m2 / self._n
198 sqrt = var**0.5
199 if isinstance(sqrt, (int, float)):
200 # var and mean are scalar tensors, let's keep it consistent
201 sqrt = Tensor.from_xarray(xr.DataArray(sqrt))
203 assert isinstance(sqrt, Tensor), type(sqrt)
204 return {
205 DatasetMean(member_id=self._member_id, axes=self._axes): self._mean,
206 DatasetVar(member_id=self._member_id, axes=self._axes): var,
207 DatasetStd(member_id=self._member_id, axes=self._axes): sqrt,
208 }
211class SamplePercentilesCalculator:
212 """to calculate sample percentiles"""
214 def __init__(
215 self,
216 member_id: MemberId,
217 axes: Optional[Sequence[AxisId]],
218 qs: Collection[float],
219 ):
220 super().__init__()
221 assert all(0.0 <= q <= 1.0 for q in qs)
222 self._qs = sorted(set(qs))
223 self._axes = None if axes is None else tuple(axes)
224 self._member_id = member_id
226 def compute(self, sample: Sample) -> Dict[SampleQuantile, MeasureValue]:
227 tensor = sample.members[self._member_id]
228 ps = tensor.quantile(self._qs, dim=self._axes)
229 return {
230 SampleQuantile(q=q, axes=self._axes, member_id=self._member_id): p
231 for q, p in zip(self._qs, ps)
232 }
235class MeanPercentilesCalculator:
236 """to calculate dataset percentiles heuristically by averaging across samples
237 **note**: the returned dataset percentiles are an estiamte and **not mathematically correct**
238 """
240 def __init__(
241 self,
242 member_id: MemberId,
243 axes: Optional[Sequence[AxisId]],
244 qs: Collection[float],
245 ):
246 super().__init__()
247 assert all(0.0 <= q <= 1.0 for q in qs)
248 self._qs = sorted(set(qs))
249 self._axes = None if axes is None else tuple(axes)
250 self._member_id = member_id
251 self._n: int = 0
252 self._estimates: Optional[Tensor] = None
254 def update(self, sample: Sample):
255 tensor = sample.members[self._member_id]
256 sample_estimates = tensor.quantile(self._qs, dim=self._axes).astype(
257 "float64", copy=False
258 )
260 # reduced voxel count
261 n = int(tensor.size / np.prod(sample_estimates.shape_tuple[1:]))
263 if self._estimates is None:
264 assert self._n == 0
265 self._estimates = sample_estimates
266 else:
267 self._estimates = (self._n * self._estimates + n * sample_estimates) / (
268 self._n + n
269 )
270 assert self._estimates.dtype == "float64"
272 self._n += n
274 def finalize(self) -> Dict[DatasetPercentile, MeasureValue]:
275 if self._estimates is None:
276 return {}
277 else:
278 warnings.warn(
279 "Computed dataset percentiles naively by averaging percentiles of samples."
280 )
281 return {
282 DatasetPercentile(q=q, axes=self._axes, member_id=self._member_id): e
283 for q, e in zip(self._qs, self._estimates)
284 }
287class CrickPercentilesCalculator:
288 """to calculate dataset percentiles with the experimental [crick libray](https://github.com/dask/crick)"""
290 def __init__(
291 self,
292 member_id: MemberId,
293 axes: Optional[Sequence[AxisId]],
294 qs: Collection[float],
295 ):
296 warnings.warn(
297 "Computing dataset percentiles with experimental 'crick' library."
298 )
299 super().__init__()
300 assert all(0.0 <= q <= 1.0 for q in qs)
301 assert axes is None or "_percentiles" not in axes
302 self._qs = sorted(set(qs))
303 self._axes = None if axes is None else tuple(axes)
304 self._member_id = member_id
305 self._digest: Optional[List[TDigest]] = None
306 self._dims: Optional[Tuple[AxisId, ...]] = None
307 self._indices: Optional[Iterator[Tuple[int, ...]]] = None
308 self._shape: Optional[Tuple[int, ...]] = None
310 def _initialize(self, tensor_sizes: PerAxis[int]):
311 assert crick is not None
312 out_sizes: OrderedDict[AxisId, int] = collections.OrderedDict(
313 _percentiles=len(self._qs)
314 )
315 if self._axes is not None:
316 for d, s in tensor_sizes.items():
317 if d not in self._axes:
318 out_sizes[d] = s
320 self._dims, self._shape = zip(*out_sizes.items())
321 assert self._shape is not None
322 d = int(np.prod(self._shape[1:]))
323 self._digest = [TDigest() for _ in range(d)]
324 self._indices = product(*map(range, self._shape[1:]))
326 def update(self, part: Sample):
327 tensor = (
328 part.members[self._member_id]
329 if isinstance(part, Sample)
330 else part.members[self._member_id].data
331 )
332 assert "_percentiles" not in tensor.dims
333 if self._digest is None:
334 self._initialize(tensor.tagged_shape)
336 assert self._digest is not None
337 assert self._indices is not None
338 assert self._dims is not None
339 for i, idx in enumerate(self._indices):
340 self._digest[i].update(tensor[dict(zip(self._dims[1:], idx))])
342 def finalize(self) -> Dict[DatasetPercentile, MeasureValue]:
343 if self._digest is None:
344 return {}
345 else:
346 assert self._dims is not None
347 assert self._shape is not None
349 vs: NDArray[Any] = np.asarray(
350 [[d.quantile(q) for d in self._digest] for q in self._qs]
351 ).reshape(self._shape)
352 return {
353 DatasetPercentile(
354 q=q, axes=self._axes, member_id=self._member_id
355 ): Tensor(v, dims=self._dims[1:])
356 for q, v in zip(self._qs, vs)
357 }
360if crick is None:
361 DatasetPercentilesCalculator: Type[
362 Union[MeanPercentilesCalculator, CrickPercentilesCalculator]
363 ] = MeanPercentilesCalculator
364else:
365 DatasetPercentilesCalculator = CrickPercentilesCalculator
368class NaiveSampleMeasureCalculator:
369 """wrapper for measures to match interface of other sample measure calculators"""
371 def __init__(self, member_id: MemberId, measure: SampleMeasure):
372 super().__init__()
373 self.tensor_name = member_id
374 self.measure = measure
376 def compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]:
377 return {self.measure: self.measure.compute(sample)}
380SampleMeasureCalculator = Union[
381 MeanCalculator,
382 MeanVarStdCalculator,
383 SamplePercentilesCalculator,
384 NaiveSampleMeasureCalculator,
385]
386DatasetMeasureCalculator = Union[
387 MeanCalculator, MeanVarStdCalculator, DatasetPercentilesCalculator
388]
391class StatsCalculator:
392 """Estimates dataset statistics and computes sample statistics efficiently"""
394 def __init__(
395 self,
396 measures: Collection[Measure],
397 initial_dataset_measures: Optional[
398 Mapping[DatasetMeasure, MeasureValue]
399 ] = None,
400 ):
401 super().__init__()
402 self.sample_count = 0
403 self.sample_calculators, self.dataset_calculators = get_measure_calculators(
404 measures
405 )
406 if not initial_dataset_measures:
407 self._current_dataset_measures: Optional[
408 Dict[DatasetMeasure, MeasureValue]
409 ] = None
410 else:
411 missing_dataset_meas = {
412 m
413 for m in measures
414 if isinstance(m, DatasetMeasureBase)
415 and m not in initial_dataset_measures
416 }
417 if missing_dataset_meas:
418 logger.debug(
419 f"ignoring `initial_dataset_measure` as it is missing {missing_dataset_meas}"
420 )
421 self._current_dataset_measures = None
422 else:
423 self._current_dataset_measures = dict(initial_dataset_measures)
425 @property
426 def has_dataset_measures(self):
427 return self._current_dataset_measures is not None
429 def update(
430 self,
431 sample: Union[Sample, Iterable[Sample]],
432 ) -> None:
433 _ = self._update(sample)
435 def finalize(self) -> Dict[DatasetMeasure, MeasureValue]:
436 """returns aggregated dataset statistics"""
437 if self._current_dataset_measures is None:
438 self._current_dataset_measures = {}
439 for calc in self.dataset_calculators:
440 values = calc.finalize()
441 self._current_dataset_measures.update(values.items())
443 return self._current_dataset_measures
445 def update_and_get_all(
446 self,
447 sample: Union[Sample, Iterable[Sample]],
448 ) -> Dict[Measure, MeasureValue]:
449 """Returns sample as well as updated dataset statistics"""
450 last_sample = self._update(sample)
451 if last_sample is None:
452 raise ValueError("`sample` was not a `Sample`, nor did it yield any.")
454 return {**self._compute(last_sample), **self.finalize()}
456 def skip_update_and_get_all(self, sample: Sample) -> Dict[Measure, MeasureValue]:
457 """Returns sample as well as previously computed dataset statistics"""
458 return {**self._compute(sample), **self.finalize()}
460 def _compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]:
461 ret: Dict[SampleMeasure, MeasureValue] = {}
462 for calc in self.sample_calculators:
463 values = calc.compute(sample)
464 ret.update(values.items())
466 return ret
468 def _update(self, sample: Union[Sample, Iterable[Sample]]) -> Optional[Sample]:
469 self.sample_count += 1
470 samples = [sample] if isinstance(sample, Sample) else sample
471 last_sample = None
472 for el in samples:
473 last_sample = el
474 for calc in self.dataset_calculators:
475 calc.update(el)
477 self._current_dataset_measures = None
478 return last_sample
481def get_measure_calculators(
482 required_measures: Iterable[Measure],
483) -> Tuple[List[SampleMeasureCalculator], List[DatasetMeasureCalculator]]:
484 """determines which calculators are needed to compute the required measures efficiently"""
486 sample_calculators: List[SampleMeasureCalculator] = []
487 dataset_calculators: List[DatasetMeasureCalculator] = []
489 # split required measures into groups
490 required_sample_means: Set[SampleMean] = set()
491 required_dataset_means: Set[DatasetMean] = set()
492 required_sample_mean_var_std: Set[Union[SampleMean, SampleVar, SampleStd]] = set()
493 required_dataset_mean_var_std: Set[Union[DatasetMean, DatasetVar, DatasetStd]] = (
494 set()
495 )
496 required_sample_percentiles: Dict[
497 Tuple[MemberId, Optional[Tuple[AxisId, ...]]], Set[float]
498 ] = {}
499 required_dataset_percentiles: Dict[
500 Tuple[MemberId, Optional[Tuple[AxisId, ...]]], Set[float]
501 ] = {}
503 for rm in required_measures:
504 if isinstance(rm, SampleMean):
505 required_sample_means.add(rm)
506 elif isinstance(rm, DatasetMean):
507 required_dataset_means.add(rm)
508 elif isinstance(rm, (SampleVar, SampleStd)):
509 required_sample_mean_var_std.update(
510 {
511 msv(axes=rm.axes, member_id=rm.member_id)
512 for msv in (SampleMean, SampleStd, SampleVar)
513 }
514 )
515 assert rm in required_sample_mean_var_std
516 elif isinstance(rm, (DatasetVar, DatasetStd)):
517 required_dataset_mean_var_std.update(
518 {
519 msv(axes=rm.axes, member_id=rm.member_id)
520 for msv in (DatasetMean, DatasetStd, DatasetVar)
521 }
522 )
523 assert rm in required_dataset_mean_var_std
524 elif isinstance(rm, SampleQuantile):
525 required_sample_percentiles.setdefault((rm.member_id, rm.axes), set()).add(
526 rm.q
527 )
528 elif isinstance(rm, DatasetPercentile):
529 required_dataset_percentiles.setdefault((rm.member_id, rm.axes), set()).add(
530 rm.q
531 )
532 else:
533 assert_never(rm)
535 for rm in required_sample_means:
536 if rm in required_sample_mean_var_std:
537 # computed togehter with var and std
538 continue
540 sample_calculators.append(MeanCalculator(member_id=rm.member_id, axes=rm.axes))
542 for rm in required_sample_mean_var_std:
543 sample_calculators.append(
544 MeanVarStdCalculator(member_id=rm.member_id, axes=rm.axes)
545 )
547 for rm in required_dataset_means:
548 if rm in required_dataset_mean_var_std:
549 # computed togehter with var and std
550 continue
552 dataset_calculators.append(MeanCalculator(member_id=rm.member_id, axes=rm.axes))
554 for rm in required_dataset_mean_var_std:
555 dataset_calculators.append(
556 MeanVarStdCalculator(member_id=rm.member_id, axes=rm.axes)
557 )
559 for (tid, axes), qs in required_sample_percentiles.items():
560 sample_calculators.append(
561 SamplePercentilesCalculator(member_id=tid, axes=axes, qs=qs)
562 )
564 for (tid, axes), qs in required_dataset_percentiles.items():
565 dataset_calculators.append(
566 DatasetPercentilesCalculator(member_id=tid, axes=axes, qs=qs)
567 )
569 return sample_calculators, dataset_calculators
572def compute_dataset_measures(
573 measures: Iterable[DatasetMeasure], dataset: Iterable[Sample]
574) -> Dict[DatasetMeasure, MeasureValue]:
575 """compute all dataset `measures` for the given `dataset`"""
576 sample_calculators, calculators = get_measure_calculators(measures)
577 assert not sample_calculators
579 ret: Dict[DatasetMeasure, MeasureValue] = {}
581 for sample in dataset:
582 for calc in calculators:
583 calc.update(sample)
585 for calc in calculators:
586 ret.update(calc.finalize().items())
588 return ret
591def compute_sample_measures(
592 measures: Iterable[SampleMeasure], sample: Sample
593) -> Dict[SampleMeasure, MeasureValue]:
594 """compute all sample `measures` for the given `sample`"""
595 calculators, dataset_calculators = get_measure_calculators(measures)
596 assert not dataset_calculators
597 ret: Dict[SampleMeasure, MeasureValue] = {}
599 for calc in calculators:
600 ret.update(calc.compute(sample).items())
602 return ret
605def compute_measures(
606 measures: Iterable[Measure], dataset: Iterable[Sample]
607) -> Dict[Measure, MeasureValue]:
608 """compute all `measures` for the given `dataset`
609 sample measures are computed for the last sample in `dataset`"""
610 sample_calculators, dataset_calculators = get_measure_calculators(measures)
611 ret: Dict[Measure, MeasureValue] = {}
612 sample = None
613 for sample in dataset:
614 for calc in dataset_calculators:
615 calc.update(sample)
616 if sample is None:
617 raise ValueError("empty dataset")
619 for calc in dataset_calculators:
620 ret.update(calc.finalize().items())
622 for calc in sample_calculators:
623 ret.update(calc.compute(sample).items())
625 return ret