bioimageio.core.tensor
1from __future__ import annotations 2 3import collections.abc 4from itertools import permutations 5from typing import ( 6 TYPE_CHECKING, 7 Any, 8 Callable, 9 Dict, 10 Iterator, 11 Mapping, 12 Optional, 13 Sequence, 14 Tuple, 15 Union, 16 cast, 17 get_args, 18) 19 20import numpy as np 21import xarray as xr 22from loguru import logger 23from numpy.typing import DTypeLike, NDArray 24from typing_extensions import Self, assert_never 25 26from bioimageio.spec.model import v0_5 27 28from ._magic_tensor_ops import MagicTensorOpsMixin 29from .axis import Axis, AxisId, AxisInfo, AxisLike, PerAxis 30from .common import ( 31 CropWhere, 32 DTypeStr, 33 PadMode, 34 PadWhere, 35 PadWidth, 36 PadWidthLike, 37 SliceInfo, 38) 39 40if TYPE_CHECKING: 41 from numpy.typing import ArrayLike, NDArray 42 43 44_ScalarOrArray = Union["ArrayLike", np.generic, "NDArray[Any]"] # TODO: add "DaskArray" 45 46 47# TODO: complete docstrings 48class Tensor(MagicTensorOpsMixin): 49 """A wrapper around an xr.DataArray for better integration with bioimageio.spec 50 and improved type annotations.""" 51 52 _Compatible = Union["Tensor", xr.DataArray, _ScalarOrArray] 53 54 def __init__( 55 self, 56 array: NDArray[Any], 57 dims: Sequence[Union[AxisId, AxisLike]], 58 ) -> None: 59 super().__init__() 60 axes = tuple( 61 a if isinstance(a, AxisId) else AxisInfo.create(a).id for a in dims 62 ) 63 self._data = xr.DataArray(array, dims=axes) 64 65 def __array__(self, dtype: DTypeLike = None): 66 return np.asarray(self._data, dtype=dtype) 67 68 def __getitem__( 69 self, key: Union[SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]]] 70 ) -> Self: 71 if isinstance(key, SliceInfo): 72 key = slice(*key) 73 elif isinstance(key, collections.abc.Mapping): 74 key = { 75 a: s if isinstance(s, int) else s if isinstance(s, slice) else slice(*s) 76 for a, s in key.items() 77 } 78 return self.__class__.from_xarray(self._data[key]) 79 80 def __setitem__(self, key: PerAxis[Union[SliceInfo, slice]], value: Tensor) -> None: 81 key = {a: s if isinstance(s, slice) else slice(*s) for a, s in key.items()} 82 self._data[key] = value._data 83 84 def __len__(self) -> int: 85 return len(self.data) 86 87 def _iter(self: Any) -> Iterator[Any]: 88 for n in range(len(self)): 89 yield self[n] 90 91 def __iter__(self: Any) -> Iterator[Any]: 92 if self.ndim == 0: 93 raise TypeError("iteration over a 0-d array") 94 return self._iter() 95 96 def _binary_op( 97 self, 98 other: _Compatible, 99 f: Callable[[Any, Any], Any], 100 reflexive: bool = False, 101 ) -> Self: 102 data = self._data._binary_op( # pyright: ignore[reportPrivateUsage] 103 (other._data if isinstance(other, Tensor) else other), 104 f, 105 reflexive, 106 ) 107 return self.__class__.from_xarray(data) 108 109 def _inplace_binary_op( 110 self, 111 other: _Compatible, 112 f: Callable[[Any, Any], Any], 113 ) -> Self: 114 _ = self._data._inplace_binary_op( # pyright: ignore[reportPrivateUsage] 115 ( 116 other_d 117 if (other_d := getattr(other, "data")) is not None 118 and isinstance( 119 other_d, 120 xr.DataArray, 121 ) 122 else other 123 ), 124 f, 125 ) 126 return self 127 128 def _unary_op(self, f: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Self: 129 data = self._data._unary_op( # pyright: ignore[reportPrivateUsage] 130 f, *args, **kwargs 131 ) 132 return self.__class__.from_xarray(data) 133 134 @classmethod 135 def from_xarray(cls, data_array: xr.DataArray) -> Self: 136 """create a `Tensor` from an xarray data array 137 138 note for internal use: this factory method is round-trip save 139 for any `Tensor`'s `data` property (an xarray.DataArray). 140 """ 141 return cls( 142 array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims) 143 ) 144 145 @classmethod 146 def from_numpy( 147 cls, 148 array: NDArray[Any], 149 *, 150 dims: Optional[Union[AxisLike, Sequence[AxisLike]]], 151 ) -> Tensor: 152 """create a `Tensor` from a numpy array 153 154 Args: 155 array: the nd numpy array 156 axes: A description of the array's axes, 157 if None axes are guessed (which might fail and raise a ValueError.) 158 159 Raises: 160 ValueError: if `axes` is None and axes guessing fails. 161 """ 162 163 if dims is None: 164 return cls._interprete_array_wo_known_axes(array) 165 elif isinstance(dims, (str, Axis, v0_5.AxisBase)): 166 dims = [dims] 167 168 axis_infos = [AxisInfo.create(a) for a in dims] 169 original_shape = tuple(array.shape) 170 171 successful_view = _get_array_view(array, axis_infos) 172 if successful_view is None: 173 raise ValueError( 174 f"Array shape {original_shape} does not map to axes {dims}" 175 ) 176 177 return Tensor(successful_view, dims=tuple(a.id for a in axis_infos)) 178 179 @property 180 def data(self): 181 return self._data 182 183 @property 184 def dims(self): # TODO: rename to `axes`? 185 """Tuple of dimension names associated with this tensor.""" 186 return cast(Tuple[AxisId, ...], self._data.dims) 187 188 @property 189 def dtype(self) -> DTypeStr: 190 dt = str(self.data.dtype) # pyright: ignore[reportUnknownArgumentType] 191 assert dt in get_args(DTypeStr) 192 return dt # pyright: ignore[reportReturnType] 193 194 @property 195 def ndim(self): 196 """Number of tensor dimensions.""" 197 return self._data.ndim 198 199 @property 200 def shape(self): 201 """Tuple of tensor axes lengths""" 202 return self._data.shape 203 204 @property 205 def shape_tuple(self): 206 """Tuple of tensor axes lengths""" 207 return self._data.shape 208 209 @property 210 def size(self): 211 """Number of elements in the tensor. 212 213 Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions. 214 """ 215 return self._data.size 216 217 @property 218 def sizes(self): 219 """Ordered, immutable mapping from axis ids to axis lengths.""" 220 return cast(Mapping[AxisId, int], self.data.sizes) 221 222 @property 223 def tagged_shape(self): 224 """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths.""" 225 return self.sizes 226 227 def argmax(self) -> Mapping[AxisId, int]: 228 ret = self._data.argmax(...) 229 assert isinstance(ret, dict) 230 return {cast(AxisId, k): cast(int, v.item()) for k, v in ret.items()} 231 232 def astype(self, dtype: DTypeStr, *, copy: bool = False): 233 """Return tensor cast to `dtype` 234 235 note: if dtype is already satisfied copy if `copy`""" 236 return self.__class__.from_xarray(self._data.astype(dtype, copy=copy)) 237 238 def clip(self, min: Optional[float] = None, max: Optional[float] = None): 239 """Return a tensor whose values are limited to [min, max]. 240 At least one of max or min must be given.""" 241 return self.__class__.from_xarray(self._data.clip(min, max)) 242 243 def crop_to( 244 self, 245 sizes: PerAxis[int], 246 crop_where: Union[ 247 CropWhere, 248 PerAxis[CropWhere], 249 ] = "left_and_right", 250 ) -> Self: 251 """crop to match `sizes`""" 252 if isinstance(crop_where, str): 253 crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims} 254 else: 255 crop_axis_where = crop_where 256 257 slices: Dict[AxisId, SliceInfo] = {} 258 259 for a, s_is in self.sizes.items(): 260 if a not in sizes or sizes[a] == s_is: 261 pass 262 elif sizes[a] > s_is: 263 logger.warning( 264 "Cannot crop axis {} of size {} to larger size {}", 265 a, 266 s_is, 267 sizes[a], 268 ) 269 elif a not in crop_axis_where: 270 raise ValueError( 271 f"Don't know where to crop axis {a}, `crop_where`={crop_where}" 272 ) 273 else: 274 crop_this_axis_where = crop_axis_where[a] 275 if crop_this_axis_where == "left": 276 slices[a] = SliceInfo(s_is - sizes[a], s_is) 277 elif crop_this_axis_where == "right": 278 slices[a] = SliceInfo(0, sizes[a]) 279 elif crop_this_axis_where == "left_and_right": 280 slices[a] = SliceInfo( 281 start := (s_is - sizes[a]) // 2, sizes[a] + start 282 ) 283 else: 284 assert_never(crop_this_axis_where) 285 286 return self[slices] 287 288 def expand_dims(self, dims: Union[Sequence[AxisId], PerAxis[int]]) -> Self: 289 return self.__class__.from_xarray(self._data.expand_dims(dims=dims)) 290 291 def item( 292 self, 293 key: Union[ 294 None, SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]] 295 ] = None, 296 ): 297 """Copy a tensor element to a standard Python scalar and return it.""" 298 if key is None: 299 ret = self._data.item() 300 else: 301 ret = self[key]._data.item() 302 303 assert isinstance(ret, (bool, float, int)) 304 return ret 305 306 def mean(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 307 return self.__class__.from_xarray(self._data.mean(dim=dim)) 308 309 def pad( 310 self, 311 pad_width: PerAxis[PadWidthLike], 312 mode: PadMode = "symmetric", 313 ) -> Self: 314 pad_width = {a: PadWidth.create(p) for a, p in pad_width.items()} 315 return self.__class__.from_xarray( 316 self._data.pad(pad_width=pad_width, mode=mode) 317 ) 318 319 def pad_to( 320 self, 321 sizes: PerAxis[int], 322 pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right", 323 mode: PadMode = "symmetric", 324 ) -> Self: 325 """pad `tensor` to match `sizes`""" 326 if isinstance(pad_where, str): 327 pad_axis_where: PerAxis[PadWhere] = {a: pad_where for a in self.dims} 328 else: 329 pad_axis_where = pad_where 330 331 pad_width: Dict[AxisId, PadWidth] = {} 332 for a, s_is in self.sizes.items(): 333 if a not in sizes or sizes[a] == s_is: 334 pad_width[a] = PadWidth(0, 0) 335 elif s_is > sizes[a]: 336 pad_width[a] = PadWidth(0, 0) 337 logger.warning( 338 "Cannot pad axis {} of size {} to smaller size {}", 339 a, 340 s_is, 341 sizes[a], 342 ) 343 elif a not in pad_axis_where: 344 raise ValueError( 345 f"Don't know where to pad axis {a}, `pad_where`={pad_where}" 346 ) 347 else: 348 pad_this_axis_where = pad_axis_where[a] 349 d = sizes[a] - s_is 350 if pad_this_axis_where == "left": 351 pad_width[a] = PadWidth(d, 0) 352 elif pad_this_axis_where == "right": 353 pad_width[a] = PadWidth(0, d) 354 elif pad_this_axis_where == "left_and_right": 355 pad_width[a] = PadWidth(left := d // 2, d - left) 356 else: 357 assert_never(pad_this_axis_where) 358 359 return self.pad(pad_width, mode) 360 361 def quantile( 362 self, 363 q: Union[float, Sequence[float]], 364 dim: Optional[Union[AxisId, Sequence[AxisId]]] = None, 365 ) -> Self: 366 assert ( 367 isinstance(q, (float, int)) 368 and q >= 0.0 369 or not isinstance(q, (float, int)) 370 and all(qq >= 0.0 for qq in q) 371 ) 372 assert ( 373 isinstance(q, (float, int)) 374 and q <= 1.0 375 or not isinstance(q, (float, int)) 376 and all(qq <= 1.0 for qq in q) 377 ) 378 assert dim is None or ( 379 (quantile_dim := AxisId("quantile")) != dim and quantile_dim not in set(dim) 380 ) 381 return self.__class__.from_xarray(self._data.quantile(q, dim=dim)) 382 383 def resize_to( 384 self, 385 sizes: PerAxis[int], 386 *, 387 pad_where: Union[ 388 PadWhere, 389 PerAxis[PadWhere], 390 ] = "left_and_right", 391 crop_where: Union[ 392 CropWhere, 393 PerAxis[CropWhere], 394 ] = "left_and_right", 395 pad_mode: PadMode = "symmetric", 396 ): 397 """return cropped/padded tensor with `sizes`""" 398 crop_to_sizes: Dict[AxisId, int] = {} 399 pad_to_sizes: Dict[AxisId, int] = {} 400 new_axes = dict(sizes) 401 for a, s_is in self.sizes.items(): 402 a = AxisId(str(a)) 403 _ = new_axes.pop(a, None) 404 if a not in sizes or sizes[a] == s_is: 405 pass 406 elif s_is > sizes[a]: 407 crop_to_sizes[a] = sizes[a] 408 else: 409 pad_to_sizes[a] = sizes[a] 410 411 tensor = self 412 if crop_to_sizes: 413 tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where) 414 415 if pad_to_sizes: 416 tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode) 417 418 if new_axes: 419 tensor = tensor.expand_dims(new_axes) 420 421 return tensor 422 423 def std(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 424 return self.__class__.from_xarray(self._data.std(dim=dim)) 425 426 def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 427 """Reduce this Tensor's data by applying sum along some dimension(s).""" 428 return self.__class__.from_xarray(self._data.sum(dim=dim)) 429 430 def transpose( 431 self, 432 axes: Sequence[AxisId], 433 ) -> Self: 434 """return a transposed tensor 435 436 Args: 437 axes: the desired tensor axes 438 """ 439 # expand missing tensor axes 440 missing_axes = tuple(a for a in axes if a not in self.dims) 441 array = self._data 442 if missing_axes: 443 array = array.expand_dims(missing_axes) 444 445 # transpose to the correct axis order 446 return self.__class__.from_xarray(array.transpose(*axes)) 447 448 def var(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 449 return self.__class__.from_xarray(self._data.var(dim=dim)) 450 451 @classmethod 452 def _interprete_array_wo_known_axes(cls, array: NDArray[Any]): 453 ndim = array.ndim 454 if ndim == 2: 455 current_axes = ( 456 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[0]), 457 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[1]), 458 ) 459 elif ndim == 3 and any(s <= 3 for s in array.shape): 460 current_axes = ( 461 v0_5.ChannelAxis( 462 channel_names=[ 463 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0]) 464 ] 465 ), 466 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]), 467 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]), 468 ) 469 elif ndim == 3: 470 current_axes = ( 471 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[0]), 472 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]), 473 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]), 474 ) 475 elif ndim == 4: 476 current_axes = ( 477 v0_5.ChannelAxis( 478 channel_names=[ 479 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0]) 480 ] 481 ), 482 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[1]), 483 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[2]), 484 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[3]), 485 ) 486 elif ndim == 5: 487 current_axes = ( 488 v0_5.BatchAxis(), 489 v0_5.ChannelAxis( 490 channel_names=[ 491 v0_5.Identifier(f"channel{i}") for i in range(array.shape[1]) 492 ] 493 ), 494 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[2]), 495 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[3]), 496 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[4]), 497 ) 498 else: 499 raise ValueError(f"Could not guess an axis mapping for {array.shape}") 500 501 return cls(array, dims=tuple(a.id for a in current_axes)) 502 503 504def _add_singletons(arr: NDArray[Any], axis_infos: Sequence[AxisInfo]): 505 if len(arr.shape) > len(axis_infos): 506 # remove singletons 507 for i, s in enumerate(arr.shape): 508 if s == 1: 509 arr = np.take(arr, 0, axis=i) 510 if len(arr.shape) == len(axis_infos): 511 break 512 513 # add singletons if nececsary 514 for i, a in enumerate(axis_infos): 515 if len(arr.shape) >= len(axis_infos): 516 break 517 518 if a.maybe_singleton: 519 arr = np.expand_dims(arr, i) 520 521 return arr 522 523 524def _get_array_view( 525 original_array: NDArray[Any], axis_infos: Sequence[AxisInfo] 526) -> Optional[NDArray[Any]]: 527 perms = list(permutations(range(len(original_array.shape)))) 528 perms.insert(1, perms.pop()) # try A and A.T first 529 530 for perm in perms: 531 view = original_array.transpose(perm) 532 view = _add_singletons(view, axis_infos) 533 if len(view.shape) != len(axis_infos): 534 return None 535 536 for s, a in zip(view.shape, axis_infos): 537 if s == 1 and not a.maybe_singleton: 538 break 539 else: 540 return view 541 542 return None
49class Tensor(MagicTensorOpsMixin): 50 """A wrapper around an xr.DataArray for better integration with bioimageio.spec 51 and improved type annotations.""" 52 53 _Compatible = Union["Tensor", xr.DataArray, _ScalarOrArray] 54 55 def __init__( 56 self, 57 array: NDArray[Any], 58 dims: Sequence[Union[AxisId, AxisLike]], 59 ) -> None: 60 super().__init__() 61 axes = tuple( 62 a if isinstance(a, AxisId) else AxisInfo.create(a).id for a in dims 63 ) 64 self._data = xr.DataArray(array, dims=axes) 65 66 def __array__(self, dtype: DTypeLike = None): 67 return np.asarray(self._data, dtype=dtype) 68 69 def __getitem__( 70 self, key: Union[SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]]] 71 ) -> Self: 72 if isinstance(key, SliceInfo): 73 key = slice(*key) 74 elif isinstance(key, collections.abc.Mapping): 75 key = { 76 a: s if isinstance(s, int) else s if isinstance(s, slice) else slice(*s) 77 for a, s in key.items() 78 } 79 return self.__class__.from_xarray(self._data[key]) 80 81 def __setitem__(self, key: PerAxis[Union[SliceInfo, slice]], value: Tensor) -> None: 82 key = {a: s if isinstance(s, slice) else slice(*s) for a, s in key.items()} 83 self._data[key] = value._data 84 85 def __len__(self) -> int: 86 return len(self.data) 87 88 def _iter(self: Any) -> Iterator[Any]: 89 for n in range(len(self)): 90 yield self[n] 91 92 def __iter__(self: Any) -> Iterator[Any]: 93 if self.ndim == 0: 94 raise TypeError("iteration over a 0-d array") 95 return self._iter() 96 97 def _binary_op( 98 self, 99 other: _Compatible, 100 f: Callable[[Any, Any], Any], 101 reflexive: bool = False, 102 ) -> Self: 103 data = self._data._binary_op( # pyright: ignore[reportPrivateUsage] 104 (other._data if isinstance(other, Tensor) else other), 105 f, 106 reflexive, 107 ) 108 return self.__class__.from_xarray(data) 109 110 def _inplace_binary_op( 111 self, 112 other: _Compatible, 113 f: Callable[[Any, Any], Any], 114 ) -> Self: 115 _ = self._data._inplace_binary_op( # pyright: ignore[reportPrivateUsage] 116 ( 117 other_d 118 if (other_d := getattr(other, "data")) is not None 119 and isinstance( 120 other_d, 121 xr.DataArray, 122 ) 123 else other 124 ), 125 f, 126 ) 127 return self 128 129 def _unary_op(self, f: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Self: 130 data = self._data._unary_op( # pyright: ignore[reportPrivateUsage] 131 f, *args, **kwargs 132 ) 133 return self.__class__.from_xarray(data) 134 135 @classmethod 136 def from_xarray(cls, data_array: xr.DataArray) -> Self: 137 """create a `Tensor` from an xarray data array 138 139 note for internal use: this factory method is round-trip save 140 for any `Tensor`'s `data` property (an xarray.DataArray). 141 """ 142 return cls( 143 array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims) 144 ) 145 146 @classmethod 147 def from_numpy( 148 cls, 149 array: NDArray[Any], 150 *, 151 dims: Optional[Union[AxisLike, Sequence[AxisLike]]], 152 ) -> Tensor: 153 """create a `Tensor` from a numpy array 154 155 Args: 156 array: the nd numpy array 157 axes: A description of the array's axes, 158 if None axes are guessed (which might fail and raise a ValueError.) 159 160 Raises: 161 ValueError: if `axes` is None and axes guessing fails. 162 """ 163 164 if dims is None: 165 return cls._interprete_array_wo_known_axes(array) 166 elif isinstance(dims, (str, Axis, v0_5.AxisBase)): 167 dims = [dims] 168 169 axis_infos = [AxisInfo.create(a) for a in dims] 170 original_shape = tuple(array.shape) 171 172 successful_view = _get_array_view(array, axis_infos) 173 if successful_view is None: 174 raise ValueError( 175 f"Array shape {original_shape} does not map to axes {dims}" 176 ) 177 178 return Tensor(successful_view, dims=tuple(a.id for a in axis_infos)) 179 180 @property 181 def data(self): 182 return self._data 183 184 @property 185 def dims(self): # TODO: rename to `axes`? 186 """Tuple of dimension names associated with this tensor.""" 187 return cast(Tuple[AxisId, ...], self._data.dims) 188 189 @property 190 def dtype(self) -> DTypeStr: 191 dt = str(self.data.dtype) # pyright: ignore[reportUnknownArgumentType] 192 assert dt in get_args(DTypeStr) 193 return dt # pyright: ignore[reportReturnType] 194 195 @property 196 def ndim(self): 197 """Number of tensor dimensions.""" 198 return self._data.ndim 199 200 @property 201 def shape(self): 202 """Tuple of tensor axes lengths""" 203 return self._data.shape 204 205 @property 206 def shape_tuple(self): 207 """Tuple of tensor axes lengths""" 208 return self._data.shape 209 210 @property 211 def size(self): 212 """Number of elements in the tensor. 213 214 Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions. 215 """ 216 return self._data.size 217 218 @property 219 def sizes(self): 220 """Ordered, immutable mapping from axis ids to axis lengths.""" 221 return cast(Mapping[AxisId, int], self.data.sizes) 222 223 @property 224 def tagged_shape(self): 225 """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths.""" 226 return self.sizes 227 228 def argmax(self) -> Mapping[AxisId, int]: 229 ret = self._data.argmax(...) 230 assert isinstance(ret, dict) 231 return {cast(AxisId, k): cast(int, v.item()) for k, v in ret.items()} 232 233 def astype(self, dtype: DTypeStr, *, copy: bool = False): 234 """Return tensor cast to `dtype` 235 236 note: if dtype is already satisfied copy if `copy`""" 237 return self.__class__.from_xarray(self._data.astype(dtype, copy=copy)) 238 239 def clip(self, min: Optional[float] = None, max: Optional[float] = None): 240 """Return a tensor whose values are limited to [min, max]. 241 At least one of max or min must be given.""" 242 return self.__class__.from_xarray(self._data.clip(min, max)) 243 244 def crop_to( 245 self, 246 sizes: PerAxis[int], 247 crop_where: Union[ 248 CropWhere, 249 PerAxis[CropWhere], 250 ] = "left_and_right", 251 ) -> Self: 252 """crop to match `sizes`""" 253 if isinstance(crop_where, str): 254 crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims} 255 else: 256 crop_axis_where = crop_where 257 258 slices: Dict[AxisId, SliceInfo] = {} 259 260 for a, s_is in self.sizes.items(): 261 if a not in sizes or sizes[a] == s_is: 262 pass 263 elif sizes[a] > s_is: 264 logger.warning( 265 "Cannot crop axis {} of size {} to larger size {}", 266 a, 267 s_is, 268 sizes[a], 269 ) 270 elif a not in crop_axis_where: 271 raise ValueError( 272 f"Don't know where to crop axis {a}, `crop_where`={crop_where}" 273 ) 274 else: 275 crop_this_axis_where = crop_axis_where[a] 276 if crop_this_axis_where == "left": 277 slices[a] = SliceInfo(s_is - sizes[a], s_is) 278 elif crop_this_axis_where == "right": 279 slices[a] = SliceInfo(0, sizes[a]) 280 elif crop_this_axis_where == "left_and_right": 281 slices[a] = SliceInfo( 282 start := (s_is - sizes[a]) // 2, sizes[a] + start 283 ) 284 else: 285 assert_never(crop_this_axis_where) 286 287 return self[slices] 288 289 def expand_dims(self, dims: Union[Sequence[AxisId], PerAxis[int]]) -> Self: 290 return self.__class__.from_xarray(self._data.expand_dims(dims=dims)) 291 292 def item( 293 self, 294 key: Union[ 295 None, SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]] 296 ] = None, 297 ): 298 """Copy a tensor element to a standard Python scalar and return it.""" 299 if key is None: 300 ret = self._data.item() 301 else: 302 ret = self[key]._data.item() 303 304 assert isinstance(ret, (bool, float, int)) 305 return ret 306 307 def mean(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 308 return self.__class__.from_xarray(self._data.mean(dim=dim)) 309 310 def pad( 311 self, 312 pad_width: PerAxis[PadWidthLike], 313 mode: PadMode = "symmetric", 314 ) -> Self: 315 pad_width = {a: PadWidth.create(p) for a, p in pad_width.items()} 316 return self.__class__.from_xarray( 317 self._data.pad(pad_width=pad_width, mode=mode) 318 ) 319 320 def pad_to( 321 self, 322 sizes: PerAxis[int], 323 pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right", 324 mode: PadMode = "symmetric", 325 ) -> Self: 326 """pad `tensor` to match `sizes`""" 327 if isinstance(pad_where, str): 328 pad_axis_where: PerAxis[PadWhere] = {a: pad_where for a in self.dims} 329 else: 330 pad_axis_where = pad_where 331 332 pad_width: Dict[AxisId, PadWidth] = {} 333 for a, s_is in self.sizes.items(): 334 if a not in sizes or sizes[a] == s_is: 335 pad_width[a] = PadWidth(0, 0) 336 elif s_is > sizes[a]: 337 pad_width[a] = PadWidth(0, 0) 338 logger.warning( 339 "Cannot pad axis {} of size {} to smaller size {}", 340 a, 341 s_is, 342 sizes[a], 343 ) 344 elif a not in pad_axis_where: 345 raise ValueError( 346 f"Don't know where to pad axis {a}, `pad_where`={pad_where}" 347 ) 348 else: 349 pad_this_axis_where = pad_axis_where[a] 350 d = sizes[a] - s_is 351 if pad_this_axis_where == "left": 352 pad_width[a] = PadWidth(d, 0) 353 elif pad_this_axis_where == "right": 354 pad_width[a] = PadWidth(0, d) 355 elif pad_this_axis_where == "left_and_right": 356 pad_width[a] = PadWidth(left := d // 2, d - left) 357 else: 358 assert_never(pad_this_axis_where) 359 360 return self.pad(pad_width, mode) 361 362 def quantile( 363 self, 364 q: Union[float, Sequence[float]], 365 dim: Optional[Union[AxisId, Sequence[AxisId]]] = None, 366 ) -> Self: 367 assert ( 368 isinstance(q, (float, int)) 369 and q >= 0.0 370 or not isinstance(q, (float, int)) 371 and all(qq >= 0.0 for qq in q) 372 ) 373 assert ( 374 isinstance(q, (float, int)) 375 and q <= 1.0 376 or not isinstance(q, (float, int)) 377 and all(qq <= 1.0 for qq in q) 378 ) 379 assert dim is None or ( 380 (quantile_dim := AxisId("quantile")) != dim and quantile_dim not in set(dim) 381 ) 382 return self.__class__.from_xarray(self._data.quantile(q, dim=dim)) 383 384 def resize_to( 385 self, 386 sizes: PerAxis[int], 387 *, 388 pad_where: Union[ 389 PadWhere, 390 PerAxis[PadWhere], 391 ] = "left_and_right", 392 crop_where: Union[ 393 CropWhere, 394 PerAxis[CropWhere], 395 ] = "left_and_right", 396 pad_mode: PadMode = "symmetric", 397 ): 398 """return cropped/padded tensor with `sizes`""" 399 crop_to_sizes: Dict[AxisId, int] = {} 400 pad_to_sizes: Dict[AxisId, int] = {} 401 new_axes = dict(sizes) 402 for a, s_is in self.sizes.items(): 403 a = AxisId(str(a)) 404 _ = new_axes.pop(a, None) 405 if a not in sizes or sizes[a] == s_is: 406 pass 407 elif s_is > sizes[a]: 408 crop_to_sizes[a] = sizes[a] 409 else: 410 pad_to_sizes[a] = sizes[a] 411 412 tensor = self 413 if crop_to_sizes: 414 tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where) 415 416 if pad_to_sizes: 417 tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode) 418 419 if new_axes: 420 tensor = tensor.expand_dims(new_axes) 421 422 return tensor 423 424 def std(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 425 return self.__class__.from_xarray(self._data.std(dim=dim)) 426 427 def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 428 """Reduce this Tensor's data by applying sum along some dimension(s).""" 429 return self.__class__.from_xarray(self._data.sum(dim=dim)) 430 431 def transpose( 432 self, 433 axes: Sequence[AxisId], 434 ) -> Self: 435 """return a transposed tensor 436 437 Args: 438 axes: the desired tensor axes 439 """ 440 # expand missing tensor axes 441 missing_axes = tuple(a for a in axes if a not in self.dims) 442 array = self._data 443 if missing_axes: 444 array = array.expand_dims(missing_axes) 445 446 # transpose to the correct axis order 447 return self.__class__.from_xarray(array.transpose(*axes)) 448 449 def var(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 450 return self.__class__.from_xarray(self._data.var(dim=dim)) 451 452 @classmethod 453 def _interprete_array_wo_known_axes(cls, array: NDArray[Any]): 454 ndim = array.ndim 455 if ndim == 2: 456 current_axes = ( 457 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[0]), 458 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[1]), 459 ) 460 elif ndim == 3 and any(s <= 3 for s in array.shape): 461 current_axes = ( 462 v0_5.ChannelAxis( 463 channel_names=[ 464 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0]) 465 ] 466 ), 467 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]), 468 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]), 469 ) 470 elif ndim == 3: 471 current_axes = ( 472 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[0]), 473 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]), 474 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]), 475 ) 476 elif ndim == 4: 477 current_axes = ( 478 v0_5.ChannelAxis( 479 channel_names=[ 480 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0]) 481 ] 482 ), 483 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[1]), 484 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[2]), 485 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[3]), 486 ) 487 elif ndim == 5: 488 current_axes = ( 489 v0_5.BatchAxis(), 490 v0_5.ChannelAxis( 491 channel_names=[ 492 v0_5.Identifier(f"channel{i}") for i in range(array.shape[1]) 493 ] 494 ), 495 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[2]), 496 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[3]), 497 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[4]), 498 ) 499 else: 500 raise ValueError(f"Could not guess an axis mapping for {array.shape}") 501 502 return cls(array, dims=tuple(a.id for a in current_axes))
A wrapper around an xr.DataArray for better integration with bioimageio.spec and improved type annotations.
135 @classmethod 136 def from_xarray(cls, data_array: xr.DataArray) -> Self: 137 """create a `Tensor` from an xarray data array 138 139 note for internal use: this factory method is round-trip save 140 for any `Tensor`'s `data` property (an xarray.DataArray). 141 """ 142 return cls( 143 array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims) 144 )
146 @classmethod 147 def from_numpy( 148 cls, 149 array: NDArray[Any], 150 *, 151 dims: Optional[Union[AxisLike, Sequence[AxisLike]]], 152 ) -> Tensor: 153 """create a `Tensor` from a numpy array 154 155 Args: 156 array: the nd numpy array 157 axes: A description of the array's axes, 158 if None axes are guessed (which might fail and raise a ValueError.) 159 160 Raises: 161 ValueError: if `axes` is None and axes guessing fails. 162 """ 163 164 if dims is None: 165 return cls._interprete_array_wo_known_axes(array) 166 elif isinstance(dims, (str, Axis, v0_5.AxisBase)): 167 dims = [dims] 168 169 axis_infos = [AxisInfo.create(a) for a in dims] 170 original_shape = tuple(array.shape) 171 172 successful_view = _get_array_view(array, axis_infos) 173 if successful_view is None: 174 raise ValueError( 175 f"Array shape {original_shape} does not map to axes {dims}" 176 ) 177 178 return Tensor(successful_view, dims=tuple(a.id for a in axis_infos))
create a Tensor
from a numpy array
Arguments:
- array: the nd numpy array
- axes: A description of the array's axes, if None axes are guessed (which might fail and raise a ValueError.)
Raises:
- ValueError: if
axes
is None and axes guessing fails.
184 @property 185 def dims(self): # TODO: rename to `axes`? 186 """Tuple of dimension names associated with this tensor.""" 187 return cast(Tuple[AxisId, ...], self._data.dims)
Tuple of dimension names associated with this tensor.
200 @property 201 def shape(self): 202 """Tuple of tensor axes lengths""" 203 return self._data.shape
Tuple of tensor axes lengths
205 @property 206 def shape_tuple(self): 207 """Tuple of tensor axes lengths""" 208 return self._data.shape
Tuple of tensor axes lengths
210 @property 211 def size(self): 212 """Number of elements in the tensor. 213 214 Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions. 215 """ 216 return self._data.size
Number of elements in the tensor.
Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions.
218 @property 219 def sizes(self): 220 """Ordered, immutable mapping from axis ids to axis lengths.""" 221 return cast(Mapping[AxisId, int], self.data.sizes)
Ordered, immutable mapping from axis ids to axis lengths.
223 @property 224 def tagged_shape(self): 225 """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths.""" 226 return self.sizes
(alias for sizes
) Ordered, immutable mapping from axis ids to lengths.
233 def astype(self, dtype: DTypeStr, *, copy: bool = False): 234 """Return tensor cast to `dtype` 235 236 note: if dtype is already satisfied copy if `copy`""" 237 return self.__class__.from_xarray(self._data.astype(dtype, copy=copy))
Return tensor cast to dtype
note: if dtype is already satisfied copy if copy
239 def clip(self, min: Optional[float] = None, max: Optional[float] = None): 240 """Return a tensor whose values are limited to [min, max]. 241 At least one of max or min must be given.""" 242 return self.__class__.from_xarray(self._data.clip(min, max))
Return a tensor whose values are limited to [min, max]. At least one of max or min must be given.
244 def crop_to( 245 self, 246 sizes: PerAxis[int], 247 crop_where: Union[ 248 CropWhere, 249 PerAxis[CropWhere], 250 ] = "left_and_right", 251 ) -> Self: 252 """crop to match `sizes`""" 253 if isinstance(crop_where, str): 254 crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims} 255 else: 256 crop_axis_where = crop_where 257 258 slices: Dict[AxisId, SliceInfo] = {} 259 260 for a, s_is in self.sizes.items(): 261 if a not in sizes or sizes[a] == s_is: 262 pass 263 elif sizes[a] > s_is: 264 logger.warning( 265 "Cannot crop axis {} of size {} to larger size {}", 266 a, 267 s_is, 268 sizes[a], 269 ) 270 elif a not in crop_axis_where: 271 raise ValueError( 272 f"Don't know where to crop axis {a}, `crop_where`={crop_where}" 273 ) 274 else: 275 crop_this_axis_where = crop_axis_where[a] 276 if crop_this_axis_where == "left": 277 slices[a] = SliceInfo(s_is - sizes[a], s_is) 278 elif crop_this_axis_where == "right": 279 slices[a] = SliceInfo(0, sizes[a]) 280 elif crop_this_axis_where == "left_and_right": 281 slices[a] = SliceInfo( 282 start := (s_is - sizes[a]) // 2, sizes[a] + start 283 ) 284 else: 285 assert_never(crop_this_axis_where) 286 287 return self[slices]
crop to match sizes
292 def item( 293 self, 294 key: Union[ 295 None, SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]] 296 ] = None, 297 ): 298 """Copy a tensor element to a standard Python scalar and return it.""" 299 if key is None: 300 ret = self._data.item() 301 else: 302 ret = self[key]._data.item() 303 304 assert isinstance(ret, (bool, float, int)) 305 return ret
Copy a tensor element to a standard Python scalar and return it.
320 def pad_to( 321 self, 322 sizes: PerAxis[int], 323 pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right", 324 mode: PadMode = "symmetric", 325 ) -> Self: 326 """pad `tensor` to match `sizes`""" 327 if isinstance(pad_where, str): 328 pad_axis_where: PerAxis[PadWhere] = {a: pad_where for a in self.dims} 329 else: 330 pad_axis_where = pad_where 331 332 pad_width: Dict[AxisId, PadWidth] = {} 333 for a, s_is in self.sizes.items(): 334 if a not in sizes or sizes[a] == s_is: 335 pad_width[a] = PadWidth(0, 0) 336 elif s_is > sizes[a]: 337 pad_width[a] = PadWidth(0, 0) 338 logger.warning( 339 "Cannot pad axis {} of size {} to smaller size {}", 340 a, 341 s_is, 342 sizes[a], 343 ) 344 elif a not in pad_axis_where: 345 raise ValueError( 346 f"Don't know where to pad axis {a}, `pad_where`={pad_where}" 347 ) 348 else: 349 pad_this_axis_where = pad_axis_where[a] 350 d = sizes[a] - s_is 351 if pad_this_axis_where == "left": 352 pad_width[a] = PadWidth(d, 0) 353 elif pad_this_axis_where == "right": 354 pad_width[a] = PadWidth(0, d) 355 elif pad_this_axis_where == "left_and_right": 356 pad_width[a] = PadWidth(left := d // 2, d - left) 357 else: 358 assert_never(pad_this_axis_where) 359 360 return self.pad(pad_width, mode)
pad tensor
to match sizes
362 def quantile( 363 self, 364 q: Union[float, Sequence[float]], 365 dim: Optional[Union[AxisId, Sequence[AxisId]]] = None, 366 ) -> Self: 367 assert ( 368 isinstance(q, (float, int)) 369 and q >= 0.0 370 or not isinstance(q, (float, int)) 371 and all(qq >= 0.0 for qq in q) 372 ) 373 assert ( 374 isinstance(q, (float, int)) 375 and q <= 1.0 376 or not isinstance(q, (float, int)) 377 and all(qq <= 1.0 for qq in q) 378 ) 379 assert dim is None or ( 380 (quantile_dim := AxisId("quantile")) != dim and quantile_dim not in set(dim) 381 ) 382 return self.__class__.from_xarray(self._data.quantile(q, dim=dim))
384 def resize_to( 385 self, 386 sizes: PerAxis[int], 387 *, 388 pad_where: Union[ 389 PadWhere, 390 PerAxis[PadWhere], 391 ] = "left_and_right", 392 crop_where: Union[ 393 CropWhere, 394 PerAxis[CropWhere], 395 ] = "left_and_right", 396 pad_mode: PadMode = "symmetric", 397 ): 398 """return cropped/padded tensor with `sizes`""" 399 crop_to_sizes: Dict[AxisId, int] = {} 400 pad_to_sizes: Dict[AxisId, int] = {} 401 new_axes = dict(sizes) 402 for a, s_is in self.sizes.items(): 403 a = AxisId(str(a)) 404 _ = new_axes.pop(a, None) 405 if a not in sizes or sizes[a] == s_is: 406 pass 407 elif s_is > sizes[a]: 408 crop_to_sizes[a] = sizes[a] 409 else: 410 pad_to_sizes[a] = sizes[a] 411 412 tensor = self 413 if crop_to_sizes: 414 tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where) 415 416 if pad_to_sizes: 417 tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode) 418 419 if new_axes: 420 tensor = tensor.expand_dims(new_axes) 421 422 return tensor
return cropped/padded tensor with sizes
427 def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 428 """Reduce this Tensor's data by applying sum along some dimension(s).""" 429 return self.__class__.from_xarray(self._data.sum(dim=dim))
Reduce this Tensor's data by applying sum along some dimension(s).
431 def transpose( 432 self, 433 axes: Sequence[AxisId], 434 ) -> Self: 435 """return a transposed tensor 436 437 Args: 438 axes: the desired tensor axes 439 """ 440 # expand missing tensor axes 441 missing_axes = tuple(a for a in axes if a not in self.dims) 442 array = self._data 443 if missing_axes: 444 array = array.expand_dims(missing_axes) 445 446 # transpose to the correct axis order 447 return self.__class__.from_xarray(array.transpose(*axes))
return a transposed tensor
Arguments:
- axes: the desired tensor axes