Coverage for bioimageio/core/block_meta.py: 84%
126 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-19 09:02 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-19 09:02 +0000
1import itertools
2from dataclasses import dataclass
3from functools import cached_property
4from math import floor, prod
5from typing import (
6 Any,
7 Callable,
8 Collection,
9 Dict,
10 Generator,
11 Iterable,
12 List,
13 Optional,
14 Tuple,
15 Union,
16)
18from loguru import logger
19from typing_extensions import Self
21from .axis import AxisId, PerAxis
22from .common import (
23 BlockIndex,
24 Frozen,
25 Halo,
26 HaloLike,
27 MemberId,
28 PadWidth,
29 PerMember,
30 SliceInfo,
31 TotalNumberOfBlocks,
32)
35@dataclass
36class LinearAxisTransform:
37 axis: AxisId
38 scale: float
39 offset: int
41 def compute(self, s: int, round: Callable[[float], int] = floor) -> int:
42 return round(s * self.scale) + self.offset
45@dataclass(frozen=True)
46class BlockMeta:
47 """Block meta data of a sample member (a tensor in a sample)
49 Figure for illustration:
50 The first 2d block (dashed) of a sample member (**bold**).
51 The inner slice (thin) is expanded by a halo in both dimensions on both sides.
52 The outer slice reaches from the sample member origin (0, 0) to the right halo point.
54 ```terminal
55 ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐
56 ╷ halo(left) ╷
57 ╷ ╷
58 ╷ (0, 0)┏━━━━━━━━━━━━━━━━━┯━━━━━━━━━┯━━━➔
59 ╷ ┃ │ ╷ sample member
60 ╷ ┃ inner │ ╷
61 ╷ ┃ (and outer) │ outer ╷
62 ╷ ┃ slice │ slice ╷
63 ╷ ┃ │ ╷
64 ╷ ┣─────────────────┘ ╷
65 ╷ ┃ outer slice ╷
66 ╷ ┃ halo(right) ╷
67 └ ─ ─ ─ ─┃─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─┘
68 ⬇
69 ```
71 note:
72 - Inner and outer slices are specified in sample member coordinates.
73 - The outer_slice of a block at the sample edge may overlap by more than the
74 halo with the neighboring block (the inner slices will not overlap though).
76 """
78 sample_shape: PerAxis[int]
79 """the axis sizes of the whole (unblocked) sample"""
81 inner_slice: PerAxis[SliceInfo]
82 """inner region (without halo) wrt the sample"""
84 halo: PerAxis[Halo]
85 """halo enlarging the inner region to the block's sizes"""
87 block_index: BlockIndex
88 """the i-th block of the sample"""
90 blocks_in_sample: TotalNumberOfBlocks
91 """total number of blocks in the sample"""
93 @cached_property
94 def shape(self) -> PerAxis[int]:
95 """axis lengths of the block"""
96 return Frozen(
97 {
98 a: s.stop - s.start + (sum(self.halo[a]) if a in self.halo else 0)
99 for a, s in self.inner_slice.items()
100 }
101 )
103 @cached_property
104 def padding(self) -> PerAxis[PadWidth]:
105 """padding to realize the halo at the sample edge
106 where we cannot simply enlarge the inner slice"""
107 return Frozen(
108 {
109 a: PadWidth(
110 (
111 self.halo[a].left
112 - (self.inner_slice[a].start - self.outer_slice[a].start)
113 if a in self.halo
114 else 0
115 ),
116 (
117 self.halo[a].right
118 - (self.outer_slice[a].stop - self.inner_slice[a].stop)
119 if a in self.halo
120 else 0
121 ),
122 )
123 for a in self.inner_slice
124 }
125 )
127 @cached_property
128 def outer_slice(self) -> PerAxis[SliceInfo]:
129 """slice of the outer block (without padding) wrt the sample"""
130 return Frozen(
131 {
132 a: SliceInfo(
133 max(
134 0,
135 min(
136 self.inner_slice[a].start
137 - (self.halo[a].left if a in self.halo else 0),
138 self.sample_shape[a]
139 - self.inner_shape[a]
140 - (self.halo[a].left if a in self.halo else 0),
141 ),
142 ),
143 min(
144 self.sample_shape[a],
145 self.inner_slice[a].stop
146 + (self.halo[a].right if a in self.halo else 0),
147 ),
148 )
149 for a in self.inner_slice
150 }
151 )
153 @cached_property
154 def inner_shape(self) -> PerAxis[int]:
155 """axis lengths of the inner region (without halo)"""
156 return Frozen({a: s.stop - s.start for a, s in self.inner_slice.items()})
158 @cached_property
159 def local_slice(self) -> PerAxis[SliceInfo]:
160 """inner slice wrt the block, **not** the sample"""
161 return Frozen(
162 {
163 a: SliceInfo(
164 self.halo[a].left,
165 self.halo[a].left + self.inner_shape[a],
166 )
167 for a in self.inner_slice
168 }
169 )
171 @property
172 def dims(self) -> Collection[AxisId]:
173 return set(self.inner_shape)
175 @property
176 def tagged_shape(self) -> PerAxis[int]:
177 """alias for shape"""
178 return self.shape
180 @property
181 def inner_slice_wo_overlap(self):
182 """subslice of the inner slice, such that all `inner_slice_wo_overlap` can be
183 stiched together trivially to form the original sample.
185 This can also be used to calculate statistics
186 without overrepresenting block edge regions."""
187 # TODO: update inner_slice_wo_overlap when adding block overlap
188 return self.inner_slice
190 def __post_init__(self):
191 # freeze mutable inputs
192 if not isinstance(self.sample_shape, Frozen):
193 object.__setattr__(self, "sample_shape", Frozen(self.sample_shape))
195 if not isinstance(self.inner_slice, Frozen):
196 object.__setattr__(self, "inner_slice", Frozen(self.inner_slice))
198 if not isinstance(self.halo, Frozen):
199 object.__setattr__(self, "halo", Frozen(self.halo))
201 assert all(
202 a in self.sample_shape for a in self.inner_slice
203 ), "block has axes not present in sample"
205 assert all(
206 a in self.inner_slice for a in self.halo
207 ), "halo has axes not present in block"
209 if any(s > self.sample_shape[a] for a, s in self.shape.items()):
210 logger.warning(
211 "block {} larger than sample {}", self.shape, self.sample_shape
212 )
214 def get_transformed(
215 self, new_axes: PerAxis[Union[LinearAxisTransform, int]]
216 ) -> Self:
217 return self.__class__(
218 sample_shape={
219 a: (
220 trf
221 if isinstance(trf, int)
222 else trf.compute(self.sample_shape[trf.axis])
223 )
224 for a, trf in new_axes.items()
225 },
226 inner_slice={
227 a: (
228 SliceInfo(0, trf)
229 if isinstance(trf, int)
230 else SliceInfo(
231 trf.compute(self.inner_slice[trf.axis].start),
232 trf.compute(self.inner_slice[trf.axis].stop),
233 )
234 )
235 for a, trf in new_axes.items()
236 },
237 halo={
238 a: (
239 Halo(0, 0)
240 if isinstance(trf, int)
241 else Halo(self.halo[trf.axis].left, self.halo[trf.axis].right)
242 )
243 for a, trf in new_axes.items()
244 },
245 block_index=self.block_index,
246 blocks_in_sample=self.blocks_in_sample,
247 )
250def split_shape_into_blocks(
251 shape: PerAxis[int],
252 block_shape: PerAxis[int],
253 halo: PerAxis[HaloLike],
254 stride: Optional[PerAxis[int]] = None,
255) -> Tuple[TotalNumberOfBlocks, Generator[BlockMeta, Any, None]]:
256 assert all(a in shape for a in block_shape), (
257 tuple(shape),
258 set(block_shape),
259 )
260 if any(shape[a] < block_shape[a] for a in block_shape):
261 raise ValueError(f"shape {shape} is smaller than block shape {block_shape}")
263 assert all(a in shape for a in halo), (tuple(shape), set(halo))
265 # fill in default halo (0) and block axis length (from tensor shape)
266 halo = {a: Halo.create(halo.get(a, 0)) for a in shape}
267 block_shape = {a: block_shape.get(a, s) for a, s in shape.items()}
268 if stride is None:
269 stride = {}
271 inner_1d_slices: Dict[AxisId, List[SliceInfo]] = {}
272 for a, s in shape.items():
273 inner_size = block_shape[a] - sum(halo[a])
274 stride_1d = stride.get(a, inner_size)
275 inner_1d_slices[a] = [
276 SliceInfo(min(p, s - inner_size), min(p + inner_size, s))
277 for p in range(0, s, stride_1d)
278 ]
280 n_blocks = prod(map(len, inner_1d_slices.values()))
282 return n_blocks, _block_meta_generator(
283 shape,
284 blocks_in_sample=n_blocks,
285 inner_1d_slices=inner_1d_slices,
286 halo=halo,
287 )
290def _block_meta_generator(
291 sample_shape: PerAxis[int],
292 *,
293 blocks_in_sample: int,
294 inner_1d_slices: Dict[AxisId, List[SliceInfo]],
295 halo: PerAxis[HaloLike],
296):
297 assert all(a in sample_shape for a in halo)
299 halo = {a: Halo.create(halo.get(a, 0)) for a in inner_1d_slices}
300 for i, nd_tile in enumerate(itertools.product(*inner_1d_slices.values())):
301 inner_slice: PerAxis[SliceInfo] = dict(zip(inner_1d_slices, nd_tile))
303 yield BlockMeta(
304 sample_shape=sample_shape,
305 inner_slice=inner_slice,
306 halo=halo,
307 block_index=i,
308 blocks_in_sample=blocks_in_sample,
309 )
312def split_multiple_shapes_into_blocks(
313 shapes: PerMember[PerAxis[int]],
314 block_shapes: PerMember[PerAxis[int]],
315 *,
316 halo: PerMember[PerAxis[HaloLike]],
317 strides: Optional[PerMember[PerAxis[int]]] = None,
318 broadcast: bool = False,
319) -> Tuple[TotalNumberOfBlocks, Iterable[PerMember[BlockMeta]]]:
320 if unknown_blocks := [t for t in block_shapes if t not in shapes]:
321 raise ValueError(
322 f"block shape specified for unknown tensors: {unknown_blocks}."
323 )
325 if not block_shapes:
326 block_shapes = shapes
328 if not broadcast and (
329 missing_blocks := [t for t in shapes if t not in block_shapes]
330 ):
331 raise ValueError(
332 f"no block shape specified for {missing_blocks}."
333 + " Set `broadcast` to True if these tensors should be repeated"
334 + " as a whole for each block."
335 )
337 if extra_halo := [t for t in halo if t not in block_shapes]:
338 raise ValueError(
339 f"`halo` specified for tensors without block shape: {extra_halo}."
340 )
342 if strides is None:
343 strides = {}
345 assert not (
346 unknown_block := [t for t in strides if t not in block_shapes]
347 ), f"`stride` specified for tensors without block shape: {unknown_block}"
349 blocks: Dict[MemberId, Iterable[BlockMeta]] = {}
350 n_blocks: Dict[MemberId, TotalNumberOfBlocks] = {}
351 for t in block_shapes:
352 n_blocks[t], blocks[t] = split_shape_into_blocks(
353 shape=shapes[t],
354 block_shape=block_shapes[t],
355 halo=halo.get(t, {}),
356 stride=strides.get(t),
357 )
358 assert n_blocks[t] > 0, n_blocks
360 assert len(blocks) > 0, blocks
361 assert len(n_blocks) > 0, n_blocks
362 unique_n_blocks = set(n_blocks.values())
363 n = max(unique_n_blocks)
364 if len(unique_n_blocks) == 2 and 1 in unique_n_blocks:
365 if not broadcast:
366 raise ValueError(
367 "Mismatch for total number of blocks due to unsplit (single block)"
368 + f" tensors: {n_blocks}. Set `broadcast` to True if you want to"
369 + " repeat unsplit (single block) tensors."
370 )
372 blocks = {
373 t: _repeat_single_block(block_gen, n) if n_blocks[t] == 1 else block_gen
374 for t, block_gen in blocks.items()
375 }
376 elif len(unique_n_blocks) != 1:
377 raise ValueError(f"Mismatch for total number of blocks: {n_blocks}")
379 return n, _aligned_blocks_generator(n, blocks)
382def _aligned_blocks_generator(
383 n: TotalNumberOfBlocks, blocks: Dict[MemberId, Iterable[BlockMeta]]
384):
385 iterators = {t: iter(gen) for t, gen in blocks.items()}
386 for _ in range(n):
387 yield {t: next(it) for t, it in iterators.items()}
390def _repeat_single_block(block_generator: Iterable[BlockMeta], n: TotalNumberOfBlocks):
391 round_two = False
392 for block in block_generator:
393 assert not round_two
394 for _ in range(n):
395 yield block
397 round_two = True