bioimageio.core.tensor
1from __future__ import annotations 2 3import collections.abc 4from itertools import permutations 5from typing import ( 6 TYPE_CHECKING, 7 Any, 8 Callable, 9 Dict, 10 Iterator, 11 Mapping, 12 Optional, 13 Sequence, 14 Tuple, 15 Union, 16 cast, 17 get_args, 18) 19 20import numpy as np 21import xarray as xr 22from loguru import logger 23from numpy.typing import DTypeLike, NDArray 24from typing_extensions import Self, assert_never 25 26from bioimageio.spec.model import v0_5 27 28from ._magic_tensor_ops import MagicTensorOpsMixin 29from .axis import Axis, AxisId, AxisInfo, AxisLike, PerAxis 30from .common import ( 31 CropWhere, 32 DTypeStr, 33 PadMode, 34 PadWhere, 35 PadWidth, 36 PadWidthLike, 37 SliceInfo, 38) 39 40if TYPE_CHECKING: 41 from numpy.typing import ArrayLike, NDArray 42 43 44_ScalarOrArray = Union["ArrayLike", np.generic, "NDArray[Any]"] # TODO: add "DaskArray" 45 46 47# TODO: complete docstrings 48# TODO: in the long run---with improved typing in xarray---we should probably replace `Tensor` with xr.DataArray 49class Tensor(MagicTensorOpsMixin): 50 """A wrapper around an xr.DataArray for better integration with bioimageio.spec 51 and improved type annotations.""" 52 53 _Compatible = Union["Tensor", xr.DataArray, _ScalarOrArray] 54 55 def __init__( 56 self, 57 array: NDArray[Any], 58 dims: Sequence[Union[AxisId, AxisLike]], 59 ) -> None: 60 super().__init__() 61 axes = tuple( 62 a if isinstance(a, AxisId) else AxisInfo.create(a).id for a in dims 63 ) 64 self._data = xr.DataArray(array, dims=axes) 65 66 def __array__(self, dtype: DTypeLike = None): 67 return np.asarray(self._data, dtype=dtype) 68 69 def __getitem__( 70 self, 71 key: Union[ 72 SliceInfo, 73 slice, 74 int, 75 PerAxis[Union[SliceInfo, slice, int]], 76 Tensor, 77 xr.DataArray, 78 ], 79 ) -> Self: 80 if isinstance(key, SliceInfo): 81 key = slice(*key) 82 elif isinstance(key, collections.abc.Mapping): 83 key = { 84 a: s if isinstance(s, int) else s if isinstance(s, slice) else slice(*s) 85 for a, s in key.items() 86 } 87 elif isinstance(key, Tensor): 88 key = key._data 89 90 return self.__class__.from_xarray(self._data[key]) 91 92 def __setitem__( 93 self, 94 key: Union[PerAxis[Union[SliceInfo, slice]], Tensor, xr.DataArray], 95 value: Union[Tensor, xr.DataArray, float, int], 96 ) -> None: 97 if isinstance(key, Tensor): 98 key = key._data 99 elif isinstance(key, xr.DataArray): 100 pass 101 else: 102 key = {a: s if isinstance(s, slice) else slice(*s) for a, s in key.items()} 103 104 if isinstance(value, Tensor): 105 value = value._data 106 107 self._data[key] = value 108 109 def __len__(self) -> int: 110 return len(self.data) 111 112 def _iter(self: Any) -> Iterator[Any]: 113 for n in range(len(self)): 114 yield self[n] 115 116 def __iter__(self: Any) -> Iterator[Any]: 117 if self.ndim == 0: 118 raise TypeError("iteration over a 0-d array") 119 return self._iter() 120 121 def _binary_op( 122 self, 123 other: _Compatible, 124 f: Callable[[Any, Any], Any], 125 reflexive: bool = False, 126 ) -> Self: 127 data = self._data._binary_op( # pyright: ignore[reportPrivateUsage] 128 (other._data if isinstance(other, Tensor) else other), 129 f, 130 reflexive, 131 ) 132 return self.__class__.from_xarray(data) 133 134 def _inplace_binary_op( 135 self, 136 other: _Compatible, 137 f: Callable[[Any, Any], Any], 138 ) -> Self: 139 _ = self._data._inplace_binary_op( # pyright: ignore[reportPrivateUsage] 140 ( 141 other_d 142 if (other_d := getattr(other, "data")) is not None 143 and isinstance( 144 other_d, 145 xr.DataArray, 146 ) 147 else other 148 ), 149 f, 150 ) 151 return self 152 153 def _unary_op(self, f: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Self: 154 data = self._data._unary_op( # pyright: ignore[reportPrivateUsage] 155 f, *args, **kwargs 156 ) 157 return self.__class__.from_xarray(data) 158 159 @classmethod 160 def from_xarray(cls, data_array: xr.DataArray) -> Self: 161 """create a `Tensor` from an xarray data array 162 163 note for internal use: this factory method is round-trip save 164 for any `Tensor`'s `data` property (an xarray.DataArray). 165 """ 166 return cls( 167 array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims) 168 ) 169 170 @classmethod 171 def from_numpy( 172 cls, 173 array: NDArray[Any], 174 *, 175 dims: Optional[Union[AxisLike, Sequence[AxisLike]]], 176 ) -> Tensor: 177 """create a `Tensor` from a numpy array 178 179 Args: 180 array: the nd numpy array 181 axes: A description of the array's axes, 182 if None axes are guessed (which might fail and raise a ValueError.) 183 184 Raises: 185 ValueError: if `axes` is None and axes guessing fails. 186 """ 187 188 if dims is None: 189 return cls._interprete_array_wo_known_axes(array) 190 elif isinstance(dims, (str, Axis, v0_5.AxisBase)): 191 dims = [dims] 192 193 axis_infos = [AxisInfo.create(a) for a in dims] 194 original_shape = tuple(array.shape) 195 196 successful_view = _get_array_view(array, axis_infos) 197 if successful_view is None: 198 raise ValueError( 199 f"Array shape {original_shape} does not map to axes {dims}" 200 ) 201 202 return Tensor(successful_view, dims=tuple(a.id for a in axis_infos)) 203 204 @property 205 def data(self): 206 return self._data 207 208 @property 209 def dims(self): # TODO: rename to `axes`? 210 """Tuple of dimension names associated with this tensor.""" 211 return cast(Tuple[AxisId, ...], self._data.dims) 212 213 @property 214 def dtype(self) -> DTypeStr: 215 dt = str(self.data.dtype) # pyright: ignore[reportUnknownArgumentType] 216 assert dt in get_args(DTypeStr) 217 return dt # pyright: ignore[reportReturnType] 218 219 @property 220 def ndim(self): 221 """Number of tensor dimensions.""" 222 return self._data.ndim 223 224 @property 225 def shape(self): 226 """Tuple of tensor axes lengths""" 227 return self._data.shape 228 229 @property 230 def shape_tuple(self): 231 """Tuple of tensor axes lengths""" 232 return self._data.shape 233 234 @property 235 def size(self): 236 """Number of elements in the tensor. 237 238 Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions. 239 """ 240 return self._data.size 241 242 @property 243 def sizes(self): 244 """Ordered, immutable mapping from axis ids to axis lengths.""" 245 return cast(Mapping[AxisId, int], self.data.sizes) 246 247 @property 248 def tagged_shape(self): 249 """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths.""" 250 return self.sizes 251 252 def argmax(self) -> Mapping[AxisId, int]: 253 ret = self._data.argmax(...) 254 assert isinstance(ret, dict) 255 return {cast(AxisId, k): cast(int, v.item()) for k, v in ret.items()} 256 257 def astype(self, dtype: DTypeStr, *, copy: bool = False): 258 """Return tensor cast to `dtype` 259 260 note: if dtype is already satisfied copy if `copy`""" 261 return self.__class__.from_xarray(self._data.astype(dtype, copy=copy)) 262 263 def clip(self, min: Optional[float] = None, max: Optional[float] = None): 264 """Return a tensor whose values are limited to [min, max]. 265 At least one of max or min must be given.""" 266 return self.__class__.from_xarray(self._data.clip(min, max)) 267 268 def crop_to( 269 self, 270 sizes: PerAxis[int], 271 crop_where: Union[ 272 CropWhere, 273 PerAxis[CropWhere], 274 ] = "left_and_right", 275 ) -> Self: 276 """crop to match `sizes`""" 277 if isinstance(crop_where, str): 278 crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims} 279 else: 280 crop_axis_where = crop_where 281 282 slices: Dict[AxisId, SliceInfo] = {} 283 284 for a, s_is in self.sizes.items(): 285 if a not in sizes or sizes[a] == s_is: 286 pass 287 elif sizes[a] > s_is: 288 logger.warning( 289 "Cannot crop axis {} of size {} to larger size {}", 290 a, 291 s_is, 292 sizes[a], 293 ) 294 elif a not in crop_axis_where: 295 raise ValueError( 296 f"Don't know where to crop axis {a}, `crop_where`={crop_where}" 297 ) 298 else: 299 crop_this_axis_where = crop_axis_where[a] 300 if crop_this_axis_where == "left": 301 slices[a] = SliceInfo(s_is - sizes[a], s_is) 302 elif crop_this_axis_where == "right": 303 slices[a] = SliceInfo(0, sizes[a]) 304 elif crop_this_axis_where == "left_and_right": 305 slices[a] = SliceInfo( 306 start := (s_is - sizes[a]) // 2, sizes[a] + start 307 ) 308 else: 309 assert_never(crop_this_axis_where) 310 311 return self[slices] 312 313 def expand_dims(self, dims: Union[Sequence[AxisId], PerAxis[int]]) -> Self: 314 return self.__class__.from_xarray(self._data.expand_dims(dims=dims)) 315 316 def item( 317 self, 318 key: Union[ 319 None, SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]] 320 ] = None, 321 ): 322 """Copy a tensor element to a standard Python scalar and return it.""" 323 if key is None: 324 ret = self._data.item() 325 else: 326 ret = self[key]._data.item() 327 328 assert isinstance(ret, (bool, float, int)) 329 return ret 330 331 def mean(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 332 return self.__class__.from_xarray(self._data.mean(dim=dim)) 333 334 def pad( 335 self, 336 pad_width: PerAxis[PadWidthLike], 337 mode: PadMode = "symmetric", 338 ) -> Self: 339 pad_width = {a: PadWidth.create(p) for a, p in pad_width.items()} 340 return self.__class__.from_xarray( 341 self._data.pad(pad_width=pad_width, mode=mode) 342 ) 343 344 def pad_to( 345 self, 346 sizes: PerAxis[int], 347 pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right", 348 mode: PadMode = "symmetric", 349 ) -> Self: 350 """pad `tensor` to match `sizes`""" 351 if isinstance(pad_where, str): 352 pad_axis_where: PerAxis[PadWhere] = {a: pad_where for a in self.dims} 353 else: 354 pad_axis_where = pad_where 355 356 pad_width: Dict[AxisId, PadWidth] = {} 357 for a, s_is in self.sizes.items(): 358 if a not in sizes or sizes[a] == s_is: 359 pad_width[a] = PadWidth(0, 0) 360 elif s_is > sizes[a]: 361 pad_width[a] = PadWidth(0, 0) 362 logger.warning( 363 "Cannot pad axis {} of size {} to smaller size {}", 364 a, 365 s_is, 366 sizes[a], 367 ) 368 elif a not in pad_axis_where: 369 raise ValueError( 370 f"Don't know where to pad axis {a}, `pad_where`={pad_where}" 371 ) 372 else: 373 pad_this_axis_where = pad_axis_where[a] 374 d = sizes[a] - s_is 375 if pad_this_axis_where == "left": 376 pad_width[a] = PadWidth(d, 0) 377 elif pad_this_axis_where == "right": 378 pad_width[a] = PadWidth(0, d) 379 elif pad_this_axis_where == "left_and_right": 380 pad_width[a] = PadWidth(left := d // 2, d - left) 381 else: 382 assert_never(pad_this_axis_where) 383 384 return self.pad(pad_width, mode) 385 386 def quantile( 387 self, 388 q: Union[float, Sequence[float]], 389 dim: Optional[Union[AxisId, Sequence[AxisId]]] = None, 390 ) -> Self: 391 assert ( 392 isinstance(q, (float, int)) 393 and q >= 0.0 394 or not isinstance(q, (float, int)) 395 and all(qq >= 0.0 for qq in q) 396 ) 397 assert ( 398 isinstance(q, (float, int)) 399 and q <= 1.0 400 or not isinstance(q, (float, int)) 401 and all(qq <= 1.0 for qq in q) 402 ) 403 assert dim is None or ( 404 (quantile_dim := AxisId("quantile")) != dim and quantile_dim not in set(dim) 405 ) 406 return self.__class__.from_xarray(self._data.quantile(q, dim=dim)) 407 408 def resize_to( 409 self, 410 sizes: PerAxis[int], 411 *, 412 pad_where: Union[ 413 PadWhere, 414 PerAxis[PadWhere], 415 ] = "left_and_right", 416 crop_where: Union[ 417 CropWhere, 418 PerAxis[CropWhere], 419 ] = "left_and_right", 420 pad_mode: PadMode = "symmetric", 421 ): 422 """return cropped/padded tensor with `sizes`""" 423 crop_to_sizes: Dict[AxisId, int] = {} 424 pad_to_sizes: Dict[AxisId, int] = {} 425 new_axes = dict(sizes) 426 for a, s_is in self.sizes.items(): 427 a = AxisId(str(a)) 428 _ = new_axes.pop(a, None) 429 if a not in sizes or sizes[a] == s_is: 430 pass 431 elif s_is > sizes[a]: 432 crop_to_sizes[a] = sizes[a] 433 else: 434 pad_to_sizes[a] = sizes[a] 435 436 tensor = self 437 if crop_to_sizes: 438 tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where) 439 440 if pad_to_sizes: 441 tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode) 442 443 if new_axes: 444 tensor = tensor.expand_dims(new_axes) 445 446 return tensor 447 448 def std(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 449 return self.__class__.from_xarray(self._data.std(dim=dim)) 450 451 def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 452 """Reduce this Tensor's data by applying sum along some dimension(s).""" 453 return self.__class__.from_xarray(self._data.sum(dim=dim)) 454 455 def transpose( 456 self, 457 axes: Sequence[AxisId], 458 ) -> Self: 459 """return a transposed tensor 460 461 Args: 462 axes: the desired tensor axes 463 """ 464 # expand missing tensor axes 465 missing_axes = tuple(a for a in axes if a not in self.dims) 466 array = self._data 467 if missing_axes: 468 array = array.expand_dims(missing_axes) 469 470 # transpose to the correct axis order 471 return self.__class__.from_xarray(array.transpose(*axes)) 472 473 def var(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 474 return self.__class__.from_xarray(self._data.var(dim=dim)) 475 476 @classmethod 477 def _interprete_array_wo_known_axes(cls, array: NDArray[Any]): 478 ndim = array.ndim 479 if ndim == 2: 480 current_axes = ( 481 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[0]), 482 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[1]), 483 ) 484 elif ndim == 3 and any(s <= 3 for s in array.shape): 485 current_axes = ( 486 v0_5.ChannelAxis( 487 channel_names=[ 488 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0]) 489 ] 490 ), 491 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]), 492 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]), 493 ) 494 elif ndim == 3: 495 current_axes = ( 496 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[0]), 497 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]), 498 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]), 499 ) 500 elif ndim == 4: 501 current_axes = ( 502 v0_5.ChannelAxis( 503 channel_names=[ 504 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0]) 505 ] 506 ), 507 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[1]), 508 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[2]), 509 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[3]), 510 ) 511 elif ndim == 5: 512 current_axes = ( 513 v0_5.BatchAxis(), 514 v0_5.ChannelAxis( 515 channel_names=[ 516 v0_5.Identifier(f"channel{i}") for i in range(array.shape[1]) 517 ] 518 ), 519 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[2]), 520 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[3]), 521 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[4]), 522 ) 523 else: 524 raise ValueError(f"Could not guess an axis mapping for {array.shape}") 525 526 return cls(array, dims=tuple(a.id for a in current_axes)) 527 528 529def _add_singletons(arr: NDArray[Any], axis_infos: Sequence[AxisInfo]): 530 if len(arr.shape) > len(axis_infos): 531 # remove singletons 532 for i, s in enumerate(arr.shape): 533 if s == 1: 534 arr = np.take(arr, 0, axis=i) 535 if len(arr.shape) == len(axis_infos): 536 break 537 538 # add singletons if nececsary 539 for i, a in enumerate(axis_infos): 540 if len(arr.shape) >= len(axis_infos): 541 break 542 543 if a.maybe_singleton: 544 arr = np.expand_dims(arr, i) 545 546 return arr 547 548 549def _get_array_view( 550 original_array: NDArray[Any], axis_infos: Sequence[AxisInfo] 551) -> Optional[NDArray[Any]]: 552 perms = list(permutations(range(len(original_array.shape)))) 553 perms.insert(1, perms.pop()) # try A and A.T first 554 555 for perm in perms: 556 view = original_array.transpose(perm) 557 view = _add_singletons(view, axis_infos) 558 if len(view.shape) != len(axis_infos): 559 return None 560 561 for s, a in zip(view.shape, axis_infos): 562 if s == 1 and not a.maybe_singleton: 563 break 564 else: 565 return view 566 567 return None
50class Tensor(MagicTensorOpsMixin): 51 """A wrapper around an xr.DataArray for better integration with bioimageio.spec 52 and improved type annotations.""" 53 54 _Compatible = Union["Tensor", xr.DataArray, _ScalarOrArray] 55 56 def __init__( 57 self, 58 array: NDArray[Any], 59 dims: Sequence[Union[AxisId, AxisLike]], 60 ) -> None: 61 super().__init__() 62 axes = tuple( 63 a if isinstance(a, AxisId) else AxisInfo.create(a).id for a in dims 64 ) 65 self._data = xr.DataArray(array, dims=axes) 66 67 def __array__(self, dtype: DTypeLike = None): 68 return np.asarray(self._data, dtype=dtype) 69 70 def __getitem__( 71 self, 72 key: Union[ 73 SliceInfo, 74 slice, 75 int, 76 PerAxis[Union[SliceInfo, slice, int]], 77 Tensor, 78 xr.DataArray, 79 ], 80 ) -> Self: 81 if isinstance(key, SliceInfo): 82 key = slice(*key) 83 elif isinstance(key, collections.abc.Mapping): 84 key = { 85 a: s if isinstance(s, int) else s if isinstance(s, slice) else slice(*s) 86 for a, s in key.items() 87 } 88 elif isinstance(key, Tensor): 89 key = key._data 90 91 return self.__class__.from_xarray(self._data[key]) 92 93 def __setitem__( 94 self, 95 key: Union[PerAxis[Union[SliceInfo, slice]], Tensor, xr.DataArray], 96 value: Union[Tensor, xr.DataArray, float, int], 97 ) -> None: 98 if isinstance(key, Tensor): 99 key = key._data 100 elif isinstance(key, xr.DataArray): 101 pass 102 else: 103 key = {a: s if isinstance(s, slice) else slice(*s) for a, s in key.items()} 104 105 if isinstance(value, Tensor): 106 value = value._data 107 108 self._data[key] = value 109 110 def __len__(self) -> int: 111 return len(self.data) 112 113 def _iter(self: Any) -> Iterator[Any]: 114 for n in range(len(self)): 115 yield self[n] 116 117 def __iter__(self: Any) -> Iterator[Any]: 118 if self.ndim == 0: 119 raise TypeError("iteration over a 0-d array") 120 return self._iter() 121 122 def _binary_op( 123 self, 124 other: _Compatible, 125 f: Callable[[Any, Any], Any], 126 reflexive: bool = False, 127 ) -> Self: 128 data = self._data._binary_op( # pyright: ignore[reportPrivateUsage] 129 (other._data if isinstance(other, Tensor) else other), 130 f, 131 reflexive, 132 ) 133 return self.__class__.from_xarray(data) 134 135 def _inplace_binary_op( 136 self, 137 other: _Compatible, 138 f: Callable[[Any, Any], Any], 139 ) -> Self: 140 _ = self._data._inplace_binary_op( # pyright: ignore[reportPrivateUsage] 141 ( 142 other_d 143 if (other_d := getattr(other, "data")) is not None 144 and isinstance( 145 other_d, 146 xr.DataArray, 147 ) 148 else other 149 ), 150 f, 151 ) 152 return self 153 154 def _unary_op(self, f: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Self: 155 data = self._data._unary_op( # pyright: ignore[reportPrivateUsage] 156 f, *args, **kwargs 157 ) 158 return self.__class__.from_xarray(data) 159 160 @classmethod 161 def from_xarray(cls, data_array: xr.DataArray) -> Self: 162 """create a `Tensor` from an xarray data array 163 164 note for internal use: this factory method is round-trip save 165 for any `Tensor`'s `data` property (an xarray.DataArray). 166 """ 167 return cls( 168 array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims) 169 ) 170 171 @classmethod 172 def from_numpy( 173 cls, 174 array: NDArray[Any], 175 *, 176 dims: Optional[Union[AxisLike, Sequence[AxisLike]]], 177 ) -> Tensor: 178 """create a `Tensor` from a numpy array 179 180 Args: 181 array: the nd numpy array 182 axes: A description of the array's axes, 183 if None axes are guessed (which might fail and raise a ValueError.) 184 185 Raises: 186 ValueError: if `axes` is None and axes guessing fails. 187 """ 188 189 if dims is None: 190 return cls._interprete_array_wo_known_axes(array) 191 elif isinstance(dims, (str, Axis, v0_5.AxisBase)): 192 dims = [dims] 193 194 axis_infos = [AxisInfo.create(a) for a in dims] 195 original_shape = tuple(array.shape) 196 197 successful_view = _get_array_view(array, axis_infos) 198 if successful_view is None: 199 raise ValueError( 200 f"Array shape {original_shape} does not map to axes {dims}" 201 ) 202 203 return Tensor(successful_view, dims=tuple(a.id for a in axis_infos)) 204 205 @property 206 def data(self): 207 return self._data 208 209 @property 210 def dims(self): # TODO: rename to `axes`? 211 """Tuple of dimension names associated with this tensor.""" 212 return cast(Tuple[AxisId, ...], self._data.dims) 213 214 @property 215 def dtype(self) -> DTypeStr: 216 dt = str(self.data.dtype) # pyright: ignore[reportUnknownArgumentType] 217 assert dt in get_args(DTypeStr) 218 return dt # pyright: ignore[reportReturnType] 219 220 @property 221 def ndim(self): 222 """Number of tensor dimensions.""" 223 return self._data.ndim 224 225 @property 226 def shape(self): 227 """Tuple of tensor axes lengths""" 228 return self._data.shape 229 230 @property 231 def shape_tuple(self): 232 """Tuple of tensor axes lengths""" 233 return self._data.shape 234 235 @property 236 def size(self): 237 """Number of elements in the tensor. 238 239 Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions. 240 """ 241 return self._data.size 242 243 @property 244 def sizes(self): 245 """Ordered, immutable mapping from axis ids to axis lengths.""" 246 return cast(Mapping[AxisId, int], self.data.sizes) 247 248 @property 249 def tagged_shape(self): 250 """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths.""" 251 return self.sizes 252 253 def argmax(self) -> Mapping[AxisId, int]: 254 ret = self._data.argmax(...) 255 assert isinstance(ret, dict) 256 return {cast(AxisId, k): cast(int, v.item()) for k, v in ret.items()} 257 258 def astype(self, dtype: DTypeStr, *, copy: bool = False): 259 """Return tensor cast to `dtype` 260 261 note: if dtype is already satisfied copy if `copy`""" 262 return self.__class__.from_xarray(self._data.astype(dtype, copy=copy)) 263 264 def clip(self, min: Optional[float] = None, max: Optional[float] = None): 265 """Return a tensor whose values are limited to [min, max]. 266 At least one of max or min must be given.""" 267 return self.__class__.from_xarray(self._data.clip(min, max)) 268 269 def crop_to( 270 self, 271 sizes: PerAxis[int], 272 crop_where: Union[ 273 CropWhere, 274 PerAxis[CropWhere], 275 ] = "left_and_right", 276 ) -> Self: 277 """crop to match `sizes`""" 278 if isinstance(crop_where, str): 279 crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims} 280 else: 281 crop_axis_where = crop_where 282 283 slices: Dict[AxisId, SliceInfo] = {} 284 285 for a, s_is in self.sizes.items(): 286 if a not in sizes or sizes[a] == s_is: 287 pass 288 elif sizes[a] > s_is: 289 logger.warning( 290 "Cannot crop axis {} of size {} to larger size {}", 291 a, 292 s_is, 293 sizes[a], 294 ) 295 elif a not in crop_axis_where: 296 raise ValueError( 297 f"Don't know where to crop axis {a}, `crop_where`={crop_where}" 298 ) 299 else: 300 crop_this_axis_where = crop_axis_where[a] 301 if crop_this_axis_where == "left": 302 slices[a] = SliceInfo(s_is - sizes[a], s_is) 303 elif crop_this_axis_where == "right": 304 slices[a] = SliceInfo(0, sizes[a]) 305 elif crop_this_axis_where == "left_and_right": 306 slices[a] = SliceInfo( 307 start := (s_is - sizes[a]) // 2, sizes[a] + start 308 ) 309 else: 310 assert_never(crop_this_axis_where) 311 312 return self[slices] 313 314 def expand_dims(self, dims: Union[Sequence[AxisId], PerAxis[int]]) -> Self: 315 return self.__class__.from_xarray(self._data.expand_dims(dims=dims)) 316 317 def item( 318 self, 319 key: Union[ 320 None, SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]] 321 ] = None, 322 ): 323 """Copy a tensor element to a standard Python scalar and return it.""" 324 if key is None: 325 ret = self._data.item() 326 else: 327 ret = self[key]._data.item() 328 329 assert isinstance(ret, (bool, float, int)) 330 return ret 331 332 def mean(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 333 return self.__class__.from_xarray(self._data.mean(dim=dim)) 334 335 def pad( 336 self, 337 pad_width: PerAxis[PadWidthLike], 338 mode: PadMode = "symmetric", 339 ) -> Self: 340 pad_width = {a: PadWidth.create(p) for a, p in pad_width.items()} 341 return self.__class__.from_xarray( 342 self._data.pad(pad_width=pad_width, mode=mode) 343 ) 344 345 def pad_to( 346 self, 347 sizes: PerAxis[int], 348 pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right", 349 mode: PadMode = "symmetric", 350 ) -> Self: 351 """pad `tensor` to match `sizes`""" 352 if isinstance(pad_where, str): 353 pad_axis_where: PerAxis[PadWhere] = {a: pad_where for a in self.dims} 354 else: 355 pad_axis_where = pad_where 356 357 pad_width: Dict[AxisId, PadWidth] = {} 358 for a, s_is in self.sizes.items(): 359 if a not in sizes or sizes[a] == s_is: 360 pad_width[a] = PadWidth(0, 0) 361 elif s_is > sizes[a]: 362 pad_width[a] = PadWidth(0, 0) 363 logger.warning( 364 "Cannot pad axis {} of size {} to smaller size {}", 365 a, 366 s_is, 367 sizes[a], 368 ) 369 elif a not in pad_axis_where: 370 raise ValueError( 371 f"Don't know where to pad axis {a}, `pad_where`={pad_where}" 372 ) 373 else: 374 pad_this_axis_where = pad_axis_where[a] 375 d = sizes[a] - s_is 376 if pad_this_axis_where == "left": 377 pad_width[a] = PadWidth(d, 0) 378 elif pad_this_axis_where == "right": 379 pad_width[a] = PadWidth(0, d) 380 elif pad_this_axis_where == "left_and_right": 381 pad_width[a] = PadWidth(left := d // 2, d - left) 382 else: 383 assert_never(pad_this_axis_where) 384 385 return self.pad(pad_width, mode) 386 387 def quantile( 388 self, 389 q: Union[float, Sequence[float]], 390 dim: Optional[Union[AxisId, Sequence[AxisId]]] = None, 391 ) -> Self: 392 assert ( 393 isinstance(q, (float, int)) 394 and q >= 0.0 395 or not isinstance(q, (float, int)) 396 and all(qq >= 0.0 for qq in q) 397 ) 398 assert ( 399 isinstance(q, (float, int)) 400 and q <= 1.0 401 or not isinstance(q, (float, int)) 402 and all(qq <= 1.0 for qq in q) 403 ) 404 assert dim is None or ( 405 (quantile_dim := AxisId("quantile")) != dim and quantile_dim not in set(dim) 406 ) 407 return self.__class__.from_xarray(self._data.quantile(q, dim=dim)) 408 409 def resize_to( 410 self, 411 sizes: PerAxis[int], 412 *, 413 pad_where: Union[ 414 PadWhere, 415 PerAxis[PadWhere], 416 ] = "left_and_right", 417 crop_where: Union[ 418 CropWhere, 419 PerAxis[CropWhere], 420 ] = "left_and_right", 421 pad_mode: PadMode = "symmetric", 422 ): 423 """return cropped/padded tensor with `sizes`""" 424 crop_to_sizes: Dict[AxisId, int] = {} 425 pad_to_sizes: Dict[AxisId, int] = {} 426 new_axes = dict(sizes) 427 for a, s_is in self.sizes.items(): 428 a = AxisId(str(a)) 429 _ = new_axes.pop(a, None) 430 if a not in sizes or sizes[a] == s_is: 431 pass 432 elif s_is > sizes[a]: 433 crop_to_sizes[a] = sizes[a] 434 else: 435 pad_to_sizes[a] = sizes[a] 436 437 tensor = self 438 if crop_to_sizes: 439 tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where) 440 441 if pad_to_sizes: 442 tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode) 443 444 if new_axes: 445 tensor = tensor.expand_dims(new_axes) 446 447 return tensor 448 449 def std(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 450 return self.__class__.from_xarray(self._data.std(dim=dim)) 451 452 def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 453 """Reduce this Tensor's data by applying sum along some dimension(s).""" 454 return self.__class__.from_xarray(self._data.sum(dim=dim)) 455 456 def transpose( 457 self, 458 axes: Sequence[AxisId], 459 ) -> Self: 460 """return a transposed tensor 461 462 Args: 463 axes: the desired tensor axes 464 """ 465 # expand missing tensor axes 466 missing_axes = tuple(a for a in axes if a not in self.dims) 467 array = self._data 468 if missing_axes: 469 array = array.expand_dims(missing_axes) 470 471 # transpose to the correct axis order 472 return self.__class__.from_xarray(array.transpose(*axes)) 473 474 def var(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 475 return self.__class__.from_xarray(self._data.var(dim=dim)) 476 477 @classmethod 478 def _interprete_array_wo_known_axes(cls, array: NDArray[Any]): 479 ndim = array.ndim 480 if ndim == 2: 481 current_axes = ( 482 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[0]), 483 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[1]), 484 ) 485 elif ndim == 3 and any(s <= 3 for s in array.shape): 486 current_axes = ( 487 v0_5.ChannelAxis( 488 channel_names=[ 489 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0]) 490 ] 491 ), 492 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]), 493 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]), 494 ) 495 elif ndim == 3: 496 current_axes = ( 497 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[0]), 498 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]), 499 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]), 500 ) 501 elif ndim == 4: 502 current_axes = ( 503 v0_5.ChannelAxis( 504 channel_names=[ 505 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0]) 506 ] 507 ), 508 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[1]), 509 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[2]), 510 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[3]), 511 ) 512 elif ndim == 5: 513 current_axes = ( 514 v0_5.BatchAxis(), 515 v0_5.ChannelAxis( 516 channel_names=[ 517 v0_5.Identifier(f"channel{i}") for i in range(array.shape[1]) 518 ] 519 ), 520 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[2]), 521 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[3]), 522 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[4]), 523 ) 524 else: 525 raise ValueError(f"Could not guess an axis mapping for {array.shape}") 526 527 return cls(array, dims=tuple(a.id for a in current_axes))
A wrapper around an xr.DataArray for better integration with bioimageio.spec and improved type annotations.
160 @classmethod 161 def from_xarray(cls, data_array: xr.DataArray) -> Self: 162 """create a `Tensor` from an xarray data array 163 164 note for internal use: this factory method is round-trip save 165 for any `Tensor`'s `data` property (an xarray.DataArray). 166 """ 167 return cls( 168 array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims) 169 )
171 @classmethod 172 def from_numpy( 173 cls, 174 array: NDArray[Any], 175 *, 176 dims: Optional[Union[AxisLike, Sequence[AxisLike]]], 177 ) -> Tensor: 178 """create a `Tensor` from a numpy array 179 180 Args: 181 array: the nd numpy array 182 axes: A description of the array's axes, 183 if None axes are guessed (which might fail and raise a ValueError.) 184 185 Raises: 186 ValueError: if `axes` is None and axes guessing fails. 187 """ 188 189 if dims is None: 190 return cls._interprete_array_wo_known_axes(array) 191 elif isinstance(dims, (str, Axis, v0_5.AxisBase)): 192 dims = [dims] 193 194 axis_infos = [AxisInfo.create(a) for a in dims] 195 original_shape = tuple(array.shape) 196 197 successful_view = _get_array_view(array, axis_infos) 198 if successful_view is None: 199 raise ValueError( 200 f"Array shape {original_shape} does not map to axes {dims}" 201 ) 202 203 return Tensor(successful_view, dims=tuple(a.id for a in axis_infos))
create a Tensor
from a numpy array
Arguments:
- array: the nd numpy array
- axes: A description of the array's axes, if None axes are guessed (which might fail and raise a ValueError.)
Raises:
- ValueError: if
axes
is None and axes guessing fails.
209 @property 210 def dims(self): # TODO: rename to `axes`? 211 """Tuple of dimension names associated with this tensor.""" 212 return cast(Tuple[AxisId, ...], self._data.dims)
Tuple of dimension names associated with this tensor.
225 @property 226 def shape(self): 227 """Tuple of tensor axes lengths""" 228 return self._data.shape
Tuple of tensor axes lengths
230 @property 231 def shape_tuple(self): 232 """Tuple of tensor axes lengths""" 233 return self._data.shape
Tuple of tensor axes lengths
235 @property 236 def size(self): 237 """Number of elements in the tensor. 238 239 Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions. 240 """ 241 return self._data.size
Number of elements in the tensor.
Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions.
243 @property 244 def sizes(self): 245 """Ordered, immutable mapping from axis ids to axis lengths.""" 246 return cast(Mapping[AxisId, int], self.data.sizes)
Ordered, immutable mapping from axis ids to axis lengths.
248 @property 249 def tagged_shape(self): 250 """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths.""" 251 return self.sizes
(alias for sizes
) Ordered, immutable mapping from axis ids to lengths.
258 def astype(self, dtype: DTypeStr, *, copy: bool = False): 259 """Return tensor cast to `dtype` 260 261 note: if dtype is already satisfied copy if `copy`""" 262 return self.__class__.from_xarray(self._data.astype(dtype, copy=copy))
Return tensor cast to dtype
note: if dtype is already satisfied copy if copy
264 def clip(self, min: Optional[float] = None, max: Optional[float] = None): 265 """Return a tensor whose values are limited to [min, max]. 266 At least one of max or min must be given.""" 267 return self.__class__.from_xarray(self._data.clip(min, max))
Return a tensor whose values are limited to [min, max]. At least one of max or min must be given.
269 def crop_to( 270 self, 271 sizes: PerAxis[int], 272 crop_where: Union[ 273 CropWhere, 274 PerAxis[CropWhere], 275 ] = "left_and_right", 276 ) -> Self: 277 """crop to match `sizes`""" 278 if isinstance(crop_where, str): 279 crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims} 280 else: 281 crop_axis_where = crop_where 282 283 slices: Dict[AxisId, SliceInfo] = {} 284 285 for a, s_is in self.sizes.items(): 286 if a not in sizes or sizes[a] == s_is: 287 pass 288 elif sizes[a] > s_is: 289 logger.warning( 290 "Cannot crop axis {} of size {} to larger size {}", 291 a, 292 s_is, 293 sizes[a], 294 ) 295 elif a not in crop_axis_where: 296 raise ValueError( 297 f"Don't know where to crop axis {a}, `crop_where`={crop_where}" 298 ) 299 else: 300 crop_this_axis_where = crop_axis_where[a] 301 if crop_this_axis_where == "left": 302 slices[a] = SliceInfo(s_is - sizes[a], s_is) 303 elif crop_this_axis_where == "right": 304 slices[a] = SliceInfo(0, sizes[a]) 305 elif crop_this_axis_where == "left_and_right": 306 slices[a] = SliceInfo( 307 start := (s_is - sizes[a]) // 2, sizes[a] + start 308 ) 309 else: 310 assert_never(crop_this_axis_where) 311 312 return self[slices]
crop to match sizes
317 def item( 318 self, 319 key: Union[ 320 None, SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]] 321 ] = None, 322 ): 323 """Copy a tensor element to a standard Python scalar and return it.""" 324 if key is None: 325 ret = self._data.item() 326 else: 327 ret = self[key]._data.item() 328 329 assert isinstance(ret, (bool, float, int)) 330 return ret
Copy a tensor element to a standard Python scalar and return it.
345 def pad_to( 346 self, 347 sizes: PerAxis[int], 348 pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right", 349 mode: PadMode = "symmetric", 350 ) -> Self: 351 """pad `tensor` to match `sizes`""" 352 if isinstance(pad_where, str): 353 pad_axis_where: PerAxis[PadWhere] = {a: pad_where for a in self.dims} 354 else: 355 pad_axis_where = pad_where 356 357 pad_width: Dict[AxisId, PadWidth] = {} 358 for a, s_is in self.sizes.items(): 359 if a not in sizes or sizes[a] == s_is: 360 pad_width[a] = PadWidth(0, 0) 361 elif s_is > sizes[a]: 362 pad_width[a] = PadWidth(0, 0) 363 logger.warning( 364 "Cannot pad axis {} of size {} to smaller size {}", 365 a, 366 s_is, 367 sizes[a], 368 ) 369 elif a not in pad_axis_where: 370 raise ValueError( 371 f"Don't know where to pad axis {a}, `pad_where`={pad_where}" 372 ) 373 else: 374 pad_this_axis_where = pad_axis_where[a] 375 d = sizes[a] - s_is 376 if pad_this_axis_where == "left": 377 pad_width[a] = PadWidth(d, 0) 378 elif pad_this_axis_where == "right": 379 pad_width[a] = PadWidth(0, d) 380 elif pad_this_axis_where == "left_and_right": 381 pad_width[a] = PadWidth(left := d // 2, d - left) 382 else: 383 assert_never(pad_this_axis_where) 384 385 return self.pad(pad_width, mode)
pad tensor
to match sizes
387 def quantile( 388 self, 389 q: Union[float, Sequence[float]], 390 dim: Optional[Union[AxisId, Sequence[AxisId]]] = None, 391 ) -> Self: 392 assert ( 393 isinstance(q, (float, int)) 394 and q >= 0.0 395 or not isinstance(q, (float, int)) 396 and all(qq >= 0.0 for qq in q) 397 ) 398 assert ( 399 isinstance(q, (float, int)) 400 and q <= 1.0 401 or not isinstance(q, (float, int)) 402 and all(qq <= 1.0 for qq in q) 403 ) 404 assert dim is None or ( 405 (quantile_dim := AxisId("quantile")) != dim and quantile_dim not in set(dim) 406 ) 407 return self.__class__.from_xarray(self._data.quantile(q, dim=dim))
409 def resize_to( 410 self, 411 sizes: PerAxis[int], 412 *, 413 pad_where: Union[ 414 PadWhere, 415 PerAxis[PadWhere], 416 ] = "left_and_right", 417 crop_where: Union[ 418 CropWhere, 419 PerAxis[CropWhere], 420 ] = "left_and_right", 421 pad_mode: PadMode = "symmetric", 422 ): 423 """return cropped/padded tensor with `sizes`""" 424 crop_to_sizes: Dict[AxisId, int] = {} 425 pad_to_sizes: Dict[AxisId, int] = {} 426 new_axes = dict(sizes) 427 for a, s_is in self.sizes.items(): 428 a = AxisId(str(a)) 429 _ = new_axes.pop(a, None) 430 if a not in sizes or sizes[a] == s_is: 431 pass 432 elif s_is > sizes[a]: 433 crop_to_sizes[a] = sizes[a] 434 else: 435 pad_to_sizes[a] = sizes[a] 436 437 tensor = self 438 if crop_to_sizes: 439 tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where) 440 441 if pad_to_sizes: 442 tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode) 443 444 if new_axes: 445 tensor = tensor.expand_dims(new_axes) 446 447 return tensor
return cropped/padded tensor with sizes
452 def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 453 """Reduce this Tensor's data by applying sum along some dimension(s).""" 454 return self.__class__.from_xarray(self._data.sum(dim=dim))
Reduce this Tensor's data by applying sum along some dimension(s).
456 def transpose( 457 self, 458 axes: Sequence[AxisId], 459 ) -> Self: 460 """return a transposed tensor 461 462 Args: 463 axes: the desired tensor axes 464 """ 465 # expand missing tensor axes 466 missing_axes = tuple(a for a in axes if a not in self.dims) 467 array = self._data 468 if missing_axes: 469 array = array.expand_dims(missing_axes) 470 471 # transpose to the correct axis order 472 return self.__class__.from_xarray(array.transpose(*axes))
return a transposed tensor
Arguments:
- axes: the desired tensor axes