Coverage for bioimageio/core/block.py: 86%
35 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-19 09:02 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-19 09:02 +0000
1from dataclasses import dataclass
2from typing import (
3 Any,
4 Generator,
5 Iterable,
6 Optional,
7 Tuple,
8 Union,
9)
11from typing_extensions import Self
13from .axis import PerAxis
14from .block_meta import BlockMeta, LinearAxisTransform, split_shape_into_blocks
15from .common import (
16 Halo,
17 HaloLike,
18 PadMode,
19 SliceInfo,
20 TotalNumberOfBlocks,
21)
22from .tensor import Tensor
25@dataclass(frozen=True)
26class Block(BlockMeta):
27 """A block/tile of a (larger) tensor"""
29 data: Tensor
30 """the block's tensor, e.g. a (padded) slice of some larger, original tensor"""
32 @property
33 def inner_data(self):
34 return self.data[self.local_slice]
36 def __post_init__(self):
37 super().__post_init__()
38 assert not any(v == -1 for v in self.sample_shape.values()), self.sample_shape
39 for a, s in self.data.sizes.items():
40 slice_ = self.inner_slice[a]
41 halo = self.halo.get(a, Halo(0, 0))
42 assert s == halo.left + (slice_.stop - slice_.start) + halo.right, (
43 s,
44 slice_,
45 halo,
46 )
48 @classmethod
49 def from_sample_member(
50 cls,
51 sample_member: Tensor,
52 block: BlockMeta,
53 *,
54 pad_mode: PadMode,
55 ) -> Self:
56 return cls(
57 data=sample_member[block.outer_slice].pad(block.padding, pad_mode),
58 sample_shape=sample_member.tagged_shape,
59 inner_slice=block.inner_slice,
60 halo=block.halo,
61 block_index=block.block_index,
62 blocks_in_sample=block.blocks_in_sample,
63 )
65 def get_transformed(
66 self, new_axes: PerAxis[Union[LinearAxisTransform, int]]
67 ) -> Self:
68 raise NotImplementedError
70 @classmethod
71 def from_meta(cls, meta: BlockMeta, data: Tensor) -> Self:
72 return cls(
73 sample_shape={
74 k: data.tagged_shape[k] if v == -1 else v
75 for k, v in meta.sample_shape.items()
76 },
77 inner_slice={
78 k: (
79 SliceInfo(start=v.start, stop=data.tagged_shape[k])
80 if v.stop == -1
81 else v
82 )
83 for k, v in meta.inner_slice.items()
84 },
85 halo=meta.halo,
86 block_index=meta.block_index,
87 blocks_in_sample=meta.blocks_in_sample,
88 data=data,
89 )
92def split_tensor_into_blocks(
93 tensor: Tensor,
94 block_shape: PerAxis[int],
95 *,
96 halo: PerAxis[HaloLike],
97 stride: Optional[PerAxis[int]] = None,
98 pad_mode: PadMode,
99) -> Tuple[TotalNumberOfBlocks, Generator[Block, Any, None]]:
100 """divide a sample tensor into tensor blocks."""
101 n_blocks, block_gen = split_shape_into_blocks(
102 tensor.tagged_shape, block_shape=block_shape, halo=halo, stride=stride
103 )
104 return n_blocks, _block_generator(tensor, block_gen, pad_mode=pad_mode)
107def _block_generator(sample: Tensor, blocks: Iterable[BlockMeta], *, pad_mode: PadMode):
108 for block in blocks:
109 yield Block.from_sample_member(sample, block, pad_mode=pad_mode)