Coverage for src/bioimageio/core/sample.py: 93%

159 statements  

« prev     ^ index     » next       coverage.py v7.14.2, created at 2026-06-22 16:54 +0000

1from __future__ import annotations 

2 

3import collections.abc 

4from dataclasses import dataclass 

5from math import ceil, floor 

6from types import MappingProxyType 

7from typing import ( 

8 Any, 

9 Callable, 

10 Dict, 

11 Generic, 

12 Iterable, 

13 Mapping, 

14 Optional, 

15 Tuple, 

16 TypeVar, 

17 Union, 

18) 

19 

20import numpy as np 

21import pydantic 

22import xarray as xr 

23from numpy.typing import NDArray 

24from typing_extensions import Self 

25 

26from ._common_annotations import PerMemberAnno 

27from .axis import AxisId, PerAxis 

28from .block import Block 

29from .block_meta import ( 

30 BlockMeta, 

31 LinearAxisTransform, 

32 split_multiple_shapes_into_blocks, 

33) 

34from .common import ( 

35 BlockIndex, 

36 Halo, 

37 HaloLike, 

38 MemberId, 

39 PadMode, 

40 PadWidthLike, 

41 PerMember, 

42 SampleId, 

43 SliceInfo, 

44 TotalNumberOfBlocks, 

45) 

46from .stat_measures import Stat 

47from .tensor import Tensor 

48 

49# TODO: allow for lazy samples to read/write to disk 

50 

51 

52@dataclass 

53class Sample: 

54 """A dataset sample. 

55 

56 A `Sample` has `members`, which allows to combine multiple tensors into a single 

57 sample. 

58 For example a `Sample` from a dataset with masked images may contain a 

59 `MemberId("raw")` and `MemberId("mask")` image. 

60 """ 

61 

62 members: Dict[MemberId, Tensor] 

63 """The sample's tensors""" 

64 

65 stat: Stat 

66 """Sample and dataset statistics""" 

67 

68 id: SampleId 

69 """Identifies the `Sample` within the dataset -- typically a number or a string.""" 

70 

71 def __getitem__( 

72 self, 

73 key: PerMember[ 

74 Union[ 

75 SliceInfo, 

76 slice, 

77 int, 

78 PerAxis[Union[SliceInfo, slice, int]], 

79 Tensor, 

80 xr.DataArray, 

81 ] 

82 ], 

83 ) -> Self: 

84 return self.__class__( 

85 members={m: t[key[m]] for m, t in self.members.items() if m in key}, 

86 stat=self.stat, 

87 id=self.id, 

88 ) 

89 

90 def set_block(self, block: SampleBlock) -> None: 

91 """Set values of `block`. 

92 

93 Note: 

94 - Updates only existing sample members (extra block members are ignored) 

95 - Ignores missing block members (i.e. members in the sample but not in the block are not modified) 

96 

97 Raises: 

98 ValueError if block and sample members do not overlap at all. 

99 """ 

100 no_overlap = True 

101 for m in self.members: 

102 if m not in block.blocks: 

103 continue 

104 b = block.blocks[m] 

105 self.members[m][b.inner_slice] = b.inner_data 

106 no_overlap = False 

107 

108 if no_overlap: 

109 raise ValueError( 

110 f"block with members {list(block.blocks)} does not overlap with sample members {list(self.members)}" 

111 ) 

112 

113 @property 

114 def shape(self) -> PerMember[PerAxis[int]]: 

115 return {tid: t.sizes for tid, t in self.members.items()} 

116 

117 def as_arrays(self) -> Dict[MemberId, NDArray[Any]]: 

118 """Return sample as dictionary of arrays.""" 

119 return {m: t.to_numpy() for m, t in self.members.items()} 

120 

121 def split_into_blocks( 

122 self, 

123 block_shapes: PerMember[PerAxis[int]], 

124 halo: PerMember[PerAxis[HaloLike]], 

125 pad_mode: Union[PadMode, PerMember[PadMode]], 

126 broadcast: bool = False, 

127 ) -> Tuple[TotalNumberOfBlocks, Iterable[SampleBlockWithOrigin]]: 

128 assert not (missing := [m for m in block_shapes if m not in self.members]), ( 

129 f"`block_shapes` specified for unknown members: {missing}" 

130 ) 

131 assert not (missing := [m for m in halo if m not in block_shapes]), ( 

132 f"`halo` specified for members without `block_shape`: {missing}" 

133 ) 

134 

135 n_blocks, blocks = split_multiple_shapes_into_blocks( 

136 shapes=self.shape, 

137 block_shapes=block_shapes, 

138 halo=halo, 

139 broadcast=broadcast, 

140 ) 

141 return n_blocks, sample_block_generator(blocks, origin=self, pad_mode=pad_mode) 

142 

143 def as_single_block(self, halo: Optional[PerMember[PerAxis[Halo]]] = None): 

144 if halo is None: 

145 halo = {} 

146 return SampleBlockWithOrigin( 

147 sample_shape=self.shape, 

148 sample_id=self.id, 

149 blocks={ 

150 m: Block( 

151 sample_shape=self.shape[m], 

152 data=data, 

153 inner_slice={ 

154 a: SliceInfo(0, s) for a, s in data.tagged_shape.items() 

155 }, 

156 halo=halo.get(m, {}), 

157 block_index=0, 

158 blocks_in_sample=1, 

159 ) 

160 for m, data in self.members.items() 

161 }, 

162 stat=self.stat, 

163 origin=self, 

164 block_index=0, 

165 blocks_in_sample=1, 

166 ) 

167 

168 @classmethod 

169 def from_blocks( 

170 cls, 

171 sample_blocks: Iterable[SampleBlock], 

172 *, 

173 fill_value: float = float("nan"), 

174 ) -> Self: 

175 """Create a `Sample` from an iterable of `SampleBlock`s. 

176 

177 Note: 

178 All sample blocks must have the same `sample_id`. 

179 

180 Args: 

181 sample_blocks: The blocks to create the sample from. 

182 fill_value: The value to fill missing values with (default: `nan`). 

183 """ 

184 output = None 

185 for output in cls.from_blocks_yield_intermediates( 

186 sample_blocks, fill_value=fill_value 

187 ): 

188 pass 

189 

190 if output is None: 

191 raise ValueError("no sample blocks provided") 

192 

193 return output 

194 

195 @classmethod 

196 def from_blocks_yield_intermediates( 

197 cls, 

198 sample_blocks: Iterable[SampleBlock], 

199 *, 

200 fill_value: float = float("nan"), 

201 ): 

202 """Create a `Sample` from an iterable of `SampleBlock`s, yielding the intermediate sample after each block. 

203 

204 Args: 

205 sample_blocks: The blocks to create the sample from. 

206 fill_value: The value to fill missing values with (default: `nan`). 

207 """ 

208 output = cls(members={}, stat={}, id=None) 

209 for sample_block in sample_blocks: 

210 if output.id is None: 

211 output.id = sample_block.sample_id 

212 else: 

213 assert output.id == sample_block.sample_id, ( 

214 "sample id changed between sample blocks" 

215 ) 

216 

217 output.stat = sample_block.stat 

218 

219 for m, block in sample_block.blocks.items(): 

220 if m not in output.members: 

221 if -1 in block.sample_shape.values(): 

222 raise NotImplementedError( 

223 "merging blocks with data dependent axis not yet implemented" 

224 ) 

225 

226 output.members[m] = Tensor( 

227 np.full( 

228 tuple(block.sample_shape[a] for a in block.data.dims), 

229 fill_value, 

230 dtype=block.data.dtype, 

231 ), 

232 dims=block.data.dims, 

233 ) 

234 

235 output.members[m][block.inner_slice] = block.inner_data 

236 yield output 

237 

238 yield output 

239 

240 def pad( 

241 self, 

242 pad_width: PerMember[PerAxis[Union[int, PadWidthLike]]], 

243 mode: Union[PerMember[PadMode], PadMode], 

244 ) -> Self: 

245 """Convenience method to pad sample members.""" 

246 default_mode = "symmetric" 

247 if isinstance(mode, collections.abc.Mapping): 

248 mode_per_member = mode 

249 else: 

250 mode_per_member: Mapping[MemberId, PadMode] = {} 

251 default_mode = mode 

252 

253 return self.__class__( 

254 members={ 

255 m: t.pad( 

256 pad_width=pad_width.get(m, {}), 

257 mode=mode_per_member.get(m, default_mode), 

258 ) 

259 for m, t in self.members.items() 

260 }, 

261 stat=self.stat, 

262 id=self.id, 

263 ) 

264 

265 

266BlockT = TypeVar("BlockT", bound=BlockMeta) 

267 

268 

269@pydantic.dataclasses.dataclass(frozen=True) 

270class SampleBlockBase(Generic[BlockT]): 

271 """base class for `SampleBlockMeta` and `SampleBlock`""" 

272 

273 sample_shape: PerMemberAnno[PerAxis[int]] 

274 """the sample shape this block represents a part of""" 

275 

276 sample_id: SampleId 

277 """identifier for the sample within its dataset""" 

278 

279 blocks: PerMemberAnno[BlockT] 

280 """Individual tensor blocks comprising this sample block""" 

281 

282 block_index: BlockIndex 

283 """the n-th block of the sample""" 

284 

285 blocks_in_sample: TotalNumberOfBlocks 

286 """total number of blocks in the sample""" 

287 

288 @property 

289 def shape(self) -> PerMember[PerAxis[int]]: 

290 return MappingProxyType({mid: b.shape for mid, b in self.blocks.items()}) 

291 

292 @property 

293 def inner_shape(self) -> PerMember[PerAxis[int]]: 

294 return MappingProxyType({mid: b.inner_shape for mid, b in self.blocks.items()}) 

295 

296 

297@dataclass 

298class LinearSampleAxisTransform(LinearAxisTransform): 

299 member: MemberId 

300 

301 

302@pydantic.dataclasses.dataclass(frozen=True) 

303class SampleBlockMeta(SampleBlockBase[BlockMeta]): 

304 """Meta data of a dataset sample block""" 

305 

306 def get_transformed( 

307 self, new_axes: PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]] 

308 ) -> Self: 

309 sample_shape = { 

310 m: { 

311 a: ( 

312 trf 

313 if isinstance(trf, int) 

314 else trf.compute(self.sample_shape[trf.member][trf.axis]) 

315 ) 

316 for a, trf in new_axes[m].items() 

317 } 

318 for m in new_axes 

319 } 

320 

321 def get_member_halo(m: MemberId, round: Callable[[float], int]): 

322 return { 

323 a: ( 

324 Halo(0, 0) 

325 if isinstance(trf, int) 

326 or trf.axis not in self.blocks[trf.member].halo 

327 else Halo( 

328 round(self.blocks[trf.member].halo[trf.axis].left * trf.scale), 

329 round(self.blocks[trf.member].halo[trf.axis].right * trf.scale), 

330 ) 

331 ) 

332 for a, trf in new_axes[m].items() 

333 } 

334 

335 halo: Dict[MemberId, Dict[AxisId, Halo]] = {} 

336 for m in new_axes: 

337 halo[m] = get_member_halo(m, floor) 

338 if halo[m] != get_member_halo(m, ceil): 

339 raise ValueError( 

340 f"failed to unambiguously scale halo {halo[m]} with {new_axes[m]}" 

341 + f" for {m}." 

342 ) 

343 

344 inner_slice = { 

345 m: { 

346 a: ( 

347 SliceInfo(0, trf) 

348 if isinstance(trf, int) 

349 else SliceInfo( 

350 trf.compute( 

351 self.blocks[trf.member].inner_slice[trf.axis].start 

352 ), 

353 trf.compute(self.blocks[trf.member].inner_slice[trf.axis].stop), 

354 ) 

355 ) 

356 for a, trf in new_axes[m].items() 

357 } 

358 for m in new_axes 

359 } 

360 return self.__class__( 

361 blocks={ 

362 m: BlockMeta( 

363 sample_shape=sample_shape[m], 

364 inner_slice=inner_slice[m], 

365 halo=halo[m], 

366 block_index=self.block_index, 

367 blocks_in_sample=self.blocks_in_sample, 

368 ) 

369 for m in new_axes 

370 }, 

371 sample_shape=sample_shape, 

372 sample_id=self.sample_id, 

373 block_index=self.block_index, 

374 blocks_in_sample=self.blocks_in_sample, 

375 ) 

376 

377 def with_data(self, data: PerMember[Tensor], *, stat: Stat) -> SampleBlock: 

378 return SampleBlock( 

379 sample_shape={ 

380 m: { 

381 a: data[m].tagged_shape[a] if s == -1 else s 

382 for a, s in member_shape.items() 

383 } 

384 for m, member_shape in self.sample_shape.items() 

385 }, 

386 sample_id=self.sample_id, 

387 blocks={ 

388 m: Block.from_meta(b, data=data[m]) for m, b in self.blocks.items() 

389 }, 

390 stat=stat, 

391 block_index=self.block_index, 

392 blocks_in_sample=self.blocks_in_sample, 

393 ) 

394 

395 

396@dataclass(frozen=True) 

397class SampleBlock(SampleBlockBase[Block]): 

398 """A block of a dataset sample""" 

399 

400 blocks: Dict[MemberId, Block] 

401 """Individual tensor blocks comprising this sample block""" 

402 

403 stat: Stat 

404 """computed statistics""" 

405 

406 @property 

407 def members(self) -> PerMember[Tensor]: 

408 """the sample block's tensors""" 

409 return {m: b.data for m, b in self.blocks.items()} 

410 

411 def get_transformed_meta( 

412 self, new_axes: PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]] 

413 ) -> SampleBlockMeta: 

414 return SampleBlockMeta( 

415 sample_id=self.sample_id, 

416 blocks=dict(self.blocks), 

417 sample_shape=self.sample_shape, 

418 block_index=self.block_index, 

419 blocks_in_sample=self.blocks_in_sample, 

420 ).get_transformed(new_axes) 

421 

422 @classmethod 

423 def from_meta( 

424 cls, meta: SampleBlockMeta, data: PerMember[Tensor], stat: Stat 

425 ) -> Self: 

426 return cls( 

427 sample_shape=meta.sample_shape, 

428 sample_id=meta.sample_id, 

429 blocks={ 

430 m: Block.from_meta(b, data=data[m]) for m, b in meta.blocks.items() 

431 }, 

432 stat=stat, 

433 block_index=meta.block_index, 

434 blocks_in_sample=meta.blocks_in_sample, 

435 ) 

436 

437 def get_meta(self) -> SampleBlockMeta: 

438 return SampleBlockMeta( 

439 sample_id=self.sample_id, 

440 blocks={m: b.get_meta() for m, b in self.blocks.items()}, 

441 sample_shape=self.sample_shape, 

442 block_index=self.block_index, 

443 blocks_in_sample=self.blocks_in_sample, 

444 ) 

445 

446 def as_sample(self) -> Sample: 

447 """Convert this sample block to a `Sample` with the shape of this block. 

448 

449 Note: 

450 If you want to convert one or more sample block to a sample with the shape of the original, whole sample, 

451 use `Sample.from_blocks()` instead. 

452 """ 

453 return Sample( 

454 members=dict(self.members), 

455 stat=dict(self.stat), 

456 id=self.sample_id, 

457 ) 

458 

459 

460@dataclass(frozen=True) 

461class SampleBlockWithOrigin(SampleBlock): 

462 """A `SampleBlock` with a reference (`origin`) to the whole `Sample`""" 

463 

464 origin: Sample 

465 """the sample this sample block was taken from""" 

466 

467 

468class _ConsolidatedMemberBlocks: 

469 def __init__(self, blocks: PerMember[BlockMeta]): 

470 super().__init__() 

471 block_indices = {b.block_index for b in blocks.values()} 

472 assert len(block_indices) == 1 

473 self.block_index = block_indices.pop() 

474 blocks_in_samples = {b.blocks_in_sample for b in blocks.values()} 

475 assert len(blocks_in_samples) == 1 

476 self.blocks_in_sample = blocks_in_samples.pop() 

477 

478 

479def sample_block_meta_generator( 

480 blocks: Iterable[PerMember[BlockMeta]], 

481 *, 

482 sample_shape: PerMember[PerAxis[int]], 

483 sample_id: SampleId, 

484): 

485 for member_blocks in blocks: 

486 cons = _ConsolidatedMemberBlocks(member_blocks) 

487 yield SampleBlockMeta( 

488 blocks=dict(member_blocks), 

489 sample_shape=sample_shape, 

490 sample_id=sample_id, 

491 block_index=cons.block_index, 

492 blocks_in_sample=cons.blocks_in_sample, 

493 ) 

494 

495 

496def sample_block_generator( 

497 blocks: Iterable[PerMember[BlockMeta]], 

498 *, 

499 origin: Sample, 

500 pad_mode: Union[PadMode, PerMember[PadMode]], 

501) -> Iterable[SampleBlockWithOrigin]: 

502 for member_blocks in blocks: 

503 cons = _ConsolidatedMemberBlocks(member_blocks) 

504 yield SampleBlockWithOrigin( 

505 blocks={ 

506 m: Block.from_sample_member( 

507 origin.members[m], 

508 block=member_blocks[m], 

509 pad_mode=pad_mode.get(m, "symmetric") 

510 if isinstance(pad_mode, collections.abc.Mapping) 

511 else pad_mode, 

512 ) 

513 for m in origin.members 

514 }, 

515 sample_shape=origin.shape, 

516 origin=origin, 

517 stat=origin.stat, 

518 sample_id=origin.id, 

519 block_index=cons.block_index, 

520 blocks_in_sample=cons.blocks_in_sample, 

521 )