Coverage for bioimageio/core/block_meta.py: 84%

126 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-01 09:51 +0000

1import itertools 

2from dataclasses import dataclass 

3from functools import cached_property 

4from math import floor, prod 

5from typing import ( 

6 Any, 

7 Callable, 

8 Collection, 

9 Dict, 

10 Generator, 

11 Iterable, 

12 List, 

13 Optional, 

14 Tuple, 

15 Union, 

16) 

17 

18from loguru import logger 

19from typing_extensions import Self 

20 

21from .axis import AxisId, PerAxis 

22from .common import ( 

23 BlockIndex, 

24 Frozen, 

25 Halo, 

26 HaloLike, 

27 MemberId, 

28 PadWidth, 

29 PerMember, 

30 SliceInfo, 

31 TotalNumberOfBlocks, 

32) 

33 

34 

35@dataclass 

36class LinearAxisTransform: 

37 axis: AxisId 

38 scale: float 

39 offset: int 

40 

41 def compute(self, s: int, round: Callable[[float], int] = floor) -> int: 

42 return round(s * self.scale) + self.offset 

43 

44 

45@dataclass(frozen=True) 

46class BlockMeta: 

47 """Block meta data of a sample member (a tensor in a sample) 

48 

49 Figure for illustration: 

50 The first 2d block (dashed) of a sample member (**bold**). 

51 The inner slice (thin) is expanded by a halo in both dimensions on both sides. 

52 The outer slice reaches from the sample member origin (0, 0) to the right halo point. 

53 

54 ```terminal 

55 ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ 

56 ╷ halo(left) ╷ 

57 ╷ ╷ 

58 ╷ (0, 0)┏━━━━━━━━━━━━━━━━━┯━━━━━━━━━┯━━━➔ 

59 ╷ ┃ │ ╷ sample member 

60 ╷ ┃ inner │ ╷ 

61 ╷ ┃ (and outer) │ outer ╷ 

62 ╷ ┃ slice │ slice ╷ 

63 ╷ ┃ │ ╷ 

64 ╷ ┣─────────────────┘ ╷ 

65 ╷ ┃ outer slice ╷ 

66 ╷ ┃ halo(right) ╷ 

67 └ ─ ─ ─ ─┃─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─┘ 

68 

69 ``` 

70 

71 note: 

72 - Inner and outer slices are specified in sample member coordinates. 

73 - The outer_slice of a block at the sample edge may overlap by more than the 

74 halo with the neighboring block (the inner slices will not overlap though). 

75 

76 """ 

77 

78 sample_shape: PerAxis[int] 

79 """the axis sizes of the whole (unblocked) sample""" 

80 

81 inner_slice: PerAxis[SliceInfo] 

82 """inner region (without halo) wrt the sample""" 

83 

84 halo: PerAxis[Halo] 

85 """halo enlarging the inner region to the block's sizes""" 

86 

87 block_index: BlockIndex 

88 """the i-th block of the sample""" 

89 

90 blocks_in_sample: TotalNumberOfBlocks 

91 """total number of blocks in the sample""" 

92 

93 @cached_property 

94 def shape(self) -> PerAxis[int]: 

95 """axis lengths of the block""" 

96 return Frozen( 

97 { 

98 a: s.stop - s.start + (sum(self.halo[a]) if a in self.halo else 0) 

99 for a, s in self.inner_slice.items() 

100 } 

101 ) 

102 

103 @cached_property 

104 def padding(self) -> PerAxis[PadWidth]: 

105 """padding to realize the halo at the sample edge 

106 where we cannot simply enlarge the inner slice""" 

107 return Frozen( 

108 { 

109 a: PadWidth( 

110 ( 

111 self.halo[a].left 

112 - (self.inner_slice[a].start - self.outer_slice[a].start) 

113 if a in self.halo 

114 else 0 

115 ), 

116 ( 

117 self.halo[a].right 

118 - (self.outer_slice[a].stop - self.inner_slice[a].stop) 

119 if a in self.halo 

120 else 0 

121 ), 

122 ) 

123 for a in self.inner_slice 

124 } 

125 ) 

126 

127 @cached_property 

128 def outer_slice(self) -> PerAxis[SliceInfo]: 

129 """slice of the outer block (without padding) wrt the sample""" 

130 return Frozen( 

131 { 

132 a: SliceInfo( 

133 max( 

134 0, 

135 min( 

136 self.inner_slice[a].start 

137 - (self.halo[a].left if a in self.halo else 0), 

138 self.sample_shape[a] 

139 - self.inner_shape[a] 

140 - (self.halo[a].left if a in self.halo else 0), 

141 ), 

142 ), 

143 min( 

144 self.sample_shape[a], 

145 self.inner_slice[a].stop 

146 + (self.halo[a].right if a in self.halo else 0), 

147 ), 

148 ) 

149 for a in self.inner_slice 

150 } 

151 ) 

152 

153 @cached_property 

154 def inner_shape(self) -> PerAxis[int]: 

155 """axis lengths of the inner region (without halo)""" 

156 return Frozen({a: s.stop - s.start for a, s in self.inner_slice.items()}) 

157 

158 @cached_property 

159 def local_slice(self) -> PerAxis[SliceInfo]: 

160 """inner slice wrt the block, **not** the sample""" 

161 return Frozen( 

162 { 

163 a: SliceInfo( 

164 self.halo[a].left, 

165 self.halo[a].left + self.inner_shape[a], 

166 ) 

167 for a in self.inner_slice 

168 } 

169 ) 

170 

171 @property 

172 def dims(self) -> Collection[AxisId]: 

173 return set(self.inner_shape) 

174 

175 @property 

176 def tagged_shape(self) -> PerAxis[int]: 

177 """alias for shape""" 

178 return self.shape 

179 

180 @property 

181 def inner_slice_wo_overlap(self): 

182 """subslice of the inner slice, such that all `inner_slice_wo_overlap` can be 

183 stiched together trivially to form the original sample. 

184 

185 This can also be used to calculate statistics 

186 without overrepresenting block edge regions.""" 

187 # TODO: update inner_slice_wo_overlap when adding block overlap 

188 return self.inner_slice 

189 

190 def __post_init__(self): 

191 # freeze mutable inputs 

192 if not isinstance(self.sample_shape, Frozen): 

193 object.__setattr__(self, "sample_shape", Frozen(self.sample_shape)) 

194 

195 if not isinstance(self.inner_slice, Frozen): 

196 object.__setattr__(self, "inner_slice", Frozen(self.inner_slice)) 

197 

198 if not isinstance(self.halo, Frozen): 

199 object.__setattr__(self, "halo", Frozen(self.halo)) 

200 

201 assert all( 

202 a in self.sample_shape for a in self.inner_slice 

203 ), "block has axes not present in sample" 

204 

205 assert all( 

206 a in self.inner_slice for a in self.halo 

207 ), "halo has axes not present in block" 

208 

209 if any(s > self.sample_shape[a] for a, s in self.shape.items()): 

210 logger.warning( 

211 "block {} larger than sample {}", self.shape, self.sample_shape 

212 ) 

213 

214 def get_transformed( 

215 self, new_axes: PerAxis[Union[LinearAxisTransform, int]] 

216 ) -> Self: 

217 return self.__class__( 

218 sample_shape={ 

219 a: ( 

220 trf 

221 if isinstance(trf, int) 

222 else trf.compute(self.sample_shape[trf.axis]) 

223 ) 

224 for a, trf in new_axes.items() 

225 }, 

226 inner_slice={ 

227 a: ( 

228 SliceInfo(0, trf) 

229 if isinstance(trf, int) 

230 else SliceInfo( 

231 trf.compute(self.inner_slice[trf.axis].start), 

232 trf.compute(self.inner_slice[trf.axis].stop), 

233 ) 

234 ) 

235 for a, trf in new_axes.items() 

236 }, 

237 halo={ 

238 a: ( 

239 Halo(0, 0) 

240 if isinstance(trf, int) 

241 else Halo(self.halo[trf.axis].left, self.halo[trf.axis].right) 

242 ) 

243 for a, trf in new_axes.items() 

244 }, 

245 block_index=self.block_index, 

246 blocks_in_sample=self.blocks_in_sample, 

247 ) 

248 

249 

250def split_shape_into_blocks( 

251 shape: PerAxis[int], 

252 block_shape: PerAxis[int], 

253 halo: PerAxis[HaloLike], 

254 stride: Optional[PerAxis[int]] = None, 

255) -> Tuple[TotalNumberOfBlocks, Generator[BlockMeta, Any, None]]: 

256 assert all(a in shape for a in block_shape), ( 

257 tuple(shape), 

258 set(block_shape), 

259 ) 

260 if any(shape[a] < block_shape[a] for a in block_shape): 

261 # TODO: allow larger blockshape 

262 raise ValueError(f"shape {shape} is smaller than block shape {block_shape}") 

263 

264 assert all(a in shape for a in halo), (tuple(shape), set(halo)) 

265 

266 # fill in default halo (0) and block axis length (from tensor shape) 

267 halo = {a: Halo.create(halo.get(a, 0)) for a in shape} 

268 block_shape = {a: block_shape.get(a, s) for a, s in shape.items()} 

269 if stride is None: 

270 stride = {} 

271 

272 inner_1d_slices: Dict[AxisId, List[SliceInfo]] = {} 

273 for a, s in shape.items(): 

274 inner_size = block_shape[a] - sum(halo[a]) 

275 stride_1d = stride.get(a, inner_size) 

276 inner_1d_slices[a] = [ 

277 SliceInfo(min(p, s - inner_size), min(p + inner_size, s)) 

278 for p in range(0, s, stride_1d) 

279 ] 

280 

281 n_blocks = prod(map(len, inner_1d_slices.values())) 

282 

283 return n_blocks, _block_meta_generator( 

284 shape, 

285 blocks_in_sample=n_blocks, 

286 inner_1d_slices=inner_1d_slices, 

287 halo=halo, 

288 ) 

289 

290 

291def _block_meta_generator( 

292 sample_shape: PerAxis[int], 

293 *, 

294 blocks_in_sample: int, 

295 inner_1d_slices: Dict[AxisId, List[SliceInfo]], 

296 halo: PerAxis[HaloLike], 

297): 

298 assert all(a in sample_shape for a in halo) 

299 

300 halo = {a: Halo.create(halo.get(a, 0)) for a in inner_1d_slices} 

301 for i, nd_tile in enumerate(itertools.product(*inner_1d_slices.values())): 

302 inner_slice: PerAxis[SliceInfo] = dict(zip(inner_1d_slices, nd_tile)) 

303 

304 yield BlockMeta( 

305 sample_shape=sample_shape, 

306 inner_slice=inner_slice, 

307 halo=halo, 

308 block_index=i, 

309 blocks_in_sample=blocks_in_sample, 

310 ) 

311 

312 

313def split_multiple_shapes_into_blocks( 

314 shapes: PerMember[PerAxis[int]], 

315 block_shapes: PerMember[PerAxis[int]], 

316 *, 

317 halo: PerMember[PerAxis[HaloLike]], 

318 strides: Optional[PerMember[PerAxis[int]]] = None, 

319 broadcast: bool = False, 

320) -> Tuple[TotalNumberOfBlocks, Iterable[PerMember[BlockMeta]]]: 

321 if unknown_blocks := [t for t in block_shapes if t not in shapes]: 

322 raise ValueError( 

323 f"block shape specified for unknown tensors: {unknown_blocks}." 

324 ) 

325 

326 if not block_shapes: 

327 block_shapes = shapes 

328 

329 if not broadcast and ( 

330 missing_blocks := [t for t in shapes if t not in block_shapes] 

331 ): 

332 raise ValueError( 

333 f"no block shape specified for {missing_blocks}." 

334 + " Set `broadcast` to True if these tensors should be repeated" 

335 + " as a whole for each block." 

336 ) 

337 

338 if extra_halo := [t for t in halo if t not in block_shapes]: 

339 raise ValueError( 

340 f"`halo` specified for tensors without block shape: {extra_halo}." 

341 ) 

342 

343 if strides is None: 

344 strides = {} 

345 

346 assert not ( 

347 unknown_block := [t for t in strides if t not in block_shapes] 

348 ), f"`stride` specified for tensors without block shape: {unknown_block}" 

349 

350 blocks: Dict[MemberId, Iterable[BlockMeta]] = {} 

351 n_blocks: Dict[MemberId, TotalNumberOfBlocks] = {} 

352 for t in block_shapes: 

353 n_blocks[t], blocks[t] = split_shape_into_blocks( 

354 shape=shapes[t], 

355 block_shape=block_shapes[t], 

356 halo=halo.get(t, {}), 

357 stride=strides.get(t), 

358 ) 

359 assert n_blocks[t] > 0, n_blocks 

360 

361 assert len(blocks) > 0, blocks 

362 assert len(n_blocks) > 0, n_blocks 

363 unique_n_blocks = set(n_blocks.values()) 

364 n = max(unique_n_blocks) 

365 if len(unique_n_blocks) == 2 and 1 in unique_n_blocks: 

366 if not broadcast: 

367 raise ValueError( 

368 "Mismatch for total number of blocks due to unsplit (single block)" 

369 + f" tensors: {n_blocks}. Set `broadcast` to True if you want to" 

370 + " repeat unsplit (single block) tensors." 

371 ) 

372 

373 blocks = { 

374 t: _repeat_single_block(block_gen, n) if n_blocks[t] == 1 else block_gen 

375 for t, block_gen in blocks.items() 

376 } 

377 elif len(unique_n_blocks) != 1: 

378 raise ValueError(f"Mismatch for total number of blocks: {n_blocks}") 

379 

380 return n, _aligned_blocks_generator(n, blocks) 

381 

382 

383def _aligned_blocks_generator( 

384 n: TotalNumberOfBlocks, blocks: Dict[MemberId, Iterable[BlockMeta]] 

385): 

386 iterators = {t: iter(gen) for t, gen in blocks.items()} 

387 for _ in range(n): 

388 yield {t: next(it) for t, it in iterators.items()} 

389 

390 

391def _repeat_single_block(block_generator: Iterable[BlockMeta], n: TotalNumberOfBlocks): 

392 round_two = False 

393 for block in block_generator: 

394 assert not round_two 

395 for _ in range(n): 

396 yield block 

397 

398 round_two = True