bioimageio.core.tensor
1from __future__ import annotations 2 3import collections.abc 4from itertools import permutations 5from typing import ( 6 TYPE_CHECKING, 7 Any, 8 Callable, 9 Dict, 10 Iterator, 11 Mapping, 12 Optional, 13 Sequence, 14 Tuple, 15 Union, 16 cast, 17 get_args, 18) 19 20import numpy as np 21import xarray as xr 22from loguru import logger 23from numpy.typing import DTypeLike, NDArray 24from typing_extensions import Self, assert_never 25 26from bioimageio.spec.model import v0_5 27 28from ._magic_tensor_ops import MagicTensorOpsMixin 29from .axis import Axis, AxisId, AxisInfo, AxisLike, PerAxis 30from .common import ( 31 CropWhere, 32 DTypeStr, 33 PadMode, 34 PadWhere, 35 PadWidth, 36 PadWidthLike, 37 SliceInfo, 38) 39 40if TYPE_CHECKING: 41 from numpy.typing import ArrayLike, NDArray 42 43 44_ScalarOrArray = Union["ArrayLike", np.generic, "NDArray[Any]"] # TODO: add "DaskArray" 45 46 47# TODO: complete docstrings 48class Tensor(MagicTensorOpsMixin): 49 """A wrapper around an xr.DataArray for better integration with bioimageio.spec 50 and improved type annotations.""" 51 52 _Compatible = Union["Tensor", xr.DataArray, _ScalarOrArray] 53 54 def __init__( 55 self, 56 array: NDArray[Any], 57 dims: Sequence[Union[AxisId, AxisLike]], 58 ) -> None: 59 super().__init__() 60 axes = tuple( 61 a if isinstance(a, AxisId) else AxisInfo.create(a).id for a in dims 62 ) 63 self._data = xr.DataArray(array, dims=axes) 64 65 def __array__(self, dtype: DTypeLike = None): 66 return np.asarray(self._data, dtype=dtype) 67 68 def __getitem__( 69 self, 70 key: Union[ 71 SliceInfo, 72 slice, 73 int, 74 PerAxis[Union[SliceInfo, slice, int]], 75 Tensor, 76 xr.DataArray, 77 ], 78 ) -> Self: 79 if isinstance(key, SliceInfo): 80 key = slice(*key) 81 elif isinstance(key, collections.abc.Mapping): 82 key = { 83 a: s if isinstance(s, int) else s if isinstance(s, slice) else slice(*s) 84 for a, s in key.items() 85 } 86 elif isinstance(key, Tensor): 87 key = key._data 88 89 return self.__class__.from_xarray(self._data[key]) 90 91 def __setitem__( 92 self, 93 key: Union[PerAxis[Union[SliceInfo, slice]], Tensor, xr.DataArray], 94 value: Union[Tensor, xr.DataArray, float, int], 95 ) -> None: 96 if isinstance(key, Tensor): 97 key = key._data 98 elif isinstance(key, xr.DataArray): 99 pass 100 else: 101 key = {a: s if isinstance(s, slice) else slice(*s) for a, s in key.items()} 102 103 if isinstance(value, Tensor): 104 value = value._data 105 106 self._data[key] = value 107 108 def __len__(self) -> int: 109 return len(self.data) 110 111 def _iter(self: Any) -> Iterator[Any]: 112 for n in range(len(self)): 113 yield self[n] 114 115 def __iter__(self: Any) -> Iterator[Any]: 116 if self.ndim == 0: 117 raise TypeError("iteration over a 0-d array") 118 return self._iter() 119 120 def _binary_op( 121 self, 122 other: _Compatible, 123 f: Callable[[Any, Any], Any], 124 reflexive: bool = False, 125 ) -> Self: 126 data = self._data._binary_op( # pyright: ignore[reportPrivateUsage] 127 (other._data if isinstance(other, Tensor) else other), 128 f, 129 reflexive, 130 ) 131 return self.__class__.from_xarray(data) 132 133 def _inplace_binary_op( 134 self, 135 other: _Compatible, 136 f: Callable[[Any, Any], Any], 137 ) -> Self: 138 _ = self._data._inplace_binary_op( # pyright: ignore[reportPrivateUsage] 139 ( 140 other_d 141 if (other_d := getattr(other, "data")) is not None 142 and isinstance( 143 other_d, 144 xr.DataArray, 145 ) 146 else other 147 ), 148 f, 149 ) 150 return self 151 152 def _unary_op(self, f: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Self: 153 data = self._data._unary_op( # pyright: ignore[reportPrivateUsage] 154 f, *args, **kwargs 155 ) 156 return self.__class__.from_xarray(data) 157 158 @classmethod 159 def from_xarray(cls, data_array: xr.DataArray) -> Self: 160 """create a `Tensor` from an xarray data array 161 162 note for internal use: this factory method is round-trip save 163 for any `Tensor`'s `data` property (an xarray.DataArray). 164 """ 165 return cls( 166 array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims) 167 ) 168 169 @classmethod 170 def from_numpy( 171 cls, 172 array: NDArray[Any], 173 *, 174 dims: Optional[Union[AxisLike, Sequence[AxisLike]]], 175 ) -> Tensor: 176 """create a `Tensor` from a numpy array 177 178 Args: 179 array: the nd numpy array 180 axes: A description of the array's axes, 181 if None axes are guessed (which might fail and raise a ValueError.) 182 183 Raises: 184 ValueError: if `axes` is None and axes guessing fails. 185 """ 186 187 if dims is None: 188 return cls._interprete_array_wo_known_axes(array) 189 elif isinstance(dims, (str, Axis, v0_5.AxisBase)): 190 dims = [dims] 191 192 axis_infos = [AxisInfo.create(a) for a in dims] 193 original_shape = tuple(array.shape) 194 195 successful_view = _get_array_view(array, axis_infos) 196 if successful_view is None: 197 raise ValueError( 198 f"Array shape {original_shape} does not map to axes {dims}" 199 ) 200 201 return Tensor(successful_view, dims=tuple(a.id for a in axis_infos)) 202 203 @property 204 def data(self): 205 return self._data 206 207 @property 208 def dims(self): # TODO: rename to `axes`? 209 """Tuple of dimension names associated with this tensor.""" 210 return cast(Tuple[AxisId, ...], self._data.dims) 211 212 @property 213 def dtype(self) -> DTypeStr: 214 dt = str(self.data.dtype) # pyright: ignore[reportUnknownArgumentType] 215 assert dt in get_args(DTypeStr) 216 return dt # pyright: ignore[reportReturnType] 217 218 @property 219 def ndim(self): 220 """Number of tensor dimensions.""" 221 return self._data.ndim 222 223 @property 224 def shape(self): 225 """Tuple of tensor axes lengths""" 226 return self._data.shape 227 228 @property 229 def shape_tuple(self): 230 """Tuple of tensor axes lengths""" 231 return self._data.shape 232 233 @property 234 def size(self): 235 """Number of elements in the tensor. 236 237 Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions. 238 """ 239 return self._data.size 240 241 @property 242 def sizes(self): 243 """Ordered, immutable mapping from axis ids to axis lengths.""" 244 return cast(Mapping[AxisId, int], self.data.sizes) 245 246 @property 247 def tagged_shape(self): 248 """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths.""" 249 return self.sizes 250 251 def argmax(self) -> Mapping[AxisId, int]: 252 ret = self._data.argmax(...) 253 assert isinstance(ret, dict) 254 return {cast(AxisId, k): cast(int, v.item()) for k, v in ret.items()} 255 256 def astype(self, dtype: DTypeStr, *, copy: bool = False): 257 """Return tensor cast to `dtype` 258 259 note: if dtype is already satisfied copy if `copy`""" 260 return self.__class__.from_xarray(self._data.astype(dtype, copy=copy)) 261 262 def clip(self, min: Optional[float] = None, max: Optional[float] = None): 263 """Return a tensor whose values are limited to [min, max]. 264 At least one of max or min must be given.""" 265 return self.__class__.from_xarray(self._data.clip(min, max)) 266 267 def crop_to( 268 self, 269 sizes: PerAxis[int], 270 crop_where: Union[ 271 CropWhere, 272 PerAxis[CropWhere], 273 ] = "left_and_right", 274 ) -> Self: 275 """crop to match `sizes`""" 276 if isinstance(crop_where, str): 277 crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims} 278 else: 279 crop_axis_where = crop_where 280 281 slices: Dict[AxisId, SliceInfo] = {} 282 283 for a, s_is in self.sizes.items(): 284 if a not in sizes or sizes[a] == s_is: 285 pass 286 elif sizes[a] > s_is: 287 logger.warning( 288 "Cannot crop axis {} of size {} to larger size {}", 289 a, 290 s_is, 291 sizes[a], 292 ) 293 elif a not in crop_axis_where: 294 raise ValueError( 295 f"Don't know where to crop axis {a}, `crop_where`={crop_where}" 296 ) 297 else: 298 crop_this_axis_where = crop_axis_where[a] 299 if crop_this_axis_where == "left": 300 slices[a] = SliceInfo(s_is - sizes[a], s_is) 301 elif crop_this_axis_where == "right": 302 slices[a] = SliceInfo(0, sizes[a]) 303 elif crop_this_axis_where == "left_and_right": 304 slices[a] = SliceInfo( 305 start := (s_is - sizes[a]) // 2, sizes[a] + start 306 ) 307 else: 308 assert_never(crop_this_axis_where) 309 310 return self[slices] 311 312 def expand_dims(self, dims: Union[Sequence[AxisId], PerAxis[int]]) -> Self: 313 return self.__class__.from_xarray(self._data.expand_dims(dims=dims)) 314 315 def item( 316 self, 317 key: Union[ 318 None, SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]] 319 ] = None, 320 ): 321 """Copy a tensor element to a standard Python scalar and return it.""" 322 if key is None: 323 ret = self._data.item() 324 else: 325 ret = self[key]._data.item() 326 327 assert isinstance(ret, (bool, float, int)) 328 return ret 329 330 def mean(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 331 return self.__class__.from_xarray(self._data.mean(dim=dim)) 332 333 def pad( 334 self, 335 pad_width: PerAxis[PadWidthLike], 336 mode: PadMode = "symmetric", 337 ) -> Self: 338 pad_width = {a: PadWidth.create(p) for a, p in pad_width.items()} 339 return self.__class__.from_xarray( 340 self._data.pad(pad_width=pad_width, mode=mode) 341 ) 342 343 def pad_to( 344 self, 345 sizes: PerAxis[int], 346 pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right", 347 mode: PadMode = "symmetric", 348 ) -> Self: 349 """pad `tensor` to match `sizes`""" 350 if isinstance(pad_where, str): 351 pad_axis_where: PerAxis[PadWhere] = {a: pad_where for a in self.dims} 352 else: 353 pad_axis_where = pad_where 354 355 pad_width: Dict[AxisId, PadWidth] = {} 356 for a, s_is in self.sizes.items(): 357 if a not in sizes or sizes[a] == s_is: 358 pad_width[a] = PadWidth(0, 0) 359 elif s_is > sizes[a]: 360 pad_width[a] = PadWidth(0, 0) 361 logger.warning( 362 "Cannot pad axis {} of size {} to smaller size {}", 363 a, 364 s_is, 365 sizes[a], 366 ) 367 elif a not in pad_axis_where: 368 raise ValueError( 369 f"Don't know where to pad axis {a}, `pad_where`={pad_where}" 370 ) 371 else: 372 pad_this_axis_where = pad_axis_where[a] 373 d = sizes[a] - s_is 374 if pad_this_axis_where == "left": 375 pad_width[a] = PadWidth(d, 0) 376 elif pad_this_axis_where == "right": 377 pad_width[a] = PadWidth(0, d) 378 elif pad_this_axis_where == "left_and_right": 379 pad_width[a] = PadWidth(left := d // 2, d - left) 380 else: 381 assert_never(pad_this_axis_where) 382 383 return self.pad(pad_width, mode) 384 385 def quantile( 386 self, 387 q: Union[float, Sequence[float]], 388 dim: Optional[Union[AxisId, Sequence[AxisId]]] = None, 389 ) -> Self: 390 assert ( 391 isinstance(q, (float, int)) 392 and q >= 0.0 393 or not isinstance(q, (float, int)) 394 and all(qq >= 0.0 for qq in q) 395 ) 396 assert ( 397 isinstance(q, (float, int)) 398 and q <= 1.0 399 or not isinstance(q, (float, int)) 400 and all(qq <= 1.0 for qq in q) 401 ) 402 assert dim is None or ( 403 (quantile_dim := AxisId("quantile")) != dim and quantile_dim not in set(dim) 404 ) 405 return self.__class__.from_xarray(self._data.quantile(q, dim=dim)) 406 407 def resize_to( 408 self, 409 sizes: PerAxis[int], 410 *, 411 pad_where: Union[ 412 PadWhere, 413 PerAxis[PadWhere], 414 ] = "left_and_right", 415 crop_where: Union[ 416 CropWhere, 417 PerAxis[CropWhere], 418 ] = "left_and_right", 419 pad_mode: PadMode = "symmetric", 420 ): 421 """return cropped/padded tensor with `sizes`""" 422 crop_to_sizes: Dict[AxisId, int] = {} 423 pad_to_sizes: Dict[AxisId, int] = {} 424 new_axes = dict(sizes) 425 for a, s_is in self.sizes.items(): 426 a = AxisId(str(a)) 427 _ = new_axes.pop(a, None) 428 if a not in sizes or sizes[a] == s_is: 429 pass 430 elif s_is > sizes[a]: 431 crop_to_sizes[a] = sizes[a] 432 else: 433 pad_to_sizes[a] = sizes[a] 434 435 tensor = self 436 if crop_to_sizes: 437 tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where) 438 439 if pad_to_sizes: 440 tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode) 441 442 if new_axes: 443 tensor = tensor.expand_dims(new_axes) 444 445 return tensor 446 447 def std(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 448 return self.__class__.from_xarray(self._data.std(dim=dim)) 449 450 def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 451 """Reduce this Tensor's data by applying sum along some dimension(s).""" 452 return self.__class__.from_xarray(self._data.sum(dim=dim)) 453 454 def transpose( 455 self, 456 axes: Sequence[AxisId], 457 ) -> Self: 458 """return a transposed tensor 459 460 Args: 461 axes: the desired tensor axes 462 """ 463 # expand missing tensor axes 464 missing_axes = tuple(a for a in axes if a not in self.dims) 465 array = self._data 466 if missing_axes: 467 array = array.expand_dims(missing_axes) 468 469 # transpose to the correct axis order 470 return self.__class__.from_xarray(array.transpose(*axes)) 471 472 def var(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 473 return self.__class__.from_xarray(self._data.var(dim=dim)) 474 475 @classmethod 476 def _interprete_array_wo_known_axes(cls, array: NDArray[Any]): 477 ndim = array.ndim 478 if ndim == 2: 479 current_axes = ( 480 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[0]), 481 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[1]), 482 ) 483 elif ndim == 3 and any(s <= 3 for s in array.shape): 484 current_axes = ( 485 v0_5.ChannelAxis( 486 channel_names=[ 487 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0]) 488 ] 489 ), 490 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]), 491 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]), 492 ) 493 elif ndim == 3: 494 current_axes = ( 495 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[0]), 496 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]), 497 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]), 498 ) 499 elif ndim == 4: 500 current_axes = ( 501 v0_5.ChannelAxis( 502 channel_names=[ 503 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0]) 504 ] 505 ), 506 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[1]), 507 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[2]), 508 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[3]), 509 ) 510 elif ndim == 5: 511 current_axes = ( 512 v0_5.BatchAxis(), 513 v0_5.ChannelAxis( 514 channel_names=[ 515 v0_5.Identifier(f"channel{i}") for i in range(array.shape[1]) 516 ] 517 ), 518 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[2]), 519 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[3]), 520 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[4]), 521 ) 522 else: 523 raise ValueError(f"Could not guess an axis mapping for {array.shape}") 524 525 return cls(array, dims=tuple(a.id for a in current_axes)) 526 527 528def _add_singletons(arr: NDArray[Any], axis_infos: Sequence[AxisInfo]): 529 if len(arr.shape) > len(axis_infos): 530 # remove singletons 531 for i, s in enumerate(arr.shape): 532 if s == 1: 533 arr = np.take(arr, 0, axis=i) 534 if len(arr.shape) == len(axis_infos): 535 break 536 537 # add singletons if nececsary 538 for i, a in enumerate(axis_infos): 539 if len(arr.shape) >= len(axis_infos): 540 break 541 542 if a.maybe_singleton: 543 arr = np.expand_dims(arr, i) 544 545 return arr 546 547 548def _get_array_view( 549 original_array: NDArray[Any], axis_infos: Sequence[AxisInfo] 550) -> Optional[NDArray[Any]]: 551 perms = list(permutations(range(len(original_array.shape)))) 552 perms.insert(1, perms.pop()) # try A and A.T first 553 554 for perm in perms: 555 view = original_array.transpose(perm) 556 view = _add_singletons(view, axis_infos) 557 if len(view.shape) != len(axis_infos): 558 return None 559 560 for s, a in zip(view.shape, axis_infos): 561 if s == 1 and not a.maybe_singleton: 562 break 563 else: 564 return view 565 566 return None
49class Tensor(MagicTensorOpsMixin): 50 """A wrapper around an xr.DataArray for better integration with bioimageio.spec 51 and improved type annotations.""" 52 53 _Compatible = Union["Tensor", xr.DataArray, _ScalarOrArray] 54 55 def __init__( 56 self, 57 array: NDArray[Any], 58 dims: Sequence[Union[AxisId, AxisLike]], 59 ) -> None: 60 super().__init__() 61 axes = tuple( 62 a if isinstance(a, AxisId) else AxisInfo.create(a).id for a in dims 63 ) 64 self._data = xr.DataArray(array, dims=axes) 65 66 def __array__(self, dtype: DTypeLike = None): 67 return np.asarray(self._data, dtype=dtype) 68 69 def __getitem__( 70 self, 71 key: Union[ 72 SliceInfo, 73 slice, 74 int, 75 PerAxis[Union[SliceInfo, slice, int]], 76 Tensor, 77 xr.DataArray, 78 ], 79 ) -> Self: 80 if isinstance(key, SliceInfo): 81 key = slice(*key) 82 elif isinstance(key, collections.abc.Mapping): 83 key = { 84 a: s if isinstance(s, int) else s if isinstance(s, slice) else slice(*s) 85 for a, s in key.items() 86 } 87 elif isinstance(key, Tensor): 88 key = key._data 89 90 return self.__class__.from_xarray(self._data[key]) 91 92 def __setitem__( 93 self, 94 key: Union[PerAxis[Union[SliceInfo, slice]], Tensor, xr.DataArray], 95 value: Union[Tensor, xr.DataArray, float, int], 96 ) -> None: 97 if isinstance(key, Tensor): 98 key = key._data 99 elif isinstance(key, xr.DataArray): 100 pass 101 else: 102 key = {a: s if isinstance(s, slice) else slice(*s) for a, s in key.items()} 103 104 if isinstance(value, Tensor): 105 value = value._data 106 107 self._data[key] = value 108 109 def __len__(self) -> int: 110 return len(self.data) 111 112 def _iter(self: Any) -> Iterator[Any]: 113 for n in range(len(self)): 114 yield self[n] 115 116 def __iter__(self: Any) -> Iterator[Any]: 117 if self.ndim == 0: 118 raise TypeError("iteration over a 0-d array") 119 return self._iter() 120 121 def _binary_op( 122 self, 123 other: _Compatible, 124 f: Callable[[Any, Any], Any], 125 reflexive: bool = False, 126 ) -> Self: 127 data = self._data._binary_op( # pyright: ignore[reportPrivateUsage] 128 (other._data if isinstance(other, Tensor) else other), 129 f, 130 reflexive, 131 ) 132 return self.__class__.from_xarray(data) 133 134 def _inplace_binary_op( 135 self, 136 other: _Compatible, 137 f: Callable[[Any, Any], Any], 138 ) -> Self: 139 _ = self._data._inplace_binary_op( # pyright: ignore[reportPrivateUsage] 140 ( 141 other_d 142 if (other_d := getattr(other, "data")) is not None 143 and isinstance( 144 other_d, 145 xr.DataArray, 146 ) 147 else other 148 ), 149 f, 150 ) 151 return self 152 153 def _unary_op(self, f: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Self: 154 data = self._data._unary_op( # pyright: ignore[reportPrivateUsage] 155 f, *args, **kwargs 156 ) 157 return self.__class__.from_xarray(data) 158 159 @classmethod 160 def from_xarray(cls, data_array: xr.DataArray) -> Self: 161 """create a `Tensor` from an xarray data array 162 163 note for internal use: this factory method is round-trip save 164 for any `Tensor`'s `data` property (an xarray.DataArray). 165 """ 166 return cls( 167 array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims) 168 ) 169 170 @classmethod 171 def from_numpy( 172 cls, 173 array: NDArray[Any], 174 *, 175 dims: Optional[Union[AxisLike, Sequence[AxisLike]]], 176 ) -> Tensor: 177 """create a `Tensor` from a numpy array 178 179 Args: 180 array: the nd numpy array 181 axes: A description of the array's axes, 182 if None axes are guessed (which might fail and raise a ValueError.) 183 184 Raises: 185 ValueError: if `axes` is None and axes guessing fails. 186 """ 187 188 if dims is None: 189 return cls._interprete_array_wo_known_axes(array) 190 elif isinstance(dims, (str, Axis, v0_5.AxisBase)): 191 dims = [dims] 192 193 axis_infos = [AxisInfo.create(a) for a in dims] 194 original_shape = tuple(array.shape) 195 196 successful_view = _get_array_view(array, axis_infos) 197 if successful_view is None: 198 raise ValueError( 199 f"Array shape {original_shape} does not map to axes {dims}" 200 ) 201 202 return Tensor(successful_view, dims=tuple(a.id for a in axis_infos)) 203 204 @property 205 def data(self): 206 return self._data 207 208 @property 209 def dims(self): # TODO: rename to `axes`? 210 """Tuple of dimension names associated with this tensor.""" 211 return cast(Tuple[AxisId, ...], self._data.dims) 212 213 @property 214 def dtype(self) -> DTypeStr: 215 dt = str(self.data.dtype) # pyright: ignore[reportUnknownArgumentType] 216 assert dt in get_args(DTypeStr) 217 return dt # pyright: ignore[reportReturnType] 218 219 @property 220 def ndim(self): 221 """Number of tensor dimensions.""" 222 return self._data.ndim 223 224 @property 225 def shape(self): 226 """Tuple of tensor axes lengths""" 227 return self._data.shape 228 229 @property 230 def shape_tuple(self): 231 """Tuple of tensor axes lengths""" 232 return self._data.shape 233 234 @property 235 def size(self): 236 """Number of elements in the tensor. 237 238 Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions. 239 """ 240 return self._data.size 241 242 @property 243 def sizes(self): 244 """Ordered, immutable mapping from axis ids to axis lengths.""" 245 return cast(Mapping[AxisId, int], self.data.sizes) 246 247 @property 248 def tagged_shape(self): 249 """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths.""" 250 return self.sizes 251 252 def argmax(self) -> Mapping[AxisId, int]: 253 ret = self._data.argmax(...) 254 assert isinstance(ret, dict) 255 return {cast(AxisId, k): cast(int, v.item()) for k, v in ret.items()} 256 257 def astype(self, dtype: DTypeStr, *, copy: bool = False): 258 """Return tensor cast to `dtype` 259 260 note: if dtype is already satisfied copy if `copy`""" 261 return self.__class__.from_xarray(self._data.astype(dtype, copy=copy)) 262 263 def clip(self, min: Optional[float] = None, max: Optional[float] = None): 264 """Return a tensor whose values are limited to [min, max]. 265 At least one of max or min must be given.""" 266 return self.__class__.from_xarray(self._data.clip(min, max)) 267 268 def crop_to( 269 self, 270 sizes: PerAxis[int], 271 crop_where: Union[ 272 CropWhere, 273 PerAxis[CropWhere], 274 ] = "left_and_right", 275 ) -> Self: 276 """crop to match `sizes`""" 277 if isinstance(crop_where, str): 278 crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims} 279 else: 280 crop_axis_where = crop_where 281 282 slices: Dict[AxisId, SliceInfo] = {} 283 284 for a, s_is in self.sizes.items(): 285 if a not in sizes or sizes[a] == s_is: 286 pass 287 elif sizes[a] > s_is: 288 logger.warning( 289 "Cannot crop axis {} of size {} to larger size {}", 290 a, 291 s_is, 292 sizes[a], 293 ) 294 elif a not in crop_axis_where: 295 raise ValueError( 296 f"Don't know where to crop axis {a}, `crop_where`={crop_where}" 297 ) 298 else: 299 crop_this_axis_where = crop_axis_where[a] 300 if crop_this_axis_where == "left": 301 slices[a] = SliceInfo(s_is - sizes[a], s_is) 302 elif crop_this_axis_where == "right": 303 slices[a] = SliceInfo(0, sizes[a]) 304 elif crop_this_axis_where == "left_and_right": 305 slices[a] = SliceInfo( 306 start := (s_is - sizes[a]) // 2, sizes[a] + start 307 ) 308 else: 309 assert_never(crop_this_axis_where) 310 311 return self[slices] 312 313 def expand_dims(self, dims: Union[Sequence[AxisId], PerAxis[int]]) -> Self: 314 return self.__class__.from_xarray(self._data.expand_dims(dims=dims)) 315 316 def item( 317 self, 318 key: Union[ 319 None, SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]] 320 ] = None, 321 ): 322 """Copy a tensor element to a standard Python scalar and return it.""" 323 if key is None: 324 ret = self._data.item() 325 else: 326 ret = self[key]._data.item() 327 328 assert isinstance(ret, (bool, float, int)) 329 return ret 330 331 def mean(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 332 return self.__class__.from_xarray(self._data.mean(dim=dim)) 333 334 def pad( 335 self, 336 pad_width: PerAxis[PadWidthLike], 337 mode: PadMode = "symmetric", 338 ) -> Self: 339 pad_width = {a: PadWidth.create(p) for a, p in pad_width.items()} 340 return self.__class__.from_xarray( 341 self._data.pad(pad_width=pad_width, mode=mode) 342 ) 343 344 def pad_to( 345 self, 346 sizes: PerAxis[int], 347 pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right", 348 mode: PadMode = "symmetric", 349 ) -> Self: 350 """pad `tensor` to match `sizes`""" 351 if isinstance(pad_where, str): 352 pad_axis_where: PerAxis[PadWhere] = {a: pad_where for a in self.dims} 353 else: 354 pad_axis_where = pad_where 355 356 pad_width: Dict[AxisId, PadWidth] = {} 357 for a, s_is in self.sizes.items(): 358 if a not in sizes or sizes[a] == s_is: 359 pad_width[a] = PadWidth(0, 0) 360 elif s_is > sizes[a]: 361 pad_width[a] = PadWidth(0, 0) 362 logger.warning( 363 "Cannot pad axis {} of size {} to smaller size {}", 364 a, 365 s_is, 366 sizes[a], 367 ) 368 elif a not in pad_axis_where: 369 raise ValueError( 370 f"Don't know where to pad axis {a}, `pad_where`={pad_where}" 371 ) 372 else: 373 pad_this_axis_where = pad_axis_where[a] 374 d = sizes[a] - s_is 375 if pad_this_axis_where == "left": 376 pad_width[a] = PadWidth(d, 0) 377 elif pad_this_axis_where == "right": 378 pad_width[a] = PadWidth(0, d) 379 elif pad_this_axis_where == "left_and_right": 380 pad_width[a] = PadWidth(left := d // 2, d - left) 381 else: 382 assert_never(pad_this_axis_where) 383 384 return self.pad(pad_width, mode) 385 386 def quantile( 387 self, 388 q: Union[float, Sequence[float]], 389 dim: Optional[Union[AxisId, Sequence[AxisId]]] = None, 390 ) -> Self: 391 assert ( 392 isinstance(q, (float, int)) 393 and q >= 0.0 394 or not isinstance(q, (float, int)) 395 and all(qq >= 0.0 for qq in q) 396 ) 397 assert ( 398 isinstance(q, (float, int)) 399 and q <= 1.0 400 or not isinstance(q, (float, int)) 401 and all(qq <= 1.0 for qq in q) 402 ) 403 assert dim is None or ( 404 (quantile_dim := AxisId("quantile")) != dim and quantile_dim not in set(dim) 405 ) 406 return self.__class__.from_xarray(self._data.quantile(q, dim=dim)) 407 408 def resize_to( 409 self, 410 sizes: PerAxis[int], 411 *, 412 pad_where: Union[ 413 PadWhere, 414 PerAxis[PadWhere], 415 ] = "left_and_right", 416 crop_where: Union[ 417 CropWhere, 418 PerAxis[CropWhere], 419 ] = "left_and_right", 420 pad_mode: PadMode = "symmetric", 421 ): 422 """return cropped/padded tensor with `sizes`""" 423 crop_to_sizes: Dict[AxisId, int] = {} 424 pad_to_sizes: Dict[AxisId, int] = {} 425 new_axes = dict(sizes) 426 for a, s_is in self.sizes.items(): 427 a = AxisId(str(a)) 428 _ = new_axes.pop(a, None) 429 if a not in sizes or sizes[a] == s_is: 430 pass 431 elif s_is > sizes[a]: 432 crop_to_sizes[a] = sizes[a] 433 else: 434 pad_to_sizes[a] = sizes[a] 435 436 tensor = self 437 if crop_to_sizes: 438 tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where) 439 440 if pad_to_sizes: 441 tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode) 442 443 if new_axes: 444 tensor = tensor.expand_dims(new_axes) 445 446 return tensor 447 448 def std(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 449 return self.__class__.from_xarray(self._data.std(dim=dim)) 450 451 def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 452 """Reduce this Tensor's data by applying sum along some dimension(s).""" 453 return self.__class__.from_xarray(self._data.sum(dim=dim)) 454 455 def transpose( 456 self, 457 axes: Sequence[AxisId], 458 ) -> Self: 459 """return a transposed tensor 460 461 Args: 462 axes: the desired tensor axes 463 """ 464 # expand missing tensor axes 465 missing_axes = tuple(a for a in axes if a not in self.dims) 466 array = self._data 467 if missing_axes: 468 array = array.expand_dims(missing_axes) 469 470 # transpose to the correct axis order 471 return self.__class__.from_xarray(array.transpose(*axes)) 472 473 def var(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 474 return self.__class__.from_xarray(self._data.var(dim=dim)) 475 476 @classmethod 477 def _interprete_array_wo_known_axes(cls, array: NDArray[Any]): 478 ndim = array.ndim 479 if ndim == 2: 480 current_axes = ( 481 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[0]), 482 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[1]), 483 ) 484 elif ndim == 3 and any(s <= 3 for s in array.shape): 485 current_axes = ( 486 v0_5.ChannelAxis( 487 channel_names=[ 488 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0]) 489 ] 490 ), 491 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]), 492 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]), 493 ) 494 elif ndim == 3: 495 current_axes = ( 496 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[0]), 497 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]), 498 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]), 499 ) 500 elif ndim == 4: 501 current_axes = ( 502 v0_5.ChannelAxis( 503 channel_names=[ 504 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0]) 505 ] 506 ), 507 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[1]), 508 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[2]), 509 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[3]), 510 ) 511 elif ndim == 5: 512 current_axes = ( 513 v0_5.BatchAxis(), 514 v0_5.ChannelAxis( 515 channel_names=[ 516 v0_5.Identifier(f"channel{i}") for i in range(array.shape[1]) 517 ] 518 ), 519 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[2]), 520 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[3]), 521 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[4]), 522 ) 523 else: 524 raise ValueError(f"Could not guess an axis mapping for {array.shape}") 525 526 return cls(array, dims=tuple(a.id for a in current_axes))
A wrapper around an xr.DataArray for better integration with bioimageio.spec and improved type annotations.
159 @classmethod 160 def from_xarray(cls, data_array: xr.DataArray) -> Self: 161 """create a `Tensor` from an xarray data array 162 163 note for internal use: this factory method is round-trip save 164 for any `Tensor`'s `data` property (an xarray.DataArray). 165 """ 166 return cls( 167 array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims) 168 )
170 @classmethod 171 def from_numpy( 172 cls, 173 array: NDArray[Any], 174 *, 175 dims: Optional[Union[AxisLike, Sequence[AxisLike]]], 176 ) -> Tensor: 177 """create a `Tensor` from a numpy array 178 179 Args: 180 array: the nd numpy array 181 axes: A description of the array's axes, 182 if None axes are guessed (which might fail and raise a ValueError.) 183 184 Raises: 185 ValueError: if `axes` is None and axes guessing fails. 186 """ 187 188 if dims is None: 189 return cls._interprete_array_wo_known_axes(array) 190 elif isinstance(dims, (str, Axis, v0_5.AxisBase)): 191 dims = [dims] 192 193 axis_infos = [AxisInfo.create(a) for a in dims] 194 original_shape = tuple(array.shape) 195 196 successful_view = _get_array_view(array, axis_infos) 197 if successful_view is None: 198 raise ValueError( 199 f"Array shape {original_shape} does not map to axes {dims}" 200 ) 201 202 return Tensor(successful_view, dims=tuple(a.id for a in axis_infos))
create a Tensor
from a numpy array
Arguments:
- array: the nd numpy array
- axes: A description of the array's axes, if None axes are guessed (which might fail and raise a ValueError.)
Raises:
- ValueError: if
axes
is None and axes guessing fails.
208 @property 209 def dims(self): # TODO: rename to `axes`? 210 """Tuple of dimension names associated with this tensor.""" 211 return cast(Tuple[AxisId, ...], self._data.dims)
Tuple of dimension names associated with this tensor.
224 @property 225 def shape(self): 226 """Tuple of tensor axes lengths""" 227 return self._data.shape
Tuple of tensor axes lengths
229 @property 230 def shape_tuple(self): 231 """Tuple of tensor axes lengths""" 232 return self._data.shape
Tuple of tensor axes lengths
234 @property 235 def size(self): 236 """Number of elements in the tensor. 237 238 Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions. 239 """ 240 return self._data.size
Number of elements in the tensor.
Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions.
242 @property 243 def sizes(self): 244 """Ordered, immutable mapping from axis ids to axis lengths.""" 245 return cast(Mapping[AxisId, int], self.data.sizes)
Ordered, immutable mapping from axis ids to axis lengths.
247 @property 248 def tagged_shape(self): 249 """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths.""" 250 return self.sizes
(alias for sizes
) Ordered, immutable mapping from axis ids to lengths.
257 def astype(self, dtype: DTypeStr, *, copy: bool = False): 258 """Return tensor cast to `dtype` 259 260 note: if dtype is already satisfied copy if `copy`""" 261 return self.__class__.from_xarray(self._data.astype(dtype, copy=copy))
Return tensor cast to dtype
note: if dtype is already satisfied copy if copy
263 def clip(self, min: Optional[float] = None, max: Optional[float] = None): 264 """Return a tensor whose values are limited to [min, max]. 265 At least one of max or min must be given.""" 266 return self.__class__.from_xarray(self._data.clip(min, max))
Return a tensor whose values are limited to [min, max]. At least one of max or min must be given.
268 def crop_to( 269 self, 270 sizes: PerAxis[int], 271 crop_where: Union[ 272 CropWhere, 273 PerAxis[CropWhere], 274 ] = "left_and_right", 275 ) -> Self: 276 """crop to match `sizes`""" 277 if isinstance(crop_where, str): 278 crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims} 279 else: 280 crop_axis_where = crop_where 281 282 slices: Dict[AxisId, SliceInfo] = {} 283 284 for a, s_is in self.sizes.items(): 285 if a not in sizes or sizes[a] == s_is: 286 pass 287 elif sizes[a] > s_is: 288 logger.warning( 289 "Cannot crop axis {} of size {} to larger size {}", 290 a, 291 s_is, 292 sizes[a], 293 ) 294 elif a not in crop_axis_where: 295 raise ValueError( 296 f"Don't know where to crop axis {a}, `crop_where`={crop_where}" 297 ) 298 else: 299 crop_this_axis_where = crop_axis_where[a] 300 if crop_this_axis_where == "left": 301 slices[a] = SliceInfo(s_is - sizes[a], s_is) 302 elif crop_this_axis_where == "right": 303 slices[a] = SliceInfo(0, sizes[a]) 304 elif crop_this_axis_where == "left_and_right": 305 slices[a] = SliceInfo( 306 start := (s_is - sizes[a]) // 2, sizes[a] + start 307 ) 308 else: 309 assert_never(crop_this_axis_where) 310 311 return self[slices]
crop to match sizes
316 def item( 317 self, 318 key: Union[ 319 None, SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]] 320 ] = None, 321 ): 322 """Copy a tensor element to a standard Python scalar and return it.""" 323 if key is None: 324 ret = self._data.item() 325 else: 326 ret = self[key]._data.item() 327 328 assert isinstance(ret, (bool, float, int)) 329 return ret
Copy a tensor element to a standard Python scalar and return it.
344 def pad_to( 345 self, 346 sizes: PerAxis[int], 347 pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right", 348 mode: PadMode = "symmetric", 349 ) -> Self: 350 """pad `tensor` to match `sizes`""" 351 if isinstance(pad_where, str): 352 pad_axis_where: PerAxis[PadWhere] = {a: pad_where for a in self.dims} 353 else: 354 pad_axis_where = pad_where 355 356 pad_width: Dict[AxisId, PadWidth] = {} 357 for a, s_is in self.sizes.items(): 358 if a not in sizes or sizes[a] == s_is: 359 pad_width[a] = PadWidth(0, 0) 360 elif s_is > sizes[a]: 361 pad_width[a] = PadWidth(0, 0) 362 logger.warning( 363 "Cannot pad axis {} of size {} to smaller size {}", 364 a, 365 s_is, 366 sizes[a], 367 ) 368 elif a not in pad_axis_where: 369 raise ValueError( 370 f"Don't know where to pad axis {a}, `pad_where`={pad_where}" 371 ) 372 else: 373 pad_this_axis_where = pad_axis_where[a] 374 d = sizes[a] - s_is 375 if pad_this_axis_where == "left": 376 pad_width[a] = PadWidth(d, 0) 377 elif pad_this_axis_where == "right": 378 pad_width[a] = PadWidth(0, d) 379 elif pad_this_axis_where == "left_and_right": 380 pad_width[a] = PadWidth(left := d // 2, d - left) 381 else: 382 assert_never(pad_this_axis_where) 383 384 return self.pad(pad_width, mode)
pad tensor
to match sizes
386 def quantile( 387 self, 388 q: Union[float, Sequence[float]], 389 dim: Optional[Union[AxisId, Sequence[AxisId]]] = None, 390 ) -> Self: 391 assert ( 392 isinstance(q, (float, int)) 393 and q >= 0.0 394 or not isinstance(q, (float, int)) 395 and all(qq >= 0.0 for qq in q) 396 ) 397 assert ( 398 isinstance(q, (float, int)) 399 and q <= 1.0 400 or not isinstance(q, (float, int)) 401 and all(qq <= 1.0 for qq in q) 402 ) 403 assert dim is None or ( 404 (quantile_dim := AxisId("quantile")) != dim and quantile_dim not in set(dim) 405 ) 406 return self.__class__.from_xarray(self._data.quantile(q, dim=dim))
408 def resize_to( 409 self, 410 sizes: PerAxis[int], 411 *, 412 pad_where: Union[ 413 PadWhere, 414 PerAxis[PadWhere], 415 ] = "left_and_right", 416 crop_where: Union[ 417 CropWhere, 418 PerAxis[CropWhere], 419 ] = "left_and_right", 420 pad_mode: PadMode = "symmetric", 421 ): 422 """return cropped/padded tensor with `sizes`""" 423 crop_to_sizes: Dict[AxisId, int] = {} 424 pad_to_sizes: Dict[AxisId, int] = {} 425 new_axes = dict(sizes) 426 for a, s_is in self.sizes.items(): 427 a = AxisId(str(a)) 428 _ = new_axes.pop(a, None) 429 if a not in sizes or sizes[a] == s_is: 430 pass 431 elif s_is > sizes[a]: 432 crop_to_sizes[a] = sizes[a] 433 else: 434 pad_to_sizes[a] = sizes[a] 435 436 tensor = self 437 if crop_to_sizes: 438 tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where) 439 440 if pad_to_sizes: 441 tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode) 442 443 if new_axes: 444 tensor = tensor.expand_dims(new_axes) 445 446 return tensor
return cropped/padded tensor with sizes
451 def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 452 """Reduce this Tensor's data by applying sum along some dimension(s).""" 453 return self.__class__.from_xarray(self._data.sum(dim=dim))
Reduce this Tensor's data by applying sum along some dimension(s).
455 def transpose( 456 self, 457 axes: Sequence[AxisId], 458 ) -> Self: 459 """return a transposed tensor 460 461 Args: 462 axes: the desired tensor axes 463 """ 464 # expand missing tensor axes 465 missing_axes = tuple(a for a in axes if a not in self.dims) 466 array = self._data 467 if missing_axes: 468 array = array.expand_dims(missing_axes) 469 470 # transpose to the correct axis order 471 return self.__class__.from_xarray(array.transpose(*axes))
return a transposed tensor
Arguments:
- axes: the desired tensor axes