Coverage for src / bioimageio / core / block_meta.py: 84%
126 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-13 09:46 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-13 09:46 +0000
1import itertools
2from dataclasses import dataclass
3from functools import cached_property
4from math import floor, prod
5from typing import (
6 Any,
7 Callable,
8 Collection,
9 Dict,
10 Generator,
11 Iterable,
12 List,
13 Optional,
14 Tuple,
15 Union,
16)
18from loguru import logger
19from typing_extensions import Self
21from .axis import AxisId, PerAxis
22from .common import (
23 BlockIndex,
24 Frozen,
25 Halo,
26 HaloLike,
27 MemberId,
28 PadWidth,
29 PerMember,
30 SliceInfo,
31 TotalNumberOfBlocks,
32)
35@dataclass
36class LinearAxisTransform:
37 axis: AxisId
38 scale: float
39 offset: int
41 def compute(self, s: int, round: Callable[[float], int] = floor) -> int:
42 return round(s * self.scale) + self.offset
45@dataclass(frozen=True)
46class BlockMeta:
47 """Block meta data of a sample member (a tensor in a sample)
49 Figure for illustration:
50 The first 2d block (dashed) of a sample member (**bold**).
51 The inner slice (thin) is expanded by a halo in both dimensions on both sides.
52 The outer slice reaches from the sample member origin (0, 0) to the right halo point.
54 ```
55 first block (at the sample origin)
56 ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─┐
57 ╷ halo(left) ╷
58 ╷ padding outside the sample ╷
59 ╷ (0, 0)┏━━━━━━━━━━━━━━━━━┯━━━━━━━━━┯━━━➔
60 ╷ ┃ │ ╷ sample member
61 ╷ ┃ inner │ outer ╷
62 ╷ ┃ region │ region ╷
63 ╷ ┃ /slice │ /slice ╷
64 ╷ ┃ │ ╷
65 ╷ ┣─────────────────┘ ╷
66 ╷ ┃ outer region/slice ╷
67 ╷ ┃ halo(right) ╷
68 └ ─ ─ ─ ─┃─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─┘
69 ⬇
70 ```
72 Note:
73 - Inner and outer slices are specified in sample member coordinates.
74 - The outer_slice of a block at the sample edge may overlap by more than the
75 halo with the neighboring block (the inner slices will not overlap though).
77 """
79 sample_shape: PerAxis[int]
80 """the axis sizes of the whole (unblocked) sample"""
82 inner_slice: PerAxis[SliceInfo]
83 """inner region (without halo) wrt the sample"""
85 halo: PerAxis[Halo]
86 """halo enlarging the inner region to the block's sizes"""
88 block_index: BlockIndex
89 """the i-th block of the sample"""
91 blocks_in_sample: TotalNumberOfBlocks
92 """total number of blocks in the sample"""
94 @cached_property
95 def shape(self) -> PerAxis[int]:
96 """axis lengths of the block"""
97 return Frozen(
98 {
99 a: s.stop - s.start + (sum(self.halo[a]) if a in self.halo else 0)
100 for a, s in self.inner_slice.items()
101 }
102 )
104 @cached_property
105 def padding(self) -> PerAxis[PadWidth]:
106 """padding to realize the halo at the sample edge
107 where we cannot simply enlarge the inner slice"""
108 return Frozen(
109 {
110 a: PadWidth(
111 (
112 self.halo[a].left
113 - (self.inner_slice[a].start - self.outer_slice[a].start)
114 if a in self.halo
115 else 0
116 ),
117 (
118 self.halo[a].right
119 - (self.outer_slice[a].stop - self.inner_slice[a].stop)
120 if a in self.halo
121 else 0
122 ),
123 )
124 for a in self.inner_slice
125 }
126 )
128 @cached_property
129 def outer_slice(self) -> PerAxis[SliceInfo]:
130 """slice of the outer block (without padding) wrt the sample"""
131 return Frozen(
132 {
133 a: SliceInfo(
134 max(
135 0,
136 min(
137 self.inner_slice[a].start
138 - (self.halo[a].left if a in self.halo else 0),
139 self.sample_shape[a]
140 - self.inner_shape[a]
141 - (self.halo[a].left if a in self.halo else 0),
142 ),
143 ),
144 min(
145 self.sample_shape[a],
146 self.inner_slice[a].stop
147 + (self.halo[a].right if a in self.halo else 0),
148 ),
149 )
150 for a in self.inner_slice
151 }
152 )
154 @cached_property
155 def inner_shape(self) -> PerAxis[int]:
156 """axis lengths of the inner region (without halo)"""
157 return Frozen({a: s.stop - s.start for a, s in self.inner_slice.items()})
159 @cached_property
160 def local_slice(self) -> PerAxis[SliceInfo]:
161 """inner slice wrt the block, **not** the sample"""
162 return Frozen(
163 {
164 a: SliceInfo(
165 self.halo[a].left,
166 self.halo[a].left + self.inner_shape[a],
167 )
168 for a in self.inner_slice
169 }
170 )
172 @property
173 def dims(self) -> Collection[AxisId]:
174 return set(self.inner_shape)
176 @property
177 def tagged_shape(self) -> PerAxis[int]:
178 """alias for shape"""
179 return self.shape
181 @property
182 def inner_slice_wo_overlap(self) -> PerAxis[SliceInfo]:
183 """subslice of the inner slice, such that all `inner_slice_wo_overlap` can be
184 stiched together trivially to form the original sample.
186 This can also be used to calculate statistics
187 without overrepresenting block edge regions."""
188 # TODO: update inner_slice_wo_overlap when adding block overlap
189 return self.inner_slice
191 def __post_init__(self):
192 # freeze mutable inputs
193 if not isinstance(self.sample_shape, Frozen):
194 object.__setattr__(self, "sample_shape", Frozen(self.sample_shape))
196 if not isinstance(self.inner_slice, Frozen):
197 object.__setattr__(self, "inner_slice", Frozen(self.inner_slice))
199 if not isinstance(self.halo, Frozen):
200 object.__setattr__(self, "halo", Frozen(self.halo))
202 assert all(a in self.sample_shape for a in self.inner_slice), (
203 "block has axes not present in sample"
204 )
206 assert all(a in self.inner_slice for a in self.halo), (
207 "halo has axes not present in block"
208 )
210 if any(s > self.sample_shape[a] for a, s in self.shape.items()):
211 logger.warning(
212 "block {} larger than sample {}", self.shape, self.sample_shape
213 )
215 def get_transformed(
216 self, new_axes: PerAxis[Union[LinearAxisTransform, int]]
217 ) -> Self:
218 return self.__class__(
219 sample_shape={
220 a: (
221 trf
222 if isinstance(trf, int)
223 else trf.compute(self.sample_shape[trf.axis])
224 )
225 for a, trf in new_axes.items()
226 },
227 inner_slice={
228 a: (
229 SliceInfo(0, trf)
230 if isinstance(trf, int)
231 else SliceInfo(
232 trf.compute(self.inner_slice[trf.axis].start),
233 trf.compute(self.inner_slice[trf.axis].stop),
234 )
235 )
236 for a, trf in new_axes.items()
237 },
238 halo={
239 a: (
240 Halo(0, 0)
241 if isinstance(trf, int)
242 else Halo(self.halo[trf.axis].left, self.halo[trf.axis].right)
243 )
244 for a, trf in new_axes.items()
245 },
246 block_index=self.block_index,
247 blocks_in_sample=self.blocks_in_sample,
248 )
251def split_shape_into_blocks(
252 shape: PerAxis[int],
253 block_shape: PerAxis[int],
254 halo: PerAxis[HaloLike],
255 stride: Optional[PerAxis[int]] = None,
256) -> Tuple[TotalNumberOfBlocks, Generator[BlockMeta, Any, None]]:
257 assert all(a in shape for a in block_shape), (
258 tuple(shape),
259 set(block_shape),
260 )
261 if any(shape[a] < block_shape[a] for a in block_shape):
262 # TODO: allow larger blockshape
263 raise ValueError(f"shape {shape} is smaller than block shape {block_shape}")
265 assert all(a in shape for a in halo), (tuple(shape), set(halo))
267 # fill in default halo (0) and block axis length (from tensor shape)
268 halo = {a: Halo.create(halo.get(a, 0)) for a in shape}
269 block_shape = {a: block_shape.get(a, s) for a, s in shape.items()}
270 if stride is None:
271 stride = {}
273 inner_1d_slices: Dict[AxisId, List[SliceInfo]] = {}
274 for a, s in shape.items():
275 inner_size = block_shape[a] - sum(halo[a])
276 stride_1d = stride.get(a, inner_size)
277 inner_1d_slices[a] = [
278 SliceInfo(min(p, s - inner_size), min(p + inner_size, s))
279 for p in range(0, s, stride_1d)
280 ]
282 n_blocks = prod(map(len, inner_1d_slices.values()))
284 return n_blocks, _block_meta_generator(
285 shape,
286 blocks_in_sample=n_blocks,
287 inner_1d_slices=inner_1d_slices,
288 halo=halo,
289 )
292def _block_meta_generator(
293 sample_shape: PerAxis[int],
294 *,
295 blocks_in_sample: int,
296 inner_1d_slices: Dict[AxisId, List[SliceInfo]],
297 halo: PerAxis[HaloLike],
298):
299 assert all(a in sample_shape for a in halo)
301 halo = {a: Halo.create(halo.get(a, 0)) for a in inner_1d_slices}
302 for i, nd_tile in enumerate(itertools.product(*inner_1d_slices.values())):
303 inner_slice: PerAxis[SliceInfo] = dict(zip(inner_1d_slices, nd_tile))
305 yield BlockMeta(
306 sample_shape=sample_shape,
307 inner_slice=inner_slice,
308 halo=halo,
309 block_index=i,
310 blocks_in_sample=blocks_in_sample,
311 )
314def split_multiple_shapes_into_blocks(
315 shapes: PerMember[PerAxis[int]],
316 block_shapes: PerMember[PerAxis[int]],
317 *,
318 halo: PerMember[PerAxis[HaloLike]],
319 strides: Optional[PerMember[PerAxis[int]]] = None,
320 broadcast: bool = False,
321) -> Tuple[TotalNumberOfBlocks, Iterable[PerMember[BlockMeta]]]:
322 if unknown_blocks := [t for t in block_shapes if t not in shapes]:
323 raise ValueError(
324 f"block shape specified for unknown tensors: {unknown_blocks}."
325 )
327 if not block_shapes:
328 block_shapes = shapes
330 if not broadcast and (
331 missing_blocks := [t for t in shapes if t not in block_shapes]
332 ):
333 raise ValueError(
334 f"no block shape specified for {missing_blocks}."
335 + " Set `broadcast` to True if these tensors should be repeated"
336 + " as a whole for each block."
337 )
339 if extra_halo := [t for t in halo if t not in block_shapes]:
340 raise ValueError(
341 f"`halo` specified for tensors without block shape: {extra_halo}."
342 )
344 if strides is None:
345 strides = {}
347 assert not (unknown_block := [t for t in strides if t not in block_shapes]), (
348 f"`stride` specified for tensors without block shape: {unknown_block}"
349 )
351 blocks: Dict[MemberId, Iterable[BlockMeta]] = {}
352 n_blocks: Dict[MemberId, TotalNumberOfBlocks] = {}
353 for t in block_shapes:
354 n_blocks[t], blocks[t] = split_shape_into_blocks(
355 shape=shapes[t],
356 block_shape=block_shapes[t],
357 halo=halo.get(t, {}),
358 stride=strides.get(t),
359 )
360 assert n_blocks[t] > 0, n_blocks
362 assert len(blocks) > 0, blocks
363 assert len(n_blocks) > 0, n_blocks
364 unique_n_blocks = set(n_blocks.values())
365 n = max(unique_n_blocks)
366 if len(unique_n_blocks) == 2 and 1 in unique_n_blocks:
367 if not broadcast:
368 raise ValueError(
369 "Mismatch for total number of blocks due to unsplit (single block)"
370 + f" tensors: {n_blocks}. Set `broadcast` to True if you want to"
371 + " repeat unsplit (single block) tensors."
372 )
374 blocks = {
375 t: _repeat_single_block(block_gen, n) if n_blocks[t] == 1 else block_gen
376 for t, block_gen in blocks.items()
377 }
378 elif len(unique_n_blocks) != 1:
379 raise ValueError(f"Mismatch for total number of blocks: {n_blocks}")
381 return n, _aligned_blocks_generator(n, blocks)
384def _aligned_blocks_generator(
385 n: TotalNumberOfBlocks, blocks: Dict[MemberId, Iterable[BlockMeta]]
386):
387 iterators = {t: iter(gen) for t, gen in blocks.items()}
388 for _ in range(n):
389 yield {t: next(it) for t, it in iterators.items()}
392def _repeat_single_block(block_generator: Iterable[BlockMeta], n: TotalNumberOfBlocks):
393 round_two = False
394 for block in block_generator:
395 assert not round_two
396 for _ in range(n):
397 yield block
399 round_two = True