Coverage for bioimageio/core/stat_calculators.py: 65%

321 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-19 09:02 +0000

1from __future__ import annotations 

2 

3import collections.abc 

4import warnings 

5from itertools import product 

6from typing import ( 

7 Any, 

8 Collection, 

9 Dict, 

10 Iterable, 

11 Iterator, 

12 List, 

13 Mapping, 

14 Optional, 

15 OrderedDict, 

16 Sequence, 

17 Set, 

18 Tuple, 

19 Type, 

20 Union, 

21) 

22 

23import numpy as np 

24import xarray as xr 

25from loguru import logger 

26from numpy.typing import NDArray 

27from typing_extensions import assert_never 

28 

29from .axis import AxisId, PerAxis 

30from .common import MemberId 

31from .sample import Sample 

32from .stat_measures import ( 

33 DatasetMean, 

34 DatasetMeasure, 

35 DatasetMeasureBase, 

36 DatasetPercentile, 

37 DatasetStd, 

38 DatasetVar, 

39 Measure, 

40 MeasureValue, 

41 SampleMean, 

42 SampleMeasure, 

43 SampleQuantile, 

44 SampleStd, 

45 SampleVar, 

46) 

47from .tensor import Tensor 

48 

49try: 

50 import crick 

51 

52except Exception: 

53 crick = None 

54 

55 class TDigest: 

56 def update(self, obj: Any): 

57 pass 

58 

59 def quantile(self, q: Any) -> Any: 

60 pass 

61 

62else: 

63 TDigest = crick.TDigest # type: ignore 

64 

65 

66class MeanCalculator: 

67 """to calculate sample and dataset mean for in-memory samples""" 

68 

69 def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]): 

70 super().__init__() 

71 self._n: int = 0 

72 self._mean: Optional[Tensor] = None 

73 self._axes = None if axes is None else tuple(axes) 

74 self._member_id = member_id 

75 self._sample_mean = SampleMean(member_id=self._member_id, axes=self._axes) 

76 self._dataset_mean = DatasetMean(member_id=self._member_id, axes=self._axes) 

77 

78 def compute(self, sample: Sample) -> Dict[SampleMean, MeasureValue]: 

79 return {self._sample_mean: self._compute_impl(sample)} 

80 

81 def _compute_impl(self, sample: Sample) -> Tensor: 

82 tensor = sample.members[self._member_id].astype("float64", copy=False) 

83 return tensor.mean(dim=self._axes) 

84 

85 def update(self, sample: Sample) -> None: 

86 mean = self._compute_impl(sample) 

87 self._update_impl(sample.members[self._member_id], mean) 

88 

89 def compute_and_update(self, sample: Sample) -> Dict[SampleMean, MeasureValue]: 

90 mean = self._compute_impl(sample) 

91 self._update_impl(sample.members[self._member_id], mean) 

92 return {self._sample_mean: mean} 

93 

94 def _update_impl(self, tensor: Tensor, tensor_mean: Tensor): 

95 assert tensor_mean.dtype == "float64" 

96 # reduced voxel count 

97 n_b = int(tensor.size / tensor_mean.size) 

98 

99 if self._mean is None: 

100 assert self._n == 0 

101 self._n = n_b 

102 self._mean = tensor_mean 

103 else: 

104 assert self._n != 0 

105 n_a = self._n 

106 mean_old = self._mean 

107 self._n = n_a + n_b 

108 self._mean = (n_a * mean_old + n_b * tensor_mean) / self._n 

109 assert self._mean.dtype == "float64" 

110 

111 def finalize(self) -> Dict[DatasetMean, MeasureValue]: 

112 if self._mean is None: 

113 return {} 

114 else: 

115 return {self._dataset_mean: self._mean} 

116 

117 

118class MeanVarStdCalculator: 

119 """to calculate sample and dataset mean, variance or standard deviation""" 

120 

121 def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]): 

122 super().__init__() 

123 self._axes = None if axes is None else tuple(axes) 

124 self._member_id = member_id 

125 self._n: int = 0 

126 self._mean: Optional[Tensor] = None 

127 self._m2: Optional[Tensor] = None 

128 

129 def compute( 

130 self, sample: Sample 

131 ) -> Dict[Union[SampleMean, SampleVar, SampleStd], MeasureValue]: 

132 tensor = sample.members[self._member_id] 

133 mean = tensor.mean(dim=self._axes) 

134 c = (tensor - mean).data 

135 if self._axes is None: 

136 n = tensor.size 

137 else: 

138 n = int(np.prod([tensor.sizes[d] for d in self._axes])) 

139 

140 var = xr.dot(c, c, dims=self._axes) / n 

141 assert isinstance(var, xr.DataArray) 

142 std = np.sqrt(var) 

143 assert isinstance(std, xr.DataArray) 

144 return { 

145 SampleMean(axes=self._axes, member_id=self._member_id): mean, 

146 SampleVar(axes=self._axes, member_id=self._member_id): Tensor.from_xarray( 

147 var 

148 ), 

149 SampleStd(axes=self._axes, member_id=self._member_id): Tensor.from_xarray( 

150 std 

151 ), 

152 } 

153 

154 def update(self, sample: Sample): 

155 tensor = sample.members[self._member_id].astype("float64", copy=False) 

156 mean_b = tensor.mean(dim=self._axes) 

157 assert mean_b.dtype == "float64" 

158 # reduced voxel count 

159 n_b = int(tensor.size / mean_b.size) 

160 m2_b = ((tensor - mean_b) ** 2).sum(dim=self._axes) 

161 assert m2_b.dtype == "float64" 

162 if self._mean is None: 

163 assert self._m2 is None 

164 self._n = n_b 

165 self._mean = mean_b 

166 self._m2 = m2_b 

167 else: 

168 n_a = self._n 

169 mean_a = self._mean 

170 m2_a = self._m2 

171 self._n = n = n_a + n_b 

172 self._mean = (n_a * mean_a + n_b * mean_b) / n 

173 assert self._mean.dtype == "float64" 

174 d = mean_b - mean_a 

175 self._m2 = m2_a + m2_b + d**2 * n_a * n_b / n 

176 assert self._m2.dtype == "float64" 

177 

178 def finalize( 

179 self, 

180 ) -> Dict[Union[DatasetMean, DatasetVar, DatasetStd], MeasureValue]: 

181 if self._mean is None: 

182 return {} 

183 else: 

184 assert self._m2 is not None 

185 var = self._m2 / self._n 

186 sqrt = np.sqrt(var) 

187 if isinstance(sqrt, (int, float)): 

188 # var and mean are scalar tensors, let's keep it consistent 

189 sqrt = Tensor.from_xarray(xr.DataArray(sqrt)) 

190 

191 assert isinstance(sqrt, Tensor), type(sqrt) 

192 return { 

193 DatasetMean(member_id=self._member_id, axes=self._axes): self._mean, 

194 DatasetVar(member_id=self._member_id, axes=self._axes): var, 

195 DatasetStd(member_id=self._member_id, axes=self._axes): sqrt, 

196 } 

197 

198 

199class SamplePercentilesCalculator: 

200 """to calculate sample percentiles""" 

201 

202 def __init__( 

203 self, 

204 member_id: MemberId, 

205 axes: Optional[Sequence[AxisId]], 

206 qs: Collection[float], 

207 ): 

208 super().__init__() 

209 assert all(0.0 <= q <= 1.0 for q in qs) 

210 self._qs = sorted(set(qs)) 

211 self._axes = None if axes is None else tuple(axes) 

212 self._member_id = member_id 

213 

214 def compute(self, sample: Sample) -> Dict[SampleQuantile, MeasureValue]: 

215 tensor = sample.members[self._member_id] 

216 ps = tensor.quantile(self._qs, dim=self._axes) 

217 return { 

218 SampleQuantile(q=q, axes=self._axes, member_id=self._member_id): p 

219 for q, p in zip(self._qs, ps) 

220 } 

221 

222 

223class MeanPercentilesCalculator: 

224 """to calculate dataset percentiles heuristically by averaging across samples 

225 **note**: the returned dataset percentiles are an estiamte and **not mathematically correct** 

226 """ 

227 

228 def __init__( 

229 self, 

230 member_id: MemberId, 

231 axes: Optional[Sequence[AxisId]], 

232 qs: Collection[float], 

233 ): 

234 super().__init__() 

235 assert all(0.0 <= q <= 1.0 for q in qs) 

236 self._qs = sorted(set(qs)) 

237 self._axes = None if axes is None else tuple(axes) 

238 self._member_id = member_id 

239 self._n: int = 0 

240 self._estimates: Optional[Tensor] = None 

241 

242 def update(self, sample: Sample): 

243 tensor = sample.members[self._member_id] 

244 sample_estimates = tensor.quantile(self._qs, dim=self._axes).astype( 

245 "float64", copy=False 

246 ) 

247 

248 # reduced voxel count 

249 n = int(tensor.size / np.prod(sample_estimates.shape_tuple[1:])) 

250 

251 if self._estimates is None: 

252 assert self._n == 0 

253 self._estimates = sample_estimates 

254 else: 

255 self._estimates = (self._n * self._estimates + n * sample_estimates) / ( 

256 self._n + n 

257 ) 

258 assert self._estimates.dtype == "float64" 

259 

260 self._n += n 

261 

262 def finalize(self) -> Dict[DatasetPercentile, MeasureValue]: 

263 if self._estimates is None: 

264 return {} 

265 else: 

266 warnings.warn( 

267 "Computed dataset percentiles naively by averaging percentiles of samples." 

268 ) 

269 return { 

270 DatasetPercentile(q=q, axes=self._axes, member_id=self._member_id): e 

271 for q, e in zip(self._qs, self._estimates) 

272 } 

273 

274 

275class CrickPercentilesCalculator: 

276 """to calculate dataset percentiles with the experimental [crick libray](https://github.com/dask/crick)""" 

277 

278 def __init__( 

279 self, 

280 member_id: MemberId, 

281 axes: Optional[Sequence[AxisId]], 

282 qs: Collection[float], 

283 ): 

284 warnings.warn( 

285 "Computing dataset percentiles with experimental 'crick' library." 

286 ) 

287 super().__init__() 

288 assert all(0.0 <= q <= 1.0 for q in qs) 

289 assert axes is None or "_percentiles" not in axes 

290 self._qs = sorted(set(qs)) 

291 self._axes = None if axes is None else tuple(axes) 

292 self._member_id = member_id 

293 self._digest: Optional[List[TDigest]] = None 

294 self._dims: Optional[Tuple[AxisId, ...]] = None 

295 self._indices: Optional[Iterator[Tuple[int, ...]]] = None 

296 self._shape: Optional[Tuple[int, ...]] = None 

297 

298 def _initialize(self, tensor_sizes: PerAxis[int]): 

299 assert crick is not None 

300 out_sizes: OrderedDict[AxisId, int] = collections.OrderedDict( 

301 _percentiles=len(self._qs) 

302 ) 

303 if self._axes is not None: 

304 for d, s in tensor_sizes.items(): 

305 if d not in self._axes: 

306 out_sizes[d] = s 

307 

308 self._dims, self._shape = zip(*out_sizes.items()) 

309 d = int(np.prod(self._shape[1:])) # type: ignore 

310 self._digest = [TDigest() for _ in range(d)] 

311 self._indices = product(*map(range, self._shape[1:])) 

312 

313 def update(self, part: Sample): 

314 tensor = ( 

315 part.members[self._member_id] 

316 if isinstance(part, Sample) 

317 else part.members[self._member_id].data 

318 ) 

319 assert "_percentiles" not in tensor.dims 

320 if self._digest is None: 

321 self._initialize(tensor.tagged_shape) 

322 

323 assert self._digest is not None 

324 assert self._indices is not None 

325 assert self._dims is not None 

326 for i, idx in enumerate(self._indices): 

327 self._digest[i].update(tensor[dict(zip(self._dims[1:], idx))]) 

328 

329 def finalize(self) -> Dict[DatasetPercentile, MeasureValue]: 

330 if self._digest is None: 

331 return {} 

332 else: 

333 assert self._dims is not None 

334 assert self._shape is not None 

335 

336 vs: NDArray[Any] = np.asarray( 

337 [[d.quantile(q) for d in self._digest] for q in self._qs] 

338 ).reshape(self._shape) 

339 return { 

340 DatasetPercentile( 

341 q=q, axes=self._axes, member_id=self._member_id 

342 ): Tensor(v, dims=self._dims[1:]) 

343 for q, v in zip(self._qs, vs) 

344 } 

345 

346 

347if crick is None: 

348 DatasetPercentilesCalculator: Type[ 

349 Union[MeanPercentilesCalculator, CrickPercentilesCalculator] 

350 ] = MeanPercentilesCalculator 

351else: 

352 DatasetPercentilesCalculator = CrickPercentilesCalculator 

353 

354 

355class NaiveSampleMeasureCalculator: 

356 """wrapper for measures to match interface of other sample measure calculators""" 

357 

358 def __init__(self, member_id: MemberId, measure: SampleMeasure): 

359 super().__init__() 

360 self.tensor_name = member_id 

361 self.measure = measure 

362 

363 def compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]: 

364 return {self.measure: self.measure.compute(sample)} 

365 

366 

367SampleMeasureCalculator = Union[ 

368 MeanCalculator, 

369 MeanVarStdCalculator, 

370 SamplePercentilesCalculator, 

371 NaiveSampleMeasureCalculator, 

372] 

373DatasetMeasureCalculator = Union[ 

374 MeanCalculator, MeanVarStdCalculator, DatasetPercentilesCalculator 

375] 

376 

377 

378class StatsCalculator: 

379 """Estimates dataset statistics and computes sample statistics efficiently""" 

380 

381 def __init__( 

382 self, 

383 measures: Collection[Measure], 

384 initial_dataset_measures: Optional[ 

385 Mapping[DatasetMeasure, MeasureValue] 

386 ] = None, 

387 ): 

388 super().__init__() 

389 self.sample_count = 0 

390 self.sample_calculators, self.dataset_calculators = get_measure_calculators( 

391 measures 

392 ) 

393 if not initial_dataset_measures: 

394 self._current_dataset_measures: Optional[ 

395 Dict[DatasetMeasure, MeasureValue] 

396 ] = None 

397 else: 

398 missing_dataset_meas = { 

399 m 

400 for m in measures 

401 if isinstance(m, DatasetMeasureBase) 

402 and m not in initial_dataset_measures 

403 } 

404 if missing_dataset_meas: 

405 logger.debug( 

406 f"ignoring `initial_dataset_measure` as it is missing {missing_dataset_meas}" 

407 ) 

408 self._current_dataset_measures = None 

409 else: 

410 self._current_dataset_measures = dict(initial_dataset_measures) 

411 

412 @property 

413 def has_dataset_measures(self): 

414 return self._current_dataset_measures is not None 

415 

416 def update( 

417 self, 

418 sample: Union[Sample, Iterable[Sample]], 

419 ) -> None: 

420 _ = self._update(sample) 

421 

422 def finalize(self) -> Dict[DatasetMeasure, MeasureValue]: 

423 """returns aggregated dataset statistics""" 

424 if self._current_dataset_measures is None: 

425 self._current_dataset_measures = {} 

426 for calc in self.dataset_calculators: 

427 values = calc.finalize() 

428 self._current_dataset_measures.update(values.items()) 

429 

430 return self._current_dataset_measures 

431 

432 def update_and_get_all( 

433 self, 

434 sample: Union[Sample, Iterable[Sample]], 

435 ) -> Dict[Measure, MeasureValue]: 

436 """Returns sample as well as updated dataset statistics""" 

437 last_sample = self._update(sample) 

438 if last_sample is None: 

439 raise ValueError("`sample` was not a `Sample`, nor did it yield any.") 

440 

441 return {**self._compute(last_sample), **self.finalize()} 

442 

443 def skip_update_and_get_all(self, sample: Sample) -> Dict[Measure, MeasureValue]: 

444 """Returns sample as well as previously computed dataset statistics""" 

445 return {**self._compute(sample), **self.finalize()} 

446 

447 def _compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]: 

448 ret: Dict[SampleMeasure, MeasureValue] = {} 

449 for calc in self.sample_calculators: 

450 values = calc.compute(sample) 

451 ret.update(values.items()) 

452 

453 return ret 

454 

455 def _update(self, sample: Union[Sample, Iterable[Sample]]) -> Optional[Sample]: 

456 self.sample_count += 1 

457 samples = [sample] if isinstance(sample, Sample) else sample 

458 last_sample = None 

459 for el in samples: 

460 last_sample = el 

461 for calc in self.dataset_calculators: 

462 calc.update(el) 

463 

464 self._current_dataset_measures = None 

465 return last_sample 

466 

467 

468def get_measure_calculators( 

469 required_measures: Iterable[Measure], 

470) -> Tuple[List[SampleMeasureCalculator], List[DatasetMeasureCalculator]]: 

471 """determines which calculators are needed to compute the required measures efficiently""" 

472 

473 sample_calculators: List[SampleMeasureCalculator] = [] 

474 dataset_calculators: List[DatasetMeasureCalculator] = [] 

475 

476 # split required measures into groups 

477 required_sample_means: Set[SampleMean] = set() 

478 required_dataset_means: Set[DatasetMean] = set() 

479 required_sample_mean_var_std: Set[Union[SampleMean, SampleVar, SampleStd]] = set() 

480 required_dataset_mean_var_std: Set[Union[DatasetMean, DatasetVar, DatasetStd]] = ( 

481 set() 

482 ) 

483 required_sample_percentiles: Dict[ 

484 Tuple[MemberId, Optional[Tuple[AxisId, ...]]], Set[float] 

485 ] = {} 

486 required_dataset_percentiles: Dict[ 

487 Tuple[MemberId, Optional[Tuple[AxisId, ...]]], Set[float] 

488 ] = {} 

489 

490 for rm in required_measures: 

491 if isinstance(rm, SampleMean): 

492 required_sample_means.add(rm) 

493 elif isinstance(rm, DatasetMean): 

494 required_dataset_means.add(rm) 

495 elif isinstance(rm, (SampleVar, SampleStd)): 

496 required_sample_mean_var_std.update( 

497 { 

498 msv(axes=rm.axes, member_id=rm.member_id) 

499 for msv in (SampleMean, SampleStd, SampleVar) 

500 } 

501 ) 

502 assert rm in required_sample_mean_var_std 

503 elif isinstance(rm, (DatasetVar, DatasetStd)): 

504 required_dataset_mean_var_std.update( 

505 { 

506 msv(axes=rm.axes, member_id=rm.member_id) 

507 for msv in (DatasetMean, DatasetStd, DatasetVar) 

508 } 

509 ) 

510 assert rm in required_dataset_mean_var_std 

511 elif isinstance(rm, SampleQuantile): 

512 required_sample_percentiles.setdefault((rm.member_id, rm.axes), set()).add( 

513 rm.q 

514 ) 

515 elif isinstance(rm, DatasetPercentile): 

516 required_dataset_percentiles.setdefault((rm.member_id, rm.axes), set()).add( 

517 rm.q 

518 ) 

519 else: 

520 assert_never(rm) 

521 

522 for rm in required_sample_means: 

523 if rm in required_sample_mean_var_std: 

524 # computed togehter with var and std 

525 continue 

526 

527 sample_calculators.append(MeanCalculator(member_id=rm.member_id, axes=rm.axes)) 

528 

529 for rm in required_sample_mean_var_std: 

530 sample_calculators.append( 

531 MeanVarStdCalculator(member_id=rm.member_id, axes=rm.axes) 

532 ) 

533 

534 for rm in required_dataset_means: 

535 if rm in required_dataset_mean_var_std: 

536 # computed togehter with var and std 

537 continue 

538 

539 dataset_calculators.append(MeanCalculator(member_id=rm.member_id, axes=rm.axes)) 

540 

541 for rm in required_dataset_mean_var_std: 

542 dataset_calculators.append( 

543 MeanVarStdCalculator(member_id=rm.member_id, axes=rm.axes) 

544 ) 

545 

546 for (tid, axes), qs in required_sample_percentiles.items(): 

547 sample_calculators.append( 

548 SamplePercentilesCalculator(member_id=tid, axes=axes, qs=qs) 

549 ) 

550 

551 for (tid, axes), qs in required_dataset_percentiles.items(): 

552 dataset_calculators.append( 

553 DatasetPercentilesCalculator(member_id=tid, axes=axes, qs=qs) 

554 ) 

555 

556 return sample_calculators, dataset_calculators 

557 

558 

559def compute_dataset_measures( 

560 measures: Iterable[DatasetMeasure], dataset: Iterable[Sample] 

561) -> Dict[DatasetMeasure, MeasureValue]: 

562 """compute all dataset `measures` for the given `dataset`""" 

563 sample_calculators, calculators = get_measure_calculators(measures) 

564 assert not sample_calculators 

565 

566 ret: Dict[DatasetMeasure, MeasureValue] = {} 

567 

568 for sample in dataset: 

569 for calc in calculators: 

570 calc.update(sample) 

571 

572 for calc in calculators: 

573 ret.update(calc.finalize().items()) 

574 

575 return ret 

576 

577 

578def compute_sample_measures( 

579 measures: Iterable[SampleMeasure], sample: Sample 

580) -> Dict[SampleMeasure, MeasureValue]: 

581 """compute all sample `measures` for the given `sample`""" 

582 calculators, dataset_calculators = get_measure_calculators(measures) 

583 assert not dataset_calculators 

584 ret: Dict[SampleMeasure, MeasureValue] = {} 

585 

586 for calc in calculators: 

587 ret.update(calc.compute(sample).items()) 

588 

589 return ret 

590 

591 

592def compute_measures( 

593 measures: Iterable[Measure], dataset: Iterable[Sample] 

594) -> Dict[Measure, MeasureValue]: 

595 """compute all `measures` for the given `dataset` 

596 sample measures are computed for the last sample in `dataset`""" 

597 sample_calculators, dataset_calculators = get_measure_calculators(measures) 

598 ret: Dict[Measure, MeasureValue] = {} 

599 sample = None 

600 for sample in dataset: 

601 for calc in dataset_calculators: 

602 calc.update(sample) 

603 if sample is None: 

604 raise ValueError("empty dataset") 

605 

606 for calc in dataset_calculators: 

607 ret.update(calc.finalize().items()) 

608 

609 for calc in sample_calculators: 

610 ret.update(calc.compute(sample).items()) 

611 

612 return ret