Coverage for src / bioimageio / core / sample.py: 92%
129 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-18 12:35 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-18 12:35 +0000
1from __future__ import annotations
3import collections.abc
4from dataclasses import dataclass
5from math import ceil, floor
6from typing import (
7 Any,
8 Callable,
9 Dict,
10 Generic,
11 Iterable,
12 Mapping,
13 Optional,
14 Tuple,
15 TypeVar,
16 Union,
17)
19import numpy as np
20import xarray as xr
21from numpy.typing import NDArray
22from typing_extensions import Self
24from .axis import AxisId, PerAxis
25from .block import Block
26from .block_meta import (
27 BlockMeta,
28 LinearAxisTransform,
29 split_multiple_shapes_into_blocks,
30)
31from .common import (
32 BlockIndex,
33 Halo,
34 HaloLike,
35 MemberId,
36 PadMode,
37 PadWidthLike,
38 PerMember,
39 SampleId,
40 SliceInfo,
41 TotalNumberOfBlocks,
42)
43from .stat_measures import Stat
44from .tensor import Tensor
46# TODO: allow for lazy samples to read/write to disk
49@dataclass
50class Sample:
51 """A dataset sample.
53 A `Sample` has `members`, which allows to combine multiple tensors into a single
54 sample.
55 For example a `Sample` from a dataset with masked images may contain a
56 `MemberId("raw")` and `MemberId("mask")` image.
57 """
59 members: Dict[MemberId, Tensor]
60 """The sample's tensors"""
62 stat: Stat
63 """Sample and dataset statistics"""
65 id: SampleId
66 """Identifies the `Sample` within the dataset -- typically a number or a string."""
68 def __getitem__(
69 self,
70 key: PerMember[
71 Union[
72 SliceInfo,
73 slice,
74 int,
75 PerAxis[Union[SliceInfo, slice, int]],
76 Tensor,
77 xr.DataArray,
78 ]
79 ],
80 ) -> Self:
81 return self.__class__(
82 members={m: t[key[m]] for m, t in self.members.items() if m in key},
83 stat=self.stat,
84 id=self.id,
85 )
87 @property
88 def shape(self) -> PerMember[PerAxis[int]]:
89 return {tid: t.sizes for tid, t in self.members.items()}
91 def as_arrays(self) -> Dict[MemberId, NDArray[Any]]:
92 """Return sample as dictionary of arrays."""
93 return {m: t.to_numpy() for m, t in self.members.items()}
95 def split_into_blocks(
96 self,
97 block_shapes: PerMember[PerAxis[int]],
98 halo: PerMember[PerAxis[HaloLike]],
99 pad_mode: Union[PadMode, PerMember[PadMode]],
100 broadcast: bool = False,
101 ) -> Tuple[TotalNumberOfBlocks, Iterable[SampleBlockWithOrigin]]:
102 assert not (missing := [m for m in block_shapes if m not in self.members]), (
103 f"`block_shapes` specified for unknown members: {missing}"
104 )
105 assert not (missing := [m for m in halo if m not in block_shapes]), (
106 f"`halo` specified for members without `block_shape`: {missing}"
107 )
109 n_blocks, blocks = split_multiple_shapes_into_blocks(
110 shapes=self.shape,
111 block_shapes=block_shapes,
112 halo=halo,
113 broadcast=broadcast,
114 )
115 return n_blocks, sample_block_generator(blocks, origin=self, pad_mode=pad_mode)
117 def as_single_block(self, halo: Optional[PerMember[PerAxis[Halo]]] = None):
118 if halo is None:
119 halo = {}
120 return SampleBlockWithOrigin(
121 sample_shape=self.shape,
122 sample_id=self.id,
123 blocks={
124 m: Block(
125 sample_shape=self.shape[m],
126 data=data,
127 inner_slice={
128 a: SliceInfo(0, s) for a, s in data.tagged_shape.items()
129 },
130 halo=halo.get(m, {}),
131 block_index=0,
132 blocks_in_sample=1,
133 )
134 for m, data in self.members.items()
135 },
136 stat=self.stat,
137 origin=self,
138 block_index=0,
139 blocks_in_sample=1,
140 )
142 @classmethod
143 def from_blocks(
144 cls,
145 sample_blocks: Iterable[SampleBlock],
146 *,
147 fill_value: float = float("nan"),
148 ) -> Self:
149 members: PerMember[Tensor] = {}
150 stat: Stat = {}
151 sample_id = None
152 for sample_block in sample_blocks:
153 assert sample_id is None or sample_id == sample_block.sample_id
154 sample_id = sample_block.sample_id
155 stat = sample_block.stat
156 for m, block in sample_block.blocks.items():
157 if m not in members:
158 if -1 in block.sample_shape.values():
159 raise NotImplementedError(
160 "merging blocks with data dependent axis not yet implemented"
161 )
163 members[m] = Tensor(
164 np.full(
165 tuple(block.sample_shape[a] for a in block.data.dims),
166 fill_value,
167 dtype=block.data.dtype,
168 ),
169 dims=block.data.dims,
170 )
172 members[m][block.inner_slice] = block.inner_data
174 return cls(members=members, stat=stat, id=sample_id)
176 def pad(
177 self,
178 pad_width: PerMember[PerAxis[Union[int, PadWidthLike]]],
179 mode: Union[PerMember[PadMode], PadMode],
180 ) -> Self:
181 """Convenience method to pad sample members."""
182 default_mode = "symmetric"
183 if isinstance(mode, collections.abc.Mapping):
184 mode_per_member = mode
185 else:
186 mode_per_member: Mapping[MemberId, PadMode] = {}
187 default_mode = mode
189 return self.__class__(
190 members={
191 m: t.pad(
192 pad_width=pad_width.get(m, {}),
193 mode=mode_per_member.get(m, default_mode),
194 )
195 for m, t in self.members.items()
196 },
197 stat=self.stat,
198 id=self.id,
199 )
202BlockT = TypeVar("BlockT", Block, BlockMeta)
205@dataclass
206class SampleBlockBase(Generic[BlockT]):
207 """base class for `SampleBlockMeta` and `SampleBlock`"""
209 sample_shape: PerMember[PerAxis[int]]
210 """the sample shape this block represents a part of"""
212 sample_id: SampleId
213 """identifier for the sample within its dataset"""
215 blocks: Dict[MemberId, BlockT]
216 """Individual tensor blocks comprising this sample block"""
218 block_index: BlockIndex
219 """the n-th block of the sample"""
221 blocks_in_sample: TotalNumberOfBlocks
222 """total number of blocks in the sample"""
224 @property
225 def shape(self) -> PerMember[PerAxis[int]]:
226 return {mid: b.shape for mid, b in self.blocks.items()}
228 @property
229 def inner_shape(self) -> PerMember[PerAxis[int]]:
230 return {mid: b.inner_shape for mid, b in self.blocks.items()}
233@dataclass
234class LinearSampleAxisTransform(LinearAxisTransform):
235 member: MemberId
238@dataclass
239class SampleBlockMeta(SampleBlockBase[BlockMeta]):
240 """Meta data of a dataset sample block"""
242 def get_transformed(
243 self, new_axes: PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]]
244 ) -> Self:
245 sample_shape = {
246 m: {
247 a: (
248 trf
249 if isinstance(trf, int)
250 else trf.compute(self.sample_shape[trf.member][trf.axis])
251 )
252 for a, trf in new_axes[m].items()
253 }
254 for m in new_axes
255 }
257 def get_member_halo(m: MemberId, round: Callable[[float], int]):
258 return {
259 a: (
260 Halo(0, 0)
261 if isinstance(trf, int)
262 or trf.axis not in self.blocks[trf.member].halo
263 else Halo(
264 round(self.blocks[trf.member].halo[trf.axis].left * trf.scale),
265 round(self.blocks[trf.member].halo[trf.axis].right * trf.scale),
266 )
267 )
268 for a, trf in new_axes[m].items()
269 }
271 halo: Dict[MemberId, Dict[AxisId, Halo]] = {}
272 for m in new_axes:
273 halo[m] = get_member_halo(m, floor)
274 if halo[m] != get_member_halo(m, ceil):
275 raise ValueError(
276 f"failed to unambiguously scale halo {halo[m]} with {new_axes[m]}"
277 + f" for {m}."
278 )
280 inner_slice = {
281 m: {
282 a: (
283 SliceInfo(0, trf)
284 if isinstance(trf, int)
285 else SliceInfo(
286 trf.compute(
287 self.blocks[trf.member].inner_slice[trf.axis].start
288 ),
289 trf.compute(self.blocks[trf.member].inner_slice[trf.axis].stop),
290 )
291 )
292 for a, trf in new_axes[m].items()
293 }
294 for m in new_axes
295 }
296 return self.__class__(
297 blocks={
298 m: BlockMeta(
299 sample_shape=sample_shape[m],
300 inner_slice=inner_slice[m],
301 halo=halo[m],
302 block_index=self.block_index,
303 blocks_in_sample=self.blocks_in_sample,
304 )
305 for m in new_axes
306 },
307 sample_shape=sample_shape,
308 sample_id=self.sample_id,
309 block_index=self.block_index,
310 blocks_in_sample=self.blocks_in_sample,
311 )
313 def with_data(self, data: PerMember[Tensor], *, stat: Stat) -> SampleBlock:
314 return SampleBlock(
315 sample_shape={
316 m: {
317 a: data[m].tagged_shape[a] if s == -1 else s
318 for a, s in member_shape.items()
319 }
320 for m, member_shape in self.sample_shape.items()
321 },
322 sample_id=self.sample_id,
323 blocks={
324 m: Block.from_meta(b, data=data[m]) for m, b in self.blocks.items()
325 },
326 stat=stat,
327 block_index=self.block_index,
328 blocks_in_sample=self.blocks_in_sample,
329 )
332@dataclass
333class SampleBlock(SampleBlockBase[Block]):
334 """A block of a dataset sample"""
336 stat: Stat
337 """computed statistics"""
339 @property
340 def members(self) -> PerMember[Tensor]:
341 """the sample block's tensors"""
342 return {m: b.data for m, b in self.blocks.items()}
344 def get_transformed_meta(
345 self, new_axes: PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]]
346 ) -> SampleBlockMeta:
347 return SampleBlockMeta(
348 sample_id=self.sample_id,
349 blocks=dict(self.blocks),
350 sample_shape=self.sample_shape,
351 block_index=self.block_index,
352 blocks_in_sample=self.blocks_in_sample,
353 ).get_transformed(new_axes)
356@dataclass
357class SampleBlockWithOrigin(SampleBlock):
358 """A `SampleBlock` with a reference (`origin`) to the whole `Sample`"""
360 origin: Sample
361 """the sample this sample block was taken from"""
364class _ConsolidatedMemberBlocks:
365 def __init__(self, blocks: PerMember[BlockMeta]):
366 super().__init__()
367 block_indices = {b.block_index for b in blocks.values()}
368 assert len(block_indices) == 1
369 self.block_index = block_indices.pop()
370 blocks_in_samples = {b.blocks_in_sample for b in blocks.values()}
371 assert len(blocks_in_samples) == 1
372 self.blocks_in_sample = blocks_in_samples.pop()
375def sample_block_meta_generator(
376 blocks: Iterable[PerMember[BlockMeta]],
377 *,
378 sample_shape: PerMember[PerAxis[int]],
379 sample_id: SampleId,
380):
381 for member_blocks in blocks:
382 cons = _ConsolidatedMemberBlocks(member_blocks)
383 yield SampleBlockMeta(
384 blocks=dict(member_blocks),
385 sample_shape=sample_shape,
386 sample_id=sample_id,
387 block_index=cons.block_index,
388 blocks_in_sample=cons.blocks_in_sample,
389 )
392def sample_block_generator(
393 blocks: Iterable[PerMember[BlockMeta]],
394 *,
395 origin: Sample,
396 pad_mode: Union[PadMode, PerMember[PadMode]],
397) -> Iterable[SampleBlockWithOrigin]:
398 for member_blocks in blocks:
399 cons = _ConsolidatedMemberBlocks(member_blocks)
400 yield SampleBlockWithOrigin(
401 blocks={
402 m: Block.from_sample_member(
403 origin.members[m],
404 block=member_blocks[m],
405 pad_mode=pad_mode.get(m, "symmetric")
406 if isinstance(pad_mode, collections.abc.Mapping)
407 else pad_mode,
408 )
409 for m in origin.members
410 },
411 sample_shape=origin.shape,
412 origin=origin,
413 stat=origin.stat,
414 sample_id=origin.id,
415 block_index=cons.block_index,
416 blocks_in_sample=cons.blocks_in_sample,
417 )