Coverage for src / bioimageio / core / tensor.py: 83%
243 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-13 09:46 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-13 09:46 +0000
1from __future__ import annotations
3import collections.abc
4from itertools import permutations
5from typing import (
6 TYPE_CHECKING,
7 Any,
8 Callable,
9 Dict,
10 Iterator,
11 Mapping,
12 Optional,
13 Sequence,
14 Tuple,
15 Union,
16 cast,
17 get_args,
18)
20import numpy as np
21import xarray as xr
22from loguru import logger
23from numpy.typing import DTypeLike, NDArray
24from typing_extensions import Self, assert_never
26from bioimageio.spec.model import v0_5
28from ._magic_tensor_ops import MagicTensorOpsMixin
29from .axis import AxisId, AxisInfo, AxisLike, PerAxis
30from .common import (
31 CropWhere,
32 DTypeStr,
33 PadMode,
34 PadWhere,
35 PadWidth,
36 PadWidthLike,
37 QuantileMethod,
38 SliceInfo,
39)
41if TYPE_CHECKING:
42 from numpy.typing import ArrayLike, NDArray
45_ScalarOrArray = Union["ArrayLike", np.generic, "NDArray[Any]"] # TODO: add "DaskArray"
48# TODO: complete docstrings
49# TODO: in the long run---with improved typing in xarray---we should probably replace `Tensor` with xr.DataArray
50class Tensor(MagicTensorOpsMixin):
51 """A wrapper around an xr.DataArray for better integration with bioimageio.spec
52 and improved type annotations."""
54 _Compatible = Union["Tensor", xr.DataArray, _ScalarOrArray]
56 def __init__(
57 self,
58 array: NDArray[Any],
59 dims: Sequence[Union[AxisId, AxisLike]],
60 ) -> None:
61 super().__init__()
62 axes = tuple(
63 a if isinstance(a, AxisId) else AxisInfo.create(a).id for a in dims
64 )
65 self._data = xr.DataArray(array, dims=axes)
67 def __array__(self, dtype: DTypeLike = None):
68 return np.asarray(self._data, dtype=dtype)
70 def __getitem__(
71 self,
72 key: Union[
73 SliceInfo,
74 slice,
75 int,
76 PerAxis[Union[SliceInfo, slice, int]],
77 Tensor,
78 xr.DataArray,
79 ],
80 ) -> Self:
81 if isinstance(key, SliceInfo):
82 key = slice(*key)
83 elif isinstance(key, collections.abc.Mapping):
84 key = {
85 a: s if isinstance(s, int) else s if isinstance(s, slice) else slice(*s)
86 for a, s in key.items()
87 }
88 elif isinstance(key, Tensor):
89 key = key._data
91 return self.__class__.from_xarray(self._data[key])
93 def __setitem__(
94 self,
95 key: Union[PerAxis[Union[SliceInfo, slice]], Tensor, xr.DataArray],
96 value: Union[Tensor, xr.DataArray, float, int],
97 ) -> None:
98 if isinstance(key, Tensor):
99 key = key._data
100 elif isinstance(key, xr.DataArray):
101 pass
102 else:
103 key = {a: s if isinstance(s, slice) else slice(*s) for a, s in key.items()}
105 if isinstance(value, Tensor):
106 value = value._data
108 self._data[key] = value
110 def __len__(self) -> int:
111 return len(self.data)
113 def _iter(self: Any) -> Iterator[Any]:
114 for n in range(len(self)):
115 yield self[n]
117 def __iter__(self: Any) -> Iterator[Any]:
118 if self.ndim == 0:
119 raise TypeError("iteration over a 0-d array")
120 return self._iter()
122 def _binary_op(
123 self,
124 other: _Compatible,
125 f: Callable[[Any, Any], Any],
126 reflexive: bool = False,
127 ) -> Self:
128 data = self._data._binary_op( # pyright: ignore[reportPrivateUsage]
129 (other._data if isinstance(other, Tensor) else other),
130 f,
131 reflexive,
132 )
133 return self.__class__.from_xarray(data)
135 def _inplace_binary_op(
136 self,
137 other: _Compatible,
138 f: Callable[[Any, Any], Any],
139 ) -> Self:
140 _ = self._data._inplace_binary_op( # pyright: ignore[reportPrivateUsage]
141 (
142 other_d
143 if (other_d := getattr(other, "data")) is not None
144 and isinstance(
145 other_d,
146 xr.DataArray,
147 )
148 else other
149 ),
150 f,
151 )
152 return self
154 def _unary_op(self, f: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Self:
155 data = self._data._unary_op( # pyright: ignore[reportPrivateUsage]
156 f, *args, **kwargs
157 )
158 return self.__class__.from_xarray(data)
160 @classmethod
161 def from_xarray(cls, data_array: xr.DataArray) -> Self:
162 """create a `Tensor` from an xarray data array
164 note for internal use: this factory method is round-trip save
165 for any `Tensor`'s `data` property (an xarray.DataArray).
166 """
167 return cls(
168 array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims)
169 )
171 @classmethod
172 def from_numpy(
173 cls,
174 array: NDArray[Any],
175 *,
176 dims: Optional[Union[AxisLike, Sequence[AxisLike]]],
177 ) -> Tensor:
178 """create a `Tensor` from a numpy array
180 Args:
181 array: the nd numpy array
182 dims: A description of the array's axes,
183 if None axes are guessed (which might fail and raise a ValueError.)
185 Raises:
186 ValueError: if `dims` is None and dims guessing fails.
187 """
189 if dims is None:
190 return cls._interprete_array_wo_known_axes(array)
191 elif isinstance(dims, collections.abc.Sequence):
192 dim_seq = list(dims)
193 else:
194 dim_seq = [dims]
196 axis_infos = [AxisInfo.create(a) for a in dim_seq]
197 original_shape = tuple(array.shape)
199 successful_view = _get_array_view(array, axis_infos)
200 if successful_view is None:
201 raise ValueError(
202 f"Array shape {original_shape} does not map to axes {dims}"
203 )
205 return Tensor(successful_view, dims=tuple(a.id for a in axis_infos))
207 @property
208 def data(self):
209 return self._data
211 @property
212 def dims(self): # TODO: rename to `axes`?
213 """Tuple of dimension names associated with this tensor."""
214 return cast(Tuple[AxisId, ...], self._data.dims)
216 @property
217 def dtype(self) -> DTypeStr:
218 dt = str(self.data.dtype) # pyright: ignore[reportUnknownArgumentType]
219 assert dt in get_args(DTypeStr)
220 return dt # pyright: ignore[reportReturnType]
222 @property
223 def ndim(self):
224 """Number of tensor dimensions."""
225 return self._data.ndim
227 @property
228 def shape(self):
229 """Tuple of tensor axes lengths"""
230 return self._data.shape
232 @property
233 def shape_tuple(self):
234 """Tuple of tensor axes lengths"""
235 return self._data.shape
237 @property
238 def size(self):
239 """Number of elements in the tensor.
241 Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions.
242 """
243 return self._data.size
245 @property
246 def sizes(self):
247 """Ordered, immutable mapping from axis ids to axis lengths."""
248 return cast(Mapping[AxisId, int], self.data.sizes)
250 @property
251 def tagged_shape(self):
252 """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths."""
253 return self.sizes
255 def argmax(self) -> Mapping[AxisId, int]:
256 ret = self._data.argmax(...)
257 assert isinstance(ret, dict)
258 return {cast(AxisId, k): cast(int, v.item()) for k, v in ret.items()}
260 def astype(self, dtype: DTypeStr, *, copy: bool = False):
261 """Return tensor cast to `dtype`
263 note: if dtype is already satisfied copy if `copy`"""
264 return self.__class__.from_xarray(self._data.astype(dtype, copy=copy))
266 def clip(self, min: Optional[float] = None, max: Optional[float] = None):
267 """Return a tensor whose values are limited to [min, max].
268 At least one of max or min must be given."""
269 return self.__class__.from_xarray(self._data.clip(min, max))
271 def crop_to(
272 self,
273 sizes: PerAxis[int],
274 crop_where: Union[
275 CropWhere,
276 PerAxis[CropWhere],
277 ] = "left_and_right",
278 ) -> Self:
279 """crop to match `sizes`"""
280 if isinstance(crop_where, str):
281 crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims}
282 else:
283 crop_axis_where = crop_where
285 slices: Dict[AxisId, SliceInfo] = {}
287 for a, s_is in self.sizes.items():
288 if a not in sizes or sizes[a] == s_is:
289 pass
290 elif sizes[a] > s_is:
291 logger.warning(
292 "Cannot crop axis {} of size {} to larger size {}",
293 a,
294 s_is,
295 sizes[a],
296 )
297 elif a not in crop_axis_where:
298 raise ValueError(
299 f"Don't know where to crop axis {a}, `crop_where`={crop_where}"
300 )
301 else:
302 crop_this_axis_where = crop_axis_where[a]
303 if crop_this_axis_where == "left":
304 slices[a] = SliceInfo(s_is - sizes[a], s_is)
305 elif crop_this_axis_where == "right":
306 slices[a] = SliceInfo(0, sizes[a])
307 elif crop_this_axis_where == "left_and_right":
308 slices[a] = SliceInfo(
309 start := (s_is - sizes[a]) // 2, sizes[a] + start
310 )
311 else:
312 assert_never(crop_this_axis_where)
314 return self[slices]
316 def expand_dims(self, dims: Union[Sequence[AxisId], PerAxis[int]]) -> Self:
317 return self.__class__.from_xarray(self._data.expand_dims(dims=dims))
319 def item(
320 self,
321 key: Union[
322 None, SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]]
323 ] = None,
324 ):
325 """Copy a tensor element to a standard Python scalar and return it."""
326 if key is None:
327 ret = self._data.item()
328 else:
329 ret = self[key]._data.item()
331 assert isinstance(ret, (bool, float, int))
332 return ret
334 def mean(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
335 return self.__class__.from_xarray(self._data.mean(dim=dim))
337 def pad(
338 self,
339 pad_width: PerAxis[PadWidthLike],
340 mode: PadMode = "symmetric",
341 ) -> Self:
342 pad_width = {a: PadWidth.create(p) for a, p in pad_width.items()}
343 return self.__class__.from_xarray(
344 self._data.pad(pad_width=pad_width, mode=mode)
345 )
347 def pad_to(
348 self,
349 sizes: PerAxis[int],
350 pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right",
351 mode: PadMode = "symmetric",
352 ) -> Self:
353 """pad `tensor` to match `sizes`"""
354 if isinstance(pad_where, str):
355 pad_axis_where: PerAxis[PadWhere] = {a: pad_where for a in self.dims}
356 else:
357 pad_axis_where = pad_where
359 pad_width: Dict[AxisId, PadWidth] = {}
360 for a, s_is in self.sizes.items():
361 if a not in sizes or sizes[a] == s_is:
362 pad_width[a] = PadWidth(0, 0)
363 elif s_is > sizes[a]:
364 pad_width[a] = PadWidth(0, 0)
365 logger.warning(
366 "Cannot pad axis {} of size {} to smaller size {}",
367 a,
368 s_is,
369 sizes[a],
370 )
371 elif a not in pad_axis_where:
372 raise ValueError(
373 f"Don't know where to pad axis {a}, `pad_where`={pad_where}"
374 )
375 else:
376 pad_this_axis_where = pad_axis_where[a]
377 d = sizes[a] - s_is
378 if pad_this_axis_where == "left":
379 pad_width[a] = PadWidth(d, 0)
380 elif pad_this_axis_where == "right":
381 pad_width[a] = PadWidth(0, d)
382 elif pad_this_axis_where == "left_and_right":
383 pad_width[a] = PadWidth(left := d // 2, d - left)
384 else:
385 assert_never(pad_this_axis_where)
387 return self.pad(pad_width, mode)
389 def quantile(
390 self,
391 q: Union[float, Sequence[float]],
392 dim: Optional[Union[AxisId, Sequence[AxisId]]] = None,
393 method: QuantileMethod = "linear",
394 ) -> Self:
395 assert (
396 isinstance(q, (float, int))
397 and q >= 0.0
398 or not isinstance(q, (float, int))
399 and all(qq >= 0.0 for qq in q)
400 )
401 assert (
402 isinstance(q, (float, int))
403 and q <= 1.0
404 or not isinstance(q, (float, int))
405 and all(qq <= 1.0 for qq in q)
406 )
407 assert dim is None or (
408 (quantile_dim := AxisId("quantile")) != dim and quantile_dim not in set(dim)
409 )
410 return self.__class__.from_xarray(
411 self._data.quantile(q, dim=dim, method=method)
412 )
414 def resize_to(
415 self,
416 sizes: PerAxis[int],
417 *,
418 pad_where: Union[
419 PadWhere,
420 PerAxis[PadWhere],
421 ] = "left_and_right",
422 crop_where: Union[
423 CropWhere,
424 PerAxis[CropWhere],
425 ] = "left_and_right",
426 pad_mode: PadMode = "symmetric",
427 ):
428 """return cropped/padded tensor with `sizes`"""
429 crop_to_sizes: Dict[AxisId, int] = {}
430 pad_to_sizes: Dict[AxisId, int] = {}
431 new_axes = dict(sizes)
432 for a, s_is in self.sizes.items():
433 a = AxisId(str(a))
434 _ = new_axes.pop(a, None)
435 if a not in sizes or sizes[a] == s_is:
436 pass
437 elif s_is > sizes[a]:
438 crop_to_sizes[a] = sizes[a]
439 else:
440 pad_to_sizes[a] = sizes[a]
442 tensor = self
443 if crop_to_sizes:
444 tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where)
446 if pad_to_sizes:
447 tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode)
449 if new_axes:
450 tensor = tensor.expand_dims(new_axes)
452 return tensor
454 def std(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
455 return self.__class__.from_xarray(self._data.std(dim=dim))
457 def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
458 """Reduce this Tensor's data by applying sum along some dimension(s)."""
459 return self.__class__.from_xarray(self._data.sum(dim=dim))
461 def transpose(
462 self,
463 axes: Sequence[AxisId],
464 ) -> Self:
465 """return a transposed tensor
467 Args:
468 axes: the desired tensor axes
469 """
470 # expand missing tensor axes
471 missing_axes = tuple(a for a in axes if a not in self.dims)
472 array = self._data
473 if missing_axes:
474 array = array.expand_dims(missing_axes)
476 # transpose to the correct axis order
477 return self.__class__.from_xarray(array.transpose(*axes))
479 def var(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
480 return self.__class__.from_xarray(self._data.var(dim=dim))
482 @classmethod
483 def _interprete_array_wo_known_axes(cls, array: NDArray[Any]):
484 ndim = array.ndim
485 if ndim == 2:
486 current_axes = (
487 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[0]),
488 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[1]),
489 )
490 elif ndim == 3 and any(s <= 3 for s in array.shape):
491 current_axes = (
492 v0_5.ChannelAxis(
493 channel_names=[
494 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0])
495 ]
496 ),
497 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]),
498 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]),
499 )
500 elif ndim == 3:
501 current_axes = (
502 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[0]),
503 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]),
504 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]),
505 )
506 elif ndim == 4:
507 current_axes = (
508 v0_5.ChannelAxis(
509 channel_names=[
510 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0])
511 ]
512 ),
513 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[1]),
514 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[2]),
515 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[3]),
516 )
517 elif ndim == 5:
518 current_axes = (
519 v0_5.BatchAxis(),
520 v0_5.ChannelAxis(
521 channel_names=[
522 v0_5.Identifier(f"channel{i}") for i in range(array.shape[1])
523 ]
524 ),
525 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[2]),
526 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[3]),
527 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[4]),
528 )
529 else:
530 raise ValueError(f"Could not guess an axis mapping for {array.shape}")
532 return cls(array, dims=tuple(a.id for a in current_axes))
535def _add_singletons(arr: NDArray[Any], axis_infos: Sequence[AxisInfo]):
536 if len(arr.shape) > len(axis_infos):
537 # remove singletons
538 for i, s in enumerate(arr.shape):
539 if s == 1:
540 arr = np.take(arr, 0, axis=i)
541 if len(arr.shape) == len(axis_infos):
542 break
544 # add singletons if nececsary
545 for i, a in enumerate(axis_infos):
546 if len(arr.shape) >= len(axis_infos):
547 break
549 if a.maybe_singleton:
550 arr = np.expand_dims(arr, i)
552 return arr
555def _get_array_view(
556 original_array: NDArray[Any], axis_infos: Sequence[AxisInfo]
557) -> Optional[NDArray[Any]]:
558 perms = list(permutations(range(len(original_array.shape))))
559 perms.insert(1, perms.pop()) # try A and A.T first
561 for perm in perms:
562 view = original_array.transpose(perm)
563 view = _add_singletons(view, axis_infos)
564 if len(view.shape) != len(axis_infos):
565 return None
567 for s, a in zip(view.shape, axis_infos):
568 if s == 1 and not a.maybe_singleton:
569 break
570 else:
571 return view
573 return None