Coverage for bioimageio/core/sample.py: 94%
115 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-19 09:02 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-19 09:02 +0000
1from __future__ import annotations
3from dataclasses import dataclass
4from math import ceil, floor
5from typing import (
6 Callable,
7 Dict,
8 Generic,
9 Iterable,
10 Optional,
11 Tuple,
12 TypeVar,
13 Union,
14)
16import numpy as np
17from typing_extensions import Self
19from .axis import AxisId, PerAxis
20from .block import Block
21from .block_meta import (
22 BlockMeta,
23 LinearAxisTransform,
24 split_multiple_shapes_into_blocks,
25)
26from .common import (
27 BlockIndex,
28 Halo,
29 HaloLike,
30 MemberId,
31 PadMode,
32 PerMember,
33 SampleId,
34 SliceInfo,
35 TotalNumberOfBlocks,
36)
37from .stat_measures import Stat
38from .tensor import Tensor
40# TODO: allow for lazy samples to read/write to disk
43@dataclass
44class Sample:
45 """A dataset sample"""
47 members: Dict[MemberId, Tensor]
48 """the sample's tensors"""
50 stat: Stat
51 """sample and dataset statistics"""
53 id: SampleId
54 """identifier within the sample's dataset"""
56 @property
57 def shape(self) -> PerMember[PerAxis[int]]:
58 return {tid: t.sizes for tid, t in self.members.items()}
60 def split_into_blocks(
61 self,
62 block_shapes: PerMember[PerAxis[int]],
63 halo: PerMember[PerAxis[HaloLike]],
64 pad_mode: PadMode,
65 broadcast: bool = False,
66 ) -> Tuple[TotalNumberOfBlocks, Iterable[SampleBlockWithOrigin]]:
67 assert not (
68 missing := [m for m in block_shapes if m not in self.members]
69 ), f"`block_shapes` specified for unknown members: {missing}"
70 assert not (
71 missing := [m for m in halo if m not in block_shapes]
72 ), f"`halo` specified for members without `block_shape`: {missing}"
74 n_blocks, blocks = split_multiple_shapes_into_blocks(
75 shapes=self.shape,
76 block_shapes=block_shapes,
77 halo=halo,
78 broadcast=broadcast,
79 )
80 return n_blocks, sample_block_generator(blocks, origin=self, pad_mode=pad_mode)
82 def as_single_block(self, halo: Optional[PerMember[PerAxis[Halo]]] = None):
83 if halo is None:
84 halo = {}
85 return SampleBlockWithOrigin(
86 sample_shape=self.shape,
87 sample_id=self.id,
88 blocks={
89 m: Block(
90 sample_shape=self.shape[m],
91 data=data,
92 inner_slice={
93 a: SliceInfo(0, s) for a, s in data.tagged_shape.items()
94 },
95 halo=halo.get(m, {}),
96 block_index=0,
97 blocks_in_sample=1,
98 )
99 for m, data in self.members.items()
100 },
101 stat=self.stat,
102 origin=self,
103 block_index=0,
104 blocks_in_sample=1,
105 )
107 @classmethod
108 def from_blocks(
109 cls,
110 sample_blocks: Iterable[SampleBlock],
111 *,
112 fill_value: float = float("nan"),
113 ) -> Self:
114 members: PerMember[Tensor] = {}
115 stat: Stat = {}
116 sample_id = None
117 for sample_block in sample_blocks:
118 assert sample_id is None or sample_id == sample_block.sample_id
119 sample_id = sample_block.sample_id
120 stat = sample_block.stat
121 for m, block in sample_block.blocks.items():
122 if m not in members:
123 if -1 in block.sample_shape.values():
124 raise NotImplementedError(
125 "merging blocks with data dependent axis not yet implemented"
126 )
128 members[m] = Tensor(
129 np.full(
130 tuple(block.sample_shape[a] for a in block.data.dims),
131 fill_value,
132 dtype=block.data.dtype,
133 ),
134 dims=block.data.dims,
135 )
137 members[m][block.inner_slice] = block.inner_data
139 return cls(members=members, stat=stat, id=sample_id)
142BlockT = TypeVar("BlockT", Block, BlockMeta)
145@dataclass
146class SampleBlockBase(Generic[BlockT]):
147 """base class for `SampleBlockMeta` and `SampleBlock`"""
149 sample_shape: PerMember[PerAxis[int]]
150 """the sample shape this block represents a part of"""
152 sample_id: SampleId
153 """identifier for the sample within its dataset"""
155 blocks: Dict[MemberId, BlockT]
156 """Individual tensor blocks comprising this sample block"""
158 block_index: BlockIndex
159 """the n-th block of the sample"""
161 blocks_in_sample: TotalNumberOfBlocks
162 """total number of blocks in the sample"""
164 @property
165 def shape(self) -> PerMember[PerAxis[int]]:
166 return {mid: b.shape for mid, b in self.blocks.items()}
168 @property
169 def inner_shape(self) -> PerMember[PerAxis[int]]:
170 return {mid: b.inner_shape for mid, b in self.blocks.items()}
173@dataclass
174class LinearSampleAxisTransform(LinearAxisTransform):
175 member: MemberId
178@dataclass
179class SampleBlockMeta(SampleBlockBase[BlockMeta]):
180 """Meta data of a dataset sample block"""
182 def get_transformed(
183 self, new_axes: PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]]
184 ) -> Self:
185 sample_shape = {
186 m: {
187 a: (
188 trf
189 if isinstance(trf, int)
190 else trf.compute(self.sample_shape[trf.member][trf.axis])
191 )
192 for a, trf in new_axes[m].items()
193 }
194 for m in new_axes
195 }
197 def get_member_halo(m: MemberId, round: Callable[[float], int]):
198 return {
199 a: (
200 Halo(0, 0)
201 if isinstance(trf, int)
202 or trf.axis not in self.blocks[trf.member].halo
203 else Halo(
204 round(self.blocks[trf.member].halo[trf.axis].left * trf.scale),
205 round(self.blocks[trf.member].halo[trf.axis].right * trf.scale),
206 )
207 )
208 for a, trf in new_axes[m].items()
209 }
211 halo: Dict[MemberId, Dict[AxisId, Halo]] = {}
212 for m in new_axes:
213 halo[m] = get_member_halo(m, floor)
214 if halo[m] != get_member_halo(m, ceil):
215 raise ValueError(
216 f"failed to unambiguously scale halo {halo[m]} with {new_axes[m]}"
217 + f" for {m}."
218 )
220 inner_slice = {
221 m: {
222 a: (
223 SliceInfo(0, trf)
224 if isinstance(trf, int)
225 else SliceInfo(
226 trf.compute(
227 self.blocks[trf.member].inner_slice[trf.axis].start
228 ),
229 trf.compute(self.blocks[trf.member].inner_slice[trf.axis].stop),
230 )
231 )
232 for a, trf in new_axes[m].items()
233 }
234 for m in new_axes
235 }
236 return self.__class__(
237 blocks={
238 m: BlockMeta(
239 sample_shape=sample_shape[m],
240 inner_slice=inner_slice[m],
241 halo=halo[m],
242 block_index=self.block_index,
243 blocks_in_sample=self.blocks_in_sample,
244 )
245 for m in new_axes
246 },
247 sample_shape=sample_shape,
248 sample_id=self.sample_id,
249 block_index=self.block_index,
250 blocks_in_sample=self.blocks_in_sample,
251 )
253 def with_data(self, data: PerMember[Tensor], *, stat: Stat) -> SampleBlock:
254 return SampleBlock(
255 sample_shape={
256 m: {
257 a: data[m].tagged_shape[a] if s == -1 else s
258 for a, s in member_shape.items()
259 }
260 for m, member_shape in self.sample_shape.items()
261 },
262 sample_id=self.sample_id,
263 blocks={
264 m: Block.from_meta(b, data=data[m]) for m, b in self.blocks.items()
265 },
266 stat=stat,
267 block_index=self.block_index,
268 blocks_in_sample=self.blocks_in_sample,
269 )
272@dataclass
273class SampleBlock(SampleBlockBase[Block]):
274 """A block of a dataset sample"""
276 stat: Stat
277 """computed statistics"""
279 @property
280 def members(self) -> PerMember[Tensor]:
281 """the sample block's tensors"""
282 return {m: b.data for m, b in self.blocks.items()}
284 def get_transformed_meta(
285 self, new_axes: PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]]
286 ) -> SampleBlockMeta:
287 return SampleBlockMeta(
288 sample_id=self.sample_id,
289 blocks=dict(self.blocks),
290 sample_shape=self.sample_shape,
291 block_index=self.block_index,
292 blocks_in_sample=self.blocks_in_sample,
293 ).get_transformed(new_axes)
296@dataclass
297class SampleBlockWithOrigin(SampleBlock):
298 """A `SampleBlock` with a reference (`origin`) to the whole `Sample`"""
300 origin: Sample
301 """the sample this sample block was taken from"""
304class _ConsolidatedMemberBlocks:
305 def __init__(self, blocks: PerMember[BlockMeta]):
306 super().__init__()
307 block_indices = {b.block_index for b in blocks.values()}
308 assert len(block_indices) == 1
309 self.block_index = block_indices.pop()
310 blocks_in_samples = {b.blocks_in_sample for b in blocks.values()}
311 assert len(blocks_in_samples) == 1
312 self.blocks_in_sample = blocks_in_samples.pop()
315def sample_block_meta_generator(
316 blocks: Iterable[PerMember[BlockMeta]],
317 *,
318 sample_shape: PerMember[PerAxis[int]],
319 sample_id: SampleId,
320):
321 for member_blocks in blocks:
322 cons = _ConsolidatedMemberBlocks(member_blocks)
323 yield SampleBlockMeta(
324 blocks=dict(member_blocks),
325 sample_shape=sample_shape,
326 sample_id=sample_id,
327 block_index=cons.block_index,
328 blocks_in_sample=cons.blocks_in_sample,
329 )
332def sample_block_generator(
333 blocks: Iterable[PerMember[BlockMeta]],
334 *,
335 origin: Sample,
336 pad_mode: PadMode,
337) -> Iterable[SampleBlockWithOrigin]:
338 for member_blocks in blocks:
339 cons = _ConsolidatedMemberBlocks(member_blocks)
340 yield SampleBlockWithOrigin(
341 blocks={
342 m: Block.from_sample_member(
343 origin.members[m], block=member_blocks[m], pad_mode=pad_mode
344 )
345 for m in origin.members
346 },
347 sample_shape=origin.shape,
348 origin=origin,
349 stat=origin.stat,
350 sample_id=origin.id,
351 block_index=cons.block_index,
352 blocks_in_sample=cons.blocks_in_sample,
353 )