Coverage for src/bioimageio/core/tensor.py: 82%
242 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-09-22 09:21 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-09-22 09:21 +0000
1from __future__ import annotations
3import collections.abc
4from itertools import permutations
5from typing import (
6 TYPE_CHECKING,
7 Any,
8 Callable,
9 Dict,
10 Iterator,
11 Mapping,
12 Optional,
13 Sequence,
14 Tuple,
15 Union,
16 cast,
17 get_args,
18)
20import numpy as np
21import xarray as xr
22from loguru import logger
23from numpy.typing import DTypeLike, NDArray
24from typing_extensions import Self, assert_never
26from bioimageio.spec.model import v0_5
28from ._magic_tensor_ops import MagicTensorOpsMixin
29from .axis import Axis, AxisId, AxisInfo, AxisLike, PerAxis
30from .common import (
31 CropWhere,
32 DTypeStr,
33 PadMode,
34 PadWhere,
35 PadWidth,
36 PadWidthLike,
37 SliceInfo,
38)
40if TYPE_CHECKING:
41 from numpy.typing import ArrayLike, NDArray
44_ScalarOrArray = Union["ArrayLike", np.generic, "NDArray[Any]"] # TODO: add "DaskArray"
47# TODO: complete docstrings
48# TODO: in the long run---with improved typing in xarray---we should probably replace `Tensor` with xr.DataArray
49class Tensor(MagicTensorOpsMixin):
50 """A wrapper around an xr.DataArray for better integration with bioimageio.spec
51 and improved type annotations."""
53 _Compatible = Union["Tensor", xr.DataArray, _ScalarOrArray]
55 def __init__(
56 self,
57 array: NDArray[Any],
58 dims: Sequence[Union[AxisId, AxisLike]],
59 ) -> None:
60 super().__init__()
61 axes = tuple(
62 a if isinstance(a, AxisId) else AxisInfo.create(a).id for a in dims
63 )
64 self._data = xr.DataArray(array, dims=axes)
66 def __array__(self, dtype: DTypeLike = None):
67 return np.asarray(self._data, dtype=dtype)
69 def __getitem__(
70 self,
71 key: Union[
72 SliceInfo,
73 slice,
74 int,
75 PerAxis[Union[SliceInfo, slice, int]],
76 Tensor,
77 xr.DataArray,
78 ],
79 ) -> Self:
80 if isinstance(key, SliceInfo):
81 key = slice(*key)
82 elif isinstance(key, collections.abc.Mapping):
83 key = {
84 a: s if isinstance(s, int) else s if isinstance(s, slice) else slice(*s)
85 for a, s in key.items()
86 }
87 elif isinstance(key, Tensor):
88 key = key._data
90 return self.__class__.from_xarray(self._data[key])
92 def __setitem__(
93 self,
94 key: Union[PerAxis[Union[SliceInfo, slice]], Tensor, xr.DataArray],
95 value: Union[Tensor, xr.DataArray, float, int],
96 ) -> None:
97 if isinstance(key, Tensor):
98 key = key._data
99 elif isinstance(key, xr.DataArray):
100 pass
101 else:
102 key = {a: s if isinstance(s, slice) else slice(*s) for a, s in key.items()}
104 if isinstance(value, Tensor):
105 value = value._data
107 self._data[key] = value
109 def __len__(self) -> int:
110 return len(self.data)
112 def _iter(self: Any) -> Iterator[Any]:
113 for n in range(len(self)):
114 yield self[n]
116 def __iter__(self: Any) -> Iterator[Any]:
117 if self.ndim == 0:
118 raise TypeError("iteration over a 0-d array")
119 return self._iter()
121 def _binary_op(
122 self,
123 other: _Compatible,
124 f: Callable[[Any, Any], Any],
125 reflexive: bool = False,
126 ) -> Self:
127 data = self._data._binary_op( # pyright: ignore[reportPrivateUsage]
128 (other._data if isinstance(other, Tensor) else other),
129 f,
130 reflexive,
131 )
132 return self.__class__.from_xarray(data)
134 def _inplace_binary_op(
135 self,
136 other: _Compatible,
137 f: Callable[[Any, Any], Any],
138 ) -> Self:
139 _ = self._data._inplace_binary_op( # pyright: ignore[reportPrivateUsage]
140 (
141 other_d
142 if (other_d := getattr(other, "data")) is not None
143 and isinstance(
144 other_d,
145 xr.DataArray,
146 )
147 else other
148 ),
149 f,
150 )
151 return self
153 def _unary_op(self, f: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Self:
154 data = self._data._unary_op( # pyright: ignore[reportPrivateUsage]
155 f, *args, **kwargs
156 )
157 return self.__class__.from_xarray(data)
159 @classmethod
160 def from_xarray(cls, data_array: xr.DataArray) -> Self:
161 """create a `Tensor` from an xarray data array
163 note for internal use: this factory method is round-trip save
164 for any `Tensor`'s `data` property (an xarray.DataArray).
165 """
166 return cls(
167 array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims)
168 )
170 @classmethod
171 def from_numpy(
172 cls,
173 array: NDArray[Any],
174 *,
175 dims: Optional[Union[AxisLike, Sequence[AxisLike]]],
176 ) -> Tensor:
177 """create a `Tensor` from a numpy array
179 Args:
180 array: the nd numpy array
181 axes: A description of the array's axes,
182 if None axes are guessed (which might fail and raise a ValueError.)
184 Raises:
185 ValueError: if `axes` is None and axes guessing fails.
186 """
188 if dims is None:
189 return cls._interprete_array_wo_known_axes(array)
190 elif isinstance(dims, (str, Axis, v0_5.AxisBase)):
191 dims = [dims]
193 axis_infos = [AxisInfo.create(a) for a in dims]
194 original_shape = tuple(array.shape)
196 successful_view = _get_array_view(array, axis_infos)
197 if successful_view is None:
198 raise ValueError(
199 f"Array shape {original_shape} does not map to axes {dims}"
200 )
202 return Tensor(successful_view, dims=tuple(a.id for a in axis_infos))
204 @property
205 def data(self):
206 return self._data
208 @property
209 def dims(self): # TODO: rename to `axes`?
210 """Tuple of dimension names associated with this tensor."""
211 return cast(Tuple[AxisId, ...], self._data.dims)
213 @property
214 def dtype(self) -> DTypeStr:
215 dt = str(self.data.dtype) # pyright: ignore[reportUnknownArgumentType]
216 assert dt in get_args(DTypeStr)
217 return dt # pyright: ignore[reportReturnType]
219 @property
220 def ndim(self):
221 """Number of tensor dimensions."""
222 return self._data.ndim
224 @property
225 def shape(self):
226 """Tuple of tensor axes lengths"""
227 return self._data.shape
229 @property
230 def shape_tuple(self):
231 """Tuple of tensor axes lengths"""
232 return self._data.shape
234 @property
235 def size(self):
236 """Number of elements in the tensor.
238 Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions.
239 """
240 return self._data.size
242 @property
243 def sizes(self):
244 """Ordered, immutable mapping from axis ids to axis lengths."""
245 return cast(Mapping[AxisId, int], self.data.sizes)
247 @property
248 def tagged_shape(self):
249 """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths."""
250 return self.sizes
252 def argmax(self) -> Mapping[AxisId, int]:
253 ret = self._data.argmax(...)
254 assert isinstance(ret, dict)
255 return {cast(AxisId, k): cast(int, v.item()) for k, v in ret.items()}
257 def astype(self, dtype: DTypeStr, *, copy: bool = False):
258 """Return tensor cast to `dtype`
260 note: if dtype is already satisfied copy if `copy`"""
261 return self.__class__.from_xarray(self._data.astype(dtype, copy=copy))
263 def clip(self, min: Optional[float] = None, max: Optional[float] = None):
264 """Return a tensor whose values are limited to [min, max].
265 At least one of max or min must be given."""
266 return self.__class__.from_xarray(self._data.clip(min, max))
268 def crop_to(
269 self,
270 sizes: PerAxis[int],
271 crop_where: Union[
272 CropWhere,
273 PerAxis[CropWhere],
274 ] = "left_and_right",
275 ) -> Self:
276 """crop to match `sizes`"""
277 if isinstance(crop_where, str):
278 crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims}
279 else:
280 crop_axis_where = crop_where
282 slices: Dict[AxisId, SliceInfo] = {}
284 for a, s_is in self.sizes.items():
285 if a not in sizes or sizes[a] == s_is:
286 pass
287 elif sizes[a] > s_is:
288 logger.warning(
289 "Cannot crop axis {} of size {} to larger size {}",
290 a,
291 s_is,
292 sizes[a],
293 )
294 elif a not in crop_axis_where:
295 raise ValueError(
296 f"Don't know where to crop axis {a}, `crop_where`={crop_where}"
297 )
298 else:
299 crop_this_axis_where = crop_axis_where[a]
300 if crop_this_axis_where == "left":
301 slices[a] = SliceInfo(s_is - sizes[a], s_is)
302 elif crop_this_axis_where == "right":
303 slices[a] = SliceInfo(0, sizes[a])
304 elif crop_this_axis_where == "left_and_right":
305 slices[a] = SliceInfo(
306 start := (s_is - sizes[a]) // 2, sizes[a] + start
307 )
308 else:
309 assert_never(crop_this_axis_where)
311 return self[slices]
313 def expand_dims(self, dims: Union[Sequence[AxisId], PerAxis[int]]) -> Self:
314 return self.__class__.from_xarray(self._data.expand_dims(dims=dims))
316 def item(
317 self,
318 key: Union[
319 None, SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]]
320 ] = None,
321 ):
322 """Copy a tensor element to a standard Python scalar and return it."""
323 if key is None:
324 ret = self._data.item()
325 else:
326 ret = self[key]._data.item()
328 assert isinstance(ret, (bool, float, int))
329 return ret
331 def mean(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
332 return self.__class__.from_xarray(self._data.mean(dim=dim))
334 def pad(
335 self,
336 pad_width: PerAxis[PadWidthLike],
337 mode: PadMode = "symmetric",
338 ) -> Self:
339 pad_width = {a: PadWidth.create(p) for a, p in pad_width.items()}
340 return self.__class__.from_xarray(
341 self._data.pad(pad_width=pad_width, mode=mode)
342 )
344 def pad_to(
345 self,
346 sizes: PerAxis[int],
347 pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right",
348 mode: PadMode = "symmetric",
349 ) -> Self:
350 """pad `tensor` to match `sizes`"""
351 if isinstance(pad_where, str):
352 pad_axis_where: PerAxis[PadWhere] = {a: pad_where for a in self.dims}
353 else:
354 pad_axis_where = pad_where
356 pad_width: Dict[AxisId, PadWidth] = {}
357 for a, s_is in self.sizes.items():
358 if a not in sizes or sizes[a] == s_is:
359 pad_width[a] = PadWidth(0, 0)
360 elif s_is > sizes[a]:
361 pad_width[a] = PadWidth(0, 0)
362 logger.warning(
363 "Cannot pad axis {} of size {} to smaller size {}",
364 a,
365 s_is,
366 sizes[a],
367 )
368 elif a not in pad_axis_where:
369 raise ValueError(
370 f"Don't know where to pad axis {a}, `pad_where`={pad_where}"
371 )
372 else:
373 pad_this_axis_where = pad_axis_where[a]
374 d = sizes[a] - s_is
375 if pad_this_axis_where == "left":
376 pad_width[a] = PadWidth(d, 0)
377 elif pad_this_axis_where == "right":
378 pad_width[a] = PadWidth(0, d)
379 elif pad_this_axis_where == "left_and_right":
380 pad_width[a] = PadWidth(left := d // 2, d - left)
381 else:
382 assert_never(pad_this_axis_where)
384 return self.pad(pad_width, mode)
386 def quantile(
387 self,
388 q: Union[float, Sequence[float]],
389 dim: Optional[Union[AxisId, Sequence[AxisId]]] = None,
390 ) -> Self:
391 assert (
392 isinstance(q, (float, int))
393 and q >= 0.0
394 or not isinstance(q, (float, int))
395 and all(qq >= 0.0 for qq in q)
396 )
397 assert (
398 isinstance(q, (float, int))
399 and q <= 1.0
400 or not isinstance(q, (float, int))
401 and all(qq <= 1.0 for qq in q)
402 )
403 assert dim is None or (
404 (quantile_dim := AxisId("quantile")) != dim and quantile_dim not in set(dim)
405 )
406 return self.__class__.from_xarray(self._data.quantile(q, dim=dim))
408 def resize_to(
409 self,
410 sizes: PerAxis[int],
411 *,
412 pad_where: Union[
413 PadWhere,
414 PerAxis[PadWhere],
415 ] = "left_and_right",
416 crop_where: Union[
417 CropWhere,
418 PerAxis[CropWhere],
419 ] = "left_and_right",
420 pad_mode: PadMode = "symmetric",
421 ):
422 """return cropped/padded tensor with `sizes`"""
423 crop_to_sizes: Dict[AxisId, int] = {}
424 pad_to_sizes: Dict[AxisId, int] = {}
425 new_axes = dict(sizes)
426 for a, s_is in self.sizes.items():
427 a = AxisId(str(a))
428 _ = new_axes.pop(a, None)
429 if a not in sizes or sizes[a] == s_is:
430 pass
431 elif s_is > sizes[a]:
432 crop_to_sizes[a] = sizes[a]
433 else:
434 pad_to_sizes[a] = sizes[a]
436 tensor = self
437 if crop_to_sizes:
438 tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where)
440 if pad_to_sizes:
441 tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode)
443 if new_axes:
444 tensor = tensor.expand_dims(new_axes)
446 return tensor
448 def std(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
449 return self.__class__.from_xarray(self._data.std(dim=dim))
451 def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
452 """Reduce this Tensor's data by applying sum along some dimension(s)."""
453 return self.__class__.from_xarray(self._data.sum(dim=dim))
455 def transpose(
456 self,
457 axes: Sequence[AxisId],
458 ) -> Self:
459 """return a transposed tensor
461 Args:
462 axes: the desired tensor axes
463 """
464 # expand missing tensor axes
465 missing_axes = tuple(a for a in axes if a not in self.dims)
466 array = self._data
467 if missing_axes:
468 array = array.expand_dims(missing_axes)
470 # transpose to the correct axis order
471 return self.__class__.from_xarray(array.transpose(*axes))
473 def var(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
474 return self.__class__.from_xarray(self._data.var(dim=dim))
476 @classmethod
477 def _interprete_array_wo_known_axes(cls, array: NDArray[Any]):
478 ndim = array.ndim
479 if ndim == 2:
480 current_axes = (
481 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[0]),
482 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[1]),
483 )
484 elif ndim == 3 and any(s <= 3 for s in array.shape):
485 current_axes = (
486 v0_5.ChannelAxis(
487 channel_names=[
488 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0])
489 ]
490 ),
491 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]),
492 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]),
493 )
494 elif ndim == 3:
495 current_axes = (
496 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[0]),
497 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]),
498 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]),
499 )
500 elif ndim == 4:
501 current_axes = (
502 v0_5.ChannelAxis(
503 channel_names=[
504 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0])
505 ]
506 ),
507 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[1]),
508 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[2]),
509 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[3]),
510 )
511 elif ndim == 5:
512 current_axes = (
513 v0_5.BatchAxis(),
514 v0_5.ChannelAxis(
515 channel_names=[
516 v0_5.Identifier(f"channel{i}") for i in range(array.shape[1])
517 ]
518 ),
519 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[2]),
520 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[3]),
521 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[4]),
522 )
523 else:
524 raise ValueError(f"Could not guess an axis mapping for {array.shape}")
526 return cls(array, dims=tuple(a.id for a in current_axes))
529def _add_singletons(arr: NDArray[Any], axis_infos: Sequence[AxisInfo]):
530 if len(arr.shape) > len(axis_infos):
531 # remove singletons
532 for i, s in enumerate(arr.shape):
533 if s == 1:
534 arr = np.take(arr, 0, axis=i)
535 if len(arr.shape) == len(axis_infos):
536 break
538 # add singletons if nececsary
539 for i, a in enumerate(axis_infos):
540 if len(arr.shape) >= len(axis_infos):
541 break
543 if a.maybe_singleton:
544 arr = np.expand_dims(arr, i)
546 return arr
549def _get_array_view(
550 original_array: NDArray[Any], axis_infos: Sequence[AxisInfo]
551) -> Optional[NDArray[Any]]:
552 perms = list(permutations(range(len(original_array.shape))))
553 perms.insert(1, perms.pop()) # try A and A.T first
555 for perm in perms:
556 view = original_array.transpose(perm)
557 view = _add_singletons(view, axis_infos)
558 if len(view.shape) != len(axis_infos):
559 return None
561 for s, a in zip(view.shape, axis_infos):
562 if s == 1 and not a.maybe_singleton:
563 break
564 else:
565 return view
567 return None