Coverage for bioimageio/core/sample.py: 93%
118 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 09:51 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 09:51 +0000
1from __future__ import annotations
3from dataclasses import dataclass
4from math import ceil, floor
5from typing import (
6 Any,
7 Callable,
8 Dict,
9 Generic,
10 Iterable,
11 Optional,
12 Tuple,
13 TypeVar,
14 Union,
15)
17import numpy as np
18from numpy.typing import NDArray
19from typing_extensions import Self
21from .axis import AxisId, PerAxis
22from .block import Block
23from .block_meta import (
24 BlockMeta,
25 LinearAxisTransform,
26 split_multiple_shapes_into_blocks,
27)
28from .common import (
29 BlockIndex,
30 Halo,
31 HaloLike,
32 MemberId,
33 PadMode,
34 PerMember,
35 SampleId,
36 SliceInfo,
37 TotalNumberOfBlocks,
38)
39from .stat_measures import Stat
40from .tensor import Tensor
42# TODO: allow for lazy samples to read/write to disk
45@dataclass
46class Sample:
47 """A dataset sample.
49 A `Sample` has `members`, which allows to combine multiple tensors into a single
50 sample.
51 For example a `Sample` from a dataset with masked images may contain a
52 `MemberId("raw")` and `MemberId("mask")` image.
53 """
55 members: Dict[MemberId, Tensor]
56 """The sample's tensors"""
58 stat: Stat
59 """Sample and dataset statistics"""
61 id: SampleId
62 """Identifies the `Sample` within the dataset -- typically a number or a string."""
64 @property
65 def shape(self) -> PerMember[PerAxis[int]]:
66 return {tid: t.sizes for tid, t in self.members.items()}
68 def as_arrays(self) -> Dict[str, NDArray[Any]]:
69 """Return sample as dictionary of arrays."""
70 return {str(m): t.data.to_numpy() for m, t in self.members.items()}
72 def split_into_blocks(
73 self,
74 block_shapes: PerMember[PerAxis[int]],
75 halo: PerMember[PerAxis[HaloLike]],
76 pad_mode: PadMode,
77 broadcast: bool = False,
78 ) -> Tuple[TotalNumberOfBlocks, Iterable[SampleBlockWithOrigin]]:
79 assert not (
80 missing := [m for m in block_shapes if m not in self.members]
81 ), f"`block_shapes` specified for unknown members: {missing}"
82 assert not (
83 missing := [m for m in halo if m not in block_shapes]
84 ), f"`halo` specified for members without `block_shape`: {missing}"
86 n_blocks, blocks = split_multiple_shapes_into_blocks(
87 shapes=self.shape,
88 block_shapes=block_shapes,
89 halo=halo,
90 broadcast=broadcast,
91 )
92 return n_blocks, sample_block_generator(blocks, origin=self, pad_mode=pad_mode)
94 def as_single_block(self, halo: Optional[PerMember[PerAxis[Halo]]] = None):
95 if halo is None:
96 halo = {}
97 return SampleBlockWithOrigin(
98 sample_shape=self.shape,
99 sample_id=self.id,
100 blocks={
101 m: Block(
102 sample_shape=self.shape[m],
103 data=data,
104 inner_slice={
105 a: SliceInfo(0, s) for a, s in data.tagged_shape.items()
106 },
107 halo=halo.get(m, {}),
108 block_index=0,
109 blocks_in_sample=1,
110 )
111 for m, data in self.members.items()
112 },
113 stat=self.stat,
114 origin=self,
115 block_index=0,
116 blocks_in_sample=1,
117 )
119 @classmethod
120 def from_blocks(
121 cls,
122 sample_blocks: Iterable[SampleBlock],
123 *,
124 fill_value: float = float("nan"),
125 ) -> Self:
126 members: PerMember[Tensor] = {}
127 stat: Stat = {}
128 sample_id = None
129 for sample_block in sample_blocks:
130 assert sample_id is None or sample_id == sample_block.sample_id
131 sample_id = sample_block.sample_id
132 stat = sample_block.stat
133 for m, block in sample_block.blocks.items():
134 if m not in members:
135 if -1 in block.sample_shape.values():
136 raise NotImplementedError(
137 "merging blocks with data dependent axis not yet implemented"
138 )
140 members[m] = Tensor(
141 np.full(
142 tuple(block.sample_shape[a] for a in block.data.dims),
143 fill_value,
144 dtype=block.data.dtype,
145 ),
146 dims=block.data.dims,
147 )
149 members[m][block.inner_slice] = block.inner_data
151 return cls(members=members, stat=stat, id=sample_id)
154BlockT = TypeVar("BlockT", Block, BlockMeta)
157@dataclass
158class SampleBlockBase(Generic[BlockT]):
159 """base class for `SampleBlockMeta` and `SampleBlock`"""
161 sample_shape: PerMember[PerAxis[int]]
162 """the sample shape this block represents a part of"""
164 sample_id: SampleId
165 """identifier for the sample within its dataset"""
167 blocks: Dict[MemberId, BlockT]
168 """Individual tensor blocks comprising this sample block"""
170 block_index: BlockIndex
171 """the n-th block of the sample"""
173 blocks_in_sample: TotalNumberOfBlocks
174 """total number of blocks in the sample"""
176 @property
177 def shape(self) -> PerMember[PerAxis[int]]:
178 return {mid: b.shape for mid, b in self.blocks.items()}
180 @property
181 def inner_shape(self) -> PerMember[PerAxis[int]]:
182 return {mid: b.inner_shape for mid, b in self.blocks.items()}
185@dataclass
186class LinearSampleAxisTransform(LinearAxisTransform):
187 member: MemberId
190@dataclass
191class SampleBlockMeta(SampleBlockBase[BlockMeta]):
192 """Meta data of a dataset sample block"""
194 def get_transformed(
195 self, new_axes: PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]]
196 ) -> Self:
197 sample_shape = {
198 m: {
199 a: (
200 trf
201 if isinstance(trf, int)
202 else trf.compute(self.sample_shape[trf.member][trf.axis])
203 )
204 for a, trf in new_axes[m].items()
205 }
206 for m in new_axes
207 }
209 def get_member_halo(m: MemberId, round: Callable[[float], int]):
210 return {
211 a: (
212 Halo(0, 0)
213 if isinstance(trf, int)
214 or trf.axis not in self.blocks[trf.member].halo
215 else Halo(
216 round(self.blocks[trf.member].halo[trf.axis].left * trf.scale),
217 round(self.blocks[trf.member].halo[trf.axis].right * trf.scale),
218 )
219 )
220 for a, trf in new_axes[m].items()
221 }
223 halo: Dict[MemberId, Dict[AxisId, Halo]] = {}
224 for m in new_axes:
225 halo[m] = get_member_halo(m, floor)
226 if halo[m] != get_member_halo(m, ceil):
227 raise ValueError(
228 f"failed to unambiguously scale halo {halo[m]} with {new_axes[m]}"
229 + f" for {m}."
230 )
232 inner_slice = {
233 m: {
234 a: (
235 SliceInfo(0, trf)
236 if isinstance(trf, int)
237 else SliceInfo(
238 trf.compute(
239 self.blocks[trf.member].inner_slice[trf.axis].start
240 ),
241 trf.compute(self.blocks[trf.member].inner_slice[trf.axis].stop),
242 )
243 )
244 for a, trf in new_axes[m].items()
245 }
246 for m in new_axes
247 }
248 return self.__class__(
249 blocks={
250 m: BlockMeta(
251 sample_shape=sample_shape[m],
252 inner_slice=inner_slice[m],
253 halo=halo[m],
254 block_index=self.block_index,
255 blocks_in_sample=self.blocks_in_sample,
256 )
257 for m in new_axes
258 },
259 sample_shape=sample_shape,
260 sample_id=self.sample_id,
261 block_index=self.block_index,
262 blocks_in_sample=self.blocks_in_sample,
263 )
265 def with_data(self, data: PerMember[Tensor], *, stat: Stat) -> SampleBlock:
266 return SampleBlock(
267 sample_shape={
268 m: {
269 a: data[m].tagged_shape[a] if s == -1 else s
270 for a, s in member_shape.items()
271 }
272 for m, member_shape in self.sample_shape.items()
273 },
274 sample_id=self.sample_id,
275 blocks={
276 m: Block.from_meta(b, data=data[m]) for m, b in self.blocks.items()
277 },
278 stat=stat,
279 block_index=self.block_index,
280 blocks_in_sample=self.blocks_in_sample,
281 )
284@dataclass
285class SampleBlock(SampleBlockBase[Block]):
286 """A block of a dataset sample"""
288 stat: Stat
289 """computed statistics"""
291 @property
292 def members(self) -> PerMember[Tensor]:
293 """the sample block's tensors"""
294 return {m: b.data for m, b in self.blocks.items()}
296 def get_transformed_meta(
297 self, new_axes: PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]]
298 ) -> SampleBlockMeta:
299 return SampleBlockMeta(
300 sample_id=self.sample_id,
301 blocks=dict(self.blocks),
302 sample_shape=self.sample_shape,
303 block_index=self.block_index,
304 blocks_in_sample=self.blocks_in_sample,
305 ).get_transformed(new_axes)
308@dataclass
309class SampleBlockWithOrigin(SampleBlock):
310 """A `SampleBlock` with a reference (`origin`) to the whole `Sample`"""
312 origin: Sample
313 """the sample this sample block was taken from"""
316class _ConsolidatedMemberBlocks:
317 def __init__(self, blocks: PerMember[BlockMeta]):
318 super().__init__()
319 block_indices = {b.block_index for b in blocks.values()}
320 assert len(block_indices) == 1
321 self.block_index = block_indices.pop()
322 blocks_in_samples = {b.blocks_in_sample for b in blocks.values()}
323 assert len(blocks_in_samples) == 1
324 self.blocks_in_sample = blocks_in_samples.pop()
327def sample_block_meta_generator(
328 blocks: Iterable[PerMember[BlockMeta]],
329 *,
330 sample_shape: PerMember[PerAxis[int]],
331 sample_id: SampleId,
332):
333 for member_blocks in blocks:
334 cons = _ConsolidatedMemberBlocks(member_blocks)
335 yield SampleBlockMeta(
336 blocks=dict(member_blocks),
337 sample_shape=sample_shape,
338 sample_id=sample_id,
339 block_index=cons.block_index,
340 blocks_in_sample=cons.blocks_in_sample,
341 )
344def sample_block_generator(
345 blocks: Iterable[PerMember[BlockMeta]],
346 *,
347 origin: Sample,
348 pad_mode: PadMode,
349) -> Iterable[SampleBlockWithOrigin]:
350 for member_blocks in blocks:
351 cons = _ConsolidatedMemberBlocks(member_blocks)
352 yield SampleBlockWithOrigin(
353 blocks={
354 m: Block.from_sample_member(
355 origin.members[m], block=member_blocks[m], pad_mode=pad_mode
356 )
357 for m in origin.members
358 },
359 sample_shape=origin.shape,
360 origin=origin,
361 stat=origin.stat,
362 sample_id=origin.id,
363 block_index=cons.block_index,
364 blocks_in_sample=cons.blocks_in_sample,
365 )