Coverage for bioimageio/core/tensor.py: 76%
221 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-19 09:02 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-19 09:02 +0000
1from __future__ import annotations
3import collections.abc
4from itertools import permutations
5from typing import (
6 TYPE_CHECKING,
7 Any,
8 Callable,
9 Dict,
10 Iterator,
11 Mapping,
12 Optional,
13 Sequence,
14 Tuple,
15 Union,
16 cast,
17 get_args,
18)
20import numpy as np
21import xarray as xr
22from loguru import logger
23from numpy.typing import DTypeLike, NDArray
24from typing_extensions import Self, assert_never
26from bioimageio.spec.model import v0_5
28from ._magic_tensor_ops import MagicTensorOpsMixin
29from .axis import Axis, AxisId, AxisInfo, AxisLike, PerAxis
30from .common import (
31 CropWhere,
32 DTypeStr,
33 PadMode,
34 PadWhere,
35 PadWidth,
36 PadWidthLike,
37 SliceInfo,
38)
40if TYPE_CHECKING:
41 from numpy.typing import ArrayLike, NDArray
44_ScalarOrArray = Union["ArrayLike", np.generic, "NDArray[Any]"] # TODO: add "DaskArray"
47# TODO: complete docstrings
48class Tensor(MagicTensorOpsMixin):
49 """A wrapper around an xr.DataArray for better integration with bioimageio.spec
50 and improved type annotations."""
52 _Compatible = Union["Tensor", xr.DataArray, _ScalarOrArray]
54 def __init__(
55 self,
56 array: NDArray[Any],
57 dims: Sequence[Union[AxisId, AxisLike]],
58 ) -> None:
59 super().__init__()
60 axes = tuple(
61 a if isinstance(a, AxisId) else AxisInfo.create(a).id for a in dims
62 )
63 self._data = xr.DataArray(array, dims=axes)
65 def __array__(self, dtype: DTypeLike = None):
66 return np.asarray(self._data, dtype=dtype)
68 def __getitem__(
69 self, key: Union[SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]]]
70 ) -> Self:
71 if isinstance(key, SliceInfo):
72 key = slice(*key)
73 elif isinstance(key, collections.abc.Mapping):
74 key = {
75 a: s if isinstance(s, int) else s if isinstance(s, slice) else slice(*s)
76 for a, s in key.items()
77 }
78 return self.__class__.from_xarray(self._data[key])
80 def __setitem__(self, key: PerAxis[Union[SliceInfo, slice]], value: Tensor) -> None:
81 key = {a: s if isinstance(s, slice) else slice(*s) for a, s in key.items()}
82 self._data[key] = value._data
84 def __len__(self) -> int:
85 return len(self.data)
87 def _iter(self: Any) -> Iterator[Any]:
88 for n in range(len(self)):
89 yield self[n]
91 def __iter__(self: Any) -> Iterator[Any]:
92 if self.ndim == 0:
93 raise TypeError("iteration over a 0-d array")
94 return self._iter()
96 def _binary_op(
97 self,
98 other: _Compatible,
99 f: Callable[[Any, Any], Any],
100 reflexive: bool = False,
101 ) -> Self:
102 data = self._data._binary_op( # pyright: ignore[reportPrivateUsage]
103 (other._data if isinstance(other, Tensor) else other),
104 f,
105 reflexive,
106 )
107 return self.__class__.from_xarray(data)
109 def _inplace_binary_op(
110 self,
111 other: _Compatible,
112 f: Callable[[Any, Any], Any],
113 ) -> Self:
114 _ = self._data._inplace_binary_op( # pyright: ignore[reportPrivateUsage]
115 (
116 other_d
117 if (other_d := getattr(other, "data")) is not None
118 and isinstance(
119 other_d,
120 xr.DataArray,
121 )
122 else other
123 ),
124 f,
125 )
126 return self
128 def _unary_op(self, f: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Self:
129 data = self._data._unary_op( # pyright: ignore[reportPrivateUsage]
130 f, *args, **kwargs
131 )
132 return self.__class__.from_xarray(data)
134 @classmethod
135 def from_xarray(cls, data_array: xr.DataArray) -> Self:
136 """create a `Tensor` from an xarray data array
138 note for internal use: this factory method is round-trip save
139 for any `Tensor`'s `data` property (an xarray.DataArray).
140 """
141 return cls(
142 array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims)
143 )
145 @classmethod
146 def from_numpy(
147 cls,
148 array: NDArray[Any],
149 *,
150 dims: Optional[Union[AxisLike, Sequence[AxisLike]]],
151 ) -> Tensor:
152 """create a `Tensor` from a numpy array
154 Args:
155 array: the nd numpy array
156 axes: A description of the array's axes,
157 if None axes are guessed (which might fail and raise a ValueError.)
159 Raises:
160 ValueError: if `axes` is None and axes guessing fails.
161 """
163 if dims is None:
164 return cls._interprete_array_wo_known_axes(array)
165 elif isinstance(dims, (str, Axis, v0_5.AxisBase)):
166 dims = [dims]
168 axis_infos = [AxisInfo.create(a) for a in dims]
169 original_shape = tuple(array.shape)
171 successful_view = _get_array_view(array, axis_infos)
172 if successful_view is None:
173 raise ValueError(
174 f"Array shape {original_shape} does not map to axes {dims}"
175 )
177 return Tensor(successful_view, dims=tuple(a.id for a in axis_infos))
179 @property
180 def data(self):
181 return self._data
183 @property
184 def dims(self): # TODO: rename to `axes`?
185 """Tuple of dimension names associated with this tensor."""
186 return cast(Tuple[AxisId, ...], self._data.dims)
188 @property
189 def tagged_shape(self):
190 """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths."""
191 return self.sizes
193 @property
194 def shape_tuple(self):
195 """Tuple of tensor axes lengths"""
196 return self._data.shape
198 @property
199 def size(self):
200 """Number of elements in the tensor.
202 Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions.
203 """
204 return self._data.size
206 def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
207 """Reduce this Tensor's data by applying sum along some dimension(s)."""
208 return self.__class__.from_xarray(self._data.sum(dim=dim))
210 @property
211 def ndim(self):
212 """Number of tensor dimensions."""
213 return self._data.ndim
215 @property
216 def dtype(self) -> DTypeStr:
217 dt = str(self.data.dtype) # pyright: ignore[reportUnknownArgumentType]
218 assert dt in get_args(DTypeStr)
219 return dt # pyright: ignore[reportReturnType]
221 @property
222 def sizes(self):
223 """Ordered, immutable mapping from axis ids to axis lengths."""
224 return cast(Mapping[AxisId, int], self.data.sizes)
226 def astype(self, dtype: DTypeStr, *, copy: bool = False):
227 """Return tensor cast to `dtype`
229 note: if dtype is already satisfied copy if `copy`"""
230 return self.__class__.from_xarray(self._data.astype(dtype, copy=copy))
232 def clip(self, min: Optional[float] = None, max: Optional[float] = None):
233 """Return a tensor whose values are limited to [min, max].
234 At least one of max or min must be given."""
235 return self.__class__.from_xarray(self._data.clip(min, max))
237 def crop_to(
238 self,
239 sizes: PerAxis[int],
240 crop_where: Union[
241 CropWhere,
242 PerAxis[CropWhere],
243 ] = "left_and_right",
244 ) -> Self:
245 """crop to match `sizes`"""
246 if isinstance(crop_where, str):
247 crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims}
248 else:
249 crop_axis_where = crop_where
251 slices: Dict[AxisId, SliceInfo] = {}
253 for a, s_is in self.sizes.items():
254 if a not in sizes or sizes[a] == s_is:
255 pass
256 elif sizes[a] > s_is:
257 logger.warning(
258 "Cannot crop axis {} of size {} to larger size {}",
259 a,
260 s_is,
261 sizes[a],
262 )
263 elif a not in crop_axis_where:
264 raise ValueError(
265 f"Don't know where to crop axis {a}, `crop_where`={crop_where}"
266 )
267 else:
268 crop_this_axis_where = crop_axis_where[a]
269 if crop_this_axis_where == "left":
270 slices[a] = SliceInfo(s_is - sizes[a], s_is)
271 elif crop_this_axis_where == "right":
272 slices[a] = SliceInfo(0, sizes[a])
273 elif crop_this_axis_where == "left_and_right":
274 slices[a] = SliceInfo(
275 start := (s_is - sizes[a]) // 2, sizes[a] + start
276 )
277 else:
278 assert_never(crop_this_axis_where)
280 return self[slices]
282 def expand_dims(self, dims: Union[Sequence[AxisId], PerAxis[int]]) -> Self:
283 return self.__class__.from_xarray(self._data.expand_dims(dims=dims))
285 def mean(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
286 return self.__class__.from_xarray(self._data.mean(dim=dim))
288 def std(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
289 return self.__class__.from_xarray(self._data.std(dim=dim))
291 def var(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
292 return self.__class__.from_xarray(self._data.var(dim=dim))
294 def pad(
295 self,
296 pad_width: PerAxis[PadWidthLike],
297 mode: PadMode = "symmetric",
298 ) -> Self:
299 pad_width = {a: PadWidth.create(p) for a, p in pad_width.items()}
300 return self.__class__.from_xarray(
301 self._data.pad(pad_width=pad_width, mode=mode)
302 )
304 def pad_to(
305 self,
306 sizes: PerAxis[int],
307 pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right",
308 mode: PadMode = "symmetric",
309 ) -> Self:
310 """pad `tensor` to match `sizes`"""
311 if isinstance(pad_where, str):
312 pad_axis_where: PerAxis[PadWhere] = {a: pad_where for a in self.dims}
313 else:
314 pad_axis_where = pad_where
316 pad_width: Dict[AxisId, PadWidth] = {}
317 for a, s_is in self.sizes.items():
318 if a not in sizes or sizes[a] == s_is:
319 pad_width[a] = PadWidth(0, 0)
320 elif s_is > sizes[a]:
321 pad_width[a] = PadWidth(0, 0)
322 logger.warning(
323 "Cannot pad axis {} of size {} to smaller size {}",
324 a,
325 s_is,
326 sizes[a],
327 )
328 elif a not in pad_axis_where:
329 raise ValueError(
330 f"Don't know where to pad axis {a}, `pad_where`={pad_where}"
331 )
332 else:
333 pad_this_axis_where = pad_axis_where[a]
334 d = sizes[a] - s_is
335 if pad_this_axis_where == "left":
336 pad_width[a] = PadWidth(d, 0)
337 elif pad_this_axis_where == "right":
338 pad_width[a] = PadWidth(0, d)
339 elif pad_this_axis_where == "left_and_right":
340 pad_width[a] = PadWidth(left := d // 2, d - left)
341 else:
342 assert_never(pad_this_axis_where)
344 return self.pad(pad_width, mode)
346 def quantile(
347 self,
348 q: Union[float, Sequence[float]],
349 dim: Optional[Union[AxisId, Sequence[AxisId]]] = None,
350 ) -> Self:
351 assert (
352 isinstance(q, (float, int))
353 and q >= 0.0
354 or not isinstance(q, (float, int))
355 and all(qq >= 0.0 for qq in q)
356 )
357 assert (
358 isinstance(q, (float, int))
359 and q <= 1.0
360 or not isinstance(q, (float, int))
361 and all(qq <= 1.0 for qq in q)
362 )
363 assert dim is None or (
364 (quantile_dim := AxisId("quantile")) != dim and quantile_dim not in set(dim)
365 )
366 return self.__class__.from_xarray(self._data.quantile(q, dim=dim))
368 def resize_to(
369 self,
370 sizes: PerAxis[int],
371 *,
372 pad_where: Union[
373 PadWhere,
374 PerAxis[PadWhere],
375 ] = "left_and_right",
376 crop_where: Union[
377 CropWhere,
378 PerAxis[CropWhere],
379 ] = "left_and_right",
380 pad_mode: PadMode = "symmetric",
381 ):
382 """return cropped/padded tensor with `sizes`"""
383 crop_to_sizes: Dict[AxisId, int] = {}
384 pad_to_sizes: Dict[AxisId, int] = {}
385 new_axes = dict(sizes)
386 for a, s_is in self.sizes.items():
387 a = AxisId(str(a))
388 _ = new_axes.pop(a, None)
389 if a not in sizes or sizes[a] == s_is:
390 pass
391 elif s_is > sizes[a]:
392 crop_to_sizes[a] = sizes[a]
393 else:
394 pad_to_sizes[a] = sizes[a]
396 tensor = self
397 if crop_to_sizes:
398 tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where)
400 if pad_to_sizes:
401 tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode)
403 if new_axes:
404 tensor = tensor.expand_dims(new_axes)
406 return tensor
408 def transpose(
409 self,
410 axes: Sequence[AxisId],
411 ) -> Self:
412 """return a transposed tensor
414 Args:
415 axes: the desired tensor axes
416 """
417 # expand missing tensor axes
418 missing_axes = tuple(a for a in axes if a not in self.dims)
419 array = self._data
420 if missing_axes:
421 array = array.expand_dims(missing_axes)
423 # transpose to the correct axis order
424 return self.__class__.from_xarray(array.transpose(*axes))
426 @classmethod
427 def _interprete_array_wo_known_axes(cls, array: NDArray[Any]):
428 ndim = array.ndim
429 if ndim == 2:
430 current_axes = (
431 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[0]),
432 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[1]),
433 )
434 elif ndim == 3 and any(s <= 3 for s in array.shape):
435 current_axes = (
436 v0_5.ChannelAxis(
437 channel_names=[
438 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0])
439 ]
440 ),
441 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]),
442 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]),
443 )
444 elif ndim == 3:
445 current_axes = (
446 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[0]),
447 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]),
448 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]),
449 )
450 elif ndim == 4:
451 current_axes = (
452 v0_5.ChannelAxis(
453 channel_names=[
454 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0])
455 ]
456 ),
457 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[1]),
458 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[2]),
459 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[3]),
460 )
461 elif ndim == 5:
462 current_axes = (
463 v0_5.BatchAxis(),
464 v0_5.ChannelAxis(
465 channel_names=[
466 v0_5.Identifier(f"channel{i}") for i in range(array.shape[1])
467 ]
468 ),
469 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[2]),
470 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[3]),
471 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[4]),
472 )
473 else:
474 raise ValueError(f"Could not guess an axis mapping for {array.shape}")
476 return cls(array, dims=tuple(a.id for a in current_axes))
479def _add_singletons(arr: NDArray[Any], axis_infos: Sequence[AxisInfo]):
480 if len(arr.shape) > len(axis_infos):
481 # remove singletons
482 for i, s in enumerate(arr.shape):
483 if s == 1:
484 arr = np.take(arr, 0, axis=i)
485 if len(arr.shape) == len(axis_infos):
486 break
488 # add singletons if nececsary
489 for i, a in enumerate(axis_infos):
490 if len(arr.shape) >= len(axis_infos):
491 break
493 if a.maybe_singleton:
494 arr = np.expand_dims(arr, i)
496 return arr
499def _get_array_view(
500 original_array: NDArray[Any], axis_infos: Sequence[AxisInfo]
501) -> Optional[NDArray[Any]]:
502 perms = list(permutations(range(len(original_array.shape))))
503 perms.insert(1, perms.pop()) # try A and A.T first
505 for perm in perms:
506 view = original_array.transpose(perm)
507 view = _add_singletons(view, axis_infos)
508 if len(view.shape) != len(axis_infos):
509 return None
511 for s, a in zip(view.shape, axis_infos):
512 if s == 1 and not a.maybe_singleton:
513 break
514 else:
515 return view
517 return None