Coverage for bioimageio/core/sample.py: 94%

115 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-19 09:02 +0000

1from __future__ import annotations 

2 

3from dataclasses import dataclass 

4from math import ceil, floor 

5from typing import ( 

6 Callable, 

7 Dict, 

8 Generic, 

9 Iterable, 

10 Optional, 

11 Tuple, 

12 TypeVar, 

13 Union, 

14) 

15 

16import numpy as np 

17from typing_extensions import Self 

18 

19from .axis import AxisId, PerAxis 

20from .block import Block 

21from .block_meta import ( 

22 BlockMeta, 

23 LinearAxisTransform, 

24 split_multiple_shapes_into_blocks, 

25) 

26from .common import ( 

27 BlockIndex, 

28 Halo, 

29 HaloLike, 

30 MemberId, 

31 PadMode, 

32 PerMember, 

33 SampleId, 

34 SliceInfo, 

35 TotalNumberOfBlocks, 

36) 

37from .stat_measures import Stat 

38from .tensor import Tensor 

39 

40# TODO: allow for lazy samples to read/write to disk 

41 

42 

43@dataclass 

44class Sample: 

45 """A dataset sample""" 

46 

47 members: Dict[MemberId, Tensor] 

48 """the sample's tensors""" 

49 

50 stat: Stat 

51 """sample and dataset statistics""" 

52 

53 id: SampleId 

54 """identifier within the sample's dataset""" 

55 

56 @property 

57 def shape(self) -> PerMember[PerAxis[int]]: 

58 return {tid: t.sizes for tid, t in self.members.items()} 

59 

60 def split_into_blocks( 

61 self, 

62 block_shapes: PerMember[PerAxis[int]], 

63 halo: PerMember[PerAxis[HaloLike]], 

64 pad_mode: PadMode, 

65 broadcast: bool = False, 

66 ) -> Tuple[TotalNumberOfBlocks, Iterable[SampleBlockWithOrigin]]: 

67 assert not ( 

68 missing := [m for m in block_shapes if m not in self.members] 

69 ), f"`block_shapes` specified for unknown members: {missing}" 

70 assert not ( 

71 missing := [m for m in halo if m not in block_shapes] 

72 ), f"`halo` specified for members without `block_shape`: {missing}" 

73 

74 n_blocks, blocks = split_multiple_shapes_into_blocks( 

75 shapes=self.shape, 

76 block_shapes=block_shapes, 

77 halo=halo, 

78 broadcast=broadcast, 

79 ) 

80 return n_blocks, sample_block_generator(blocks, origin=self, pad_mode=pad_mode) 

81 

82 def as_single_block(self, halo: Optional[PerMember[PerAxis[Halo]]] = None): 

83 if halo is None: 

84 halo = {} 

85 return SampleBlockWithOrigin( 

86 sample_shape=self.shape, 

87 sample_id=self.id, 

88 blocks={ 

89 m: Block( 

90 sample_shape=self.shape[m], 

91 data=data, 

92 inner_slice={ 

93 a: SliceInfo(0, s) for a, s in data.tagged_shape.items() 

94 }, 

95 halo=halo.get(m, {}), 

96 block_index=0, 

97 blocks_in_sample=1, 

98 ) 

99 for m, data in self.members.items() 

100 }, 

101 stat=self.stat, 

102 origin=self, 

103 block_index=0, 

104 blocks_in_sample=1, 

105 ) 

106 

107 @classmethod 

108 def from_blocks( 

109 cls, 

110 sample_blocks: Iterable[SampleBlock], 

111 *, 

112 fill_value: float = float("nan"), 

113 ) -> Self: 

114 members: PerMember[Tensor] = {} 

115 stat: Stat = {} 

116 sample_id = None 

117 for sample_block in sample_blocks: 

118 assert sample_id is None or sample_id == sample_block.sample_id 

119 sample_id = sample_block.sample_id 

120 stat = sample_block.stat 

121 for m, block in sample_block.blocks.items(): 

122 if m not in members: 

123 if -1 in block.sample_shape.values(): 

124 raise NotImplementedError( 

125 "merging blocks with data dependent axis not yet implemented" 

126 ) 

127 

128 members[m] = Tensor( 

129 np.full( 

130 tuple(block.sample_shape[a] for a in block.data.dims), 

131 fill_value, 

132 dtype=block.data.dtype, 

133 ), 

134 dims=block.data.dims, 

135 ) 

136 

137 members[m][block.inner_slice] = block.inner_data 

138 

139 return cls(members=members, stat=stat, id=sample_id) 

140 

141 

142BlockT = TypeVar("BlockT", Block, BlockMeta) 

143 

144 

145@dataclass 

146class SampleBlockBase(Generic[BlockT]): 

147 """base class for `SampleBlockMeta` and `SampleBlock`""" 

148 

149 sample_shape: PerMember[PerAxis[int]] 

150 """the sample shape this block represents a part of""" 

151 

152 sample_id: SampleId 

153 """identifier for the sample within its dataset""" 

154 

155 blocks: Dict[MemberId, BlockT] 

156 """Individual tensor blocks comprising this sample block""" 

157 

158 block_index: BlockIndex 

159 """the n-th block of the sample""" 

160 

161 blocks_in_sample: TotalNumberOfBlocks 

162 """total number of blocks in the sample""" 

163 

164 @property 

165 def shape(self) -> PerMember[PerAxis[int]]: 

166 return {mid: b.shape for mid, b in self.blocks.items()} 

167 

168 @property 

169 def inner_shape(self) -> PerMember[PerAxis[int]]: 

170 return {mid: b.inner_shape for mid, b in self.blocks.items()} 

171 

172 

173@dataclass 

174class LinearSampleAxisTransform(LinearAxisTransform): 

175 member: MemberId 

176 

177 

178@dataclass 

179class SampleBlockMeta(SampleBlockBase[BlockMeta]): 

180 """Meta data of a dataset sample block""" 

181 

182 def get_transformed( 

183 self, new_axes: PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]] 

184 ) -> Self: 

185 sample_shape = { 

186 m: { 

187 a: ( 

188 trf 

189 if isinstance(trf, int) 

190 else trf.compute(self.sample_shape[trf.member][trf.axis]) 

191 ) 

192 for a, trf in new_axes[m].items() 

193 } 

194 for m in new_axes 

195 } 

196 

197 def get_member_halo(m: MemberId, round: Callable[[float], int]): 

198 return { 

199 a: ( 

200 Halo(0, 0) 

201 if isinstance(trf, int) 

202 or trf.axis not in self.blocks[trf.member].halo 

203 else Halo( 

204 round(self.blocks[trf.member].halo[trf.axis].left * trf.scale), 

205 round(self.blocks[trf.member].halo[trf.axis].right * trf.scale), 

206 ) 

207 ) 

208 for a, trf in new_axes[m].items() 

209 } 

210 

211 halo: Dict[MemberId, Dict[AxisId, Halo]] = {} 

212 for m in new_axes: 

213 halo[m] = get_member_halo(m, floor) 

214 if halo[m] != get_member_halo(m, ceil): 

215 raise ValueError( 

216 f"failed to unambiguously scale halo {halo[m]} with {new_axes[m]}" 

217 + f" for {m}." 

218 ) 

219 

220 inner_slice = { 

221 m: { 

222 a: ( 

223 SliceInfo(0, trf) 

224 if isinstance(trf, int) 

225 else SliceInfo( 

226 trf.compute( 

227 self.blocks[trf.member].inner_slice[trf.axis].start 

228 ), 

229 trf.compute(self.blocks[trf.member].inner_slice[trf.axis].stop), 

230 ) 

231 ) 

232 for a, trf in new_axes[m].items() 

233 } 

234 for m in new_axes 

235 } 

236 return self.__class__( 

237 blocks={ 

238 m: BlockMeta( 

239 sample_shape=sample_shape[m], 

240 inner_slice=inner_slice[m], 

241 halo=halo[m], 

242 block_index=self.block_index, 

243 blocks_in_sample=self.blocks_in_sample, 

244 ) 

245 for m in new_axes 

246 }, 

247 sample_shape=sample_shape, 

248 sample_id=self.sample_id, 

249 block_index=self.block_index, 

250 blocks_in_sample=self.blocks_in_sample, 

251 ) 

252 

253 def with_data(self, data: PerMember[Tensor], *, stat: Stat) -> SampleBlock: 

254 return SampleBlock( 

255 sample_shape={ 

256 m: { 

257 a: data[m].tagged_shape[a] if s == -1 else s 

258 for a, s in member_shape.items() 

259 } 

260 for m, member_shape in self.sample_shape.items() 

261 }, 

262 sample_id=self.sample_id, 

263 blocks={ 

264 m: Block.from_meta(b, data=data[m]) for m, b in self.blocks.items() 

265 }, 

266 stat=stat, 

267 block_index=self.block_index, 

268 blocks_in_sample=self.blocks_in_sample, 

269 ) 

270 

271 

272@dataclass 

273class SampleBlock(SampleBlockBase[Block]): 

274 """A block of a dataset sample""" 

275 

276 stat: Stat 

277 """computed statistics""" 

278 

279 @property 

280 def members(self) -> PerMember[Tensor]: 

281 """the sample block's tensors""" 

282 return {m: b.data for m, b in self.blocks.items()} 

283 

284 def get_transformed_meta( 

285 self, new_axes: PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]] 

286 ) -> SampleBlockMeta: 

287 return SampleBlockMeta( 

288 sample_id=self.sample_id, 

289 blocks=dict(self.blocks), 

290 sample_shape=self.sample_shape, 

291 block_index=self.block_index, 

292 blocks_in_sample=self.blocks_in_sample, 

293 ).get_transformed(new_axes) 

294 

295 

296@dataclass 

297class SampleBlockWithOrigin(SampleBlock): 

298 """A `SampleBlock` with a reference (`origin`) to the whole `Sample`""" 

299 

300 origin: Sample 

301 """the sample this sample block was taken from""" 

302 

303 

304class _ConsolidatedMemberBlocks: 

305 def __init__(self, blocks: PerMember[BlockMeta]): 

306 super().__init__() 

307 block_indices = {b.block_index for b in blocks.values()} 

308 assert len(block_indices) == 1 

309 self.block_index = block_indices.pop() 

310 blocks_in_samples = {b.blocks_in_sample for b in blocks.values()} 

311 assert len(blocks_in_samples) == 1 

312 self.blocks_in_sample = blocks_in_samples.pop() 

313 

314 

315def sample_block_meta_generator( 

316 blocks: Iterable[PerMember[BlockMeta]], 

317 *, 

318 sample_shape: PerMember[PerAxis[int]], 

319 sample_id: SampleId, 

320): 

321 for member_blocks in blocks: 

322 cons = _ConsolidatedMemberBlocks(member_blocks) 

323 yield SampleBlockMeta( 

324 blocks=dict(member_blocks), 

325 sample_shape=sample_shape, 

326 sample_id=sample_id, 

327 block_index=cons.block_index, 

328 blocks_in_sample=cons.blocks_in_sample, 

329 ) 

330 

331 

332def sample_block_generator( 

333 blocks: Iterable[PerMember[BlockMeta]], 

334 *, 

335 origin: Sample, 

336 pad_mode: PadMode, 

337) -> Iterable[SampleBlockWithOrigin]: 

338 for member_blocks in blocks: 

339 cons = _ConsolidatedMemberBlocks(member_blocks) 

340 yield SampleBlockWithOrigin( 

341 blocks={ 

342 m: Block.from_sample_member( 

343 origin.members[m], block=member_blocks[m], pad_mode=pad_mode 

344 ) 

345 for m in origin.members 

346 }, 

347 sample_shape=origin.shape, 

348 origin=origin, 

349 stat=origin.stat, 

350 sample_id=origin.id, 

351 block_index=cons.block_index, 

352 blocks_in_sample=cons.blocks_in_sample, 

353 )