bioimageio.core.block_meta
1import itertools 2from dataclasses import dataclass 3from functools import cached_property 4from math import floor, prod 5from typing import ( 6 Any, 7 Callable, 8 Collection, 9 Dict, 10 Generator, 11 Iterable, 12 List, 13 Optional, 14 Tuple, 15 Union, 16) 17 18from loguru import logger 19from typing_extensions import Self 20 21from .axis import AxisId, PerAxis 22from .common import ( 23 BlockIndex, 24 Frozen, 25 Halo, 26 HaloLike, 27 MemberId, 28 PadWidth, 29 PerMember, 30 SliceInfo, 31 TotalNumberOfBlocks, 32) 33 34 35@dataclass 36class LinearAxisTransform: 37 axis: AxisId 38 scale: float 39 offset: int 40 41 def compute(self, s: int, round: Callable[[float], int] = floor) -> int: 42 return round(s * self.scale) + self.offset 43 44 45@dataclass(frozen=True) 46class BlockMeta: 47 """Block meta data of a sample member (a tensor in a sample) 48 49 Figure for illustration: 50 The first 2d block (dashed) of a sample member (**bold**). 51 The inner slice (thin) is expanded by a halo in both dimensions on both sides. 52 The outer slice reaches from the sample member origin (0, 0) to the right halo point. 53 54 ```terminal 55 ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ 56 ╷ halo(left) ╷ 57 ╷ ╷ 58 ╷ (0, 0)┏━━━━━━━━━━━━━━━━━┯━━━━━━━━━┯━━━➔ 59 ╷ ┃ │ ╷ sample member 60 ╷ ┃ inner │ ╷ 61 ╷ ┃ (and outer) │ outer ╷ 62 ╷ ┃ slice │ slice ╷ 63 ╷ ┃ │ ╷ 64 ╷ ┣─────────────────┘ ╷ 65 ╷ ┃ outer slice ╷ 66 ╷ ┃ halo(right) ╷ 67 └ ─ ─ ─ ─┃─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─┘ 68 ⬇ 69 ``` 70 71 note: 72 - Inner and outer slices are specified in sample member coordinates. 73 - The outer_slice of a block at the sample edge may overlap by more than the 74 halo with the neighboring block (the inner slices will not overlap though). 75 76 """ 77 78 sample_shape: PerAxis[int] 79 """the axis sizes of the whole (unblocked) sample""" 80 81 inner_slice: PerAxis[SliceInfo] 82 """inner region (without halo) wrt the sample""" 83 84 halo: PerAxis[Halo] 85 """halo enlarging the inner region to the block's sizes""" 86 87 block_index: BlockIndex 88 """the i-th block of the sample""" 89 90 blocks_in_sample: TotalNumberOfBlocks 91 """total number of blocks in the sample""" 92 93 @cached_property 94 def shape(self) -> PerAxis[int]: 95 """axis lengths of the block""" 96 return Frozen( 97 { 98 a: s.stop - s.start + (sum(self.halo[a]) if a in self.halo else 0) 99 for a, s in self.inner_slice.items() 100 } 101 ) 102 103 @cached_property 104 def padding(self) -> PerAxis[PadWidth]: 105 """padding to realize the halo at the sample edge 106 where we cannot simply enlarge the inner slice""" 107 return Frozen( 108 { 109 a: PadWidth( 110 ( 111 self.halo[a].left 112 - (self.inner_slice[a].start - self.outer_slice[a].start) 113 if a in self.halo 114 else 0 115 ), 116 ( 117 self.halo[a].right 118 - (self.outer_slice[a].stop - self.inner_slice[a].stop) 119 if a in self.halo 120 else 0 121 ), 122 ) 123 for a in self.inner_slice 124 } 125 ) 126 127 @cached_property 128 def outer_slice(self) -> PerAxis[SliceInfo]: 129 """slice of the outer block (without padding) wrt the sample""" 130 return Frozen( 131 { 132 a: SliceInfo( 133 max( 134 0, 135 min( 136 self.inner_slice[a].start 137 - (self.halo[a].left if a in self.halo else 0), 138 self.sample_shape[a] 139 - self.inner_shape[a] 140 - (self.halo[a].left if a in self.halo else 0), 141 ), 142 ), 143 min( 144 self.sample_shape[a], 145 self.inner_slice[a].stop 146 + (self.halo[a].right if a in self.halo else 0), 147 ), 148 ) 149 for a in self.inner_slice 150 } 151 ) 152 153 @cached_property 154 def inner_shape(self) -> PerAxis[int]: 155 """axis lengths of the inner region (without halo)""" 156 return Frozen({a: s.stop - s.start for a, s in self.inner_slice.items()}) 157 158 @cached_property 159 def local_slice(self) -> PerAxis[SliceInfo]: 160 """inner slice wrt the block, **not** the sample""" 161 return Frozen( 162 { 163 a: SliceInfo( 164 self.halo[a].left, 165 self.halo[a].left + self.inner_shape[a], 166 ) 167 for a in self.inner_slice 168 } 169 ) 170 171 @property 172 def dims(self) -> Collection[AxisId]: 173 return set(self.inner_shape) 174 175 @property 176 def tagged_shape(self) -> PerAxis[int]: 177 """alias for shape""" 178 return self.shape 179 180 @property 181 def inner_slice_wo_overlap(self): 182 """subslice of the inner slice, such that all `inner_slice_wo_overlap` can be 183 stiched together trivially to form the original sample. 184 185 This can also be used to calculate statistics 186 without overrepresenting block edge regions.""" 187 # TODO: update inner_slice_wo_overlap when adding block overlap 188 return self.inner_slice 189 190 def __post_init__(self): 191 # freeze mutable inputs 192 if not isinstance(self.sample_shape, Frozen): 193 object.__setattr__(self, "sample_shape", Frozen(self.sample_shape)) 194 195 if not isinstance(self.inner_slice, Frozen): 196 object.__setattr__(self, "inner_slice", Frozen(self.inner_slice)) 197 198 if not isinstance(self.halo, Frozen): 199 object.__setattr__(self, "halo", Frozen(self.halo)) 200 201 assert all( 202 a in self.sample_shape for a in self.inner_slice 203 ), "block has axes not present in sample" 204 205 assert all( 206 a in self.inner_slice for a in self.halo 207 ), "halo has axes not present in block" 208 209 if any(s > self.sample_shape[a] for a, s in self.shape.items()): 210 logger.warning( 211 "block {} larger than sample {}", self.shape, self.sample_shape 212 ) 213 214 def get_transformed( 215 self, new_axes: PerAxis[Union[LinearAxisTransform, int]] 216 ) -> Self: 217 return self.__class__( 218 sample_shape={ 219 a: ( 220 trf 221 if isinstance(trf, int) 222 else trf.compute(self.sample_shape[trf.axis]) 223 ) 224 for a, trf in new_axes.items() 225 }, 226 inner_slice={ 227 a: ( 228 SliceInfo(0, trf) 229 if isinstance(trf, int) 230 else SliceInfo( 231 trf.compute(self.inner_slice[trf.axis].start), 232 trf.compute(self.inner_slice[trf.axis].stop), 233 ) 234 ) 235 for a, trf in new_axes.items() 236 }, 237 halo={ 238 a: ( 239 Halo(0, 0) 240 if isinstance(trf, int) 241 else Halo(self.halo[trf.axis].left, self.halo[trf.axis].right) 242 ) 243 for a, trf in new_axes.items() 244 }, 245 block_index=self.block_index, 246 blocks_in_sample=self.blocks_in_sample, 247 ) 248 249 250def split_shape_into_blocks( 251 shape: PerAxis[int], 252 block_shape: PerAxis[int], 253 halo: PerAxis[HaloLike], 254 stride: Optional[PerAxis[int]] = None, 255) -> Tuple[TotalNumberOfBlocks, Generator[BlockMeta, Any, None]]: 256 assert all(a in shape for a in block_shape), ( 257 tuple(shape), 258 set(block_shape), 259 ) 260 if any(shape[a] < block_shape[a] for a in block_shape): 261 raise ValueError(f"shape {shape} is smaller than block shape {block_shape}") 262 263 assert all(a in shape for a in halo), (tuple(shape), set(halo)) 264 265 # fill in default halo (0) and block axis length (from tensor shape) 266 halo = {a: Halo.create(halo.get(a, 0)) for a in shape} 267 block_shape = {a: block_shape.get(a, s) for a, s in shape.items()} 268 if stride is None: 269 stride = {} 270 271 inner_1d_slices: Dict[AxisId, List[SliceInfo]] = {} 272 for a, s in shape.items(): 273 inner_size = block_shape[a] - sum(halo[a]) 274 stride_1d = stride.get(a, inner_size) 275 inner_1d_slices[a] = [ 276 SliceInfo(min(p, s - inner_size), min(p + inner_size, s)) 277 for p in range(0, s, stride_1d) 278 ] 279 280 n_blocks = prod(map(len, inner_1d_slices.values())) 281 282 return n_blocks, _block_meta_generator( 283 shape, 284 blocks_in_sample=n_blocks, 285 inner_1d_slices=inner_1d_slices, 286 halo=halo, 287 ) 288 289 290def _block_meta_generator( 291 sample_shape: PerAxis[int], 292 *, 293 blocks_in_sample: int, 294 inner_1d_slices: Dict[AxisId, List[SliceInfo]], 295 halo: PerAxis[HaloLike], 296): 297 assert all(a in sample_shape for a in halo) 298 299 halo = {a: Halo.create(halo.get(a, 0)) for a in inner_1d_slices} 300 for i, nd_tile in enumerate(itertools.product(*inner_1d_slices.values())): 301 inner_slice: PerAxis[SliceInfo] = dict(zip(inner_1d_slices, nd_tile)) 302 303 yield BlockMeta( 304 sample_shape=sample_shape, 305 inner_slice=inner_slice, 306 halo=halo, 307 block_index=i, 308 blocks_in_sample=blocks_in_sample, 309 ) 310 311 312def split_multiple_shapes_into_blocks( 313 shapes: PerMember[PerAxis[int]], 314 block_shapes: PerMember[PerAxis[int]], 315 *, 316 halo: PerMember[PerAxis[HaloLike]], 317 strides: Optional[PerMember[PerAxis[int]]] = None, 318 broadcast: bool = False, 319) -> Tuple[TotalNumberOfBlocks, Iterable[PerMember[BlockMeta]]]: 320 if unknown_blocks := [t for t in block_shapes if t not in shapes]: 321 raise ValueError( 322 f"block shape specified for unknown tensors: {unknown_blocks}." 323 ) 324 325 if not block_shapes: 326 block_shapes = shapes 327 328 if not broadcast and ( 329 missing_blocks := [t for t in shapes if t not in block_shapes] 330 ): 331 raise ValueError( 332 f"no block shape specified for {missing_blocks}." 333 + " Set `broadcast` to True if these tensors should be repeated" 334 + " as a whole for each block." 335 ) 336 337 if extra_halo := [t for t in halo if t not in block_shapes]: 338 raise ValueError( 339 f"`halo` specified for tensors without block shape: {extra_halo}." 340 ) 341 342 if strides is None: 343 strides = {} 344 345 assert not ( 346 unknown_block := [t for t in strides if t not in block_shapes] 347 ), f"`stride` specified for tensors without block shape: {unknown_block}" 348 349 blocks: Dict[MemberId, Iterable[BlockMeta]] = {} 350 n_blocks: Dict[MemberId, TotalNumberOfBlocks] = {} 351 for t in block_shapes: 352 n_blocks[t], blocks[t] = split_shape_into_blocks( 353 shape=shapes[t], 354 block_shape=block_shapes[t], 355 halo=halo.get(t, {}), 356 stride=strides.get(t), 357 ) 358 assert n_blocks[t] > 0, n_blocks 359 360 assert len(blocks) > 0, blocks 361 assert len(n_blocks) > 0, n_blocks 362 unique_n_blocks = set(n_blocks.values()) 363 n = max(unique_n_blocks) 364 if len(unique_n_blocks) == 2 and 1 in unique_n_blocks: 365 if not broadcast: 366 raise ValueError( 367 "Mismatch for total number of blocks due to unsplit (single block)" 368 + f" tensors: {n_blocks}. Set `broadcast` to True if you want to" 369 + " repeat unsplit (single block) tensors." 370 ) 371 372 blocks = { 373 t: _repeat_single_block(block_gen, n) if n_blocks[t] == 1 else block_gen 374 for t, block_gen in blocks.items() 375 } 376 elif len(unique_n_blocks) != 1: 377 raise ValueError(f"Mismatch for total number of blocks: {n_blocks}") 378 379 return n, _aligned_blocks_generator(n, blocks) 380 381 382def _aligned_blocks_generator( 383 n: TotalNumberOfBlocks, blocks: Dict[MemberId, Iterable[BlockMeta]] 384): 385 iterators = {t: iter(gen) for t, gen in blocks.items()} 386 for _ in range(n): 387 yield {t: next(it) for t, it in iterators.items()} 388 389 390def _repeat_single_block(block_generator: Iterable[BlockMeta], n: TotalNumberOfBlocks): 391 round_two = False 392 for block in block_generator: 393 assert not round_two 394 for _ in range(n): 395 yield block 396 397 round_two = True
36@dataclass 37class LinearAxisTransform: 38 axis: AxisId 39 scale: float 40 offset: int 41 42 def compute(self, s: int, round: Callable[[float], int] = floor) -> int: 43 return round(s * self.scale) + self.offset
46@dataclass(frozen=True) 47class BlockMeta: 48 """Block meta data of a sample member (a tensor in a sample) 49 50 Figure for illustration: 51 The first 2d block (dashed) of a sample member (**bold**). 52 The inner slice (thin) is expanded by a halo in both dimensions on both sides. 53 The outer slice reaches from the sample member origin (0, 0) to the right halo point. 54 55 ```terminal 56 ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ 57 ╷ halo(left) ╷ 58 ╷ ╷ 59 ╷ (0, 0)┏━━━━━━━━━━━━━━━━━┯━━━━━━━━━┯━━━➔ 60 ╷ ┃ │ ╷ sample member 61 ╷ ┃ inner │ ╷ 62 ╷ ┃ (and outer) │ outer ╷ 63 ╷ ┃ slice │ slice ╷ 64 ╷ ┃ │ ╷ 65 ╷ ┣─────────────────┘ ╷ 66 ╷ ┃ outer slice ╷ 67 ╷ ┃ halo(right) ╷ 68 └ ─ ─ ─ ─┃─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─┘ 69 ⬇ 70 ``` 71 72 note: 73 - Inner and outer slices are specified in sample member coordinates. 74 - The outer_slice of a block at the sample edge may overlap by more than the 75 halo with the neighboring block (the inner slices will not overlap though). 76 77 """ 78 79 sample_shape: PerAxis[int] 80 """the axis sizes of the whole (unblocked) sample""" 81 82 inner_slice: PerAxis[SliceInfo] 83 """inner region (without halo) wrt the sample""" 84 85 halo: PerAxis[Halo] 86 """halo enlarging the inner region to the block's sizes""" 87 88 block_index: BlockIndex 89 """the i-th block of the sample""" 90 91 blocks_in_sample: TotalNumberOfBlocks 92 """total number of blocks in the sample""" 93 94 @cached_property 95 def shape(self) -> PerAxis[int]: 96 """axis lengths of the block""" 97 return Frozen( 98 { 99 a: s.stop - s.start + (sum(self.halo[a]) if a in self.halo else 0) 100 for a, s in self.inner_slice.items() 101 } 102 ) 103 104 @cached_property 105 def padding(self) -> PerAxis[PadWidth]: 106 """padding to realize the halo at the sample edge 107 where we cannot simply enlarge the inner slice""" 108 return Frozen( 109 { 110 a: PadWidth( 111 ( 112 self.halo[a].left 113 - (self.inner_slice[a].start - self.outer_slice[a].start) 114 if a in self.halo 115 else 0 116 ), 117 ( 118 self.halo[a].right 119 - (self.outer_slice[a].stop - self.inner_slice[a].stop) 120 if a in self.halo 121 else 0 122 ), 123 ) 124 for a in self.inner_slice 125 } 126 ) 127 128 @cached_property 129 def outer_slice(self) -> PerAxis[SliceInfo]: 130 """slice of the outer block (without padding) wrt the sample""" 131 return Frozen( 132 { 133 a: SliceInfo( 134 max( 135 0, 136 min( 137 self.inner_slice[a].start 138 - (self.halo[a].left if a in self.halo else 0), 139 self.sample_shape[a] 140 - self.inner_shape[a] 141 - (self.halo[a].left if a in self.halo else 0), 142 ), 143 ), 144 min( 145 self.sample_shape[a], 146 self.inner_slice[a].stop 147 + (self.halo[a].right if a in self.halo else 0), 148 ), 149 ) 150 for a in self.inner_slice 151 } 152 ) 153 154 @cached_property 155 def inner_shape(self) -> PerAxis[int]: 156 """axis lengths of the inner region (without halo)""" 157 return Frozen({a: s.stop - s.start for a, s in self.inner_slice.items()}) 158 159 @cached_property 160 def local_slice(self) -> PerAxis[SliceInfo]: 161 """inner slice wrt the block, **not** the sample""" 162 return Frozen( 163 { 164 a: SliceInfo( 165 self.halo[a].left, 166 self.halo[a].left + self.inner_shape[a], 167 ) 168 for a in self.inner_slice 169 } 170 ) 171 172 @property 173 def dims(self) -> Collection[AxisId]: 174 return set(self.inner_shape) 175 176 @property 177 def tagged_shape(self) -> PerAxis[int]: 178 """alias for shape""" 179 return self.shape 180 181 @property 182 def inner_slice_wo_overlap(self): 183 """subslice of the inner slice, such that all `inner_slice_wo_overlap` can be 184 stiched together trivially to form the original sample. 185 186 This can also be used to calculate statistics 187 without overrepresenting block edge regions.""" 188 # TODO: update inner_slice_wo_overlap when adding block overlap 189 return self.inner_slice 190 191 def __post_init__(self): 192 # freeze mutable inputs 193 if not isinstance(self.sample_shape, Frozen): 194 object.__setattr__(self, "sample_shape", Frozen(self.sample_shape)) 195 196 if not isinstance(self.inner_slice, Frozen): 197 object.__setattr__(self, "inner_slice", Frozen(self.inner_slice)) 198 199 if not isinstance(self.halo, Frozen): 200 object.__setattr__(self, "halo", Frozen(self.halo)) 201 202 assert all( 203 a in self.sample_shape for a in self.inner_slice 204 ), "block has axes not present in sample" 205 206 assert all( 207 a in self.inner_slice for a in self.halo 208 ), "halo has axes not present in block" 209 210 if any(s > self.sample_shape[a] for a, s in self.shape.items()): 211 logger.warning( 212 "block {} larger than sample {}", self.shape, self.sample_shape 213 ) 214 215 def get_transformed( 216 self, new_axes: PerAxis[Union[LinearAxisTransform, int]] 217 ) -> Self: 218 return self.__class__( 219 sample_shape={ 220 a: ( 221 trf 222 if isinstance(trf, int) 223 else trf.compute(self.sample_shape[trf.axis]) 224 ) 225 for a, trf in new_axes.items() 226 }, 227 inner_slice={ 228 a: ( 229 SliceInfo(0, trf) 230 if isinstance(trf, int) 231 else SliceInfo( 232 trf.compute(self.inner_slice[trf.axis].start), 233 trf.compute(self.inner_slice[trf.axis].stop), 234 ) 235 ) 236 for a, trf in new_axes.items() 237 }, 238 halo={ 239 a: ( 240 Halo(0, 0) 241 if isinstance(trf, int) 242 else Halo(self.halo[trf.axis].left, self.halo[trf.axis].right) 243 ) 244 for a, trf in new_axes.items() 245 }, 246 block_index=self.block_index, 247 blocks_in_sample=self.blocks_in_sample, 248 )
Block meta data of a sample member (a tensor in a sample)
Figure for illustration: The first 2d block (dashed) of a sample member (bold). The inner slice (thin) is expanded by a halo in both dimensions on both sides. The outer slice reaches from the sample member origin (0, 0) to the right halo point.
┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐
╷ halo(left) ╷
╷ ╷
╷ (0, 0)┏━━━━━━━━━━━━━━━━━┯━━━━━━━━━┯━━━➔
╷ ┃ │ ╷ sample member
╷ ┃ inner │ ╷
╷ ┃ (and outer) │ outer ╷
╷ ┃ slice │ slice ╷
╷ ┃ │ ╷
╷ ┣─────────────────┘ ╷
╷ ┃ outer slice ╷
╷ ┃ halo(right) ╷
└ ─ ─ ─ ─┃─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─┘
⬇
note:
- Inner and outer slices are specified in sample member coordinates.
- The outer_slice of a block at the sample edge may overlap by more than the halo with the neighboring block (the inner slices will not overlap though).
the axis sizes of the whole (unblocked) sample
inner region (without halo) wrt the sample
halo enlarging the inner region to the block's sizes
94 @cached_property 95 def shape(self) -> PerAxis[int]: 96 """axis lengths of the block""" 97 return Frozen( 98 { 99 a: s.stop - s.start + (sum(self.halo[a]) if a in self.halo else 0) 100 for a, s in self.inner_slice.items() 101 } 102 )
axis lengths of the block
104 @cached_property 105 def padding(self) -> PerAxis[PadWidth]: 106 """padding to realize the halo at the sample edge 107 where we cannot simply enlarge the inner slice""" 108 return Frozen( 109 { 110 a: PadWidth( 111 ( 112 self.halo[a].left 113 - (self.inner_slice[a].start - self.outer_slice[a].start) 114 if a in self.halo 115 else 0 116 ), 117 ( 118 self.halo[a].right 119 - (self.outer_slice[a].stop - self.inner_slice[a].stop) 120 if a in self.halo 121 else 0 122 ), 123 ) 124 for a in self.inner_slice 125 } 126 )
padding to realize the halo at the sample edge where we cannot simply enlarge the inner slice
128 @cached_property 129 def outer_slice(self) -> PerAxis[SliceInfo]: 130 """slice of the outer block (without padding) wrt the sample""" 131 return Frozen( 132 { 133 a: SliceInfo( 134 max( 135 0, 136 min( 137 self.inner_slice[a].start 138 - (self.halo[a].left if a in self.halo else 0), 139 self.sample_shape[a] 140 - self.inner_shape[a] 141 - (self.halo[a].left if a in self.halo else 0), 142 ), 143 ), 144 min( 145 self.sample_shape[a], 146 self.inner_slice[a].stop 147 + (self.halo[a].right if a in self.halo else 0), 148 ), 149 ) 150 for a in self.inner_slice 151 } 152 )
slice of the outer block (without padding) wrt the sample
154 @cached_property 155 def inner_shape(self) -> PerAxis[int]: 156 """axis lengths of the inner region (without halo)""" 157 return Frozen({a: s.stop - s.start for a, s in self.inner_slice.items()})
axis lengths of the inner region (without halo)
159 @cached_property 160 def local_slice(self) -> PerAxis[SliceInfo]: 161 """inner slice wrt the block, **not** the sample""" 162 return Frozen( 163 { 164 a: SliceInfo( 165 self.halo[a].left, 166 self.halo[a].left + self.inner_shape[a], 167 ) 168 for a in self.inner_slice 169 } 170 )
inner slice wrt the block, not the sample
176 @property 177 def tagged_shape(self) -> PerAxis[int]: 178 """alias for shape""" 179 return self.shape
alias for shape
181 @property 182 def inner_slice_wo_overlap(self): 183 """subslice of the inner slice, such that all `inner_slice_wo_overlap` can be 184 stiched together trivially to form the original sample. 185 186 This can also be used to calculate statistics 187 without overrepresenting block edge regions.""" 188 # TODO: update inner_slice_wo_overlap when adding block overlap 189 return self.inner_slice
subslice of the inner slice, such that all inner_slice_wo_overlap
can be
stiched together trivially to form the original sample.
This can also be used to calculate statistics without overrepresenting block edge regions.
215 def get_transformed( 216 self, new_axes: PerAxis[Union[LinearAxisTransform, int]] 217 ) -> Self: 218 return self.__class__( 219 sample_shape={ 220 a: ( 221 trf 222 if isinstance(trf, int) 223 else trf.compute(self.sample_shape[trf.axis]) 224 ) 225 for a, trf in new_axes.items() 226 }, 227 inner_slice={ 228 a: ( 229 SliceInfo(0, trf) 230 if isinstance(trf, int) 231 else SliceInfo( 232 trf.compute(self.inner_slice[trf.axis].start), 233 trf.compute(self.inner_slice[trf.axis].stop), 234 ) 235 ) 236 for a, trf in new_axes.items() 237 }, 238 halo={ 239 a: ( 240 Halo(0, 0) 241 if isinstance(trf, int) 242 else Halo(self.halo[trf.axis].left, self.halo[trf.axis].right) 243 ) 244 for a, trf in new_axes.items() 245 }, 246 block_index=self.block_index, 247 blocks_in_sample=self.blocks_in_sample, 248 )
251def split_shape_into_blocks( 252 shape: PerAxis[int], 253 block_shape: PerAxis[int], 254 halo: PerAxis[HaloLike], 255 stride: Optional[PerAxis[int]] = None, 256) -> Tuple[TotalNumberOfBlocks, Generator[BlockMeta, Any, None]]: 257 assert all(a in shape for a in block_shape), ( 258 tuple(shape), 259 set(block_shape), 260 ) 261 if any(shape[a] < block_shape[a] for a in block_shape): 262 raise ValueError(f"shape {shape} is smaller than block shape {block_shape}") 263 264 assert all(a in shape for a in halo), (tuple(shape), set(halo)) 265 266 # fill in default halo (0) and block axis length (from tensor shape) 267 halo = {a: Halo.create(halo.get(a, 0)) for a in shape} 268 block_shape = {a: block_shape.get(a, s) for a, s in shape.items()} 269 if stride is None: 270 stride = {} 271 272 inner_1d_slices: Dict[AxisId, List[SliceInfo]] = {} 273 for a, s in shape.items(): 274 inner_size = block_shape[a] - sum(halo[a]) 275 stride_1d = stride.get(a, inner_size) 276 inner_1d_slices[a] = [ 277 SliceInfo(min(p, s - inner_size), min(p + inner_size, s)) 278 for p in range(0, s, stride_1d) 279 ] 280 281 n_blocks = prod(map(len, inner_1d_slices.values())) 282 283 return n_blocks, _block_meta_generator( 284 shape, 285 blocks_in_sample=n_blocks, 286 inner_1d_slices=inner_1d_slices, 287 halo=halo, 288 )
313def split_multiple_shapes_into_blocks( 314 shapes: PerMember[PerAxis[int]], 315 block_shapes: PerMember[PerAxis[int]], 316 *, 317 halo: PerMember[PerAxis[HaloLike]], 318 strides: Optional[PerMember[PerAxis[int]]] = None, 319 broadcast: bool = False, 320) -> Tuple[TotalNumberOfBlocks, Iterable[PerMember[BlockMeta]]]: 321 if unknown_blocks := [t for t in block_shapes if t not in shapes]: 322 raise ValueError( 323 f"block shape specified for unknown tensors: {unknown_blocks}." 324 ) 325 326 if not block_shapes: 327 block_shapes = shapes 328 329 if not broadcast and ( 330 missing_blocks := [t for t in shapes if t not in block_shapes] 331 ): 332 raise ValueError( 333 f"no block shape specified for {missing_blocks}." 334 + " Set `broadcast` to True if these tensors should be repeated" 335 + " as a whole for each block." 336 ) 337 338 if extra_halo := [t for t in halo if t not in block_shapes]: 339 raise ValueError( 340 f"`halo` specified for tensors without block shape: {extra_halo}." 341 ) 342 343 if strides is None: 344 strides = {} 345 346 assert not ( 347 unknown_block := [t for t in strides if t not in block_shapes] 348 ), f"`stride` specified for tensors without block shape: {unknown_block}" 349 350 blocks: Dict[MemberId, Iterable[BlockMeta]] = {} 351 n_blocks: Dict[MemberId, TotalNumberOfBlocks] = {} 352 for t in block_shapes: 353 n_blocks[t], blocks[t] = split_shape_into_blocks( 354 shape=shapes[t], 355 block_shape=block_shapes[t], 356 halo=halo.get(t, {}), 357 stride=strides.get(t), 358 ) 359 assert n_blocks[t] > 0, n_blocks 360 361 assert len(blocks) > 0, blocks 362 assert len(n_blocks) > 0, n_blocks 363 unique_n_blocks = set(n_blocks.values()) 364 n = max(unique_n_blocks) 365 if len(unique_n_blocks) == 2 and 1 in unique_n_blocks: 366 if not broadcast: 367 raise ValueError( 368 "Mismatch for total number of blocks due to unsplit (single block)" 369 + f" tensors: {n_blocks}. Set `broadcast` to True if you want to" 370 + " repeat unsplit (single block) tensors." 371 ) 372 373 blocks = { 374 t: _repeat_single_block(block_gen, n) if n_blocks[t] == 1 else block_gen 375 for t, block_gen in blocks.items() 376 } 377 elif len(unique_n_blocks) != 1: 378 raise ValueError(f"Mismatch for total number of blocks: {n_blocks}") 379 380 return n, _aligned_blocks_generator(n, blocks)