Coverage for src/bioimageio/core/tensor.py: 82%
243 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-14 08:35 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-14 08:35 +0000
1from __future__ import annotations
3import collections.abc
4from itertools import permutations
5from typing import (
6 TYPE_CHECKING,
7 Any,
8 Callable,
9 Dict,
10 Iterator,
11 Mapping,
12 Optional,
13 Sequence,
14 Tuple,
15 Union,
16 cast,
17 get_args,
18)
20import numpy as np
21import xarray as xr
22from bioimageio.spec.model import v0_5
23from loguru import logger
24from numpy.typing import DTypeLike, NDArray
25from typing_extensions import Self, assert_never
27from ._magic_tensor_ops import MagicTensorOpsMixin
28from .axis import AxisId, AxisInfo, AxisLike, PerAxis
29from .common import (
30 CropWhere,
31 DTypeStr,
32 PadMode,
33 PadWhere,
34 PadWidth,
35 PadWidthLike,
36 SliceInfo,
37)
39if TYPE_CHECKING:
40 from numpy.typing import ArrayLike, NDArray
43_ScalarOrArray = Union["ArrayLike", np.generic, "NDArray[Any]"] # TODO: add "DaskArray"
46# TODO: complete docstrings
47# TODO: in the long run---with improved typing in xarray---we should probably replace `Tensor` with xr.DataArray
48class Tensor(MagicTensorOpsMixin):
49 """A wrapper around an xr.DataArray for better integration with bioimageio.spec
50 and improved type annotations."""
52 _Compatible = Union["Tensor", xr.DataArray, _ScalarOrArray]
54 def __init__(
55 self,
56 array: NDArray[Any],
57 dims: Sequence[Union[AxisId, AxisLike]],
58 ) -> None:
59 super().__init__()
60 axes = tuple(
61 a if isinstance(a, AxisId) else AxisInfo.create(a).id for a in dims
62 )
63 self._data = xr.DataArray(array, dims=axes)
65 def __array__(self, dtype: DTypeLike = None):
66 return np.asarray(self._data, dtype=dtype)
68 def __getitem__(
69 self,
70 key: Union[
71 SliceInfo,
72 slice,
73 int,
74 PerAxis[Union[SliceInfo, slice, int]],
75 Tensor,
76 xr.DataArray,
77 ],
78 ) -> Self:
79 if isinstance(key, SliceInfo):
80 key = slice(*key)
81 elif isinstance(key, collections.abc.Mapping):
82 key = {
83 a: s if isinstance(s, int) else s if isinstance(s, slice) else slice(*s)
84 for a, s in key.items()
85 }
86 elif isinstance(key, Tensor):
87 key = key._data
89 return self.__class__.from_xarray(self._data[key])
91 def __setitem__(
92 self,
93 key: Union[PerAxis[Union[SliceInfo, slice]], Tensor, xr.DataArray],
94 value: Union[Tensor, xr.DataArray, float, int],
95 ) -> None:
96 if isinstance(key, Tensor):
97 key = key._data
98 elif isinstance(key, xr.DataArray):
99 pass
100 else:
101 key = {a: s if isinstance(s, slice) else slice(*s) for a, s in key.items()}
103 if isinstance(value, Tensor):
104 value = value._data
106 self._data[key] = value
108 def __len__(self) -> int:
109 return len(self.data)
111 def _iter(self: Any) -> Iterator[Any]:
112 for n in range(len(self)):
113 yield self[n]
115 def __iter__(self: Any) -> Iterator[Any]:
116 if self.ndim == 0:
117 raise TypeError("iteration over a 0-d array")
118 return self._iter()
120 def _binary_op(
121 self,
122 other: _Compatible,
123 f: Callable[[Any, Any], Any],
124 reflexive: bool = False,
125 ) -> Self:
126 data = self._data._binary_op( # pyright: ignore[reportPrivateUsage]
127 (other._data if isinstance(other, Tensor) else other),
128 f,
129 reflexive,
130 )
131 return self.__class__.from_xarray(data)
133 def _inplace_binary_op(
134 self,
135 other: _Compatible,
136 f: Callable[[Any, Any], Any],
137 ) -> Self:
138 _ = self._data._inplace_binary_op( # pyright: ignore[reportPrivateUsage]
139 (
140 other_d
141 if (other_d := getattr(other, "data")) is not None
142 and isinstance(
143 other_d,
144 xr.DataArray,
145 )
146 else other
147 ),
148 f,
149 )
150 return self
152 def _unary_op(self, f: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Self:
153 data = self._data._unary_op( # pyright: ignore[reportPrivateUsage]
154 f, *args, **kwargs
155 )
156 return self.__class__.from_xarray(data)
158 @classmethod
159 def from_xarray(cls, data_array: xr.DataArray) -> Self:
160 """create a `Tensor` from an xarray data array
162 note for internal use: this factory method is round-trip save
163 for any `Tensor`'s `data` property (an xarray.DataArray).
164 """
165 return cls(
166 array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims)
167 )
169 @classmethod
170 def from_numpy(
171 cls,
172 array: NDArray[Any],
173 *,
174 dims: Optional[Union[AxisLike, Sequence[AxisLike]]],
175 ) -> Tensor:
176 """create a `Tensor` from a numpy array
178 Args:
179 array: the nd numpy array
180 axes: A description of the array's axes,
181 if None axes are guessed (which might fail and raise a ValueError.)
183 Raises:
184 ValueError: if `axes` is None and axes guessing fails.
185 """
187 if dims is None:
188 return cls._interprete_array_wo_known_axes(array)
189 elif isinstance(dims, collections.abc.Sequence):
190 dim_seq = list(dims)
191 else:
192 dim_seq = [dims]
194 axis_infos = [AxisInfo.create(a) for a in dim_seq]
195 original_shape = tuple(array.shape)
197 successful_view = _get_array_view(array, axis_infos)
198 if successful_view is None:
199 raise ValueError(
200 f"Array shape {original_shape} does not map to axes {dims}"
201 )
203 return Tensor(successful_view, dims=tuple(a.id for a in axis_infos))
205 @property
206 def data(self):
207 return self._data
209 @property
210 def dims(self): # TODO: rename to `axes`?
211 """Tuple of dimension names associated with this tensor."""
212 return cast(Tuple[AxisId, ...], self._data.dims)
214 @property
215 def dtype(self) -> DTypeStr:
216 dt = str(self.data.dtype) # pyright: ignore[reportUnknownArgumentType]
217 assert dt in get_args(DTypeStr)
218 return dt # pyright: ignore[reportReturnType]
220 @property
221 def ndim(self):
222 """Number of tensor dimensions."""
223 return self._data.ndim
225 @property
226 def shape(self):
227 """Tuple of tensor axes lengths"""
228 return self._data.shape
230 @property
231 def shape_tuple(self):
232 """Tuple of tensor axes lengths"""
233 return self._data.shape
235 @property
236 def size(self):
237 """Number of elements in the tensor.
239 Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions.
240 """
241 return self._data.size
243 @property
244 def sizes(self):
245 """Ordered, immutable mapping from axis ids to axis lengths."""
246 return cast(Mapping[AxisId, int], self.data.sizes)
248 @property
249 def tagged_shape(self):
250 """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths."""
251 return self.sizes
253 def argmax(self) -> Mapping[AxisId, int]:
254 ret = self._data.argmax(...)
255 assert isinstance(ret, dict)
256 return {cast(AxisId, k): cast(int, v.item()) for k, v in ret.items()}
258 def astype(self, dtype: DTypeStr, *, copy: bool = False):
259 """Return tensor cast to `dtype`
261 note: if dtype is already satisfied copy if `copy`"""
262 return self.__class__.from_xarray(self._data.astype(dtype, copy=copy))
264 def clip(self, min: Optional[float] = None, max: Optional[float] = None):
265 """Return a tensor whose values are limited to [min, max].
266 At least one of max or min must be given."""
267 return self.__class__.from_xarray(self._data.clip(min, max))
269 def crop_to(
270 self,
271 sizes: PerAxis[int],
272 crop_where: Union[
273 CropWhere,
274 PerAxis[CropWhere],
275 ] = "left_and_right",
276 ) -> Self:
277 """crop to match `sizes`"""
278 if isinstance(crop_where, str):
279 crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims}
280 else:
281 crop_axis_where = crop_where
283 slices: Dict[AxisId, SliceInfo] = {}
285 for a, s_is in self.sizes.items():
286 if a not in sizes or sizes[a] == s_is:
287 pass
288 elif sizes[a] > s_is:
289 logger.warning(
290 "Cannot crop axis {} of size {} to larger size {}",
291 a,
292 s_is,
293 sizes[a],
294 )
295 elif a not in crop_axis_where:
296 raise ValueError(
297 f"Don't know where to crop axis {a}, `crop_where`={crop_where}"
298 )
299 else:
300 crop_this_axis_where = crop_axis_where[a]
301 if crop_this_axis_where == "left":
302 slices[a] = SliceInfo(s_is - sizes[a], s_is)
303 elif crop_this_axis_where == "right":
304 slices[a] = SliceInfo(0, sizes[a])
305 elif crop_this_axis_where == "left_and_right":
306 slices[a] = SliceInfo(
307 start := (s_is - sizes[a]) // 2, sizes[a] + start
308 )
309 else:
310 assert_never(crop_this_axis_where)
312 return self[slices]
314 def expand_dims(self, dims: Union[Sequence[AxisId], PerAxis[int]]) -> Self:
315 return self.__class__.from_xarray(self._data.expand_dims(dims=dims))
317 def item(
318 self,
319 key: Union[
320 None, SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]]
321 ] = None,
322 ):
323 """Copy a tensor element to a standard Python scalar and return it."""
324 if key is None:
325 ret = self._data.item()
326 else:
327 ret = self[key]._data.item()
329 assert isinstance(ret, (bool, float, int))
330 return ret
332 def mean(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
333 return self.__class__.from_xarray(self._data.mean(dim=dim))
335 def pad(
336 self,
337 pad_width: PerAxis[PadWidthLike],
338 mode: PadMode = "symmetric",
339 ) -> Self:
340 pad_width = {a: PadWidth.create(p) for a, p in pad_width.items()}
341 return self.__class__.from_xarray(
342 self._data.pad(pad_width=pad_width, mode=mode)
343 )
345 def pad_to(
346 self,
347 sizes: PerAxis[int],
348 pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right",
349 mode: PadMode = "symmetric",
350 ) -> Self:
351 """pad `tensor` to match `sizes`"""
352 if isinstance(pad_where, str):
353 pad_axis_where: PerAxis[PadWhere] = {a: pad_where for a in self.dims}
354 else:
355 pad_axis_where = pad_where
357 pad_width: Dict[AxisId, PadWidth] = {}
358 for a, s_is in self.sizes.items():
359 if a not in sizes or sizes[a] == s_is:
360 pad_width[a] = PadWidth(0, 0)
361 elif s_is > sizes[a]:
362 pad_width[a] = PadWidth(0, 0)
363 logger.warning(
364 "Cannot pad axis {} of size {} to smaller size {}",
365 a,
366 s_is,
367 sizes[a],
368 )
369 elif a not in pad_axis_where:
370 raise ValueError(
371 f"Don't know where to pad axis {a}, `pad_where`={pad_where}"
372 )
373 else:
374 pad_this_axis_where = pad_axis_where[a]
375 d = sizes[a] - s_is
376 if pad_this_axis_where == "left":
377 pad_width[a] = PadWidth(d, 0)
378 elif pad_this_axis_where == "right":
379 pad_width[a] = PadWidth(0, d)
380 elif pad_this_axis_where == "left_and_right":
381 pad_width[a] = PadWidth(left := d // 2, d - left)
382 else:
383 assert_never(pad_this_axis_where)
385 return self.pad(pad_width, mode)
387 def quantile(
388 self,
389 q: Union[float, Sequence[float]],
390 dim: Optional[Union[AxisId, Sequence[AxisId]]] = None,
391 ) -> Self:
392 assert (
393 isinstance(q, (float, int))
394 and q >= 0.0
395 or not isinstance(q, (float, int))
396 and all(qq >= 0.0 for qq in q)
397 )
398 assert (
399 isinstance(q, (float, int))
400 and q <= 1.0
401 or not isinstance(q, (float, int))
402 and all(qq <= 1.0 for qq in q)
403 )
404 assert dim is None or (
405 (quantile_dim := AxisId("quantile")) != dim and quantile_dim not in set(dim)
406 )
407 return self.__class__.from_xarray(self._data.quantile(q, dim=dim))
409 def resize_to(
410 self,
411 sizes: PerAxis[int],
412 *,
413 pad_where: Union[
414 PadWhere,
415 PerAxis[PadWhere],
416 ] = "left_and_right",
417 crop_where: Union[
418 CropWhere,
419 PerAxis[CropWhere],
420 ] = "left_and_right",
421 pad_mode: PadMode = "symmetric",
422 ):
423 """return cropped/padded tensor with `sizes`"""
424 crop_to_sizes: Dict[AxisId, int] = {}
425 pad_to_sizes: Dict[AxisId, int] = {}
426 new_axes = dict(sizes)
427 for a, s_is in self.sizes.items():
428 a = AxisId(str(a))
429 _ = new_axes.pop(a, None)
430 if a not in sizes or sizes[a] == s_is:
431 pass
432 elif s_is > sizes[a]:
433 crop_to_sizes[a] = sizes[a]
434 else:
435 pad_to_sizes[a] = sizes[a]
437 tensor = self
438 if crop_to_sizes:
439 tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where)
441 if pad_to_sizes:
442 tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode)
444 if new_axes:
445 tensor = tensor.expand_dims(new_axes)
447 return tensor
449 def std(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
450 return self.__class__.from_xarray(self._data.std(dim=dim))
452 def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
453 """Reduce this Tensor's data by applying sum along some dimension(s)."""
454 return self.__class__.from_xarray(self._data.sum(dim=dim))
456 def transpose(
457 self,
458 axes: Sequence[AxisId],
459 ) -> Self:
460 """return a transposed tensor
462 Args:
463 axes: the desired tensor axes
464 """
465 # expand missing tensor axes
466 missing_axes = tuple(a for a in axes if a not in self.dims)
467 array = self._data
468 if missing_axes:
469 array = array.expand_dims(missing_axes)
471 # transpose to the correct axis order
472 return self.__class__.from_xarray(array.transpose(*axes))
474 def var(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
475 return self.__class__.from_xarray(self._data.var(dim=dim))
477 @classmethod
478 def _interprete_array_wo_known_axes(cls, array: NDArray[Any]):
479 ndim = array.ndim
480 if ndim == 2:
481 current_axes = (
482 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[0]),
483 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[1]),
484 )
485 elif ndim == 3 and any(s <= 3 for s in array.shape):
486 current_axes = (
487 v0_5.ChannelAxis(
488 channel_names=[
489 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0])
490 ]
491 ),
492 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]),
493 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]),
494 )
495 elif ndim == 3:
496 current_axes = (
497 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[0]),
498 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]),
499 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]),
500 )
501 elif ndim == 4:
502 current_axes = (
503 v0_5.ChannelAxis(
504 channel_names=[
505 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0])
506 ]
507 ),
508 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[1]),
509 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[2]),
510 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[3]),
511 )
512 elif ndim == 5:
513 current_axes = (
514 v0_5.BatchAxis(),
515 v0_5.ChannelAxis(
516 channel_names=[
517 v0_5.Identifier(f"channel{i}") for i in range(array.shape[1])
518 ]
519 ),
520 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[2]),
521 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[3]),
522 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[4]),
523 )
524 else:
525 raise ValueError(f"Could not guess an axis mapping for {array.shape}")
527 return cls(array, dims=tuple(a.id for a in current_axes))
530def _add_singletons(arr: NDArray[Any], axis_infos: Sequence[AxisInfo]):
531 if len(arr.shape) > len(axis_infos):
532 # remove singletons
533 for i, s in enumerate(arr.shape):
534 if s == 1:
535 arr = np.take(arr, 0, axis=i)
536 if len(arr.shape) == len(axis_infos):
537 break
539 # add singletons if nececsary
540 for i, a in enumerate(axis_infos):
541 if len(arr.shape) >= len(axis_infos):
542 break
544 if a.maybe_singleton:
545 arr = np.expand_dims(arr, i)
547 return arr
550def _get_array_view(
551 original_array: NDArray[Any], axis_infos: Sequence[AxisInfo]
552) -> Optional[NDArray[Any]]:
553 perms = list(permutations(range(len(original_array.shape))))
554 perms.insert(1, perms.pop()) # try A and A.T first
556 for perm in perms:
557 view = original_array.transpose(perm)
558 view = _add_singletons(view, axis_infos)
559 if len(view.shape) != len(axis_infos):
560 return None
562 for s, a in zip(view.shape, axis_infos):
563 if s == 1 and not a.maybe_singleton:
564 break
565 else:
566 return view
568 return None