Coverage for src/bioimageio/core/block_meta.py: 83%

125 statements  

« prev     ^ index     » next       coverage.py v7.14.2, created at 2026-06-22 16:54 +0000

1import itertools 

2from dataclasses import dataclass 

3from functools import cached_property 

4from math import floor, prod 

5from types import MappingProxyType 

6from typing import ( 

7 Any, 

8 Callable, 

9 Collection, 

10 Dict, 

11 Generator, 

12 Iterable, 

13 List, 

14 Optional, 

15 Tuple, 

16 Union, 

17) 

18 

19import pydantic 

20from loguru import logger 

21from typing_extensions import Self 

22 

23from ._axis_annotations import PerAxisAnno 

24from .axis import AxisId, PerAxis 

25from .common import ( 

26 BlockIndex, 

27 Halo, 

28 HaloLike, 

29 MemberId, 

30 PadWidth, 

31 PerMember, 

32 SliceInfo, 

33 TotalNumberOfBlocks, 

34) 

35 

36 

37@dataclass 

38class LinearAxisTransform: 

39 axis: AxisId 

40 scale: float 

41 offset: int 

42 

43 def compute(self, s: int, round: Callable[[float], int] = floor) -> int: 

44 return round(s * self.scale) + self.offset 

45 

46 

47@pydantic.dataclasses.dataclass(frozen=True) 

48class BlockMeta: 

49 """Block meta data of a sample member (a tensor in a sample) 

50 

51 Figure for illustration: 

52 The first 2d block (dashed) of a sample member (**bold**). 

53 The inner slice (thin) is expanded by a halo in both dimensions on both sides. 

54 The outer slice reaches from the sample member origin (0, 0) to the right halo point. 

55 

56 ``` 

57 first block (at the sample origin) 

58 ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─┐ 

59 ╷ halo(left) ╷ 

60 ╷ padding outside the sample ╷ 

61 ╷ (0, 0)┏━━━━━━━━━━━━━━━━━┯━━━━━━━━━┯━━━➔ 

62 ╷ ┃ │ ╷ sample member 

63 ╷ ┃ inner │ outer ╷ 

64 ╷ ┃ region │ region ╷ 

65 ╷ ┃ /slice │ /slice ╷ 

66 ╷ ┃ │ ╷ 

67 ╷ ┣─────────────────┘ ╷ 

68 ╷ ┃ outer region/slice ╷ 

69 ╷ ┃ halo(right) ╷ 

70 └ ─ ─ ─ ─┃─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─┘ 

71 

72 ``` 

73 

74 Note: 

75 - Inner and outer slices are specified in sample member coordinates. 

76 - The outer_slice of a block at the sample edge may overlap by more than the 

77 halo with the neighboring block (the inner slices will not overlap though). 

78 

79 """ 

80 

81 sample_shape: PerAxisAnno[int] 

82 """the axis sizes of the whole (unblocked) sample""" 

83 

84 inner_slice: PerAxisAnno[SliceInfo] 

85 """inner region (without halo) wrt the sample""" 

86 

87 halo: PerAxisAnno[Halo] 

88 """halo enlarging the inner region to the block's sizes""" 

89 

90 block_index: BlockIndex 

91 """the i-th block of the sample""" 

92 

93 blocks_in_sample: TotalNumberOfBlocks 

94 """total number of blocks in the sample""" 

95 

96 @cached_property 

97 def shape(self) -> PerAxis[int]: 

98 """axis lengths of the block""" 

99 return MappingProxyType( 

100 { 

101 a: s.stop - s.start + (sum(self.halo[a]) if a in self.halo else 0) 

102 for a, s in self.inner_slice.items() 

103 } 

104 ) 

105 

106 @cached_property 

107 def padding(self) -> PerAxis[PadWidth]: 

108 """padding to realize the halo at the sample edge 

109 where we cannot simply enlarge the inner slice""" 

110 return MappingProxyType( 

111 { 

112 a: PadWidth( 

113 ( 

114 self.halo[a].left 

115 - (self.inner_slice[a].start - self.outer_slice[a].start) 

116 if a in self.halo 

117 else 0 

118 ), 

119 ( 

120 self.halo[a].right 

121 - (self.outer_slice[a].stop - self.inner_slice[a].stop) 

122 if a in self.halo 

123 else 0 

124 ), 

125 ) 

126 for a in self.inner_slice 

127 } 

128 ) 

129 

130 @cached_property 

131 def outer_slice(self) -> PerAxis[SliceInfo]: 

132 """slice of the outer block (without padding) wrt the sample""" 

133 return MappingProxyType( 

134 { 

135 a: SliceInfo( 

136 max( 

137 0, 

138 min( 

139 self.inner_slice[a].start 

140 - (self.halo[a].left if a in self.halo else 0), 

141 self.sample_shape[a] 

142 - self.inner_shape[a] 

143 - (self.halo[a].left if a in self.halo else 0), 

144 ), 

145 ), 

146 min( 

147 self.sample_shape[a], 

148 self.inner_slice[a].stop 

149 + (self.halo[a].right if a in self.halo else 0), 

150 ), 

151 ) 

152 for a in self.inner_slice 

153 } 

154 ) 

155 

156 @cached_property 

157 def inner_shape(self) -> PerAxis[int]: 

158 """axis lengths of the inner region (without halo)""" 

159 return MappingProxyType( 

160 {a: s.stop - s.start for a, s in self.inner_slice.items()} 

161 ) 

162 

163 @cached_property 

164 def local_slice(self) -> PerAxis[SliceInfo]: 

165 """inner slice wrt the block, **not** the sample""" 

166 return MappingProxyType( 

167 { 

168 a: ( 

169 SliceInfo( 

170 h.left, 

171 h.left + self.inner_shape[a], 

172 ) 

173 if (h := self.halo.get(a)) is not None 

174 else SliceInfo(0, self.inner_shape[a]) 

175 ) 

176 for a in self.inner_slice 

177 } 

178 ) 

179 

180 @property 

181 def dims(self) -> Collection[AxisId]: 

182 return set(self.inner_shape) 

183 

184 @property 

185 def tagged_shape(self) -> PerAxis[int]: 

186 """alias for shape""" 

187 return self.shape 

188 

189 @property 

190 def inner_slice_wo_overlap(self) -> PerAxis[SliceInfo]: 

191 """subslice of the inner slice, such that all `inner_slice_wo_overlap` can be 

192 stiched together trivially to form the original sample. 

193 

194 This can also be used to calculate statistics 

195 without overrepresenting block edge regions.""" 

196 # TODO: update inner_slice_wo_overlap when adding block overlap 

197 return self.inner_slice 

198 

199 def __post_init__(self): 

200 assert all(a in self.sample_shape for a in self.inner_slice), ( 

201 "block has axes not present in sample" 

202 ) 

203 

204 assert all(a in self.inner_slice for a in self.halo), ( 

205 "halo has axes not present in block" 

206 ) 

207 

208 if any(s > self.sample_shape[a] for a, s in self.shape.items()): 

209 logger.warning( 

210 "block {} larger than sample {}", self.shape, self.sample_shape 

211 ) 

212 

213 def get_transformed( 

214 self, new_axes: PerAxis[Union[LinearAxisTransform, int]] 

215 ) -> Self: 

216 return self.__class__( 

217 sample_shape={ 

218 a: ( 

219 trf 

220 if isinstance(trf, int) 

221 else trf.compute(self.sample_shape[trf.axis]) 

222 ) 

223 for a, trf in new_axes.items() 

224 }, 

225 inner_slice={ 

226 a: ( 

227 SliceInfo(0, trf) 

228 if isinstance(trf, int) 

229 else SliceInfo( 

230 trf.compute(self.inner_slice[trf.axis].start), 

231 trf.compute(self.inner_slice[trf.axis].stop), 

232 ) 

233 ) 

234 for a, trf in new_axes.items() 

235 }, 

236 halo={ 

237 a: ( 

238 Halo(0, 0) 

239 if isinstance(trf, int) 

240 else Halo(self.halo[trf.axis].left, self.halo[trf.axis].right) 

241 ) 

242 for a, trf in new_axes.items() 

243 }, 

244 block_index=self.block_index, 

245 blocks_in_sample=self.blocks_in_sample, 

246 ) 

247 

248 

249def split_shape_into_blocks( 

250 shape: PerAxis[int], 

251 block_shape: PerAxis[int], 

252 halo: PerAxis[HaloLike], 

253 stride: Optional[PerAxis[int]] = None, 

254) -> Tuple[TotalNumberOfBlocks, Generator[BlockMeta, Any, None]]: 

255 unknown_axes = [a for a in block_shape if a not in shape] 

256 if unknown_axes: 

257 raise ValueError( 

258 f"unknown axes in block_shape: {unknown_axes} for shape {shape}" 

259 ) 

260 

261 if any(shape[a] < block_shape[a] for a in block_shape): 

262 # TODO: allow larger blockshape 

263 raise ValueError(f"shape {shape} is smaller than block shape {block_shape}") 

264 

265 assert all(a in shape for a in halo), (tuple(shape), set(halo)) 

266 

267 # fill in default halo (0) and block axis length (from tensor shape) 

268 halo = {a: Halo.create(halo.get(a, 0)) for a in shape} 

269 block_shape = {a: block_shape.get(a, s) for a, s in shape.items()} 

270 if stride is None: 

271 stride = {} 

272 

273 inner_1d_slices: Dict[AxisId, List[SliceInfo]] = {} 

274 for a, s in shape.items(): 

275 inner_size = block_shape[a] - sum(halo[a]) 

276 stride_1d = stride.get(a, inner_size) 

277 inner_1d_slices[a] = [ 

278 SliceInfo(min(p, s - inner_size), min(p + inner_size, s)) 

279 for p in range(0, s, stride_1d) 

280 ] 

281 

282 n_blocks = prod(map(len, inner_1d_slices.values())) 

283 

284 return n_blocks, _block_meta_generator( 

285 shape, 

286 blocks_in_sample=n_blocks, 

287 inner_1d_slices=inner_1d_slices, 

288 halo=halo, 

289 ) 

290 

291 

292def _block_meta_generator( 

293 sample_shape: PerAxis[int], 

294 *, 

295 blocks_in_sample: int, 

296 inner_1d_slices: Dict[AxisId, List[SliceInfo]], 

297 halo: PerAxis[HaloLike], 

298): 

299 assert all(a in sample_shape for a in halo) 

300 

301 halo = {a: Halo.create(halo.get(a, 0)) for a in inner_1d_slices} 

302 for i, nd_tile in enumerate(itertools.product(*inner_1d_slices.values())): 

303 inner_slice: PerAxis[SliceInfo] = dict(zip(inner_1d_slices, nd_tile)) 

304 

305 yield BlockMeta( 

306 sample_shape=sample_shape, 

307 inner_slice=inner_slice, 

308 halo=halo, 

309 block_index=i, 

310 blocks_in_sample=blocks_in_sample, 

311 ) 

312 

313 

314def split_multiple_shapes_into_blocks( 

315 shapes: PerMember[PerAxis[int]], 

316 block_shapes: PerMember[PerAxis[int]], 

317 *, 

318 halo: PerMember[PerAxis[HaloLike]], 

319 strides: Optional[PerMember[PerAxis[int]]] = None, 

320 broadcast: bool = False, 

321) -> Tuple[TotalNumberOfBlocks, Iterable[PerMember[BlockMeta]]]: 

322 if unknown_blocks := [t for t in block_shapes if t not in shapes]: 

323 raise ValueError( 

324 f"block shape specified for unknown tensors: {unknown_blocks}." 

325 ) 

326 

327 if not block_shapes: 

328 block_shapes = shapes 

329 

330 if not broadcast and ( 

331 missing_blocks := [t for t in shapes if t not in block_shapes] 

332 ): 

333 raise ValueError( 

334 f"no block shape specified for {missing_blocks}." 

335 + " Set `broadcast` to True if these tensors should be repeated" 

336 + " as a whole for each block." 

337 ) 

338 

339 if extra_halo := [t for t in halo if t not in block_shapes]: 

340 raise ValueError( 

341 f"`halo` specified for tensors without block shape: {extra_halo}." 

342 ) 

343 

344 if strides is None: 

345 strides = {} 

346 

347 assert not (unknown_block := [t for t in strides if t not in block_shapes]), ( 

348 f"`stride` specified for tensors without block shape: {unknown_block}" 

349 ) 

350 

351 blocks: Dict[MemberId, Iterable[BlockMeta]] = {} 

352 n_blocks: Dict[MemberId, TotalNumberOfBlocks] = {} 

353 for t in block_shapes: 

354 n_blocks[t], blocks[t] = split_shape_into_blocks( 

355 shape=shapes[t], 

356 block_shape=block_shapes[t], 

357 halo=halo.get(t, {}), 

358 stride=strides.get(t), 

359 ) 

360 assert n_blocks[t] > 0, n_blocks 

361 

362 assert len(blocks) > 0, blocks 

363 assert len(n_blocks) > 0, n_blocks 

364 unique_n_blocks = set(n_blocks.values()) 

365 n = max(unique_n_blocks) 

366 if len(unique_n_blocks) == 2 and 1 in unique_n_blocks: 

367 if not broadcast: 

368 raise ValueError( 

369 "Mismatch for total number of blocks due to unsplit (single block)" 

370 + f" tensors: {n_blocks}. Set `broadcast` to True if you want to" 

371 + " repeat unsplit (single block) tensors." 

372 ) 

373 

374 blocks = { 

375 t: _repeat_single_block(block_gen, n) if n_blocks[t] == 1 else block_gen 

376 for t, block_gen in blocks.items() 

377 } 

378 elif len(unique_n_blocks) != 1: 

379 raise ValueError(f"Mismatch for total number of blocks: {n_blocks}") 

380 

381 return n, _aligned_blocks_generator(n, blocks) 

382 

383 

384def _aligned_blocks_generator( 

385 n: TotalNumberOfBlocks, blocks: Dict[MemberId, Iterable[BlockMeta]] 

386): 

387 iterators = {t: iter(gen) for t, gen in blocks.items()} 

388 for _ in range(n): 

389 yield {t: next(it) for t, it in iterators.items()} 

390 

391 

392def _repeat_single_block(block_generator: Iterable[BlockMeta], n: TotalNumberOfBlocks): 

393 round_two = False 

394 for block in block_generator: 

395 assert not round_two 

396 for _ in range(n): 

397 yield block 

398 

399 round_two = True