Coverage for bioimageio/core/block_meta.py: 84%

126 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-19 09:02 +0000

1import itertools 

2from dataclasses import dataclass 

3from functools import cached_property 

4from math import floor, prod 

5from typing import ( 

6 Any, 

7 Callable, 

8 Collection, 

9 Dict, 

10 Generator, 

11 Iterable, 

12 List, 

13 Optional, 

14 Tuple, 

15 Union, 

16) 

17 

18from loguru import logger 

19from typing_extensions import Self 

20 

21from .axis import AxisId, PerAxis 

22from .common import ( 

23 BlockIndex, 

24 Frozen, 

25 Halo, 

26 HaloLike, 

27 MemberId, 

28 PadWidth, 

29 PerMember, 

30 SliceInfo, 

31 TotalNumberOfBlocks, 

32) 

33 

34 

35@dataclass 

36class LinearAxisTransform: 

37 axis: AxisId 

38 scale: float 

39 offset: int 

40 

41 def compute(self, s: int, round: Callable[[float], int] = floor) -> int: 

42 return round(s * self.scale) + self.offset 

43 

44 

45@dataclass(frozen=True) 

46class BlockMeta: 

47 """Block meta data of a sample member (a tensor in a sample) 

48 

49 Figure for illustration: 

50 The first 2d block (dashed) of a sample member (**bold**). 

51 The inner slice (thin) is expanded by a halo in both dimensions on both sides. 

52 The outer slice reaches from the sample member origin (0, 0) to the right halo point. 

53 

54 ```terminal 

55 ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ 

56 ╷ halo(left) ╷ 

57 ╷ ╷ 

58 ╷ (0, 0)┏━━━━━━━━━━━━━━━━━┯━━━━━━━━━┯━━━➔ 

59 ╷ ┃ │ ╷ sample member 

60 ╷ ┃ inner │ ╷ 

61 ╷ ┃ (and outer) │ outer ╷ 

62 ╷ ┃ slice │ slice ╷ 

63 ╷ ┃ │ ╷ 

64 ╷ ┣─────────────────┘ ╷ 

65 ╷ ┃ outer slice ╷ 

66 ╷ ┃ halo(right) ╷ 

67 └ ─ ─ ─ ─┃─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─┘ 

68 

69 ``` 

70 

71 note: 

72 - Inner and outer slices are specified in sample member coordinates. 

73 - The outer_slice of a block at the sample edge may overlap by more than the 

74 halo with the neighboring block (the inner slices will not overlap though). 

75 

76 """ 

77 

78 sample_shape: PerAxis[int] 

79 """the axis sizes of the whole (unblocked) sample""" 

80 

81 inner_slice: PerAxis[SliceInfo] 

82 """inner region (without halo) wrt the sample""" 

83 

84 halo: PerAxis[Halo] 

85 """halo enlarging the inner region to the block's sizes""" 

86 

87 block_index: BlockIndex 

88 """the i-th block of the sample""" 

89 

90 blocks_in_sample: TotalNumberOfBlocks 

91 """total number of blocks in the sample""" 

92 

93 @cached_property 

94 def shape(self) -> PerAxis[int]: 

95 """axis lengths of the block""" 

96 return Frozen( 

97 { 

98 a: s.stop - s.start + (sum(self.halo[a]) if a in self.halo else 0) 

99 for a, s in self.inner_slice.items() 

100 } 

101 ) 

102 

103 @cached_property 

104 def padding(self) -> PerAxis[PadWidth]: 

105 """padding to realize the halo at the sample edge 

106 where we cannot simply enlarge the inner slice""" 

107 return Frozen( 

108 { 

109 a: PadWidth( 

110 ( 

111 self.halo[a].left 

112 - (self.inner_slice[a].start - self.outer_slice[a].start) 

113 if a in self.halo 

114 else 0 

115 ), 

116 ( 

117 self.halo[a].right 

118 - (self.outer_slice[a].stop - self.inner_slice[a].stop) 

119 if a in self.halo 

120 else 0 

121 ), 

122 ) 

123 for a in self.inner_slice 

124 } 

125 ) 

126 

127 @cached_property 

128 def outer_slice(self) -> PerAxis[SliceInfo]: 

129 """slice of the outer block (without padding) wrt the sample""" 

130 return Frozen( 

131 { 

132 a: SliceInfo( 

133 max( 

134 0, 

135 min( 

136 self.inner_slice[a].start 

137 - (self.halo[a].left if a in self.halo else 0), 

138 self.sample_shape[a] 

139 - self.inner_shape[a] 

140 - (self.halo[a].left if a in self.halo else 0), 

141 ), 

142 ), 

143 min( 

144 self.sample_shape[a], 

145 self.inner_slice[a].stop 

146 + (self.halo[a].right if a in self.halo else 0), 

147 ), 

148 ) 

149 for a in self.inner_slice 

150 } 

151 ) 

152 

153 @cached_property 

154 def inner_shape(self) -> PerAxis[int]: 

155 """axis lengths of the inner region (without halo)""" 

156 return Frozen({a: s.stop - s.start for a, s in self.inner_slice.items()}) 

157 

158 @cached_property 

159 def local_slice(self) -> PerAxis[SliceInfo]: 

160 """inner slice wrt the block, **not** the sample""" 

161 return Frozen( 

162 { 

163 a: SliceInfo( 

164 self.halo[a].left, 

165 self.halo[a].left + self.inner_shape[a], 

166 ) 

167 for a in self.inner_slice 

168 } 

169 ) 

170 

171 @property 

172 def dims(self) -> Collection[AxisId]: 

173 return set(self.inner_shape) 

174 

175 @property 

176 def tagged_shape(self) -> PerAxis[int]: 

177 """alias for shape""" 

178 return self.shape 

179 

180 @property 

181 def inner_slice_wo_overlap(self): 

182 """subslice of the inner slice, such that all `inner_slice_wo_overlap` can be 

183 stiched together trivially to form the original sample. 

184 

185 This can also be used to calculate statistics 

186 without overrepresenting block edge regions.""" 

187 # TODO: update inner_slice_wo_overlap when adding block overlap 

188 return self.inner_slice 

189 

190 def __post_init__(self): 

191 # freeze mutable inputs 

192 if not isinstance(self.sample_shape, Frozen): 

193 object.__setattr__(self, "sample_shape", Frozen(self.sample_shape)) 

194 

195 if not isinstance(self.inner_slice, Frozen): 

196 object.__setattr__(self, "inner_slice", Frozen(self.inner_slice)) 

197 

198 if not isinstance(self.halo, Frozen): 

199 object.__setattr__(self, "halo", Frozen(self.halo)) 

200 

201 assert all( 

202 a in self.sample_shape for a in self.inner_slice 

203 ), "block has axes not present in sample" 

204 

205 assert all( 

206 a in self.inner_slice for a in self.halo 

207 ), "halo has axes not present in block" 

208 

209 if any(s > self.sample_shape[a] for a, s in self.shape.items()): 

210 logger.warning( 

211 "block {} larger than sample {}", self.shape, self.sample_shape 

212 ) 

213 

214 def get_transformed( 

215 self, new_axes: PerAxis[Union[LinearAxisTransform, int]] 

216 ) -> Self: 

217 return self.__class__( 

218 sample_shape={ 

219 a: ( 

220 trf 

221 if isinstance(trf, int) 

222 else trf.compute(self.sample_shape[trf.axis]) 

223 ) 

224 for a, trf in new_axes.items() 

225 }, 

226 inner_slice={ 

227 a: ( 

228 SliceInfo(0, trf) 

229 if isinstance(trf, int) 

230 else SliceInfo( 

231 trf.compute(self.inner_slice[trf.axis].start), 

232 trf.compute(self.inner_slice[trf.axis].stop), 

233 ) 

234 ) 

235 for a, trf in new_axes.items() 

236 }, 

237 halo={ 

238 a: ( 

239 Halo(0, 0) 

240 if isinstance(trf, int) 

241 else Halo(self.halo[trf.axis].left, self.halo[trf.axis].right) 

242 ) 

243 for a, trf in new_axes.items() 

244 }, 

245 block_index=self.block_index, 

246 blocks_in_sample=self.blocks_in_sample, 

247 ) 

248 

249 

250def split_shape_into_blocks( 

251 shape: PerAxis[int], 

252 block_shape: PerAxis[int], 

253 halo: PerAxis[HaloLike], 

254 stride: Optional[PerAxis[int]] = None, 

255) -> Tuple[TotalNumberOfBlocks, Generator[BlockMeta, Any, None]]: 

256 assert all(a in shape for a in block_shape), ( 

257 tuple(shape), 

258 set(block_shape), 

259 ) 

260 if any(shape[a] < block_shape[a] for a in block_shape): 

261 raise ValueError(f"shape {shape} is smaller than block shape {block_shape}") 

262 

263 assert all(a in shape for a in halo), (tuple(shape), set(halo)) 

264 

265 # fill in default halo (0) and block axis length (from tensor shape) 

266 halo = {a: Halo.create(halo.get(a, 0)) for a in shape} 

267 block_shape = {a: block_shape.get(a, s) for a, s in shape.items()} 

268 if stride is None: 

269 stride = {} 

270 

271 inner_1d_slices: Dict[AxisId, List[SliceInfo]] = {} 

272 for a, s in shape.items(): 

273 inner_size = block_shape[a] - sum(halo[a]) 

274 stride_1d = stride.get(a, inner_size) 

275 inner_1d_slices[a] = [ 

276 SliceInfo(min(p, s - inner_size), min(p + inner_size, s)) 

277 for p in range(0, s, stride_1d) 

278 ] 

279 

280 n_blocks = prod(map(len, inner_1d_slices.values())) 

281 

282 return n_blocks, _block_meta_generator( 

283 shape, 

284 blocks_in_sample=n_blocks, 

285 inner_1d_slices=inner_1d_slices, 

286 halo=halo, 

287 ) 

288 

289 

290def _block_meta_generator( 

291 sample_shape: PerAxis[int], 

292 *, 

293 blocks_in_sample: int, 

294 inner_1d_slices: Dict[AxisId, List[SliceInfo]], 

295 halo: PerAxis[HaloLike], 

296): 

297 assert all(a in sample_shape for a in halo) 

298 

299 halo = {a: Halo.create(halo.get(a, 0)) for a in inner_1d_slices} 

300 for i, nd_tile in enumerate(itertools.product(*inner_1d_slices.values())): 

301 inner_slice: PerAxis[SliceInfo] = dict(zip(inner_1d_slices, nd_tile)) 

302 

303 yield BlockMeta( 

304 sample_shape=sample_shape, 

305 inner_slice=inner_slice, 

306 halo=halo, 

307 block_index=i, 

308 blocks_in_sample=blocks_in_sample, 

309 ) 

310 

311 

312def split_multiple_shapes_into_blocks( 

313 shapes: PerMember[PerAxis[int]], 

314 block_shapes: PerMember[PerAxis[int]], 

315 *, 

316 halo: PerMember[PerAxis[HaloLike]], 

317 strides: Optional[PerMember[PerAxis[int]]] = None, 

318 broadcast: bool = False, 

319) -> Tuple[TotalNumberOfBlocks, Iterable[PerMember[BlockMeta]]]: 

320 if unknown_blocks := [t for t in block_shapes if t not in shapes]: 

321 raise ValueError( 

322 f"block shape specified for unknown tensors: {unknown_blocks}." 

323 ) 

324 

325 if not block_shapes: 

326 block_shapes = shapes 

327 

328 if not broadcast and ( 

329 missing_blocks := [t for t in shapes if t not in block_shapes] 

330 ): 

331 raise ValueError( 

332 f"no block shape specified for {missing_blocks}." 

333 + " Set `broadcast` to True if these tensors should be repeated" 

334 + " as a whole for each block." 

335 ) 

336 

337 if extra_halo := [t for t in halo if t not in block_shapes]: 

338 raise ValueError( 

339 f"`halo` specified for tensors without block shape: {extra_halo}." 

340 ) 

341 

342 if strides is None: 

343 strides = {} 

344 

345 assert not ( 

346 unknown_block := [t for t in strides if t not in block_shapes] 

347 ), f"`stride` specified for tensors without block shape: {unknown_block}" 

348 

349 blocks: Dict[MemberId, Iterable[BlockMeta]] = {} 

350 n_blocks: Dict[MemberId, TotalNumberOfBlocks] = {} 

351 for t in block_shapes: 

352 n_blocks[t], blocks[t] = split_shape_into_blocks( 

353 shape=shapes[t], 

354 block_shape=block_shapes[t], 

355 halo=halo.get(t, {}), 

356 stride=strides.get(t), 

357 ) 

358 assert n_blocks[t] > 0, n_blocks 

359 

360 assert len(blocks) > 0, blocks 

361 assert len(n_blocks) > 0, n_blocks 

362 unique_n_blocks = set(n_blocks.values()) 

363 n = max(unique_n_blocks) 

364 if len(unique_n_blocks) == 2 and 1 in unique_n_blocks: 

365 if not broadcast: 

366 raise ValueError( 

367 "Mismatch for total number of blocks due to unsplit (single block)" 

368 + f" tensors: {n_blocks}. Set `broadcast` to True if you want to" 

369 + " repeat unsplit (single block) tensors." 

370 ) 

371 

372 blocks = { 

373 t: _repeat_single_block(block_gen, n) if n_blocks[t] == 1 else block_gen 

374 for t, block_gen in blocks.items() 

375 } 

376 elif len(unique_n_blocks) != 1: 

377 raise ValueError(f"Mismatch for total number of blocks: {n_blocks}") 

378 

379 return n, _aligned_blocks_generator(n, blocks) 

380 

381 

382def _aligned_blocks_generator( 

383 n: TotalNumberOfBlocks, blocks: Dict[MemberId, Iterable[BlockMeta]] 

384): 

385 iterators = {t: iter(gen) for t, gen in blocks.items()} 

386 for _ in range(n): 

387 yield {t: next(it) for t, it in iterators.items()} 

388 

389 

390def _repeat_single_block(block_generator: Iterable[BlockMeta], n: TotalNumberOfBlocks): 

391 round_two = False 

392 for block in block_generator: 

393 assert not round_two 

394 for _ in range(n): 

395 yield block 

396 

397 round_two = True