Coverage for src / bioimageio / core / tensor.py: 83%
257 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-18 12:35 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-18 12:35 +0000
1from __future__ import annotations
3import collections.abc
4from itertools import permutations
5from typing import (
6 TYPE_CHECKING,
7 Any,
8 Callable,
9 Dict,
10 Iterator,
11 Mapping,
12 Optional,
13 Sequence,
14 Tuple,
15 Union,
16 cast,
17 get_args,
18)
20import numpy as np
21import xarray as xr
22from loguru import logger
23from numpy.typing import DTypeLike, NDArray
24from typing_extensions import Self, assert_never
26from bioimageio.spec.model import v0_5
28from ._magic_tensor_ops import MagicTensorOpsMixin
29from .axis import AxisId, AxisInfo, AxisLike, PerAxis
30from .common import (
31 CropWhere,
32 DTypeStr,
33 PadMode,
34 PadWhere,
35 PadWidth,
36 PadWidthLike,
37 QuantileMethod,
38 SliceInfo,
39)
41if TYPE_CHECKING:
42 from numpy.typing import ArrayLike, NDArray
45_ScalarOrArray = Union["ArrayLike", np.generic, "NDArray[Any]"] # TODO: add "DaskArray"
48def _resolve_pad_mode(mode: PadMode):
49 constant_value = None
50 if isinstance(mode, str):
51 mode_name = mode
52 elif isinstance(mode, v0_5.ConstantPadding):
53 mode_name = mode.mode
54 constant_value = mode.value
55 elif isinstance(
56 mode, (v0_5.EdgePadding, v0_5.ReflectPadding, v0_5.SymmetricPadding)
57 ):
58 mode_name = mode.mode
59 else:
60 assert_never(mode)
62 return mode_name, constant_value
65# TODO: complete docstrings
66# TODO: in the long run---with improved typing in xarray---we should probably replace `Tensor` with xr.DataArray
67class Tensor(MagicTensorOpsMixin):
68 """A wrapper around an xr.DataArray for better integration with bioimageio.spec
69 and improved type annotations."""
71 _Compatible = Union["Tensor", xr.DataArray, _ScalarOrArray]
73 def __init__(
74 self,
75 array: NDArray[Any],
76 dims: Sequence[Union[AxisId, AxisLike]],
77 ) -> None:
78 super().__init__()
79 axes = tuple(
80 a if isinstance(a, AxisId) else AxisInfo.create(a).id for a in dims
81 )
82 self._data = xr.DataArray(array, dims=axes)
84 def __repr__(self) -> str:
85 return f"<Tensor {repr(self._data)}>"
87 def __array__(self, dtype: DTypeLike = None):
88 return np.asarray(self._data, dtype=dtype)
90 def __getitem__(
91 self,
92 key: Union[
93 SliceInfo,
94 slice,
95 int,
96 PerAxis[Union[SliceInfo, slice, int]],
97 Tensor,
98 xr.DataArray,
99 ],
100 ) -> Self:
101 if isinstance(key, SliceInfo):
102 key = slice(*key)
103 elif isinstance(key, collections.abc.Mapping):
104 key = {
105 a: s if isinstance(s, int) else s if isinstance(s, slice) else slice(*s)
106 for a, s in key.items()
107 }
108 elif isinstance(key, Tensor):
109 key = key._data
111 return self.__class__.from_xarray(self._data[key])
113 def __setitem__(
114 self,
115 key: Union[PerAxis[Union[SliceInfo, slice]], Tensor, xr.DataArray],
116 value: Union[Tensor, xr.DataArray, float, int],
117 ) -> None:
118 if isinstance(key, Tensor):
119 key = key._data
120 elif isinstance(key, xr.DataArray):
121 pass
122 else:
123 key = {a: s if isinstance(s, slice) else slice(*s) for a, s in key.items()}
125 if isinstance(value, Tensor):
126 value = value._data
128 self._data[key] = value
130 def __len__(self) -> int:
131 return len(self.data)
133 def _iter(self: Any) -> Iterator[Any]:
134 for n in range(len(self)):
135 yield self[n]
137 def __iter__(self: Any) -> Iterator[Any]:
138 if self.ndim == 0:
139 raise TypeError("iteration over a 0-d array")
140 return self._iter()
142 def _binary_op(
143 self,
144 other: _Compatible,
145 f: Callable[[Any, Any], Any],
146 reflexive: bool = False,
147 ) -> Self:
148 data = self._data._binary_op( # pyright: ignore[reportPrivateUsage]
149 (other._data if isinstance(other, Tensor) else other),
150 f,
151 reflexive,
152 )
153 return self.__class__.from_xarray(data)
155 def _inplace_binary_op(
156 self,
157 other: _Compatible,
158 f: Callable[[Any, Any], Any],
159 ) -> Self:
160 _ = self._data._inplace_binary_op( # pyright: ignore[reportPrivateUsage]
161 (
162 other_d
163 if (other_d := getattr(other, "data")) is not None
164 and isinstance(
165 other_d,
166 xr.DataArray,
167 )
168 else other
169 ),
170 f,
171 )
172 return self
174 def _unary_op(self, f: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Self:
175 data = self._data._unary_op( # pyright: ignore[reportPrivateUsage]
176 f, *args, **kwargs
177 )
178 return self.__class__.from_xarray(data)
180 @classmethod
181 def from_xarray(cls, data_array: xr.DataArray) -> Self:
182 """create a `Tensor` from an xarray data array
184 note for internal use: this factory method is round-trip save
185 for any `Tensor`'s `data` property (an xarray.DataArray).
186 """
187 return cls(
188 array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims)
189 )
191 @classmethod
192 def from_numpy(
193 cls,
194 array: NDArray[Any],
195 *,
196 dims: Optional[Union[AxisLike, Sequence[AxisLike]]],
197 ) -> Tensor:
198 """create a `Tensor` from a numpy array
200 Args:
201 array: the nd numpy array
202 dims: A description of the array's axes.
203 If None axes are guessed (which might fail and raise a ValueError.)
204 If dims do not match array shape, permutations and singleton dimensions are tried to find a match.
205 Raises:
206 ValueError: if `dims` is None and dims guessing fails.
207 """
209 if dims is None:
210 return cls._interprete_array_wo_known_axes(array)
211 elif isinstance(dims, collections.abc.Sequence):
212 dim_seq = list(dims)
213 else:
214 dim_seq = [dims]
216 axis_infos = [AxisInfo.create(a) for a in dim_seq]
217 original_shape = tuple(array.shape)
219 successful_view = _get_array_view(array, axis_infos)
220 if successful_view is None:
221 raise ValueError(
222 f"Array shape {original_shape} does not map to axes {dims}"
223 )
225 return Tensor(successful_view, dims=tuple(a.id for a in axis_infos))
227 @property
228 def data(self):
229 return self._data
231 @property
232 def dims(self): # TODO: rename to `axes`?
233 """Tuple of dimension names associated with this tensor."""
234 return cast(Tuple[AxisId, ...], self._data.dims)
236 @property
237 def dtype(self) -> DTypeStr:
238 dt = str(self.data.dtype) # pyright: ignore[reportUnknownArgumentType]
239 assert dt in get_args(DTypeStr)
240 return dt # pyright: ignore[reportReturnType]
242 @property
243 def ndim(self):
244 """Number of tensor dimensions."""
245 return self._data.ndim
247 @property
248 def shape(self):
249 """Tuple of tensor axes lengths"""
250 return self._data.shape
252 @property
253 def shape_tuple(self):
254 """Tuple of tensor axes lengths"""
255 return self._data.shape
257 @property
258 def size(self):
259 """Number of elements in the tensor.
261 Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions.
262 """
263 return self._data.size
265 @property
266 def sizes(self):
267 """Ordered, immutable mapping from axis ids to axis lengths."""
268 return cast(Mapping[AxisId, int], self.data.sizes)
270 @property
271 def tagged_shape(self):
272 """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths."""
273 return self.sizes
275 def to_numpy(self) -> NDArray[Any]:
276 """Return the data of this tensor as a numpy array."""
277 return self.data.to_numpy() # pyright: ignore[reportUnknownVariableType]
279 def argmax(self) -> Mapping[AxisId, int]:
280 ret = self._data.argmax(...)
281 assert isinstance(ret, dict)
282 return {cast(AxisId, k): cast(int, v.item()) for k, v in ret.items()}
284 def astype(self, dtype: DTypeStr, *, copy: bool = False):
285 """Return tensor cast to `dtype`
287 note: if dtype is already satisfied copy if `copy`"""
288 return self.__class__.from_xarray(self._data.astype(dtype, copy=copy))
290 def clip(self, min: Optional[float] = None, max: Optional[float] = None):
291 """Return a tensor whose values are limited to [min, max].
292 At least one of max or min must be given."""
293 return self.__class__.from_xarray(self._data.clip(min, max))
295 def crop_to(
296 self,
297 sizes: PerAxis[int],
298 crop_where: Union[
299 CropWhere,
300 PerAxis[CropWhere],
301 ] = "left_and_right",
302 ) -> Self:
303 """crop to match `sizes`"""
304 if isinstance(crop_where, str):
305 crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims}
306 else:
307 crop_axis_where = crop_where
309 slices: Dict[AxisId, SliceInfo] = {}
311 for a, s_is in self.sizes.items():
312 if a not in sizes or sizes[a] == s_is:
313 pass
314 elif sizes[a] > s_is:
315 logger.warning(
316 "Cannot crop axis {} of size {} to larger size {}",
317 a,
318 s_is,
319 sizes[a],
320 )
321 elif a not in crop_axis_where:
322 raise ValueError(
323 f"Don't know where to crop axis {a}, `crop_where`={crop_where}"
324 )
325 else:
326 crop_this_axis_where = crop_axis_where[a]
327 if crop_this_axis_where == "left":
328 slices[a] = SliceInfo(s_is - sizes[a], s_is)
329 elif crop_this_axis_where == "right":
330 slices[a] = SliceInfo(0, sizes[a])
331 elif crop_this_axis_where == "left_and_right":
332 slices[a] = SliceInfo(
333 start := (s_is - sizes[a]) // 2, sizes[a] + start
334 )
335 else:
336 assert_never(crop_this_axis_where)
338 return self[slices]
340 def expand_dims(self, dims: Union[Sequence[AxisId], PerAxis[int]]) -> Self:
341 return self.__class__.from_xarray(self._data.expand_dims(dims=dims))
343 def item(
344 self,
345 key: Union[
346 None, SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]]
347 ] = None,
348 ):
349 """Copy a tensor element to a standard Python scalar and return it."""
350 if key is None:
351 ret = self._data.item()
352 else:
353 ret = self[key]._data.item()
355 assert isinstance(ret, (bool, float, int))
356 return ret
358 def mean(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
359 return self.__class__.from_xarray(self._data.mean(dim=dim))
361 def pad(
362 self,
363 pad_width: PerAxis[PadWidthLike],
364 mode: PadMode = "symmetric",
365 ) -> Self:
366 pad_width = {a: PadWidth.create(p) for a, p in pad_width.items()}
367 mode_name, constant_value = _resolve_pad_mode(mode)
368 return self.__class__.from_xarray(
369 self._data.pad(
370 pad_width=pad_width, mode=mode_name, constant_values=constant_value
371 )
372 )
374 def pad_to(
375 self,
376 sizes: PerAxis[int],
377 pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right",
378 mode: PadMode = "symmetric",
379 ) -> Self:
380 """pad `tensor` to match `sizes`"""
381 if isinstance(pad_where, str):
382 pad_axis_where: PerAxis[PadWhere] = {a: pad_where for a in self.dims}
383 else:
384 pad_axis_where = pad_where
386 pad_width: Dict[AxisId, PadWidth] = {}
387 for a, s_is in self.sizes.items():
388 if a not in sizes or sizes[a] == s_is:
389 pad_width[a] = PadWidth(0, 0)
390 elif s_is > sizes[a]:
391 pad_width[a] = PadWidth(0, 0)
392 logger.warning(
393 "Cannot pad axis {} of size {} to smaller size {}",
394 a,
395 s_is,
396 sizes[a],
397 )
398 elif a not in pad_axis_where:
399 raise ValueError(
400 f"Don't know where to pad axis {a}, `pad_where`={pad_where}"
401 )
402 else:
403 pad_this_axis_where = pad_axis_where[a]
404 d = sizes[a] - s_is
405 if pad_this_axis_where == "left":
406 pad_width[a] = PadWidth(d, 0)
407 elif pad_this_axis_where == "right":
408 pad_width[a] = PadWidth(0, d)
409 elif pad_this_axis_where == "left_and_right":
410 pad_width[a] = PadWidth(left := d // 2, d - left)
411 else:
412 assert_never(pad_this_axis_where)
414 return self.pad(pad_width, mode)
416 def quantile(
417 self,
418 q: Union[float, Sequence[float]],
419 dim: Optional[Union[AxisId, Sequence[AxisId]]] = None,
420 method: QuantileMethod = "linear",
421 ) -> Self:
422 assert (
423 isinstance(q, (float, int))
424 and q >= 0.0
425 or not isinstance(q, (float, int))
426 and all(qq >= 0.0 for qq in q)
427 )
428 assert (
429 isinstance(q, (float, int))
430 and q <= 1.0
431 or not isinstance(q, (float, int))
432 and all(qq <= 1.0 for qq in q)
433 )
434 assert dim is None or (
435 (quantile_dim := AxisId("quantile")) != dim and quantile_dim not in set(dim)
436 )
437 return self.__class__.from_xarray(
438 self._data.quantile(q, dim=dim, method=method)
439 )
441 def resize_to(
442 self,
443 sizes: PerAxis[int],
444 *,
445 pad_where: Union[
446 PadWhere,
447 PerAxis[PadWhere],
448 ] = "left_and_right",
449 crop_where: Union[
450 CropWhere,
451 PerAxis[CropWhere],
452 ] = "left_and_right",
453 pad_mode: PadMode = "symmetric",
454 ):
455 """return cropped/padded tensor with `sizes`"""
456 crop_to_sizes: Dict[AxisId, int] = {}
457 pad_to_sizes: Dict[AxisId, int] = {}
458 new_axes = dict(sizes)
459 for a, s_is in self.sizes.items():
460 a = AxisId(str(a))
461 _ = new_axes.pop(a, None)
462 if a not in sizes or sizes[a] == s_is:
463 pass
464 elif s_is > sizes[a]:
465 crop_to_sizes[a] = sizes[a]
466 else:
467 pad_to_sizes[a] = sizes[a]
469 tensor = self
470 if crop_to_sizes:
471 tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where)
473 if pad_to_sizes:
474 tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode)
476 if new_axes:
477 tensor = tensor.expand_dims(new_axes)
479 return tensor
481 def std(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
482 return self.__class__.from_xarray(self._data.std(dim=dim))
484 def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
485 """Reduce this Tensor's data by applying sum along some dimension(s)."""
486 return self.__class__.from_xarray(self._data.sum(dim=dim))
488 def transpose(
489 self,
490 axes: Sequence[AxisId],
491 ) -> Self:
492 """return a transposed tensor
494 Args:
495 axes: the desired tensor axes
496 """
497 # expand missing tensor axes
498 missing_axes = tuple(a for a in axes if a not in self.dims)
499 array = self._data
500 if missing_axes:
501 array = array.expand_dims(missing_axes)
503 # transpose to the correct axis order
504 return self.__class__.from_xarray(array.transpose(*axes))
506 def var(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
507 return self.__class__.from_xarray(self._data.var(dim=dim))
509 @classmethod
510 def _interprete_array_wo_known_axes(cls, array: NDArray[Any]):
511 ndim = array.ndim
512 if ndim == 2:
513 current_axes = (
514 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[0]),
515 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[1]),
516 )
517 elif ndim == 3 and any(s <= 3 for s in array.shape):
518 current_axes = (
519 v0_5.ChannelAxis(
520 channel_names=[
521 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0])
522 ]
523 ),
524 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]),
525 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]),
526 )
527 elif ndim == 3:
528 current_axes = (
529 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[0]),
530 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]),
531 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]),
532 )
533 elif ndim == 4:
534 current_axes = (
535 v0_5.ChannelAxis(
536 channel_names=[
537 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0])
538 ]
539 ),
540 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[1]),
541 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[2]),
542 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[3]),
543 )
544 elif ndim == 5:
545 current_axes = (
546 v0_5.BatchAxis(),
547 v0_5.ChannelAxis(
548 channel_names=[
549 v0_5.Identifier(f"channel{i}") for i in range(array.shape[1])
550 ]
551 ),
552 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[2]),
553 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[3]),
554 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[4]),
555 )
556 else:
557 raise ValueError(f"Could not guess an axis mapping for {array.shape}")
559 return cls(array, dims=tuple(a.id for a in current_axes))
562def _add_singletons(arr: NDArray[Any], axis_infos: Sequence[AxisInfo]):
563 if len(arr.shape) > len(axis_infos):
564 # remove singletons
565 for i, s in enumerate(arr.shape):
566 if s == 1:
567 arr = np.take(arr, 0, axis=i)
568 if len(arr.shape) == len(axis_infos):
569 break
571 # add singletons if nececsary
572 for i, a in enumerate(axis_infos):
573 if len(arr.shape) >= len(axis_infos):
574 break
576 if a.size.min == 1:
577 arr = np.expand_dims(arr, i)
579 return arr
582def _get_array_view(
583 original_array: NDArray[Any], axis_infos: Sequence[AxisInfo]
584) -> Optional[NDArray[Any]]:
585 perms = list(permutations(range(len(original_array.shape))))
587 for perm in perms:
588 view = original_array.transpose(perm)
589 view = _add_singletons(view, axis_infos)
590 if len(view.shape) != len(axis_infos):
591 return None
593 for s, a in zip(view.shape, axis_infos):
594 if (
595 s < a.size.min
596 or (a.size.max is not None and s > a.size.max)
597 or (a.size.step is not None and (s - a.size.min) % a.size.step != 0)
598 ):
599 break
600 else:
601 return view
603 return None