Coverage for src / bioimageio / core / tensor.py: 83%
245 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-27 22:06 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-27 22:06 +0000
1from __future__ import annotations
3import collections.abc
4from itertools import permutations
5from typing import (
6 TYPE_CHECKING,
7 Any,
8 Callable,
9 Dict,
10 Iterator,
11 Mapping,
12 Optional,
13 Sequence,
14 Tuple,
15 Union,
16 cast,
17 get_args,
18)
20import numpy as np
21import xarray as xr
22from loguru import logger
23from numpy.typing import DTypeLike, NDArray
24from typing_extensions import Self, assert_never
26from bioimageio.spec.model import v0_5
28from ._magic_tensor_ops import MagicTensorOpsMixin
29from .axis import AxisId, AxisInfo, AxisLike, PerAxis
30from .common import (
31 CropWhere,
32 DTypeStr,
33 PadMode,
34 PadWhere,
35 PadWidth,
36 PadWidthLike,
37 QuantileMethod,
38 SliceInfo,
39)
41if TYPE_CHECKING:
42 from numpy.typing import ArrayLike, NDArray
45_ScalarOrArray = Union["ArrayLike", np.generic, "NDArray[Any]"] # TODO: add "DaskArray"
48# TODO: complete docstrings
49# TODO: in the long run---with improved typing in xarray---we should probably replace `Tensor` with xr.DataArray
50class Tensor(MagicTensorOpsMixin):
51 """A wrapper around an xr.DataArray for better integration with bioimageio.spec
52 and improved type annotations."""
54 _Compatible = Union["Tensor", xr.DataArray, _ScalarOrArray]
56 def __init__(
57 self,
58 array: NDArray[Any],
59 dims: Sequence[Union[AxisId, AxisLike]],
60 ) -> None:
61 super().__init__()
62 axes = tuple(
63 a if isinstance(a, AxisId) else AxisInfo.create(a).id for a in dims
64 )
65 self._data = xr.DataArray(array, dims=axes)
67 def __array__(self, dtype: DTypeLike = None):
68 return np.asarray(self._data, dtype=dtype)
70 def __getitem__(
71 self,
72 key: Union[
73 SliceInfo,
74 slice,
75 int,
76 PerAxis[Union[SliceInfo, slice, int]],
77 Tensor,
78 xr.DataArray,
79 ],
80 ) -> Self:
81 if isinstance(key, SliceInfo):
82 key = slice(*key)
83 elif isinstance(key, collections.abc.Mapping):
84 key = {
85 a: s if isinstance(s, int) else s if isinstance(s, slice) else slice(*s)
86 for a, s in key.items()
87 }
88 elif isinstance(key, Tensor):
89 key = key._data
91 return self.__class__.from_xarray(self._data[key])
93 def __setitem__(
94 self,
95 key: Union[PerAxis[Union[SliceInfo, slice]], Tensor, xr.DataArray],
96 value: Union[Tensor, xr.DataArray, float, int],
97 ) -> None:
98 if isinstance(key, Tensor):
99 key = key._data
100 elif isinstance(key, xr.DataArray):
101 pass
102 else:
103 key = {a: s if isinstance(s, slice) else slice(*s) for a, s in key.items()}
105 if isinstance(value, Tensor):
106 value = value._data
108 self._data[key] = value
110 def __len__(self) -> int:
111 return len(self.data)
113 def _iter(self: Any) -> Iterator[Any]:
114 for n in range(len(self)):
115 yield self[n]
117 def __iter__(self: Any) -> Iterator[Any]:
118 if self.ndim == 0:
119 raise TypeError("iteration over a 0-d array")
120 return self._iter()
122 def _binary_op(
123 self,
124 other: _Compatible,
125 f: Callable[[Any, Any], Any],
126 reflexive: bool = False,
127 ) -> Self:
128 data = self._data._binary_op( # pyright: ignore[reportPrivateUsage]
129 (other._data if isinstance(other, Tensor) else other),
130 f,
131 reflexive,
132 )
133 return self.__class__.from_xarray(data)
135 def _inplace_binary_op(
136 self,
137 other: _Compatible,
138 f: Callable[[Any, Any], Any],
139 ) -> Self:
140 _ = self._data._inplace_binary_op( # pyright: ignore[reportPrivateUsage]
141 (
142 other_d
143 if (other_d := getattr(other, "data")) is not None
144 and isinstance(
145 other_d,
146 xr.DataArray,
147 )
148 else other
149 ),
150 f,
151 )
152 return self
154 def _unary_op(self, f: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Self:
155 data = self._data._unary_op( # pyright: ignore[reportPrivateUsage]
156 f, *args, **kwargs
157 )
158 return self.__class__.from_xarray(data)
160 @classmethod
161 def from_xarray(cls, data_array: xr.DataArray) -> Self:
162 """create a `Tensor` from an xarray data array
164 note for internal use: this factory method is round-trip save
165 for any `Tensor`'s `data` property (an xarray.DataArray).
166 """
167 return cls(
168 array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims)
169 )
171 @classmethod
172 def from_numpy(
173 cls,
174 array: NDArray[Any],
175 *,
176 dims: Optional[Union[AxisLike, Sequence[AxisLike]]],
177 ) -> Tensor:
178 """create a `Tensor` from a numpy array
180 Args:
181 array: the nd numpy array
182 dims: A description of the array's axes,
183 if None axes are guessed (which might fail and raise a ValueError.)
185 Raises:
186 ValueError: if `dims` is None and dims guessing fails.
187 """
189 if dims is None:
190 return cls._interprete_array_wo_known_axes(array)
191 elif isinstance(dims, collections.abc.Sequence):
192 dim_seq = list(dims)
193 else:
194 dim_seq = [dims]
196 axis_infos = [AxisInfo.create(a) for a in dim_seq]
197 original_shape = tuple(array.shape)
199 successful_view = _get_array_view(array, axis_infos)
200 if successful_view is None:
201 raise ValueError(
202 f"Array shape {original_shape} does not map to axes {dims}"
203 )
205 return Tensor(successful_view, dims=tuple(a.id for a in axis_infos))
207 @property
208 def data(self):
209 return self._data
211 @property
212 def dims(self): # TODO: rename to `axes`?
213 """Tuple of dimension names associated with this tensor."""
214 return cast(Tuple[AxisId, ...], self._data.dims)
216 @property
217 def dtype(self) -> DTypeStr:
218 dt = str(self.data.dtype) # pyright: ignore[reportUnknownArgumentType]
219 assert dt in get_args(DTypeStr)
220 return dt # pyright: ignore[reportReturnType]
222 @property
223 def ndim(self):
224 """Number of tensor dimensions."""
225 return self._data.ndim
227 @property
228 def shape(self):
229 """Tuple of tensor axes lengths"""
230 return self._data.shape
232 @property
233 def shape_tuple(self):
234 """Tuple of tensor axes lengths"""
235 return self._data.shape
237 @property
238 def size(self):
239 """Number of elements in the tensor.
241 Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions.
242 """
243 return self._data.size
245 @property
246 def sizes(self):
247 """Ordered, immutable mapping from axis ids to axis lengths."""
248 return cast(Mapping[AxisId, int], self.data.sizes)
250 @property
251 def tagged_shape(self):
252 """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths."""
253 return self.sizes
255 def to_numpy(self) -> NDArray[Any]:
256 """Return the data of this tensor as a numpy array."""
257 return self.data.to_numpy() # pyright: ignore[reportUnknownVariableType]
259 def argmax(self) -> Mapping[AxisId, int]:
260 ret = self._data.argmax(...)
261 assert isinstance(ret, dict)
262 return {cast(AxisId, k): cast(int, v.item()) for k, v in ret.items()}
264 def astype(self, dtype: DTypeStr, *, copy: bool = False):
265 """Return tensor cast to `dtype`
267 note: if dtype is already satisfied copy if `copy`"""
268 return self.__class__.from_xarray(self._data.astype(dtype, copy=copy))
270 def clip(self, min: Optional[float] = None, max: Optional[float] = None):
271 """Return a tensor whose values are limited to [min, max].
272 At least one of max or min must be given."""
273 return self.__class__.from_xarray(self._data.clip(min, max))
275 def crop_to(
276 self,
277 sizes: PerAxis[int],
278 crop_where: Union[
279 CropWhere,
280 PerAxis[CropWhere],
281 ] = "left_and_right",
282 ) -> Self:
283 """crop to match `sizes`"""
284 if isinstance(crop_where, str):
285 crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims}
286 else:
287 crop_axis_where = crop_where
289 slices: Dict[AxisId, SliceInfo] = {}
291 for a, s_is in self.sizes.items():
292 if a not in sizes or sizes[a] == s_is:
293 pass
294 elif sizes[a] > s_is:
295 logger.warning(
296 "Cannot crop axis {} of size {} to larger size {}",
297 a,
298 s_is,
299 sizes[a],
300 )
301 elif a not in crop_axis_where:
302 raise ValueError(
303 f"Don't know where to crop axis {a}, `crop_where`={crop_where}"
304 )
305 else:
306 crop_this_axis_where = crop_axis_where[a]
307 if crop_this_axis_where == "left":
308 slices[a] = SliceInfo(s_is - sizes[a], s_is)
309 elif crop_this_axis_where == "right":
310 slices[a] = SliceInfo(0, sizes[a])
311 elif crop_this_axis_where == "left_and_right":
312 slices[a] = SliceInfo(
313 start := (s_is - sizes[a]) // 2, sizes[a] + start
314 )
315 else:
316 assert_never(crop_this_axis_where)
318 return self[slices]
320 def expand_dims(self, dims: Union[Sequence[AxisId], PerAxis[int]]) -> Self:
321 return self.__class__.from_xarray(self._data.expand_dims(dims=dims))
323 def item(
324 self,
325 key: Union[
326 None, SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]]
327 ] = None,
328 ):
329 """Copy a tensor element to a standard Python scalar and return it."""
330 if key is None:
331 ret = self._data.item()
332 else:
333 ret = self[key]._data.item()
335 assert isinstance(ret, (bool, float, int))
336 return ret
338 def mean(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
339 return self.__class__.from_xarray(self._data.mean(dim=dim))
341 def pad(
342 self,
343 pad_width: PerAxis[PadWidthLike],
344 mode: PadMode = "symmetric",
345 ) -> Self:
346 pad_width = {a: PadWidth.create(p) for a, p in pad_width.items()}
347 return self.__class__.from_xarray(
348 self._data.pad(pad_width=pad_width, mode=mode)
349 )
351 def pad_to(
352 self,
353 sizes: PerAxis[int],
354 pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right",
355 mode: PadMode = "symmetric",
356 ) -> Self:
357 """pad `tensor` to match `sizes`"""
358 if isinstance(pad_where, str):
359 pad_axis_where: PerAxis[PadWhere] = {a: pad_where for a in self.dims}
360 else:
361 pad_axis_where = pad_where
363 pad_width: Dict[AxisId, PadWidth] = {}
364 for a, s_is in self.sizes.items():
365 if a not in sizes or sizes[a] == s_is:
366 pad_width[a] = PadWidth(0, 0)
367 elif s_is > sizes[a]:
368 pad_width[a] = PadWidth(0, 0)
369 logger.warning(
370 "Cannot pad axis {} of size {} to smaller size {}",
371 a,
372 s_is,
373 sizes[a],
374 )
375 elif a not in pad_axis_where:
376 raise ValueError(
377 f"Don't know where to pad axis {a}, `pad_where`={pad_where}"
378 )
379 else:
380 pad_this_axis_where = pad_axis_where[a]
381 d = sizes[a] - s_is
382 if pad_this_axis_where == "left":
383 pad_width[a] = PadWidth(d, 0)
384 elif pad_this_axis_where == "right":
385 pad_width[a] = PadWidth(0, d)
386 elif pad_this_axis_where == "left_and_right":
387 pad_width[a] = PadWidth(left := d // 2, d - left)
388 else:
389 assert_never(pad_this_axis_where)
391 return self.pad(pad_width, mode)
393 def quantile(
394 self,
395 q: Union[float, Sequence[float]],
396 dim: Optional[Union[AxisId, Sequence[AxisId]]] = None,
397 method: QuantileMethod = "linear",
398 ) -> Self:
399 assert (
400 isinstance(q, (float, int))
401 and q >= 0.0
402 or not isinstance(q, (float, int))
403 and all(qq >= 0.0 for qq in q)
404 )
405 assert (
406 isinstance(q, (float, int))
407 and q <= 1.0
408 or not isinstance(q, (float, int))
409 and all(qq <= 1.0 for qq in q)
410 )
411 assert dim is None or (
412 (quantile_dim := AxisId("quantile")) != dim and quantile_dim not in set(dim)
413 )
414 return self.__class__.from_xarray(
415 self._data.quantile(q, dim=dim, method=method)
416 )
418 def resize_to(
419 self,
420 sizes: PerAxis[int],
421 *,
422 pad_where: Union[
423 PadWhere,
424 PerAxis[PadWhere],
425 ] = "left_and_right",
426 crop_where: Union[
427 CropWhere,
428 PerAxis[CropWhere],
429 ] = "left_and_right",
430 pad_mode: PadMode = "symmetric",
431 ):
432 """return cropped/padded tensor with `sizes`"""
433 crop_to_sizes: Dict[AxisId, int] = {}
434 pad_to_sizes: Dict[AxisId, int] = {}
435 new_axes = dict(sizes)
436 for a, s_is in self.sizes.items():
437 a = AxisId(str(a))
438 _ = new_axes.pop(a, None)
439 if a not in sizes or sizes[a] == s_is:
440 pass
441 elif s_is > sizes[a]:
442 crop_to_sizes[a] = sizes[a]
443 else:
444 pad_to_sizes[a] = sizes[a]
446 tensor = self
447 if crop_to_sizes:
448 tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where)
450 if pad_to_sizes:
451 tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode)
453 if new_axes:
454 tensor = tensor.expand_dims(new_axes)
456 return tensor
458 def std(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
459 return self.__class__.from_xarray(self._data.std(dim=dim))
461 def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
462 """Reduce this Tensor's data by applying sum along some dimension(s)."""
463 return self.__class__.from_xarray(self._data.sum(dim=dim))
465 def transpose(
466 self,
467 axes: Sequence[AxisId],
468 ) -> Self:
469 """return a transposed tensor
471 Args:
472 axes: the desired tensor axes
473 """
474 # expand missing tensor axes
475 missing_axes = tuple(a for a in axes if a not in self.dims)
476 array = self._data
477 if missing_axes:
478 array = array.expand_dims(missing_axes)
480 # transpose to the correct axis order
481 return self.__class__.from_xarray(array.transpose(*axes))
483 def var(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
484 return self.__class__.from_xarray(self._data.var(dim=dim))
486 @classmethod
487 def _interprete_array_wo_known_axes(cls, array: NDArray[Any]):
488 ndim = array.ndim
489 if ndim == 2:
490 current_axes = (
491 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[0]),
492 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[1]),
493 )
494 elif ndim == 3 and any(s <= 3 for s in array.shape):
495 current_axes = (
496 v0_5.ChannelAxis(
497 channel_names=[
498 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0])
499 ]
500 ),
501 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]),
502 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]),
503 )
504 elif ndim == 3:
505 current_axes = (
506 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[0]),
507 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]),
508 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]),
509 )
510 elif ndim == 4:
511 current_axes = (
512 v0_5.ChannelAxis(
513 channel_names=[
514 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0])
515 ]
516 ),
517 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[1]),
518 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[2]),
519 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[3]),
520 )
521 elif ndim == 5:
522 current_axes = (
523 v0_5.BatchAxis(),
524 v0_5.ChannelAxis(
525 channel_names=[
526 v0_5.Identifier(f"channel{i}") for i in range(array.shape[1])
527 ]
528 ),
529 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[2]),
530 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[3]),
531 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[4]),
532 )
533 else:
534 raise ValueError(f"Could not guess an axis mapping for {array.shape}")
536 return cls(array, dims=tuple(a.id for a in current_axes))
539def _add_singletons(arr: NDArray[Any], axis_infos: Sequence[AxisInfo]):
540 if len(arr.shape) > len(axis_infos):
541 # remove singletons
542 for i, s in enumerate(arr.shape):
543 if s == 1:
544 arr = np.take(arr, 0, axis=i)
545 if len(arr.shape) == len(axis_infos):
546 break
548 # add singletons if nececsary
549 for i, a in enumerate(axis_infos):
550 if len(arr.shape) >= len(axis_infos):
551 break
553 if a.maybe_singleton:
554 arr = np.expand_dims(arr, i)
556 return arr
559def _get_array_view(
560 original_array: NDArray[Any], axis_infos: Sequence[AxisInfo]
561) -> Optional[NDArray[Any]]:
562 perms = list(permutations(range(len(original_array.shape))))
563 perms.insert(1, perms.pop()) # try A and A.T first
565 for perm in perms:
566 view = original_array.transpose(perm)
567 view = _add_singletons(view, axis_infos)
568 if len(view.shape) != len(axis_infos):
569 return None
571 for s, a in zip(view.shape, axis_infos):
572 if s == 1 and not a.maybe_singleton:
573 break
574 else:
575 return view
577 return None