Coverage for src / bioimageio / core / block.py: 83%
36 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-18 12:35 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-18 12:35 +0000
1from dataclasses import dataclass
2from typing import (
3 Any,
4 Generator,
5 Iterable,
6 Optional,
7 Tuple,
8 Union,
9)
11from typing_extensions import Self
13from .axis import PerAxis
14from .block_meta import BlockMeta, LinearAxisTransform, split_shape_into_blocks
15from .common import (
16 Halo,
17 HaloLike,
18 PadMode,
19 SliceInfo,
20 TotalNumberOfBlocks,
21)
22from .tensor import Tensor
25@dataclass(frozen=True)
26class Block(BlockMeta):
27 """A block/tile of a (larger) tensor"""
29 data: Tensor
30 """the block's tensor, e.g. a (padded) slice of some larger, original tensor"""
32 @property
33 def inner_data(self):
34 return self.data[self.local_slice]
36 def __post_init__(self):
37 super().__post_init__()
38 assert not any(v == -1 for v in self.sample_shape.values()), self.sample_shape
39 for a, s in self.data.sizes.items():
40 slice_ = self.local_slice[a]
41 halo = self.halo.get(a, Halo(0, 0))
42 if s != (s_exp := halo.left + (slice_.stop - slice_.start) + halo.right):
43 raise ValueError(
44 f"Block data shape does not match block meta information for axis {a}: data size is {s} but expected {s_exp} (halo: {halo}, slice: {slice_})",
45 )
47 @classmethod
48 def from_sample_member(
49 cls,
50 sample_member: Tensor,
51 block: BlockMeta,
52 *,
53 pad_mode: PadMode,
54 ) -> Self:
55 return cls(
56 data=sample_member[block.outer_slice].pad(block.padding, pad_mode),
57 sample_shape=sample_member.tagged_shape,
58 inner_slice=block.inner_slice,
59 halo=block.halo,
60 block_index=block.block_index,
61 blocks_in_sample=block.blocks_in_sample,
62 )
64 def get_transformed(
65 self, new_axes: PerAxis[Union[LinearAxisTransform, int]]
66 ) -> Self:
67 raise NotImplementedError
69 @classmethod
70 def from_meta(cls, meta: BlockMeta, data: Tensor) -> Self:
71 return cls(
72 sample_shape={
73 k: data.tagged_shape[k] if v == -1 else v
74 for k, v in meta.sample_shape.items()
75 },
76 inner_slice={
77 k: (
78 SliceInfo(start=v.start, stop=data.tagged_shape[k])
79 if v.stop == -1
80 else v
81 )
82 for k, v in meta.inner_slice.items()
83 },
84 halo=meta.halo,
85 block_index=meta.block_index,
86 blocks_in_sample=meta.blocks_in_sample,
87 data=data,
88 )
91def split_tensor_into_blocks(
92 tensor: Tensor,
93 block_shape: PerAxis[int],
94 *,
95 halo: PerAxis[HaloLike],
96 stride: Optional[PerAxis[int]] = None,
97 pad_mode: PadMode,
98) -> Tuple[TotalNumberOfBlocks, Generator[Block, Any, None]]:
99 """divide a sample tensor into tensor blocks."""
100 n_blocks, block_gen = split_shape_into_blocks(
101 tensor.tagged_shape, block_shape=block_shape, halo=halo, stride=stride
102 )
103 return n_blocks, _block_generator(tensor, block_gen, pad_mode=pad_mode)
106def _block_generator(sample: Tensor, blocks: Iterable[BlockMeta], *, pad_mode: PadMode):
107 for block in blocks:
108 yield Block.from_sample_member(sample, block, pad_mode=pad_mode)