Coverage for bioimageio/core/tensor.py: 85%
242 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-16 15:20 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-16 15:20 +0000
1from __future__ import annotations
3import collections.abc
4from itertools import permutations
5from typing import (
6 TYPE_CHECKING,
7 Any,
8 Callable,
9 Dict,
10 Iterator,
11 Mapping,
12 Optional,
13 Sequence,
14 Tuple,
15 Union,
16 cast,
17 get_args,
18)
20import numpy as np
21import xarray as xr
22from loguru import logger
23from numpy.typing import DTypeLike, NDArray
24from typing_extensions import Self, assert_never
26from bioimageio.spec.model import v0_5
28from ._magic_tensor_ops import MagicTensorOpsMixin
29from .axis import Axis, AxisId, AxisInfo, AxisLike, PerAxis
30from .common import (
31 CropWhere,
32 DTypeStr,
33 PadMode,
34 PadWhere,
35 PadWidth,
36 PadWidthLike,
37 SliceInfo,
38)
40if TYPE_CHECKING:
41 from numpy.typing import ArrayLike, NDArray
44_ScalarOrArray = Union["ArrayLike", np.generic, "NDArray[Any]"] # TODO: add "DaskArray"
47# TODO: complete docstrings
48class Tensor(MagicTensorOpsMixin):
49 """A wrapper around an xr.DataArray for better integration with bioimageio.spec
50 and improved type annotations."""
52 _Compatible = Union["Tensor", xr.DataArray, _ScalarOrArray]
54 def __init__(
55 self,
56 array: NDArray[Any],
57 dims: Sequence[Union[AxisId, AxisLike]],
58 ) -> None:
59 super().__init__()
60 axes = tuple(
61 a if isinstance(a, AxisId) else AxisInfo.create(a).id for a in dims
62 )
63 self._data = xr.DataArray(array, dims=axes)
65 def __array__(self, dtype: DTypeLike = None):
66 return np.asarray(self._data, dtype=dtype)
68 def __getitem__(
69 self,
70 key: Union[
71 SliceInfo,
72 slice,
73 int,
74 PerAxis[Union[SliceInfo, slice, int]],
75 Tensor,
76 xr.DataArray,
77 ],
78 ) -> Self:
79 if isinstance(key, SliceInfo):
80 key = slice(*key)
81 elif isinstance(key, collections.abc.Mapping):
82 key = {
83 a: s if isinstance(s, int) else s if isinstance(s, slice) else slice(*s)
84 for a, s in key.items()
85 }
86 elif isinstance(key, Tensor):
87 key = key._data
89 return self.__class__.from_xarray(self._data[key])
91 def __setitem__(
92 self,
93 key: Union[PerAxis[Union[SliceInfo, slice]], Tensor, xr.DataArray],
94 value: Union[Tensor, xr.DataArray, float, int],
95 ) -> None:
96 if isinstance(key, Tensor):
97 key = key._data
98 elif isinstance(key, xr.DataArray):
99 pass
100 else:
101 key = {a: s if isinstance(s, slice) else slice(*s) for a, s in key.items()}
103 if isinstance(value, Tensor):
104 value = value._data
106 self._data[key] = value
108 def __len__(self) -> int:
109 return len(self.data)
111 def _iter(self: Any) -> Iterator[Any]:
112 for n in range(len(self)):
113 yield self[n]
115 def __iter__(self: Any) -> Iterator[Any]:
116 if self.ndim == 0:
117 raise TypeError("iteration over a 0-d array")
118 return self._iter()
120 def _binary_op(
121 self,
122 other: _Compatible,
123 f: Callable[[Any, Any], Any],
124 reflexive: bool = False,
125 ) -> Self:
126 data = self._data._binary_op( # pyright: ignore[reportPrivateUsage]
127 (other._data if isinstance(other, Tensor) else other),
128 f,
129 reflexive,
130 )
131 return self.__class__.from_xarray(data)
133 def _inplace_binary_op(
134 self,
135 other: _Compatible,
136 f: Callable[[Any, Any], Any],
137 ) -> Self:
138 _ = self._data._inplace_binary_op( # pyright: ignore[reportPrivateUsage]
139 (
140 other_d
141 if (other_d := getattr(other, "data")) is not None
142 and isinstance(
143 other_d,
144 xr.DataArray,
145 )
146 else other
147 ),
148 f,
149 )
150 return self
152 def _unary_op(self, f: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Self:
153 data = self._data._unary_op( # pyright: ignore[reportPrivateUsage]
154 f, *args, **kwargs
155 )
156 return self.__class__.from_xarray(data)
158 @classmethod
159 def from_xarray(cls, data_array: xr.DataArray) -> Self:
160 """create a `Tensor` from an xarray data array
162 note for internal use: this factory method is round-trip save
163 for any `Tensor`'s `data` property (an xarray.DataArray).
164 """
165 return cls(
166 array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims)
167 )
169 @classmethod
170 def from_numpy(
171 cls,
172 array: NDArray[Any],
173 *,
174 dims: Optional[Union[AxisLike, Sequence[AxisLike]]],
175 ) -> Tensor:
176 """create a `Tensor` from a numpy array
178 Args:
179 array: the nd numpy array
180 axes: A description of the array's axes,
181 if None axes are guessed (which might fail and raise a ValueError.)
183 Raises:
184 ValueError: if `axes` is None and axes guessing fails.
185 """
187 if dims is None:
188 return cls._interprete_array_wo_known_axes(array)
189 elif isinstance(dims, (str, Axis, v0_5.AxisBase)):
190 dims = [dims]
192 axis_infos = [AxisInfo.create(a) for a in dims]
193 original_shape = tuple(array.shape)
195 successful_view = _get_array_view(array, axis_infos)
196 if successful_view is None:
197 raise ValueError(
198 f"Array shape {original_shape} does not map to axes {dims}"
199 )
201 return Tensor(successful_view, dims=tuple(a.id for a in axis_infos))
203 @property
204 def data(self):
205 return self._data
207 @property
208 def dims(self): # TODO: rename to `axes`?
209 """Tuple of dimension names associated with this tensor."""
210 return cast(Tuple[AxisId, ...], self._data.dims)
212 @property
213 def dtype(self) -> DTypeStr:
214 dt = str(self.data.dtype) # pyright: ignore[reportUnknownArgumentType]
215 assert dt in get_args(DTypeStr)
216 return dt # pyright: ignore[reportReturnType]
218 @property
219 def ndim(self):
220 """Number of tensor dimensions."""
221 return self._data.ndim
223 @property
224 def shape(self):
225 """Tuple of tensor axes lengths"""
226 return self._data.shape
228 @property
229 def shape_tuple(self):
230 """Tuple of tensor axes lengths"""
231 return self._data.shape
233 @property
234 def size(self):
235 """Number of elements in the tensor.
237 Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions.
238 """
239 return self._data.size
241 @property
242 def sizes(self):
243 """Ordered, immutable mapping from axis ids to axis lengths."""
244 return cast(Mapping[AxisId, int], self.data.sizes)
246 @property
247 def tagged_shape(self):
248 """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths."""
249 return self.sizes
251 def argmax(self) -> Mapping[AxisId, int]:
252 ret = self._data.argmax(...)
253 assert isinstance(ret, dict)
254 return {cast(AxisId, k): cast(int, v.item()) for k, v in ret.items()}
256 def astype(self, dtype: DTypeStr, *, copy: bool = False):
257 """Return tensor cast to `dtype`
259 note: if dtype is already satisfied copy if `copy`"""
260 return self.__class__.from_xarray(self._data.astype(dtype, copy=copy))
262 def clip(self, min: Optional[float] = None, max: Optional[float] = None):
263 """Return a tensor whose values are limited to [min, max].
264 At least one of max or min must be given."""
265 return self.__class__.from_xarray(self._data.clip(min, max))
267 def crop_to(
268 self,
269 sizes: PerAxis[int],
270 crop_where: Union[
271 CropWhere,
272 PerAxis[CropWhere],
273 ] = "left_and_right",
274 ) -> Self:
275 """crop to match `sizes`"""
276 if isinstance(crop_where, str):
277 crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims}
278 else:
279 crop_axis_where = crop_where
281 slices: Dict[AxisId, SliceInfo] = {}
283 for a, s_is in self.sizes.items():
284 if a not in sizes or sizes[a] == s_is:
285 pass
286 elif sizes[a] > s_is:
287 logger.warning(
288 "Cannot crop axis {} of size {} to larger size {}",
289 a,
290 s_is,
291 sizes[a],
292 )
293 elif a not in crop_axis_where:
294 raise ValueError(
295 f"Don't know where to crop axis {a}, `crop_where`={crop_where}"
296 )
297 else:
298 crop_this_axis_where = crop_axis_where[a]
299 if crop_this_axis_where == "left":
300 slices[a] = SliceInfo(s_is - sizes[a], s_is)
301 elif crop_this_axis_where == "right":
302 slices[a] = SliceInfo(0, sizes[a])
303 elif crop_this_axis_where == "left_and_right":
304 slices[a] = SliceInfo(
305 start := (s_is - sizes[a]) // 2, sizes[a] + start
306 )
307 else:
308 assert_never(crop_this_axis_where)
310 return self[slices]
312 def expand_dims(self, dims: Union[Sequence[AxisId], PerAxis[int]]) -> Self:
313 return self.__class__.from_xarray(self._data.expand_dims(dims=dims))
315 def item(
316 self,
317 key: Union[
318 None, SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]]
319 ] = None,
320 ):
321 """Copy a tensor element to a standard Python scalar and return it."""
322 if key is None:
323 ret = self._data.item()
324 else:
325 ret = self[key]._data.item()
327 assert isinstance(ret, (bool, float, int))
328 return ret
330 def mean(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
331 return self.__class__.from_xarray(self._data.mean(dim=dim))
333 def pad(
334 self,
335 pad_width: PerAxis[PadWidthLike],
336 mode: PadMode = "symmetric",
337 ) -> Self:
338 pad_width = {a: PadWidth.create(p) for a, p in pad_width.items()}
339 return self.__class__.from_xarray(
340 self._data.pad(pad_width=pad_width, mode=mode)
341 )
343 def pad_to(
344 self,
345 sizes: PerAxis[int],
346 pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right",
347 mode: PadMode = "symmetric",
348 ) -> Self:
349 """pad `tensor` to match `sizes`"""
350 if isinstance(pad_where, str):
351 pad_axis_where: PerAxis[PadWhere] = {a: pad_where for a in self.dims}
352 else:
353 pad_axis_where = pad_where
355 pad_width: Dict[AxisId, PadWidth] = {}
356 for a, s_is in self.sizes.items():
357 if a not in sizes or sizes[a] == s_is:
358 pad_width[a] = PadWidth(0, 0)
359 elif s_is > sizes[a]:
360 pad_width[a] = PadWidth(0, 0)
361 logger.warning(
362 "Cannot pad axis {} of size {} to smaller size {}",
363 a,
364 s_is,
365 sizes[a],
366 )
367 elif a not in pad_axis_where:
368 raise ValueError(
369 f"Don't know where to pad axis {a}, `pad_where`={pad_where}"
370 )
371 else:
372 pad_this_axis_where = pad_axis_where[a]
373 d = sizes[a] - s_is
374 if pad_this_axis_where == "left":
375 pad_width[a] = PadWidth(d, 0)
376 elif pad_this_axis_where == "right":
377 pad_width[a] = PadWidth(0, d)
378 elif pad_this_axis_where == "left_and_right":
379 pad_width[a] = PadWidth(left := d // 2, d - left)
380 else:
381 assert_never(pad_this_axis_where)
383 return self.pad(pad_width, mode)
385 def quantile(
386 self,
387 q: Union[float, Sequence[float]],
388 dim: Optional[Union[AxisId, Sequence[AxisId]]] = None,
389 ) -> Self:
390 assert (
391 isinstance(q, (float, int))
392 and q >= 0.0
393 or not isinstance(q, (float, int))
394 and all(qq >= 0.0 for qq in q)
395 )
396 assert (
397 isinstance(q, (float, int))
398 and q <= 1.0
399 or not isinstance(q, (float, int))
400 and all(qq <= 1.0 for qq in q)
401 )
402 assert dim is None or (
403 (quantile_dim := AxisId("quantile")) != dim and quantile_dim not in set(dim)
404 )
405 return self.__class__.from_xarray(self._data.quantile(q, dim=dim))
407 def resize_to(
408 self,
409 sizes: PerAxis[int],
410 *,
411 pad_where: Union[
412 PadWhere,
413 PerAxis[PadWhere],
414 ] = "left_and_right",
415 crop_where: Union[
416 CropWhere,
417 PerAxis[CropWhere],
418 ] = "left_and_right",
419 pad_mode: PadMode = "symmetric",
420 ):
421 """return cropped/padded tensor with `sizes`"""
422 crop_to_sizes: Dict[AxisId, int] = {}
423 pad_to_sizes: Dict[AxisId, int] = {}
424 new_axes = dict(sizes)
425 for a, s_is in self.sizes.items():
426 a = AxisId(str(a))
427 _ = new_axes.pop(a, None)
428 if a not in sizes or sizes[a] == s_is:
429 pass
430 elif s_is > sizes[a]:
431 crop_to_sizes[a] = sizes[a]
432 else:
433 pad_to_sizes[a] = sizes[a]
435 tensor = self
436 if crop_to_sizes:
437 tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where)
439 if pad_to_sizes:
440 tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode)
442 if new_axes:
443 tensor = tensor.expand_dims(new_axes)
445 return tensor
447 def std(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
448 return self.__class__.from_xarray(self._data.std(dim=dim))
450 def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
451 """Reduce this Tensor's data by applying sum along some dimension(s)."""
452 return self.__class__.from_xarray(self._data.sum(dim=dim))
454 def transpose(
455 self,
456 axes: Sequence[AxisId],
457 ) -> Self:
458 """return a transposed tensor
460 Args:
461 axes: the desired tensor axes
462 """
463 # expand missing tensor axes
464 missing_axes = tuple(a for a in axes if a not in self.dims)
465 array = self._data
466 if missing_axes:
467 array = array.expand_dims(missing_axes)
469 # transpose to the correct axis order
470 return self.__class__.from_xarray(array.transpose(*axes))
472 def var(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
473 return self.__class__.from_xarray(self._data.var(dim=dim))
475 @classmethod
476 def _interprete_array_wo_known_axes(cls, array: NDArray[Any]):
477 ndim = array.ndim
478 if ndim == 2:
479 current_axes = (
480 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[0]),
481 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[1]),
482 )
483 elif ndim == 3 and any(s <= 3 for s in array.shape):
484 current_axes = (
485 v0_5.ChannelAxis(
486 channel_names=[
487 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0])
488 ]
489 ),
490 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]),
491 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]),
492 )
493 elif ndim == 3:
494 current_axes = (
495 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[0]),
496 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]),
497 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]),
498 )
499 elif ndim == 4:
500 current_axes = (
501 v0_5.ChannelAxis(
502 channel_names=[
503 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0])
504 ]
505 ),
506 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[1]),
507 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[2]),
508 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[3]),
509 )
510 elif ndim == 5:
511 current_axes = (
512 v0_5.BatchAxis(),
513 v0_5.ChannelAxis(
514 channel_names=[
515 v0_5.Identifier(f"channel{i}") for i in range(array.shape[1])
516 ]
517 ),
518 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[2]),
519 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[3]),
520 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[4]),
521 )
522 else:
523 raise ValueError(f"Could not guess an axis mapping for {array.shape}")
525 return cls(array, dims=tuple(a.id for a in current_axes))
528def _add_singletons(arr: NDArray[Any], axis_infos: Sequence[AxisInfo]):
529 if len(arr.shape) > len(axis_infos):
530 # remove singletons
531 for i, s in enumerate(arr.shape):
532 if s == 1:
533 arr = np.take(arr, 0, axis=i)
534 if len(arr.shape) == len(axis_infos):
535 break
537 # add singletons if nececsary
538 for i, a in enumerate(axis_infos):
539 if len(arr.shape) >= len(axis_infos):
540 break
542 if a.maybe_singleton:
543 arr = np.expand_dims(arr, i)
545 return arr
548def _get_array_view(
549 original_array: NDArray[Any], axis_infos: Sequence[AxisInfo]
550) -> Optional[NDArray[Any]]:
551 perms = list(permutations(range(len(original_array.shape))))
552 perms.insert(1, perms.pop()) # try A and A.T first
554 for perm in perms:
555 view = original_array.transpose(perm)
556 view = _add_singletons(view, axis_infos)
557 if len(view.shape) != len(axis_infos):
558 return None
560 for s, a in zip(view.shape, axis_infos):
561 if s == 1 and not a.maybe_singleton:
562 break
563 else:
564 return view
566 return None