Coverage for bioimageio/core/stat_calculators.py: 65%
321 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-19 09:02 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-19 09:02 +0000
1from __future__ import annotations
3import collections.abc
4import warnings
5from itertools import product
6from typing import (
7 Any,
8 Collection,
9 Dict,
10 Iterable,
11 Iterator,
12 List,
13 Mapping,
14 Optional,
15 OrderedDict,
16 Sequence,
17 Set,
18 Tuple,
19 Type,
20 Union,
21)
23import numpy as np
24import xarray as xr
25from loguru import logger
26from numpy.typing import NDArray
27from typing_extensions import assert_never
29from .axis import AxisId, PerAxis
30from .common import MemberId
31from .sample import Sample
32from .stat_measures import (
33 DatasetMean,
34 DatasetMeasure,
35 DatasetMeasureBase,
36 DatasetPercentile,
37 DatasetStd,
38 DatasetVar,
39 Measure,
40 MeasureValue,
41 SampleMean,
42 SampleMeasure,
43 SampleQuantile,
44 SampleStd,
45 SampleVar,
46)
47from .tensor import Tensor
49try:
50 import crick
52except Exception:
53 crick = None
55 class TDigest:
56 def update(self, obj: Any):
57 pass
59 def quantile(self, q: Any) -> Any:
60 pass
62else:
63 TDigest = crick.TDigest # type: ignore
66class MeanCalculator:
67 """to calculate sample and dataset mean for in-memory samples"""
69 def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]):
70 super().__init__()
71 self._n: int = 0
72 self._mean: Optional[Tensor] = None
73 self._axes = None if axes is None else tuple(axes)
74 self._member_id = member_id
75 self._sample_mean = SampleMean(member_id=self._member_id, axes=self._axes)
76 self._dataset_mean = DatasetMean(member_id=self._member_id, axes=self._axes)
78 def compute(self, sample: Sample) -> Dict[SampleMean, MeasureValue]:
79 return {self._sample_mean: self._compute_impl(sample)}
81 def _compute_impl(self, sample: Sample) -> Tensor:
82 tensor = sample.members[self._member_id].astype("float64", copy=False)
83 return tensor.mean(dim=self._axes)
85 def update(self, sample: Sample) -> None:
86 mean = self._compute_impl(sample)
87 self._update_impl(sample.members[self._member_id], mean)
89 def compute_and_update(self, sample: Sample) -> Dict[SampleMean, MeasureValue]:
90 mean = self._compute_impl(sample)
91 self._update_impl(sample.members[self._member_id], mean)
92 return {self._sample_mean: mean}
94 def _update_impl(self, tensor: Tensor, tensor_mean: Tensor):
95 assert tensor_mean.dtype == "float64"
96 # reduced voxel count
97 n_b = int(tensor.size / tensor_mean.size)
99 if self._mean is None:
100 assert self._n == 0
101 self._n = n_b
102 self._mean = tensor_mean
103 else:
104 assert self._n != 0
105 n_a = self._n
106 mean_old = self._mean
107 self._n = n_a + n_b
108 self._mean = (n_a * mean_old + n_b * tensor_mean) / self._n
109 assert self._mean.dtype == "float64"
111 def finalize(self) -> Dict[DatasetMean, MeasureValue]:
112 if self._mean is None:
113 return {}
114 else:
115 return {self._dataset_mean: self._mean}
118class MeanVarStdCalculator:
119 """to calculate sample and dataset mean, variance or standard deviation"""
121 def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]):
122 super().__init__()
123 self._axes = None if axes is None else tuple(axes)
124 self._member_id = member_id
125 self._n: int = 0
126 self._mean: Optional[Tensor] = None
127 self._m2: Optional[Tensor] = None
129 def compute(
130 self, sample: Sample
131 ) -> Dict[Union[SampleMean, SampleVar, SampleStd], MeasureValue]:
132 tensor = sample.members[self._member_id]
133 mean = tensor.mean(dim=self._axes)
134 c = (tensor - mean).data
135 if self._axes is None:
136 n = tensor.size
137 else:
138 n = int(np.prod([tensor.sizes[d] for d in self._axes]))
140 var = xr.dot(c, c, dims=self._axes) / n
141 assert isinstance(var, xr.DataArray)
142 std = np.sqrt(var)
143 assert isinstance(std, xr.DataArray)
144 return {
145 SampleMean(axes=self._axes, member_id=self._member_id): mean,
146 SampleVar(axes=self._axes, member_id=self._member_id): Tensor.from_xarray(
147 var
148 ),
149 SampleStd(axes=self._axes, member_id=self._member_id): Tensor.from_xarray(
150 std
151 ),
152 }
154 def update(self, sample: Sample):
155 tensor = sample.members[self._member_id].astype("float64", copy=False)
156 mean_b = tensor.mean(dim=self._axes)
157 assert mean_b.dtype == "float64"
158 # reduced voxel count
159 n_b = int(tensor.size / mean_b.size)
160 m2_b = ((tensor - mean_b) ** 2).sum(dim=self._axes)
161 assert m2_b.dtype == "float64"
162 if self._mean is None:
163 assert self._m2 is None
164 self._n = n_b
165 self._mean = mean_b
166 self._m2 = m2_b
167 else:
168 n_a = self._n
169 mean_a = self._mean
170 m2_a = self._m2
171 self._n = n = n_a + n_b
172 self._mean = (n_a * mean_a + n_b * mean_b) / n
173 assert self._mean.dtype == "float64"
174 d = mean_b - mean_a
175 self._m2 = m2_a + m2_b + d**2 * n_a * n_b / n
176 assert self._m2.dtype == "float64"
178 def finalize(
179 self,
180 ) -> Dict[Union[DatasetMean, DatasetVar, DatasetStd], MeasureValue]:
181 if self._mean is None:
182 return {}
183 else:
184 assert self._m2 is not None
185 var = self._m2 / self._n
186 sqrt = np.sqrt(var)
187 if isinstance(sqrt, (int, float)):
188 # var and mean are scalar tensors, let's keep it consistent
189 sqrt = Tensor.from_xarray(xr.DataArray(sqrt))
191 assert isinstance(sqrt, Tensor), type(sqrt)
192 return {
193 DatasetMean(member_id=self._member_id, axes=self._axes): self._mean,
194 DatasetVar(member_id=self._member_id, axes=self._axes): var,
195 DatasetStd(member_id=self._member_id, axes=self._axes): sqrt,
196 }
199class SamplePercentilesCalculator:
200 """to calculate sample percentiles"""
202 def __init__(
203 self,
204 member_id: MemberId,
205 axes: Optional[Sequence[AxisId]],
206 qs: Collection[float],
207 ):
208 super().__init__()
209 assert all(0.0 <= q <= 1.0 for q in qs)
210 self._qs = sorted(set(qs))
211 self._axes = None if axes is None else tuple(axes)
212 self._member_id = member_id
214 def compute(self, sample: Sample) -> Dict[SampleQuantile, MeasureValue]:
215 tensor = sample.members[self._member_id]
216 ps = tensor.quantile(self._qs, dim=self._axes)
217 return {
218 SampleQuantile(q=q, axes=self._axes, member_id=self._member_id): p
219 for q, p in zip(self._qs, ps)
220 }
223class MeanPercentilesCalculator:
224 """to calculate dataset percentiles heuristically by averaging across samples
225 **note**: the returned dataset percentiles are an estiamte and **not mathematically correct**
226 """
228 def __init__(
229 self,
230 member_id: MemberId,
231 axes: Optional[Sequence[AxisId]],
232 qs: Collection[float],
233 ):
234 super().__init__()
235 assert all(0.0 <= q <= 1.0 for q in qs)
236 self._qs = sorted(set(qs))
237 self._axes = None if axes is None else tuple(axes)
238 self._member_id = member_id
239 self._n: int = 0
240 self._estimates: Optional[Tensor] = None
242 def update(self, sample: Sample):
243 tensor = sample.members[self._member_id]
244 sample_estimates = tensor.quantile(self._qs, dim=self._axes).astype(
245 "float64", copy=False
246 )
248 # reduced voxel count
249 n = int(tensor.size / np.prod(sample_estimates.shape_tuple[1:]))
251 if self._estimates is None:
252 assert self._n == 0
253 self._estimates = sample_estimates
254 else:
255 self._estimates = (self._n * self._estimates + n * sample_estimates) / (
256 self._n + n
257 )
258 assert self._estimates.dtype == "float64"
260 self._n += n
262 def finalize(self) -> Dict[DatasetPercentile, MeasureValue]:
263 if self._estimates is None:
264 return {}
265 else:
266 warnings.warn(
267 "Computed dataset percentiles naively by averaging percentiles of samples."
268 )
269 return {
270 DatasetPercentile(q=q, axes=self._axes, member_id=self._member_id): e
271 for q, e in zip(self._qs, self._estimates)
272 }
275class CrickPercentilesCalculator:
276 """to calculate dataset percentiles with the experimental [crick libray](https://github.com/dask/crick)"""
278 def __init__(
279 self,
280 member_id: MemberId,
281 axes: Optional[Sequence[AxisId]],
282 qs: Collection[float],
283 ):
284 warnings.warn(
285 "Computing dataset percentiles with experimental 'crick' library."
286 )
287 super().__init__()
288 assert all(0.0 <= q <= 1.0 for q in qs)
289 assert axes is None or "_percentiles" not in axes
290 self._qs = sorted(set(qs))
291 self._axes = None if axes is None else tuple(axes)
292 self._member_id = member_id
293 self._digest: Optional[List[TDigest]] = None
294 self._dims: Optional[Tuple[AxisId, ...]] = None
295 self._indices: Optional[Iterator[Tuple[int, ...]]] = None
296 self._shape: Optional[Tuple[int, ...]] = None
298 def _initialize(self, tensor_sizes: PerAxis[int]):
299 assert crick is not None
300 out_sizes: OrderedDict[AxisId, int] = collections.OrderedDict(
301 _percentiles=len(self._qs)
302 )
303 if self._axes is not None:
304 for d, s in tensor_sizes.items():
305 if d not in self._axes:
306 out_sizes[d] = s
308 self._dims, self._shape = zip(*out_sizes.items())
309 d = int(np.prod(self._shape[1:])) # type: ignore
310 self._digest = [TDigest() for _ in range(d)]
311 self._indices = product(*map(range, self._shape[1:]))
313 def update(self, part: Sample):
314 tensor = (
315 part.members[self._member_id]
316 if isinstance(part, Sample)
317 else part.members[self._member_id].data
318 )
319 assert "_percentiles" not in tensor.dims
320 if self._digest is None:
321 self._initialize(tensor.tagged_shape)
323 assert self._digest is not None
324 assert self._indices is not None
325 assert self._dims is not None
326 for i, idx in enumerate(self._indices):
327 self._digest[i].update(tensor[dict(zip(self._dims[1:], idx))])
329 def finalize(self) -> Dict[DatasetPercentile, MeasureValue]:
330 if self._digest is None:
331 return {}
332 else:
333 assert self._dims is not None
334 assert self._shape is not None
336 vs: NDArray[Any] = np.asarray(
337 [[d.quantile(q) for d in self._digest] for q in self._qs]
338 ).reshape(self._shape)
339 return {
340 DatasetPercentile(
341 q=q, axes=self._axes, member_id=self._member_id
342 ): Tensor(v, dims=self._dims[1:])
343 for q, v in zip(self._qs, vs)
344 }
347if crick is None:
348 DatasetPercentilesCalculator: Type[
349 Union[MeanPercentilesCalculator, CrickPercentilesCalculator]
350 ] = MeanPercentilesCalculator
351else:
352 DatasetPercentilesCalculator = CrickPercentilesCalculator
355class NaiveSampleMeasureCalculator:
356 """wrapper for measures to match interface of other sample measure calculators"""
358 def __init__(self, member_id: MemberId, measure: SampleMeasure):
359 super().__init__()
360 self.tensor_name = member_id
361 self.measure = measure
363 def compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]:
364 return {self.measure: self.measure.compute(sample)}
367SampleMeasureCalculator = Union[
368 MeanCalculator,
369 MeanVarStdCalculator,
370 SamplePercentilesCalculator,
371 NaiveSampleMeasureCalculator,
372]
373DatasetMeasureCalculator = Union[
374 MeanCalculator, MeanVarStdCalculator, DatasetPercentilesCalculator
375]
378class StatsCalculator:
379 """Estimates dataset statistics and computes sample statistics efficiently"""
381 def __init__(
382 self,
383 measures: Collection[Measure],
384 initial_dataset_measures: Optional[
385 Mapping[DatasetMeasure, MeasureValue]
386 ] = None,
387 ):
388 super().__init__()
389 self.sample_count = 0
390 self.sample_calculators, self.dataset_calculators = get_measure_calculators(
391 measures
392 )
393 if not initial_dataset_measures:
394 self._current_dataset_measures: Optional[
395 Dict[DatasetMeasure, MeasureValue]
396 ] = None
397 else:
398 missing_dataset_meas = {
399 m
400 for m in measures
401 if isinstance(m, DatasetMeasureBase)
402 and m not in initial_dataset_measures
403 }
404 if missing_dataset_meas:
405 logger.debug(
406 f"ignoring `initial_dataset_measure` as it is missing {missing_dataset_meas}"
407 )
408 self._current_dataset_measures = None
409 else:
410 self._current_dataset_measures = dict(initial_dataset_measures)
412 @property
413 def has_dataset_measures(self):
414 return self._current_dataset_measures is not None
416 def update(
417 self,
418 sample: Union[Sample, Iterable[Sample]],
419 ) -> None:
420 _ = self._update(sample)
422 def finalize(self) -> Dict[DatasetMeasure, MeasureValue]:
423 """returns aggregated dataset statistics"""
424 if self._current_dataset_measures is None:
425 self._current_dataset_measures = {}
426 for calc in self.dataset_calculators:
427 values = calc.finalize()
428 self._current_dataset_measures.update(values.items())
430 return self._current_dataset_measures
432 def update_and_get_all(
433 self,
434 sample: Union[Sample, Iterable[Sample]],
435 ) -> Dict[Measure, MeasureValue]:
436 """Returns sample as well as updated dataset statistics"""
437 last_sample = self._update(sample)
438 if last_sample is None:
439 raise ValueError("`sample` was not a `Sample`, nor did it yield any.")
441 return {**self._compute(last_sample), **self.finalize()}
443 def skip_update_and_get_all(self, sample: Sample) -> Dict[Measure, MeasureValue]:
444 """Returns sample as well as previously computed dataset statistics"""
445 return {**self._compute(sample), **self.finalize()}
447 def _compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]:
448 ret: Dict[SampleMeasure, MeasureValue] = {}
449 for calc in self.sample_calculators:
450 values = calc.compute(sample)
451 ret.update(values.items())
453 return ret
455 def _update(self, sample: Union[Sample, Iterable[Sample]]) -> Optional[Sample]:
456 self.sample_count += 1
457 samples = [sample] if isinstance(sample, Sample) else sample
458 last_sample = None
459 for el in samples:
460 last_sample = el
461 for calc in self.dataset_calculators:
462 calc.update(el)
464 self._current_dataset_measures = None
465 return last_sample
468def get_measure_calculators(
469 required_measures: Iterable[Measure],
470) -> Tuple[List[SampleMeasureCalculator], List[DatasetMeasureCalculator]]:
471 """determines which calculators are needed to compute the required measures efficiently"""
473 sample_calculators: List[SampleMeasureCalculator] = []
474 dataset_calculators: List[DatasetMeasureCalculator] = []
476 # split required measures into groups
477 required_sample_means: Set[SampleMean] = set()
478 required_dataset_means: Set[DatasetMean] = set()
479 required_sample_mean_var_std: Set[Union[SampleMean, SampleVar, SampleStd]] = set()
480 required_dataset_mean_var_std: Set[Union[DatasetMean, DatasetVar, DatasetStd]] = (
481 set()
482 )
483 required_sample_percentiles: Dict[
484 Tuple[MemberId, Optional[Tuple[AxisId, ...]]], Set[float]
485 ] = {}
486 required_dataset_percentiles: Dict[
487 Tuple[MemberId, Optional[Tuple[AxisId, ...]]], Set[float]
488 ] = {}
490 for rm in required_measures:
491 if isinstance(rm, SampleMean):
492 required_sample_means.add(rm)
493 elif isinstance(rm, DatasetMean):
494 required_dataset_means.add(rm)
495 elif isinstance(rm, (SampleVar, SampleStd)):
496 required_sample_mean_var_std.update(
497 {
498 msv(axes=rm.axes, member_id=rm.member_id)
499 for msv in (SampleMean, SampleStd, SampleVar)
500 }
501 )
502 assert rm in required_sample_mean_var_std
503 elif isinstance(rm, (DatasetVar, DatasetStd)):
504 required_dataset_mean_var_std.update(
505 {
506 msv(axes=rm.axes, member_id=rm.member_id)
507 for msv in (DatasetMean, DatasetStd, DatasetVar)
508 }
509 )
510 assert rm in required_dataset_mean_var_std
511 elif isinstance(rm, SampleQuantile):
512 required_sample_percentiles.setdefault((rm.member_id, rm.axes), set()).add(
513 rm.q
514 )
515 elif isinstance(rm, DatasetPercentile):
516 required_dataset_percentiles.setdefault((rm.member_id, rm.axes), set()).add(
517 rm.q
518 )
519 else:
520 assert_never(rm)
522 for rm in required_sample_means:
523 if rm in required_sample_mean_var_std:
524 # computed togehter with var and std
525 continue
527 sample_calculators.append(MeanCalculator(member_id=rm.member_id, axes=rm.axes))
529 for rm in required_sample_mean_var_std:
530 sample_calculators.append(
531 MeanVarStdCalculator(member_id=rm.member_id, axes=rm.axes)
532 )
534 for rm in required_dataset_means:
535 if rm in required_dataset_mean_var_std:
536 # computed togehter with var and std
537 continue
539 dataset_calculators.append(MeanCalculator(member_id=rm.member_id, axes=rm.axes))
541 for rm in required_dataset_mean_var_std:
542 dataset_calculators.append(
543 MeanVarStdCalculator(member_id=rm.member_id, axes=rm.axes)
544 )
546 for (tid, axes), qs in required_sample_percentiles.items():
547 sample_calculators.append(
548 SamplePercentilesCalculator(member_id=tid, axes=axes, qs=qs)
549 )
551 for (tid, axes), qs in required_dataset_percentiles.items():
552 dataset_calculators.append(
553 DatasetPercentilesCalculator(member_id=tid, axes=axes, qs=qs)
554 )
556 return sample_calculators, dataset_calculators
559def compute_dataset_measures(
560 measures: Iterable[DatasetMeasure], dataset: Iterable[Sample]
561) -> Dict[DatasetMeasure, MeasureValue]:
562 """compute all dataset `measures` for the given `dataset`"""
563 sample_calculators, calculators = get_measure_calculators(measures)
564 assert not sample_calculators
566 ret: Dict[DatasetMeasure, MeasureValue] = {}
568 for sample in dataset:
569 for calc in calculators:
570 calc.update(sample)
572 for calc in calculators:
573 ret.update(calc.finalize().items())
575 return ret
578def compute_sample_measures(
579 measures: Iterable[SampleMeasure], sample: Sample
580) -> Dict[SampleMeasure, MeasureValue]:
581 """compute all sample `measures` for the given `sample`"""
582 calculators, dataset_calculators = get_measure_calculators(measures)
583 assert not dataset_calculators
584 ret: Dict[SampleMeasure, MeasureValue] = {}
586 for calc in calculators:
587 ret.update(calc.compute(sample).items())
589 return ret
592def compute_measures(
593 measures: Iterable[Measure], dataset: Iterable[Sample]
594) -> Dict[Measure, MeasureValue]:
595 """compute all `measures` for the given `dataset`
596 sample measures are computed for the last sample in `dataset`"""
597 sample_calculators, dataset_calculators = get_measure_calculators(measures)
598 ret: Dict[Measure, MeasureValue] = {}
599 sample = None
600 for sample in dataset:
601 for calc in dataset_calculators:
602 calc.update(sample)
603 if sample is None:
604 raise ValueError("empty dataset")
606 for calc in dataset_calculators:
607 ret.update(calc.finalize().items())
609 for calc in sample_calculators:
610 ret.update(calc.compute(sample).items())
612 return ret