bioimageio.core.stat_calculators
1from __future__ import annotations 2 3import collections 4import warnings 5from itertools import product 6from typing import ( 7 Any, 8 Collection, 9 Dict, 10 Iterable, 11 Iterator, 12 List, 13 Mapping, 14 Optional, 15 OrderedDict, 16 Sequence, 17 Set, 18 Tuple, 19 Type, 20 Union, 21) 22 23import numpy as np 24import xarray as xr 25from bioimageio.spec.model.v0_5 import BATCH_AXIS_ID 26from loguru import logger 27from numpy.typing import NDArray 28from typing_extensions import assert_never 29 30from .axis import AxisId, PerAxis 31from .common import MemberId 32from .sample import Sample 33from .stat_measures import ( 34 DatasetMean, 35 DatasetMeasure, 36 DatasetMeasureBase, 37 DatasetPercentile, 38 DatasetStd, 39 DatasetVar, 40 Measure, 41 MeasureValue, 42 SampleMean, 43 SampleMeasure, 44 SampleQuantile, 45 SampleStd, 46 SampleVar, 47) 48from .tensor import Tensor 49 50try: 51 import crick # pyright: ignore[reportMissingTypeStubs] 52 53except Exception: 54 crick = None 55 56 class TDigest: 57 def update(self, obj: Any): 58 pass 59 60 def quantile(self, q: Any) -> Any: 61 pass 62 63else: 64 TDigest = crick.TDigest # type: ignore 65 66 67class MeanCalculator: 68 """to calculate sample and dataset mean for in-memory samples""" 69 70 def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]): 71 super().__init__() 72 self._n: int = 0 73 self._mean: Optional[Tensor] = None 74 self._axes = None if axes is None else tuple(axes) 75 self._member_id = member_id 76 self._sample_mean = SampleMean(member_id=self._member_id, axes=self._axes) 77 self._dataset_mean = DatasetMean(member_id=self._member_id, axes=self._axes) 78 79 def compute(self, sample: Sample) -> Dict[SampleMean, MeasureValue]: 80 return {self._sample_mean: self._compute_impl(sample)} 81 82 def _compute_impl(self, sample: Sample) -> Tensor: 83 tensor = sample.members[self._member_id].astype("float64", copy=False) 84 return tensor.mean(dim=self._axes) 85 86 def update(self, sample: Sample) -> None: 87 mean = self._compute_impl(sample) 88 self._update_impl(sample.members[self._member_id], mean) 89 90 def compute_and_update(self, sample: Sample) -> Dict[SampleMean, MeasureValue]: 91 mean = self._compute_impl(sample) 92 self._update_impl(sample.members[self._member_id], mean) 93 return {self._sample_mean: mean} 94 95 def _update_impl(self, tensor: Tensor, tensor_mean: Tensor): 96 assert tensor_mean.dtype == "float64" 97 # reduced voxel count 98 n_b = int(tensor.size / tensor_mean.size) 99 100 if self._mean is None: 101 assert self._n == 0 102 self._n = n_b 103 self._mean = tensor_mean 104 else: 105 assert self._n != 0 106 n_a = self._n 107 mean_old = self._mean 108 self._n = n_a + n_b 109 self._mean = (n_a * mean_old + n_b * tensor_mean) / self._n 110 assert self._mean.dtype == "float64" 111 112 def finalize(self) -> Dict[DatasetMean, MeasureValue]: 113 if self._mean is None: 114 return {} 115 else: 116 return {self._dataset_mean: self._mean} 117 118 119class MeanVarStdCalculator: 120 """to calculate sample and dataset mean, variance or standard deviation""" 121 122 def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]): 123 super().__init__() 124 self._axes = None if axes is None else tuple(map(AxisId, axes)) 125 self._member_id = member_id 126 self._n: int = 0 127 self._mean: Optional[Tensor] = None 128 self._m2: Optional[Tensor] = None 129 130 def compute( 131 self, sample: Sample 132 ) -> Dict[Union[SampleMean, SampleVar, SampleStd], MeasureValue]: 133 tensor = sample.members[self._member_id] 134 mean = tensor.mean(dim=self._axes) 135 c = (tensor - mean).data 136 if self._axes is None: 137 n = tensor.size 138 else: 139 n = int(np.prod([tensor.sizes[d] for d in self._axes])) 140 141 if xr.__version__.startswith("2023"): 142 var = xr.dot(c, c, dims=self._axes) / n 143 else: 144 var = xr.dot(c, c, dim=self._axes) / n 145 146 assert isinstance(var, xr.DataArray) 147 std = np.sqrt(var) 148 assert isinstance(std, xr.DataArray) 149 return { 150 SampleMean(axes=self._axes, member_id=self._member_id): mean, 151 SampleVar(axes=self._axes, member_id=self._member_id): Tensor.from_xarray( 152 var 153 ), 154 SampleStd(axes=self._axes, member_id=self._member_id): Tensor.from_xarray( 155 std 156 ), 157 } 158 159 def update(self, sample: Sample): 160 if self._axes is not None and BATCH_AXIS_ID not in self._axes: 161 return 162 163 tensor = sample.members[self._member_id].astype("float64", copy=False) 164 mean_b = tensor.mean(dim=self._axes) 165 assert mean_b.dtype == "float64" 166 # reduced voxel count 167 n_b = int(tensor.size / mean_b.size) 168 m2_b = ((tensor - mean_b) ** 2).sum(dim=self._axes) 169 assert m2_b.dtype == "float64" 170 if self._mean is None: 171 assert self._m2 is None 172 self._n = n_b 173 self._mean = mean_b 174 self._m2 = m2_b 175 else: 176 n_a = self._n 177 mean_a = self._mean 178 m2_a = self._m2 179 self._n = n = n_a + n_b 180 self._mean = (n_a * mean_a + n_b * mean_b) / n 181 assert self._mean.dtype == "float64" 182 d = mean_b - mean_a 183 self._m2 = m2_a + m2_b + d**2 * n_a * n_b / n 184 assert self._m2.dtype == "float64" 185 186 def finalize( 187 self, 188 ) -> Dict[Union[DatasetMean, DatasetVar, DatasetStd], MeasureValue]: 189 if ( 190 self._axes is not None 191 and BATCH_AXIS_ID not in self._axes 192 or self._mean is None 193 ): 194 return {} 195 else: 196 assert self._m2 is not None 197 var = self._m2 / self._n 198 sqrt = var**0.5 199 if isinstance(sqrt, (int, float)): 200 # var and mean are scalar tensors, let's keep it consistent 201 sqrt = Tensor.from_xarray(xr.DataArray(sqrt)) 202 203 assert isinstance(sqrt, Tensor), type(sqrt) 204 return { 205 DatasetMean(member_id=self._member_id, axes=self._axes): self._mean, 206 DatasetVar(member_id=self._member_id, axes=self._axes): var, 207 DatasetStd(member_id=self._member_id, axes=self._axes): sqrt, 208 } 209 210 211class SamplePercentilesCalculator: 212 """to calculate sample percentiles""" 213 214 def __init__( 215 self, 216 member_id: MemberId, 217 axes: Optional[Sequence[AxisId]], 218 qs: Collection[float], 219 ): 220 super().__init__() 221 assert all(0.0 <= q <= 1.0 for q in qs) 222 self._qs = sorted(set(qs)) 223 self._axes = None if axes is None else tuple(axes) 224 self._member_id = member_id 225 226 def compute(self, sample: Sample) -> Dict[SampleQuantile, MeasureValue]: 227 tensor = sample.members[self._member_id] 228 ps = tensor.quantile(self._qs, dim=self._axes) 229 return { 230 SampleQuantile(q=q, axes=self._axes, member_id=self._member_id): p 231 for q, p in zip(self._qs, ps) 232 } 233 234 235class MeanPercentilesCalculator: 236 """to calculate dataset percentiles heuristically by averaging across samples 237 **note**: the returned dataset percentiles are an estiamte and **not mathematically correct** 238 """ 239 240 def __init__( 241 self, 242 member_id: MemberId, 243 axes: Optional[Sequence[AxisId]], 244 qs: Collection[float], 245 ): 246 super().__init__() 247 assert all(0.0 <= q <= 1.0 for q in qs) 248 self._qs = sorted(set(qs)) 249 self._axes = None if axes is None else tuple(axes) 250 self._member_id = member_id 251 self._n: int = 0 252 self._estimates: Optional[Tensor] = None 253 254 def update(self, sample: Sample): 255 tensor = sample.members[self._member_id] 256 sample_estimates = tensor.quantile(self._qs, dim=self._axes).astype( 257 "float64", copy=False 258 ) 259 260 # reduced voxel count 261 n = int(tensor.size / np.prod(sample_estimates.shape_tuple[1:])) 262 263 if self._estimates is None: 264 assert self._n == 0 265 self._estimates = sample_estimates 266 else: 267 self._estimates = (self._n * self._estimates + n * sample_estimates) / ( 268 self._n + n 269 ) 270 assert self._estimates.dtype == "float64" 271 272 self._n += n 273 274 def finalize(self) -> Dict[DatasetPercentile, MeasureValue]: 275 if self._estimates is None: 276 return {} 277 else: 278 warnings.warn( 279 "Computed dataset percentiles naively by averaging percentiles of samples." 280 ) 281 return { 282 DatasetPercentile(q=q, axes=self._axes, member_id=self._member_id): e 283 for q, e in zip(self._qs, self._estimates) 284 } 285 286 287class CrickPercentilesCalculator: 288 """to calculate dataset percentiles with the experimental [crick libray](https://github.com/dask/crick)""" 289 290 def __init__( 291 self, 292 member_id: MemberId, 293 axes: Optional[Sequence[AxisId]], 294 qs: Collection[float], 295 ): 296 warnings.warn( 297 "Computing dataset percentiles with experimental 'crick' library." 298 ) 299 super().__init__() 300 assert all(0.0 <= q <= 1.0 for q in qs) 301 assert axes is None or "_percentiles" not in axes 302 self._qs = sorted(set(qs)) 303 self._axes = None if axes is None else tuple(axes) 304 self._member_id = member_id 305 self._digest: Optional[List[TDigest]] = None 306 self._dims: Optional[Tuple[AxisId, ...]] = None 307 self._indices: Optional[Iterator[Tuple[int, ...]]] = None 308 self._shape: Optional[Tuple[int, ...]] = None 309 310 def _initialize(self, tensor_sizes: PerAxis[int]): 311 assert crick is not None 312 out_sizes: OrderedDict[AxisId, int] = collections.OrderedDict( 313 _percentiles=len(self._qs) 314 ) 315 if self._axes is not None: 316 for d, s in tensor_sizes.items(): 317 if d not in self._axes: 318 out_sizes[d] = s 319 320 self._dims, self._shape = zip(*out_sizes.items()) 321 assert self._shape is not None 322 d = int(np.prod(self._shape[1:])) 323 self._digest = [TDigest() for _ in range(d)] 324 self._indices = product(*map(range, self._shape[1:])) 325 326 def update(self, part: Sample): 327 tensor = ( 328 part.members[self._member_id] 329 if isinstance(part, Sample) 330 else part.members[self._member_id].data 331 ) 332 assert "_percentiles" not in tensor.dims 333 if self._digest is None: 334 self._initialize(tensor.tagged_shape) 335 336 assert self._digest is not None 337 assert self._indices is not None 338 assert self._dims is not None 339 for i, idx in enumerate(self._indices): 340 self._digest[i].update(tensor[dict(zip(self._dims[1:], idx))]) 341 342 def finalize(self) -> Dict[DatasetPercentile, MeasureValue]: 343 if self._digest is None: 344 return {} 345 else: 346 assert self._dims is not None 347 assert self._shape is not None 348 349 vs: NDArray[Any] = np.asarray( 350 [[d.quantile(q) for d in self._digest] for q in self._qs] 351 ).reshape(self._shape) 352 return { 353 DatasetPercentile( 354 q=q, axes=self._axes, member_id=self._member_id 355 ): Tensor(v, dims=self._dims[1:]) 356 for q, v in zip(self._qs, vs) 357 } 358 359 360if crick is None: 361 DatasetPercentilesCalculator: Type[ 362 Union[MeanPercentilesCalculator, CrickPercentilesCalculator] 363 ] = MeanPercentilesCalculator 364else: 365 DatasetPercentilesCalculator = CrickPercentilesCalculator 366 367 368class NaiveSampleMeasureCalculator: 369 """wrapper for measures to match interface of other sample measure calculators""" 370 371 def __init__(self, member_id: MemberId, measure: SampleMeasure): 372 super().__init__() 373 self.tensor_name = member_id 374 self.measure = measure 375 376 def compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]: 377 return {self.measure: self.measure.compute(sample)} 378 379 380SampleMeasureCalculator = Union[ 381 MeanCalculator, 382 MeanVarStdCalculator, 383 SamplePercentilesCalculator, 384 NaiveSampleMeasureCalculator, 385] 386DatasetMeasureCalculator = Union[ 387 MeanCalculator, MeanVarStdCalculator, DatasetPercentilesCalculator 388] 389 390 391class StatsCalculator: 392 """Estimates dataset statistics and computes sample statistics efficiently""" 393 394 def __init__( 395 self, 396 measures: Collection[Measure], 397 initial_dataset_measures: Optional[ 398 Mapping[DatasetMeasure, MeasureValue] 399 ] = None, 400 ): 401 super().__init__() 402 self.sample_count = 0 403 self.sample_calculators, self.dataset_calculators = get_measure_calculators( 404 measures 405 ) 406 if not initial_dataset_measures: 407 self._current_dataset_measures: Optional[ 408 Dict[DatasetMeasure, MeasureValue] 409 ] = None 410 else: 411 missing_dataset_meas = { 412 m 413 for m in measures 414 if isinstance(m, DatasetMeasureBase) 415 and m not in initial_dataset_measures 416 } 417 if missing_dataset_meas: 418 logger.debug( 419 f"ignoring `initial_dataset_measure` as it is missing {missing_dataset_meas}" 420 ) 421 self._current_dataset_measures = None 422 else: 423 self._current_dataset_measures = dict(initial_dataset_measures) 424 425 @property 426 def has_dataset_measures(self): 427 return self._current_dataset_measures is not None 428 429 def update( 430 self, 431 sample: Union[Sample, Iterable[Sample]], 432 ) -> None: 433 _ = self._update(sample) 434 435 def finalize(self) -> Dict[DatasetMeasure, MeasureValue]: 436 """returns aggregated dataset statistics""" 437 if self._current_dataset_measures is None: 438 self._current_dataset_measures = {} 439 for calc in self.dataset_calculators: 440 values = calc.finalize() 441 self._current_dataset_measures.update(values.items()) 442 443 return self._current_dataset_measures 444 445 def update_and_get_all( 446 self, 447 sample: Union[Sample, Iterable[Sample]], 448 ) -> Dict[Measure, MeasureValue]: 449 """Returns sample as well as updated dataset statistics""" 450 last_sample = self._update(sample) 451 if last_sample is None: 452 raise ValueError("`sample` was not a `Sample`, nor did it yield any.") 453 454 return {**self._compute(last_sample), **self.finalize()} 455 456 def skip_update_and_get_all(self, sample: Sample) -> Dict[Measure, MeasureValue]: 457 """Returns sample as well as previously computed dataset statistics""" 458 return {**self._compute(sample), **self.finalize()} 459 460 def _compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]: 461 ret: Dict[SampleMeasure, MeasureValue] = {} 462 for calc in self.sample_calculators: 463 values = calc.compute(sample) 464 ret.update(values.items()) 465 466 return ret 467 468 def _update(self, sample: Union[Sample, Iterable[Sample]]) -> Optional[Sample]: 469 self.sample_count += 1 470 samples = [sample] if isinstance(sample, Sample) else sample 471 last_sample = None 472 for el in samples: 473 last_sample = el 474 for calc in self.dataset_calculators: 475 calc.update(el) 476 477 self._current_dataset_measures = None 478 return last_sample 479 480 481def get_measure_calculators( 482 required_measures: Iterable[Measure], 483) -> Tuple[List[SampleMeasureCalculator], List[DatasetMeasureCalculator]]: 484 """determines which calculators are needed to compute the required measures efficiently""" 485 486 sample_calculators: List[SampleMeasureCalculator] = [] 487 dataset_calculators: List[DatasetMeasureCalculator] = [] 488 489 # split required measures into groups 490 required_sample_means: Set[SampleMean] = set() 491 required_dataset_means: Set[DatasetMean] = set() 492 required_sample_mean_var_std: Set[Union[SampleMean, SampleVar, SampleStd]] = set() 493 required_dataset_mean_var_std: Set[Union[DatasetMean, DatasetVar, DatasetStd]] = ( 494 set() 495 ) 496 required_sample_percentiles: Dict[ 497 Tuple[MemberId, Optional[Tuple[AxisId, ...]]], Set[float] 498 ] = {} 499 required_dataset_percentiles: Dict[ 500 Tuple[MemberId, Optional[Tuple[AxisId, ...]]], Set[float] 501 ] = {} 502 503 for rm in required_measures: 504 if isinstance(rm, SampleMean): 505 required_sample_means.add(rm) 506 elif isinstance(rm, DatasetMean): 507 required_dataset_means.add(rm) 508 elif isinstance(rm, (SampleVar, SampleStd)): 509 required_sample_mean_var_std.update( 510 { 511 msv(axes=rm.axes, member_id=rm.member_id) 512 for msv in (SampleMean, SampleStd, SampleVar) 513 } 514 ) 515 assert rm in required_sample_mean_var_std 516 elif isinstance(rm, (DatasetVar, DatasetStd)): 517 required_dataset_mean_var_std.update( 518 { 519 msv(axes=rm.axes, member_id=rm.member_id) 520 for msv in (DatasetMean, DatasetStd, DatasetVar) 521 } 522 ) 523 assert rm in required_dataset_mean_var_std 524 elif isinstance(rm, SampleQuantile): 525 required_sample_percentiles.setdefault((rm.member_id, rm.axes), set()).add( 526 rm.q 527 ) 528 elif isinstance(rm, DatasetPercentile): 529 required_dataset_percentiles.setdefault((rm.member_id, rm.axes), set()).add( 530 rm.q 531 ) 532 else: 533 assert_never(rm) 534 535 for rm in required_sample_means: 536 if rm in required_sample_mean_var_std: 537 # computed togehter with var and std 538 continue 539 540 sample_calculators.append(MeanCalculator(member_id=rm.member_id, axes=rm.axes)) 541 542 for rm in required_sample_mean_var_std: 543 sample_calculators.append( 544 MeanVarStdCalculator(member_id=rm.member_id, axes=rm.axes) 545 ) 546 547 for rm in required_dataset_means: 548 if rm in required_dataset_mean_var_std: 549 # computed togehter with var and std 550 continue 551 552 dataset_calculators.append(MeanCalculator(member_id=rm.member_id, axes=rm.axes)) 553 554 for rm in required_dataset_mean_var_std: 555 dataset_calculators.append( 556 MeanVarStdCalculator(member_id=rm.member_id, axes=rm.axes) 557 ) 558 559 for (tid, axes), qs in required_sample_percentiles.items(): 560 sample_calculators.append( 561 SamplePercentilesCalculator(member_id=tid, axes=axes, qs=qs) 562 ) 563 564 for (tid, axes), qs in required_dataset_percentiles.items(): 565 dataset_calculators.append( 566 DatasetPercentilesCalculator(member_id=tid, axes=axes, qs=qs) 567 ) 568 569 return sample_calculators, dataset_calculators 570 571 572def compute_dataset_measures( 573 measures: Iterable[DatasetMeasure], dataset: Iterable[Sample] 574) -> Dict[DatasetMeasure, MeasureValue]: 575 """compute all dataset `measures` for the given `dataset`""" 576 sample_calculators, calculators = get_measure_calculators(measures) 577 assert not sample_calculators 578 579 ret: Dict[DatasetMeasure, MeasureValue] = {} 580 581 for sample in dataset: 582 for calc in calculators: 583 calc.update(sample) 584 585 for calc in calculators: 586 ret.update(calc.finalize().items()) 587 588 return ret 589 590 591def compute_sample_measures( 592 measures: Iterable[SampleMeasure], sample: Sample 593) -> Dict[SampleMeasure, MeasureValue]: 594 """compute all sample `measures` for the given `sample`""" 595 calculators, dataset_calculators = get_measure_calculators(measures) 596 assert not dataset_calculators 597 ret: Dict[SampleMeasure, MeasureValue] = {} 598 599 for calc in calculators: 600 ret.update(calc.compute(sample).items()) 601 602 return ret 603 604 605def compute_measures( 606 measures: Iterable[Measure], dataset: Iterable[Sample] 607) -> Dict[Measure, MeasureValue]: 608 """compute all `measures` for the given `dataset` 609 sample measures are computed for the last sample in `dataset`""" 610 sample_calculators, dataset_calculators = get_measure_calculators(measures) 611 ret: Dict[Measure, MeasureValue] = {} 612 sample = None 613 for sample in dataset: 614 for calc in dataset_calculators: 615 calc.update(sample) 616 if sample is None: 617 raise ValueError("empty dataset") 618 619 for calc in dataset_calculators: 620 ret.update(calc.finalize().items()) 621 622 for calc in sample_calculators: 623 ret.update(calc.compute(sample).items()) 624 625 return ret
68class MeanCalculator: 69 """to calculate sample and dataset mean for in-memory samples""" 70 71 def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]): 72 super().__init__() 73 self._n: int = 0 74 self._mean: Optional[Tensor] = None 75 self._axes = None if axes is None else tuple(axes) 76 self._member_id = member_id 77 self._sample_mean = SampleMean(member_id=self._member_id, axes=self._axes) 78 self._dataset_mean = DatasetMean(member_id=self._member_id, axes=self._axes) 79 80 def compute(self, sample: Sample) -> Dict[SampleMean, MeasureValue]: 81 return {self._sample_mean: self._compute_impl(sample)} 82 83 def _compute_impl(self, sample: Sample) -> Tensor: 84 tensor = sample.members[self._member_id].astype("float64", copy=False) 85 return tensor.mean(dim=self._axes) 86 87 def update(self, sample: Sample) -> None: 88 mean = self._compute_impl(sample) 89 self._update_impl(sample.members[self._member_id], mean) 90 91 def compute_and_update(self, sample: Sample) -> Dict[SampleMean, MeasureValue]: 92 mean = self._compute_impl(sample) 93 self._update_impl(sample.members[self._member_id], mean) 94 return {self._sample_mean: mean} 95 96 def _update_impl(self, tensor: Tensor, tensor_mean: Tensor): 97 assert tensor_mean.dtype == "float64" 98 # reduced voxel count 99 n_b = int(tensor.size / tensor_mean.size) 100 101 if self._mean is None: 102 assert self._n == 0 103 self._n = n_b 104 self._mean = tensor_mean 105 else: 106 assert self._n != 0 107 n_a = self._n 108 mean_old = self._mean 109 self._n = n_a + n_b 110 self._mean = (n_a * mean_old + n_b * tensor_mean) / self._n 111 assert self._mean.dtype == "float64" 112 113 def finalize(self) -> Dict[DatasetMean, MeasureValue]: 114 if self._mean is None: 115 return {} 116 else: 117 return {self._dataset_mean: self._mean}
to calculate sample and dataset mean for in-memory samples
71 def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]): 72 super().__init__() 73 self._n: int = 0 74 self._mean: Optional[Tensor] = None 75 self._axes = None if axes is None else tuple(axes) 76 self._member_id = member_id 77 self._sample_mean = SampleMean(member_id=self._member_id, axes=self._axes) 78 self._dataset_mean = DatasetMean(member_id=self._member_id, axes=self._axes)
120class MeanVarStdCalculator: 121 """to calculate sample and dataset mean, variance or standard deviation""" 122 123 def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]): 124 super().__init__() 125 self._axes = None if axes is None else tuple(map(AxisId, axes)) 126 self._member_id = member_id 127 self._n: int = 0 128 self._mean: Optional[Tensor] = None 129 self._m2: Optional[Tensor] = None 130 131 def compute( 132 self, sample: Sample 133 ) -> Dict[Union[SampleMean, SampleVar, SampleStd], MeasureValue]: 134 tensor = sample.members[self._member_id] 135 mean = tensor.mean(dim=self._axes) 136 c = (tensor - mean).data 137 if self._axes is None: 138 n = tensor.size 139 else: 140 n = int(np.prod([tensor.sizes[d] for d in self._axes])) 141 142 if xr.__version__.startswith("2023"): 143 var = xr.dot(c, c, dims=self._axes) / n 144 else: 145 var = xr.dot(c, c, dim=self._axes) / n 146 147 assert isinstance(var, xr.DataArray) 148 std = np.sqrt(var) 149 assert isinstance(std, xr.DataArray) 150 return { 151 SampleMean(axes=self._axes, member_id=self._member_id): mean, 152 SampleVar(axes=self._axes, member_id=self._member_id): Tensor.from_xarray( 153 var 154 ), 155 SampleStd(axes=self._axes, member_id=self._member_id): Tensor.from_xarray( 156 std 157 ), 158 } 159 160 def update(self, sample: Sample): 161 if self._axes is not None and BATCH_AXIS_ID not in self._axes: 162 return 163 164 tensor = sample.members[self._member_id].astype("float64", copy=False) 165 mean_b = tensor.mean(dim=self._axes) 166 assert mean_b.dtype == "float64" 167 # reduced voxel count 168 n_b = int(tensor.size / mean_b.size) 169 m2_b = ((tensor - mean_b) ** 2).sum(dim=self._axes) 170 assert m2_b.dtype == "float64" 171 if self._mean is None: 172 assert self._m2 is None 173 self._n = n_b 174 self._mean = mean_b 175 self._m2 = m2_b 176 else: 177 n_a = self._n 178 mean_a = self._mean 179 m2_a = self._m2 180 self._n = n = n_a + n_b 181 self._mean = (n_a * mean_a + n_b * mean_b) / n 182 assert self._mean.dtype == "float64" 183 d = mean_b - mean_a 184 self._m2 = m2_a + m2_b + d**2 * n_a * n_b / n 185 assert self._m2.dtype == "float64" 186 187 def finalize( 188 self, 189 ) -> Dict[Union[DatasetMean, DatasetVar, DatasetStd], MeasureValue]: 190 if ( 191 self._axes is not None 192 and BATCH_AXIS_ID not in self._axes 193 or self._mean is None 194 ): 195 return {} 196 else: 197 assert self._m2 is not None 198 var = self._m2 / self._n 199 sqrt = var**0.5 200 if isinstance(sqrt, (int, float)): 201 # var and mean are scalar tensors, let's keep it consistent 202 sqrt = Tensor.from_xarray(xr.DataArray(sqrt)) 203 204 assert isinstance(sqrt, Tensor), type(sqrt) 205 return { 206 DatasetMean(member_id=self._member_id, axes=self._axes): self._mean, 207 DatasetVar(member_id=self._member_id, axes=self._axes): var, 208 DatasetStd(member_id=self._member_id, axes=self._axes): sqrt, 209 }
to calculate sample and dataset mean, variance or standard deviation
123 def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]): 124 super().__init__() 125 self._axes = None if axes is None else tuple(map(AxisId, axes)) 126 self._member_id = member_id 127 self._n: int = 0 128 self._mean: Optional[Tensor] = None 129 self._m2: Optional[Tensor] = None
131 def compute( 132 self, sample: Sample 133 ) -> Dict[Union[SampleMean, SampleVar, SampleStd], MeasureValue]: 134 tensor = sample.members[self._member_id] 135 mean = tensor.mean(dim=self._axes) 136 c = (tensor - mean).data 137 if self._axes is None: 138 n = tensor.size 139 else: 140 n = int(np.prod([tensor.sizes[d] for d in self._axes])) 141 142 if xr.__version__.startswith("2023"): 143 var = xr.dot(c, c, dims=self._axes) / n 144 else: 145 var = xr.dot(c, c, dim=self._axes) / n 146 147 assert isinstance(var, xr.DataArray) 148 std = np.sqrt(var) 149 assert isinstance(std, xr.DataArray) 150 return { 151 SampleMean(axes=self._axes, member_id=self._member_id): mean, 152 SampleVar(axes=self._axes, member_id=self._member_id): Tensor.from_xarray( 153 var 154 ), 155 SampleStd(axes=self._axes, member_id=self._member_id): Tensor.from_xarray( 156 std 157 ), 158 }
160 def update(self, sample: Sample): 161 if self._axes is not None and BATCH_AXIS_ID not in self._axes: 162 return 163 164 tensor = sample.members[self._member_id].astype("float64", copy=False) 165 mean_b = tensor.mean(dim=self._axes) 166 assert mean_b.dtype == "float64" 167 # reduced voxel count 168 n_b = int(tensor.size / mean_b.size) 169 m2_b = ((tensor - mean_b) ** 2).sum(dim=self._axes) 170 assert m2_b.dtype == "float64" 171 if self._mean is None: 172 assert self._m2 is None 173 self._n = n_b 174 self._mean = mean_b 175 self._m2 = m2_b 176 else: 177 n_a = self._n 178 mean_a = self._mean 179 m2_a = self._m2 180 self._n = n = n_a + n_b 181 self._mean = (n_a * mean_a + n_b * mean_b) / n 182 assert self._mean.dtype == "float64" 183 d = mean_b - mean_a 184 self._m2 = m2_a + m2_b + d**2 * n_a * n_b / n 185 assert self._m2.dtype == "float64"
187 def finalize( 188 self, 189 ) -> Dict[Union[DatasetMean, DatasetVar, DatasetStd], MeasureValue]: 190 if ( 191 self._axes is not None 192 and BATCH_AXIS_ID not in self._axes 193 or self._mean is None 194 ): 195 return {} 196 else: 197 assert self._m2 is not None 198 var = self._m2 / self._n 199 sqrt = var**0.5 200 if isinstance(sqrt, (int, float)): 201 # var and mean are scalar tensors, let's keep it consistent 202 sqrt = Tensor.from_xarray(xr.DataArray(sqrt)) 203 204 assert isinstance(sqrt, Tensor), type(sqrt) 205 return { 206 DatasetMean(member_id=self._member_id, axes=self._axes): self._mean, 207 DatasetVar(member_id=self._member_id, axes=self._axes): var, 208 DatasetStd(member_id=self._member_id, axes=self._axes): sqrt, 209 }
212class SamplePercentilesCalculator: 213 """to calculate sample percentiles""" 214 215 def __init__( 216 self, 217 member_id: MemberId, 218 axes: Optional[Sequence[AxisId]], 219 qs: Collection[float], 220 ): 221 super().__init__() 222 assert all(0.0 <= q <= 1.0 for q in qs) 223 self._qs = sorted(set(qs)) 224 self._axes = None if axes is None else tuple(axes) 225 self._member_id = member_id 226 227 def compute(self, sample: Sample) -> Dict[SampleQuantile, MeasureValue]: 228 tensor = sample.members[self._member_id] 229 ps = tensor.quantile(self._qs, dim=self._axes) 230 return { 231 SampleQuantile(q=q, axes=self._axes, member_id=self._member_id): p 232 for q, p in zip(self._qs, ps) 233 }
to calculate sample percentiles
215 def __init__( 216 self, 217 member_id: MemberId, 218 axes: Optional[Sequence[AxisId]], 219 qs: Collection[float], 220 ): 221 super().__init__() 222 assert all(0.0 <= q <= 1.0 for q in qs) 223 self._qs = sorted(set(qs)) 224 self._axes = None if axes is None else tuple(axes) 225 self._member_id = member_id
236class MeanPercentilesCalculator: 237 """to calculate dataset percentiles heuristically by averaging across samples 238 **note**: the returned dataset percentiles are an estiamte and **not mathematically correct** 239 """ 240 241 def __init__( 242 self, 243 member_id: MemberId, 244 axes: Optional[Sequence[AxisId]], 245 qs: Collection[float], 246 ): 247 super().__init__() 248 assert all(0.0 <= q <= 1.0 for q in qs) 249 self._qs = sorted(set(qs)) 250 self._axes = None if axes is None else tuple(axes) 251 self._member_id = member_id 252 self._n: int = 0 253 self._estimates: Optional[Tensor] = None 254 255 def update(self, sample: Sample): 256 tensor = sample.members[self._member_id] 257 sample_estimates = tensor.quantile(self._qs, dim=self._axes).astype( 258 "float64", copy=False 259 ) 260 261 # reduced voxel count 262 n = int(tensor.size / np.prod(sample_estimates.shape_tuple[1:])) 263 264 if self._estimates is None: 265 assert self._n == 0 266 self._estimates = sample_estimates 267 else: 268 self._estimates = (self._n * self._estimates + n * sample_estimates) / ( 269 self._n + n 270 ) 271 assert self._estimates.dtype == "float64" 272 273 self._n += n 274 275 def finalize(self) -> Dict[DatasetPercentile, MeasureValue]: 276 if self._estimates is None: 277 return {} 278 else: 279 warnings.warn( 280 "Computed dataset percentiles naively by averaging percentiles of samples." 281 ) 282 return { 283 DatasetPercentile(q=q, axes=self._axes, member_id=self._member_id): e 284 for q, e in zip(self._qs, self._estimates) 285 }
to calculate dataset percentiles heuristically by averaging across samples note: the returned dataset percentiles are an estiamte and not mathematically correct
241 def __init__( 242 self, 243 member_id: MemberId, 244 axes: Optional[Sequence[AxisId]], 245 qs: Collection[float], 246 ): 247 super().__init__() 248 assert all(0.0 <= q <= 1.0 for q in qs) 249 self._qs = sorted(set(qs)) 250 self._axes = None if axes is None else tuple(axes) 251 self._member_id = member_id 252 self._n: int = 0 253 self._estimates: Optional[Tensor] = None
255 def update(self, sample: Sample): 256 tensor = sample.members[self._member_id] 257 sample_estimates = tensor.quantile(self._qs, dim=self._axes).astype( 258 "float64", copy=False 259 ) 260 261 # reduced voxel count 262 n = int(tensor.size / np.prod(sample_estimates.shape_tuple[1:])) 263 264 if self._estimates is None: 265 assert self._n == 0 266 self._estimates = sample_estimates 267 else: 268 self._estimates = (self._n * self._estimates + n * sample_estimates) / ( 269 self._n + n 270 ) 271 assert self._estimates.dtype == "float64" 272 273 self._n += n
275 def finalize(self) -> Dict[DatasetPercentile, MeasureValue]: 276 if self._estimates is None: 277 return {} 278 else: 279 warnings.warn( 280 "Computed dataset percentiles naively by averaging percentiles of samples." 281 ) 282 return { 283 DatasetPercentile(q=q, axes=self._axes, member_id=self._member_id): e 284 for q, e in zip(self._qs, self._estimates) 285 }
288class CrickPercentilesCalculator: 289 """to calculate dataset percentiles with the experimental [crick libray](https://github.com/dask/crick)""" 290 291 def __init__( 292 self, 293 member_id: MemberId, 294 axes: Optional[Sequence[AxisId]], 295 qs: Collection[float], 296 ): 297 warnings.warn( 298 "Computing dataset percentiles with experimental 'crick' library." 299 ) 300 super().__init__() 301 assert all(0.0 <= q <= 1.0 for q in qs) 302 assert axes is None or "_percentiles" not in axes 303 self._qs = sorted(set(qs)) 304 self._axes = None if axes is None else tuple(axes) 305 self._member_id = member_id 306 self._digest: Optional[List[TDigest]] = None 307 self._dims: Optional[Tuple[AxisId, ...]] = None 308 self._indices: Optional[Iterator[Tuple[int, ...]]] = None 309 self._shape: Optional[Tuple[int, ...]] = None 310 311 def _initialize(self, tensor_sizes: PerAxis[int]): 312 assert crick is not None 313 out_sizes: OrderedDict[AxisId, int] = collections.OrderedDict( 314 _percentiles=len(self._qs) 315 ) 316 if self._axes is not None: 317 for d, s in tensor_sizes.items(): 318 if d not in self._axes: 319 out_sizes[d] = s 320 321 self._dims, self._shape = zip(*out_sizes.items()) 322 assert self._shape is not None 323 d = int(np.prod(self._shape[1:])) 324 self._digest = [TDigest() for _ in range(d)] 325 self._indices = product(*map(range, self._shape[1:])) 326 327 def update(self, part: Sample): 328 tensor = ( 329 part.members[self._member_id] 330 if isinstance(part, Sample) 331 else part.members[self._member_id].data 332 ) 333 assert "_percentiles" not in tensor.dims 334 if self._digest is None: 335 self._initialize(tensor.tagged_shape) 336 337 assert self._digest is not None 338 assert self._indices is not None 339 assert self._dims is not None 340 for i, idx in enumerate(self._indices): 341 self._digest[i].update(tensor[dict(zip(self._dims[1:], idx))]) 342 343 def finalize(self) -> Dict[DatasetPercentile, MeasureValue]: 344 if self._digest is None: 345 return {} 346 else: 347 assert self._dims is not None 348 assert self._shape is not None 349 350 vs: NDArray[Any] = np.asarray( 351 [[d.quantile(q) for d in self._digest] for q in self._qs] 352 ).reshape(self._shape) 353 return { 354 DatasetPercentile( 355 q=q, axes=self._axes, member_id=self._member_id 356 ): Tensor(v, dims=self._dims[1:]) 357 for q, v in zip(self._qs, vs) 358 }
to calculate dataset percentiles with the experimental crick libray
291 def __init__( 292 self, 293 member_id: MemberId, 294 axes: Optional[Sequence[AxisId]], 295 qs: Collection[float], 296 ): 297 warnings.warn( 298 "Computing dataset percentiles with experimental 'crick' library." 299 ) 300 super().__init__() 301 assert all(0.0 <= q <= 1.0 for q in qs) 302 assert axes is None or "_percentiles" not in axes 303 self._qs = sorted(set(qs)) 304 self._axes = None if axes is None else tuple(axes) 305 self._member_id = member_id 306 self._digest: Optional[List[TDigest]] = None 307 self._dims: Optional[Tuple[AxisId, ...]] = None 308 self._indices: Optional[Iterator[Tuple[int, ...]]] = None 309 self._shape: Optional[Tuple[int, ...]] = None
327 def update(self, part: Sample): 328 tensor = ( 329 part.members[self._member_id] 330 if isinstance(part, Sample) 331 else part.members[self._member_id].data 332 ) 333 assert "_percentiles" not in tensor.dims 334 if self._digest is None: 335 self._initialize(tensor.tagged_shape) 336 337 assert self._digest is not None 338 assert self._indices is not None 339 assert self._dims is not None 340 for i, idx in enumerate(self._indices): 341 self._digest[i].update(tensor[dict(zip(self._dims[1:], idx))])
343 def finalize(self) -> Dict[DatasetPercentile, MeasureValue]: 344 if self._digest is None: 345 return {} 346 else: 347 assert self._dims is not None 348 assert self._shape is not None 349 350 vs: NDArray[Any] = np.asarray( 351 [[d.quantile(q) for d in self._digest] for q in self._qs] 352 ).reshape(self._shape) 353 return { 354 DatasetPercentile( 355 q=q, axes=self._axes, member_id=self._member_id 356 ): Tensor(v, dims=self._dims[1:]) 357 for q, v in zip(self._qs, vs) 358 }
369class NaiveSampleMeasureCalculator: 370 """wrapper for measures to match interface of other sample measure calculators""" 371 372 def __init__(self, member_id: MemberId, measure: SampleMeasure): 373 super().__init__() 374 self.tensor_name = member_id 375 self.measure = measure 376 377 def compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]: 378 return {self.measure: self.measure.compute(sample)}
wrapper for measures to match interface of other sample measure calculators
392class StatsCalculator: 393 """Estimates dataset statistics and computes sample statistics efficiently""" 394 395 def __init__( 396 self, 397 measures: Collection[Measure], 398 initial_dataset_measures: Optional[ 399 Mapping[DatasetMeasure, MeasureValue] 400 ] = None, 401 ): 402 super().__init__() 403 self.sample_count = 0 404 self.sample_calculators, self.dataset_calculators = get_measure_calculators( 405 measures 406 ) 407 if not initial_dataset_measures: 408 self._current_dataset_measures: Optional[ 409 Dict[DatasetMeasure, MeasureValue] 410 ] = None 411 else: 412 missing_dataset_meas = { 413 m 414 for m in measures 415 if isinstance(m, DatasetMeasureBase) 416 and m not in initial_dataset_measures 417 } 418 if missing_dataset_meas: 419 logger.debug( 420 f"ignoring `initial_dataset_measure` as it is missing {missing_dataset_meas}" 421 ) 422 self._current_dataset_measures = None 423 else: 424 self._current_dataset_measures = dict(initial_dataset_measures) 425 426 @property 427 def has_dataset_measures(self): 428 return self._current_dataset_measures is not None 429 430 def update( 431 self, 432 sample: Union[Sample, Iterable[Sample]], 433 ) -> None: 434 _ = self._update(sample) 435 436 def finalize(self) -> Dict[DatasetMeasure, MeasureValue]: 437 """returns aggregated dataset statistics""" 438 if self._current_dataset_measures is None: 439 self._current_dataset_measures = {} 440 for calc in self.dataset_calculators: 441 values = calc.finalize() 442 self._current_dataset_measures.update(values.items()) 443 444 return self._current_dataset_measures 445 446 def update_and_get_all( 447 self, 448 sample: Union[Sample, Iterable[Sample]], 449 ) -> Dict[Measure, MeasureValue]: 450 """Returns sample as well as updated dataset statistics""" 451 last_sample = self._update(sample) 452 if last_sample is None: 453 raise ValueError("`sample` was not a `Sample`, nor did it yield any.") 454 455 return {**self._compute(last_sample), **self.finalize()} 456 457 def skip_update_and_get_all(self, sample: Sample) -> Dict[Measure, MeasureValue]: 458 """Returns sample as well as previously computed dataset statistics""" 459 return {**self._compute(sample), **self.finalize()} 460 461 def _compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]: 462 ret: Dict[SampleMeasure, MeasureValue] = {} 463 for calc in self.sample_calculators: 464 values = calc.compute(sample) 465 ret.update(values.items()) 466 467 return ret 468 469 def _update(self, sample: Union[Sample, Iterable[Sample]]) -> Optional[Sample]: 470 self.sample_count += 1 471 samples = [sample] if isinstance(sample, Sample) else sample 472 last_sample = None 473 for el in samples: 474 last_sample = el 475 for calc in self.dataset_calculators: 476 calc.update(el) 477 478 self._current_dataset_measures = None 479 return last_sample
Estimates dataset statistics and computes sample statistics efficiently
395 def __init__( 396 self, 397 measures: Collection[Measure], 398 initial_dataset_measures: Optional[ 399 Mapping[DatasetMeasure, MeasureValue] 400 ] = None, 401 ): 402 super().__init__() 403 self.sample_count = 0 404 self.sample_calculators, self.dataset_calculators = get_measure_calculators( 405 measures 406 ) 407 if not initial_dataset_measures: 408 self._current_dataset_measures: Optional[ 409 Dict[DatasetMeasure, MeasureValue] 410 ] = None 411 else: 412 missing_dataset_meas = { 413 m 414 for m in measures 415 if isinstance(m, DatasetMeasureBase) 416 and m not in initial_dataset_measures 417 } 418 if missing_dataset_meas: 419 logger.debug( 420 f"ignoring `initial_dataset_measure` as it is missing {missing_dataset_meas}" 421 ) 422 self._current_dataset_measures = None 423 else: 424 self._current_dataset_measures = dict(initial_dataset_measures)
436 def finalize(self) -> Dict[DatasetMeasure, MeasureValue]: 437 """returns aggregated dataset statistics""" 438 if self._current_dataset_measures is None: 439 self._current_dataset_measures = {} 440 for calc in self.dataset_calculators: 441 values = calc.finalize() 442 self._current_dataset_measures.update(values.items()) 443 444 return self._current_dataset_measures
returns aggregated dataset statistics
446 def update_and_get_all( 447 self, 448 sample: Union[Sample, Iterable[Sample]], 449 ) -> Dict[Measure, MeasureValue]: 450 """Returns sample as well as updated dataset statistics""" 451 last_sample = self._update(sample) 452 if last_sample is None: 453 raise ValueError("`sample` was not a `Sample`, nor did it yield any.") 454 455 return {**self._compute(last_sample), **self.finalize()}
Returns sample as well as updated dataset statistics
457 def skip_update_and_get_all(self, sample: Sample) -> Dict[Measure, MeasureValue]: 458 """Returns sample as well as previously computed dataset statistics""" 459 return {**self._compute(sample), **self.finalize()}
Returns sample as well as previously computed dataset statistics
482def get_measure_calculators( 483 required_measures: Iterable[Measure], 484) -> Tuple[List[SampleMeasureCalculator], List[DatasetMeasureCalculator]]: 485 """determines which calculators are needed to compute the required measures efficiently""" 486 487 sample_calculators: List[SampleMeasureCalculator] = [] 488 dataset_calculators: List[DatasetMeasureCalculator] = [] 489 490 # split required measures into groups 491 required_sample_means: Set[SampleMean] = set() 492 required_dataset_means: Set[DatasetMean] = set() 493 required_sample_mean_var_std: Set[Union[SampleMean, SampleVar, SampleStd]] = set() 494 required_dataset_mean_var_std: Set[Union[DatasetMean, DatasetVar, DatasetStd]] = ( 495 set() 496 ) 497 required_sample_percentiles: Dict[ 498 Tuple[MemberId, Optional[Tuple[AxisId, ...]]], Set[float] 499 ] = {} 500 required_dataset_percentiles: Dict[ 501 Tuple[MemberId, Optional[Tuple[AxisId, ...]]], Set[float] 502 ] = {} 503 504 for rm in required_measures: 505 if isinstance(rm, SampleMean): 506 required_sample_means.add(rm) 507 elif isinstance(rm, DatasetMean): 508 required_dataset_means.add(rm) 509 elif isinstance(rm, (SampleVar, SampleStd)): 510 required_sample_mean_var_std.update( 511 { 512 msv(axes=rm.axes, member_id=rm.member_id) 513 for msv in (SampleMean, SampleStd, SampleVar) 514 } 515 ) 516 assert rm in required_sample_mean_var_std 517 elif isinstance(rm, (DatasetVar, DatasetStd)): 518 required_dataset_mean_var_std.update( 519 { 520 msv(axes=rm.axes, member_id=rm.member_id) 521 for msv in (DatasetMean, DatasetStd, DatasetVar) 522 } 523 ) 524 assert rm in required_dataset_mean_var_std 525 elif isinstance(rm, SampleQuantile): 526 required_sample_percentiles.setdefault((rm.member_id, rm.axes), set()).add( 527 rm.q 528 ) 529 elif isinstance(rm, DatasetPercentile): 530 required_dataset_percentiles.setdefault((rm.member_id, rm.axes), set()).add( 531 rm.q 532 ) 533 else: 534 assert_never(rm) 535 536 for rm in required_sample_means: 537 if rm in required_sample_mean_var_std: 538 # computed togehter with var and std 539 continue 540 541 sample_calculators.append(MeanCalculator(member_id=rm.member_id, axes=rm.axes)) 542 543 for rm in required_sample_mean_var_std: 544 sample_calculators.append( 545 MeanVarStdCalculator(member_id=rm.member_id, axes=rm.axes) 546 ) 547 548 for rm in required_dataset_means: 549 if rm in required_dataset_mean_var_std: 550 # computed togehter with var and std 551 continue 552 553 dataset_calculators.append(MeanCalculator(member_id=rm.member_id, axes=rm.axes)) 554 555 for rm in required_dataset_mean_var_std: 556 dataset_calculators.append( 557 MeanVarStdCalculator(member_id=rm.member_id, axes=rm.axes) 558 ) 559 560 for (tid, axes), qs in required_sample_percentiles.items(): 561 sample_calculators.append( 562 SamplePercentilesCalculator(member_id=tid, axes=axes, qs=qs) 563 ) 564 565 for (tid, axes), qs in required_dataset_percentiles.items(): 566 dataset_calculators.append( 567 DatasetPercentilesCalculator(member_id=tid, axes=axes, qs=qs) 568 ) 569 570 return sample_calculators, dataset_calculators
determines which calculators are needed to compute the required measures efficiently
573def compute_dataset_measures( 574 measures: Iterable[DatasetMeasure], dataset: Iterable[Sample] 575) -> Dict[DatasetMeasure, MeasureValue]: 576 """compute all dataset `measures` for the given `dataset`""" 577 sample_calculators, calculators = get_measure_calculators(measures) 578 assert not sample_calculators 579 580 ret: Dict[DatasetMeasure, MeasureValue] = {} 581 582 for sample in dataset: 583 for calc in calculators: 584 calc.update(sample) 585 586 for calc in calculators: 587 ret.update(calc.finalize().items()) 588 589 return ret
compute all dataset measures for the given dataset
592def compute_sample_measures( 593 measures: Iterable[SampleMeasure], sample: Sample 594) -> Dict[SampleMeasure, MeasureValue]: 595 """compute all sample `measures` for the given `sample`""" 596 calculators, dataset_calculators = get_measure_calculators(measures) 597 assert not dataset_calculators 598 ret: Dict[SampleMeasure, MeasureValue] = {} 599 600 for calc in calculators: 601 ret.update(calc.compute(sample).items()) 602 603 return ret
compute all sample measures for the given sample
606def compute_measures( 607 measures: Iterable[Measure], dataset: Iterable[Sample] 608) -> Dict[Measure, MeasureValue]: 609 """compute all `measures` for the given `dataset` 610 sample measures are computed for the last sample in `dataset`""" 611 sample_calculators, dataset_calculators = get_measure_calculators(measures) 612 ret: Dict[Measure, MeasureValue] = {} 613 sample = None 614 for sample in dataset: 615 for calc in dataset_calculators: 616 calc.update(sample) 617 if sample is None: 618 raise ValueError("empty dataset") 619 620 for calc in dataset_calculators: 621 ret.update(calc.finalize().items()) 622 623 for calc in sample_calculators: 624 ret.update(calc.compute(sample).items()) 625 626 return ret
compute all measures for the given dataset
sample measures are computed for the last sample in dataset