Coverage for src / bioimageio / core / tensor.py: 83%
247 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-22 18:38 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-22 18:38 +0000
1from __future__ import annotations
3import collections.abc
4from itertools import permutations
5from typing import (
6 TYPE_CHECKING,
7 Any,
8 Callable,
9 Dict,
10 Iterator,
11 Mapping,
12 Optional,
13 Sequence,
14 Tuple,
15 Union,
16 cast,
17 get_args,
18)
20import numpy as np
21import xarray as xr
22from loguru import logger
23from numpy.typing import DTypeLike, NDArray
24from typing_extensions import Self, assert_never
26from bioimageio.spec.model import v0_5
28from ._magic_tensor_ops import MagicTensorOpsMixin
29from .axis import AxisId, AxisInfo, AxisLike, PerAxis
30from .common import (
31 CropWhere,
32 DTypeStr,
33 PadMode,
34 PadWhere,
35 PadWidth,
36 PadWidthLike,
37 QuantileMethod,
38 SliceInfo,
39)
41if TYPE_CHECKING:
42 from numpy.typing import ArrayLike, NDArray
45_ScalarOrArray = Union["ArrayLike", np.generic, "NDArray[Any]"] # TODO: add "DaskArray"
48# TODO: complete docstrings
49# TODO: in the long run---with improved typing in xarray---we should probably replace `Tensor` with xr.DataArray
50class Tensor(MagicTensorOpsMixin):
51 """A wrapper around an xr.DataArray for better integration with bioimageio.spec
52 and improved type annotations."""
54 _Compatible = Union["Tensor", xr.DataArray, _ScalarOrArray]
56 def __init__(
57 self,
58 array: NDArray[Any],
59 dims: Sequence[Union[AxisId, AxisLike]],
60 ) -> None:
61 super().__init__()
62 axes = tuple(
63 a if isinstance(a, AxisId) else AxisInfo.create(a).id for a in dims
64 )
65 self._data = xr.DataArray(array, dims=axes)
67 def __repr__(self) -> str:
68 return f"<Tensor {repr(self._data)}>"
70 def __array__(self, dtype: DTypeLike = None):
71 return np.asarray(self._data, dtype=dtype)
73 def __getitem__(
74 self,
75 key: Union[
76 SliceInfo,
77 slice,
78 int,
79 PerAxis[Union[SliceInfo, slice, int]],
80 Tensor,
81 xr.DataArray,
82 ],
83 ) -> Self:
84 if isinstance(key, SliceInfo):
85 key = slice(*key)
86 elif isinstance(key, collections.abc.Mapping):
87 key = {
88 a: s if isinstance(s, int) else s if isinstance(s, slice) else slice(*s)
89 for a, s in key.items()
90 }
91 elif isinstance(key, Tensor):
92 key = key._data
94 return self.__class__.from_xarray(self._data[key])
96 def __setitem__(
97 self,
98 key: Union[PerAxis[Union[SliceInfo, slice]], Tensor, xr.DataArray],
99 value: Union[Tensor, xr.DataArray, float, int],
100 ) -> None:
101 if isinstance(key, Tensor):
102 key = key._data
103 elif isinstance(key, xr.DataArray):
104 pass
105 else:
106 key = {a: s if isinstance(s, slice) else slice(*s) for a, s in key.items()}
108 if isinstance(value, Tensor):
109 value = value._data
111 self._data[key] = value
113 def __len__(self) -> int:
114 return len(self.data)
116 def _iter(self: Any) -> Iterator[Any]:
117 for n in range(len(self)):
118 yield self[n]
120 def __iter__(self: Any) -> Iterator[Any]:
121 if self.ndim == 0:
122 raise TypeError("iteration over a 0-d array")
123 return self._iter()
125 def _binary_op(
126 self,
127 other: _Compatible,
128 f: Callable[[Any, Any], Any],
129 reflexive: bool = False,
130 ) -> Self:
131 data = self._data._binary_op( # pyright: ignore[reportPrivateUsage]
132 (other._data if isinstance(other, Tensor) else other),
133 f,
134 reflexive,
135 )
136 return self.__class__.from_xarray(data)
138 def _inplace_binary_op(
139 self,
140 other: _Compatible,
141 f: Callable[[Any, Any], Any],
142 ) -> Self:
143 _ = self._data._inplace_binary_op( # pyright: ignore[reportPrivateUsage]
144 (
145 other_d
146 if (other_d := getattr(other, "data")) is not None
147 and isinstance(
148 other_d,
149 xr.DataArray,
150 )
151 else other
152 ),
153 f,
154 )
155 return self
157 def _unary_op(self, f: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Self:
158 data = self._data._unary_op( # pyright: ignore[reportPrivateUsage]
159 f, *args, **kwargs
160 )
161 return self.__class__.from_xarray(data)
163 @classmethod
164 def from_xarray(cls, data_array: xr.DataArray) -> Self:
165 """create a `Tensor` from an xarray data array
167 note for internal use: this factory method is round-trip save
168 for any `Tensor`'s `data` property (an xarray.DataArray).
169 """
170 return cls(
171 array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims)
172 )
174 @classmethod
175 def from_numpy(
176 cls,
177 array: NDArray[Any],
178 *,
179 dims: Optional[Union[AxisLike, Sequence[AxisLike]]],
180 ) -> Tensor:
181 """create a `Tensor` from a numpy array
183 Args:
184 array: the nd numpy array
185 dims: A description of the array's axes,
186 if None axes are guessed (which might fail and raise a ValueError.)
188 Raises:
189 ValueError: if `dims` is None and dims guessing fails.
190 """
192 if dims is None:
193 return cls._interprete_array_wo_known_axes(array)
194 elif isinstance(dims, collections.abc.Sequence):
195 dim_seq = list(dims)
196 else:
197 dim_seq = [dims]
199 axis_infos = [AxisInfo.create(a) for a in dim_seq]
200 original_shape = tuple(array.shape)
202 successful_view = _get_array_view(array, axis_infos)
203 if successful_view is None:
204 raise ValueError(
205 f"Array shape {original_shape} does not map to axes {dims}"
206 )
208 return Tensor(successful_view, dims=tuple(a.id for a in axis_infos))
210 @property
211 def data(self):
212 return self._data
214 @property
215 def dims(self): # TODO: rename to `axes`?
216 """Tuple of dimension names associated with this tensor."""
217 return cast(Tuple[AxisId, ...], self._data.dims)
219 @property
220 def dtype(self) -> DTypeStr:
221 dt = str(self.data.dtype) # pyright: ignore[reportUnknownArgumentType]
222 assert dt in get_args(DTypeStr)
223 return dt # pyright: ignore[reportReturnType]
225 @property
226 def ndim(self):
227 """Number of tensor dimensions."""
228 return self._data.ndim
230 @property
231 def shape(self):
232 """Tuple of tensor axes lengths"""
233 return self._data.shape
235 @property
236 def shape_tuple(self):
237 """Tuple of tensor axes lengths"""
238 return self._data.shape
240 @property
241 def size(self):
242 """Number of elements in the tensor.
244 Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions.
245 """
246 return self._data.size
248 @property
249 def sizes(self):
250 """Ordered, immutable mapping from axis ids to axis lengths."""
251 return cast(Mapping[AxisId, int], self.data.sizes)
253 @property
254 def tagged_shape(self):
255 """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths."""
256 return self.sizes
258 def to_numpy(self) -> NDArray[Any]:
259 """Return the data of this tensor as a numpy array."""
260 return self.data.to_numpy() # pyright: ignore[reportUnknownVariableType]
262 def argmax(self) -> Mapping[AxisId, int]:
263 ret = self._data.argmax(...)
264 assert isinstance(ret, dict)
265 return {cast(AxisId, k): cast(int, v.item()) for k, v in ret.items()}
267 def astype(self, dtype: DTypeStr, *, copy: bool = False):
268 """Return tensor cast to `dtype`
270 note: if dtype is already satisfied copy if `copy`"""
271 return self.__class__.from_xarray(self._data.astype(dtype, copy=copy))
273 def clip(self, min: Optional[float] = None, max: Optional[float] = None):
274 """Return a tensor whose values are limited to [min, max].
275 At least one of max or min must be given."""
276 return self.__class__.from_xarray(self._data.clip(min, max))
278 def crop_to(
279 self,
280 sizes: PerAxis[int],
281 crop_where: Union[
282 CropWhere,
283 PerAxis[CropWhere],
284 ] = "left_and_right",
285 ) -> Self:
286 """crop to match `sizes`"""
287 if isinstance(crop_where, str):
288 crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims}
289 else:
290 crop_axis_where = crop_where
292 slices: Dict[AxisId, SliceInfo] = {}
294 for a, s_is in self.sizes.items():
295 if a not in sizes or sizes[a] == s_is:
296 pass
297 elif sizes[a] > s_is:
298 logger.warning(
299 "Cannot crop axis {} of size {} to larger size {}",
300 a,
301 s_is,
302 sizes[a],
303 )
304 elif a not in crop_axis_where:
305 raise ValueError(
306 f"Don't know where to crop axis {a}, `crop_where`={crop_where}"
307 )
308 else:
309 crop_this_axis_where = crop_axis_where[a]
310 if crop_this_axis_where == "left":
311 slices[a] = SliceInfo(s_is - sizes[a], s_is)
312 elif crop_this_axis_where == "right":
313 slices[a] = SliceInfo(0, sizes[a])
314 elif crop_this_axis_where == "left_and_right":
315 slices[a] = SliceInfo(
316 start := (s_is - sizes[a]) // 2, sizes[a] + start
317 )
318 else:
319 assert_never(crop_this_axis_where)
321 return self[slices]
323 def expand_dims(self, dims: Union[Sequence[AxisId], PerAxis[int]]) -> Self:
324 return self.__class__.from_xarray(self._data.expand_dims(dims=dims))
326 def item(
327 self,
328 key: Union[
329 None, SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]]
330 ] = None,
331 ):
332 """Copy a tensor element to a standard Python scalar and return it."""
333 if key is None:
334 ret = self._data.item()
335 else:
336 ret = self[key]._data.item()
338 assert isinstance(ret, (bool, float, int))
339 return ret
341 def mean(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
342 return self.__class__.from_xarray(self._data.mean(dim=dim))
344 def pad(
345 self,
346 pad_width: PerAxis[PadWidthLike],
347 mode: PadMode = "symmetric",
348 ) -> Self:
349 pad_width = {a: PadWidth.create(p) for a, p in pad_width.items()}
350 return self.__class__.from_xarray(
351 self._data.pad(pad_width=pad_width, mode=mode)
352 )
354 def pad_to(
355 self,
356 sizes: PerAxis[int],
357 pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right",
358 mode: PadMode = "symmetric",
359 ) -> Self:
360 """pad `tensor` to match `sizes`"""
361 if isinstance(pad_where, str):
362 pad_axis_where: PerAxis[PadWhere] = {a: pad_where for a in self.dims}
363 else:
364 pad_axis_where = pad_where
366 pad_width: Dict[AxisId, PadWidth] = {}
367 for a, s_is in self.sizes.items():
368 if a not in sizes or sizes[a] == s_is:
369 pad_width[a] = PadWidth(0, 0)
370 elif s_is > sizes[a]:
371 pad_width[a] = PadWidth(0, 0)
372 logger.warning(
373 "Cannot pad axis {} of size {} to smaller size {}",
374 a,
375 s_is,
376 sizes[a],
377 )
378 elif a not in pad_axis_where:
379 raise ValueError(
380 f"Don't know where to pad axis {a}, `pad_where`={pad_where}"
381 )
382 else:
383 pad_this_axis_where = pad_axis_where[a]
384 d = sizes[a] - s_is
385 if pad_this_axis_where == "left":
386 pad_width[a] = PadWidth(d, 0)
387 elif pad_this_axis_where == "right":
388 pad_width[a] = PadWidth(0, d)
389 elif pad_this_axis_where == "left_and_right":
390 pad_width[a] = PadWidth(left := d // 2, d - left)
391 else:
392 assert_never(pad_this_axis_where)
394 return self.pad(pad_width, mode)
396 def quantile(
397 self,
398 q: Union[float, Sequence[float]],
399 dim: Optional[Union[AxisId, Sequence[AxisId]]] = None,
400 method: QuantileMethod = "linear",
401 ) -> Self:
402 assert (
403 isinstance(q, (float, int))
404 and q >= 0.0
405 or not isinstance(q, (float, int))
406 and all(qq >= 0.0 for qq in q)
407 )
408 assert (
409 isinstance(q, (float, int))
410 and q <= 1.0
411 or not isinstance(q, (float, int))
412 and all(qq <= 1.0 for qq in q)
413 )
414 assert dim is None or (
415 (quantile_dim := AxisId("quantile")) != dim and quantile_dim not in set(dim)
416 )
417 return self.__class__.from_xarray(
418 self._data.quantile(q, dim=dim, method=method)
419 )
421 def resize_to(
422 self,
423 sizes: PerAxis[int],
424 *,
425 pad_where: Union[
426 PadWhere,
427 PerAxis[PadWhere],
428 ] = "left_and_right",
429 crop_where: Union[
430 CropWhere,
431 PerAxis[CropWhere],
432 ] = "left_and_right",
433 pad_mode: PadMode = "symmetric",
434 ):
435 """return cropped/padded tensor with `sizes`"""
436 crop_to_sizes: Dict[AxisId, int] = {}
437 pad_to_sizes: Dict[AxisId, int] = {}
438 new_axes = dict(sizes)
439 for a, s_is in self.sizes.items():
440 a = AxisId(str(a))
441 _ = new_axes.pop(a, None)
442 if a not in sizes or sizes[a] == s_is:
443 pass
444 elif s_is > sizes[a]:
445 crop_to_sizes[a] = sizes[a]
446 else:
447 pad_to_sizes[a] = sizes[a]
449 tensor = self
450 if crop_to_sizes:
451 tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where)
453 if pad_to_sizes:
454 tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode)
456 if new_axes:
457 tensor = tensor.expand_dims(new_axes)
459 return tensor
461 def std(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
462 return self.__class__.from_xarray(self._data.std(dim=dim))
464 def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
465 """Reduce this Tensor's data by applying sum along some dimension(s)."""
466 return self.__class__.from_xarray(self._data.sum(dim=dim))
468 def transpose(
469 self,
470 axes: Sequence[AxisId],
471 ) -> Self:
472 """return a transposed tensor
474 Args:
475 axes: the desired tensor axes
476 """
477 # expand missing tensor axes
478 missing_axes = tuple(a for a in axes if a not in self.dims)
479 array = self._data
480 if missing_axes:
481 array = array.expand_dims(missing_axes)
483 # transpose to the correct axis order
484 return self.__class__.from_xarray(array.transpose(*axes))
486 def var(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
487 return self.__class__.from_xarray(self._data.var(dim=dim))
489 @classmethod
490 def _interprete_array_wo_known_axes(cls, array: NDArray[Any]):
491 ndim = array.ndim
492 if ndim == 2:
493 current_axes = (
494 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[0]),
495 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[1]),
496 )
497 elif ndim == 3 and any(s <= 3 for s in array.shape):
498 current_axes = (
499 v0_5.ChannelAxis(
500 channel_names=[
501 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0])
502 ]
503 ),
504 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]),
505 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]),
506 )
507 elif ndim == 3:
508 current_axes = (
509 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[0]),
510 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]),
511 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]),
512 )
513 elif ndim == 4:
514 current_axes = (
515 v0_5.ChannelAxis(
516 channel_names=[
517 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0])
518 ]
519 ),
520 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[1]),
521 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[2]),
522 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[3]),
523 )
524 elif ndim == 5:
525 current_axes = (
526 v0_5.BatchAxis(),
527 v0_5.ChannelAxis(
528 channel_names=[
529 v0_5.Identifier(f"channel{i}") for i in range(array.shape[1])
530 ]
531 ),
532 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[2]),
533 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[3]),
534 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[4]),
535 )
536 else:
537 raise ValueError(f"Could not guess an axis mapping for {array.shape}")
539 return cls(array, dims=tuple(a.id for a in current_axes))
542def _add_singletons(arr: NDArray[Any], axis_infos: Sequence[AxisInfo]):
543 if len(arr.shape) > len(axis_infos):
544 # remove singletons
545 for i, s in enumerate(arr.shape):
546 if s == 1:
547 arr = np.take(arr, 0, axis=i)
548 if len(arr.shape) == len(axis_infos):
549 break
551 # add singletons if nececsary
552 for i, a in enumerate(axis_infos):
553 if len(arr.shape) >= len(axis_infos):
554 break
556 if a.maybe_singleton:
557 arr = np.expand_dims(arr, i)
559 return arr
562def _get_array_view(
563 original_array: NDArray[Any], axis_infos: Sequence[AxisInfo]
564) -> Optional[NDArray[Any]]:
565 perms = list(permutations(range(len(original_array.shape))))
566 perms.insert(1, perms.pop()) # try A and A.T first
568 for perm in perms:
569 view = original_array.transpose(perm)
570 view = _add_singletons(view, axis_infos)
571 if len(view.shape) != len(axis_infos):
572 return None
574 for s, a in zip(view.shape, axis_infos):
575 if s == 1 and not a.maybe_singleton:
576 break
577 else:
578 return view
580 return None