Coverage for src/bioimageio/core/block_meta.py: 83%
125 statements
« prev ^ index » next coverage.py v7.14.2, created at 2026-06-22 16:54 +0000
« prev ^ index » next coverage.py v7.14.2, created at 2026-06-22 16:54 +0000
1import itertools
2from dataclasses import dataclass
3from functools import cached_property
4from math import floor, prod
5from types import MappingProxyType
6from typing import (
7 Any,
8 Callable,
9 Collection,
10 Dict,
11 Generator,
12 Iterable,
13 List,
14 Optional,
15 Tuple,
16 Union,
17)
19import pydantic
20from loguru import logger
21from typing_extensions import Self
23from ._axis_annotations import PerAxisAnno
24from .axis import AxisId, PerAxis
25from .common import (
26 BlockIndex,
27 Halo,
28 HaloLike,
29 MemberId,
30 PadWidth,
31 PerMember,
32 SliceInfo,
33 TotalNumberOfBlocks,
34)
37@dataclass
38class LinearAxisTransform:
39 axis: AxisId
40 scale: float
41 offset: int
43 def compute(self, s: int, round: Callable[[float], int] = floor) -> int:
44 return round(s * self.scale) + self.offset
47@pydantic.dataclasses.dataclass(frozen=True)
48class BlockMeta:
49 """Block meta data of a sample member (a tensor in a sample)
51 Figure for illustration:
52 The first 2d block (dashed) of a sample member (**bold**).
53 The inner slice (thin) is expanded by a halo in both dimensions on both sides.
54 The outer slice reaches from the sample member origin (0, 0) to the right halo point.
56 ```
57 first block (at the sample origin)
58 ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─┐
59 ╷ halo(left) ╷
60 ╷ padding outside the sample ╷
61 ╷ (0, 0)┏━━━━━━━━━━━━━━━━━┯━━━━━━━━━┯━━━➔
62 ╷ ┃ │ ╷ sample member
63 ╷ ┃ inner │ outer ╷
64 ╷ ┃ region │ region ╷
65 ╷ ┃ /slice │ /slice ╷
66 ╷ ┃ │ ╷
67 ╷ ┣─────────────────┘ ╷
68 ╷ ┃ outer region/slice ╷
69 ╷ ┃ halo(right) ╷
70 └ ─ ─ ─ ─┃─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─┘
71 ⬇
72 ```
74 Note:
75 - Inner and outer slices are specified in sample member coordinates.
76 - The outer_slice of a block at the sample edge may overlap by more than the
77 halo with the neighboring block (the inner slices will not overlap though).
79 """
81 sample_shape: PerAxisAnno[int]
82 """the axis sizes of the whole (unblocked) sample"""
84 inner_slice: PerAxisAnno[SliceInfo]
85 """inner region (without halo) wrt the sample"""
87 halo: PerAxisAnno[Halo]
88 """halo enlarging the inner region to the block's sizes"""
90 block_index: BlockIndex
91 """the i-th block of the sample"""
93 blocks_in_sample: TotalNumberOfBlocks
94 """total number of blocks in the sample"""
96 @cached_property
97 def shape(self) -> PerAxis[int]:
98 """axis lengths of the block"""
99 return MappingProxyType(
100 {
101 a: s.stop - s.start + (sum(self.halo[a]) if a in self.halo else 0)
102 for a, s in self.inner_slice.items()
103 }
104 )
106 @cached_property
107 def padding(self) -> PerAxis[PadWidth]:
108 """padding to realize the halo at the sample edge
109 where we cannot simply enlarge the inner slice"""
110 return MappingProxyType(
111 {
112 a: PadWidth(
113 (
114 self.halo[a].left
115 - (self.inner_slice[a].start - self.outer_slice[a].start)
116 if a in self.halo
117 else 0
118 ),
119 (
120 self.halo[a].right
121 - (self.outer_slice[a].stop - self.inner_slice[a].stop)
122 if a in self.halo
123 else 0
124 ),
125 )
126 for a in self.inner_slice
127 }
128 )
130 @cached_property
131 def outer_slice(self) -> PerAxis[SliceInfo]:
132 """slice of the outer block (without padding) wrt the sample"""
133 return MappingProxyType(
134 {
135 a: SliceInfo(
136 max(
137 0,
138 min(
139 self.inner_slice[a].start
140 - (self.halo[a].left if a in self.halo else 0),
141 self.sample_shape[a]
142 - self.inner_shape[a]
143 - (self.halo[a].left if a in self.halo else 0),
144 ),
145 ),
146 min(
147 self.sample_shape[a],
148 self.inner_slice[a].stop
149 + (self.halo[a].right if a in self.halo else 0),
150 ),
151 )
152 for a in self.inner_slice
153 }
154 )
156 @cached_property
157 def inner_shape(self) -> PerAxis[int]:
158 """axis lengths of the inner region (without halo)"""
159 return MappingProxyType(
160 {a: s.stop - s.start for a, s in self.inner_slice.items()}
161 )
163 @cached_property
164 def local_slice(self) -> PerAxis[SliceInfo]:
165 """inner slice wrt the block, **not** the sample"""
166 return MappingProxyType(
167 {
168 a: (
169 SliceInfo(
170 h.left,
171 h.left + self.inner_shape[a],
172 )
173 if (h := self.halo.get(a)) is not None
174 else SliceInfo(0, self.inner_shape[a])
175 )
176 for a in self.inner_slice
177 }
178 )
180 @property
181 def dims(self) -> Collection[AxisId]:
182 return set(self.inner_shape)
184 @property
185 def tagged_shape(self) -> PerAxis[int]:
186 """alias for shape"""
187 return self.shape
189 @property
190 def inner_slice_wo_overlap(self) -> PerAxis[SliceInfo]:
191 """subslice of the inner slice, such that all `inner_slice_wo_overlap` can be
192 stiched together trivially to form the original sample.
194 This can also be used to calculate statistics
195 without overrepresenting block edge regions."""
196 # TODO: update inner_slice_wo_overlap when adding block overlap
197 return self.inner_slice
199 def __post_init__(self):
200 assert all(a in self.sample_shape for a in self.inner_slice), (
201 "block has axes not present in sample"
202 )
204 assert all(a in self.inner_slice for a in self.halo), (
205 "halo has axes not present in block"
206 )
208 if any(s > self.sample_shape[a] for a, s in self.shape.items()):
209 logger.warning(
210 "block {} larger than sample {}", self.shape, self.sample_shape
211 )
213 def get_transformed(
214 self, new_axes: PerAxis[Union[LinearAxisTransform, int]]
215 ) -> Self:
216 return self.__class__(
217 sample_shape={
218 a: (
219 trf
220 if isinstance(trf, int)
221 else trf.compute(self.sample_shape[trf.axis])
222 )
223 for a, trf in new_axes.items()
224 },
225 inner_slice={
226 a: (
227 SliceInfo(0, trf)
228 if isinstance(trf, int)
229 else SliceInfo(
230 trf.compute(self.inner_slice[trf.axis].start),
231 trf.compute(self.inner_slice[trf.axis].stop),
232 )
233 )
234 for a, trf in new_axes.items()
235 },
236 halo={
237 a: (
238 Halo(0, 0)
239 if isinstance(trf, int)
240 else Halo(self.halo[trf.axis].left, self.halo[trf.axis].right)
241 )
242 for a, trf in new_axes.items()
243 },
244 block_index=self.block_index,
245 blocks_in_sample=self.blocks_in_sample,
246 )
249def split_shape_into_blocks(
250 shape: PerAxis[int],
251 block_shape: PerAxis[int],
252 halo: PerAxis[HaloLike],
253 stride: Optional[PerAxis[int]] = None,
254) -> Tuple[TotalNumberOfBlocks, Generator[BlockMeta, Any, None]]:
255 unknown_axes = [a for a in block_shape if a not in shape]
256 if unknown_axes:
257 raise ValueError(
258 f"unknown axes in block_shape: {unknown_axes} for shape {shape}"
259 )
261 if any(shape[a] < block_shape[a] for a in block_shape):
262 # TODO: allow larger blockshape
263 raise ValueError(f"shape {shape} is smaller than block shape {block_shape}")
265 assert all(a in shape for a in halo), (tuple(shape), set(halo))
267 # fill in default halo (0) and block axis length (from tensor shape)
268 halo = {a: Halo.create(halo.get(a, 0)) for a in shape}
269 block_shape = {a: block_shape.get(a, s) for a, s in shape.items()}
270 if stride is None:
271 stride = {}
273 inner_1d_slices: Dict[AxisId, List[SliceInfo]] = {}
274 for a, s in shape.items():
275 inner_size = block_shape[a] - sum(halo[a])
276 stride_1d = stride.get(a, inner_size)
277 inner_1d_slices[a] = [
278 SliceInfo(min(p, s - inner_size), min(p + inner_size, s))
279 for p in range(0, s, stride_1d)
280 ]
282 n_blocks = prod(map(len, inner_1d_slices.values()))
284 return n_blocks, _block_meta_generator(
285 shape,
286 blocks_in_sample=n_blocks,
287 inner_1d_slices=inner_1d_slices,
288 halo=halo,
289 )
292def _block_meta_generator(
293 sample_shape: PerAxis[int],
294 *,
295 blocks_in_sample: int,
296 inner_1d_slices: Dict[AxisId, List[SliceInfo]],
297 halo: PerAxis[HaloLike],
298):
299 assert all(a in sample_shape for a in halo)
301 halo = {a: Halo.create(halo.get(a, 0)) for a in inner_1d_slices}
302 for i, nd_tile in enumerate(itertools.product(*inner_1d_slices.values())):
303 inner_slice: PerAxis[SliceInfo] = dict(zip(inner_1d_slices, nd_tile))
305 yield BlockMeta(
306 sample_shape=sample_shape,
307 inner_slice=inner_slice,
308 halo=halo,
309 block_index=i,
310 blocks_in_sample=blocks_in_sample,
311 )
314def split_multiple_shapes_into_blocks(
315 shapes: PerMember[PerAxis[int]],
316 block_shapes: PerMember[PerAxis[int]],
317 *,
318 halo: PerMember[PerAxis[HaloLike]],
319 strides: Optional[PerMember[PerAxis[int]]] = None,
320 broadcast: bool = False,
321) -> Tuple[TotalNumberOfBlocks, Iterable[PerMember[BlockMeta]]]:
322 if unknown_blocks := [t for t in block_shapes if t not in shapes]:
323 raise ValueError(
324 f"block shape specified for unknown tensors: {unknown_blocks}."
325 )
327 if not block_shapes:
328 block_shapes = shapes
330 if not broadcast and (
331 missing_blocks := [t for t in shapes if t not in block_shapes]
332 ):
333 raise ValueError(
334 f"no block shape specified for {missing_blocks}."
335 + " Set `broadcast` to True if these tensors should be repeated"
336 + " as a whole for each block."
337 )
339 if extra_halo := [t for t in halo if t not in block_shapes]:
340 raise ValueError(
341 f"`halo` specified for tensors without block shape: {extra_halo}."
342 )
344 if strides is None:
345 strides = {}
347 assert not (unknown_block := [t for t in strides if t not in block_shapes]), (
348 f"`stride` specified for tensors without block shape: {unknown_block}"
349 )
351 blocks: Dict[MemberId, Iterable[BlockMeta]] = {}
352 n_blocks: Dict[MemberId, TotalNumberOfBlocks] = {}
353 for t in block_shapes:
354 n_blocks[t], blocks[t] = split_shape_into_blocks(
355 shape=shapes[t],
356 block_shape=block_shapes[t],
357 halo=halo.get(t, {}),
358 stride=strides.get(t),
359 )
360 assert n_blocks[t] > 0, n_blocks
362 assert len(blocks) > 0, blocks
363 assert len(n_blocks) > 0, n_blocks
364 unique_n_blocks = set(n_blocks.values())
365 n = max(unique_n_blocks)
366 if len(unique_n_blocks) == 2 and 1 in unique_n_blocks:
367 if not broadcast:
368 raise ValueError(
369 "Mismatch for total number of blocks due to unsplit (single block)"
370 + f" tensors: {n_blocks}. Set `broadcast` to True if you want to"
371 + " repeat unsplit (single block) tensors."
372 )
374 blocks = {
375 t: _repeat_single_block(block_gen, n) if n_blocks[t] == 1 else block_gen
376 for t, block_gen in blocks.items()
377 }
378 elif len(unique_n_blocks) != 1:
379 raise ValueError(f"Mismatch for total number of blocks: {n_blocks}")
381 return n, _aligned_blocks_generator(n, blocks)
384def _aligned_blocks_generator(
385 n: TotalNumberOfBlocks, blocks: Dict[MemberId, Iterable[BlockMeta]]
386):
387 iterators = {t: iter(gen) for t, gen in blocks.items()}
388 for _ in range(n):
389 yield {t: next(it) for t, it in iterators.items()}
392def _repeat_single_block(block_generator: Iterable[BlockMeta], n: TotalNumberOfBlocks):
393 round_two = False
394 for block in block_generator:
395 assert not round_two
396 for _ in range(n):
397 yield block
399 round_two = True