Coverage for src / bioimageio / core / block.py: 83%

36 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-18 12:35 +0000

1from dataclasses import dataclass 

2from typing import ( 

3 Any, 

4 Generator, 

5 Iterable, 

6 Optional, 

7 Tuple, 

8 Union, 

9) 

10 

11from typing_extensions import Self 

12 

13from .axis import PerAxis 

14from .block_meta import BlockMeta, LinearAxisTransform, split_shape_into_blocks 

15from .common import ( 

16 Halo, 

17 HaloLike, 

18 PadMode, 

19 SliceInfo, 

20 TotalNumberOfBlocks, 

21) 

22from .tensor import Tensor 

23 

24 

25@dataclass(frozen=True) 

26class Block(BlockMeta): 

27 """A block/tile of a (larger) tensor""" 

28 

29 data: Tensor 

30 """the block's tensor, e.g. a (padded) slice of some larger, original tensor""" 

31 

32 @property 

33 def inner_data(self): 

34 return self.data[self.local_slice] 

35 

36 def __post_init__(self): 

37 super().__post_init__() 

38 assert not any(v == -1 for v in self.sample_shape.values()), self.sample_shape 

39 for a, s in self.data.sizes.items(): 

40 slice_ = self.local_slice[a] 

41 halo = self.halo.get(a, Halo(0, 0)) 

42 if s != (s_exp := halo.left + (slice_.stop - slice_.start) + halo.right): 

43 raise ValueError( 

44 f"Block data shape does not match block meta information for axis {a}: data size is {s} but expected {s_exp} (halo: {halo}, slice: {slice_})", 

45 ) 

46 

47 @classmethod 

48 def from_sample_member( 

49 cls, 

50 sample_member: Tensor, 

51 block: BlockMeta, 

52 *, 

53 pad_mode: PadMode, 

54 ) -> Self: 

55 return cls( 

56 data=sample_member[block.outer_slice].pad(block.padding, pad_mode), 

57 sample_shape=sample_member.tagged_shape, 

58 inner_slice=block.inner_slice, 

59 halo=block.halo, 

60 block_index=block.block_index, 

61 blocks_in_sample=block.blocks_in_sample, 

62 ) 

63 

64 def get_transformed( 

65 self, new_axes: PerAxis[Union[LinearAxisTransform, int]] 

66 ) -> Self: 

67 raise NotImplementedError 

68 

69 @classmethod 

70 def from_meta(cls, meta: BlockMeta, data: Tensor) -> Self: 

71 return cls( 

72 sample_shape={ 

73 k: data.tagged_shape[k] if v == -1 else v 

74 for k, v in meta.sample_shape.items() 

75 }, 

76 inner_slice={ 

77 k: ( 

78 SliceInfo(start=v.start, stop=data.tagged_shape[k]) 

79 if v.stop == -1 

80 else v 

81 ) 

82 for k, v in meta.inner_slice.items() 

83 }, 

84 halo=meta.halo, 

85 block_index=meta.block_index, 

86 blocks_in_sample=meta.blocks_in_sample, 

87 data=data, 

88 ) 

89 

90 

91def split_tensor_into_blocks( 

92 tensor: Tensor, 

93 block_shape: PerAxis[int], 

94 *, 

95 halo: PerAxis[HaloLike], 

96 stride: Optional[PerAxis[int]] = None, 

97 pad_mode: PadMode, 

98) -> Tuple[TotalNumberOfBlocks, Generator[Block, Any, None]]: 

99 """divide a sample tensor into tensor blocks.""" 

100 n_blocks, block_gen = split_shape_into_blocks( 

101 tensor.tagged_shape, block_shape=block_shape, halo=halo, stride=stride 

102 ) 

103 return n_blocks, _block_generator(tensor, block_gen, pad_mode=pad_mode) 

104 

105 

106def _block_generator(sample: Tensor, blocks: Iterable[BlockMeta], *, pad_mode: PadMode): 

107 for block in blocks: 

108 yield Block.from_sample_member(sample, block, pad_mode=pad_mode)