Coverage for bioimageio/core/tensor.py: 85%
234 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 09:51 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 09:51 +0000
1from __future__ import annotations
3import collections.abc
4from itertools import permutations
5from typing import (
6 TYPE_CHECKING,
7 Any,
8 Callable,
9 Dict,
10 Iterator,
11 Mapping,
12 Optional,
13 Sequence,
14 Tuple,
15 Union,
16 cast,
17 get_args,
18)
20import numpy as np
21import xarray as xr
22from loguru import logger
23from numpy.typing import DTypeLike, NDArray
24from typing_extensions import Self, assert_never
26from bioimageio.spec.model import v0_5
28from ._magic_tensor_ops import MagicTensorOpsMixin
29from .axis import Axis, AxisId, AxisInfo, AxisLike, PerAxis
30from .common import (
31 CropWhere,
32 DTypeStr,
33 PadMode,
34 PadWhere,
35 PadWidth,
36 PadWidthLike,
37 SliceInfo,
38)
40if TYPE_CHECKING:
41 from numpy.typing import ArrayLike, NDArray
44_ScalarOrArray = Union["ArrayLike", np.generic, "NDArray[Any]"] # TODO: add "DaskArray"
47# TODO: complete docstrings
48class Tensor(MagicTensorOpsMixin):
49 """A wrapper around an xr.DataArray for better integration with bioimageio.spec
50 and improved type annotations."""
52 _Compatible = Union["Tensor", xr.DataArray, _ScalarOrArray]
54 def __init__(
55 self,
56 array: NDArray[Any],
57 dims: Sequence[Union[AxisId, AxisLike]],
58 ) -> None:
59 super().__init__()
60 axes = tuple(
61 a if isinstance(a, AxisId) else AxisInfo.create(a).id for a in dims
62 )
63 self._data = xr.DataArray(array, dims=axes)
65 def __array__(self, dtype: DTypeLike = None):
66 return np.asarray(self._data, dtype=dtype)
68 def __getitem__(
69 self, key: Union[SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]]]
70 ) -> Self:
71 if isinstance(key, SliceInfo):
72 key = slice(*key)
73 elif isinstance(key, collections.abc.Mapping):
74 key = {
75 a: s if isinstance(s, int) else s if isinstance(s, slice) else slice(*s)
76 for a, s in key.items()
77 }
78 return self.__class__.from_xarray(self._data[key])
80 def __setitem__(self, key: PerAxis[Union[SliceInfo, slice]], value: Tensor) -> None:
81 key = {a: s if isinstance(s, slice) else slice(*s) for a, s in key.items()}
82 self._data[key] = value._data
84 def __len__(self) -> int:
85 return len(self.data)
87 def _iter(self: Any) -> Iterator[Any]:
88 for n in range(len(self)):
89 yield self[n]
91 def __iter__(self: Any) -> Iterator[Any]:
92 if self.ndim == 0:
93 raise TypeError("iteration over a 0-d array")
94 return self._iter()
96 def _binary_op(
97 self,
98 other: _Compatible,
99 f: Callable[[Any, Any], Any],
100 reflexive: bool = False,
101 ) -> Self:
102 data = self._data._binary_op( # pyright: ignore[reportPrivateUsage]
103 (other._data if isinstance(other, Tensor) else other),
104 f,
105 reflexive,
106 )
107 return self.__class__.from_xarray(data)
109 def _inplace_binary_op(
110 self,
111 other: _Compatible,
112 f: Callable[[Any, Any], Any],
113 ) -> Self:
114 _ = self._data._inplace_binary_op( # pyright: ignore[reportPrivateUsage]
115 (
116 other_d
117 if (other_d := getattr(other, "data")) is not None
118 and isinstance(
119 other_d,
120 xr.DataArray,
121 )
122 else other
123 ),
124 f,
125 )
126 return self
128 def _unary_op(self, f: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Self:
129 data = self._data._unary_op( # pyright: ignore[reportPrivateUsage]
130 f, *args, **kwargs
131 )
132 return self.__class__.from_xarray(data)
134 @classmethod
135 def from_xarray(cls, data_array: xr.DataArray) -> Self:
136 """create a `Tensor` from an xarray data array
138 note for internal use: this factory method is round-trip save
139 for any `Tensor`'s `data` property (an xarray.DataArray).
140 """
141 return cls(
142 array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims)
143 )
145 @classmethod
146 def from_numpy(
147 cls,
148 array: NDArray[Any],
149 *,
150 dims: Optional[Union[AxisLike, Sequence[AxisLike]]],
151 ) -> Tensor:
152 """create a `Tensor` from a numpy array
154 Args:
155 array: the nd numpy array
156 axes: A description of the array's axes,
157 if None axes are guessed (which might fail and raise a ValueError.)
159 Raises:
160 ValueError: if `axes` is None and axes guessing fails.
161 """
163 if dims is None:
164 return cls._interprete_array_wo_known_axes(array)
165 elif isinstance(dims, (str, Axis, v0_5.AxisBase)):
166 dims = [dims]
168 axis_infos = [AxisInfo.create(a) for a in dims]
169 original_shape = tuple(array.shape)
171 successful_view = _get_array_view(array, axis_infos)
172 if successful_view is None:
173 raise ValueError(
174 f"Array shape {original_shape} does not map to axes {dims}"
175 )
177 return Tensor(successful_view, dims=tuple(a.id for a in axis_infos))
179 @property
180 def data(self):
181 return self._data
183 @property
184 def dims(self): # TODO: rename to `axes`?
185 """Tuple of dimension names associated with this tensor."""
186 return cast(Tuple[AxisId, ...], self._data.dims)
188 @property
189 def dtype(self) -> DTypeStr:
190 dt = str(self.data.dtype) # pyright: ignore[reportUnknownArgumentType]
191 assert dt in get_args(DTypeStr)
192 return dt # pyright: ignore[reportReturnType]
194 @property
195 def ndim(self):
196 """Number of tensor dimensions."""
197 return self._data.ndim
199 @property
200 def shape(self):
201 """Tuple of tensor axes lengths"""
202 return self._data.shape
204 @property
205 def shape_tuple(self):
206 """Tuple of tensor axes lengths"""
207 return self._data.shape
209 @property
210 def size(self):
211 """Number of elements in the tensor.
213 Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions.
214 """
215 return self._data.size
217 @property
218 def sizes(self):
219 """Ordered, immutable mapping from axis ids to axis lengths."""
220 return cast(Mapping[AxisId, int], self.data.sizes)
222 @property
223 def tagged_shape(self):
224 """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths."""
225 return self.sizes
227 def argmax(self) -> Mapping[AxisId, int]:
228 ret = self._data.argmax(...)
229 assert isinstance(ret, dict)
230 return {cast(AxisId, k): cast(int, v.item()) for k, v in ret.items()}
232 def astype(self, dtype: DTypeStr, *, copy: bool = False):
233 """Return tensor cast to `dtype`
235 note: if dtype is already satisfied copy if `copy`"""
236 return self.__class__.from_xarray(self._data.astype(dtype, copy=copy))
238 def clip(self, min: Optional[float] = None, max: Optional[float] = None):
239 """Return a tensor whose values are limited to [min, max].
240 At least one of max or min must be given."""
241 return self.__class__.from_xarray(self._data.clip(min, max))
243 def crop_to(
244 self,
245 sizes: PerAxis[int],
246 crop_where: Union[
247 CropWhere,
248 PerAxis[CropWhere],
249 ] = "left_and_right",
250 ) -> Self:
251 """crop to match `sizes`"""
252 if isinstance(crop_where, str):
253 crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims}
254 else:
255 crop_axis_where = crop_where
257 slices: Dict[AxisId, SliceInfo] = {}
259 for a, s_is in self.sizes.items():
260 if a not in sizes or sizes[a] == s_is:
261 pass
262 elif sizes[a] > s_is:
263 logger.warning(
264 "Cannot crop axis {} of size {} to larger size {}",
265 a,
266 s_is,
267 sizes[a],
268 )
269 elif a not in crop_axis_where:
270 raise ValueError(
271 f"Don't know where to crop axis {a}, `crop_where`={crop_where}"
272 )
273 else:
274 crop_this_axis_where = crop_axis_where[a]
275 if crop_this_axis_where == "left":
276 slices[a] = SliceInfo(s_is - sizes[a], s_is)
277 elif crop_this_axis_where == "right":
278 slices[a] = SliceInfo(0, sizes[a])
279 elif crop_this_axis_where == "left_and_right":
280 slices[a] = SliceInfo(
281 start := (s_is - sizes[a]) // 2, sizes[a] + start
282 )
283 else:
284 assert_never(crop_this_axis_where)
286 return self[slices]
288 def expand_dims(self, dims: Union[Sequence[AxisId], PerAxis[int]]) -> Self:
289 return self.__class__.from_xarray(self._data.expand_dims(dims=dims))
291 def item(
292 self,
293 key: Union[
294 None, SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]]
295 ] = None,
296 ):
297 """Copy a tensor element to a standard Python scalar and return it."""
298 if key is None:
299 ret = self._data.item()
300 else:
301 ret = self[key]._data.item()
303 assert isinstance(ret, (bool, float, int))
304 return ret
306 def mean(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
307 return self.__class__.from_xarray(self._data.mean(dim=dim))
309 def pad(
310 self,
311 pad_width: PerAxis[PadWidthLike],
312 mode: PadMode = "symmetric",
313 ) -> Self:
314 pad_width = {a: PadWidth.create(p) for a, p in pad_width.items()}
315 return self.__class__.from_xarray(
316 self._data.pad(pad_width=pad_width, mode=mode)
317 )
319 def pad_to(
320 self,
321 sizes: PerAxis[int],
322 pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right",
323 mode: PadMode = "symmetric",
324 ) -> Self:
325 """pad `tensor` to match `sizes`"""
326 if isinstance(pad_where, str):
327 pad_axis_where: PerAxis[PadWhere] = {a: pad_where for a in self.dims}
328 else:
329 pad_axis_where = pad_where
331 pad_width: Dict[AxisId, PadWidth] = {}
332 for a, s_is in self.sizes.items():
333 if a not in sizes or sizes[a] == s_is:
334 pad_width[a] = PadWidth(0, 0)
335 elif s_is > sizes[a]:
336 pad_width[a] = PadWidth(0, 0)
337 logger.warning(
338 "Cannot pad axis {} of size {} to smaller size {}",
339 a,
340 s_is,
341 sizes[a],
342 )
343 elif a not in pad_axis_where:
344 raise ValueError(
345 f"Don't know where to pad axis {a}, `pad_where`={pad_where}"
346 )
347 else:
348 pad_this_axis_where = pad_axis_where[a]
349 d = sizes[a] - s_is
350 if pad_this_axis_where == "left":
351 pad_width[a] = PadWidth(d, 0)
352 elif pad_this_axis_where == "right":
353 pad_width[a] = PadWidth(0, d)
354 elif pad_this_axis_where == "left_and_right":
355 pad_width[a] = PadWidth(left := d // 2, d - left)
356 else:
357 assert_never(pad_this_axis_where)
359 return self.pad(pad_width, mode)
361 def quantile(
362 self,
363 q: Union[float, Sequence[float]],
364 dim: Optional[Union[AxisId, Sequence[AxisId]]] = None,
365 ) -> Self:
366 assert (
367 isinstance(q, (float, int))
368 and q >= 0.0
369 or not isinstance(q, (float, int))
370 and all(qq >= 0.0 for qq in q)
371 )
372 assert (
373 isinstance(q, (float, int))
374 and q <= 1.0
375 or not isinstance(q, (float, int))
376 and all(qq <= 1.0 for qq in q)
377 )
378 assert dim is None or (
379 (quantile_dim := AxisId("quantile")) != dim and quantile_dim not in set(dim)
380 )
381 return self.__class__.from_xarray(self._data.quantile(q, dim=dim))
383 def resize_to(
384 self,
385 sizes: PerAxis[int],
386 *,
387 pad_where: Union[
388 PadWhere,
389 PerAxis[PadWhere],
390 ] = "left_and_right",
391 crop_where: Union[
392 CropWhere,
393 PerAxis[CropWhere],
394 ] = "left_and_right",
395 pad_mode: PadMode = "symmetric",
396 ):
397 """return cropped/padded tensor with `sizes`"""
398 crop_to_sizes: Dict[AxisId, int] = {}
399 pad_to_sizes: Dict[AxisId, int] = {}
400 new_axes = dict(sizes)
401 for a, s_is in self.sizes.items():
402 a = AxisId(str(a))
403 _ = new_axes.pop(a, None)
404 if a not in sizes or sizes[a] == s_is:
405 pass
406 elif s_is > sizes[a]:
407 crop_to_sizes[a] = sizes[a]
408 else:
409 pad_to_sizes[a] = sizes[a]
411 tensor = self
412 if crop_to_sizes:
413 tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where)
415 if pad_to_sizes:
416 tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode)
418 if new_axes:
419 tensor = tensor.expand_dims(new_axes)
421 return tensor
423 def std(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
424 return self.__class__.from_xarray(self._data.std(dim=dim))
426 def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
427 """Reduce this Tensor's data by applying sum along some dimension(s)."""
428 return self.__class__.from_xarray(self._data.sum(dim=dim))
430 def transpose(
431 self,
432 axes: Sequence[AxisId],
433 ) -> Self:
434 """return a transposed tensor
436 Args:
437 axes: the desired tensor axes
438 """
439 # expand missing tensor axes
440 missing_axes = tuple(a for a in axes if a not in self.dims)
441 array = self._data
442 if missing_axes:
443 array = array.expand_dims(missing_axes)
445 # transpose to the correct axis order
446 return self.__class__.from_xarray(array.transpose(*axes))
448 def var(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
449 return self.__class__.from_xarray(self._data.var(dim=dim))
451 @classmethod
452 def _interprete_array_wo_known_axes(cls, array: NDArray[Any]):
453 ndim = array.ndim
454 if ndim == 2:
455 current_axes = (
456 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[0]),
457 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[1]),
458 )
459 elif ndim == 3 and any(s <= 3 for s in array.shape):
460 current_axes = (
461 v0_5.ChannelAxis(
462 channel_names=[
463 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0])
464 ]
465 ),
466 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]),
467 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]),
468 )
469 elif ndim == 3:
470 current_axes = (
471 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[0]),
472 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]),
473 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]),
474 )
475 elif ndim == 4:
476 current_axes = (
477 v0_5.ChannelAxis(
478 channel_names=[
479 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0])
480 ]
481 ),
482 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[1]),
483 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[2]),
484 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[3]),
485 )
486 elif ndim == 5:
487 current_axes = (
488 v0_5.BatchAxis(),
489 v0_5.ChannelAxis(
490 channel_names=[
491 v0_5.Identifier(f"channel{i}") for i in range(array.shape[1])
492 ]
493 ),
494 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[2]),
495 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[3]),
496 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[4]),
497 )
498 else:
499 raise ValueError(f"Could not guess an axis mapping for {array.shape}")
501 return cls(array, dims=tuple(a.id for a in current_axes))
504def _add_singletons(arr: NDArray[Any], axis_infos: Sequence[AxisInfo]):
505 if len(arr.shape) > len(axis_infos):
506 # remove singletons
507 for i, s in enumerate(arr.shape):
508 if s == 1:
509 arr = np.take(arr, 0, axis=i)
510 if len(arr.shape) == len(axis_infos):
511 break
513 # add singletons if nececsary
514 for i, a in enumerate(axis_infos):
515 if len(arr.shape) >= len(axis_infos):
516 break
518 if a.maybe_singleton:
519 arr = np.expand_dims(arr, i)
521 return arr
524def _get_array_view(
525 original_array: NDArray[Any], axis_infos: Sequence[AxisInfo]
526) -> Optional[NDArray[Any]]:
527 perms = list(permutations(range(len(original_array.shape))))
528 perms.insert(1, perms.pop()) # try A and A.T first
530 for perm in perms:
531 view = original_array.transpose(perm)
532 view = _add_singletons(view, axis_infos)
533 if len(view.shape) != len(axis_infos):
534 return None
536 for s, a in zip(view.shape, axis_infos):
537 if s == 1 and not a.maybe_singleton:
538 break
539 else:
540 return view
542 return None