Coverage for src / bioimageio / core / sample.py: 92%

129 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-18 12:35 +0000

1from __future__ import annotations 

2 

3import collections.abc 

4from dataclasses import dataclass 

5from math import ceil, floor 

6from typing import ( 

7 Any, 

8 Callable, 

9 Dict, 

10 Generic, 

11 Iterable, 

12 Mapping, 

13 Optional, 

14 Tuple, 

15 TypeVar, 

16 Union, 

17) 

18 

19import numpy as np 

20import xarray as xr 

21from numpy.typing import NDArray 

22from typing_extensions import Self 

23 

24from .axis import AxisId, PerAxis 

25from .block import Block 

26from .block_meta import ( 

27 BlockMeta, 

28 LinearAxisTransform, 

29 split_multiple_shapes_into_blocks, 

30) 

31from .common import ( 

32 BlockIndex, 

33 Halo, 

34 HaloLike, 

35 MemberId, 

36 PadMode, 

37 PadWidthLike, 

38 PerMember, 

39 SampleId, 

40 SliceInfo, 

41 TotalNumberOfBlocks, 

42) 

43from .stat_measures import Stat 

44from .tensor import Tensor 

45 

46# TODO: allow for lazy samples to read/write to disk 

47 

48 

49@dataclass 

50class Sample: 

51 """A dataset sample. 

52 

53 A `Sample` has `members`, which allows to combine multiple tensors into a single 

54 sample. 

55 For example a `Sample` from a dataset with masked images may contain a 

56 `MemberId("raw")` and `MemberId("mask")` image. 

57 """ 

58 

59 members: Dict[MemberId, Tensor] 

60 """The sample's tensors""" 

61 

62 stat: Stat 

63 """Sample and dataset statistics""" 

64 

65 id: SampleId 

66 """Identifies the `Sample` within the dataset -- typically a number or a string.""" 

67 

68 def __getitem__( 

69 self, 

70 key: PerMember[ 

71 Union[ 

72 SliceInfo, 

73 slice, 

74 int, 

75 PerAxis[Union[SliceInfo, slice, int]], 

76 Tensor, 

77 xr.DataArray, 

78 ] 

79 ], 

80 ) -> Self: 

81 return self.__class__( 

82 members={m: t[key[m]] for m, t in self.members.items() if m in key}, 

83 stat=self.stat, 

84 id=self.id, 

85 ) 

86 

87 @property 

88 def shape(self) -> PerMember[PerAxis[int]]: 

89 return {tid: t.sizes for tid, t in self.members.items()} 

90 

91 def as_arrays(self) -> Dict[MemberId, NDArray[Any]]: 

92 """Return sample as dictionary of arrays.""" 

93 return {m: t.to_numpy() for m, t in self.members.items()} 

94 

95 def split_into_blocks( 

96 self, 

97 block_shapes: PerMember[PerAxis[int]], 

98 halo: PerMember[PerAxis[HaloLike]], 

99 pad_mode: Union[PadMode, PerMember[PadMode]], 

100 broadcast: bool = False, 

101 ) -> Tuple[TotalNumberOfBlocks, Iterable[SampleBlockWithOrigin]]: 

102 assert not (missing := [m for m in block_shapes if m not in self.members]), ( 

103 f"`block_shapes` specified for unknown members: {missing}" 

104 ) 

105 assert not (missing := [m for m in halo if m not in block_shapes]), ( 

106 f"`halo` specified for members without `block_shape`: {missing}" 

107 ) 

108 

109 n_blocks, blocks = split_multiple_shapes_into_blocks( 

110 shapes=self.shape, 

111 block_shapes=block_shapes, 

112 halo=halo, 

113 broadcast=broadcast, 

114 ) 

115 return n_blocks, sample_block_generator(blocks, origin=self, pad_mode=pad_mode) 

116 

117 def as_single_block(self, halo: Optional[PerMember[PerAxis[Halo]]] = None): 

118 if halo is None: 

119 halo = {} 

120 return SampleBlockWithOrigin( 

121 sample_shape=self.shape, 

122 sample_id=self.id, 

123 blocks={ 

124 m: Block( 

125 sample_shape=self.shape[m], 

126 data=data, 

127 inner_slice={ 

128 a: SliceInfo(0, s) for a, s in data.tagged_shape.items() 

129 }, 

130 halo=halo.get(m, {}), 

131 block_index=0, 

132 blocks_in_sample=1, 

133 ) 

134 for m, data in self.members.items() 

135 }, 

136 stat=self.stat, 

137 origin=self, 

138 block_index=0, 

139 blocks_in_sample=1, 

140 ) 

141 

142 @classmethod 

143 def from_blocks( 

144 cls, 

145 sample_blocks: Iterable[SampleBlock], 

146 *, 

147 fill_value: float = float("nan"), 

148 ) -> Self: 

149 members: PerMember[Tensor] = {} 

150 stat: Stat = {} 

151 sample_id = None 

152 for sample_block in sample_blocks: 

153 assert sample_id is None or sample_id == sample_block.sample_id 

154 sample_id = sample_block.sample_id 

155 stat = sample_block.stat 

156 for m, block in sample_block.blocks.items(): 

157 if m not in members: 

158 if -1 in block.sample_shape.values(): 

159 raise NotImplementedError( 

160 "merging blocks with data dependent axis not yet implemented" 

161 ) 

162 

163 members[m] = Tensor( 

164 np.full( 

165 tuple(block.sample_shape[a] for a in block.data.dims), 

166 fill_value, 

167 dtype=block.data.dtype, 

168 ), 

169 dims=block.data.dims, 

170 ) 

171 

172 members[m][block.inner_slice] = block.inner_data 

173 

174 return cls(members=members, stat=stat, id=sample_id) 

175 

176 def pad( 

177 self, 

178 pad_width: PerMember[PerAxis[Union[int, PadWidthLike]]], 

179 mode: Union[PerMember[PadMode], PadMode], 

180 ) -> Self: 

181 """Convenience method to pad sample members.""" 

182 default_mode = "symmetric" 

183 if isinstance(mode, collections.abc.Mapping): 

184 mode_per_member = mode 

185 else: 

186 mode_per_member: Mapping[MemberId, PadMode] = {} 

187 default_mode = mode 

188 

189 return self.__class__( 

190 members={ 

191 m: t.pad( 

192 pad_width=pad_width.get(m, {}), 

193 mode=mode_per_member.get(m, default_mode), 

194 ) 

195 for m, t in self.members.items() 

196 }, 

197 stat=self.stat, 

198 id=self.id, 

199 ) 

200 

201 

202BlockT = TypeVar("BlockT", Block, BlockMeta) 

203 

204 

205@dataclass 

206class SampleBlockBase(Generic[BlockT]): 

207 """base class for `SampleBlockMeta` and `SampleBlock`""" 

208 

209 sample_shape: PerMember[PerAxis[int]] 

210 """the sample shape this block represents a part of""" 

211 

212 sample_id: SampleId 

213 """identifier for the sample within its dataset""" 

214 

215 blocks: Dict[MemberId, BlockT] 

216 """Individual tensor blocks comprising this sample block""" 

217 

218 block_index: BlockIndex 

219 """the n-th block of the sample""" 

220 

221 blocks_in_sample: TotalNumberOfBlocks 

222 """total number of blocks in the sample""" 

223 

224 @property 

225 def shape(self) -> PerMember[PerAxis[int]]: 

226 return {mid: b.shape for mid, b in self.blocks.items()} 

227 

228 @property 

229 def inner_shape(self) -> PerMember[PerAxis[int]]: 

230 return {mid: b.inner_shape for mid, b in self.blocks.items()} 

231 

232 

233@dataclass 

234class LinearSampleAxisTransform(LinearAxisTransform): 

235 member: MemberId 

236 

237 

238@dataclass 

239class SampleBlockMeta(SampleBlockBase[BlockMeta]): 

240 """Meta data of a dataset sample block""" 

241 

242 def get_transformed( 

243 self, new_axes: PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]] 

244 ) -> Self: 

245 sample_shape = { 

246 m: { 

247 a: ( 

248 trf 

249 if isinstance(trf, int) 

250 else trf.compute(self.sample_shape[trf.member][trf.axis]) 

251 ) 

252 for a, trf in new_axes[m].items() 

253 } 

254 for m in new_axes 

255 } 

256 

257 def get_member_halo(m: MemberId, round: Callable[[float], int]): 

258 return { 

259 a: ( 

260 Halo(0, 0) 

261 if isinstance(trf, int) 

262 or trf.axis not in self.blocks[trf.member].halo 

263 else Halo( 

264 round(self.blocks[trf.member].halo[trf.axis].left * trf.scale), 

265 round(self.blocks[trf.member].halo[trf.axis].right * trf.scale), 

266 ) 

267 ) 

268 for a, trf in new_axes[m].items() 

269 } 

270 

271 halo: Dict[MemberId, Dict[AxisId, Halo]] = {} 

272 for m in new_axes: 

273 halo[m] = get_member_halo(m, floor) 

274 if halo[m] != get_member_halo(m, ceil): 

275 raise ValueError( 

276 f"failed to unambiguously scale halo {halo[m]} with {new_axes[m]}" 

277 + f" for {m}." 

278 ) 

279 

280 inner_slice = { 

281 m: { 

282 a: ( 

283 SliceInfo(0, trf) 

284 if isinstance(trf, int) 

285 else SliceInfo( 

286 trf.compute( 

287 self.blocks[trf.member].inner_slice[trf.axis].start 

288 ), 

289 trf.compute(self.blocks[trf.member].inner_slice[trf.axis].stop), 

290 ) 

291 ) 

292 for a, trf in new_axes[m].items() 

293 } 

294 for m in new_axes 

295 } 

296 return self.__class__( 

297 blocks={ 

298 m: BlockMeta( 

299 sample_shape=sample_shape[m], 

300 inner_slice=inner_slice[m], 

301 halo=halo[m], 

302 block_index=self.block_index, 

303 blocks_in_sample=self.blocks_in_sample, 

304 ) 

305 for m in new_axes 

306 }, 

307 sample_shape=sample_shape, 

308 sample_id=self.sample_id, 

309 block_index=self.block_index, 

310 blocks_in_sample=self.blocks_in_sample, 

311 ) 

312 

313 def with_data(self, data: PerMember[Tensor], *, stat: Stat) -> SampleBlock: 

314 return SampleBlock( 

315 sample_shape={ 

316 m: { 

317 a: data[m].tagged_shape[a] if s == -1 else s 

318 for a, s in member_shape.items() 

319 } 

320 for m, member_shape in self.sample_shape.items() 

321 }, 

322 sample_id=self.sample_id, 

323 blocks={ 

324 m: Block.from_meta(b, data=data[m]) for m, b in self.blocks.items() 

325 }, 

326 stat=stat, 

327 block_index=self.block_index, 

328 blocks_in_sample=self.blocks_in_sample, 

329 ) 

330 

331 

332@dataclass 

333class SampleBlock(SampleBlockBase[Block]): 

334 """A block of a dataset sample""" 

335 

336 stat: Stat 

337 """computed statistics""" 

338 

339 @property 

340 def members(self) -> PerMember[Tensor]: 

341 """the sample block's tensors""" 

342 return {m: b.data for m, b in self.blocks.items()} 

343 

344 def get_transformed_meta( 

345 self, new_axes: PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]] 

346 ) -> SampleBlockMeta: 

347 return SampleBlockMeta( 

348 sample_id=self.sample_id, 

349 blocks=dict(self.blocks), 

350 sample_shape=self.sample_shape, 

351 block_index=self.block_index, 

352 blocks_in_sample=self.blocks_in_sample, 

353 ).get_transformed(new_axes) 

354 

355 

356@dataclass 

357class SampleBlockWithOrigin(SampleBlock): 

358 """A `SampleBlock` with a reference (`origin`) to the whole `Sample`""" 

359 

360 origin: Sample 

361 """the sample this sample block was taken from""" 

362 

363 

364class _ConsolidatedMemberBlocks: 

365 def __init__(self, blocks: PerMember[BlockMeta]): 

366 super().__init__() 

367 block_indices = {b.block_index for b in blocks.values()} 

368 assert len(block_indices) == 1 

369 self.block_index = block_indices.pop() 

370 blocks_in_samples = {b.blocks_in_sample for b in blocks.values()} 

371 assert len(blocks_in_samples) == 1 

372 self.blocks_in_sample = blocks_in_samples.pop() 

373 

374 

375def sample_block_meta_generator( 

376 blocks: Iterable[PerMember[BlockMeta]], 

377 *, 

378 sample_shape: PerMember[PerAxis[int]], 

379 sample_id: SampleId, 

380): 

381 for member_blocks in blocks: 

382 cons = _ConsolidatedMemberBlocks(member_blocks) 

383 yield SampleBlockMeta( 

384 blocks=dict(member_blocks), 

385 sample_shape=sample_shape, 

386 sample_id=sample_id, 

387 block_index=cons.block_index, 

388 blocks_in_sample=cons.blocks_in_sample, 

389 ) 

390 

391 

392def sample_block_generator( 

393 blocks: Iterable[PerMember[BlockMeta]], 

394 *, 

395 origin: Sample, 

396 pad_mode: Union[PadMode, PerMember[PadMode]], 

397) -> Iterable[SampleBlockWithOrigin]: 

398 for member_blocks in blocks: 

399 cons = _ConsolidatedMemberBlocks(member_blocks) 

400 yield SampleBlockWithOrigin( 

401 blocks={ 

402 m: Block.from_sample_member( 

403 origin.members[m], 

404 block=member_blocks[m], 

405 pad_mode=pad_mode.get(m, "symmetric") 

406 if isinstance(pad_mode, collections.abc.Mapping) 

407 else pad_mode, 

408 ) 

409 for m in origin.members 

410 }, 

411 sample_shape=origin.shape, 

412 origin=origin, 

413 stat=origin.stat, 

414 sample_id=origin.id, 

415 block_index=cons.block_index, 

416 blocks_in_sample=cons.blocks_in_sample, 

417 )