bioimageio.core.block_meta

  1import itertools
  2from dataclasses import dataclass
  3from functools import cached_property
  4from math import floor, prod
  5from typing import (
  6    Any,
  7    Callable,
  8    Collection,
  9    Dict,
 10    Generator,
 11    Iterable,
 12    List,
 13    Optional,
 14    Tuple,
 15    Union,
 16)
 17
 18from loguru import logger
 19from typing_extensions import Self
 20
 21from .axis import AxisId, PerAxis
 22from .common import (
 23    BlockIndex,
 24    Frozen,
 25    Halo,
 26    HaloLike,
 27    MemberId,
 28    PadWidth,
 29    PerMember,
 30    SliceInfo,
 31    TotalNumberOfBlocks,
 32)
 33
 34
 35@dataclass
 36class LinearAxisTransform:
 37    axis: AxisId
 38    scale: float
 39    offset: int
 40
 41    def compute(self, s: int, round: Callable[[float], int] = floor) -> int:
 42        return round(s * self.scale) + self.offset
 43
 44
 45@dataclass(frozen=True)
 46class BlockMeta:
 47    """Block meta data of a sample member (a tensor in a sample)
 48
 49    Figure for illustration:
 50    The first 2d block (dashed) of a sample member (**bold**).
 51    The inner slice (thin) is expanded by a halo in both dimensions on both sides.
 52    The outer slice reaches from the sample member origin (0, 0) to the right halo point.
 53
 54    ```terminal
 55    ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─  ─ ─ ─ ─ ─ ─ ─ ┐
 56    ╷ halo(left)                         ╷
 57    ╷                                    ╷
 58    ╷  (0, 0)┏━━━━━━━━━━━━━━━━━┯━━━━━━━━━┯━━━➔
 59    ╷        ┃                 │         ╷  sample member
 60    ╷        ┃      inner      │         ╷
 61    ╷        ┃   (and outer)   │  outer  ╷
 62    ╷        ┃      slice      │  slice  ╷
 63    ╷        ┃                 │         ╷
 64    ╷        ┣─────────────────┘         ╷
 65    ╷        ┃   outer slice             ╷
 66    ╷        ┃               halo(right) ╷
 67    └ ─ ─ ─ ─┃─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─┘
 68
 69    ```
 70
 71    note:
 72    - Inner and outer slices are specified in sample member coordinates.
 73    - The outer_slice of a block at the sample edge may overlap by more than the
 74        halo with the neighboring block (the inner slices will not overlap though).
 75
 76    """
 77
 78    sample_shape: PerAxis[int]
 79    """the axis sizes of the whole (unblocked) sample"""
 80
 81    inner_slice: PerAxis[SliceInfo]
 82    """inner region (without halo) wrt the sample"""
 83
 84    halo: PerAxis[Halo]
 85    """halo enlarging the inner region to the block's sizes"""
 86
 87    block_index: BlockIndex
 88    """the i-th block of the sample"""
 89
 90    blocks_in_sample: TotalNumberOfBlocks
 91    """total number of blocks in the sample"""
 92
 93    @cached_property
 94    def shape(self) -> PerAxis[int]:
 95        """axis lengths of the block"""
 96        return Frozen(
 97            {
 98                a: s.stop - s.start + (sum(self.halo[a]) if a in self.halo else 0)
 99                for a, s in self.inner_slice.items()
100            }
101        )
102
103    @cached_property
104    def padding(self) -> PerAxis[PadWidth]:
105        """padding to realize the halo at the sample edge
106        where we cannot simply enlarge the inner slice"""
107        return Frozen(
108            {
109                a: PadWidth(
110                    (
111                        self.halo[a].left
112                        - (self.inner_slice[a].start - self.outer_slice[a].start)
113                        if a in self.halo
114                        else 0
115                    ),
116                    (
117                        self.halo[a].right
118                        - (self.outer_slice[a].stop - self.inner_slice[a].stop)
119                        if a in self.halo
120                        else 0
121                    ),
122                )
123                for a in self.inner_slice
124            }
125        )
126
127    @cached_property
128    def outer_slice(self) -> PerAxis[SliceInfo]:
129        """slice of the outer block (without padding) wrt the sample"""
130        return Frozen(
131            {
132                a: SliceInfo(
133                    max(
134                        0,
135                        min(
136                            self.inner_slice[a].start
137                            - (self.halo[a].left if a in self.halo else 0),
138                            self.sample_shape[a]
139                            - self.inner_shape[a]
140                            - (self.halo[a].left if a in self.halo else 0),
141                        ),
142                    ),
143                    min(
144                        self.sample_shape[a],
145                        self.inner_slice[a].stop
146                        + (self.halo[a].right if a in self.halo else 0),
147                    ),
148                )
149                for a in self.inner_slice
150            }
151        )
152
153    @cached_property
154    def inner_shape(self) -> PerAxis[int]:
155        """axis lengths of the inner region (without halo)"""
156        return Frozen({a: s.stop - s.start for a, s in self.inner_slice.items()})
157
158    @cached_property
159    def local_slice(self) -> PerAxis[SliceInfo]:
160        """inner slice wrt the block, **not** the sample"""
161        return Frozen(
162            {
163                a: SliceInfo(
164                    self.halo[a].left,
165                    self.halo[a].left + self.inner_shape[a],
166                )
167                for a in self.inner_slice
168            }
169        )
170
171    @property
172    def dims(self) -> Collection[AxisId]:
173        return set(self.inner_shape)
174
175    @property
176    def tagged_shape(self) -> PerAxis[int]:
177        """alias for shape"""
178        return self.shape
179
180    @property
181    def inner_slice_wo_overlap(self):
182        """subslice of the inner slice, such that all `inner_slice_wo_overlap` can be
183        stiched together trivially to form the original sample.
184
185        This can also be used to calculate statistics
186        without overrepresenting block edge regions."""
187        # TODO: update inner_slice_wo_overlap when adding block overlap
188        return self.inner_slice
189
190    def __post_init__(self):
191        # freeze mutable inputs
192        if not isinstance(self.sample_shape, Frozen):
193            object.__setattr__(self, "sample_shape", Frozen(self.sample_shape))
194
195        if not isinstance(self.inner_slice, Frozen):
196            object.__setattr__(self, "inner_slice", Frozen(self.inner_slice))
197
198        if not isinstance(self.halo, Frozen):
199            object.__setattr__(self, "halo", Frozen(self.halo))
200
201        assert all(
202            a in self.sample_shape for a in self.inner_slice
203        ), "block has axes not present in sample"
204
205        assert all(
206            a in self.inner_slice for a in self.halo
207        ), "halo has axes not present in block"
208
209        if any(s > self.sample_shape[a] for a, s in self.shape.items()):
210            logger.warning(
211                "block {} larger than sample {}", self.shape, self.sample_shape
212            )
213
214    def get_transformed(
215        self, new_axes: PerAxis[Union[LinearAxisTransform, int]]
216    ) -> Self:
217        return self.__class__(
218            sample_shape={
219                a: (
220                    trf
221                    if isinstance(trf, int)
222                    else trf.compute(self.sample_shape[trf.axis])
223                )
224                for a, trf in new_axes.items()
225            },
226            inner_slice={
227                a: (
228                    SliceInfo(0, trf)
229                    if isinstance(trf, int)
230                    else SliceInfo(
231                        trf.compute(self.inner_slice[trf.axis].start),
232                        trf.compute(self.inner_slice[trf.axis].stop),
233                    )
234                )
235                for a, trf in new_axes.items()
236            },
237            halo={
238                a: (
239                    Halo(0, 0)
240                    if isinstance(trf, int)
241                    else Halo(self.halo[trf.axis].left, self.halo[trf.axis].right)
242                )
243                for a, trf in new_axes.items()
244            },
245            block_index=self.block_index,
246            blocks_in_sample=self.blocks_in_sample,
247        )
248
249
250def split_shape_into_blocks(
251    shape: PerAxis[int],
252    block_shape: PerAxis[int],
253    halo: PerAxis[HaloLike],
254    stride: Optional[PerAxis[int]] = None,
255) -> Tuple[TotalNumberOfBlocks, Generator[BlockMeta, Any, None]]:
256    assert all(a in shape for a in block_shape), (
257        tuple(shape),
258        set(block_shape),
259    )
260    if any(shape[a] < block_shape[a] for a in block_shape):
261        raise ValueError(f"shape {shape} is smaller than block shape {block_shape}")
262
263    assert all(a in shape for a in halo), (tuple(shape), set(halo))
264
265    # fill in default halo (0) and block axis length (from tensor shape)
266    halo = {a: Halo.create(halo.get(a, 0)) for a in shape}
267    block_shape = {a: block_shape.get(a, s) for a, s in shape.items()}
268    if stride is None:
269        stride = {}
270
271    inner_1d_slices: Dict[AxisId, List[SliceInfo]] = {}
272    for a, s in shape.items():
273        inner_size = block_shape[a] - sum(halo[a])
274        stride_1d = stride.get(a, inner_size)
275        inner_1d_slices[a] = [
276            SliceInfo(min(p, s - inner_size), min(p + inner_size, s))
277            for p in range(0, s, stride_1d)
278        ]
279
280    n_blocks = prod(map(len, inner_1d_slices.values()))
281
282    return n_blocks, _block_meta_generator(
283        shape,
284        blocks_in_sample=n_blocks,
285        inner_1d_slices=inner_1d_slices,
286        halo=halo,
287    )
288
289
290def _block_meta_generator(
291    sample_shape: PerAxis[int],
292    *,
293    blocks_in_sample: int,
294    inner_1d_slices: Dict[AxisId, List[SliceInfo]],
295    halo: PerAxis[HaloLike],
296):
297    assert all(a in sample_shape for a in halo)
298
299    halo = {a: Halo.create(halo.get(a, 0)) for a in inner_1d_slices}
300    for i, nd_tile in enumerate(itertools.product(*inner_1d_slices.values())):
301        inner_slice: PerAxis[SliceInfo] = dict(zip(inner_1d_slices, nd_tile))
302
303        yield BlockMeta(
304            sample_shape=sample_shape,
305            inner_slice=inner_slice,
306            halo=halo,
307            block_index=i,
308            blocks_in_sample=blocks_in_sample,
309        )
310
311
312def split_multiple_shapes_into_blocks(
313    shapes: PerMember[PerAxis[int]],
314    block_shapes: PerMember[PerAxis[int]],
315    *,
316    halo: PerMember[PerAxis[HaloLike]],
317    strides: Optional[PerMember[PerAxis[int]]] = None,
318    broadcast: bool = False,
319) -> Tuple[TotalNumberOfBlocks, Iterable[PerMember[BlockMeta]]]:
320    if unknown_blocks := [t for t in block_shapes if t not in shapes]:
321        raise ValueError(
322            f"block shape specified for unknown tensors: {unknown_blocks}."
323        )
324
325    if not block_shapes:
326        block_shapes = shapes
327
328    if not broadcast and (
329        missing_blocks := [t for t in shapes if t not in block_shapes]
330    ):
331        raise ValueError(
332            f"no block shape specified for {missing_blocks}."
333            + " Set `broadcast` to True if these tensors should be repeated"
334            + " as a whole for each block."
335        )
336
337    if extra_halo := [t for t in halo if t not in block_shapes]:
338        raise ValueError(
339            f"`halo` specified for tensors without block shape: {extra_halo}."
340        )
341
342    if strides is None:
343        strides = {}
344
345    assert not (
346        unknown_block := [t for t in strides if t not in block_shapes]
347    ), f"`stride` specified for tensors without block shape: {unknown_block}"
348
349    blocks: Dict[MemberId, Iterable[BlockMeta]] = {}
350    n_blocks: Dict[MemberId, TotalNumberOfBlocks] = {}
351    for t in block_shapes:
352        n_blocks[t], blocks[t] = split_shape_into_blocks(
353            shape=shapes[t],
354            block_shape=block_shapes[t],
355            halo=halo.get(t, {}),
356            stride=strides.get(t),
357        )
358        assert n_blocks[t] > 0, n_blocks
359
360    assert len(blocks) > 0, blocks
361    assert len(n_blocks) > 0, n_blocks
362    unique_n_blocks = set(n_blocks.values())
363    n = max(unique_n_blocks)
364    if len(unique_n_blocks) == 2 and 1 in unique_n_blocks:
365        if not broadcast:
366            raise ValueError(
367                "Mismatch for total number of blocks due to unsplit (single block)"
368                + f" tensors: {n_blocks}. Set `broadcast` to True if you want to"
369                + " repeat unsplit (single block) tensors."
370            )
371
372        blocks = {
373            t: _repeat_single_block(block_gen, n) if n_blocks[t] == 1 else block_gen
374            for t, block_gen in blocks.items()
375        }
376    elif len(unique_n_blocks) != 1:
377        raise ValueError(f"Mismatch for total number of blocks: {n_blocks}")
378
379    return n, _aligned_blocks_generator(n, blocks)
380
381
382def _aligned_blocks_generator(
383    n: TotalNumberOfBlocks, blocks: Dict[MemberId, Iterable[BlockMeta]]
384):
385    iterators = {t: iter(gen) for t, gen in blocks.items()}
386    for _ in range(n):
387        yield {t: next(it) for t, it in iterators.items()}
388
389
390def _repeat_single_block(block_generator: Iterable[BlockMeta], n: TotalNumberOfBlocks):
391    round_two = False
392    for block in block_generator:
393        assert not round_two
394        for _ in range(n):
395            yield block
396
397        round_two = True
@dataclass
class LinearAxisTransform:
36@dataclass
37class LinearAxisTransform:
38    axis: AxisId
39    scale: float
40    offset: int
41
42    def compute(self, s: int, round: Callable[[float], int] = floor) -> int:
43        return round(s * self.scale) + self.offset
LinearAxisTransform(axis: bioimageio.spec.model.v0_5.AxisId, scale: float, offset: int)
scale: float
offset: int
def compute( self, s: int, round: Callable[[float], int] = <built-in function floor>) -> int:
42    def compute(self, s: int, round: Callable[[float], int] = floor) -> int:
43        return round(s * self.scale) + self.offset
@dataclass(frozen=True)
class BlockMeta:
 46@dataclass(frozen=True)
 47class BlockMeta:
 48    """Block meta data of a sample member (a tensor in a sample)
 49
 50    Figure for illustration:
 51    The first 2d block (dashed) of a sample member (**bold**).
 52    The inner slice (thin) is expanded by a halo in both dimensions on both sides.
 53    The outer slice reaches from the sample member origin (0, 0) to the right halo point.
 54
 55    ```terminal
 56    ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─  ─ ─ ─ ─ ─ ─ ─ ┐
 57    ╷ halo(left)                         ╷
 58    ╷                                    ╷
 59    ╷  (0, 0)┏━━━━━━━━━━━━━━━━━┯━━━━━━━━━┯━━━➔
 60    ╷        ┃                 │         ╷  sample member
 61    ╷        ┃      inner      │         ╷
 62    ╷        ┃   (and outer)   │  outer  ╷
 63    ╷        ┃      slice      │  slice  ╷
 64    ╷        ┃                 │         ╷
 65    ╷        ┣─────────────────┘         ╷
 66    ╷        ┃   outer slice             ╷
 67    ╷        ┃               halo(right) ╷
 68    └ ─ ─ ─ ─┃─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─┘
 69
 70    ```
 71
 72    note:
 73    - Inner and outer slices are specified in sample member coordinates.
 74    - The outer_slice of a block at the sample edge may overlap by more than the
 75        halo with the neighboring block (the inner slices will not overlap though).
 76
 77    """
 78
 79    sample_shape: PerAxis[int]
 80    """the axis sizes of the whole (unblocked) sample"""
 81
 82    inner_slice: PerAxis[SliceInfo]
 83    """inner region (without halo) wrt the sample"""
 84
 85    halo: PerAxis[Halo]
 86    """halo enlarging the inner region to the block's sizes"""
 87
 88    block_index: BlockIndex
 89    """the i-th block of the sample"""
 90
 91    blocks_in_sample: TotalNumberOfBlocks
 92    """total number of blocks in the sample"""
 93
 94    @cached_property
 95    def shape(self) -> PerAxis[int]:
 96        """axis lengths of the block"""
 97        return Frozen(
 98            {
 99                a: s.stop - s.start + (sum(self.halo[a]) if a in self.halo else 0)
100                for a, s in self.inner_slice.items()
101            }
102        )
103
104    @cached_property
105    def padding(self) -> PerAxis[PadWidth]:
106        """padding to realize the halo at the sample edge
107        where we cannot simply enlarge the inner slice"""
108        return Frozen(
109            {
110                a: PadWidth(
111                    (
112                        self.halo[a].left
113                        - (self.inner_slice[a].start - self.outer_slice[a].start)
114                        if a in self.halo
115                        else 0
116                    ),
117                    (
118                        self.halo[a].right
119                        - (self.outer_slice[a].stop - self.inner_slice[a].stop)
120                        if a in self.halo
121                        else 0
122                    ),
123                )
124                for a in self.inner_slice
125            }
126        )
127
128    @cached_property
129    def outer_slice(self) -> PerAxis[SliceInfo]:
130        """slice of the outer block (without padding) wrt the sample"""
131        return Frozen(
132            {
133                a: SliceInfo(
134                    max(
135                        0,
136                        min(
137                            self.inner_slice[a].start
138                            - (self.halo[a].left if a in self.halo else 0),
139                            self.sample_shape[a]
140                            - self.inner_shape[a]
141                            - (self.halo[a].left if a in self.halo else 0),
142                        ),
143                    ),
144                    min(
145                        self.sample_shape[a],
146                        self.inner_slice[a].stop
147                        + (self.halo[a].right if a in self.halo else 0),
148                    ),
149                )
150                for a in self.inner_slice
151            }
152        )
153
154    @cached_property
155    def inner_shape(self) -> PerAxis[int]:
156        """axis lengths of the inner region (without halo)"""
157        return Frozen({a: s.stop - s.start for a, s in self.inner_slice.items()})
158
159    @cached_property
160    def local_slice(self) -> PerAxis[SliceInfo]:
161        """inner slice wrt the block, **not** the sample"""
162        return Frozen(
163            {
164                a: SliceInfo(
165                    self.halo[a].left,
166                    self.halo[a].left + self.inner_shape[a],
167                )
168                for a in self.inner_slice
169            }
170        )
171
172    @property
173    def dims(self) -> Collection[AxisId]:
174        return set(self.inner_shape)
175
176    @property
177    def tagged_shape(self) -> PerAxis[int]:
178        """alias for shape"""
179        return self.shape
180
181    @property
182    def inner_slice_wo_overlap(self):
183        """subslice of the inner slice, such that all `inner_slice_wo_overlap` can be
184        stiched together trivially to form the original sample.
185
186        This can also be used to calculate statistics
187        without overrepresenting block edge regions."""
188        # TODO: update inner_slice_wo_overlap when adding block overlap
189        return self.inner_slice
190
191    def __post_init__(self):
192        # freeze mutable inputs
193        if not isinstance(self.sample_shape, Frozen):
194            object.__setattr__(self, "sample_shape", Frozen(self.sample_shape))
195
196        if not isinstance(self.inner_slice, Frozen):
197            object.__setattr__(self, "inner_slice", Frozen(self.inner_slice))
198
199        if not isinstance(self.halo, Frozen):
200            object.__setattr__(self, "halo", Frozen(self.halo))
201
202        assert all(
203            a in self.sample_shape for a in self.inner_slice
204        ), "block has axes not present in sample"
205
206        assert all(
207            a in self.inner_slice for a in self.halo
208        ), "halo has axes not present in block"
209
210        if any(s > self.sample_shape[a] for a, s in self.shape.items()):
211            logger.warning(
212                "block {} larger than sample {}", self.shape, self.sample_shape
213            )
214
215    def get_transformed(
216        self, new_axes: PerAxis[Union[LinearAxisTransform, int]]
217    ) -> Self:
218        return self.__class__(
219            sample_shape={
220                a: (
221                    trf
222                    if isinstance(trf, int)
223                    else trf.compute(self.sample_shape[trf.axis])
224                )
225                for a, trf in new_axes.items()
226            },
227            inner_slice={
228                a: (
229                    SliceInfo(0, trf)
230                    if isinstance(trf, int)
231                    else SliceInfo(
232                        trf.compute(self.inner_slice[trf.axis].start),
233                        trf.compute(self.inner_slice[trf.axis].stop),
234                    )
235                )
236                for a, trf in new_axes.items()
237            },
238            halo={
239                a: (
240                    Halo(0, 0)
241                    if isinstance(trf, int)
242                    else Halo(self.halo[trf.axis].left, self.halo[trf.axis].right)
243                )
244                for a, trf in new_axes.items()
245            },
246            block_index=self.block_index,
247            blocks_in_sample=self.blocks_in_sample,
248        )

Block meta data of a sample member (a tensor in a sample)

Figure for illustration: The first 2d block (dashed) of a sample member (bold). The inner slice (thin) is expanded by a halo in both dimensions on both sides. The outer slice reaches from the sample member origin (0, 0) to the right halo point.

┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─  ─ ─ ─ ─ ─ ─ ─ ┐
╷ halo(left)                         ╷
╷                                    ╷
╷  (0, 0)┏━━━━━━━━━━━━━━━━━┯━━━━━━━━━┯━━━➔
╷        ┃                 │         ╷  sample member
╷        ┃      inner      │         ╷
╷        ┃   (and outer)   │  outer  ╷
╷        ┃      slice      │  slice  ╷
╷        ┃                 │         ╷
╷        ┣─────────────────┘         ╷
╷        ┃   outer slice             ╷
╷        ┃               halo(right) ╷
└ ─ ─ ─ ─┃─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─┘
         ⬇

note:

  • Inner and outer slices are specified in sample member coordinates.
  • The outer_slice of a block at the sample edge may overlap by more than the halo with the neighboring block (the inner slices will not overlap though).
BlockMeta( sample_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int], inner_slice: Mapping[bioimageio.spec.model.v0_5.AxisId, bioimageio.core.common.SliceInfo], halo: Mapping[bioimageio.spec.model.v0_5.AxisId, bioimageio.core.common.Halo], block_index: int, blocks_in_sample: int)
sample_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]

the axis sizes of the whole (unblocked) sample

inner region (without halo) wrt the sample

halo enlarging the inner region to the block's sizes

block_index: int

the i-th block of the sample

blocks_in_sample: int

total number of blocks in the sample

shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]
 94    @cached_property
 95    def shape(self) -> PerAxis[int]:
 96        """axis lengths of the block"""
 97        return Frozen(
 98            {
 99                a: s.stop - s.start + (sum(self.halo[a]) if a in self.halo else 0)
100                for a, s in self.inner_slice.items()
101            }
102        )

axis lengths of the block

104    @cached_property
105    def padding(self) -> PerAxis[PadWidth]:
106        """padding to realize the halo at the sample edge
107        where we cannot simply enlarge the inner slice"""
108        return Frozen(
109            {
110                a: PadWidth(
111                    (
112                        self.halo[a].left
113                        - (self.inner_slice[a].start - self.outer_slice[a].start)
114                        if a in self.halo
115                        else 0
116                    ),
117                    (
118                        self.halo[a].right
119                        - (self.outer_slice[a].stop - self.inner_slice[a].stop)
120                        if a in self.halo
121                        else 0
122                    ),
123                )
124                for a in self.inner_slice
125            }
126        )

padding to realize the halo at the sample edge where we cannot simply enlarge the inner slice

128    @cached_property
129    def outer_slice(self) -> PerAxis[SliceInfo]:
130        """slice of the outer block (without padding) wrt the sample"""
131        return Frozen(
132            {
133                a: SliceInfo(
134                    max(
135                        0,
136                        min(
137                            self.inner_slice[a].start
138                            - (self.halo[a].left if a in self.halo else 0),
139                            self.sample_shape[a]
140                            - self.inner_shape[a]
141                            - (self.halo[a].left if a in self.halo else 0),
142                        ),
143                    ),
144                    min(
145                        self.sample_shape[a],
146                        self.inner_slice[a].stop
147                        + (self.halo[a].right if a in self.halo else 0),
148                    ),
149                )
150                for a in self.inner_slice
151            }
152        )

slice of the outer block (without padding) wrt the sample

inner_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]
154    @cached_property
155    def inner_shape(self) -> PerAxis[int]:
156        """axis lengths of the inner region (without halo)"""
157        return Frozen({a: s.stop - s.start for a, s in self.inner_slice.items()})

axis lengths of the inner region (without halo)

159    @cached_property
160    def local_slice(self) -> PerAxis[SliceInfo]:
161        """inner slice wrt the block, **not** the sample"""
162        return Frozen(
163            {
164                a: SliceInfo(
165                    self.halo[a].left,
166                    self.halo[a].left + self.inner_shape[a],
167                )
168                for a in self.inner_slice
169            }
170        )

inner slice wrt the block, not the sample

dims: Collection[bioimageio.spec.model.v0_5.AxisId]
172    @property
173    def dims(self) -> Collection[AxisId]:
174        return set(self.inner_shape)
tagged_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int]
176    @property
177    def tagged_shape(self) -> PerAxis[int]:
178        """alias for shape"""
179        return self.shape

alias for shape

inner_slice_wo_overlap
181    @property
182    def inner_slice_wo_overlap(self):
183        """subslice of the inner slice, such that all `inner_slice_wo_overlap` can be
184        stiched together trivially to form the original sample.
185
186        This can also be used to calculate statistics
187        without overrepresenting block edge regions."""
188        # TODO: update inner_slice_wo_overlap when adding block overlap
189        return self.inner_slice

subslice of the inner slice, such that all inner_slice_wo_overlap can be stiched together trivially to form the original sample.

This can also be used to calculate statistics without overrepresenting block edge regions.

def get_transformed( self, new_axes: Mapping[bioimageio.spec.model.v0_5.AxisId, Union[LinearAxisTransform, int]]) -> Self:
215    def get_transformed(
216        self, new_axes: PerAxis[Union[LinearAxisTransform, int]]
217    ) -> Self:
218        return self.__class__(
219            sample_shape={
220                a: (
221                    trf
222                    if isinstance(trf, int)
223                    else trf.compute(self.sample_shape[trf.axis])
224                )
225                for a, trf in new_axes.items()
226            },
227            inner_slice={
228                a: (
229                    SliceInfo(0, trf)
230                    if isinstance(trf, int)
231                    else SliceInfo(
232                        trf.compute(self.inner_slice[trf.axis].start),
233                        trf.compute(self.inner_slice[trf.axis].stop),
234                    )
235                )
236                for a, trf in new_axes.items()
237            },
238            halo={
239                a: (
240                    Halo(0, 0)
241                    if isinstance(trf, int)
242                    else Halo(self.halo[trf.axis].left, self.halo[trf.axis].right)
243                )
244                for a, trf in new_axes.items()
245            },
246            block_index=self.block_index,
247            blocks_in_sample=self.blocks_in_sample,
248        )
def split_shape_into_blocks( shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int], block_shape: Mapping[bioimageio.spec.model.v0_5.AxisId, int], halo: Mapping[bioimageio.spec.model.v0_5.AxisId, Union[int, Tuple[int, int], bioimageio.core.common.Halo]], stride: Optional[Mapping[bioimageio.spec.model.v0_5.AxisId, int]] = None) -> Tuple[int, Generator[BlockMeta, Any, NoneType]]:
251def split_shape_into_blocks(
252    shape: PerAxis[int],
253    block_shape: PerAxis[int],
254    halo: PerAxis[HaloLike],
255    stride: Optional[PerAxis[int]] = None,
256) -> Tuple[TotalNumberOfBlocks, Generator[BlockMeta, Any, None]]:
257    assert all(a in shape for a in block_shape), (
258        tuple(shape),
259        set(block_shape),
260    )
261    if any(shape[a] < block_shape[a] for a in block_shape):
262        raise ValueError(f"shape {shape} is smaller than block shape {block_shape}")
263
264    assert all(a in shape for a in halo), (tuple(shape), set(halo))
265
266    # fill in default halo (0) and block axis length (from tensor shape)
267    halo = {a: Halo.create(halo.get(a, 0)) for a in shape}
268    block_shape = {a: block_shape.get(a, s) for a, s in shape.items()}
269    if stride is None:
270        stride = {}
271
272    inner_1d_slices: Dict[AxisId, List[SliceInfo]] = {}
273    for a, s in shape.items():
274        inner_size = block_shape[a] - sum(halo[a])
275        stride_1d = stride.get(a, inner_size)
276        inner_1d_slices[a] = [
277            SliceInfo(min(p, s - inner_size), min(p + inner_size, s))
278            for p in range(0, s, stride_1d)
279        ]
280
281    n_blocks = prod(map(len, inner_1d_slices.values()))
282
283    return n_blocks, _block_meta_generator(
284        shape,
285        blocks_in_sample=n_blocks,
286        inner_1d_slices=inner_1d_slices,
287        halo=halo,
288    )
def split_multiple_shapes_into_blocks( shapes: Mapping[bioimageio.spec.model.v0_5.TensorId, Mapping[bioimageio.spec.model.v0_5.AxisId, int]], block_shapes: Mapping[bioimageio.spec.model.v0_5.TensorId, Mapping[bioimageio.spec.model.v0_5.AxisId, int]], *, halo: Mapping[bioimageio.spec.model.v0_5.TensorId, Mapping[bioimageio.spec.model.v0_5.AxisId, Union[int, Tuple[int, int], bioimageio.core.common.Halo]]], strides: Optional[Mapping[bioimageio.spec.model.v0_5.TensorId, Mapping[bioimageio.spec.model.v0_5.AxisId, int]]] = None, broadcast: bool = False) -> Tuple[int, Iterable[Mapping[bioimageio.spec.model.v0_5.TensorId, BlockMeta]]]:
313def split_multiple_shapes_into_blocks(
314    shapes: PerMember[PerAxis[int]],
315    block_shapes: PerMember[PerAxis[int]],
316    *,
317    halo: PerMember[PerAxis[HaloLike]],
318    strides: Optional[PerMember[PerAxis[int]]] = None,
319    broadcast: bool = False,
320) -> Tuple[TotalNumberOfBlocks, Iterable[PerMember[BlockMeta]]]:
321    if unknown_blocks := [t for t in block_shapes if t not in shapes]:
322        raise ValueError(
323            f"block shape specified for unknown tensors: {unknown_blocks}."
324        )
325
326    if not block_shapes:
327        block_shapes = shapes
328
329    if not broadcast and (
330        missing_blocks := [t for t in shapes if t not in block_shapes]
331    ):
332        raise ValueError(
333            f"no block shape specified for {missing_blocks}."
334            + " Set `broadcast` to True if these tensors should be repeated"
335            + " as a whole for each block."
336        )
337
338    if extra_halo := [t for t in halo if t not in block_shapes]:
339        raise ValueError(
340            f"`halo` specified for tensors without block shape: {extra_halo}."
341        )
342
343    if strides is None:
344        strides = {}
345
346    assert not (
347        unknown_block := [t for t in strides if t not in block_shapes]
348    ), f"`stride` specified for tensors without block shape: {unknown_block}"
349
350    blocks: Dict[MemberId, Iterable[BlockMeta]] = {}
351    n_blocks: Dict[MemberId, TotalNumberOfBlocks] = {}
352    for t in block_shapes:
353        n_blocks[t], blocks[t] = split_shape_into_blocks(
354            shape=shapes[t],
355            block_shape=block_shapes[t],
356            halo=halo.get(t, {}),
357            stride=strides.get(t),
358        )
359        assert n_blocks[t] > 0, n_blocks
360
361    assert len(blocks) > 0, blocks
362    assert len(n_blocks) > 0, n_blocks
363    unique_n_blocks = set(n_blocks.values())
364    n = max(unique_n_blocks)
365    if len(unique_n_blocks) == 2 and 1 in unique_n_blocks:
366        if not broadcast:
367            raise ValueError(
368                "Mismatch for total number of blocks due to unsplit (single block)"
369                + f" tensors: {n_blocks}. Set `broadcast` to True if you want to"
370                + " repeat unsplit (single block) tensors."
371            )
372
373        blocks = {
374            t: _repeat_single_block(block_gen, n) if n_blocks[t] == 1 else block_gen
375            for t, block_gen in blocks.items()
376        }
377    elif len(unique_n_blocks) != 1:
378        raise ValueError(f"Mismatch for total number of blocks: {n_blocks}")
379
380    return n, _aligned_blocks_generator(n, blocks)