bioimageio.core.stat_calculators
1from __future__ import annotations 2 3import collections 4import warnings 5from itertools import product 6from typing import ( 7 Any, 8 Collection, 9 Dict, 10 Iterable, 11 Iterator, 12 List, 13 Mapping, 14 Optional, 15 OrderedDict, 16 Sequence, 17 Set, 18 Tuple, 19 Type, 20 Union, 21) 22 23import numpy as np 24import xarray as xr 25from loguru import logger 26from numpy.typing import NDArray 27from typing_extensions import assert_never 28 29from bioimageio.spec.model.v0_5 import BATCH_AXIS_ID 30 31from .axis import AxisId, PerAxis 32from .common import MemberId 33from .sample import Sample 34from .stat_measures import ( 35 DatasetMean, 36 DatasetMeasure, 37 DatasetMeasureBase, 38 DatasetPercentile, 39 DatasetStd, 40 DatasetVar, 41 Measure, 42 MeasureValue, 43 SampleMean, 44 SampleMeasure, 45 SampleQuantile, 46 SampleStd, 47 SampleVar, 48) 49from .tensor import Tensor 50 51try: 52 import crick # pyright: ignore[reportMissingImports] 53 54except Exception: 55 crick = None 56 57 class TDigest: 58 def update(self, obj: Any): 59 pass 60 61 def quantile(self, q: Any) -> Any: 62 pass 63 64else: 65 TDigest = crick.TDigest # type: ignore 66 67 68class MeanCalculator: 69 """to calculate sample and dataset mean for in-memory samples""" 70 71 def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]): 72 super().__init__() 73 self._n: int = 0 74 self._mean: Optional[Tensor] = None 75 self._axes = None if axes is None else tuple(axes) 76 self._member_id = member_id 77 self._sample_mean = SampleMean(member_id=self._member_id, axes=self._axes) 78 self._dataset_mean = DatasetMean(member_id=self._member_id, axes=self._axes) 79 80 def compute(self, sample: Sample) -> Dict[SampleMean, MeasureValue]: 81 return {self._sample_mean: self._compute_impl(sample)} 82 83 def _compute_impl(self, sample: Sample) -> Tensor: 84 tensor = sample.members[self._member_id].astype("float64", copy=False) 85 return tensor.mean(dim=self._axes) 86 87 def update(self, sample: Sample) -> None: 88 mean = self._compute_impl(sample) 89 self._update_impl(sample.members[self._member_id], mean) 90 91 def compute_and_update(self, sample: Sample) -> Dict[SampleMean, MeasureValue]: 92 mean = self._compute_impl(sample) 93 self._update_impl(sample.members[self._member_id], mean) 94 return {self._sample_mean: mean} 95 96 def _update_impl(self, tensor: Tensor, tensor_mean: Tensor): 97 assert tensor_mean.dtype == "float64" 98 # reduced voxel count 99 n_b = int(tensor.size / tensor_mean.size) 100 101 if self._mean is None: 102 assert self._n == 0 103 self._n = n_b 104 self._mean = tensor_mean 105 else: 106 assert self._n != 0 107 n_a = self._n 108 mean_old = self._mean 109 self._n = n_a + n_b 110 self._mean = (n_a * mean_old + n_b * tensor_mean) / self._n 111 assert self._mean.dtype == "float64" 112 113 def finalize(self) -> Dict[DatasetMean, MeasureValue]: 114 if self._mean is None: 115 return {} 116 else: 117 return {self._dataset_mean: self._mean} 118 119 120class MeanVarStdCalculator: 121 """to calculate sample and dataset mean, variance or standard deviation""" 122 123 def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]): 124 super().__init__() 125 self._axes = None if axes is None else tuple(map(AxisId, axes)) 126 self._member_id = member_id 127 self._n: int = 0 128 self._mean: Optional[Tensor] = None 129 self._m2: Optional[Tensor] = None 130 131 def compute( 132 self, sample: Sample 133 ) -> Dict[Union[SampleMean, SampleVar, SampleStd], MeasureValue]: 134 tensor = sample.members[self._member_id] 135 mean = tensor.mean(dim=self._axes) 136 c = (tensor - mean).data 137 if self._axes is None: 138 n = tensor.size 139 else: 140 n = int(np.prod([tensor.sizes[d] for d in self._axes])) 141 142 if xr.__version__.startswith("2023"): 143 var = ( # pyright: ignore[reportUnknownVariableType] 144 xr.dot(c, c, dims=self._axes) / n 145 ) 146 else: 147 var = ( # pyright: ignore[reportUnknownVariableType] 148 xr.dot(c, c, dim=self._axes) / n 149 ) 150 151 assert isinstance(var, xr.DataArray) 152 std = np.sqrt(var) 153 assert isinstance(std, xr.DataArray) 154 return { 155 SampleMean(axes=self._axes, member_id=self._member_id): mean, 156 SampleVar(axes=self._axes, member_id=self._member_id): Tensor.from_xarray( 157 var 158 ), 159 SampleStd(axes=self._axes, member_id=self._member_id): Tensor.from_xarray( 160 std 161 ), 162 } 163 164 def update(self, sample: Sample): 165 if self._axes is not None and BATCH_AXIS_ID not in self._axes: 166 return 167 168 tensor = sample.members[self._member_id].astype("float64", copy=False) 169 mean_b = tensor.mean(dim=self._axes) 170 assert mean_b.dtype == "float64" 171 # reduced voxel count 172 n_b = int(tensor.size / mean_b.size) 173 m2_b = ((tensor - mean_b) ** 2).sum(dim=self._axes) 174 assert m2_b.dtype == "float64" 175 if self._mean is None: 176 assert self._m2 is None 177 self._n = n_b 178 self._mean = mean_b 179 self._m2 = m2_b 180 else: 181 n_a = self._n 182 mean_a = self._mean 183 m2_a = self._m2 184 self._n = n = n_a + n_b 185 self._mean = (n_a * mean_a + n_b * mean_b) / n 186 assert self._mean.dtype == "float64" 187 d = mean_b - mean_a 188 self._m2 = m2_a + m2_b + d**2 * n_a * n_b / n 189 assert self._m2.dtype == "float64" 190 191 def finalize( 192 self, 193 ) -> Dict[Union[DatasetMean, DatasetVar, DatasetStd], MeasureValue]: 194 if ( 195 self._axes is not None 196 and BATCH_AXIS_ID not in self._axes 197 or self._mean is None 198 ): 199 return {} 200 else: 201 assert self._m2 is not None 202 var = self._m2 / self._n 203 sqrt = var**0.5 204 if isinstance(sqrt, (int, float)): 205 # var and mean are scalar tensors, let's keep it consistent 206 sqrt = Tensor.from_xarray(xr.DataArray(sqrt)) 207 208 assert isinstance(sqrt, Tensor), type(sqrt) 209 return { 210 DatasetMean(member_id=self._member_id, axes=self._axes): self._mean, 211 DatasetVar(member_id=self._member_id, axes=self._axes): var, 212 DatasetStd(member_id=self._member_id, axes=self._axes): sqrt, 213 } 214 215 216class SamplePercentilesCalculator: 217 """to calculate sample percentiles""" 218 219 def __init__( 220 self, 221 member_id: MemberId, 222 axes: Optional[Sequence[AxisId]], 223 qs: Collection[float], 224 ): 225 super().__init__() 226 assert all(0.0 <= q <= 1.0 for q in qs) 227 self._qs = sorted(set(qs)) 228 self._axes = None if axes is None else tuple(axes) 229 self._member_id = member_id 230 231 def compute(self, sample: Sample) -> Dict[SampleQuantile, MeasureValue]: 232 tensor = sample.members[self._member_id] 233 ps = tensor.quantile(self._qs, dim=self._axes) 234 return { 235 SampleQuantile(q=q, axes=self._axes, member_id=self._member_id): p 236 for q, p in zip(self._qs, ps) 237 } 238 239 240class MeanPercentilesCalculator: 241 """to calculate dataset percentiles heuristically by averaging across samples 242 **note**: the returned dataset percentiles are an estiamte and **not mathematically correct** 243 """ 244 245 def __init__( 246 self, 247 member_id: MemberId, 248 axes: Optional[Sequence[AxisId]], 249 qs: Collection[float], 250 ): 251 super().__init__() 252 assert all(0.0 <= q <= 1.0 for q in qs) 253 self._qs = sorted(set(qs)) 254 self._axes = None if axes is None else tuple(axes) 255 self._member_id = member_id 256 self._n: int = 0 257 self._estimates: Optional[Tensor] = None 258 259 def update(self, sample: Sample): 260 tensor = sample.members[self._member_id] 261 sample_estimates = tensor.quantile(self._qs, dim=self._axes).astype( 262 "float64", copy=False 263 ) 264 265 # reduced voxel count 266 n = int(tensor.size / np.prod(sample_estimates.shape_tuple[1:])) 267 268 if self._estimates is None: 269 assert self._n == 0 270 self._estimates = sample_estimates 271 else: 272 self._estimates = (self._n * self._estimates + n * sample_estimates) / ( 273 self._n + n 274 ) 275 assert self._estimates.dtype == "float64" 276 277 self._n += n 278 279 def finalize(self) -> Dict[DatasetPercentile, MeasureValue]: 280 if self._estimates is None: 281 return {} 282 else: 283 warnings.warn( 284 "Computed dataset percentiles naively by averaging percentiles of samples." 285 ) 286 return { 287 DatasetPercentile(q=q, axes=self._axes, member_id=self._member_id): e 288 for q, e in zip(self._qs, self._estimates) 289 } 290 291 292class CrickPercentilesCalculator: 293 """to calculate dataset percentiles with the experimental [crick libray](https://github.com/dask/crick)""" 294 295 def __init__( 296 self, 297 member_id: MemberId, 298 axes: Optional[Sequence[AxisId]], 299 qs: Collection[float], 300 ): 301 warnings.warn( 302 "Computing dataset percentiles with experimental 'crick' library." 303 ) 304 super().__init__() 305 assert all(0.0 <= q <= 1.0 for q in qs) 306 assert axes is None or "_percentiles" not in axes 307 self._qs = sorted(set(qs)) 308 self._axes = None if axes is None else tuple(axes) 309 self._member_id = member_id 310 self._digest: Optional[List[TDigest]] = None 311 self._dims: Optional[Tuple[AxisId, ...]] = None 312 self._indices: Optional[Iterator[Tuple[int, ...]]] = None 313 self._shape: Optional[Tuple[int, ...]] = None 314 315 def _initialize(self, tensor_sizes: PerAxis[int]): 316 assert crick is not None 317 out_sizes: OrderedDict[AxisId, int] = collections.OrderedDict( 318 _percentiles=len(self._qs) 319 ) 320 if self._axes is not None: 321 for d, s in tensor_sizes.items(): 322 if d not in self._axes: 323 out_sizes[d] = s 324 325 self._dims, self._shape = zip(*out_sizes.items()) 326 assert self._shape is not None 327 d = int(np.prod(self._shape[1:])) 328 self._digest = [TDigest() for _ in range(d)] 329 self._indices = product(*map(range, self._shape[1:])) 330 331 def update(self, part: Sample): 332 tensor = ( 333 part.members[self._member_id] 334 if isinstance(part, Sample) 335 else part.members[self._member_id].data 336 ) 337 assert "_percentiles" not in tensor.dims 338 if self._digest is None: 339 self._initialize(tensor.tagged_shape) 340 341 assert self._digest is not None 342 assert self._indices is not None 343 assert self._dims is not None 344 for i, idx in enumerate(self._indices): 345 self._digest[i].update(tensor[dict(zip(self._dims[1:], idx))]) 346 347 def finalize(self) -> Dict[DatasetPercentile, MeasureValue]: 348 if self._digest is None: 349 return {} 350 else: 351 assert self._dims is not None 352 assert self._shape is not None 353 354 vs: NDArray[Any] = np.asarray( 355 [[d.quantile(q) for d in self._digest] for q in self._qs] 356 ).reshape(self._shape) 357 return { 358 DatasetPercentile( 359 q=q, axes=self._axes, member_id=self._member_id 360 ): Tensor(v, dims=self._dims[1:]) 361 for q, v in zip(self._qs, vs) 362 } 363 364 365if crick is None: 366 DatasetPercentilesCalculator: Type[ 367 Union[MeanPercentilesCalculator, CrickPercentilesCalculator] 368 ] = MeanPercentilesCalculator 369else: 370 DatasetPercentilesCalculator = CrickPercentilesCalculator 371 372 373class NaiveSampleMeasureCalculator: 374 """wrapper for measures to match interface of other sample measure calculators""" 375 376 def __init__(self, member_id: MemberId, measure: SampleMeasure): 377 super().__init__() 378 self.tensor_name = member_id 379 self.measure = measure 380 381 def compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]: 382 return {self.measure: self.measure.compute(sample)} 383 384 385SampleMeasureCalculator = Union[ 386 MeanCalculator, 387 MeanVarStdCalculator, 388 SamplePercentilesCalculator, 389 NaiveSampleMeasureCalculator, 390] 391DatasetMeasureCalculator = Union[ 392 MeanCalculator, MeanVarStdCalculator, DatasetPercentilesCalculator 393] 394 395 396class StatsCalculator: 397 """Estimates dataset statistics and computes sample statistics efficiently""" 398 399 def __init__( 400 self, 401 measures: Collection[Measure], 402 initial_dataset_measures: Optional[ 403 Mapping[DatasetMeasure, MeasureValue] 404 ] = None, 405 ): 406 super().__init__() 407 self.sample_count = 0 408 self.sample_calculators, self.dataset_calculators = get_measure_calculators( 409 measures 410 ) 411 if not initial_dataset_measures: 412 self._current_dataset_measures: Optional[ 413 Dict[DatasetMeasure, MeasureValue] 414 ] = None 415 else: 416 missing_dataset_meas = { 417 m 418 for m in measures 419 if isinstance(m, DatasetMeasureBase) 420 and m not in initial_dataset_measures 421 } 422 if missing_dataset_meas: 423 logger.debug( 424 f"ignoring `initial_dataset_measure` as it is missing {missing_dataset_meas}" 425 ) 426 self._current_dataset_measures = None 427 else: 428 self._current_dataset_measures = dict(initial_dataset_measures) 429 430 @property 431 def has_dataset_measures(self): 432 return self._current_dataset_measures is not None 433 434 def update( 435 self, 436 sample: Union[Sample, Iterable[Sample]], 437 ) -> None: 438 _ = self._update(sample) 439 440 def finalize(self) -> Dict[DatasetMeasure, MeasureValue]: 441 """returns aggregated dataset statistics""" 442 if self._current_dataset_measures is None: 443 self._current_dataset_measures = {} 444 for calc in self.dataset_calculators: 445 values = calc.finalize() 446 self._current_dataset_measures.update(values.items()) 447 448 return self._current_dataset_measures 449 450 def update_and_get_all( 451 self, 452 sample: Union[Sample, Iterable[Sample]], 453 ) -> Dict[Measure, MeasureValue]: 454 """Returns sample as well as updated dataset statistics""" 455 last_sample = self._update(sample) 456 if last_sample is None: 457 raise ValueError("`sample` was not a `Sample`, nor did it yield any.") 458 459 return {**self._compute(last_sample), **self.finalize()} 460 461 def skip_update_and_get_all(self, sample: Sample) -> Dict[Measure, MeasureValue]: 462 """Returns sample as well as previously computed dataset statistics""" 463 return {**self._compute(sample), **self.finalize()} 464 465 def _compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]: 466 ret: Dict[SampleMeasure, MeasureValue] = {} 467 for calc in self.sample_calculators: 468 values = calc.compute(sample) 469 ret.update(values.items()) 470 471 return ret 472 473 def _update(self, sample: Union[Sample, Iterable[Sample]]) -> Optional[Sample]: 474 self.sample_count += 1 475 samples = [sample] if isinstance(sample, Sample) else sample 476 last_sample = None 477 for el in samples: 478 last_sample = el 479 for calc in self.dataset_calculators: 480 calc.update(el) 481 482 self._current_dataset_measures = None 483 return last_sample 484 485 486def get_measure_calculators( 487 required_measures: Iterable[Measure], 488) -> Tuple[List[SampleMeasureCalculator], List[DatasetMeasureCalculator]]: 489 """determines which calculators are needed to compute the required measures efficiently""" 490 491 sample_calculators: List[SampleMeasureCalculator] = [] 492 dataset_calculators: List[DatasetMeasureCalculator] = [] 493 494 # split required measures into groups 495 required_sample_means: Set[SampleMean] = set() 496 required_dataset_means: Set[DatasetMean] = set() 497 required_sample_mean_var_std: Set[Union[SampleMean, SampleVar, SampleStd]] = set() 498 required_dataset_mean_var_std: Set[Union[DatasetMean, DatasetVar, DatasetStd]] = ( 499 set() 500 ) 501 required_sample_percentiles: Dict[ 502 Tuple[MemberId, Optional[Tuple[AxisId, ...]]], Set[float] 503 ] = {} 504 required_dataset_percentiles: Dict[ 505 Tuple[MemberId, Optional[Tuple[AxisId, ...]]], Set[float] 506 ] = {} 507 508 for rm in required_measures: 509 if isinstance(rm, SampleMean): 510 required_sample_means.add(rm) 511 elif isinstance(rm, DatasetMean): 512 required_dataset_means.add(rm) 513 elif isinstance(rm, (SampleVar, SampleStd)): 514 required_sample_mean_var_std.update( 515 { 516 msv(axes=rm.axes, member_id=rm.member_id) 517 for msv in (SampleMean, SampleStd, SampleVar) 518 } 519 ) 520 assert rm in required_sample_mean_var_std 521 elif isinstance(rm, (DatasetVar, DatasetStd)): 522 required_dataset_mean_var_std.update( 523 { 524 msv(axes=rm.axes, member_id=rm.member_id) 525 for msv in (DatasetMean, DatasetStd, DatasetVar) 526 } 527 ) 528 assert rm in required_dataset_mean_var_std 529 elif isinstance(rm, SampleQuantile): 530 required_sample_percentiles.setdefault((rm.member_id, rm.axes), set()).add( 531 rm.q 532 ) 533 elif isinstance(rm, DatasetPercentile): 534 required_dataset_percentiles.setdefault((rm.member_id, rm.axes), set()).add( 535 rm.q 536 ) 537 else: 538 assert_never(rm) 539 540 for rm in required_sample_means: 541 if rm in required_sample_mean_var_std: 542 # computed togehter with var and std 543 continue 544 545 sample_calculators.append(MeanCalculator(member_id=rm.member_id, axes=rm.axes)) 546 547 for rm in required_sample_mean_var_std: 548 sample_calculators.append( 549 MeanVarStdCalculator(member_id=rm.member_id, axes=rm.axes) 550 ) 551 552 for rm in required_dataset_means: 553 if rm in required_dataset_mean_var_std: 554 # computed togehter with var and std 555 continue 556 557 dataset_calculators.append(MeanCalculator(member_id=rm.member_id, axes=rm.axes)) 558 559 for rm in required_dataset_mean_var_std: 560 dataset_calculators.append( 561 MeanVarStdCalculator(member_id=rm.member_id, axes=rm.axes) 562 ) 563 564 for (tid, axes), qs in required_sample_percentiles.items(): 565 sample_calculators.append( 566 SamplePercentilesCalculator(member_id=tid, axes=axes, qs=qs) 567 ) 568 569 for (tid, axes), qs in required_dataset_percentiles.items(): 570 dataset_calculators.append( 571 DatasetPercentilesCalculator(member_id=tid, axes=axes, qs=qs) 572 ) 573 574 return sample_calculators, dataset_calculators 575 576 577def compute_dataset_measures( 578 measures: Iterable[DatasetMeasure], dataset: Iterable[Sample] 579) -> Dict[DatasetMeasure, MeasureValue]: 580 """compute all dataset `measures` for the given `dataset`""" 581 sample_calculators, calculators = get_measure_calculators(measures) 582 assert not sample_calculators 583 584 ret: Dict[DatasetMeasure, MeasureValue] = {} 585 586 for sample in dataset: 587 for calc in calculators: 588 calc.update(sample) 589 590 for calc in calculators: 591 ret.update(calc.finalize().items()) 592 593 return ret 594 595 596def compute_sample_measures( 597 measures: Iterable[SampleMeasure], sample: Sample 598) -> Dict[SampleMeasure, MeasureValue]: 599 """compute all sample `measures` for the given `sample`""" 600 calculators, dataset_calculators = get_measure_calculators(measures) 601 assert not dataset_calculators 602 ret: Dict[SampleMeasure, MeasureValue] = {} 603 604 for calc in calculators: 605 ret.update(calc.compute(sample).items()) 606 607 return ret 608 609 610def compute_measures( 611 measures: Iterable[Measure], dataset: Iterable[Sample] 612) -> Dict[Measure, MeasureValue]: 613 """compute all `measures` for the given `dataset` 614 sample measures are computed for the last sample in `dataset`""" 615 sample_calculators, dataset_calculators = get_measure_calculators(measures) 616 ret: Dict[Measure, MeasureValue] = {} 617 sample = None 618 for sample in dataset: 619 for calc in dataset_calculators: 620 calc.update(sample) 621 if sample is None: 622 raise ValueError("empty dataset") 623 624 for calc in dataset_calculators: 625 ret.update(calc.finalize().items()) 626 627 for calc in sample_calculators: 628 ret.update(calc.compute(sample).items()) 629 630 return ret
69class MeanCalculator: 70 """to calculate sample and dataset mean for in-memory samples""" 71 72 def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]): 73 super().__init__() 74 self._n: int = 0 75 self._mean: Optional[Tensor] = None 76 self._axes = None if axes is None else tuple(axes) 77 self._member_id = member_id 78 self._sample_mean = SampleMean(member_id=self._member_id, axes=self._axes) 79 self._dataset_mean = DatasetMean(member_id=self._member_id, axes=self._axes) 80 81 def compute(self, sample: Sample) -> Dict[SampleMean, MeasureValue]: 82 return {self._sample_mean: self._compute_impl(sample)} 83 84 def _compute_impl(self, sample: Sample) -> Tensor: 85 tensor = sample.members[self._member_id].astype("float64", copy=False) 86 return tensor.mean(dim=self._axes) 87 88 def update(self, sample: Sample) -> None: 89 mean = self._compute_impl(sample) 90 self._update_impl(sample.members[self._member_id], mean) 91 92 def compute_and_update(self, sample: Sample) -> Dict[SampleMean, MeasureValue]: 93 mean = self._compute_impl(sample) 94 self._update_impl(sample.members[self._member_id], mean) 95 return {self._sample_mean: mean} 96 97 def _update_impl(self, tensor: Tensor, tensor_mean: Tensor): 98 assert tensor_mean.dtype == "float64" 99 # reduced voxel count 100 n_b = int(tensor.size / tensor_mean.size) 101 102 if self._mean is None: 103 assert self._n == 0 104 self._n = n_b 105 self._mean = tensor_mean 106 else: 107 assert self._n != 0 108 n_a = self._n 109 mean_old = self._mean 110 self._n = n_a + n_b 111 self._mean = (n_a * mean_old + n_b * tensor_mean) / self._n 112 assert self._mean.dtype == "float64" 113 114 def finalize(self) -> Dict[DatasetMean, MeasureValue]: 115 if self._mean is None: 116 return {} 117 else: 118 return {self._dataset_mean: self._mean}
to calculate sample and dataset mean for in-memory samples
72 def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]): 73 super().__init__() 74 self._n: int = 0 75 self._mean: Optional[Tensor] = None 76 self._axes = None if axes is None else tuple(axes) 77 self._member_id = member_id 78 self._sample_mean = SampleMean(member_id=self._member_id, axes=self._axes) 79 self._dataset_mean = DatasetMean(member_id=self._member_id, axes=self._axes)
121class MeanVarStdCalculator: 122 """to calculate sample and dataset mean, variance or standard deviation""" 123 124 def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]): 125 super().__init__() 126 self._axes = None if axes is None else tuple(map(AxisId, axes)) 127 self._member_id = member_id 128 self._n: int = 0 129 self._mean: Optional[Tensor] = None 130 self._m2: Optional[Tensor] = None 131 132 def compute( 133 self, sample: Sample 134 ) -> Dict[Union[SampleMean, SampleVar, SampleStd], MeasureValue]: 135 tensor = sample.members[self._member_id] 136 mean = tensor.mean(dim=self._axes) 137 c = (tensor - mean).data 138 if self._axes is None: 139 n = tensor.size 140 else: 141 n = int(np.prod([tensor.sizes[d] for d in self._axes])) 142 143 if xr.__version__.startswith("2023"): 144 var = ( # pyright: ignore[reportUnknownVariableType] 145 xr.dot(c, c, dims=self._axes) / n 146 ) 147 else: 148 var = ( # pyright: ignore[reportUnknownVariableType] 149 xr.dot(c, c, dim=self._axes) / n 150 ) 151 152 assert isinstance(var, xr.DataArray) 153 std = np.sqrt(var) 154 assert isinstance(std, xr.DataArray) 155 return { 156 SampleMean(axes=self._axes, member_id=self._member_id): mean, 157 SampleVar(axes=self._axes, member_id=self._member_id): Tensor.from_xarray( 158 var 159 ), 160 SampleStd(axes=self._axes, member_id=self._member_id): Tensor.from_xarray( 161 std 162 ), 163 } 164 165 def update(self, sample: Sample): 166 if self._axes is not None and BATCH_AXIS_ID not in self._axes: 167 return 168 169 tensor = sample.members[self._member_id].astype("float64", copy=False) 170 mean_b = tensor.mean(dim=self._axes) 171 assert mean_b.dtype == "float64" 172 # reduced voxel count 173 n_b = int(tensor.size / mean_b.size) 174 m2_b = ((tensor - mean_b) ** 2).sum(dim=self._axes) 175 assert m2_b.dtype == "float64" 176 if self._mean is None: 177 assert self._m2 is None 178 self._n = n_b 179 self._mean = mean_b 180 self._m2 = m2_b 181 else: 182 n_a = self._n 183 mean_a = self._mean 184 m2_a = self._m2 185 self._n = n = n_a + n_b 186 self._mean = (n_a * mean_a + n_b * mean_b) / n 187 assert self._mean.dtype == "float64" 188 d = mean_b - mean_a 189 self._m2 = m2_a + m2_b + d**2 * n_a * n_b / n 190 assert self._m2.dtype == "float64" 191 192 def finalize( 193 self, 194 ) -> Dict[Union[DatasetMean, DatasetVar, DatasetStd], MeasureValue]: 195 if ( 196 self._axes is not None 197 and BATCH_AXIS_ID not in self._axes 198 or self._mean is None 199 ): 200 return {} 201 else: 202 assert self._m2 is not None 203 var = self._m2 / self._n 204 sqrt = var**0.5 205 if isinstance(sqrt, (int, float)): 206 # var and mean are scalar tensors, let's keep it consistent 207 sqrt = Tensor.from_xarray(xr.DataArray(sqrt)) 208 209 assert isinstance(sqrt, Tensor), type(sqrt) 210 return { 211 DatasetMean(member_id=self._member_id, axes=self._axes): self._mean, 212 DatasetVar(member_id=self._member_id, axes=self._axes): var, 213 DatasetStd(member_id=self._member_id, axes=self._axes): sqrt, 214 }
to calculate sample and dataset mean, variance or standard deviation
124 def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]): 125 super().__init__() 126 self._axes = None if axes is None else tuple(map(AxisId, axes)) 127 self._member_id = member_id 128 self._n: int = 0 129 self._mean: Optional[Tensor] = None 130 self._m2: Optional[Tensor] = None
132 def compute( 133 self, sample: Sample 134 ) -> Dict[Union[SampleMean, SampleVar, SampleStd], MeasureValue]: 135 tensor = sample.members[self._member_id] 136 mean = tensor.mean(dim=self._axes) 137 c = (tensor - mean).data 138 if self._axes is None: 139 n = tensor.size 140 else: 141 n = int(np.prod([tensor.sizes[d] for d in self._axes])) 142 143 if xr.__version__.startswith("2023"): 144 var = ( # pyright: ignore[reportUnknownVariableType] 145 xr.dot(c, c, dims=self._axes) / n 146 ) 147 else: 148 var = ( # pyright: ignore[reportUnknownVariableType] 149 xr.dot(c, c, dim=self._axes) / n 150 ) 151 152 assert isinstance(var, xr.DataArray) 153 std = np.sqrt(var) 154 assert isinstance(std, xr.DataArray) 155 return { 156 SampleMean(axes=self._axes, member_id=self._member_id): mean, 157 SampleVar(axes=self._axes, member_id=self._member_id): Tensor.from_xarray( 158 var 159 ), 160 SampleStd(axes=self._axes, member_id=self._member_id): Tensor.from_xarray( 161 std 162 ), 163 }
165 def update(self, sample: Sample): 166 if self._axes is not None and BATCH_AXIS_ID not in self._axes: 167 return 168 169 tensor = sample.members[self._member_id].astype("float64", copy=False) 170 mean_b = tensor.mean(dim=self._axes) 171 assert mean_b.dtype == "float64" 172 # reduced voxel count 173 n_b = int(tensor.size / mean_b.size) 174 m2_b = ((tensor - mean_b) ** 2).sum(dim=self._axes) 175 assert m2_b.dtype == "float64" 176 if self._mean is None: 177 assert self._m2 is None 178 self._n = n_b 179 self._mean = mean_b 180 self._m2 = m2_b 181 else: 182 n_a = self._n 183 mean_a = self._mean 184 m2_a = self._m2 185 self._n = n = n_a + n_b 186 self._mean = (n_a * mean_a + n_b * mean_b) / n 187 assert self._mean.dtype == "float64" 188 d = mean_b - mean_a 189 self._m2 = m2_a + m2_b + d**2 * n_a * n_b / n 190 assert self._m2.dtype == "float64"
192 def finalize( 193 self, 194 ) -> Dict[Union[DatasetMean, DatasetVar, DatasetStd], MeasureValue]: 195 if ( 196 self._axes is not None 197 and BATCH_AXIS_ID not in self._axes 198 or self._mean is None 199 ): 200 return {} 201 else: 202 assert self._m2 is not None 203 var = self._m2 / self._n 204 sqrt = var**0.5 205 if isinstance(sqrt, (int, float)): 206 # var and mean are scalar tensors, let's keep it consistent 207 sqrt = Tensor.from_xarray(xr.DataArray(sqrt)) 208 209 assert isinstance(sqrt, Tensor), type(sqrt) 210 return { 211 DatasetMean(member_id=self._member_id, axes=self._axes): self._mean, 212 DatasetVar(member_id=self._member_id, axes=self._axes): var, 213 DatasetStd(member_id=self._member_id, axes=self._axes): sqrt, 214 }
217class SamplePercentilesCalculator: 218 """to calculate sample percentiles""" 219 220 def __init__( 221 self, 222 member_id: MemberId, 223 axes: Optional[Sequence[AxisId]], 224 qs: Collection[float], 225 ): 226 super().__init__() 227 assert all(0.0 <= q <= 1.0 for q in qs) 228 self._qs = sorted(set(qs)) 229 self._axes = None if axes is None else tuple(axes) 230 self._member_id = member_id 231 232 def compute(self, sample: Sample) -> Dict[SampleQuantile, MeasureValue]: 233 tensor = sample.members[self._member_id] 234 ps = tensor.quantile(self._qs, dim=self._axes) 235 return { 236 SampleQuantile(q=q, axes=self._axes, member_id=self._member_id): p 237 for q, p in zip(self._qs, ps) 238 }
to calculate sample percentiles
220 def __init__( 221 self, 222 member_id: MemberId, 223 axes: Optional[Sequence[AxisId]], 224 qs: Collection[float], 225 ): 226 super().__init__() 227 assert all(0.0 <= q <= 1.0 for q in qs) 228 self._qs = sorted(set(qs)) 229 self._axes = None if axes is None else tuple(axes) 230 self._member_id = member_id
241class MeanPercentilesCalculator: 242 """to calculate dataset percentiles heuristically by averaging across samples 243 **note**: the returned dataset percentiles are an estiamte and **not mathematically correct** 244 """ 245 246 def __init__( 247 self, 248 member_id: MemberId, 249 axes: Optional[Sequence[AxisId]], 250 qs: Collection[float], 251 ): 252 super().__init__() 253 assert all(0.0 <= q <= 1.0 for q in qs) 254 self._qs = sorted(set(qs)) 255 self._axes = None if axes is None else tuple(axes) 256 self._member_id = member_id 257 self._n: int = 0 258 self._estimates: Optional[Tensor] = None 259 260 def update(self, sample: Sample): 261 tensor = sample.members[self._member_id] 262 sample_estimates = tensor.quantile(self._qs, dim=self._axes).astype( 263 "float64", copy=False 264 ) 265 266 # reduced voxel count 267 n = int(tensor.size / np.prod(sample_estimates.shape_tuple[1:])) 268 269 if self._estimates is None: 270 assert self._n == 0 271 self._estimates = sample_estimates 272 else: 273 self._estimates = (self._n * self._estimates + n * sample_estimates) / ( 274 self._n + n 275 ) 276 assert self._estimates.dtype == "float64" 277 278 self._n += n 279 280 def finalize(self) -> Dict[DatasetPercentile, MeasureValue]: 281 if self._estimates is None: 282 return {} 283 else: 284 warnings.warn( 285 "Computed dataset percentiles naively by averaging percentiles of samples." 286 ) 287 return { 288 DatasetPercentile(q=q, axes=self._axes, member_id=self._member_id): e 289 for q, e in zip(self._qs, self._estimates) 290 }
to calculate dataset percentiles heuristically by averaging across samples note: the returned dataset percentiles are an estiamte and not mathematically correct
246 def __init__( 247 self, 248 member_id: MemberId, 249 axes: Optional[Sequence[AxisId]], 250 qs: Collection[float], 251 ): 252 super().__init__() 253 assert all(0.0 <= q <= 1.0 for q in qs) 254 self._qs = sorted(set(qs)) 255 self._axes = None if axes is None else tuple(axes) 256 self._member_id = member_id 257 self._n: int = 0 258 self._estimates: Optional[Tensor] = None
260 def update(self, sample: Sample): 261 tensor = sample.members[self._member_id] 262 sample_estimates = tensor.quantile(self._qs, dim=self._axes).astype( 263 "float64", copy=False 264 ) 265 266 # reduced voxel count 267 n = int(tensor.size / np.prod(sample_estimates.shape_tuple[1:])) 268 269 if self._estimates is None: 270 assert self._n == 0 271 self._estimates = sample_estimates 272 else: 273 self._estimates = (self._n * self._estimates + n * sample_estimates) / ( 274 self._n + n 275 ) 276 assert self._estimates.dtype == "float64" 277 278 self._n += n
280 def finalize(self) -> Dict[DatasetPercentile, MeasureValue]: 281 if self._estimates is None: 282 return {} 283 else: 284 warnings.warn( 285 "Computed dataset percentiles naively by averaging percentiles of samples." 286 ) 287 return { 288 DatasetPercentile(q=q, axes=self._axes, member_id=self._member_id): e 289 for q, e in zip(self._qs, self._estimates) 290 }
293class CrickPercentilesCalculator: 294 """to calculate dataset percentiles with the experimental [crick libray](https://github.com/dask/crick)""" 295 296 def __init__( 297 self, 298 member_id: MemberId, 299 axes: Optional[Sequence[AxisId]], 300 qs: Collection[float], 301 ): 302 warnings.warn( 303 "Computing dataset percentiles with experimental 'crick' library." 304 ) 305 super().__init__() 306 assert all(0.0 <= q <= 1.0 for q in qs) 307 assert axes is None or "_percentiles" not in axes 308 self._qs = sorted(set(qs)) 309 self._axes = None if axes is None else tuple(axes) 310 self._member_id = member_id 311 self._digest: Optional[List[TDigest]] = None 312 self._dims: Optional[Tuple[AxisId, ...]] = None 313 self._indices: Optional[Iterator[Tuple[int, ...]]] = None 314 self._shape: Optional[Tuple[int, ...]] = None 315 316 def _initialize(self, tensor_sizes: PerAxis[int]): 317 assert crick is not None 318 out_sizes: OrderedDict[AxisId, int] = collections.OrderedDict( 319 _percentiles=len(self._qs) 320 ) 321 if self._axes is not None: 322 for d, s in tensor_sizes.items(): 323 if d not in self._axes: 324 out_sizes[d] = s 325 326 self._dims, self._shape = zip(*out_sizes.items()) 327 assert self._shape is not None 328 d = int(np.prod(self._shape[1:])) 329 self._digest = [TDigest() for _ in range(d)] 330 self._indices = product(*map(range, self._shape[1:])) 331 332 def update(self, part: Sample): 333 tensor = ( 334 part.members[self._member_id] 335 if isinstance(part, Sample) 336 else part.members[self._member_id].data 337 ) 338 assert "_percentiles" not in tensor.dims 339 if self._digest is None: 340 self._initialize(tensor.tagged_shape) 341 342 assert self._digest is not None 343 assert self._indices is not None 344 assert self._dims is not None 345 for i, idx in enumerate(self._indices): 346 self._digest[i].update(tensor[dict(zip(self._dims[1:], idx))]) 347 348 def finalize(self) -> Dict[DatasetPercentile, MeasureValue]: 349 if self._digest is None: 350 return {} 351 else: 352 assert self._dims is not None 353 assert self._shape is not None 354 355 vs: NDArray[Any] = np.asarray( 356 [[d.quantile(q) for d in self._digest] for q in self._qs] 357 ).reshape(self._shape) 358 return { 359 DatasetPercentile( 360 q=q, axes=self._axes, member_id=self._member_id 361 ): Tensor(v, dims=self._dims[1:]) 362 for q, v in zip(self._qs, vs) 363 }
to calculate dataset percentiles with the experimental crick libray
296 def __init__( 297 self, 298 member_id: MemberId, 299 axes: Optional[Sequence[AxisId]], 300 qs: Collection[float], 301 ): 302 warnings.warn( 303 "Computing dataset percentiles with experimental 'crick' library." 304 ) 305 super().__init__() 306 assert all(0.0 <= q <= 1.0 for q in qs) 307 assert axes is None or "_percentiles" not in axes 308 self._qs = sorted(set(qs)) 309 self._axes = None if axes is None else tuple(axes) 310 self._member_id = member_id 311 self._digest: Optional[List[TDigest]] = None 312 self._dims: Optional[Tuple[AxisId, ...]] = None 313 self._indices: Optional[Iterator[Tuple[int, ...]]] = None 314 self._shape: Optional[Tuple[int, ...]] = None
332 def update(self, part: Sample): 333 tensor = ( 334 part.members[self._member_id] 335 if isinstance(part, Sample) 336 else part.members[self._member_id].data 337 ) 338 assert "_percentiles" not in tensor.dims 339 if self._digest is None: 340 self._initialize(tensor.tagged_shape) 341 342 assert self._digest is not None 343 assert self._indices is not None 344 assert self._dims is not None 345 for i, idx in enumerate(self._indices): 346 self._digest[i].update(tensor[dict(zip(self._dims[1:], idx))])
348 def finalize(self) -> Dict[DatasetPercentile, MeasureValue]: 349 if self._digest is None: 350 return {} 351 else: 352 assert self._dims is not None 353 assert self._shape is not None 354 355 vs: NDArray[Any] = np.asarray( 356 [[d.quantile(q) for d in self._digest] for q in self._qs] 357 ).reshape(self._shape) 358 return { 359 DatasetPercentile( 360 q=q, axes=self._axes, member_id=self._member_id 361 ): Tensor(v, dims=self._dims[1:]) 362 for q, v in zip(self._qs, vs) 363 }
374class NaiveSampleMeasureCalculator: 375 """wrapper for measures to match interface of other sample measure calculators""" 376 377 def __init__(self, member_id: MemberId, measure: SampleMeasure): 378 super().__init__() 379 self.tensor_name = member_id 380 self.measure = measure 381 382 def compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]: 383 return {self.measure: self.measure.compute(sample)}
wrapper for measures to match interface of other sample measure calculators
397class StatsCalculator: 398 """Estimates dataset statistics and computes sample statistics efficiently""" 399 400 def __init__( 401 self, 402 measures: Collection[Measure], 403 initial_dataset_measures: Optional[ 404 Mapping[DatasetMeasure, MeasureValue] 405 ] = None, 406 ): 407 super().__init__() 408 self.sample_count = 0 409 self.sample_calculators, self.dataset_calculators = get_measure_calculators( 410 measures 411 ) 412 if not initial_dataset_measures: 413 self._current_dataset_measures: Optional[ 414 Dict[DatasetMeasure, MeasureValue] 415 ] = None 416 else: 417 missing_dataset_meas = { 418 m 419 for m in measures 420 if isinstance(m, DatasetMeasureBase) 421 and m not in initial_dataset_measures 422 } 423 if missing_dataset_meas: 424 logger.debug( 425 f"ignoring `initial_dataset_measure` as it is missing {missing_dataset_meas}" 426 ) 427 self._current_dataset_measures = None 428 else: 429 self._current_dataset_measures = dict(initial_dataset_measures) 430 431 @property 432 def has_dataset_measures(self): 433 return self._current_dataset_measures is not None 434 435 def update( 436 self, 437 sample: Union[Sample, Iterable[Sample]], 438 ) -> None: 439 _ = self._update(sample) 440 441 def finalize(self) -> Dict[DatasetMeasure, MeasureValue]: 442 """returns aggregated dataset statistics""" 443 if self._current_dataset_measures is None: 444 self._current_dataset_measures = {} 445 for calc in self.dataset_calculators: 446 values = calc.finalize() 447 self._current_dataset_measures.update(values.items()) 448 449 return self._current_dataset_measures 450 451 def update_and_get_all( 452 self, 453 sample: Union[Sample, Iterable[Sample]], 454 ) -> Dict[Measure, MeasureValue]: 455 """Returns sample as well as updated dataset statistics""" 456 last_sample = self._update(sample) 457 if last_sample is None: 458 raise ValueError("`sample` was not a `Sample`, nor did it yield any.") 459 460 return {**self._compute(last_sample), **self.finalize()} 461 462 def skip_update_and_get_all(self, sample: Sample) -> Dict[Measure, MeasureValue]: 463 """Returns sample as well as previously computed dataset statistics""" 464 return {**self._compute(sample), **self.finalize()} 465 466 def _compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]: 467 ret: Dict[SampleMeasure, MeasureValue] = {} 468 for calc in self.sample_calculators: 469 values = calc.compute(sample) 470 ret.update(values.items()) 471 472 return ret 473 474 def _update(self, sample: Union[Sample, Iterable[Sample]]) -> Optional[Sample]: 475 self.sample_count += 1 476 samples = [sample] if isinstance(sample, Sample) else sample 477 last_sample = None 478 for el in samples: 479 last_sample = el 480 for calc in self.dataset_calculators: 481 calc.update(el) 482 483 self._current_dataset_measures = None 484 return last_sample
Estimates dataset statistics and computes sample statistics efficiently
400 def __init__( 401 self, 402 measures: Collection[Measure], 403 initial_dataset_measures: Optional[ 404 Mapping[DatasetMeasure, MeasureValue] 405 ] = None, 406 ): 407 super().__init__() 408 self.sample_count = 0 409 self.sample_calculators, self.dataset_calculators = get_measure_calculators( 410 measures 411 ) 412 if not initial_dataset_measures: 413 self._current_dataset_measures: Optional[ 414 Dict[DatasetMeasure, MeasureValue] 415 ] = None 416 else: 417 missing_dataset_meas = { 418 m 419 for m in measures 420 if isinstance(m, DatasetMeasureBase) 421 and m not in initial_dataset_measures 422 } 423 if missing_dataset_meas: 424 logger.debug( 425 f"ignoring `initial_dataset_measure` as it is missing {missing_dataset_meas}" 426 ) 427 self._current_dataset_measures = None 428 else: 429 self._current_dataset_measures = dict(initial_dataset_measures)
441 def finalize(self) -> Dict[DatasetMeasure, MeasureValue]: 442 """returns aggregated dataset statistics""" 443 if self._current_dataset_measures is None: 444 self._current_dataset_measures = {} 445 for calc in self.dataset_calculators: 446 values = calc.finalize() 447 self._current_dataset_measures.update(values.items()) 448 449 return self._current_dataset_measures
returns aggregated dataset statistics
451 def update_and_get_all( 452 self, 453 sample: Union[Sample, Iterable[Sample]], 454 ) -> Dict[Measure, MeasureValue]: 455 """Returns sample as well as updated dataset statistics""" 456 last_sample = self._update(sample) 457 if last_sample is None: 458 raise ValueError("`sample` was not a `Sample`, nor did it yield any.") 459 460 return {**self._compute(last_sample), **self.finalize()}
Returns sample as well as updated dataset statistics
462 def skip_update_and_get_all(self, sample: Sample) -> Dict[Measure, MeasureValue]: 463 """Returns sample as well as previously computed dataset statistics""" 464 return {**self._compute(sample), **self.finalize()}
Returns sample as well as previously computed dataset statistics
487def get_measure_calculators( 488 required_measures: Iterable[Measure], 489) -> Tuple[List[SampleMeasureCalculator], List[DatasetMeasureCalculator]]: 490 """determines which calculators are needed to compute the required measures efficiently""" 491 492 sample_calculators: List[SampleMeasureCalculator] = [] 493 dataset_calculators: List[DatasetMeasureCalculator] = [] 494 495 # split required measures into groups 496 required_sample_means: Set[SampleMean] = set() 497 required_dataset_means: Set[DatasetMean] = set() 498 required_sample_mean_var_std: Set[Union[SampleMean, SampleVar, SampleStd]] = set() 499 required_dataset_mean_var_std: Set[Union[DatasetMean, DatasetVar, DatasetStd]] = ( 500 set() 501 ) 502 required_sample_percentiles: Dict[ 503 Tuple[MemberId, Optional[Tuple[AxisId, ...]]], Set[float] 504 ] = {} 505 required_dataset_percentiles: Dict[ 506 Tuple[MemberId, Optional[Tuple[AxisId, ...]]], Set[float] 507 ] = {} 508 509 for rm in required_measures: 510 if isinstance(rm, SampleMean): 511 required_sample_means.add(rm) 512 elif isinstance(rm, DatasetMean): 513 required_dataset_means.add(rm) 514 elif isinstance(rm, (SampleVar, SampleStd)): 515 required_sample_mean_var_std.update( 516 { 517 msv(axes=rm.axes, member_id=rm.member_id) 518 for msv in (SampleMean, SampleStd, SampleVar) 519 } 520 ) 521 assert rm in required_sample_mean_var_std 522 elif isinstance(rm, (DatasetVar, DatasetStd)): 523 required_dataset_mean_var_std.update( 524 { 525 msv(axes=rm.axes, member_id=rm.member_id) 526 for msv in (DatasetMean, DatasetStd, DatasetVar) 527 } 528 ) 529 assert rm in required_dataset_mean_var_std 530 elif isinstance(rm, SampleQuantile): 531 required_sample_percentiles.setdefault((rm.member_id, rm.axes), set()).add( 532 rm.q 533 ) 534 elif isinstance(rm, DatasetPercentile): 535 required_dataset_percentiles.setdefault((rm.member_id, rm.axes), set()).add( 536 rm.q 537 ) 538 else: 539 assert_never(rm) 540 541 for rm in required_sample_means: 542 if rm in required_sample_mean_var_std: 543 # computed togehter with var and std 544 continue 545 546 sample_calculators.append(MeanCalculator(member_id=rm.member_id, axes=rm.axes)) 547 548 for rm in required_sample_mean_var_std: 549 sample_calculators.append( 550 MeanVarStdCalculator(member_id=rm.member_id, axes=rm.axes) 551 ) 552 553 for rm in required_dataset_means: 554 if rm in required_dataset_mean_var_std: 555 # computed togehter with var and std 556 continue 557 558 dataset_calculators.append(MeanCalculator(member_id=rm.member_id, axes=rm.axes)) 559 560 for rm in required_dataset_mean_var_std: 561 dataset_calculators.append( 562 MeanVarStdCalculator(member_id=rm.member_id, axes=rm.axes) 563 ) 564 565 for (tid, axes), qs in required_sample_percentiles.items(): 566 sample_calculators.append( 567 SamplePercentilesCalculator(member_id=tid, axes=axes, qs=qs) 568 ) 569 570 for (tid, axes), qs in required_dataset_percentiles.items(): 571 dataset_calculators.append( 572 DatasetPercentilesCalculator(member_id=tid, axes=axes, qs=qs) 573 ) 574 575 return sample_calculators, dataset_calculators
determines which calculators are needed to compute the required measures efficiently
578def compute_dataset_measures( 579 measures: Iterable[DatasetMeasure], dataset: Iterable[Sample] 580) -> Dict[DatasetMeasure, MeasureValue]: 581 """compute all dataset `measures` for the given `dataset`""" 582 sample_calculators, calculators = get_measure_calculators(measures) 583 assert not sample_calculators 584 585 ret: Dict[DatasetMeasure, MeasureValue] = {} 586 587 for sample in dataset: 588 for calc in calculators: 589 calc.update(sample) 590 591 for calc in calculators: 592 ret.update(calc.finalize().items()) 593 594 return ret
compute all dataset measures
for the given dataset
597def compute_sample_measures( 598 measures: Iterable[SampleMeasure], sample: Sample 599) -> Dict[SampleMeasure, MeasureValue]: 600 """compute all sample `measures` for the given `sample`""" 601 calculators, dataset_calculators = get_measure_calculators(measures) 602 assert not dataset_calculators 603 ret: Dict[SampleMeasure, MeasureValue] = {} 604 605 for calc in calculators: 606 ret.update(calc.compute(sample).items()) 607 608 return ret
compute all sample measures
for the given sample
611def compute_measures( 612 measures: Iterable[Measure], dataset: Iterable[Sample] 613) -> Dict[Measure, MeasureValue]: 614 """compute all `measures` for the given `dataset` 615 sample measures are computed for the last sample in `dataset`""" 616 sample_calculators, dataset_calculators = get_measure_calculators(measures) 617 ret: Dict[Measure, MeasureValue] = {} 618 sample = None 619 for sample in dataset: 620 for calc in dataset_calculators: 621 calc.update(sample) 622 if sample is None: 623 raise ValueError("empty dataset") 624 625 for calc in dataset_calculators: 626 ret.update(calc.finalize().items()) 627 628 for calc in sample_calculators: 629 ret.update(calc.compute(sample).items()) 630 631 return ret
compute all measures
for the given dataset
sample measures are computed for the last sample in dataset