Coverage for src / bioimageio / core / block_meta.py: 84%

126 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-13 09:46 +0000

1import itertools 

2from dataclasses import dataclass 

3from functools import cached_property 

4from math import floor, prod 

5from typing import ( 

6 Any, 

7 Callable, 

8 Collection, 

9 Dict, 

10 Generator, 

11 Iterable, 

12 List, 

13 Optional, 

14 Tuple, 

15 Union, 

16) 

17 

18from loguru import logger 

19from typing_extensions import Self 

20 

21from .axis import AxisId, PerAxis 

22from .common import ( 

23 BlockIndex, 

24 Frozen, 

25 Halo, 

26 HaloLike, 

27 MemberId, 

28 PadWidth, 

29 PerMember, 

30 SliceInfo, 

31 TotalNumberOfBlocks, 

32) 

33 

34 

35@dataclass 

36class LinearAxisTransform: 

37 axis: AxisId 

38 scale: float 

39 offset: int 

40 

41 def compute(self, s: int, round: Callable[[float], int] = floor) -> int: 

42 return round(s * self.scale) + self.offset 

43 

44 

45@dataclass(frozen=True) 

46class BlockMeta: 

47 """Block meta data of a sample member (a tensor in a sample) 

48 

49 Figure for illustration: 

50 The first 2d block (dashed) of a sample member (**bold**). 

51 The inner slice (thin) is expanded by a halo in both dimensions on both sides. 

52 The outer slice reaches from the sample member origin (0, 0) to the right halo point. 

53 

54 ``` 

55 first block (at the sample origin) 

56 ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─┐ 

57 ╷ halo(left) ╷ 

58 ╷ padding outside the sample ╷ 

59 ╷ (0, 0)┏━━━━━━━━━━━━━━━━━┯━━━━━━━━━┯━━━➔ 

60 ╷ ┃ │ ╷ sample member 

61 ╷ ┃ inner │ outer ╷ 

62 ╷ ┃ region │ region ╷ 

63 ╷ ┃ /slice │ /slice ╷ 

64 ╷ ┃ │ ╷ 

65 ╷ ┣─────────────────┘ ╷ 

66 ╷ ┃ outer region/slice ╷ 

67 ╷ ┃ halo(right) ╷ 

68 └ ─ ─ ─ ─┃─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─┘ 

69 

70 ``` 

71 

72 Note: 

73 - Inner and outer slices are specified in sample member coordinates. 

74 - The outer_slice of a block at the sample edge may overlap by more than the 

75 halo with the neighboring block (the inner slices will not overlap though). 

76 

77 """ 

78 

79 sample_shape: PerAxis[int] 

80 """the axis sizes of the whole (unblocked) sample""" 

81 

82 inner_slice: PerAxis[SliceInfo] 

83 """inner region (without halo) wrt the sample""" 

84 

85 halo: PerAxis[Halo] 

86 """halo enlarging the inner region to the block's sizes""" 

87 

88 block_index: BlockIndex 

89 """the i-th block of the sample""" 

90 

91 blocks_in_sample: TotalNumberOfBlocks 

92 """total number of blocks in the sample""" 

93 

94 @cached_property 

95 def shape(self) -> PerAxis[int]: 

96 """axis lengths of the block""" 

97 return Frozen( 

98 { 

99 a: s.stop - s.start + (sum(self.halo[a]) if a in self.halo else 0) 

100 for a, s in self.inner_slice.items() 

101 } 

102 ) 

103 

104 @cached_property 

105 def padding(self) -> PerAxis[PadWidth]: 

106 """padding to realize the halo at the sample edge 

107 where we cannot simply enlarge the inner slice""" 

108 return Frozen( 

109 { 

110 a: PadWidth( 

111 ( 

112 self.halo[a].left 

113 - (self.inner_slice[a].start - self.outer_slice[a].start) 

114 if a in self.halo 

115 else 0 

116 ), 

117 ( 

118 self.halo[a].right 

119 - (self.outer_slice[a].stop - self.inner_slice[a].stop) 

120 if a in self.halo 

121 else 0 

122 ), 

123 ) 

124 for a in self.inner_slice 

125 } 

126 ) 

127 

128 @cached_property 

129 def outer_slice(self) -> PerAxis[SliceInfo]: 

130 """slice of the outer block (without padding) wrt the sample""" 

131 return Frozen( 

132 { 

133 a: SliceInfo( 

134 max( 

135 0, 

136 min( 

137 self.inner_slice[a].start 

138 - (self.halo[a].left if a in self.halo else 0), 

139 self.sample_shape[a] 

140 - self.inner_shape[a] 

141 - (self.halo[a].left if a in self.halo else 0), 

142 ), 

143 ), 

144 min( 

145 self.sample_shape[a], 

146 self.inner_slice[a].stop 

147 + (self.halo[a].right if a in self.halo else 0), 

148 ), 

149 ) 

150 for a in self.inner_slice 

151 } 

152 ) 

153 

154 @cached_property 

155 def inner_shape(self) -> PerAxis[int]: 

156 """axis lengths of the inner region (without halo)""" 

157 return Frozen({a: s.stop - s.start for a, s in self.inner_slice.items()}) 

158 

159 @cached_property 

160 def local_slice(self) -> PerAxis[SliceInfo]: 

161 """inner slice wrt the block, **not** the sample""" 

162 return Frozen( 

163 { 

164 a: SliceInfo( 

165 self.halo[a].left, 

166 self.halo[a].left + self.inner_shape[a], 

167 ) 

168 for a in self.inner_slice 

169 } 

170 ) 

171 

172 @property 

173 def dims(self) -> Collection[AxisId]: 

174 return set(self.inner_shape) 

175 

176 @property 

177 def tagged_shape(self) -> PerAxis[int]: 

178 """alias for shape""" 

179 return self.shape 

180 

181 @property 

182 def inner_slice_wo_overlap(self) -> PerAxis[SliceInfo]: 

183 """subslice of the inner slice, such that all `inner_slice_wo_overlap` can be 

184 stiched together trivially to form the original sample. 

185 

186 This can also be used to calculate statistics 

187 without overrepresenting block edge regions.""" 

188 # TODO: update inner_slice_wo_overlap when adding block overlap 

189 return self.inner_slice 

190 

191 def __post_init__(self): 

192 # freeze mutable inputs 

193 if not isinstance(self.sample_shape, Frozen): 

194 object.__setattr__(self, "sample_shape", Frozen(self.sample_shape)) 

195 

196 if not isinstance(self.inner_slice, Frozen): 

197 object.__setattr__(self, "inner_slice", Frozen(self.inner_slice)) 

198 

199 if not isinstance(self.halo, Frozen): 

200 object.__setattr__(self, "halo", Frozen(self.halo)) 

201 

202 assert all(a in self.sample_shape for a in self.inner_slice), ( 

203 "block has axes not present in sample" 

204 ) 

205 

206 assert all(a in self.inner_slice for a in self.halo), ( 

207 "halo has axes not present in block" 

208 ) 

209 

210 if any(s > self.sample_shape[a] for a, s in self.shape.items()): 

211 logger.warning( 

212 "block {} larger than sample {}", self.shape, self.sample_shape 

213 ) 

214 

215 def get_transformed( 

216 self, new_axes: PerAxis[Union[LinearAxisTransform, int]] 

217 ) -> Self: 

218 return self.__class__( 

219 sample_shape={ 

220 a: ( 

221 trf 

222 if isinstance(trf, int) 

223 else trf.compute(self.sample_shape[trf.axis]) 

224 ) 

225 for a, trf in new_axes.items() 

226 }, 

227 inner_slice={ 

228 a: ( 

229 SliceInfo(0, trf) 

230 if isinstance(trf, int) 

231 else SliceInfo( 

232 trf.compute(self.inner_slice[trf.axis].start), 

233 trf.compute(self.inner_slice[trf.axis].stop), 

234 ) 

235 ) 

236 for a, trf in new_axes.items() 

237 }, 

238 halo={ 

239 a: ( 

240 Halo(0, 0) 

241 if isinstance(trf, int) 

242 else Halo(self.halo[trf.axis].left, self.halo[trf.axis].right) 

243 ) 

244 for a, trf in new_axes.items() 

245 }, 

246 block_index=self.block_index, 

247 blocks_in_sample=self.blocks_in_sample, 

248 ) 

249 

250 

251def split_shape_into_blocks( 

252 shape: PerAxis[int], 

253 block_shape: PerAxis[int], 

254 halo: PerAxis[HaloLike], 

255 stride: Optional[PerAxis[int]] = None, 

256) -> Tuple[TotalNumberOfBlocks, Generator[BlockMeta, Any, None]]: 

257 assert all(a in shape for a in block_shape), ( 

258 tuple(shape), 

259 set(block_shape), 

260 ) 

261 if any(shape[a] < block_shape[a] for a in block_shape): 

262 # TODO: allow larger blockshape 

263 raise ValueError(f"shape {shape} is smaller than block shape {block_shape}") 

264 

265 assert all(a in shape for a in halo), (tuple(shape), set(halo)) 

266 

267 # fill in default halo (0) and block axis length (from tensor shape) 

268 halo = {a: Halo.create(halo.get(a, 0)) for a in shape} 

269 block_shape = {a: block_shape.get(a, s) for a, s in shape.items()} 

270 if stride is None: 

271 stride = {} 

272 

273 inner_1d_slices: Dict[AxisId, List[SliceInfo]] = {} 

274 for a, s in shape.items(): 

275 inner_size = block_shape[a] - sum(halo[a]) 

276 stride_1d = stride.get(a, inner_size) 

277 inner_1d_slices[a] = [ 

278 SliceInfo(min(p, s - inner_size), min(p + inner_size, s)) 

279 for p in range(0, s, stride_1d) 

280 ] 

281 

282 n_blocks = prod(map(len, inner_1d_slices.values())) 

283 

284 return n_blocks, _block_meta_generator( 

285 shape, 

286 blocks_in_sample=n_blocks, 

287 inner_1d_slices=inner_1d_slices, 

288 halo=halo, 

289 ) 

290 

291 

292def _block_meta_generator( 

293 sample_shape: PerAxis[int], 

294 *, 

295 blocks_in_sample: int, 

296 inner_1d_slices: Dict[AxisId, List[SliceInfo]], 

297 halo: PerAxis[HaloLike], 

298): 

299 assert all(a in sample_shape for a in halo) 

300 

301 halo = {a: Halo.create(halo.get(a, 0)) for a in inner_1d_slices} 

302 for i, nd_tile in enumerate(itertools.product(*inner_1d_slices.values())): 

303 inner_slice: PerAxis[SliceInfo] = dict(zip(inner_1d_slices, nd_tile)) 

304 

305 yield BlockMeta( 

306 sample_shape=sample_shape, 

307 inner_slice=inner_slice, 

308 halo=halo, 

309 block_index=i, 

310 blocks_in_sample=blocks_in_sample, 

311 ) 

312 

313 

314def split_multiple_shapes_into_blocks( 

315 shapes: PerMember[PerAxis[int]], 

316 block_shapes: PerMember[PerAxis[int]], 

317 *, 

318 halo: PerMember[PerAxis[HaloLike]], 

319 strides: Optional[PerMember[PerAxis[int]]] = None, 

320 broadcast: bool = False, 

321) -> Tuple[TotalNumberOfBlocks, Iterable[PerMember[BlockMeta]]]: 

322 if unknown_blocks := [t for t in block_shapes if t not in shapes]: 

323 raise ValueError( 

324 f"block shape specified for unknown tensors: {unknown_blocks}." 

325 ) 

326 

327 if not block_shapes: 

328 block_shapes = shapes 

329 

330 if not broadcast and ( 

331 missing_blocks := [t for t in shapes if t not in block_shapes] 

332 ): 

333 raise ValueError( 

334 f"no block shape specified for {missing_blocks}." 

335 + " Set `broadcast` to True if these tensors should be repeated" 

336 + " as a whole for each block." 

337 ) 

338 

339 if extra_halo := [t for t in halo if t not in block_shapes]: 

340 raise ValueError( 

341 f"`halo` specified for tensors without block shape: {extra_halo}." 

342 ) 

343 

344 if strides is None: 

345 strides = {} 

346 

347 assert not (unknown_block := [t for t in strides if t not in block_shapes]), ( 

348 f"`stride` specified for tensors without block shape: {unknown_block}" 

349 ) 

350 

351 blocks: Dict[MemberId, Iterable[BlockMeta]] = {} 

352 n_blocks: Dict[MemberId, TotalNumberOfBlocks] = {} 

353 for t in block_shapes: 

354 n_blocks[t], blocks[t] = split_shape_into_blocks( 

355 shape=shapes[t], 

356 block_shape=block_shapes[t], 

357 halo=halo.get(t, {}), 

358 stride=strides.get(t), 

359 ) 

360 assert n_blocks[t] > 0, n_blocks 

361 

362 assert len(blocks) > 0, blocks 

363 assert len(n_blocks) > 0, n_blocks 

364 unique_n_blocks = set(n_blocks.values()) 

365 n = max(unique_n_blocks) 

366 if len(unique_n_blocks) == 2 and 1 in unique_n_blocks: 

367 if not broadcast: 

368 raise ValueError( 

369 "Mismatch for total number of blocks due to unsplit (single block)" 

370 + f" tensors: {n_blocks}. Set `broadcast` to True if you want to" 

371 + " repeat unsplit (single block) tensors." 

372 ) 

373 

374 blocks = { 

375 t: _repeat_single_block(block_gen, n) if n_blocks[t] == 1 else block_gen 

376 for t, block_gen in blocks.items() 

377 } 

378 elif len(unique_n_blocks) != 1: 

379 raise ValueError(f"Mismatch for total number of blocks: {n_blocks}") 

380 

381 return n, _aligned_blocks_generator(n, blocks) 

382 

383 

384def _aligned_blocks_generator( 

385 n: TotalNumberOfBlocks, blocks: Dict[MemberId, Iterable[BlockMeta]] 

386): 

387 iterators = {t: iter(gen) for t, gen in blocks.items()} 

388 for _ in range(n): 

389 yield {t: next(it) for t, it in iterators.items()} 

390 

391 

392def _repeat_single_block(block_generator: Iterable[BlockMeta], n: TotalNumberOfBlocks): 

393 round_two = False 

394 for block in block_generator: 

395 assert not round_two 

396 for _ in range(n): 

397 yield block 

398 

399 round_two = True