Coverage for bioimageio/core/sample.py: 93%

118 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-01 09:51 +0000

1from __future__ import annotations 

2 

3from dataclasses import dataclass 

4from math import ceil, floor 

5from typing import ( 

6 Any, 

7 Callable, 

8 Dict, 

9 Generic, 

10 Iterable, 

11 Optional, 

12 Tuple, 

13 TypeVar, 

14 Union, 

15) 

16 

17import numpy as np 

18from numpy.typing import NDArray 

19from typing_extensions import Self 

20 

21from .axis import AxisId, PerAxis 

22from .block import Block 

23from .block_meta import ( 

24 BlockMeta, 

25 LinearAxisTransform, 

26 split_multiple_shapes_into_blocks, 

27) 

28from .common import ( 

29 BlockIndex, 

30 Halo, 

31 HaloLike, 

32 MemberId, 

33 PadMode, 

34 PerMember, 

35 SampleId, 

36 SliceInfo, 

37 TotalNumberOfBlocks, 

38) 

39from .stat_measures import Stat 

40from .tensor import Tensor 

41 

42# TODO: allow for lazy samples to read/write to disk 

43 

44 

45@dataclass 

46class Sample: 

47 """A dataset sample. 

48 

49 A `Sample` has `members`, which allows to combine multiple tensors into a single 

50 sample. 

51 For example a `Sample` from a dataset with masked images may contain a 

52 `MemberId("raw")` and `MemberId("mask")` image. 

53 """ 

54 

55 members: Dict[MemberId, Tensor] 

56 """The sample's tensors""" 

57 

58 stat: Stat 

59 """Sample and dataset statistics""" 

60 

61 id: SampleId 

62 """Identifies the `Sample` within the dataset -- typically a number or a string.""" 

63 

64 @property 

65 def shape(self) -> PerMember[PerAxis[int]]: 

66 return {tid: t.sizes for tid, t in self.members.items()} 

67 

68 def as_arrays(self) -> Dict[str, NDArray[Any]]: 

69 """Return sample as dictionary of arrays.""" 

70 return {str(m): t.data.to_numpy() for m, t in self.members.items()} 

71 

72 def split_into_blocks( 

73 self, 

74 block_shapes: PerMember[PerAxis[int]], 

75 halo: PerMember[PerAxis[HaloLike]], 

76 pad_mode: PadMode, 

77 broadcast: bool = False, 

78 ) -> Tuple[TotalNumberOfBlocks, Iterable[SampleBlockWithOrigin]]: 

79 assert not ( 

80 missing := [m for m in block_shapes if m not in self.members] 

81 ), f"`block_shapes` specified for unknown members: {missing}" 

82 assert not ( 

83 missing := [m for m in halo if m not in block_shapes] 

84 ), f"`halo` specified for members without `block_shape`: {missing}" 

85 

86 n_blocks, blocks = split_multiple_shapes_into_blocks( 

87 shapes=self.shape, 

88 block_shapes=block_shapes, 

89 halo=halo, 

90 broadcast=broadcast, 

91 ) 

92 return n_blocks, sample_block_generator(blocks, origin=self, pad_mode=pad_mode) 

93 

94 def as_single_block(self, halo: Optional[PerMember[PerAxis[Halo]]] = None): 

95 if halo is None: 

96 halo = {} 

97 return SampleBlockWithOrigin( 

98 sample_shape=self.shape, 

99 sample_id=self.id, 

100 blocks={ 

101 m: Block( 

102 sample_shape=self.shape[m], 

103 data=data, 

104 inner_slice={ 

105 a: SliceInfo(0, s) for a, s in data.tagged_shape.items() 

106 }, 

107 halo=halo.get(m, {}), 

108 block_index=0, 

109 blocks_in_sample=1, 

110 ) 

111 for m, data in self.members.items() 

112 }, 

113 stat=self.stat, 

114 origin=self, 

115 block_index=0, 

116 blocks_in_sample=1, 

117 ) 

118 

119 @classmethod 

120 def from_blocks( 

121 cls, 

122 sample_blocks: Iterable[SampleBlock], 

123 *, 

124 fill_value: float = float("nan"), 

125 ) -> Self: 

126 members: PerMember[Tensor] = {} 

127 stat: Stat = {} 

128 sample_id = None 

129 for sample_block in sample_blocks: 

130 assert sample_id is None or sample_id == sample_block.sample_id 

131 sample_id = sample_block.sample_id 

132 stat = sample_block.stat 

133 for m, block in sample_block.blocks.items(): 

134 if m not in members: 

135 if -1 in block.sample_shape.values(): 

136 raise NotImplementedError( 

137 "merging blocks with data dependent axis not yet implemented" 

138 ) 

139 

140 members[m] = Tensor( 

141 np.full( 

142 tuple(block.sample_shape[a] for a in block.data.dims), 

143 fill_value, 

144 dtype=block.data.dtype, 

145 ), 

146 dims=block.data.dims, 

147 ) 

148 

149 members[m][block.inner_slice] = block.inner_data 

150 

151 return cls(members=members, stat=stat, id=sample_id) 

152 

153 

154BlockT = TypeVar("BlockT", Block, BlockMeta) 

155 

156 

157@dataclass 

158class SampleBlockBase(Generic[BlockT]): 

159 """base class for `SampleBlockMeta` and `SampleBlock`""" 

160 

161 sample_shape: PerMember[PerAxis[int]] 

162 """the sample shape this block represents a part of""" 

163 

164 sample_id: SampleId 

165 """identifier for the sample within its dataset""" 

166 

167 blocks: Dict[MemberId, BlockT] 

168 """Individual tensor blocks comprising this sample block""" 

169 

170 block_index: BlockIndex 

171 """the n-th block of the sample""" 

172 

173 blocks_in_sample: TotalNumberOfBlocks 

174 """total number of blocks in the sample""" 

175 

176 @property 

177 def shape(self) -> PerMember[PerAxis[int]]: 

178 return {mid: b.shape for mid, b in self.blocks.items()} 

179 

180 @property 

181 def inner_shape(self) -> PerMember[PerAxis[int]]: 

182 return {mid: b.inner_shape for mid, b in self.blocks.items()} 

183 

184 

185@dataclass 

186class LinearSampleAxisTransform(LinearAxisTransform): 

187 member: MemberId 

188 

189 

190@dataclass 

191class SampleBlockMeta(SampleBlockBase[BlockMeta]): 

192 """Meta data of a dataset sample block""" 

193 

194 def get_transformed( 

195 self, new_axes: PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]] 

196 ) -> Self: 

197 sample_shape = { 

198 m: { 

199 a: ( 

200 trf 

201 if isinstance(trf, int) 

202 else trf.compute(self.sample_shape[trf.member][trf.axis]) 

203 ) 

204 for a, trf in new_axes[m].items() 

205 } 

206 for m in new_axes 

207 } 

208 

209 def get_member_halo(m: MemberId, round: Callable[[float], int]): 

210 return { 

211 a: ( 

212 Halo(0, 0) 

213 if isinstance(trf, int) 

214 or trf.axis not in self.blocks[trf.member].halo 

215 else Halo( 

216 round(self.blocks[trf.member].halo[trf.axis].left * trf.scale), 

217 round(self.blocks[trf.member].halo[trf.axis].right * trf.scale), 

218 ) 

219 ) 

220 for a, trf in new_axes[m].items() 

221 } 

222 

223 halo: Dict[MemberId, Dict[AxisId, Halo]] = {} 

224 for m in new_axes: 

225 halo[m] = get_member_halo(m, floor) 

226 if halo[m] != get_member_halo(m, ceil): 

227 raise ValueError( 

228 f"failed to unambiguously scale halo {halo[m]} with {new_axes[m]}" 

229 + f" for {m}." 

230 ) 

231 

232 inner_slice = { 

233 m: { 

234 a: ( 

235 SliceInfo(0, trf) 

236 if isinstance(trf, int) 

237 else SliceInfo( 

238 trf.compute( 

239 self.blocks[trf.member].inner_slice[trf.axis].start 

240 ), 

241 trf.compute(self.blocks[trf.member].inner_slice[trf.axis].stop), 

242 ) 

243 ) 

244 for a, trf in new_axes[m].items() 

245 } 

246 for m in new_axes 

247 } 

248 return self.__class__( 

249 blocks={ 

250 m: BlockMeta( 

251 sample_shape=sample_shape[m], 

252 inner_slice=inner_slice[m], 

253 halo=halo[m], 

254 block_index=self.block_index, 

255 blocks_in_sample=self.blocks_in_sample, 

256 ) 

257 for m in new_axes 

258 }, 

259 sample_shape=sample_shape, 

260 sample_id=self.sample_id, 

261 block_index=self.block_index, 

262 blocks_in_sample=self.blocks_in_sample, 

263 ) 

264 

265 def with_data(self, data: PerMember[Tensor], *, stat: Stat) -> SampleBlock: 

266 return SampleBlock( 

267 sample_shape={ 

268 m: { 

269 a: data[m].tagged_shape[a] if s == -1 else s 

270 for a, s in member_shape.items() 

271 } 

272 for m, member_shape in self.sample_shape.items() 

273 }, 

274 sample_id=self.sample_id, 

275 blocks={ 

276 m: Block.from_meta(b, data=data[m]) for m, b in self.blocks.items() 

277 }, 

278 stat=stat, 

279 block_index=self.block_index, 

280 blocks_in_sample=self.blocks_in_sample, 

281 ) 

282 

283 

284@dataclass 

285class SampleBlock(SampleBlockBase[Block]): 

286 """A block of a dataset sample""" 

287 

288 stat: Stat 

289 """computed statistics""" 

290 

291 @property 

292 def members(self) -> PerMember[Tensor]: 

293 """the sample block's tensors""" 

294 return {m: b.data for m, b in self.blocks.items()} 

295 

296 def get_transformed_meta( 

297 self, new_axes: PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]] 

298 ) -> SampleBlockMeta: 

299 return SampleBlockMeta( 

300 sample_id=self.sample_id, 

301 blocks=dict(self.blocks), 

302 sample_shape=self.sample_shape, 

303 block_index=self.block_index, 

304 blocks_in_sample=self.blocks_in_sample, 

305 ).get_transformed(new_axes) 

306 

307 

308@dataclass 

309class SampleBlockWithOrigin(SampleBlock): 

310 """A `SampleBlock` with a reference (`origin`) to the whole `Sample`""" 

311 

312 origin: Sample 

313 """the sample this sample block was taken from""" 

314 

315 

316class _ConsolidatedMemberBlocks: 

317 def __init__(self, blocks: PerMember[BlockMeta]): 

318 super().__init__() 

319 block_indices = {b.block_index for b in blocks.values()} 

320 assert len(block_indices) == 1 

321 self.block_index = block_indices.pop() 

322 blocks_in_samples = {b.blocks_in_sample for b in blocks.values()} 

323 assert len(blocks_in_samples) == 1 

324 self.blocks_in_sample = blocks_in_samples.pop() 

325 

326 

327def sample_block_meta_generator( 

328 blocks: Iterable[PerMember[BlockMeta]], 

329 *, 

330 sample_shape: PerMember[PerAxis[int]], 

331 sample_id: SampleId, 

332): 

333 for member_blocks in blocks: 

334 cons = _ConsolidatedMemberBlocks(member_blocks) 

335 yield SampleBlockMeta( 

336 blocks=dict(member_blocks), 

337 sample_shape=sample_shape, 

338 sample_id=sample_id, 

339 block_index=cons.block_index, 

340 blocks_in_sample=cons.blocks_in_sample, 

341 ) 

342 

343 

344def sample_block_generator( 

345 blocks: Iterable[PerMember[BlockMeta]], 

346 *, 

347 origin: Sample, 

348 pad_mode: PadMode, 

349) -> Iterable[SampleBlockWithOrigin]: 

350 for member_blocks in blocks: 

351 cons = _ConsolidatedMemberBlocks(member_blocks) 

352 yield SampleBlockWithOrigin( 

353 blocks={ 

354 m: Block.from_sample_member( 

355 origin.members[m], block=member_blocks[m], pad_mode=pad_mode 

356 ) 

357 for m in origin.members 

358 }, 

359 sample_shape=origin.shape, 

360 origin=origin, 

361 stat=origin.stat, 

362 sample_id=origin.id, 

363 block_index=cons.block_index, 

364 blocks_in_sample=cons.blocks_in_sample, 

365 )