bioimageio.core.tensor
1from __future__ import annotations 2 3import collections.abc 4from itertools import permutations 5from typing import ( 6 TYPE_CHECKING, 7 Any, 8 Callable, 9 Dict, 10 Iterator, 11 Mapping, 12 Optional, 13 Sequence, 14 Tuple, 15 Union, 16 cast, 17 get_args, 18) 19 20import numpy as np 21import xarray as xr 22from bioimageio.spec.model import v0_5 23from loguru import logger 24from numpy.typing import DTypeLike, NDArray 25from typing_extensions import Self, assert_never 26 27from ._magic_tensor_ops import MagicTensorOpsMixin 28from .axis import AxisId, AxisInfo, AxisLike, PerAxis 29from .common import ( 30 CropWhere, 31 DTypeStr, 32 PadMode, 33 PadWhere, 34 PadWidth, 35 PadWidthLike, 36 SliceInfo, 37) 38 39if TYPE_CHECKING: 40 from numpy.typing import ArrayLike, NDArray 41 42 43_ScalarOrArray = Union["ArrayLike", np.generic, "NDArray[Any]"] # TODO: add "DaskArray" 44 45 46# TODO: complete docstrings 47# TODO: in the long run---with improved typing in xarray---we should probably replace `Tensor` with xr.DataArray 48class Tensor(MagicTensorOpsMixin): 49 """A wrapper around an xr.DataArray for better integration with bioimageio.spec 50 and improved type annotations.""" 51 52 _Compatible = Union["Tensor", xr.DataArray, _ScalarOrArray] 53 54 def __init__( 55 self, 56 array: NDArray[Any], 57 dims: Sequence[Union[AxisId, AxisLike]], 58 ) -> None: 59 super().__init__() 60 axes = tuple( 61 a if isinstance(a, AxisId) else AxisInfo.create(a).id for a in dims 62 ) 63 self._data = xr.DataArray(array, dims=axes) 64 65 def __array__(self, dtype: DTypeLike = None): 66 return np.asarray(self._data, dtype=dtype) 67 68 def __getitem__( 69 self, 70 key: Union[ 71 SliceInfo, 72 slice, 73 int, 74 PerAxis[Union[SliceInfo, slice, int]], 75 Tensor, 76 xr.DataArray, 77 ], 78 ) -> Self: 79 if isinstance(key, SliceInfo): 80 key = slice(*key) 81 elif isinstance(key, collections.abc.Mapping): 82 key = { 83 a: s if isinstance(s, int) else s if isinstance(s, slice) else slice(*s) 84 for a, s in key.items() 85 } 86 elif isinstance(key, Tensor): 87 key = key._data 88 89 return self.__class__.from_xarray(self._data[key]) 90 91 def __setitem__( 92 self, 93 key: Union[PerAxis[Union[SliceInfo, slice]], Tensor, xr.DataArray], 94 value: Union[Tensor, xr.DataArray, float, int], 95 ) -> None: 96 if isinstance(key, Tensor): 97 key = key._data 98 elif isinstance(key, xr.DataArray): 99 pass 100 else: 101 key = {a: s if isinstance(s, slice) else slice(*s) for a, s in key.items()} 102 103 if isinstance(value, Tensor): 104 value = value._data 105 106 self._data[key] = value 107 108 def __len__(self) -> int: 109 return len(self.data) 110 111 def _iter(self: Any) -> Iterator[Any]: 112 for n in range(len(self)): 113 yield self[n] 114 115 def __iter__(self: Any) -> Iterator[Any]: 116 if self.ndim == 0: 117 raise TypeError("iteration over a 0-d array") 118 return self._iter() 119 120 def _binary_op( 121 self, 122 other: _Compatible, 123 f: Callable[[Any, Any], Any], 124 reflexive: bool = False, 125 ) -> Self: 126 data = self._data._binary_op( # pyright: ignore[reportPrivateUsage] 127 (other._data if isinstance(other, Tensor) else other), 128 f, 129 reflexive, 130 ) 131 return self.__class__.from_xarray(data) 132 133 def _inplace_binary_op( 134 self, 135 other: _Compatible, 136 f: Callable[[Any, Any], Any], 137 ) -> Self: 138 _ = self._data._inplace_binary_op( # pyright: ignore[reportPrivateUsage] 139 ( 140 other_d 141 if (other_d := getattr(other, "data")) is not None 142 and isinstance( 143 other_d, 144 xr.DataArray, 145 ) 146 else other 147 ), 148 f, 149 ) 150 return self 151 152 def _unary_op(self, f: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Self: 153 data = self._data._unary_op( # pyright: ignore[reportPrivateUsage] 154 f, *args, **kwargs 155 ) 156 return self.__class__.from_xarray(data) 157 158 @classmethod 159 def from_xarray(cls, data_array: xr.DataArray) -> Self: 160 """create a `Tensor` from an xarray data array 161 162 note for internal use: this factory method is round-trip save 163 for any `Tensor`'s `data` property (an xarray.DataArray). 164 """ 165 return cls( 166 array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims) 167 ) 168 169 @classmethod 170 def from_numpy( 171 cls, 172 array: NDArray[Any], 173 *, 174 dims: Optional[Union[AxisLike, Sequence[AxisLike]]], 175 ) -> Tensor: 176 """create a `Tensor` from a numpy array 177 178 Args: 179 array: the nd numpy array 180 axes: A description of the array's axes, 181 if None axes are guessed (which might fail and raise a ValueError.) 182 183 Raises: 184 ValueError: if `axes` is None and axes guessing fails. 185 """ 186 187 if dims is None: 188 return cls._interprete_array_wo_known_axes(array) 189 elif isinstance(dims, collections.abc.Sequence): 190 dim_seq = list(dims) 191 else: 192 dim_seq = [dims] 193 194 axis_infos = [AxisInfo.create(a) for a in dim_seq] 195 original_shape = tuple(array.shape) 196 197 successful_view = _get_array_view(array, axis_infos) 198 if successful_view is None: 199 raise ValueError( 200 f"Array shape {original_shape} does not map to axes {dims}" 201 ) 202 203 return Tensor(successful_view, dims=tuple(a.id for a in axis_infos)) 204 205 @property 206 def data(self): 207 return self._data 208 209 @property 210 def dims(self): # TODO: rename to `axes`? 211 """Tuple of dimension names associated with this tensor.""" 212 return cast(Tuple[AxisId, ...], self._data.dims) 213 214 @property 215 def dtype(self) -> DTypeStr: 216 dt = str(self.data.dtype) # pyright: ignore[reportUnknownArgumentType] 217 assert dt in get_args(DTypeStr) 218 return dt # pyright: ignore[reportReturnType] 219 220 @property 221 def ndim(self): 222 """Number of tensor dimensions.""" 223 return self._data.ndim 224 225 @property 226 def shape(self): 227 """Tuple of tensor axes lengths""" 228 return self._data.shape 229 230 @property 231 def shape_tuple(self): 232 """Tuple of tensor axes lengths""" 233 return self._data.shape 234 235 @property 236 def size(self): 237 """Number of elements in the tensor. 238 239 Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions. 240 """ 241 return self._data.size 242 243 @property 244 def sizes(self): 245 """Ordered, immutable mapping from axis ids to axis lengths.""" 246 return cast(Mapping[AxisId, int], self.data.sizes) 247 248 @property 249 def tagged_shape(self): 250 """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths.""" 251 return self.sizes 252 253 def argmax(self) -> Mapping[AxisId, int]: 254 ret = self._data.argmax(...) 255 assert isinstance(ret, dict) 256 return {cast(AxisId, k): cast(int, v.item()) for k, v in ret.items()} 257 258 def astype(self, dtype: DTypeStr, *, copy: bool = False): 259 """Return tensor cast to `dtype` 260 261 note: if dtype is already satisfied copy if `copy`""" 262 return self.__class__.from_xarray(self._data.astype(dtype, copy=copy)) 263 264 def clip(self, min: Optional[float] = None, max: Optional[float] = None): 265 """Return a tensor whose values are limited to [min, max]. 266 At least one of max or min must be given.""" 267 return self.__class__.from_xarray(self._data.clip(min, max)) 268 269 def crop_to( 270 self, 271 sizes: PerAxis[int], 272 crop_where: Union[ 273 CropWhere, 274 PerAxis[CropWhere], 275 ] = "left_and_right", 276 ) -> Self: 277 """crop to match `sizes`""" 278 if isinstance(crop_where, str): 279 crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims} 280 else: 281 crop_axis_where = crop_where 282 283 slices: Dict[AxisId, SliceInfo] = {} 284 285 for a, s_is in self.sizes.items(): 286 if a not in sizes or sizes[a] == s_is: 287 pass 288 elif sizes[a] > s_is: 289 logger.warning( 290 "Cannot crop axis {} of size {} to larger size {}", 291 a, 292 s_is, 293 sizes[a], 294 ) 295 elif a not in crop_axis_where: 296 raise ValueError( 297 f"Don't know where to crop axis {a}, `crop_where`={crop_where}" 298 ) 299 else: 300 crop_this_axis_where = crop_axis_where[a] 301 if crop_this_axis_where == "left": 302 slices[a] = SliceInfo(s_is - sizes[a], s_is) 303 elif crop_this_axis_where == "right": 304 slices[a] = SliceInfo(0, sizes[a]) 305 elif crop_this_axis_where == "left_and_right": 306 slices[a] = SliceInfo( 307 start := (s_is - sizes[a]) // 2, sizes[a] + start 308 ) 309 else: 310 assert_never(crop_this_axis_where) 311 312 return self[slices] 313 314 def expand_dims(self, dims: Union[Sequence[AxisId], PerAxis[int]]) -> Self: 315 return self.__class__.from_xarray(self._data.expand_dims(dims=dims)) 316 317 def item( 318 self, 319 key: Union[ 320 None, SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]] 321 ] = None, 322 ): 323 """Copy a tensor element to a standard Python scalar and return it.""" 324 if key is None: 325 ret = self._data.item() 326 else: 327 ret = self[key]._data.item() 328 329 assert isinstance(ret, (bool, float, int)) 330 return ret 331 332 def mean(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 333 return self.__class__.from_xarray(self._data.mean(dim=dim)) 334 335 def pad( 336 self, 337 pad_width: PerAxis[PadWidthLike], 338 mode: PadMode = "symmetric", 339 ) -> Self: 340 pad_width = {a: PadWidth.create(p) for a, p in pad_width.items()} 341 return self.__class__.from_xarray( 342 self._data.pad(pad_width=pad_width, mode=mode) 343 ) 344 345 def pad_to( 346 self, 347 sizes: PerAxis[int], 348 pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right", 349 mode: PadMode = "symmetric", 350 ) -> Self: 351 """pad `tensor` to match `sizes`""" 352 if isinstance(pad_where, str): 353 pad_axis_where: PerAxis[PadWhere] = {a: pad_where for a in self.dims} 354 else: 355 pad_axis_where = pad_where 356 357 pad_width: Dict[AxisId, PadWidth] = {} 358 for a, s_is in self.sizes.items(): 359 if a not in sizes or sizes[a] == s_is: 360 pad_width[a] = PadWidth(0, 0) 361 elif s_is > sizes[a]: 362 pad_width[a] = PadWidth(0, 0) 363 logger.warning( 364 "Cannot pad axis {} of size {} to smaller size {}", 365 a, 366 s_is, 367 sizes[a], 368 ) 369 elif a not in pad_axis_where: 370 raise ValueError( 371 f"Don't know where to pad axis {a}, `pad_where`={pad_where}" 372 ) 373 else: 374 pad_this_axis_where = pad_axis_where[a] 375 d = sizes[a] - s_is 376 if pad_this_axis_where == "left": 377 pad_width[a] = PadWidth(d, 0) 378 elif pad_this_axis_where == "right": 379 pad_width[a] = PadWidth(0, d) 380 elif pad_this_axis_where == "left_and_right": 381 pad_width[a] = PadWidth(left := d // 2, d - left) 382 else: 383 assert_never(pad_this_axis_where) 384 385 return self.pad(pad_width, mode) 386 387 def quantile( 388 self, 389 q: Union[float, Sequence[float]], 390 dim: Optional[Union[AxisId, Sequence[AxisId]]] = None, 391 ) -> Self: 392 assert ( 393 isinstance(q, (float, int)) 394 and q >= 0.0 395 or not isinstance(q, (float, int)) 396 and all(qq >= 0.0 for qq in q) 397 ) 398 assert ( 399 isinstance(q, (float, int)) 400 and q <= 1.0 401 or not isinstance(q, (float, int)) 402 and all(qq <= 1.0 for qq in q) 403 ) 404 assert dim is None or ( 405 (quantile_dim := AxisId("quantile")) != dim and quantile_dim not in set(dim) 406 ) 407 return self.__class__.from_xarray(self._data.quantile(q, dim=dim)) 408 409 def resize_to( 410 self, 411 sizes: PerAxis[int], 412 *, 413 pad_where: Union[ 414 PadWhere, 415 PerAxis[PadWhere], 416 ] = "left_and_right", 417 crop_where: Union[ 418 CropWhere, 419 PerAxis[CropWhere], 420 ] = "left_and_right", 421 pad_mode: PadMode = "symmetric", 422 ): 423 """return cropped/padded tensor with `sizes`""" 424 crop_to_sizes: Dict[AxisId, int] = {} 425 pad_to_sizes: Dict[AxisId, int] = {} 426 new_axes = dict(sizes) 427 for a, s_is in self.sizes.items(): 428 a = AxisId(str(a)) 429 _ = new_axes.pop(a, None) 430 if a not in sizes or sizes[a] == s_is: 431 pass 432 elif s_is > sizes[a]: 433 crop_to_sizes[a] = sizes[a] 434 else: 435 pad_to_sizes[a] = sizes[a] 436 437 tensor = self 438 if crop_to_sizes: 439 tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where) 440 441 if pad_to_sizes: 442 tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode) 443 444 if new_axes: 445 tensor = tensor.expand_dims(new_axes) 446 447 return tensor 448 449 def std(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 450 return self.__class__.from_xarray(self._data.std(dim=dim)) 451 452 def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 453 """Reduce this Tensor's data by applying sum along some dimension(s).""" 454 return self.__class__.from_xarray(self._data.sum(dim=dim)) 455 456 def transpose( 457 self, 458 axes: Sequence[AxisId], 459 ) -> Self: 460 """return a transposed tensor 461 462 Args: 463 axes: the desired tensor axes 464 """ 465 # expand missing tensor axes 466 missing_axes = tuple(a for a in axes if a not in self.dims) 467 array = self._data 468 if missing_axes: 469 array = array.expand_dims(missing_axes) 470 471 # transpose to the correct axis order 472 return self.__class__.from_xarray(array.transpose(*axes)) 473 474 def var(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 475 return self.__class__.from_xarray(self._data.var(dim=dim)) 476 477 @classmethod 478 def _interprete_array_wo_known_axes(cls, array: NDArray[Any]): 479 ndim = array.ndim 480 if ndim == 2: 481 current_axes = ( 482 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[0]), 483 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[1]), 484 ) 485 elif ndim == 3 and any(s <= 3 for s in array.shape): 486 current_axes = ( 487 v0_5.ChannelAxis( 488 channel_names=[ 489 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0]) 490 ] 491 ), 492 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]), 493 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]), 494 ) 495 elif ndim == 3: 496 current_axes = ( 497 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[0]), 498 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]), 499 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]), 500 ) 501 elif ndim == 4: 502 current_axes = ( 503 v0_5.ChannelAxis( 504 channel_names=[ 505 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0]) 506 ] 507 ), 508 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[1]), 509 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[2]), 510 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[3]), 511 ) 512 elif ndim == 5: 513 current_axes = ( 514 v0_5.BatchAxis(), 515 v0_5.ChannelAxis( 516 channel_names=[ 517 v0_5.Identifier(f"channel{i}") for i in range(array.shape[1]) 518 ] 519 ), 520 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[2]), 521 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[3]), 522 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[4]), 523 ) 524 else: 525 raise ValueError(f"Could not guess an axis mapping for {array.shape}") 526 527 return cls(array, dims=tuple(a.id for a in current_axes)) 528 529 530def _add_singletons(arr: NDArray[Any], axis_infos: Sequence[AxisInfo]): 531 if len(arr.shape) > len(axis_infos): 532 # remove singletons 533 for i, s in enumerate(arr.shape): 534 if s == 1: 535 arr = np.take(arr, 0, axis=i) 536 if len(arr.shape) == len(axis_infos): 537 break 538 539 # add singletons if nececsary 540 for i, a in enumerate(axis_infos): 541 if len(arr.shape) >= len(axis_infos): 542 break 543 544 if a.maybe_singleton: 545 arr = np.expand_dims(arr, i) 546 547 return arr 548 549 550def _get_array_view( 551 original_array: NDArray[Any], axis_infos: Sequence[AxisInfo] 552) -> Optional[NDArray[Any]]: 553 perms = list(permutations(range(len(original_array.shape)))) 554 perms.insert(1, perms.pop()) # try A and A.T first 555 556 for perm in perms: 557 view = original_array.transpose(perm) 558 view = _add_singletons(view, axis_infos) 559 if len(view.shape) != len(axis_infos): 560 return None 561 562 for s, a in zip(view.shape, axis_infos): 563 if s == 1 and not a.maybe_singleton: 564 break 565 else: 566 return view 567 568 return None
49class Tensor(MagicTensorOpsMixin): 50 """A wrapper around an xr.DataArray for better integration with bioimageio.spec 51 and improved type annotations.""" 52 53 _Compatible = Union["Tensor", xr.DataArray, _ScalarOrArray] 54 55 def __init__( 56 self, 57 array: NDArray[Any], 58 dims: Sequence[Union[AxisId, AxisLike]], 59 ) -> None: 60 super().__init__() 61 axes = tuple( 62 a if isinstance(a, AxisId) else AxisInfo.create(a).id for a in dims 63 ) 64 self._data = xr.DataArray(array, dims=axes) 65 66 def __array__(self, dtype: DTypeLike = None): 67 return np.asarray(self._data, dtype=dtype) 68 69 def __getitem__( 70 self, 71 key: Union[ 72 SliceInfo, 73 slice, 74 int, 75 PerAxis[Union[SliceInfo, slice, int]], 76 Tensor, 77 xr.DataArray, 78 ], 79 ) -> Self: 80 if isinstance(key, SliceInfo): 81 key = slice(*key) 82 elif isinstance(key, collections.abc.Mapping): 83 key = { 84 a: s if isinstance(s, int) else s if isinstance(s, slice) else slice(*s) 85 for a, s in key.items() 86 } 87 elif isinstance(key, Tensor): 88 key = key._data 89 90 return self.__class__.from_xarray(self._data[key]) 91 92 def __setitem__( 93 self, 94 key: Union[PerAxis[Union[SliceInfo, slice]], Tensor, xr.DataArray], 95 value: Union[Tensor, xr.DataArray, float, int], 96 ) -> None: 97 if isinstance(key, Tensor): 98 key = key._data 99 elif isinstance(key, xr.DataArray): 100 pass 101 else: 102 key = {a: s if isinstance(s, slice) else slice(*s) for a, s in key.items()} 103 104 if isinstance(value, Tensor): 105 value = value._data 106 107 self._data[key] = value 108 109 def __len__(self) -> int: 110 return len(self.data) 111 112 def _iter(self: Any) -> Iterator[Any]: 113 for n in range(len(self)): 114 yield self[n] 115 116 def __iter__(self: Any) -> Iterator[Any]: 117 if self.ndim == 0: 118 raise TypeError("iteration over a 0-d array") 119 return self._iter() 120 121 def _binary_op( 122 self, 123 other: _Compatible, 124 f: Callable[[Any, Any], Any], 125 reflexive: bool = False, 126 ) -> Self: 127 data = self._data._binary_op( # pyright: ignore[reportPrivateUsage] 128 (other._data if isinstance(other, Tensor) else other), 129 f, 130 reflexive, 131 ) 132 return self.__class__.from_xarray(data) 133 134 def _inplace_binary_op( 135 self, 136 other: _Compatible, 137 f: Callable[[Any, Any], Any], 138 ) -> Self: 139 _ = self._data._inplace_binary_op( # pyright: ignore[reportPrivateUsage] 140 ( 141 other_d 142 if (other_d := getattr(other, "data")) is not None 143 and isinstance( 144 other_d, 145 xr.DataArray, 146 ) 147 else other 148 ), 149 f, 150 ) 151 return self 152 153 def _unary_op(self, f: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Self: 154 data = self._data._unary_op( # pyright: ignore[reportPrivateUsage] 155 f, *args, **kwargs 156 ) 157 return self.__class__.from_xarray(data) 158 159 @classmethod 160 def from_xarray(cls, data_array: xr.DataArray) -> Self: 161 """create a `Tensor` from an xarray data array 162 163 note for internal use: this factory method is round-trip save 164 for any `Tensor`'s `data` property (an xarray.DataArray). 165 """ 166 return cls( 167 array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims) 168 ) 169 170 @classmethod 171 def from_numpy( 172 cls, 173 array: NDArray[Any], 174 *, 175 dims: Optional[Union[AxisLike, Sequence[AxisLike]]], 176 ) -> Tensor: 177 """create a `Tensor` from a numpy array 178 179 Args: 180 array: the nd numpy array 181 axes: A description of the array's axes, 182 if None axes are guessed (which might fail and raise a ValueError.) 183 184 Raises: 185 ValueError: if `axes` is None and axes guessing fails. 186 """ 187 188 if dims is None: 189 return cls._interprete_array_wo_known_axes(array) 190 elif isinstance(dims, collections.abc.Sequence): 191 dim_seq = list(dims) 192 else: 193 dim_seq = [dims] 194 195 axis_infos = [AxisInfo.create(a) for a in dim_seq] 196 original_shape = tuple(array.shape) 197 198 successful_view = _get_array_view(array, axis_infos) 199 if successful_view is None: 200 raise ValueError( 201 f"Array shape {original_shape} does not map to axes {dims}" 202 ) 203 204 return Tensor(successful_view, dims=tuple(a.id for a in axis_infos)) 205 206 @property 207 def data(self): 208 return self._data 209 210 @property 211 def dims(self): # TODO: rename to `axes`? 212 """Tuple of dimension names associated with this tensor.""" 213 return cast(Tuple[AxisId, ...], self._data.dims) 214 215 @property 216 def dtype(self) -> DTypeStr: 217 dt = str(self.data.dtype) # pyright: ignore[reportUnknownArgumentType] 218 assert dt in get_args(DTypeStr) 219 return dt # pyright: ignore[reportReturnType] 220 221 @property 222 def ndim(self): 223 """Number of tensor dimensions.""" 224 return self._data.ndim 225 226 @property 227 def shape(self): 228 """Tuple of tensor axes lengths""" 229 return self._data.shape 230 231 @property 232 def shape_tuple(self): 233 """Tuple of tensor axes lengths""" 234 return self._data.shape 235 236 @property 237 def size(self): 238 """Number of elements in the tensor. 239 240 Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions. 241 """ 242 return self._data.size 243 244 @property 245 def sizes(self): 246 """Ordered, immutable mapping from axis ids to axis lengths.""" 247 return cast(Mapping[AxisId, int], self.data.sizes) 248 249 @property 250 def tagged_shape(self): 251 """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths.""" 252 return self.sizes 253 254 def argmax(self) -> Mapping[AxisId, int]: 255 ret = self._data.argmax(...) 256 assert isinstance(ret, dict) 257 return {cast(AxisId, k): cast(int, v.item()) for k, v in ret.items()} 258 259 def astype(self, dtype: DTypeStr, *, copy: bool = False): 260 """Return tensor cast to `dtype` 261 262 note: if dtype is already satisfied copy if `copy`""" 263 return self.__class__.from_xarray(self._data.astype(dtype, copy=copy)) 264 265 def clip(self, min: Optional[float] = None, max: Optional[float] = None): 266 """Return a tensor whose values are limited to [min, max]. 267 At least one of max or min must be given.""" 268 return self.__class__.from_xarray(self._data.clip(min, max)) 269 270 def crop_to( 271 self, 272 sizes: PerAxis[int], 273 crop_where: Union[ 274 CropWhere, 275 PerAxis[CropWhere], 276 ] = "left_and_right", 277 ) -> Self: 278 """crop to match `sizes`""" 279 if isinstance(crop_where, str): 280 crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims} 281 else: 282 crop_axis_where = crop_where 283 284 slices: Dict[AxisId, SliceInfo] = {} 285 286 for a, s_is in self.sizes.items(): 287 if a not in sizes or sizes[a] == s_is: 288 pass 289 elif sizes[a] > s_is: 290 logger.warning( 291 "Cannot crop axis {} of size {} to larger size {}", 292 a, 293 s_is, 294 sizes[a], 295 ) 296 elif a not in crop_axis_where: 297 raise ValueError( 298 f"Don't know where to crop axis {a}, `crop_where`={crop_where}" 299 ) 300 else: 301 crop_this_axis_where = crop_axis_where[a] 302 if crop_this_axis_where == "left": 303 slices[a] = SliceInfo(s_is - sizes[a], s_is) 304 elif crop_this_axis_where == "right": 305 slices[a] = SliceInfo(0, sizes[a]) 306 elif crop_this_axis_where == "left_and_right": 307 slices[a] = SliceInfo( 308 start := (s_is - sizes[a]) // 2, sizes[a] + start 309 ) 310 else: 311 assert_never(crop_this_axis_where) 312 313 return self[slices] 314 315 def expand_dims(self, dims: Union[Sequence[AxisId], PerAxis[int]]) -> Self: 316 return self.__class__.from_xarray(self._data.expand_dims(dims=dims)) 317 318 def item( 319 self, 320 key: Union[ 321 None, SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]] 322 ] = None, 323 ): 324 """Copy a tensor element to a standard Python scalar and return it.""" 325 if key is None: 326 ret = self._data.item() 327 else: 328 ret = self[key]._data.item() 329 330 assert isinstance(ret, (bool, float, int)) 331 return ret 332 333 def mean(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 334 return self.__class__.from_xarray(self._data.mean(dim=dim)) 335 336 def pad( 337 self, 338 pad_width: PerAxis[PadWidthLike], 339 mode: PadMode = "symmetric", 340 ) -> Self: 341 pad_width = {a: PadWidth.create(p) for a, p in pad_width.items()} 342 return self.__class__.from_xarray( 343 self._data.pad(pad_width=pad_width, mode=mode) 344 ) 345 346 def pad_to( 347 self, 348 sizes: PerAxis[int], 349 pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right", 350 mode: PadMode = "symmetric", 351 ) -> Self: 352 """pad `tensor` to match `sizes`""" 353 if isinstance(pad_where, str): 354 pad_axis_where: PerAxis[PadWhere] = {a: pad_where for a in self.dims} 355 else: 356 pad_axis_where = pad_where 357 358 pad_width: Dict[AxisId, PadWidth] = {} 359 for a, s_is in self.sizes.items(): 360 if a not in sizes or sizes[a] == s_is: 361 pad_width[a] = PadWidth(0, 0) 362 elif s_is > sizes[a]: 363 pad_width[a] = PadWidth(0, 0) 364 logger.warning( 365 "Cannot pad axis {} of size {} to smaller size {}", 366 a, 367 s_is, 368 sizes[a], 369 ) 370 elif a not in pad_axis_where: 371 raise ValueError( 372 f"Don't know where to pad axis {a}, `pad_where`={pad_where}" 373 ) 374 else: 375 pad_this_axis_where = pad_axis_where[a] 376 d = sizes[a] - s_is 377 if pad_this_axis_where == "left": 378 pad_width[a] = PadWidth(d, 0) 379 elif pad_this_axis_where == "right": 380 pad_width[a] = PadWidth(0, d) 381 elif pad_this_axis_where == "left_and_right": 382 pad_width[a] = PadWidth(left := d // 2, d - left) 383 else: 384 assert_never(pad_this_axis_where) 385 386 return self.pad(pad_width, mode) 387 388 def quantile( 389 self, 390 q: Union[float, Sequence[float]], 391 dim: Optional[Union[AxisId, Sequence[AxisId]]] = None, 392 ) -> Self: 393 assert ( 394 isinstance(q, (float, int)) 395 and q >= 0.0 396 or not isinstance(q, (float, int)) 397 and all(qq >= 0.0 for qq in q) 398 ) 399 assert ( 400 isinstance(q, (float, int)) 401 and q <= 1.0 402 or not isinstance(q, (float, int)) 403 and all(qq <= 1.0 for qq in q) 404 ) 405 assert dim is None or ( 406 (quantile_dim := AxisId("quantile")) != dim and quantile_dim not in set(dim) 407 ) 408 return self.__class__.from_xarray(self._data.quantile(q, dim=dim)) 409 410 def resize_to( 411 self, 412 sizes: PerAxis[int], 413 *, 414 pad_where: Union[ 415 PadWhere, 416 PerAxis[PadWhere], 417 ] = "left_and_right", 418 crop_where: Union[ 419 CropWhere, 420 PerAxis[CropWhere], 421 ] = "left_and_right", 422 pad_mode: PadMode = "symmetric", 423 ): 424 """return cropped/padded tensor with `sizes`""" 425 crop_to_sizes: Dict[AxisId, int] = {} 426 pad_to_sizes: Dict[AxisId, int] = {} 427 new_axes = dict(sizes) 428 for a, s_is in self.sizes.items(): 429 a = AxisId(str(a)) 430 _ = new_axes.pop(a, None) 431 if a not in sizes or sizes[a] == s_is: 432 pass 433 elif s_is > sizes[a]: 434 crop_to_sizes[a] = sizes[a] 435 else: 436 pad_to_sizes[a] = sizes[a] 437 438 tensor = self 439 if crop_to_sizes: 440 tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where) 441 442 if pad_to_sizes: 443 tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode) 444 445 if new_axes: 446 tensor = tensor.expand_dims(new_axes) 447 448 return tensor 449 450 def std(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 451 return self.__class__.from_xarray(self._data.std(dim=dim)) 452 453 def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 454 """Reduce this Tensor's data by applying sum along some dimension(s).""" 455 return self.__class__.from_xarray(self._data.sum(dim=dim)) 456 457 def transpose( 458 self, 459 axes: Sequence[AxisId], 460 ) -> Self: 461 """return a transposed tensor 462 463 Args: 464 axes: the desired tensor axes 465 """ 466 # expand missing tensor axes 467 missing_axes = tuple(a for a in axes if a not in self.dims) 468 array = self._data 469 if missing_axes: 470 array = array.expand_dims(missing_axes) 471 472 # transpose to the correct axis order 473 return self.__class__.from_xarray(array.transpose(*axes)) 474 475 def var(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 476 return self.__class__.from_xarray(self._data.var(dim=dim)) 477 478 @classmethod 479 def _interprete_array_wo_known_axes(cls, array: NDArray[Any]): 480 ndim = array.ndim 481 if ndim == 2: 482 current_axes = ( 483 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[0]), 484 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[1]), 485 ) 486 elif ndim == 3 and any(s <= 3 for s in array.shape): 487 current_axes = ( 488 v0_5.ChannelAxis( 489 channel_names=[ 490 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0]) 491 ] 492 ), 493 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]), 494 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]), 495 ) 496 elif ndim == 3: 497 current_axes = ( 498 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[0]), 499 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]), 500 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]), 501 ) 502 elif ndim == 4: 503 current_axes = ( 504 v0_5.ChannelAxis( 505 channel_names=[ 506 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0]) 507 ] 508 ), 509 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[1]), 510 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[2]), 511 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[3]), 512 ) 513 elif ndim == 5: 514 current_axes = ( 515 v0_5.BatchAxis(), 516 v0_5.ChannelAxis( 517 channel_names=[ 518 v0_5.Identifier(f"channel{i}") for i in range(array.shape[1]) 519 ] 520 ), 521 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[2]), 522 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[3]), 523 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[4]), 524 ) 525 else: 526 raise ValueError(f"Could not guess an axis mapping for {array.shape}") 527 528 return cls(array, dims=tuple(a.id for a in current_axes))
A wrapper around an xr.DataArray for better integration with bioimageio.spec and improved type annotations.
159 @classmethod 160 def from_xarray(cls, data_array: xr.DataArray) -> Self: 161 """create a `Tensor` from an xarray data array 162 163 note for internal use: this factory method is round-trip save 164 for any `Tensor`'s `data` property (an xarray.DataArray). 165 """ 166 return cls( 167 array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims) 168 )
170 @classmethod 171 def from_numpy( 172 cls, 173 array: NDArray[Any], 174 *, 175 dims: Optional[Union[AxisLike, Sequence[AxisLike]]], 176 ) -> Tensor: 177 """create a `Tensor` from a numpy array 178 179 Args: 180 array: the nd numpy array 181 axes: A description of the array's axes, 182 if None axes are guessed (which might fail and raise a ValueError.) 183 184 Raises: 185 ValueError: if `axes` is None and axes guessing fails. 186 """ 187 188 if dims is None: 189 return cls._interprete_array_wo_known_axes(array) 190 elif isinstance(dims, collections.abc.Sequence): 191 dim_seq = list(dims) 192 else: 193 dim_seq = [dims] 194 195 axis_infos = [AxisInfo.create(a) for a in dim_seq] 196 original_shape = tuple(array.shape) 197 198 successful_view = _get_array_view(array, axis_infos) 199 if successful_view is None: 200 raise ValueError( 201 f"Array shape {original_shape} does not map to axes {dims}" 202 ) 203 204 return Tensor(successful_view, dims=tuple(a.id for a in axis_infos))
create a Tensor
from a numpy array
Arguments:
- array: the nd numpy array
- axes: A description of the array's axes, if None axes are guessed (which might fail and raise a ValueError.)
Raises:
- ValueError: if
axes
is None and axes guessing fails.
210 @property 211 def dims(self): # TODO: rename to `axes`? 212 """Tuple of dimension names associated with this tensor.""" 213 return cast(Tuple[AxisId, ...], self._data.dims)
Tuple of dimension names associated with this tensor.
226 @property 227 def shape(self): 228 """Tuple of tensor axes lengths""" 229 return self._data.shape
Tuple of tensor axes lengths
231 @property 232 def shape_tuple(self): 233 """Tuple of tensor axes lengths""" 234 return self._data.shape
Tuple of tensor axes lengths
236 @property 237 def size(self): 238 """Number of elements in the tensor. 239 240 Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions. 241 """ 242 return self._data.size
Number of elements in the tensor.
Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions.
244 @property 245 def sizes(self): 246 """Ordered, immutable mapping from axis ids to axis lengths.""" 247 return cast(Mapping[AxisId, int], self.data.sizes)
Ordered, immutable mapping from axis ids to axis lengths.
249 @property 250 def tagged_shape(self): 251 """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths.""" 252 return self.sizes
(alias for sizes
) Ordered, immutable mapping from axis ids to lengths.
259 def astype(self, dtype: DTypeStr, *, copy: bool = False): 260 """Return tensor cast to `dtype` 261 262 note: if dtype is already satisfied copy if `copy`""" 263 return self.__class__.from_xarray(self._data.astype(dtype, copy=copy))
Return tensor cast to dtype
note: if dtype is already satisfied copy if copy
265 def clip(self, min: Optional[float] = None, max: Optional[float] = None): 266 """Return a tensor whose values are limited to [min, max]. 267 At least one of max or min must be given.""" 268 return self.__class__.from_xarray(self._data.clip(min, max))
Return a tensor whose values are limited to [min, max]. At least one of max or min must be given.
270 def crop_to( 271 self, 272 sizes: PerAxis[int], 273 crop_where: Union[ 274 CropWhere, 275 PerAxis[CropWhere], 276 ] = "left_and_right", 277 ) -> Self: 278 """crop to match `sizes`""" 279 if isinstance(crop_where, str): 280 crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims} 281 else: 282 crop_axis_where = crop_where 283 284 slices: Dict[AxisId, SliceInfo] = {} 285 286 for a, s_is in self.sizes.items(): 287 if a not in sizes or sizes[a] == s_is: 288 pass 289 elif sizes[a] > s_is: 290 logger.warning( 291 "Cannot crop axis {} of size {} to larger size {}", 292 a, 293 s_is, 294 sizes[a], 295 ) 296 elif a not in crop_axis_where: 297 raise ValueError( 298 f"Don't know where to crop axis {a}, `crop_where`={crop_where}" 299 ) 300 else: 301 crop_this_axis_where = crop_axis_where[a] 302 if crop_this_axis_where == "left": 303 slices[a] = SliceInfo(s_is - sizes[a], s_is) 304 elif crop_this_axis_where == "right": 305 slices[a] = SliceInfo(0, sizes[a]) 306 elif crop_this_axis_where == "left_and_right": 307 slices[a] = SliceInfo( 308 start := (s_is - sizes[a]) // 2, sizes[a] + start 309 ) 310 else: 311 assert_never(crop_this_axis_where) 312 313 return self[slices]
crop to match sizes
318 def item( 319 self, 320 key: Union[ 321 None, SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]] 322 ] = None, 323 ): 324 """Copy a tensor element to a standard Python scalar and return it.""" 325 if key is None: 326 ret = self._data.item() 327 else: 328 ret = self[key]._data.item() 329 330 assert isinstance(ret, (bool, float, int)) 331 return ret
Copy a tensor element to a standard Python scalar and return it.
346 def pad_to( 347 self, 348 sizes: PerAxis[int], 349 pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right", 350 mode: PadMode = "symmetric", 351 ) -> Self: 352 """pad `tensor` to match `sizes`""" 353 if isinstance(pad_where, str): 354 pad_axis_where: PerAxis[PadWhere] = {a: pad_where for a in self.dims} 355 else: 356 pad_axis_where = pad_where 357 358 pad_width: Dict[AxisId, PadWidth] = {} 359 for a, s_is in self.sizes.items(): 360 if a not in sizes or sizes[a] == s_is: 361 pad_width[a] = PadWidth(0, 0) 362 elif s_is > sizes[a]: 363 pad_width[a] = PadWidth(0, 0) 364 logger.warning( 365 "Cannot pad axis {} of size {} to smaller size {}", 366 a, 367 s_is, 368 sizes[a], 369 ) 370 elif a not in pad_axis_where: 371 raise ValueError( 372 f"Don't know where to pad axis {a}, `pad_where`={pad_where}" 373 ) 374 else: 375 pad_this_axis_where = pad_axis_where[a] 376 d = sizes[a] - s_is 377 if pad_this_axis_where == "left": 378 pad_width[a] = PadWidth(d, 0) 379 elif pad_this_axis_where == "right": 380 pad_width[a] = PadWidth(0, d) 381 elif pad_this_axis_where == "left_and_right": 382 pad_width[a] = PadWidth(left := d // 2, d - left) 383 else: 384 assert_never(pad_this_axis_where) 385 386 return self.pad(pad_width, mode)
pad tensor
to match sizes
388 def quantile( 389 self, 390 q: Union[float, Sequence[float]], 391 dim: Optional[Union[AxisId, Sequence[AxisId]]] = None, 392 ) -> Self: 393 assert ( 394 isinstance(q, (float, int)) 395 and q >= 0.0 396 or not isinstance(q, (float, int)) 397 and all(qq >= 0.0 for qq in q) 398 ) 399 assert ( 400 isinstance(q, (float, int)) 401 and q <= 1.0 402 or not isinstance(q, (float, int)) 403 and all(qq <= 1.0 for qq in q) 404 ) 405 assert dim is None or ( 406 (quantile_dim := AxisId("quantile")) != dim and quantile_dim not in set(dim) 407 ) 408 return self.__class__.from_xarray(self._data.quantile(q, dim=dim))
410 def resize_to( 411 self, 412 sizes: PerAxis[int], 413 *, 414 pad_where: Union[ 415 PadWhere, 416 PerAxis[PadWhere], 417 ] = "left_and_right", 418 crop_where: Union[ 419 CropWhere, 420 PerAxis[CropWhere], 421 ] = "left_and_right", 422 pad_mode: PadMode = "symmetric", 423 ): 424 """return cropped/padded tensor with `sizes`""" 425 crop_to_sizes: Dict[AxisId, int] = {} 426 pad_to_sizes: Dict[AxisId, int] = {} 427 new_axes = dict(sizes) 428 for a, s_is in self.sizes.items(): 429 a = AxisId(str(a)) 430 _ = new_axes.pop(a, None) 431 if a not in sizes or sizes[a] == s_is: 432 pass 433 elif s_is > sizes[a]: 434 crop_to_sizes[a] = sizes[a] 435 else: 436 pad_to_sizes[a] = sizes[a] 437 438 tensor = self 439 if crop_to_sizes: 440 tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where) 441 442 if pad_to_sizes: 443 tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode) 444 445 if new_axes: 446 tensor = tensor.expand_dims(new_axes) 447 448 return tensor
return cropped/padded tensor with sizes
453 def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 454 """Reduce this Tensor's data by applying sum along some dimension(s).""" 455 return self.__class__.from_xarray(self._data.sum(dim=dim))
Reduce this Tensor's data by applying sum along some dimension(s).
457 def transpose( 458 self, 459 axes: Sequence[AxisId], 460 ) -> Self: 461 """return a transposed tensor 462 463 Args: 464 axes: the desired tensor axes 465 """ 466 # expand missing tensor axes 467 missing_axes = tuple(a for a in axes if a not in self.dims) 468 array = self._data 469 if missing_axes: 470 array = array.expand_dims(missing_axes) 471 472 # transpose to the correct axis order 473 return self.__class__.from_xarray(array.transpose(*axes))
return a transposed tensor
Arguments:
- axes: the desired tensor axes