Coverage for src/bioimageio/core/stat_calculators.py: 75%

327 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-13 11:02 +0000

1from __future__ import annotations 

2 

3import collections 

4import warnings 

5from itertools import product 

6from typing import ( 

7 Any, 

8 Collection, 

9 Dict, 

10 Iterable, 

11 Iterator, 

12 List, 

13 Mapping, 

14 Optional, 

15 OrderedDict, 

16 Sequence, 

17 Set, 

18 Tuple, 

19 Type, 

20 Union, 

21) 

22 

23import numpy as np 

24import xarray as xr 

25from bioimageio.spec.model.v0_5 import BATCH_AXIS_ID 

26from loguru import logger 

27from numpy.typing import NDArray 

28from typing_extensions import assert_never 

29 

30from .axis import AxisId, PerAxis 

31from .common import MemberId 

32from .sample import Sample 

33from .stat_measures import ( 

34 DatasetMean, 

35 DatasetMeasure, 

36 DatasetMeasureBase, 

37 DatasetPercentile, 

38 DatasetStd, 

39 DatasetVar, 

40 Measure, 

41 MeasureValue, 

42 SampleMean, 

43 SampleMeasure, 

44 SampleQuantile, 

45 SampleStd, 

46 SampleVar, 

47) 

48from .tensor import Tensor 

49 

50try: 

51 import crick # pyright: ignore[reportMissingTypeStubs] 

52 

53except Exception: 

54 crick = None 

55 

56 class TDigest: 

57 def update(self, obj: Any): 

58 pass 

59 

60 def quantile(self, q: Any) -> Any: 

61 pass 

62 

63else: 

64 TDigest = crick.TDigest # type: ignore 

65 

66 

67class MeanCalculator: 

68 """to calculate sample and dataset mean for in-memory samples""" 

69 

70 def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]): 

71 super().__init__() 

72 self._n: int = 0 

73 self._mean: Optional[Tensor] = None 

74 self._axes = None if axes is None else tuple(axes) 

75 self._member_id = member_id 

76 self._sample_mean = SampleMean(member_id=self._member_id, axes=self._axes) 

77 self._dataset_mean = DatasetMean(member_id=self._member_id, axes=self._axes) 

78 

79 def compute(self, sample: Sample) -> Dict[SampleMean, MeasureValue]: 

80 return {self._sample_mean: self._compute_impl(sample)} 

81 

82 def _compute_impl(self, sample: Sample) -> Tensor: 

83 tensor = sample.members[self._member_id].astype("float64", copy=False) 

84 return tensor.mean(dim=self._axes) 

85 

86 def update(self, sample: Sample) -> None: 

87 mean = self._compute_impl(sample) 

88 self._update_impl(sample.members[self._member_id], mean) 

89 

90 def compute_and_update(self, sample: Sample) -> Dict[SampleMean, MeasureValue]: 

91 mean = self._compute_impl(sample) 

92 self._update_impl(sample.members[self._member_id], mean) 

93 return {self._sample_mean: mean} 

94 

95 def _update_impl(self, tensor: Tensor, tensor_mean: Tensor): 

96 assert tensor_mean.dtype == "float64" 

97 # reduced voxel count 

98 n_b = int(tensor.size / tensor_mean.size) 

99 

100 if self._mean is None: 

101 assert self._n == 0 

102 self._n = n_b 

103 self._mean = tensor_mean 

104 else: 

105 assert self._n != 0 

106 n_a = self._n 

107 mean_old = self._mean 

108 self._n = n_a + n_b 

109 self._mean = (n_a * mean_old + n_b * tensor_mean) / self._n 

110 assert self._mean.dtype == "float64" 

111 

112 def finalize(self) -> Dict[DatasetMean, MeasureValue]: 

113 if self._mean is None: 

114 return {} 

115 else: 

116 return {self._dataset_mean: self._mean} 

117 

118 

119class MeanVarStdCalculator: 

120 """to calculate sample and dataset mean, variance or standard deviation""" 

121 

122 def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]): 

123 super().__init__() 

124 self._axes = None if axes is None else tuple(map(AxisId, axes)) 

125 self._member_id = member_id 

126 self._n: int = 0 

127 self._mean: Optional[Tensor] = None 

128 self._m2: Optional[Tensor] = None 

129 

130 def compute( 

131 self, sample: Sample 

132 ) -> Dict[Union[SampleMean, SampleVar, SampleStd], MeasureValue]: 

133 tensor = sample.members[self._member_id] 

134 mean = tensor.mean(dim=self._axes) 

135 c = (tensor - mean).data 

136 if self._axes is None: 

137 n = tensor.size 

138 else: 

139 n = int(np.prod([tensor.sizes[d] for d in self._axes])) 

140 

141 if xr.__version__.startswith("2023"): 

142 var = xr.dot(c, c, dims=self._axes) / n 

143 else: 

144 var = xr.dot(c, c, dim=self._axes) / n 

145 

146 assert isinstance(var, xr.DataArray) 

147 std = np.sqrt(var) 

148 assert isinstance(std, xr.DataArray) 

149 return { 

150 SampleMean(axes=self._axes, member_id=self._member_id): mean, 

151 SampleVar(axes=self._axes, member_id=self._member_id): Tensor.from_xarray( 

152 var 

153 ), 

154 SampleStd(axes=self._axes, member_id=self._member_id): Tensor.from_xarray( 

155 std 

156 ), 

157 } 

158 

159 def update(self, sample: Sample): 

160 if self._axes is not None and BATCH_AXIS_ID not in self._axes: 

161 return 

162 

163 tensor = sample.members[self._member_id].astype("float64", copy=False) 

164 mean_b = tensor.mean(dim=self._axes) 

165 assert mean_b.dtype == "float64" 

166 # reduced voxel count 

167 n_b = int(tensor.size / mean_b.size) 

168 m2_b = ((tensor - mean_b) ** 2).sum(dim=self._axes) 

169 assert m2_b.dtype == "float64" 

170 if self._mean is None: 

171 assert self._m2 is None 

172 self._n = n_b 

173 self._mean = mean_b 

174 self._m2 = m2_b 

175 else: 

176 n_a = self._n 

177 mean_a = self._mean 

178 m2_a = self._m2 

179 self._n = n = n_a + n_b 

180 self._mean = (n_a * mean_a + n_b * mean_b) / n 

181 assert self._mean.dtype == "float64" 

182 d = mean_b - mean_a 

183 self._m2 = m2_a + m2_b + d**2 * n_a * n_b / n 

184 assert self._m2.dtype == "float64" 

185 

186 def finalize( 

187 self, 

188 ) -> Dict[Union[DatasetMean, DatasetVar, DatasetStd], MeasureValue]: 

189 if ( 

190 self._axes is not None 

191 and BATCH_AXIS_ID not in self._axes 

192 or self._mean is None 

193 ): 

194 return {} 

195 else: 

196 assert self._m2 is not None 

197 var = self._m2 / self._n 

198 sqrt = var**0.5 

199 if isinstance(sqrt, (int, float)): 

200 # var and mean are scalar tensors, let's keep it consistent 

201 sqrt = Tensor.from_xarray(xr.DataArray(sqrt)) 

202 

203 assert isinstance(sqrt, Tensor), type(sqrt) 

204 return { 

205 DatasetMean(member_id=self._member_id, axes=self._axes): self._mean, 

206 DatasetVar(member_id=self._member_id, axes=self._axes): var, 

207 DatasetStd(member_id=self._member_id, axes=self._axes): sqrt, 

208 } 

209 

210 

211class SamplePercentilesCalculator: 

212 """to calculate sample percentiles""" 

213 

214 def __init__( 

215 self, 

216 member_id: MemberId, 

217 axes: Optional[Sequence[AxisId]], 

218 qs: Collection[float], 

219 ): 

220 super().__init__() 

221 assert all(0.0 <= q <= 1.0 for q in qs) 

222 self._qs = sorted(set(qs)) 

223 self._axes = None if axes is None else tuple(axes) 

224 self._member_id = member_id 

225 

226 def compute(self, sample: Sample) -> Dict[SampleQuantile, MeasureValue]: 

227 tensor = sample.members[self._member_id] 

228 ps = tensor.quantile(self._qs, dim=self._axes) 

229 return { 

230 SampleQuantile(q=q, axes=self._axes, member_id=self._member_id): p 

231 for q, p in zip(self._qs, ps) 

232 } 

233 

234 

235class MeanPercentilesCalculator: 

236 """to calculate dataset percentiles heuristically by averaging across samples 

237 **note**: the returned dataset percentiles are an estiamte and **not mathematically correct** 

238 """ 

239 

240 def __init__( 

241 self, 

242 member_id: MemberId, 

243 axes: Optional[Sequence[AxisId]], 

244 qs: Collection[float], 

245 ): 

246 super().__init__() 

247 assert all(0.0 <= q <= 1.0 for q in qs) 

248 self._qs = sorted(set(qs)) 

249 self._axes = None if axes is None else tuple(axes) 

250 self._member_id = member_id 

251 self._n: int = 0 

252 self._estimates: Optional[Tensor] = None 

253 

254 def update(self, sample: Sample): 

255 tensor = sample.members[self._member_id] 

256 sample_estimates = tensor.quantile(self._qs, dim=self._axes).astype( 

257 "float64", copy=False 

258 ) 

259 

260 # reduced voxel count 

261 n = int(tensor.size / np.prod(sample_estimates.shape_tuple[1:])) 

262 

263 if self._estimates is None: 

264 assert self._n == 0 

265 self._estimates = sample_estimates 

266 else: 

267 self._estimates = (self._n * self._estimates + n * sample_estimates) / ( 

268 self._n + n 

269 ) 

270 assert self._estimates.dtype == "float64" 

271 

272 self._n += n 

273 

274 def finalize(self) -> Dict[DatasetPercentile, MeasureValue]: 

275 if self._estimates is None: 

276 return {} 

277 else: 

278 warnings.warn( 

279 "Computed dataset percentiles naively by averaging percentiles of samples." 

280 ) 

281 return { 

282 DatasetPercentile(q=q, axes=self._axes, member_id=self._member_id): e 

283 for q, e in zip(self._qs, self._estimates) 

284 } 

285 

286 

287class CrickPercentilesCalculator: 

288 """to calculate dataset percentiles with the experimental [crick libray](https://github.com/dask/crick)""" 

289 

290 def __init__( 

291 self, 

292 member_id: MemberId, 

293 axes: Optional[Sequence[AxisId]], 

294 qs: Collection[float], 

295 ): 

296 warnings.warn( 

297 "Computing dataset percentiles with experimental 'crick' library." 

298 ) 

299 super().__init__() 

300 assert all(0.0 <= q <= 1.0 for q in qs) 

301 assert axes is None or "_percentiles" not in axes 

302 self._qs = sorted(set(qs)) 

303 self._axes = None if axes is None else tuple(axes) 

304 self._member_id = member_id 

305 self._digest: Optional[List[TDigest]] = None 

306 self._dims: Optional[Tuple[AxisId, ...]] = None 

307 self._indices: Optional[Iterator[Tuple[int, ...]]] = None 

308 self._shape: Optional[Tuple[int, ...]] = None 

309 

310 def _initialize(self, tensor_sizes: PerAxis[int]): 

311 assert crick is not None 

312 out_sizes: OrderedDict[AxisId, int] = collections.OrderedDict( 

313 _percentiles=len(self._qs) 

314 ) 

315 if self._axes is not None: 

316 for d, s in tensor_sizes.items(): 

317 if d not in self._axes: 

318 out_sizes[d] = s 

319 

320 self._dims, self._shape = zip(*out_sizes.items()) 

321 assert self._shape is not None 

322 d = int(np.prod(self._shape[1:])) 

323 self._digest = [TDigest() for _ in range(d)] 

324 self._indices = product(*map(range, self._shape[1:])) 

325 

326 def update(self, part: Sample): 

327 tensor = ( 

328 part.members[self._member_id] 

329 if isinstance(part, Sample) 

330 else part.members[self._member_id].data 

331 ) 

332 assert "_percentiles" not in tensor.dims 

333 if self._digest is None: 

334 self._initialize(tensor.tagged_shape) 

335 

336 assert self._digest is not None 

337 assert self._indices is not None 

338 assert self._dims is not None 

339 for i, idx in enumerate(self._indices): 

340 self._digest[i].update(tensor[dict(zip(self._dims[1:], idx))]) 

341 

342 def finalize(self) -> Dict[DatasetPercentile, MeasureValue]: 

343 if self._digest is None: 

344 return {} 

345 else: 

346 assert self._dims is not None 

347 assert self._shape is not None 

348 

349 vs: NDArray[Any] = np.asarray( 

350 [[d.quantile(q) for d in self._digest] for q in self._qs] 

351 ).reshape(self._shape) 

352 return { 

353 DatasetPercentile( 

354 q=q, axes=self._axes, member_id=self._member_id 

355 ): Tensor(v, dims=self._dims[1:]) 

356 for q, v in zip(self._qs, vs) 

357 } 

358 

359 

360if crick is None: 

361 DatasetPercentilesCalculator: Type[ 

362 Union[MeanPercentilesCalculator, CrickPercentilesCalculator] 

363 ] = MeanPercentilesCalculator 

364else: 

365 DatasetPercentilesCalculator = CrickPercentilesCalculator 

366 

367 

368class NaiveSampleMeasureCalculator: 

369 """wrapper for measures to match interface of other sample measure calculators""" 

370 

371 def __init__(self, member_id: MemberId, measure: SampleMeasure): 

372 super().__init__() 

373 self.tensor_name = member_id 

374 self.measure = measure 

375 

376 def compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]: 

377 return {self.measure: self.measure.compute(sample)} 

378 

379 

380SampleMeasureCalculator = Union[ 

381 MeanCalculator, 

382 MeanVarStdCalculator, 

383 SamplePercentilesCalculator, 

384 NaiveSampleMeasureCalculator, 

385] 

386DatasetMeasureCalculator = Union[ 

387 MeanCalculator, MeanVarStdCalculator, DatasetPercentilesCalculator 

388] 

389 

390 

391class StatsCalculator: 

392 """Estimates dataset statistics and computes sample statistics efficiently""" 

393 

394 def __init__( 

395 self, 

396 measures: Collection[Measure], 

397 initial_dataset_measures: Optional[ 

398 Mapping[DatasetMeasure, MeasureValue] 

399 ] = None, 

400 ): 

401 super().__init__() 

402 self.sample_count = 0 

403 self.sample_calculators, self.dataset_calculators = get_measure_calculators( 

404 measures 

405 ) 

406 if not initial_dataset_measures: 

407 self._current_dataset_measures: Optional[ 

408 Dict[DatasetMeasure, MeasureValue] 

409 ] = None 

410 else: 

411 missing_dataset_meas = { 

412 m 

413 for m in measures 

414 if isinstance(m, DatasetMeasureBase) 

415 and m not in initial_dataset_measures 

416 } 

417 if missing_dataset_meas: 

418 logger.debug( 

419 f"ignoring `initial_dataset_measure` as it is missing {missing_dataset_meas}" 

420 ) 

421 self._current_dataset_measures = None 

422 else: 

423 self._current_dataset_measures = dict(initial_dataset_measures) 

424 

425 @property 

426 def has_dataset_measures(self): 

427 return self._current_dataset_measures is not None 

428 

429 def update( 

430 self, 

431 sample: Union[Sample, Iterable[Sample]], 

432 ) -> None: 

433 _ = self._update(sample) 

434 

435 def finalize(self) -> Dict[DatasetMeasure, MeasureValue]: 

436 """returns aggregated dataset statistics""" 

437 if self._current_dataset_measures is None: 

438 self._current_dataset_measures = {} 

439 for calc in self.dataset_calculators: 

440 values = calc.finalize() 

441 self._current_dataset_measures.update(values.items()) 

442 

443 return self._current_dataset_measures 

444 

445 def update_and_get_all( 

446 self, 

447 sample: Union[Sample, Iterable[Sample]], 

448 ) -> Dict[Measure, MeasureValue]: 

449 """Returns sample as well as updated dataset statistics""" 

450 last_sample = self._update(sample) 

451 if last_sample is None: 

452 raise ValueError("`sample` was not a `Sample`, nor did it yield any.") 

453 

454 return {**self._compute(last_sample), **self.finalize()} 

455 

456 def skip_update_and_get_all(self, sample: Sample) -> Dict[Measure, MeasureValue]: 

457 """Returns sample as well as previously computed dataset statistics""" 

458 return {**self._compute(sample), **self.finalize()} 

459 

460 def _compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]: 

461 ret: Dict[SampleMeasure, MeasureValue] = {} 

462 for calc in self.sample_calculators: 

463 values = calc.compute(sample) 

464 ret.update(values.items()) 

465 

466 return ret 

467 

468 def _update(self, sample: Union[Sample, Iterable[Sample]]) -> Optional[Sample]: 

469 self.sample_count += 1 

470 samples = [sample] if isinstance(sample, Sample) else sample 

471 last_sample = None 

472 for el in samples: 

473 last_sample = el 

474 for calc in self.dataset_calculators: 

475 calc.update(el) 

476 

477 self._current_dataset_measures = None 

478 return last_sample 

479 

480 

481def get_measure_calculators( 

482 required_measures: Iterable[Measure], 

483) -> Tuple[List[SampleMeasureCalculator], List[DatasetMeasureCalculator]]: 

484 """determines which calculators are needed to compute the required measures efficiently""" 

485 

486 sample_calculators: List[SampleMeasureCalculator] = [] 

487 dataset_calculators: List[DatasetMeasureCalculator] = [] 

488 

489 # split required measures into groups 

490 required_sample_means: Set[SampleMean] = set() 

491 required_dataset_means: Set[DatasetMean] = set() 

492 required_sample_mean_var_std: Set[Union[SampleMean, SampleVar, SampleStd]] = set() 

493 required_dataset_mean_var_std: Set[Union[DatasetMean, DatasetVar, DatasetStd]] = ( 

494 set() 

495 ) 

496 required_sample_percentiles: Dict[ 

497 Tuple[MemberId, Optional[Tuple[AxisId, ...]]], Set[float] 

498 ] = {} 

499 required_dataset_percentiles: Dict[ 

500 Tuple[MemberId, Optional[Tuple[AxisId, ...]]], Set[float] 

501 ] = {} 

502 

503 for rm in required_measures: 

504 if isinstance(rm, SampleMean): 

505 required_sample_means.add(rm) 

506 elif isinstance(rm, DatasetMean): 

507 required_dataset_means.add(rm) 

508 elif isinstance(rm, (SampleVar, SampleStd)): 

509 required_sample_mean_var_std.update( 

510 { 

511 msv(axes=rm.axes, member_id=rm.member_id) 

512 for msv in (SampleMean, SampleStd, SampleVar) 

513 } 

514 ) 

515 assert rm in required_sample_mean_var_std 

516 elif isinstance(rm, (DatasetVar, DatasetStd)): 

517 required_dataset_mean_var_std.update( 

518 { 

519 msv(axes=rm.axes, member_id=rm.member_id) 

520 for msv in (DatasetMean, DatasetStd, DatasetVar) 

521 } 

522 ) 

523 assert rm in required_dataset_mean_var_std 

524 elif isinstance(rm, SampleQuantile): 

525 required_sample_percentiles.setdefault((rm.member_id, rm.axes), set()).add( 

526 rm.q 

527 ) 

528 elif isinstance(rm, DatasetPercentile): 

529 required_dataset_percentiles.setdefault((rm.member_id, rm.axes), set()).add( 

530 rm.q 

531 ) 

532 else: 

533 assert_never(rm) 

534 

535 for rm in required_sample_means: 

536 if rm in required_sample_mean_var_std: 

537 # computed togehter with var and std 

538 continue 

539 

540 sample_calculators.append(MeanCalculator(member_id=rm.member_id, axes=rm.axes)) 

541 

542 for rm in required_sample_mean_var_std: 

543 sample_calculators.append( 

544 MeanVarStdCalculator(member_id=rm.member_id, axes=rm.axes) 

545 ) 

546 

547 for rm in required_dataset_means: 

548 if rm in required_dataset_mean_var_std: 

549 # computed togehter with var and std 

550 continue 

551 

552 dataset_calculators.append(MeanCalculator(member_id=rm.member_id, axes=rm.axes)) 

553 

554 for rm in required_dataset_mean_var_std: 

555 dataset_calculators.append( 

556 MeanVarStdCalculator(member_id=rm.member_id, axes=rm.axes) 

557 ) 

558 

559 for (tid, axes), qs in required_sample_percentiles.items(): 

560 sample_calculators.append( 

561 SamplePercentilesCalculator(member_id=tid, axes=axes, qs=qs) 

562 ) 

563 

564 for (tid, axes), qs in required_dataset_percentiles.items(): 

565 dataset_calculators.append( 

566 DatasetPercentilesCalculator(member_id=tid, axes=axes, qs=qs) 

567 ) 

568 

569 return sample_calculators, dataset_calculators 

570 

571 

572def compute_dataset_measures( 

573 measures: Iterable[DatasetMeasure], dataset: Iterable[Sample] 

574) -> Dict[DatasetMeasure, MeasureValue]: 

575 """compute all dataset `measures` for the given `dataset`""" 

576 sample_calculators, calculators = get_measure_calculators(measures) 

577 assert not sample_calculators 

578 

579 ret: Dict[DatasetMeasure, MeasureValue] = {} 

580 

581 for sample in dataset: 

582 for calc in calculators: 

583 calc.update(sample) 

584 

585 for calc in calculators: 

586 ret.update(calc.finalize().items()) 

587 

588 return ret 

589 

590 

591def compute_sample_measures( 

592 measures: Iterable[SampleMeasure], sample: Sample 

593) -> Dict[SampleMeasure, MeasureValue]: 

594 """compute all sample `measures` for the given `sample`""" 

595 calculators, dataset_calculators = get_measure_calculators(measures) 

596 assert not dataset_calculators 

597 ret: Dict[SampleMeasure, MeasureValue] = {} 

598 

599 for calc in calculators: 

600 ret.update(calc.compute(sample).items()) 

601 

602 return ret 

603 

604 

605def compute_measures( 

606 measures: Iterable[Measure], dataset: Iterable[Sample] 

607) -> Dict[Measure, MeasureValue]: 

608 """compute all `measures` for the given `dataset` 

609 sample measures are computed for the last sample in `dataset`""" 

610 sample_calculators, dataset_calculators = get_measure_calculators(measures) 

611 ret: Dict[Measure, MeasureValue] = {} 

612 sample = None 

613 for sample in dataset: 

614 for calc in dataset_calculators: 

615 calc.update(sample) 

616 if sample is None: 

617 raise ValueError("empty dataset") 

618 

619 for calc in dataset_calculators: 

620 ret.update(calc.finalize().items()) 

621 

622 for calc in sample_calculators: 

623 ret.update(calc.compute(sample).items()) 

624 

625 return ret