Coverage for src/bioimageio/core/sample.py: 93%
159 statements
« prev ^ index » next coverage.py v7.14.2, created at 2026-06-22 16:54 +0000
« prev ^ index » next coverage.py v7.14.2, created at 2026-06-22 16:54 +0000
1from __future__ import annotations
3import collections.abc
4from dataclasses import dataclass
5from math import ceil, floor
6from types import MappingProxyType
7from typing import (
8 Any,
9 Callable,
10 Dict,
11 Generic,
12 Iterable,
13 Mapping,
14 Optional,
15 Tuple,
16 TypeVar,
17 Union,
18)
20import numpy as np
21import pydantic
22import xarray as xr
23from numpy.typing import NDArray
24from typing_extensions import Self
26from ._common_annotations import PerMemberAnno
27from .axis import AxisId, PerAxis
28from .block import Block
29from .block_meta import (
30 BlockMeta,
31 LinearAxisTransform,
32 split_multiple_shapes_into_blocks,
33)
34from .common import (
35 BlockIndex,
36 Halo,
37 HaloLike,
38 MemberId,
39 PadMode,
40 PadWidthLike,
41 PerMember,
42 SampleId,
43 SliceInfo,
44 TotalNumberOfBlocks,
45)
46from .stat_measures import Stat
47from .tensor import Tensor
49# TODO: allow for lazy samples to read/write to disk
52@dataclass
53class Sample:
54 """A dataset sample.
56 A `Sample` has `members`, which allows to combine multiple tensors into a single
57 sample.
58 For example a `Sample` from a dataset with masked images may contain a
59 `MemberId("raw")` and `MemberId("mask")` image.
60 """
62 members: Dict[MemberId, Tensor]
63 """The sample's tensors"""
65 stat: Stat
66 """Sample and dataset statistics"""
68 id: SampleId
69 """Identifies the `Sample` within the dataset -- typically a number or a string."""
71 def __getitem__(
72 self,
73 key: PerMember[
74 Union[
75 SliceInfo,
76 slice,
77 int,
78 PerAxis[Union[SliceInfo, slice, int]],
79 Tensor,
80 xr.DataArray,
81 ]
82 ],
83 ) -> Self:
84 return self.__class__(
85 members={m: t[key[m]] for m, t in self.members.items() if m in key},
86 stat=self.stat,
87 id=self.id,
88 )
90 def set_block(self, block: SampleBlock) -> None:
91 """Set values of `block`.
93 Note:
94 - Updates only existing sample members (extra block members are ignored)
95 - Ignores missing block members (i.e. members in the sample but not in the block are not modified)
97 Raises:
98 ValueError if block and sample members do not overlap at all.
99 """
100 no_overlap = True
101 for m in self.members:
102 if m not in block.blocks:
103 continue
104 b = block.blocks[m]
105 self.members[m][b.inner_slice] = b.inner_data
106 no_overlap = False
108 if no_overlap:
109 raise ValueError(
110 f"block with members {list(block.blocks)} does not overlap with sample members {list(self.members)}"
111 )
113 @property
114 def shape(self) -> PerMember[PerAxis[int]]:
115 return {tid: t.sizes for tid, t in self.members.items()}
117 def as_arrays(self) -> Dict[MemberId, NDArray[Any]]:
118 """Return sample as dictionary of arrays."""
119 return {m: t.to_numpy() for m, t in self.members.items()}
121 def split_into_blocks(
122 self,
123 block_shapes: PerMember[PerAxis[int]],
124 halo: PerMember[PerAxis[HaloLike]],
125 pad_mode: Union[PadMode, PerMember[PadMode]],
126 broadcast: bool = False,
127 ) -> Tuple[TotalNumberOfBlocks, Iterable[SampleBlockWithOrigin]]:
128 assert not (missing := [m for m in block_shapes if m not in self.members]), (
129 f"`block_shapes` specified for unknown members: {missing}"
130 )
131 assert not (missing := [m for m in halo if m not in block_shapes]), (
132 f"`halo` specified for members without `block_shape`: {missing}"
133 )
135 n_blocks, blocks = split_multiple_shapes_into_blocks(
136 shapes=self.shape,
137 block_shapes=block_shapes,
138 halo=halo,
139 broadcast=broadcast,
140 )
141 return n_blocks, sample_block_generator(blocks, origin=self, pad_mode=pad_mode)
143 def as_single_block(self, halo: Optional[PerMember[PerAxis[Halo]]] = None):
144 if halo is None:
145 halo = {}
146 return SampleBlockWithOrigin(
147 sample_shape=self.shape,
148 sample_id=self.id,
149 blocks={
150 m: Block(
151 sample_shape=self.shape[m],
152 data=data,
153 inner_slice={
154 a: SliceInfo(0, s) for a, s in data.tagged_shape.items()
155 },
156 halo=halo.get(m, {}),
157 block_index=0,
158 blocks_in_sample=1,
159 )
160 for m, data in self.members.items()
161 },
162 stat=self.stat,
163 origin=self,
164 block_index=0,
165 blocks_in_sample=1,
166 )
168 @classmethod
169 def from_blocks(
170 cls,
171 sample_blocks: Iterable[SampleBlock],
172 *,
173 fill_value: float = float("nan"),
174 ) -> Self:
175 """Create a `Sample` from an iterable of `SampleBlock`s.
177 Note:
178 All sample blocks must have the same `sample_id`.
180 Args:
181 sample_blocks: The blocks to create the sample from.
182 fill_value: The value to fill missing values with (default: `nan`).
183 """
184 output = None
185 for output in cls.from_blocks_yield_intermediates(
186 sample_blocks, fill_value=fill_value
187 ):
188 pass
190 if output is None:
191 raise ValueError("no sample blocks provided")
193 return output
195 @classmethod
196 def from_blocks_yield_intermediates(
197 cls,
198 sample_blocks: Iterable[SampleBlock],
199 *,
200 fill_value: float = float("nan"),
201 ):
202 """Create a `Sample` from an iterable of `SampleBlock`s, yielding the intermediate sample after each block.
204 Args:
205 sample_blocks: The blocks to create the sample from.
206 fill_value: The value to fill missing values with (default: `nan`).
207 """
208 output = cls(members={}, stat={}, id=None)
209 for sample_block in sample_blocks:
210 if output.id is None:
211 output.id = sample_block.sample_id
212 else:
213 assert output.id == sample_block.sample_id, (
214 "sample id changed between sample blocks"
215 )
217 output.stat = sample_block.stat
219 for m, block in sample_block.blocks.items():
220 if m not in output.members:
221 if -1 in block.sample_shape.values():
222 raise NotImplementedError(
223 "merging blocks with data dependent axis not yet implemented"
224 )
226 output.members[m] = Tensor(
227 np.full(
228 tuple(block.sample_shape[a] for a in block.data.dims),
229 fill_value,
230 dtype=block.data.dtype,
231 ),
232 dims=block.data.dims,
233 )
235 output.members[m][block.inner_slice] = block.inner_data
236 yield output
238 yield output
240 def pad(
241 self,
242 pad_width: PerMember[PerAxis[Union[int, PadWidthLike]]],
243 mode: Union[PerMember[PadMode], PadMode],
244 ) -> Self:
245 """Convenience method to pad sample members."""
246 default_mode = "symmetric"
247 if isinstance(mode, collections.abc.Mapping):
248 mode_per_member = mode
249 else:
250 mode_per_member: Mapping[MemberId, PadMode] = {}
251 default_mode = mode
253 return self.__class__(
254 members={
255 m: t.pad(
256 pad_width=pad_width.get(m, {}),
257 mode=mode_per_member.get(m, default_mode),
258 )
259 for m, t in self.members.items()
260 },
261 stat=self.stat,
262 id=self.id,
263 )
266BlockT = TypeVar("BlockT", bound=BlockMeta)
269@pydantic.dataclasses.dataclass(frozen=True)
270class SampleBlockBase(Generic[BlockT]):
271 """base class for `SampleBlockMeta` and `SampleBlock`"""
273 sample_shape: PerMemberAnno[PerAxis[int]]
274 """the sample shape this block represents a part of"""
276 sample_id: SampleId
277 """identifier for the sample within its dataset"""
279 blocks: PerMemberAnno[BlockT]
280 """Individual tensor blocks comprising this sample block"""
282 block_index: BlockIndex
283 """the n-th block of the sample"""
285 blocks_in_sample: TotalNumberOfBlocks
286 """total number of blocks in the sample"""
288 @property
289 def shape(self) -> PerMember[PerAxis[int]]:
290 return MappingProxyType({mid: b.shape for mid, b in self.blocks.items()})
292 @property
293 def inner_shape(self) -> PerMember[PerAxis[int]]:
294 return MappingProxyType({mid: b.inner_shape for mid, b in self.blocks.items()})
297@dataclass
298class LinearSampleAxisTransform(LinearAxisTransform):
299 member: MemberId
302@pydantic.dataclasses.dataclass(frozen=True)
303class SampleBlockMeta(SampleBlockBase[BlockMeta]):
304 """Meta data of a dataset sample block"""
306 def get_transformed(
307 self, new_axes: PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]]
308 ) -> Self:
309 sample_shape = {
310 m: {
311 a: (
312 trf
313 if isinstance(trf, int)
314 else trf.compute(self.sample_shape[trf.member][trf.axis])
315 )
316 for a, trf in new_axes[m].items()
317 }
318 for m in new_axes
319 }
321 def get_member_halo(m: MemberId, round: Callable[[float], int]):
322 return {
323 a: (
324 Halo(0, 0)
325 if isinstance(trf, int)
326 or trf.axis not in self.blocks[trf.member].halo
327 else Halo(
328 round(self.blocks[trf.member].halo[trf.axis].left * trf.scale),
329 round(self.blocks[trf.member].halo[trf.axis].right * trf.scale),
330 )
331 )
332 for a, trf in new_axes[m].items()
333 }
335 halo: Dict[MemberId, Dict[AxisId, Halo]] = {}
336 for m in new_axes:
337 halo[m] = get_member_halo(m, floor)
338 if halo[m] != get_member_halo(m, ceil):
339 raise ValueError(
340 f"failed to unambiguously scale halo {halo[m]} with {new_axes[m]}"
341 + f" for {m}."
342 )
344 inner_slice = {
345 m: {
346 a: (
347 SliceInfo(0, trf)
348 if isinstance(trf, int)
349 else SliceInfo(
350 trf.compute(
351 self.blocks[trf.member].inner_slice[trf.axis].start
352 ),
353 trf.compute(self.blocks[trf.member].inner_slice[trf.axis].stop),
354 )
355 )
356 for a, trf in new_axes[m].items()
357 }
358 for m in new_axes
359 }
360 return self.__class__(
361 blocks={
362 m: BlockMeta(
363 sample_shape=sample_shape[m],
364 inner_slice=inner_slice[m],
365 halo=halo[m],
366 block_index=self.block_index,
367 blocks_in_sample=self.blocks_in_sample,
368 )
369 for m in new_axes
370 },
371 sample_shape=sample_shape,
372 sample_id=self.sample_id,
373 block_index=self.block_index,
374 blocks_in_sample=self.blocks_in_sample,
375 )
377 def with_data(self, data: PerMember[Tensor], *, stat: Stat) -> SampleBlock:
378 return SampleBlock(
379 sample_shape={
380 m: {
381 a: data[m].tagged_shape[a] if s == -1 else s
382 for a, s in member_shape.items()
383 }
384 for m, member_shape in self.sample_shape.items()
385 },
386 sample_id=self.sample_id,
387 blocks={
388 m: Block.from_meta(b, data=data[m]) for m, b in self.blocks.items()
389 },
390 stat=stat,
391 block_index=self.block_index,
392 blocks_in_sample=self.blocks_in_sample,
393 )
396@dataclass(frozen=True)
397class SampleBlock(SampleBlockBase[Block]):
398 """A block of a dataset sample"""
400 blocks: Dict[MemberId, Block]
401 """Individual tensor blocks comprising this sample block"""
403 stat: Stat
404 """computed statistics"""
406 @property
407 def members(self) -> PerMember[Tensor]:
408 """the sample block's tensors"""
409 return {m: b.data for m, b in self.blocks.items()}
411 def get_transformed_meta(
412 self, new_axes: PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]]
413 ) -> SampleBlockMeta:
414 return SampleBlockMeta(
415 sample_id=self.sample_id,
416 blocks=dict(self.blocks),
417 sample_shape=self.sample_shape,
418 block_index=self.block_index,
419 blocks_in_sample=self.blocks_in_sample,
420 ).get_transformed(new_axes)
422 @classmethod
423 def from_meta(
424 cls, meta: SampleBlockMeta, data: PerMember[Tensor], stat: Stat
425 ) -> Self:
426 return cls(
427 sample_shape=meta.sample_shape,
428 sample_id=meta.sample_id,
429 blocks={
430 m: Block.from_meta(b, data=data[m]) for m, b in meta.blocks.items()
431 },
432 stat=stat,
433 block_index=meta.block_index,
434 blocks_in_sample=meta.blocks_in_sample,
435 )
437 def get_meta(self) -> SampleBlockMeta:
438 return SampleBlockMeta(
439 sample_id=self.sample_id,
440 blocks={m: b.get_meta() for m, b in self.blocks.items()},
441 sample_shape=self.sample_shape,
442 block_index=self.block_index,
443 blocks_in_sample=self.blocks_in_sample,
444 )
446 def as_sample(self) -> Sample:
447 """Convert this sample block to a `Sample` with the shape of this block.
449 Note:
450 If you want to convert one or more sample block to a sample with the shape of the original, whole sample,
451 use `Sample.from_blocks()` instead.
452 """
453 return Sample(
454 members=dict(self.members),
455 stat=dict(self.stat),
456 id=self.sample_id,
457 )
460@dataclass(frozen=True)
461class SampleBlockWithOrigin(SampleBlock):
462 """A `SampleBlock` with a reference (`origin`) to the whole `Sample`"""
464 origin: Sample
465 """the sample this sample block was taken from"""
468class _ConsolidatedMemberBlocks:
469 def __init__(self, blocks: PerMember[BlockMeta]):
470 super().__init__()
471 block_indices = {b.block_index for b in blocks.values()}
472 assert len(block_indices) == 1
473 self.block_index = block_indices.pop()
474 blocks_in_samples = {b.blocks_in_sample for b in blocks.values()}
475 assert len(blocks_in_samples) == 1
476 self.blocks_in_sample = blocks_in_samples.pop()
479def sample_block_meta_generator(
480 blocks: Iterable[PerMember[BlockMeta]],
481 *,
482 sample_shape: PerMember[PerAxis[int]],
483 sample_id: SampleId,
484):
485 for member_blocks in blocks:
486 cons = _ConsolidatedMemberBlocks(member_blocks)
487 yield SampleBlockMeta(
488 blocks=dict(member_blocks),
489 sample_shape=sample_shape,
490 sample_id=sample_id,
491 block_index=cons.block_index,
492 blocks_in_sample=cons.blocks_in_sample,
493 )
496def sample_block_generator(
497 blocks: Iterable[PerMember[BlockMeta]],
498 *,
499 origin: Sample,
500 pad_mode: Union[PadMode, PerMember[PadMode]],
501) -> Iterable[SampleBlockWithOrigin]:
502 for member_blocks in blocks:
503 cons = _ConsolidatedMemberBlocks(member_blocks)
504 yield SampleBlockWithOrigin(
505 blocks={
506 m: Block.from_sample_member(
507 origin.members[m],
508 block=member_blocks[m],
509 pad_mode=pad_mode.get(m, "symmetric")
510 if isinstance(pad_mode, collections.abc.Mapping)
511 else pad_mode,
512 )
513 for m in origin.members
514 },
515 sample_shape=origin.shape,
516 origin=origin,
517 stat=origin.stat,
518 sample_id=origin.id,
519 block_index=cons.block_index,
520 blocks_in_sample=cons.blocks_in_sample,
521 )