Coverage for bioimageio/core/stat_calculators.py: 68%
327 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 09:51 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 09:51 +0000
1from __future__ import annotations
3import collections
4import warnings
5from itertools import product
6from typing import (
7 Any,
8 Collection,
9 Dict,
10 Iterable,
11 Iterator,
12 List,
13 Mapping,
14 Optional,
15 OrderedDict,
16 Sequence,
17 Set,
18 Tuple,
19 Type,
20 Union,
21)
23import numpy as np
24import xarray as xr
25from loguru import logger
26from numpy.typing import NDArray
27from typing_extensions import assert_never
29from bioimageio.spec.model.v0_5 import BATCH_AXIS_ID
31from .axis import AxisId, PerAxis
32from .common import MemberId
33from .sample import Sample
34from .stat_measures import (
35 DatasetMean,
36 DatasetMeasure,
37 DatasetMeasureBase,
38 DatasetPercentile,
39 DatasetStd,
40 DatasetVar,
41 Measure,
42 MeasureValue,
43 SampleMean,
44 SampleMeasure,
45 SampleQuantile,
46 SampleStd,
47 SampleVar,
48)
49from .tensor import Tensor
51try:
52 import crick # pyright: ignore[reportMissingImports]
54except Exception:
55 crick = None
57 class TDigest:
58 def update(self, obj: Any):
59 pass
61 def quantile(self, q: Any) -> Any:
62 pass
64else:
65 TDigest = crick.TDigest # type: ignore
68class MeanCalculator:
69 """to calculate sample and dataset mean for in-memory samples"""
71 def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]):
72 super().__init__()
73 self._n: int = 0
74 self._mean: Optional[Tensor] = None
75 self._axes = None if axes is None else tuple(axes)
76 self._member_id = member_id
77 self._sample_mean = SampleMean(member_id=self._member_id, axes=self._axes)
78 self._dataset_mean = DatasetMean(member_id=self._member_id, axes=self._axes)
80 def compute(self, sample: Sample) -> Dict[SampleMean, MeasureValue]:
81 return {self._sample_mean: self._compute_impl(sample)}
83 def _compute_impl(self, sample: Sample) -> Tensor:
84 tensor = sample.members[self._member_id].astype("float64", copy=False)
85 return tensor.mean(dim=self._axes)
87 def update(self, sample: Sample) -> None:
88 mean = self._compute_impl(sample)
89 self._update_impl(sample.members[self._member_id], mean)
91 def compute_and_update(self, sample: Sample) -> Dict[SampleMean, MeasureValue]:
92 mean = self._compute_impl(sample)
93 self._update_impl(sample.members[self._member_id], mean)
94 return {self._sample_mean: mean}
96 def _update_impl(self, tensor: Tensor, tensor_mean: Tensor):
97 assert tensor_mean.dtype == "float64"
98 # reduced voxel count
99 n_b = int(tensor.size / tensor_mean.size)
101 if self._mean is None:
102 assert self._n == 0
103 self._n = n_b
104 self._mean = tensor_mean
105 else:
106 assert self._n != 0
107 n_a = self._n
108 mean_old = self._mean
109 self._n = n_a + n_b
110 self._mean = (n_a * mean_old + n_b * tensor_mean) / self._n
111 assert self._mean.dtype == "float64"
113 def finalize(self) -> Dict[DatasetMean, MeasureValue]:
114 if self._mean is None:
115 return {}
116 else:
117 return {self._dataset_mean: self._mean}
120class MeanVarStdCalculator:
121 """to calculate sample and dataset mean, variance or standard deviation"""
123 def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]):
124 super().__init__()
125 self._axes = None if axes is None else tuple(map(AxisId, axes))
126 self._member_id = member_id
127 self._n: int = 0
128 self._mean: Optional[Tensor] = None
129 self._m2: Optional[Tensor] = None
131 def compute(
132 self, sample: Sample
133 ) -> Dict[Union[SampleMean, SampleVar, SampleStd], MeasureValue]:
134 tensor = sample.members[self._member_id]
135 mean = tensor.mean(dim=self._axes)
136 c = (tensor - mean).data
137 if self._axes is None:
138 n = tensor.size
139 else:
140 n = int(np.prod([tensor.sizes[d] for d in self._axes]))
142 if xr.__version__.startswith("2023"):
143 var = ( # pyright: ignore[reportUnknownVariableType]
144 xr.dot(c, c, dims=self._axes) / n
145 )
146 else:
147 var = ( # pyright: ignore[reportUnknownVariableType]
148 xr.dot(c, c, dim=self._axes) / n
149 )
151 assert isinstance(var, xr.DataArray)
152 std = np.sqrt(var)
153 assert isinstance(std, xr.DataArray)
154 return {
155 SampleMean(axes=self._axes, member_id=self._member_id): mean,
156 SampleVar(axes=self._axes, member_id=self._member_id): Tensor.from_xarray(
157 var
158 ),
159 SampleStd(axes=self._axes, member_id=self._member_id): Tensor.from_xarray(
160 std
161 ),
162 }
164 def update(self, sample: Sample):
165 if self._axes is not None and BATCH_AXIS_ID not in self._axes:
166 return
168 tensor = sample.members[self._member_id].astype("float64", copy=False)
169 mean_b = tensor.mean(dim=self._axes)
170 assert mean_b.dtype == "float64"
171 # reduced voxel count
172 n_b = int(tensor.size / mean_b.size)
173 m2_b = ((tensor - mean_b) ** 2).sum(dim=self._axes)
174 assert m2_b.dtype == "float64"
175 if self._mean is None:
176 assert self._m2 is None
177 self._n = n_b
178 self._mean = mean_b
179 self._m2 = m2_b
180 else:
181 n_a = self._n
182 mean_a = self._mean
183 m2_a = self._m2
184 self._n = n = n_a + n_b
185 self._mean = (n_a * mean_a + n_b * mean_b) / n
186 assert self._mean.dtype == "float64"
187 d = mean_b - mean_a
188 self._m2 = m2_a + m2_b + d**2 * n_a * n_b / n
189 assert self._m2.dtype == "float64"
191 def finalize(
192 self,
193 ) -> Dict[Union[DatasetMean, DatasetVar, DatasetStd], MeasureValue]:
194 if (
195 self._axes is not None
196 and BATCH_AXIS_ID not in self._axes
197 or self._mean is None
198 ):
199 return {}
200 else:
201 assert self._m2 is not None
202 var = self._m2 / self._n
203 sqrt = var**0.5
204 if isinstance(sqrt, (int, float)):
205 # var and mean are scalar tensors, let's keep it consistent
206 sqrt = Tensor.from_xarray(xr.DataArray(sqrt))
208 assert isinstance(sqrt, Tensor), type(sqrt)
209 return {
210 DatasetMean(member_id=self._member_id, axes=self._axes): self._mean,
211 DatasetVar(member_id=self._member_id, axes=self._axes): var,
212 DatasetStd(member_id=self._member_id, axes=self._axes): sqrt,
213 }
216class SamplePercentilesCalculator:
217 """to calculate sample percentiles"""
219 def __init__(
220 self,
221 member_id: MemberId,
222 axes: Optional[Sequence[AxisId]],
223 qs: Collection[float],
224 ):
225 super().__init__()
226 assert all(0.0 <= q <= 1.0 for q in qs)
227 self._qs = sorted(set(qs))
228 self._axes = None if axes is None else tuple(axes)
229 self._member_id = member_id
231 def compute(self, sample: Sample) -> Dict[SampleQuantile, MeasureValue]:
232 tensor = sample.members[self._member_id]
233 ps = tensor.quantile(self._qs, dim=self._axes)
234 return {
235 SampleQuantile(q=q, axes=self._axes, member_id=self._member_id): p
236 for q, p in zip(self._qs, ps)
237 }
240class MeanPercentilesCalculator:
241 """to calculate dataset percentiles heuristically by averaging across samples
242 **note**: the returned dataset percentiles are an estiamte and **not mathematically correct**
243 """
245 def __init__(
246 self,
247 member_id: MemberId,
248 axes: Optional[Sequence[AxisId]],
249 qs: Collection[float],
250 ):
251 super().__init__()
252 assert all(0.0 <= q <= 1.0 for q in qs)
253 self._qs = sorted(set(qs))
254 self._axes = None if axes is None else tuple(axes)
255 self._member_id = member_id
256 self._n: int = 0
257 self._estimates: Optional[Tensor] = None
259 def update(self, sample: Sample):
260 tensor = sample.members[self._member_id]
261 sample_estimates = tensor.quantile(self._qs, dim=self._axes).astype(
262 "float64", copy=False
263 )
265 # reduced voxel count
266 n = int(tensor.size / np.prod(sample_estimates.shape_tuple[1:]))
268 if self._estimates is None:
269 assert self._n == 0
270 self._estimates = sample_estimates
271 else:
272 self._estimates = (self._n * self._estimates + n * sample_estimates) / (
273 self._n + n
274 )
275 assert self._estimates.dtype == "float64"
277 self._n += n
279 def finalize(self) -> Dict[DatasetPercentile, MeasureValue]:
280 if self._estimates is None:
281 return {}
282 else:
283 warnings.warn(
284 "Computed dataset percentiles naively by averaging percentiles of samples."
285 )
286 return {
287 DatasetPercentile(q=q, axes=self._axes, member_id=self._member_id): e
288 for q, e in zip(self._qs, self._estimates)
289 }
292class CrickPercentilesCalculator:
293 """to calculate dataset percentiles with the experimental [crick libray](https://github.com/dask/crick)"""
295 def __init__(
296 self,
297 member_id: MemberId,
298 axes: Optional[Sequence[AxisId]],
299 qs: Collection[float],
300 ):
301 warnings.warn(
302 "Computing dataset percentiles with experimental 'crick' library."
303 )
304 super().__init__()
305 assert all(0.0 <= q <= 1.0 for q in qs)
306 assert axes is None or "_percentiles" not in axes
307 self._qs = sorted(set(qs))
308 self._axes = None if axes is None else tuple(axes)
309 self._member_id = member_id
310 self._digest: Optional[List[TDigest]] = None
311 self._dims: Optional[Tuple[AxisId, ...]] = None
312 self._indices: Optional[Iterator[Tuple[int, ...]]] = None
313 self._shape: Optional[Tuple[int, ...]] = None
315 def _initialize(self, tensor_sizes: PerAxis[int]):
316 assert crick is not None
317 out_sizes: OrderedDict[AxisId, int] = collections.OrderedDict(
318 _percentiles=len(self._qs)
319 )
320 if self._axes is not None:
321 for d, s in tensor_sizes.items():
322 if d not in self._axes:
323 out_sizes[d] = s
325 self._dims, self._shape = zip(*out_sizes.items())
326 assert self._shape is not None
327 d = int(np.prod(self._shape[1:]))
328 self._digest = [TDigest() for _ in range(d)]
329 self._indices = product(*map(range, self._shape[1:]))
331 def update(self, part: Sample):
332 tensor = (
333 part.members[self._member_id]
334 if isinstance(part, Sample)
335 else part.members[self._member_id].data
336 )
337 assert "_percentiles" not in tensor.dims
338 if self._digest is None:
339 self._initialize(tensor.tagged_shape)
341 assert self._digest is not None
342 assert self._indices is not None
343 assert self._dims is not None
344 for i, idx in enumerate(self._indices):
345 self._digest[i].update(tensor[dict(zip(self._dims[1:], idx))])
347 def finalize(self) -> Dict[DatasetPercentile, MeasureValue]:
348 if self._digest is None:
349 return {}
350 else:
351 assert self._dims is not None
352 assert self._shape is not None
354 vs: NDArray[Any] = np.asarray(
355 [[d.quantile(q) for d in self._digest] for q in self._qs]
356 ).reshape(self._shape)
357 return {
358 DatasetPercentile(
359 q=q, axes=self._axes, member_id=self._member_id
360 ): Tensor(v, dims=self._dims[1:])
361 for q, v in zip(self._qs, vs)
362 }
365if crick is None:
366 DatasetPercentilesCalculator: Type[
367 Union[MeanPercentilesCalculator, CrickPercentilesCalculator]
368 ] = MeanPercentilesCalculator
369else:
370 DatasetPercentilesCalculator = CrickPercentilesCalculator
373class NaiveSampleMeasureCalculator:
374 """wrapper for measures to match interface of other sample measure calculators"""
376 def __init__(self, member_id: MemberId, measure: SampleMeasure):
377 super().__init__()
378 self.tensor_name = member_id
379 self.measure = measure
381 def compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]:
382 return {self.measure: self.measure.compute(sample)}
385SampleMeasureCalculator = Union[
386 MeanCalculator,
387 MeanVarStdCalculator,
388 SamplePercentilesCalculator,
389 NaiveSampleMeasureCalculator,
390]
391DatasetMeasureCalculator = Union[
392 MeanCalculator, MeanVarStdCalculator, DatasetPercentilesCalculator
393]
396class StatsCalculator:
397 """Estimates dataset statistics and computes sample statistics efficiently"""
399 def __init__(
400 self,
401 measures: Collection[Measure],
402 initial_dataset_measures: Optional[
403 Mapping[DatasetMeasure, MeasureValue]
404 ] = None,
405 ):
406 super().__init__()
407 self.sample_count = 0
408 self.sample_calculators, self.dataset_calculators = get_measure_calculators(
409 measures
410 )
411 if not initial_dataset_measures:
412 self._current_dataset_measures: Optional[
413 Dict[DatasetMeasure, MeasureValue]
414 ] = None
415 else:
416 missing_dataset_meas = {
417 m
418 for m in measures
419 if isinstance(m, DatasetMeasureBase)
420 and m not in initial_dataset_measures
421 }
422 if missing_dataset_meas:
423 logger.debug(
424 f"ignoring `initial_dataset_measure` as it is missing {missing_dataset_meas}"
425 )
426 self._current_dataset_measures = None
427 else:
428 self._current_dataset_measures = dict(initial_dataset_measures)
430 @property
431 def has_dataset_measures(self):
432 return self._current_dataset_measures is not None
434 def update(
435 self,
436 sample: Union[Sample, Iterable[Sample]],
437 ) -> None:
438 _ = self._update(sample)
440 def finalize(self) -> Dict[DatasetMeasure, MeasureValue]:
441 """returns aggregated dataset statistics"""
442 if self._current_dataset_measures is None:
443 self._current_dataset_measures = {}
444 for calc in self.dataset_calculators:
445 values = calc.finalize()
446 self._current_dataset_measures.update(values.items())
448 return self._current_dataset_measures
450 def update_and_get_all(
451 self,
452 sample: Union[Sample, Iterable[Sample]],
453 ) -> Dict[Measure, MeasureValue]:
454 """Returns sample as well as updated dataset statistics"""
455 last_sample = self._update(sample)
456 if last_sample is None:
457 raise ValueError("`sample` was not a `Sample`, nor did it yield any.")
459 return {**self._compute(last_sample), **self.finalize()}
461 def skip_update_and_get_all(self, sample: Sample) -> Dict[Measure, MeasureValue]:
462 """Returns sample as well as previously computed dataset statistics"""
463 return {**self._compute(sample), **self.finalize()}
465 def _compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]:
466 ret: Dict[SampleMeasure, MeasureValue] = {}
467 for calc in self.sample_calculators:
468 values = calc.compute(sample)
469 ret.update(values.items())
471 return ret
473 def _update(self, sample: Union[Sample, Iterable[Sample]]) -> Optional[Sample]:
474 self.sample_count += 1
475 samples = [sample] if isinstance(sample, Sample) else sample
476 last_sample = None
477 for el in samples:
478 last_sample = el
479 for calc in self.dataset_calculators:
480 calc.update(el)
482 self._current_dataset_measures = None
483 return last_sample
486def get_measure_calculators(
487 required_measures: Iterable[Measure],
488) -> Tuple[List[SampleMeasureCalculator], List[DatasetMeasureCalculator]]:
489 """determines which calculators are needed to compute the required measures efficiently"""
491 sample_calculators: List[SampleMeasureCalculator] = []
492 dataset_calculators: List[DatasetMeasureCalculator] = []
494 # split required measures into groups
495 required_sample_means: Set[SampleMean] = set()
496 required_dataset_means: Set[DatasetMean] = set()
497 required_sample_mean_var_std: Set[Union[SampleMean, SampleVar, SampleStd]] = set()
498 required_dataset_mean_var_std: Set[Union[DatasetMean, DatasetVar, DatasetStd]] = (
499 set()
500 )
501 required_sample_percentiles: Dict[
502 Tuple[MemberId, Optional[Tuple[AxisId, ...]]], Set[float]
503 ] = {}
504 required_dataset_percentiles: Dict[
505 Tuple[MemberId, Optional[Tuple[AxisId, ...]]], Set[float]
506 ] = {}
508 for rm in required_measures:
509 if isinstance(rm, SampleMean):
510 required_sample_means.add(rm)
511 elif isinstance(rm, DatasetMean):
512 required_dataset_means.add(rm)
513 elif isinstance(rm, (SampleVar, SampleStd)):
514 required_sample_mean_var_std.update(
515 {
516 msv(axes=rm.axes, member_id=rm.member_id)
517 for msv in (SampleMean, SampleStd, SampleVar)
518 }
519 )
520 assert rm in required_sample_mean_var_std
521 elif isinstance(rm, (DatasetVar, DatasetStd)):
522 required_dataset_mean_var_std.update(
523 {
524 msv(axes=rm.axes, member_id=rm.member_id)
525 for msv in (DatasetMean, DatasetStd, DatasetVar)
526 }
527 )
528 assert rm in required_dataset_mean_var_std
529 elif isinstance(rm, SampleQuantile):
530 required_sample_percentiles.setdefault((rm.member_id, rm.axes), set()).add(
531 rm.q
532 )
533 elif isinstance(rm, DatasetPercentile):
534 required_dataset_percentiles.setdefault((rm.member_id, rm.axes), set()).add(
535 rm.q
536 )
537 else:
538 assert_never(rm)
540 for rm in required_sample_means:
541 if rm in required_sample_mean_var_std:
542 # computed togehter with var and std
543 continue
545 sample_calculators.append(MeanCalculator(member_id=rm.member_id, axes=rm.axes))
547 for rm in required_sample_mean_var_std:
548 sample_calculators.append(
549 MeanVarStdCalculator(member_id=rm.member_id, axes=rm.axes)
550 )
552 for rm in required_dataset_means:
553 if rm in required_dataset_mean_var_std:
554 # computed togehter with var and std
555 continue
557 dataset_calculators.append(MeanCalculator(member_id=rm.member_id, axes=rm.axes))
559 for rm in required_dataset_mean_var_std:
560 dataset_calculators.append(
561 MeanVarStdCalculator(member_id=rm.member_id, axes=rm.axes)
562 )
564 for (tid, axes), qs in required_sample_percentiles.items():
565 sample_calculators.append(
566 SamplePercentilesCalculator(member_id=tid, axes=axes, qs=qs)
567 )
569 for (tid, axes), qs in required_dataset_percentiles.items():
570 dataset_calculators.append(
571 DatasetPercentilesCalculator(member_id=tid, axes=axes, qs=qs)
572 )
574 return sample_calculators, dataset_calculators
577def compute_dataset_measures(
578 measures: Iterable[DatasetMeasure], dataset: Iterable[Sample]
579) -> Dict[DatasetMeasure, MeasureValue]:
580 """compute all dataset `measures` for the given `dataset`"""
581 sample_calculators, calculators = get_measure_calculators(measures)
582 assert not sample_calculators
584 ret: Dict[DatasetMeasure, MeasureValue] = {}
586 for sample in dataset:
587 for calc in calculators:
588 calc.update(sample)
590 for calc in calculators:
591 ret.update(calc.finalize().items())
593 return ret
596def compute_sample_measures(
597 measures: Iterable[SampleMeasure], sample: Sample
598) -> Dict[SampleMeasure, MeasureValue]:
599 """compute all sample `measures` for the given `sample`"""
600 calculators, dataset_calculators = get_measure_calculators(measures)
601 assert not dataset_calculators
602 ret: Dict[SampleMeasure, MeasureValue] = {}
604 for calc in calculators:
605 ret.update(calc.compute(sample).items())
607 return ret
610def compute_measures(
611 measures: Iterable[Measure], dataset: Iterable[Sample]
612) -> Dict[Measure, MeasureValue]:
613 """compute all `measures` for the given `dataset`
614 sample measures are computed for the last sample in `dataset`"""
615 sample_calculators, dataset_calculators = get_measure_calculators(measures)
616 ret: Dict[Measure, MeasureValue] = {}
617 sample = None
618 for sample in dataset:
619 for calc in dataset_calculators:
620 calc.update(sample)
621 if sample is None:
622 raise ValueError("empty dataset")
624 for calc in dataset_calculators:
625 ret.update(calc.finalize().items())
627 for calc in sample_calculators:
628 ret.update(calc.compute(sample).items())
630 return ret