Coverage for bioimageio/core/stat_calculators.py: 68%

327 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-01 09:51 +0000

1from __future__ import annotations 

2 

3import collections 

4import warnings 

5from itertools import product 

6from typing import ( 

7 Any, 

8 Collection, 

9 Dict, 

10 Iterable, 

11 Iterator, 

12 List, 

13 Mapping, 

14 Optional, 

15 OrderedDict, 

16 Sequence, 

17 Set, 

18 Tuple, 

19 Type, 

20 Union, 

21) 

22 

23import numpy as np 

24import xarray as xr 

25from loguru import logger 

26from numpy.typing import NDArray 

27from typing_extensions import assert_never 

28 

29from bioimageio.spec.model.v0_5 import BATCH_AXIS_ID 

30 

31from .axis import AxisId, PerAxis 

32from .common import MemberId 

33from .sample import Sample 

34from .stat_measures import ( 

35 DatasetMean, 

36 DatasetMeasure, 

37 DatasetMeasureBase, 

38 DatasetPercentile, 

39 DatasetStd, 

40 DatasetVar, 

41 Measure, 

42 MeasureValue, 

43 SampleMean, 

44 SampleMeasure, 

45 SampleQuantile, 

46 SampleStd, 

47 SampleVar, 

48) 

49from .tensor import Tensor 

50 

51try: 

52 import crick # pyright: ignore[reportMissingImports] 

53 

54except Exception: 

55 crick = None 

56 

57 class TDigest: 

58 def update(self, obj: Any): 

59 pass 

60 

61 def quantile(self, q: Any) -> Any: 

62 pass 

63 

64else: 

65 TDigest = crick.TDigest # type: ignore 

66 

67 

68class MeanCalculator: 

69 """to calculate sample and dataset mean for in-memory samples""" 

70 

71 def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]): 

72 super().__init__() 

73 self._n: int = 0 

74 self._mean: Optional[Tensor] = None 

75 self._axes = None if axes is None else tuple(axes) 

76 self._member_id = member_id 

77 self._sample_mean = SampleMean(member_id=self._member_id, axes=self._axes) 

78 self._dataset_mean = DatasetMean(member_id=self._member_id, axes=self._axes) 

79 

80 def compute(self, sample: Sample) -> Dict[SampleMean, MeasureValue]: 

81 return {self._sample_mean: self._compute_impl(sample)} 

82 

83 def _compute_impl(self, sample: Sample) -> Tensor: 

84 tensor = sample.members[self._member_id].astype("float64", copy=False) 

85 return tensor.mean(dim=self._axes) 

86 

87 def update(self, sample: Sample) -> None: 

88 mean = self._compute_impl(sample) 

89 self._update_impl(sample.members[self._member_id], mean) 

90 

91 def compute_and_update(self, sample: Sample) -> Dict[SampleMean, MeasureValue]: 

92 mean = self._compute_impl(sample) 

93 self._update_impl(sample.members[self._member_id], mean) 

94 return {self._sample_mean: mean} 

95 

96 def _update_impl(self, tensor: Tensor, tensor_mean: Tensor): 

97 assert tensor_mean.dtype == "float64" 

98 # reduced voxel count 

99 n_b = int(tensor.size / tensor_mean.size) 

100 

101 if self._mean is None: 

102 assert self._n == 0 

103 self._n = n_b 

104 self._mean = tensor_mean 

105 else: 

106 assert self._n != 0 

107 n_a = self._n 

108 mean_old = self._mean 

109 self._n = n_a + n_b 

110 self._mean = (n_a * mean_old + n_b * tensor_mean) / self._n 

111 assert self._mean.dtype == "float64" 

112 

113 def finalize(self) -> Dict[DatasetMean, MeasureValue]: 

114 if self._mean is None: 

115 return {} 

116 else: 

117 return {self._dataset_mean: self._mean} 

118 

119 

120class MeanVarStdCalculator: 

121 """to calculate sample and dataset mean, variance or standard deviation""" 

122 

123 def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]): 

124 super().__init__() 

125 self._axes = None if axes is None else tuple(map(AxisId, axes)) 

126 self._member_id = member_id 

127 self._n: int = 0 

128 self._mean: Optional[Tensor] = None 

129 self._m2: Optional[Tensor] = None 

130 

131 def compute( 

132 self, sample: Sample 

133 ) -> Dict[Union[SampleMean, SampleVar, SampleStd], MeasureValue]: 

134 tensor = sample.members[self._member_id] 

135 mean = tensor.mean(dim=self._axes) 

136 c = (tensor - mean).data 

137 if self._axes is None: 

138 n = tensor.size 

139 else: 

140 n = int(np.prod([tensor.sizes[d] for d in self._axes])) 

141 

142 if xr.__version__.startswith("2023"): 

143 var = ( # pyright: ignore[reportUnknownVariableType] 

144 xr.dot(c, c, dims=self._axes) / n 

145 ) 

146 else: 

147 var = ( # pyright: ignore[reportUnknownVariableType] 

148 xr.dot(c, c, dim=self._axes) / n 

149 ) 

150 

151 assert isinstance(var, xr.DataArray) 

152 std = np.sqrt(var) 

153 assert isinstance(std, xr.DataArray) 

154 return { 

155 SampleMean(axes=self._axes, member_id=self._member_id): mean, 

156 SampleVar(axes=self._axes, member_id=self._member_id): Tensor.from_xarray( 

157 var 

158 ), 

159 SampleStd(axes=self._axes, member_id=self._member_id): Tensor.from_xarray( 

160 std 

161 ), 

162 } 

163 

164 def update(self, sample: Sample): 

165 if self._axes is not None and BATCH_AXIS_ID not in self._axes: 

166 return 

167 

168 tensor = sample.members[self._member_id].astype("float64", copy=False) 

169 mean_b = tensor.mean(dim=self._axes) 

170 assert mean_b.dtype == "float64" 

171 # reduced voxel count 

172 n_b = int(tensor.size / mean_b.size) 

173 m2_b = ((tensor - mean_b) ** 2).sum(dim=self._axes) 

174 assert m2_b.dtype == "float64" 

175 if self._mean is None: 

176 assert self._m2 is None 

177 self._n = n_b 

178 self._mean = mean_b 

179 self._m2 = m2_b 

180 else: 

181 n_a = self._n 

182 mean_a = self._mean 

183 m2_a = self._m2 

184 self._n = n = n_a + n_b 

185 self._mean = (n_a * mean_a + n_b * mean_b) / n 

186 assert self._mean.dtype == "float64" 

187 d = mean_b - mean_a 

188 self._m2 = m2_a + m2_b + d**2 * n_a * n_b / n 

189 assert self._m2.dtype == "float64" 

190 

191 def finalize( 

192 self, 

193 ) -> Dict[Union[DatasetMean, DatasetVar, DatasetStd], MeasureValue]: 

194 if ( 

195 self._axes is not None 

196 and BATCH_AXIS_ID not in self._axes 

197 or self._mean is None 

198 ): 

199 return {} 

200 else: 

201 assert self._m2 is not None 

202 var = self._m2 / self._n 

203 sqrt = var**0.5 

204 if isinstance(sqrt, (int, float)): 

205 # var and mean are scalar tensors, let's keep it consistent 

206 sqrt = Tensor.from_xarray(xr.DataArray(sqrt)) 

207 

208 assert isinstance(sqrt, Tensor), type(sqrt) 

209 return { 

210 DatasetMean(member_id=self._member_id, axes=self._axes): self._mean, 

211 DatasetVar(member_id=self._member_id, axes=self._axes): var, 

212 DatasetStd(member_id=self._member_id, axes=self._axes): sqrt, 

213 } 

214 

215 

216class SamplePercentilesCalculator: 

217 """to calculate sample percentiles""" 

218 

219 def __init__( 

220 self, 

221 member_id: MemberId, 

222 axes: Optional[Sequence[AxisId]], 

223 qs: Collection[float], 

224 ): 

225 super().__init__() 

226 assert all(0.0 <= q <= 1.0 for q in qs) 

227 self._qs = sorted(set(qs)) 

228 self._axes = None if axes is None else tuple(axes) 

229 self._member_id = member_id 

230 

231 def compute(self, sample: Sample) -> Dict[SampleQuantile, MeasureValue]: 

232 tensor = sample.members[self._member_id] 

233 ps = tensor.quantile(self._qs, dim=self._axes) 

234 return { 

235 SampleQuantile(q=q, axes=self._axes, member_id=self._member_id): p 

236 for q, p in zip(self._qs, ps) 

237 } 

238 

239 

240class MeanPercentilesCalculator: 

241 """to calculate dataset percentiles heuristically by averaging across samples 

242 **note**: the returned dataset percentiles are an estiamte and **not mathematically correct** 

243 """ 

244 

245 def __init__( 

246 self, 

247 member_id: MemberId, 

248 axes: Optional[Sequence[AxisId]], 

249 qs: Collection[float], 

250 ): 

251 super().__init__() 

252 assert all(0.0 <= q <= 1.0 for q in qs) 

253 self._qs = sorted(set(qs)) 

254 self._axes = None if axes is None else tuple(axes) 

255 self._member_id = member_id 

256 self._n: int = 0 

257 self._estimates: Optional[Tensor] = None 

258 

259 def update(self, sample: Sample): 

260 tensor = sample.members[self._member_id] 

261 sample_estimates = tensor.quantile(self._qs, dim=self._axes).astype( 

262 "float64", copy=False 

263 ) 

264 

265 # reduced voxel count 

266 n = int(tensor.size / np.prod(sample_estimates.shape_tuple[1:])) 

267 

268 if self._estimates is None: 

269 assert self._n == 0 

270 self._estimates = sample_estimates 

271 else: 

272 self._estimates = (self._n * self._estimates + n * sample_estimates) / ( 

273 self._n + n 

274 ) 

275 assert self._estimates.dtype == "float64" 

276 

277 self._n += n 

278 

279 def finalize(self) -> Dict[DatasetPercentile, MeasureValue]: 

280 if self._estimates is None: 

281 return {} 

282 else: 

283 warnings.warn( 

284 "Computed dataset percentiles naively by averaging percentiles of samples." 

285 ) 

286 return { 

287 DatasetPercentile(q=q, axes=self._axes, member_id=self._member_id): e 

288 for q, e in zip(self._qs, self._estimates) 

289 } 

290 

291 

292class CrickPercentilesCalculator: 

293 """to calculate dataset percentiles with the experimental [crick libray](https://github.com/dask/crick)""" 

294 

295 def __init__( 

296 self, 

297 member_id: MemberId, 

298 axes: Optional[Sequence[AxisId]], 

299 qs: Collection[float], 

300 ): 

301 warnings.warn( 

302 "Computing dataset percentiles with experimental 'crick' library." 

303 ) 

304 super().__init__() 

305 assert all(0.0 <= q <= 1.0 for q in qs) 

306 assert axes is None or "_percentiles" not in axes 

307 self._qs = sorted(set(qs)) 

308 self._axes = None if axes is None else tuple(axes) 

309 self._member_id = member_id 

310 self._digest: Optional[List[TDigest]] = None 

311 self._dims: Optional[Tuple[AxisId, ...]] = None 

312 self._indices: Optional[Iterator[Tuple[int, ...]]] = None 

313 self._shape: Optional[Tuple[int, ...]] = None 

314 

315 def _initialize(self, tensor_sizes: PerAxis[int]): 

316 assert crick is not None 

317 out_sizes: OrderedDict[AxisId, int] = collections.OrderedDict( 

318 _percentiles=len(self._qs) 

319 ) 

320 if self._axes is not None: 

321 for d, s in tensor_sizes.items(): 

322 if d not in self._axes: 

323 out_sizes[d] = s 

324 

325 self._dims, self._shape = zip(*out_sizes.items()) 

326 assert self._shape is not None 

327 d = int(np.prod(self._shape[1:])) 

328 self._digest = [TDigest() for _ in range(d)] 

329 self._indices = product(*map(range, self._shape[1:])) 

330 

331 def update(self, part: Sample): 

332 tensor = ( 

333 part.members[self._member_id] 

334 if isinstance(part, Sample) 

335 else part.members[self._member_id].data 

336 ) 

337 assert "_percentiles" not in tensor.dims 

338 if self._digest is None: 

339 self._initialize(tensor.tagged_shape) 

340 

341 assert self._digest is not None 

342 assert self._indices is not None 

343 assert self._dims is not None 

344 for i, idx in enumerate(self._indices): 

345 self._digest[i].update(tensor[dict(zip(self._dims[1:], idx))]) 

346 

347 def finalize(self) -> Dict[DatasetPercentile, MeasureValue]: 

348 if self._digest is None: 

349 return {} 

350 else: 

351 assert self._dims is not None 

352 assert self._shape is not None 

353 

354 vs: NDArray[Any] = np.asarray( 

355 [[d.quantile(q) for d in self._digest] for q in self._qs] 

356 ).reshape(self._shape) 

357 return { 

358 DatasetPercentile( 

359 q=q, axes=self._axes, member_id=self._member_id 

360 ): Tensor(v, dims=self._dims[1:]) 

361 for q, v in zip(self._qs, vs) 

362 } 

363 

364 

365if crick is None: 

366 DatasetPercentilesCalculator: Type[ 

367 Union[MeanPercentilesCalculator, CrickPercentilesCalculator] 

368 ] = MeanPercentilesCalculator 

369else: 

370 DatasetPercentilesCalculator = CrickPercentilesCalculator 

371 

372 

373class NaiveSampleMeasureCalculator: 

374 """wrapper for measures to match interface of other sample measure calculators""" 

375 

376 def __init__(self, member_id: MemberId, measure: SampleMeasure): 

377 super().__init__() 

378 self.tensor_name = member_id 

379 self.measure = measure 

380 

381 def compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]: 

382 return {self.measure: self.measure.compute(sample)} 

383 

384 

385SampleMeasureCalculator = Union[ 

386 MeanCalculator, 

387 MeanVarStdCalculator, 

388 SamplePercentilesCalculator, 

389 NaiveSampleMeasureCalculator, 

390] 

391DatasetMeasureCalculator = Union[ 

392 MeanCalculator, MeanVarStdCalculator, DatasetPercentilesCalculator 

393] 

394 

395 

396class StatsCalculator: 

397 """Estimates dataset statistics and computes sample statistics efficiently""" 

398 

399 def __init__( 

400 self, 

401 measures: Collection[Measure], 

402 initial_dataset_measures: Optional[ 

403 Mapping[DatasetMeasure, MeasureValue] 

404 ] = None, 

405 ): 

406 super().__init__() 

407 self.sample_count = 0 

408 self.sample_calculators, self.dataset_calculators = get_measure_calculators( 

409 measures 

410 ) 

411 if not initial_dataset_measures: 

412 self._current_dataset_measures: Optional[ 

413 Dict[DatasetMeasure, MeasureValue] 

414 ] = None 

415 else: 

416 missing_dataset_meas = { 

417 m 

418 for m in measures 

419 if isinstance(m, DatasetMeasureBase) 

420 and m not in initial_dataset_measures 

421 } 

422 if missing_dataset_meas: 

423 logger.debug( 

424 f"ignoring `initial_dataset_measure` as it is missing {missing_dataset_meas}" 

425 ) 

426 self._current_dataset_measures = None 

427 else: 

428 self._current_dataset_measures = dict(initial_dataset_measures) 

429 

430 @property 

431 def has_dataset_measures(self): 

432 return self._current_dataset_measures is not None 

433 

434 def update( 

435 self, 

436 sample: Union[Sample, Iterable[Sample]], 

437 ) -> None: 

438 _ = self._update(sample) 

439 

440 def finalize(self) -> Dict[DatasetMeasure, MeasureValue]: 

441 """returns aggregated dataset statistics""" 

442 if self._current_dataset_measures is None: 

443 self._current_dataset_measures = {} 

444 for calc in self.dataset_calculators: 

445 values = calc.finalize() 

446 self._current_dataset_measures.update(values.items()) 

447 

448 return self._current_dataset_measures 

449 

450 def update_and_get_all( 

451 self, 

452 sample: Union[Sample, Iterable[Sample]], 

453 ) -> Dict[Measure, MeasureValue]: 

454 """Returns sample as well as updated dataset statistics""" 

455 last_sample = self._update(sample) 

456 if last_sample is None: 

457 raise ValueError("`sample` was not a `Sample`, nor did it yield any.") 

458 

459 return {**self._compute(last_sample), **self.finalize()} 

460 

461 def skip_update_and_get_all(self, sample: Sample) -> Dict[Measure, MeasureValue]: 

462 """Returns sample as well as previously computed dataset statistics""" 

463 return {**self._compute(sample), **self.finalize()} 

464 

465 def _compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]: 

466 ret: Dict[SampleMeasure, MeasureValue] = {} 

467 for calc in self.sample_calculators: 

468 values = calc.compute(sample) 

469 ret.update(values.items()) 

470 

471 return ret 

472 

473 def _update(self, sample: Union[Sample, Iterable[Sample]]) -> Optional[Sample]: 

474 self.sample_count += 1 

475 samples = [sample] if isinstance(sample, Sample) else sample 

476 last_sample = None 

477 for el in samples: 

478 last_sample = el 

479 for calc in self.dataset_calculators: 

480 calc.update(el) 

481 

482 self._current_dataset_measures = None 

483 return last_sample 

484 

485 

486def get_measure_calculators( 

487 required_measures: Iterable[Measure], 

488) -> Tuple[List[SampleMeasureCalculator], List[DatasetMeasureCalculator]]: 

489 """determines which calculators are needed to compute the required measures efficiently""" 

490 

491 sample_calculators: List[SampleMeasureCalculator] = [] 

492 dataset_calculators: List[DatasetMeasureCalculator] = [] 

493 

494 # split required measures into groups 

495 required_sample_means: Set[SampleMean] = set() 

496 required_dataset_means: Set[DatasetMean] = set() 

497 required_sample_mean_var_std: Set[Union[SampleMean, SampleVar, SampleStd]] = set() 

498 required_dataset_mean_var_std: Set[Union[DatasetMean, DatasetVar, DatasetStd]] = ( 

499 set() 

500 ) 

501 required_sample_percentiles: Dict[ 

502 Tuple[MemberId, Optional[Tuple[AxisId, ...]]], Set[float] 

503 ] = {} 

504 required_dataset_percentiles: Dict[ 

505 Tuple[MemberId, Optional[Tuple[AxisId, ...]]], Set[float] 

506 ] = {} 

507 

508 for rm in required_measures: 

509 if isinstance(rm, SampleMean): 

510 required_sample_means.add(rm) 

511 elif isinstance(rm, DatasetMean): 

512 required_dataset_means.add(rm) 

513 elif isinstance(rm, (SampleVar, SampleStd)): 

514 required_sample_mean_var_std.update( 

515 { 

516 msv(axes=rm.axes, member_id=rm.member_id) 

517 for msv in (SampleMean, SampleStd, SampleVar) 

518 } 

519 ) 

520 assert rm in required_sample_mean_var_std 

521 elif isinstance(rm, (DatasetVar, DatasetStd)): 

522 required_dataset_mean_var_std.update( 

523 { 

524 msv(axes=rm.axes, member_id=rm.member_id) 

525 for msv in (DatasetMean, DatasetStd, DatasetVar) 

526 } 

527 ) 

528 assert rm in required_dataset_mean_var_std 

529 elif isinstance(rm, SampleQuantile): 

530 required_sample_percentiles.setdefault((rm.member_id, rm.axes), set()).add( 

531 rm.q 

532 ) 

533 elif isinstance(rm, DatasetPercentile): 

534 required_dataset_percentiles.setdefault((rm.member_id, rm.axes), set()).add( 

535 rm.q 

536 ) 

537 else: 

538 assert_never(rm) 

539 

540 for rm in required_sample_means: 

541 if rm in required_sample_mean_var_std: 

542 # computed togehter with var and std 

543 continue 

544 

545 sample_calculators.append(MeanCalculator(member_id=rm.member_id, axes=rm.axes)) 

546 

547 for rm in required_sample_mean_var_std: 

548 sample_calculators.append( 

549 MeanVarStdCalculator(member_id=rm.member_id, axes=rm.axes) 

550 ) 

551 

552 for rm in required_dataset_means: 

553 if rm in required_dataset_mean_var_std: 

554 # computed togehter with var and std 

555 continue 

556 

557 dataset_calculators.append(MeanCalculator(member_id=rm.member_id, axes=rm.axes)) 

558 

559 for rm in required_dataset_mean_var_std: 

560 dataset_calculators.append( 

561 MeanVarStdCalculator(member_id=rm.member_id, axes=rm.axes) 

562 ) 

563 

564 for (tid, axes), qs in required_sample_percentiles.items(): 

565 sample_calculators.append( 

566 SamplePercentilesCalculator(member_id=tid, axes=axes, qs=qs) 

567 ) 

568 

569 for (tid, axes), qs in required_dataset_percentiles.items(): 

570 dataset_calculators.append( 

571 DatasetPercentilesCalculator(member_id=tid, axes=axes, qs=qs) 

572 ) 

573 

574 return sample_calculators, dataset_calculators 

575 

576 

577def compute_dataset_measures( 

578 measures: Iterable[DatasetMeasure], dataset: Iterable[Sample] 

579) -> Dict[DatasetMeasure, MeasureValue]: 

580 """compute all dataset `measures` for the given `dataset`""" 

581 sample_calculators, calculators = get_measure_calculators(measures) 

582 assert not sample_calculators 

583 

584 ret: Dict[DatasetMeasure, MeasureValue] = {} 

585 

586 for sample in dataset: 

587 for calc in calculators: 

588 calc.update(sample) 

589 

590 for calc in calculators: 

591 ret.update(calc.finalize().items()) 

592 

593 return ret 

594 

595 

596def compute_sample_measures( 

597 measures: Iterable[SampleMeasure], sample: Sample 

598) -> Dict[SampleMeasure, MeasureValue]: 

599 """compute all sample `measures` for the given `sample`""" 

600 calculators, dataset_calculators = get_measure_calculators(measures) 

601 assert not dataset_calculators 

602 ret: Dict[SampleMeasure, MeasureValue] = {} 

603 

604 for calc in calculators: 

605 ret.update(calc.compute(sample).items()) 

606 

607 return ret 

608 

609 

610def compute_measures( 

611 measures: Iterable[Measure], dataset: Iterable[Sample] 

612) -> Dict[Measure, MeasureValue]: 

613 """compute all `measures` for the given `dataset` 

614 sample measures are computed for the last sample in `dataset`""" 

615 sample_calculators, dataset_calculators = get_measure_calculators(measures) 

616 ret: Dict[Measure, MeasureValue] = {} 

617 sample = None 

618 for sample in dataset: 

619 for calc in dataset_calculators: 

620 calc.update(sample) 

621 if sample is None: 

622 raise ValueError("empty dataset") 

623 

624 for calc in dataset_calculators: 

625 ret.update(calc.finalize().items()) 

626 

627 for calc in sample_calculators: 

628 ret.update(calc.compute(sample).items()) 

629 

630 return ret