bioimageio.core.stat_calculators
1from __future__ import annotations 2 3import collections.abc 4import warnings 5from itertools import product 6from typing import ( 7 Any, 8 Collection, 9 Dict, 10 Iterable, 11 Iterator, 12 List, 13 Mapping, 14 Optional, 15 OrderedDict, 16 Sequence, 17 Set, 18 Tuple, 19 Type, 20 Union, 21) 22 23import numpy as np 24import xarray as xr 25from loguru import logger 26from numpy.typing import NDArray 27from typing_extensions import assert_never 28 29from .axis import AxisId, PerAxis 30from .common import MemberId 31from .sample import Sample 32from .stat_measures import ( 33 DatasetMean, 34 DatasetMeasure, 35 DatasetMeasureBase, 36 DatasetPercentile, 37 DatasetStd, 38 DatasetVar, 39 Measure, 40 MeasureValue, 41 SampleMean, 42 SampleMeasure, 43 SampleQuantile, 44 SampleStd, 45 SampleVar, 46) 47from .tensor import Tensor 48 49try: 50 import crick 51 52except Exception: 53 crick = None 54 55 class TDigest: 56 def update(self, obj: Any): 57 pass 58 59 def quantile(self, q: Any) -> Any: 60 pass 61 62else: 63 TDigest = crick.TDigest # type: ignore 64 65 66class MeanCalculator: 67 """to calculate sample and dataset mean for in-memory samples""" 68 69 def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]): 70 super().__init__() 71 self._n: int = 0 72 self._mean: Optional[Tensor] = None 73 self._axes = None if axes is None else tuple(axes) 74 self._member_id = member_id 75 self._sample_mean = SampleMean(member_id=self._member_id, axes=self._axes) 76 self._dataset_mean = DatasetMean(member_id=self._member_id, axes=self._axes) 77 78 def compute(self, sample: Sample) -> Dict[SampleMean, MeasureValue]: 79 return {self._sample_mean: self._compute_impl(sample)} 80 81 def _compute_impl(self, sample: Sample) -> Tensor: 82 tensor = sample.members[self._member_id].astype("float64", copy=False) 83 return tensor.mean(dim=self._axes) 84 85 def update(self, sample: Sample) -> None: 86 mean = self._compute_impl(sample) 87 self._update_impl(sample.members[self._member_id], mean) 88 89 def compute_and_update(self, sample: Sample) -> Dict[SampleMean, MeasureValue]: 90 mean = self._compute_impl(sample) 91 self._update_impl(sample.members[self._member_id], mean) 92 return {self._sample_mean: mean} 93 94 def _update_impl(self, tensor: Tensor, tensor_mean: Tensor): 95 assert tensor_mean.dtype == "float64" 96 # reduced voxel count 97 n_b = int(tensor.size / tensor_mean.size) 98 99 if self._mean is None: 100 assert self._n == 0 101 self._n = n_b 102 self._mean = tensor_mean 103 else: 104 assert self._n != 0 105 n_a = self._n 106 mean_old = self._mean 107 self._n = n_a + n_b 108 self._mean = (n_a * mean_old + n_b * tensor_mean) / self._n 109 assert self._mean.dtype == "float64" 110 111 def finalize(self) -> Dict[DatasetMean, MeasureValue]: 112 if self._mean is None: 113 return {} 114 else: 115 return {self._dataset_mean: self._mean} 116 117 118class MeanVarStdCalculator: 119 """to calculate sample and dataset mean, variance or standard deviation""" 120 121 def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]): 122 super().__init__() 123 self._axes = None if axes is None else tuple(axes) 124 self._member_id = member_id 125 self._n: int = 0 126 self._mean: Optional[Tensor] = None 127 self._m2: Optional[Tensor] = None 128 129 def compute( 130 self, sample: Sample 131 ) -> Dict[Union[SampleMean, SampleVar, SampleStd], MeasureValue]: 132 tensor = sample.members[self._member_id] 133 mean = tensor.mean(dim=self._axes) 134 c = (tensor - mean).data 135 if self._axes is None: 136 n = tensor.size 137 else: 138 n = int(np.prod([tensor.sizes[d] for d in self._axes])) 139 140 var = xr.dot(c, c, dims=self._axes) / n 141 assert isinstance(var, xr.DataArray) 142 std = np.sqrt(var) 143 assert isinstance(std, xr.DataArray) 144 return { 145 SampleMean(axes=self._axes, member_id=self._member_id): mean, 146 SampleVar(axes=self._axes, member_id=self._member_id): Tensor.from_xarray( 147 var 148 ), 149 SampleStd(axes=self._axes, member_id=self._member_id): Tensor.from_xarray( 150 std 151 ), 152 } 153 154 def update(self, sample: Sample): 155 tensor = sample.members[self._member_id].astype("float64", copy=False) 156 mean_b = tensor.mean(dim=self._axes) 157 assert mean_b.dtype == "float64" 158 # reduced voxel count 159 n_b = int(tensor.size / mean_b.size) 160 m2_b = ((tensor - mean_b) ** 2).sum(dim=self._axes) 161 assert m2_b.dtype == "float64" 162 if self._mean is None: 163 assert self._m2 is None 164 self._n = n_b 165 self._mean = mean_b 166 self._m2 = m2_b 167 else: 168 n_a = self._n 169 mean_a = self._mean 170 m2_a = self._m2 171 self._n = n = n_a + n_b 172 self._mean = (n_a * mean_a + n_b * mean_b) / n 173 assert self._mean.dtype == "float64" 174 d = mean_b - mean_a 175 self._m2 = m2_a + m2_b + d**2 * n_a * n_b / n 176 assert self._m2.dtype == "float64" 177 178 def finalize( 179 self, 180 ) -> Dict[Union[DatasetMean, DatasetVar, DatasetStd], MeasureValue]: 181 if self._mean is None: 182 return {} 183 else: 184 assert self._m2 is not None 185 var = self._m2 / self._n 186 sqrt = np.sqrt(var) 187 if isinstance(sqrt, (int, float)): 188 # var and mean are scalar tensors, let's keep it consistent 189 sqrt = Tensor.from_xarray(xr.DataArray(sqrt)) 190 191 assert isinstance(sqrt, Tensor), type(sqrt) 192 return { 193 DatasetMean(member_id=self._member_id, axes=self._axes): self._mean, 194 DatasetVar(member_id=self._member_id, axes=self._axes): var, 195 DatasetStd(member_id=self._member_id, axes=self._axes): sqrt, 196 } 197 198 199class SamplePercentilesCalculator: 200 """to calculate sample percentiles""" 201 202 def __init__( 203 self, 204 member_id: MemberId, 205 axes: Optional[Sequence[AxisId]], 206 qs: Collection[float], 207 ): 208 super().__init__() 209 assert all(0.0 <= q <= 1.0 for q in qs) 210 self._qs = sorted(set(qs)) 211 self._axes = None if axes is None else tuple(axes) 212 self._member_id = member_id 213 214 def compute(self, sample: Sample) -> Dict[SampleQuantile, MeasureValue]: 215 tensor = sample.members[self._member_id] 216 ps = tensor.quantile(self._qs, dim=self._axes) 217 return { 218 SampleQuantile(q=q, axes=self._axes, member_id=self._member_id): p 219 for q, p in zip(self._qs, ps) 220 } 221 222 223class MeanPercentilesCalculator: 224 """to calculate dataset percentiles heuristically by averaging across samples 225 **note**: the returned dataset percentiles are an estiamte and **not mathematically correct** 226 """ 227 228 def __init__( 229 self, 230 member_id: MemberId, 231 axes: Optional[Sequence[AxisId]], 232 qs: Collection[float], 233 ): 234 super().__init__() 235 assert all(0.0 <= q <= 1.0 for q in qs) 236 self._qs = sorted(set(qs)) 237 self._axes = None if axes is None else tuple(axes) 238 self._member_id = member_id 239 self._n: int = 0 240 self._estimates: Optional[Tensor] = None 241 242 def update(self, sample: Sample): 243 tensor = sample.members[self._member_id] 244 sample_estimates = tensor.quantile(self._qs, dim=self._axes).astype( 245 "float64", copy=False 246 ) 247 248 # reduced voxel count 249 n = int(tensor.size / np.prod(sample_estimates.shape_tuple[1:])) 250 251 if self._estimates is None: 252 assert self._n == 0 253 self._estimates = sample_estimates 254 else: 255 self._estimates = (self._n * self._estimates + n * sample_estimates) / ( 256 self._n + n 257 ) 258 assert self._estimates.dtype == "float64" 259 260 self._n += n 261 262 def finalize(self) -> Dict[DatasetPercentile, MeasureValue]: 263 if self._estimates is None: 264 return {} 265 else: 266 warnings.warn( 267 "Computed dataset percentiles naively by averaging percentiles of samples." 268 ) 269 return { 270 DatasetPercentile(q=q, axes=self._axes, member_id=self._member_id): e 271 for q, e in zip(self._qs, self._estimates) 272 } 273 274 275class CrickPercentilesCalculator: 276 """to calculate dataset percentiles with the experimental [crick libray](https://github.com/dask/crick)""" 277 278 def __init__( 279 self, 280 member_id: MemberId, 281 axes: Optional[Sequence[AxisId]], 282 qs: Collection[float], 283 ): 284 warnings.warn( 285 "Computing dataset percentiles with experimental 'crick' library." 286 ) 287 super().__init__() 288 assert all(0.0 <= q <= 1.0 for q in qs) 289 assert axes is None or "_percentiles" not in axes 290 self._qs = sorted(set(qs)) 291 self._axes = None if axes is None else tuple(axes) 292 self._member_id = member_id 293 self._digest: Optional[List[TDigest]] = None 294 self._dims: Optional[Tuple[AxisId, ...]] = None 295 self._indices: Optional[Iterator[Tuple[int, ...]]] = None 296 self._shape: Optional[Tuple[int, ...]] = None 297 298 def _initialize(self, tensor_sizes: PerAxis[int]): 299 assert crick is not None 300 out_sizes: OrderedDict[AxisId, int] = collections.OrderedDict( 301 _percentiles=len(self._qs) 302 ) 303 if self._axes is not None: 304 for d, s in tensor_sizes.items(): 305 if d not in self._axes: 306 out_sizes[d] = s 307 308 self._dims, self._shape = zip(*out_sizes.items()) 309 d = int(np.prod(self._shape[1:])) # type: ignore 310 self._digest = [TDigest() for _ in range(d)] 311 self._indices = product(*map(range, self._shape[1:])) 312 313 def update(self, part: Sample): 314 tensor = ( 315 part.members[self._member_id] 316 if isinstance(part, Sample) 317 else part.members[self._member_id].data 318 ) 319 assert "_percentiles" not in tensor.dims 320 if self._digest is None: 321 self._initialize(tensor.tagged_shape) 322 323 assert self._digest is not None 324 assert self._indices is not None 325 assert self._dims is not None 326 for i, idx in enumerate(self._indices): 327 self._digest[i].update(tensor[dict(zip(self._dims[1:], idx))]) 328 329 def finalize(self) -> Dict[DatasetPercentile, MeasureValue]: 330 if self._digest is None: 331 return {} 332 else: 333 assert self._dims is not None 334 assert self._shape is not None 335 336 vs: NDArray[Any] = np.asarray( 337 [[d.quantile(q) for d in self._digest] for q in self._qs] 338 ).reshape(self._shape) 339 return { 340 DatasetPercentile( 341 q=q, axes=self._axes, member_id=self._member_id 342 ): Tensor(v, dims=self._dims[1:]) 343 for q, v in zip(self._qs, vs) 344 } 345 346 347if crick is None: 348 DatasetPercentilesCalculator: Type[ 349 Union[MeanPercentilesCalculator, CrickPercentilesCalculator] 350 ] = MeanPercentilesCalculator 351else: 352 DatasetPercentilesCalculator = CrickPercentilesCalculator 353 354 355class NaiveSampleMeasureCalculator: 356 """wrapper for measures to match interface of other sample measure calculators""" 357 358 def __init__(self, member_id: MemberId, measure: SampleMeasure): 359 super().__init__() 360 self.tensor_name = member_id 361 self.measure = measure 362 363 def compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]: 364 return {self.measure: self.measure.compute(sample)} 365 366 367SampleMeasureCalculator = Union[ 368 MeanCalculator, 369 MeanVarStdCalculator, 370 SamplePercentilesCalculator, 371 NaiveSampleMeasureCalculator, 372] 373DatasetMeasureCalculator = Union[ 374 MeanCalculator, MeanVarStdCalculator, DatasetPercentilesCalculator 375] 376 377 378class StatsCalculator: 379 """Estimates dataset statistics and computes sample statistics efficiently""" 380 381 def __init__( 382 self, 383 measures: Collection[Measure], 384 initial_dataset_measures: Optional[ 385 Mapping[DatasetMeasure, MeasureValue] 386 ] = None, 387 ): 388 super().__init__() 389 self.sample_count = 0 390 self.sample_calculators, self.dataset_calculators = get_measure_calculators( 391 measures 392 ) 393 if not initial_dataset_measures: 394 self._current_dataset_measures: Optional[ 395 Dict[DatasetMeasure, MeasureValue] 396 ] = None 397 else: 398 missing_dataset_meas = { 399 m 400 for m in measures 401 if isinstance(m, DatasetMeasureBase) 402 and m not in initial_dataset_measures 403 } 404 if missing_dataset_meas: 405 logger.debug( 406 f"ignoring `initial_dataset_measure` as it is missing {missing_dataset_meas}" 407 ) 408 self._current_dataset_measures = None 409 else: 410 self._current_dataset_measures = dict(initial_dataset_measures) 411 412 @property 413 def has_dataset_measures(self): 414 return self._current_dataset_measures is not None 415 416 def update( 417 self, 418 sample: Union[Sample, Iterable[Sample]], 419 ) -> None: 420 _ = self._update(sample) 421 422 def finalize(self) -> Dict[DatasetMeasure, MeasureValue]: 423 """returns aggregated dataset statistics""" 424 if self._current_dataset_measures is None: 425 self._current_dataset_measures = {} 426 for calc in self.dataset_calculators: 427 values = calc.finalize() 428 self._current_dataset_measures.update(values.items()) 429 430 return self._current_dataset_measures 431 432 def update_and_get_all( 433 self, 434 sample: Union[Sample, Iterable[Sample]], 435 ) -> Dict[Measure, MeasureValue]: 436 """Returns sample as well as updated dataset statistics""" 437 last_sample = self._update(sample) 438 if last_sample is None: 439 raise ValueError("`sample` was not a `Sample`, nor did it yield any.") 440 441 return {**self._compute(last_sample), **self.finalize()} 442 443 def skip_update_and_get_all(self, sample: Sample) -> Dict[Measure, MeasureValue]: 444 """Returns sample as well as previously computed dataset statistics""" 445 return {**self._compute(sample), **self.finalize()} 446 447 def _compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]: 448 ret: Dict[SampleMeasure, MeasureValue] = {} 449 for calc in self.sample_calculators: 450 values = calc.compute(sample) 451 ret.update(values.items()) 452 453 return ret 454 455 def _update(self, sample: Union[Sample, Iterable[Sample]]) -> Optional[Sample]: 456 self.sample_count += 1 457 samples = [sample] if isinstance(sample, Sample) else sample 458 last_sample = None 459 for el in samples: 460 last_sample = el 461 for calc in self.dataset_calculators: 462 calc.update(el) 463 464 self._current_dataset_measures = None 465 return last_sample 466 467 468def get_measure_calculators( 469 required_measures: Iterable[Measure], 470) -> Tuple[List[SampleMeasureCalculator], List[DatasetMeasureCalculator]]: 471 """determines which calculators are needed to compute the required measures efficiently""" 472 473 sample_calculators: List[SampleMeasureCalculator] = [] 474 dataset_calculators: List[DatasetMeasureCalculator] = [] 475 476 # split required measures into groups 477 required_sample_means: Set[SampleMean] = set() 478 required_dataset_means: Set[DatasetMean] = set() 479 required_sample_mean_var_std: Set[Union[SampleMean, SampleVar, SampleStd]] = set() 480 required_dataset_mean_var_std: Set[Union[DatasetMean, DatasetVar, DatasetStd]] = ( 481 set() 482 ) 483 required_sample_percentiles: Dict[ 484 Tuple[MemberId, Optional[Tuple[AxisId, ...]]], Set[float] 485 ] = {} 486 required_dataset_percentiles: Dict[ 487 Tuple[MemberId, Optional[Tuple[AxisId, ...]]], Set[float] 488 ] = {} 489 490 for rm in required_measures: 491 if isinstance(rm, SampleMean): 492 required_sample_means.add(rm) 493 elif isinstance(rm, DatasetMean): 494 required_dataset_means.add(rm) 495 elif isinstance(rm, (SampleVar, SampleStd)): 496 required_sample_mean_var_std.update( 497 { 498 msv(axes=rm.axes, member_id=rm.member_id) 499 for msv in (SampleMean, SampleStd, SampleVar) 500 } 501 ) 502 assert rm in required_sample_mean_var_std 503 elif isinstance(rm, (DatasetVar, DatasetStd)): 504 required_dataset_mean_var_std.update( 505 { 506 msv(axes=rm.axes, member_id=rm.member_id) 507 for msv in (DatasetMean, DatasetStd, DatasetVar) 508 } 509 ) 510 assert rm in required_dataset_mean_var_std 511 elif isinstance(rm, SampleQuantile): 512 required_sample_percentiles.setdefault((rm.member_id, rm.axes), set()).add( 513 rm.q 514 ) 515 elif isinstance(rm, DatasetPercentile): 516 required_dataset_percentiles.setdefault((rm.member_id, rm.axes), set()).add( 517 rm.q 518 ) 519 else: 520 assert_never(rm) 521 522 for rm in required_sample_means: 523 if rm in required_sample_mean_var_std: 524 # computed togehter with var and std 525 continue 526 527 sample_calculators.append(MeanCalculator(member_id=rm.member_id, axes=rm.axes)) 528 529 for rm in required_sample_mean_var_std: 530 sample_calculators.append( 531 MeanVarStdCalculator(member_id=rm.member_id, axes=rm.axes) 532 ) 533 534 for rm in required_dataset_means: 535 if rm in required_dataset_mean_var_std: 536 # computed togehter with var and std 537 continue 538 539 dataset_calculators.append(MeanCalculator(member_id=rm.member_id, axes=rm.axes)) 540 541 for rm in required_dataset_mean_var_std: 542 dataset_calculators.append( 543 MeanVarStdCalculator(member_id=rm.member_id, axes=rm.axes) 544 ) 545 546 for (tid, axes), qs in required_sample_percentiles.items(): 547 sample_calculators.append( 548 SamplePercentilesCalculator(member_id=tid, axes=axes, qs=qs) 549 ) 550 551 for (tid, axes), qs in required_dataset_percentiles.items(): 552 dataset_calculators.append( 553 DatasetPercentilesCalculator(member_id=tid, axes=axes, qs=qs) 554 ) 555 556 return sample_calculators, dataset_calculators 557 558 559def compute_dataset_measures( 560 measures: Iterable[DatasetMeasure], dataset: Iterable[Sample] 561) -> Dict[DatasetMeasure, MeasureValue]: 562 """compute all dataset `measures` for the given `dataset`""" 563 sample_calculators, calculators = get_measure_calculators(measures) 564 assert not sample_calculators 565 566 ret: Dict[DatasetMeasure, MeasureValue] = {} 567 568 for sample in dataset: 569 for calc in calculators: 570 calc.update(sample) 571 572 for calc in calculators: 573 ret.update(calc.finalize().items()) 574 575 return ret 576 577 578def compute_sample_measures( 579 measures: Iterable[SampleMeasure], sample: Sample 580) -> Dict[SampleMeasure, MeasureValue]: 581 """compute all sample `measures` for the given `sample`""" 582 calculators, dataset_calculators = get_measure_calculators(measures) 583 assert not dataset_calculators 584 ret: Dict[SampleMeasure, MeasureValue] = {} 585 586 for calc in calculators: 587 ret.update(calc.compute(sample).items()) 588 589 return ret 590 591 592def compute_measures( 593 measures: Iterable[Measure], dataset: Iterable[Sample] 594) -> Dict[Measure, MeasureValue]: 595 """compute all `measures` for the given `dataset` 596 sample measures are computed for the last sample in `dataset`""" 597 sample_calculators, dataset_calculators = get_measure_calculators(measures) 598 ret: Dict[Measure, MeasureValue] = {} 599 sample = None 600 for sample in dataset: 601 for calc in dataset_calculators: 602 calc.update(sample) 603 if sample is None: 604 raise ValueError("empty dataset") 605 606 for calc in dataset_calculators: 607 ret.update(calc.finalize().items()) 608 609 for calc in sample_calculators: 610 ret.update(calc.compute(sample).items()) 611 612 return ret
67class MeanCalculator: 68 """to calculate sample and dataset mean for in-memory samples""" 69 70 def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]): 71 super().__init__() 72 self._n: int = 0 73 self._mean: Optional[Tensor] = None 74 self._axes = None if axes is None else tuple(axes) 75 self._member_id = member_id 76 self._sample_mean = SampleMean(member_id=self._member_id, axes=self._axes) 77 self._dataset_mean = DatasetMean(member_id=self._member_id, axes=self._axes) 78 79 def compute(self, sample: Sample) -> Dict[SampleMean, MeasureValue]: 80 return {self._sample_mean: self._compute_impl(sample)} 81 82 def _compute_impl(self, sample: Sample) -> Tensor: 83 tensor = sample.members[self._member_id].astype("float64", copy=False) 84 return tensor.mean(dim=self._axes) 85 86 def update(self, sample: Sample) -> None: 87 mean = self._compute_impl(sample) 88 self._update_impl(sample.members[self._member_id], mean) 89 90 def compute_and_update(self, sample: Sample) -> Dict[SampleMean, MeasureValue]: 91 mean = self._compute_impl(sample) 92 self._update_impl(sample.members[self._member_id], mean) 93 return {self._sample_mean: mean} 94 95 def _update_impl(self, tensor: Tensor, tensor_mean: Tensor): 96 assert tensor_mean.dtype == "float64" 97 # reduced voxel count 98 n_b = int(tensor.size / tensor_mean.size) 99 100 if self._mean is None: 101 assert self._n == 0 102 self._n = n_b 103 self._mean = tensor_mean 104 else: 105 assert self._n != 0 106 n_a = self._n 107 mean_old = self._mean 108 self._n = n_a + n_b 109 self._mean = (n_a * mean_old + n_b * tensor_mean) / self._n 110 assert self._mean.dtype == "float64" 111 112 def finalize(self) -> Dict[DatasetMean, MeasureValue]: 113 if self._mean is None: 114 return {} 115 else: 116 return {self._dataset_mean: self._mean}
to calculate sample and dataset mean for in-memory samples
70 def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]): 71 super().__init__() 72 self._n: int = 0 73 self._mean: Optional[Tensor] = None 74 self._axes = None if axes is None else tuple(axes) 75 self._member_id = member_id 76 self._sample_mean = SampleMean(member_id=self._member_id, axes=self._axes) 77 self._dataset_mean = DatasetMean(member_id=self._member_id, axes=self._axes)
119class MeanVarStdCalculator: 120 """to calculate sample and dataset mean, variance or standard deviation""" 121 122 def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]): 123 super().__init__() 124 self._axes = None if axes is None else tuple(axes) 125 self._member_id = member_id 126 self._n: int = 0 127 self._mean: Optional[Tensor] = None 128 self._m2: Optional[Tensor] = None 129 130 def compute( 131 self, sample: Sample 132 ) -> Dict[Union[SampleMean, SampleVar, SampleStd], MeasureValue]: 133 tensor = sample.members[self._member_id] 134 mean = tensor.mean(dim=self._axes) 135 c = (tensor - mean).data 136 if self._axes is None: 137 n = tensor.size 138 else: 139 n = int(np.prod([tensor.sizes[d] for d in self._axes])) 140 141 var = xr.dot(c, c, dims=self._axes) / n 142 assert isinstance(var, xr.DataArray) 143 std = np.sqrt(var) 144 assert isinstance(std, xr.DataArray) 145 return { 146 SampleMean(axes=self._axes, member_id=self._member_id): mean, 147 SampleVar(axes=self._axes, member_id=self._member_id): Tensor.from_xarray( 148 var 149 ), 150 SampleStd(axes=self._axes, member_id=self._member_id): Tensor.from_xarray( 151 std 152 ), 153 } 154 155 def update(self, sample: Sample): 156 tensor = sample.members[self._member_id].astype("float64", copy=False) 157 mean_b = tensor.mean(dim=self._axes) 158 assert mean_b.dtype == "float64" 159 # reduced voxel count 160 n_b = int(tensor.size / mean_b.size) 161 m2_b = ((tensor - mean_b) ** 2).sum(dim=self._axes) 162 assert m2_b.dtype == "float64" 163 if self._mean is None: 164 assert self._m2 is None 165 self._n = n_b 166 self._mean = mean_b 167 self._m2 = m2_b 168 else: 169 n_a = self._n 170 mean_a = self._mean 171 m2_a = self._m2 172 self._n = n = n_a + n_b 173 self._mean = (n_a * mean_a + n_b * mean_b) / n 174 assert self._mean.dtype == "float64" 175 d = mean_b - mean_a 176 self._m2 = m2_a + m2_b + d**2 * n_a * n_b / n 177 assert self._m2.dtype == "float64" 178 179 def finalize( 180 self, 181 ) -> Dict[Union[DatasetMean, DatasetVar, DatasetStd], MeasureValue]: 182 if self._mean is None: 183 return {} 184 else: 185 assert self._m2 is not None 186 var = self._m2 / self._n 187 sqrt = np.sqrt(var) 188 if isinstance(sqrt, (int, float)): 189 # var and mean are scalar tensors, let's keep it consistent 190 sqrt = Tensor.from_xarray(xr.DataArray(sqrt)) 191 192 assert isinstance(sqrt, Tensor), type(sqrt) 193 return { 194 DatasetMean(member_id=self._member_id, axes=self._axes): self._mean, 195 DatasetVar(member_id=self._member_id, axes=self._axes): var, 196 DatasetStd(member_id=self._member_id, axes=self._axes): sqrt, 197 }
to calculate sample and dataset mean, variance or standard deviation
130 def compute( 131 self, sample: Sample 132 ) -> Dict[Union[SampleMean, SampleVar, SampleStd], MeasureValue]: 133 tensor = sample.members[self._member_id] 134 mean = tensor.mean(dim=self._axes) 135 c = (tensor - mean).data 136 if self._axes is None: 137 n = tensor.size 138 else: 139 n = int(np.prod([tensor.sizes[d] for d in self._axes])) 140 141 var = xr.dot(c, c, dims=self._axes) / n 142 assert isinstance(var, xr.DataArray) 143 std = np.sqrt(var) 144 assert isinstance(std, xr.DataArray) 145 return { 146 SampleMean(axes=self._axes, member_id=self._member_id): mean, 147 SampleVar(axes=self._axes, member_id=self._member_id): Tensor.from_xarray( 148 var 149 ), 150 SampleStd(axes=self._axes, member_id=self._member_id): Tensor.from_xarray( 151 std 152 ), 153 }
155 def update(self, sample: Sample): 156 tensor = sample.members[self._member_id].astype("float64", copy=False) 157 mean_b = tensor.mean(dim=self._axes) 158 assert mean_b.dtype == "float64" 159 # reduced voxel count 160 n_b = int(tensor.size / mean_b.size) 161 m2_b = ((tensor - mean_b) ** 2).sum(dim=self._axes) 162 assert m2_b.dtype == "float64" 163 if self._mean is None: 164 assert self._m2 is None 165 self._n = n_b 166 self._mean = mean_b 167 self._m2 = m2_b 168 else: 169 n_a = self._n 170 mean_a = self._mean 171 m2_a = self._m2 172 self._n = n = n_a + n_b 173 self._mean = (n_a * mean_a + n_b * mean_b) / n 174 assert self._mean.dtype == "float64" 175 d = mean_b - mean_a 176 self._m2 = m2_a + m2_b + d**2 * n_a * n_b / n 177 assert self._m2.dtype == "float64"
179 def finalize( 180 self, 181 ) -> Dict[Union[DatasetMean, DatasetVar, DatasetStd], MeasureValue]: 182 if self._mean is None: 183 return {} 184 else: 185 assert self._m2 is not None 186 var = self._m2 / self._n 187 sqrt = np.sqrt(var) 188 if isinstance(sqrt, (int, float)): 189 # var and mean are scalar tensors, let's keep it consistent 190 sqrt = Tensor.from_xarray(xr.DataArray(sqrt)) 191 192 assert isinstance(sqrt, Tensor), type(sqrt) 193 return { 194 DatasetMean(member_id=self._member_id, axes=self._axes): self._mean, 195 DatasetVar(member_id=self._member_id, axes=self._axes): var, 196 DatasetStd(member_id=self._member_id, axes=self._axes): sqrt, 197 }
200class SamplePercentilesCalculator: 201 """to calculate sample percentiles""" 202 203 def __init__( 204 self, 205 member_id: MemberId, 206 axes: Optional[Sequence[AxisId]], 207 qs: Collection[float], 208 ): 209 super().__init__() 210 assert all(0.0 <= q <= 1.0 for q in qs) 211 self._qs = sorted(set(qs)) 212 self._axes = None if axes is None else tuple(axes) 213 self._member_id = member_id 214 215 def compute(self, sample: Sample) -> Dict[SampleQuantile, MeasureValue]: 216 tensor = sample.members[self._member_id] 217 ps = tensor.quantile(self._qs, dim=self._axes) 218 return { 219 SampleQuantile(q=q, axes=self._axes, member_id=self._member_id): p 220 for q, p in zip(self._qs, ps) 221 }
to calculate sample percentiles
203 def __init__( 204 self, 205 member_id: MemberId, 206 axes: Optional[Sequence[AxisId]], 207 qs: Collection[float], 208 ): 209 super().__init__() 210 assert all(0.0 <= q <= 1.0 for q in qs) 211 self._qs = sorted(set(qs)) 212 self._axes = None if axes is None else tuple(axes) 213 self._member_id = member_id
224class MeanPercentilesCalculator: 225 """to calculate dataset percentiles heuristically by averaging across samples 226 **note**: the returned dataset percentiles are an estiamte and **not mathematically correct** 227 """ 228 229 def __init__( 230 self, 231 member_id: MemberId, 232 axes: Optional[Sequence[AxisId]], 233 qs: Collection[float], 234 ): 235 super().__init__() 236 assert all(0.0 <= q <= 1.0 for q in qs) 237 self._qs = sorted(set(qs)) 238 self._axes = None if axes is None else tuple(axes) 239 self._member_id = member_id 240 self._n: int = 0 241 self._estimates: Optional[Tensor] = None 242 243 def update(self, sample: Sample): 244 tensor = sample.members[self._member_id] 245 sample_estimates = tensor.quantile(self._qs, dim=self._axes).astype( 246 "float64", copy=False 247 ) 248 249 # reduced voxel count 250 n = int(tensor.size / np.prod(sample_estimates.shape_tuple[1:])) 251 252 if self._estimates is None: 253 assert self._n == 0 254 self._estimates = sample_estimates 255 else: 256 self._estimates = (self._n * self._estimates + n * sample_estimates) / ( 257 self._n + n 258 ) 259 assert self._estimates.dtype == "float64" 260 261 self._n += n 262 263 def finalize(self) -> Dict[DatasetPercentile, MeasureValue]: 264 if self._estimates is None: 265 return {} 266 else: 267 warnings.warn( 268 "Computed dataset percentiles naively by averaging percentiles of samples." 269 ) 270 return { 271 DatasetPercentile(q=q, axes=self._axes, member_id=self._member_id): e 272 for q, e in zip(self._qs, self._estimates) 273 }
to calculate dataset percentiles heuristically by averaging across samples note: the returned dataset percentiles are an estiamte and not mathematically correct
229 def __init__( 230 self, 231 member_id: MemberId, 232 axes: Optional[Sequence[AxisId]], 233 qs: Collection[float], 234 ): 235 super().__init__() 236 assert all(0.0 <= q <= 1.0 for q in qs) 237 self._qs = sorted(set(qs)) 238 self._axes = None if axes is None else tuple(axes) 239 self._member_id = member_id 240 self._n: int = 0 241 self._estimates: Optional[Tensor] = None
243 def update(self, sample: Sample): 244 tensor = sample.members[self._member_id] 245 sample_estimates = tensor.quantile(self._qs, dim=self._axes).astype( 246 "float64", copy=False 247 ) 248 249 # reduced voxel count 250 n = int(tensor.size / np.prod(sample_estimates.shape_tuple[1:])) 251 252 if self._estimates is None: 253 assert self._n == 0 254 self._estimates = sample_estimates 255 else: 256 self._estimates = (self._n * self._estimates + n * sample_estimates) / ( 257 self._n + n 258 ) 259 assert self._estimates.dtype == "float64" 260 261 self._n += n
263 def finalize(self) -> Dict[DatasetPercentile, MeasureValue]: 264 if self._estimates is None: 265 return {} 266 else: 267 warnings.warn( 268 "Computed dataset percentiles naively by averaging percentiles of samples." 269 ) 270 return { 271 DatasetPercentile(q=q, axes=self._axes, member_id=self._member_id): e 272 for q, e in zip(self._qs, self._estimates) 273 }
276class CrickPercentilesCalculator: 277 """to calculate dataset percentiles with the experimental [crick libray](https://github.com/dask/crick)""" 278 279 def __init__( 280 self, 281 member_id: MemberId, 282 axes: Optional[Sequence[AxisId]], 283 qs: Collection[float], 284 ): 285 warnings.warn( 286 "Computing dataset percentiles with experimental 'crick' library." 287 ) 288 super().__init__() 289 assert all(0.0 <= q <= 1.0 for q in qs) 290 assert axes is None or "_percentiles" not in axes 291 self._qs = sorted(set(qs)) 292 self._axes = None if axes is None else tuple(axes) 293 self._member_id = member_id 294 self._digest: Optional[List[TDigest]] = None 295 self._dims: Optional[Tuple[AxisId, ...]] = None 296 self._indices: Optional[Iterator[Tuple[int, ...]]] = None 297 self._shape: Optional[Tuple[int, ...]] = None 298 299 def _initialize(self, tensor_sizes: PerAxis[int]): 300 assert crick is not None 301 out_sizes: OrderedDict[AxisId, int] = collections.OrderedDict( 302 _percentiles=len(self._qs) 303 ) 304 if self._axes is not None: 305 for d, s in tensor_sizes.items(): 306 if d not in self._axes: 307 out_sizes[d] = s 308 309 self._dims, self._shape = zip(*out_sizes.items()) 310 d = int(np.prod(self._shape[1:])) # type: ignore 311 self._digest = [TDigest() for _ in range(d)] 312 self._indices = product(*map(range, self._shape[1:])) 313 314 def update(self, part: Sample): 315 tensor = ( 316 part.members[self._member_id] 317 if isinstance(part, Sample) 318 else part.members[self._member_id].data 319 ) 320 assert "_percentiles" not in tensor.dims 321 if self._digest is None: 322 self._initialize(tensor.tagged_shape) 323 324 assert self._digest is not None 325 assert self._indices is not None 326 assert self._dims is not None 327 for i, idx in enumerate(self._indices): 328 self._digest[i].update(tensor[dict(zip(self._dims[1:], idx))]) 329 330 def finalize(self) -> Dict[DatasetPercentile, MeasureValue]: 331 if self._digest is None: 332 return {} 333 else: 334 assert self._dims is not None 335 assert self._shape is not None 336 337 vs: NDArray[Any] = np.asarray( 338 [[d.quantile(q) for d in self._digest] for q in self._qs] 339 ).reshape(self._shape) 340 return { 341 DatasetPercentile( 342 q=q, axes=self._axes, member_id=self._member_id 343 ): Tensor(v, dims=self._dims[1:]) 344 for q, v in zip(self._qs, vs) 345 }
to calculate dataset percentiles with the experimental crick libray
279 def __init__( 280 self, 281 member_id: MemberId, 282 axes: Optional[Sequence[AxisId]], 283 qs: Collection[float], 284 ): 285 warnings.warn( 286 "Computing dataset percentiles with experimental 'crick' library." 287 ) 288 super().__init__() 289 assert all(0.0 <= q <= 1.0 for q in qs) 290 assert axes is None or "_percentiles" not in axes 291 self._qs = sorted(set(qs)) 292 self._axes = None if axes is None else tuple(axes) 293 self._member_id = member_id 294 self._digest: Optional[List[TDigest]] = None 295 self._dims: Optional[Tuple[AxisId, ...]] = None 296 self._indices: Optional[Iterator[Tuple[int, ...]]] = None 297 self._shape: Optional[Tuple[int, ...]] = None
314 def update(self, part: Sample): 315 tensor = ( 316 part.members[self._member_id] 317 if isinstance(part, Sample) 318 else part.members[self._member_id].data 319 ) 320 assert "_percentiles" not in tensor.dims 321 if self._digest is None: 322 self._initialize(tensor.tagged_shape) 323 324 assert self._digest is not None 325 assert self._indices is not None 326 assert self._dims is not None 327 for i, idx in enumerate(self._indices): 328 self._digest[i].update(tensor[dict(zip(self._dims[1:], idx))])
330 def finalize(self) -> Dict[DatasetPercentile, MeasureValue]: 331 if self._digest is None: 332 return {} 333 else: 334 assert self._dims is not None 335 assert self._shape is not None 336 337 vs: NDArray[Any] = np.asarray( 338 [[d.quantile(q) for d in self._digest] for q in self._qs] 339 ).reshape(self._shape) 340 return { 341 DatasetPercentile( 342 q=q, axes=self._axes, member_id=self._member_id 343 ): Tensor(v, dims=self._dims[1:]) 344 for q, v in zip(self._qs, vs) 345 }
356class NaiveSampleMeasureCalculator: 357 """wrapper for measures to match interface of other sample measure calculators""" 358 359 def __init__(self, member_id: MemberId, measure: SampleMeasure): 360 super().__init__() 361 self.tensor_name = member_id 362 self.measure = measure 363 364 def compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]: 365 return {self.measure: self.measure.compute(sample)}
wrapper for measures to match interface of other sample measure calculators
379class StatsCalculator: 380 """Estimates dataset statistics and computes sample statistics efficiently""" 381 382 def __init__( 383 self, 384 measures: Collection[Measure], 385 initial_dataset_measures: Optional[ 386 Mapping[DatasetMeasure, MeasureValue] 387 ] = None, 388 ): 389 super().__init__() 390 self.sample_count = 0 391 self.sample_calculators, self.dataset_calculators = get_measure_calculators( 392 measures 393 ) 394 if not initial_dataset_measures: 395 self._current_dataset_measures: Optional[ 396 Dict[DatasetMeasure, MeasureValue] 397 ] = None 398 else: 399 missing_dataset_meas = { 400 m 401 for m in measures 402 if isinstance(m, DatasetMeasureBase) 403 and m not in initial_dataset_measures 404 } 405 if missing_dataset_meas: 406 logger.debug( 407 f"ignoring `initial_dataset_measure` as it is missing {missing_dataset_meas}" 408 ) 409 self._current_dataset_measures = None 410 else: 411 self._current_dataset_measures = dict(initial_dataset_measures) 412 413 @property 414 def has_dataset_measures(self): 415 return self._current_dataset_measures is not None 416 417 def update( 418 self, 419 sample: Union[Sample, Iterable[Sample]], 420 ) -> None: 421 _ = self._update(sample) 422 423 def finalize(self) -> Dict[DatasetMeasure, MeasureValue]: 424 """returns aggregated dataset statistics""" 425 if self._current_dataset_measures is None: 426 self._current_dataset_measures = {} 427 for calc in self.dataset_calculators: 428 values = calc.finalize() 429 self._current_dataset_measures.update(values.items()) 430 431 return self._current_dataset_measures 432 433 def update_and_get_all( 434 self, 435 sample: Union[Sample, Iterable[Sample]], 436 ) -> Dict[Measure, MeasureValue]: 437 """Returns sample as well as updated dataset statistics""" 438 last_sample = self._update(sample) 439 if last_sample is None: 440 raise ValueError("`sample` was not a `Sample`, nor did it yield any.") 441 442 return {**self._compute(last_sample), **self.finalize()} 443 444 def skip_update_and_get_all(self, sample: Sample) -> Dict[Measure, MeasureValue]: 445 """Returns sample as well as previously computed dataset statistics""" 446 return {**self._compute(sample), **self.finalize()} 447 448 def _compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]: 449 ret: Dict[SampleMeasure, MeasureValue] = {} 450 for calc in self.sample_calculators: 451 values = calc.compute(sample) 452 ret.update(values.items()) 453 454 return ret 455 456 def _update(self, sample: Union[Sample, Iterable[Sample]]) -> Optional[Sample]: 457 self.sample_count += 1 458 samples = [sample] if isinstance(sample, Sample) else sample 459 last_sample = None 460 for el in samples: 461 last_sample = el 462 for calc in self.dataset_calculators: 463 calc.update(el) 464 465 self._current_dataset_measures = None 466 return last_sample
Estimates dataset statistics and computes sample statistics efficiently
382 def __init__( 383 self, 384 measures: Collection[Measure], 385 initial_dataset_measures: Optional[ 386 Mapping[DatasetMeasure, MeasureValue] 387 ] = None, 388 ): 389 super().__init__() 390 self.sample_count = 0 391 self.sample_calculators, self.dataset_calculators = get_measure_calculators( 392 measures 393 ) 394 if not initial_dataset_measures: 395 self._current_dataset_measures: Optional[ 396 Dict[DatasetMeasure, MeasureValue] 397 ] = None 398 else: 399 missing_dataset_meas = { 400 m 401 for m in measures 402 if isinstance(m, DatasetMeasureBase) 403 and m not in initial_dataset_measures 404 } 405 if missing_dataset_meas: 406 logger.debug( 407 f"ignoring `initial_dataset_measure` as it is missing {missing_dataset_meas}" 408 ) 409 self._current_dataset_measures = None 410 else: 411 self._current_dataset_measures = dict(initial_dataset_measures)
423 def finalize(self) -> Dict[DatasetMeasure, MeasureValue]: 424 """returns aggregated dataset statistics""" 425 if self._current_dataset_measures is None: 426 self._current_dataset_measures = {} 427 for calc in self.dataset_calculators: 428 values = calc.finalize() 429 self._current_dataset_measures.update(values.items()) 430 431 return self._current_dataset_measures
returns aggregated dataset statistics
433 def update_and_get_all( 434 self, 435 sample: Union[Sample, Iterable[Sample]], 436 ) -> Dict[Measure, MeasureValue]: 437 """Returns sample as well as updated dataset statistics""" 438 last_sample = self._update(sample) 439 if last_sample is None: 440 raise ValueError("`sample` was not a `Sample`, nor did it yield any.") 441 442 return {**self._compute(last_sample), **self.finalize()}
Returns sample as well as updated dataset statistics
444 def skip_update_and_get_all(self, sample: Sample) -> Dict[Measure, MeasureValue]: 445 """Returns sample as well as previously computed dataset statistics""" 446 return {**self._compute(sample), **self.finalize()}
Returns sample as well as previously computed dataset statistics
469def get_measure_calculators( 470 required_measures: Iterable[Measure], 471) -> Tuple[List[SampleMeasureCalculator], List[DatasetMeasureCalculator]]: 472 """determines which calculators are needed to compute the required measures efficiently""" 473 474 sample_calculators: List[SampleMeasureCalculator] = [] 475 dataset_calculators: List[DatasetMeasureCalculator] = [] 476 477 # split required measures into groups 478 required_sample_means: Set[SampleMean] = set() 479 required_dataset_means: Set[DatasetMean] = set() 480 required_sample_mean_var_std: Set[Union[SampleMean, SampleVar, SampleStd]] = set() 481 required_dataset_mean_var_std: Set[Union[DatasetMean, DatasetVar, DatasetStd]] = ( 482 set() 483 ) 484 required_sample_percentiles: Dict[ 485 Tuple[MemberId, Optional[Tuple[AxisId, ...]]], Set[float] 486 ] = {} 487 required_dataset_percentiles: Dict[ 488 Tuple[MemberId, Optional[Tuple[AxisId, ...]]], Set[float] 489 ] = {} 490 491 for rm in required_measures: 492 if isinstance(rm, SampleMean): 493 required_sample_means.add(rm) 494 elif isinstance(rm, DatasetMean): 495 required_dataset_means.add(rm) 496 elif isinstance(rm, (SampleVar, SampleStd)): 497 required_sample_mean_var_std.update( 498 { 499 msv(axes=rm.axes, member_id=rm.member_id) 500 for msv in (SampleMean, SampleStd, SampleVar) 501 } 502 ) 503 assert rm in required_sample_mean_var_std 504 elif isinstance(rm, (DatasetVar, DatasetStd)): 505 required_dataset_mean_var_std.update( 506 { 507 msv(axes=rm.axes, member_id=rm.member_id) 508 for msv in (DatasetMean, DatasetStd, DatasetVar) 509 } 510 ) 511 assert rm in required_dataset_mean_var_std 512 elif isinstance(rm, SampleQuantile): 513 required_sample_percentiles.setdefault((rm.member_id, rm.axes), set()).add( 514 rm.q 515 ) 516 elif isinstance(rm, DatasetPercentile): 517 required_dataset_percentiles.setdefault((rm.member_id, rm.axes), set()).add( 518 rm.q 519 ) 520 else: 521 assert_never(rm) 522 523 for rm in required_sample_means: 524 if rm in required_sample_mean_var_std: 525 # computed togehter with var and std 526 continue 527 528 sample_calculators.append(MeanCalculator(member_id=rm.member_id, axes=rm.axes)) 529 530 for rm in required_sample_mean_var_std: 531 sample_calculators.append( 532 MeanVarStdCalculator(member_id=rm.member_id, axes=rm.axes) 533 ) 534 535 for rm in required_dataset_means: 536 if rm in required_dataset_mean_var_std: 537 # computed togehter with var and std 538 continue 539 540 dataset_calculators.append(MeanCalculator(member_id=rm.member_id, axes=rm.axes)) 541 542 for rm in required_dataset_mean_var_std: 543 dataset_calculators.append( 544 MeanVarStdCalculator(member_id=rm.member_id, axes=rm.axes) 545 ) 546 547 for (tid, axes), qs in required_sample_percentiles.items(): 548 sample_calculators.append( 549 SamplePercentilesCalculator(member_id=tid, axes=axes, qs=qs) 550 ) 551 552 for (tid, axes), qs in required_dataset_percentiles.items(): 553 dataset_calculators.append( 554 DatasetPercentilesCalculator(member_id=tid, axes=axes, qs=qs) 555 ) 556 557 return sample_calculators, dataset_calculators
determines which calculators are needed to compute the required measures efficiently
560def compute_dataset_measures( 561 measures: Iterable[DatasetMeasure], dataset: Iterable[Sample] 562) -> Dict[DatasetMeasure, MeasureValue]: 563 """compute all dataset `measures` for the given `dataset`""" 564 sample_calculators, calculators = get_measure_calculators(measures) 565 assert not sample_calculators 566 567 ret: Dict[DatasetMeasure, MeasureValue] = {} 568 569 for sample in dataset: 570 for calc in calculators: 571 calc.update(sample) 572 573 for calc in calculators: 574 ret.update(calc.finalize().items()) 575 576 return ret
compute all dataset measures
for the given dataset
579def compute_sample_measures( 580 measures: Iterable[SampleMeasure], sample: Sample 581) -> Dict[SampleMeasure, MeasureValue]: 582 """compute all sample `measures` for the given `sample`""" 583 calculators, dataset_calculators = get_measure_calculators(measures) 584 assert not dataset_calculators 585 ret: Dict[SampleMeasure, MeasureValue] = {} 586 587 for calc in calculators: 588 ret.update(calc.compute(sample).items()) 589 590 return ret
compute all sample measures
for the given sample
593def compute_measures( 594 measures: Iterable[Measure], dataset: Iterable[Sample] 595) -> Dict[Measure, MeasureValue]: 596 """compute all `measures` for the given `dataset` 597 sample measures are computed for the last sample in `dataset`""" 598 sample_calculators, dataset_calculators = get_measure_calculators(measures) 599 ret: Dict[Measure, MeasureValue] = {} 600 sample = None 601 for sample in dataset: 602 for calc in dataset_calculators: 603 calc.update(sample) 604 if sample is None: 605 raise ValueError("empty dataset") 606 607 for calc in dataset_calculators: 608 ret.update(calc.finalize().items()) 609 610 for calc in sample_calculators: 611 ret.update(calc.compute(sample).items()) 612 613 return ret
compute all measures
for the given dataset
sample measures are computed for the last sample in dataset