Coverage for bioimageio/core/block.py: 86%

35 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-19 09:02 +0000

1from dataclasses import dataclass 

2from typing import ( 

3 Any, 

4 Generator, 

5 Iterable, 

6 Optional, 

7 Tuple, 

8 Union, 

9) 

10 

11from typing_extensions import Self 

12 

13from .axis import PerAxis 

14from .block_meta import BlockMeta, LinearAxisTransform, split_shape_into_blocks 

15from .common import ( 

16 Halo, 

17 HaloLike, 

18 PadMode, 

19 SliceInfo, 

20 TotalNumberOfBlocks, 

21) 

22from .tensor import Tensor 

23 

24 

25@dataclass(frozen=True) 

26class Block(BlockMeta): 

27 """A block/tile of a (larger) tensor""" 

28 

29 data: Tensor 

30 """the block's tensor, e.g. a (padded) slice of some larger, original tensor""" 

31 

32 @property 

33 def inner_data(self): 

34 return self.data[self.local_slice] 

35 

36 def __post_init__(self): 

37 super().__post_init__() 

38 assert not any(v == -1 for v in self.sample_shape.values()), self.sample_shape 

39 for a, s in self.data.sizes.items(): 

40 slice_ = self.inner_slice[a] 

41 halo = self.halo.get(a, Halo(0, 0)) 

42 assert s == halo.left + (slice_.stop - slice_.start) + halo.right, ( 

43 s, 

44 slice_, 

45 halo, 

46 ) 

47 

48 @classmethod 

49 def from_sample_member( 

50 cls, 

51 sample_member: Tensor, 

52 block: BlockMeta, 

53 *, 

54 pad_mode: PadMode, 

55 ) -> Self: 

56 return cls( 

57 data=sample_member[block.outer_slice].pad(block.padding, pad_mode), 

58 sample_shape=sample_member.tagged_shape, 

59 inner_slice=block.inner_slice, 

60 halo=block.halo, 

61 block_index=block.block_index, 

62 blocks_in_sample=block.blocks_in_sample, 

63 ) 

64 

65 def get_transformed( 

66 self, new_axes: PerAxis[Union[LinearAxisTransform, int]] 

67 ) -> Self: 

68 raise NotImplementedError 

69 

70 @classmethod 

71 def from_meta(cls, meta: BlockMeta, data: Tensor) -> Self: 

72 return cls( 

73 sample_shape={ 

74 k: data.tagged_shape[k] if v == -1 else v 

75 for k, v in meta.sample_shape.items() 

76 }, 

77 inner_slice={ 

78 k: ( 

79 SliceInfo(start=v.start, stop=data.tagged_shape[k]) 

80 if v.stop == -1 

81 else v 

82 ) 

83 for k, v in meta.inner_slice.items() 

84 }, 

85 halo=meta.halo, 

86 block_index=meta.block_index, 

87 blocks_in_sample=meta.blocks_in_sample, 

88 data=data, 

89 ) 

90 

91 

92def split_tensor_into_blocks( 

93 tensor: Tensor, 

94 block_shape: PerAxis[int], 

95 *, 

96 halo: PerAxis[HaloLike], 

97 stride: Optional[PerAxis[int]] = None, 

98 pad_mode: PadMode, 

99) -> Tuple[TotalNumberOfBlocks, Generator[Block, Any, None]]: 

100 """divide a sample tensor into tensor blocks.""" 

101 n_blocks, block_gen = split_shape_into_blocks( 

102 tensor.tagged_shape, block_shape=block_shape, halo=halo, stride=stride 

103 ) 

104 return n_blocks, _block_generator(tensor, block_gen, pad_mode=pad_mode) 

105 

106 

107def _block_generator(sample: Tensor, blocks: Iterable[BlockMeta], *, pad_mode: PadMode): 

108 for block in blocks: 

109 yield Block.from_sample_member(sample, block, pad_mode=pad_mode)