bioimageio.core.tensor

  1from __future__ import annotations
  2
  3import collections.abc
  4from itertools import permutations
  5from typing import (
  6    TYPE_CHECKING,
  7    Any,
  8    Callable,
  9    Dict,
 10    Iterator,
 11    Mapping,
 12    Optional,
 13    Sequence,
 14    Tuple,
 15    Union,
 16    cast,
 17    get_args,
 18)
 19
 20import numpy as np
 21import xarray as xr
 22from loguru import logger
 23from numpy.typing import DTypeLike, NDArray
 24from typing_extensions import Self, assert_never
 25
 26from bioimageio.spec.model import v0_5
 27
 28from ._magic_tensor_ops import MagicTensorOpsMixin
 29from .axis import Axis, AxisId, AxisInfo, AxisLike, PerAxis
 30from .common import (
 31    CropWhere,
 32    DTypeStr,
 33    PadMode,
 34    PadWhere,
 35    PadWidth,
 36    PadWidthLike,
 37    SliceInfo,
 38)
 39
 40if TYPE_CHECKING:
 41    from numpy.typing import ArrayLike, NDArray
 42
 43
 44_ScalarOrArray = Union["ArrayLike", np.generic, "NDArray[Any]"]  # TODO: add "DaskArray"
 45
 46
 47# TODO: complete docstrings
 48class Tensor(MagicTensorOpsMixin):
 49    """A wrapper around an xr.DataArray for better integration with bioimageio.spec
 50    and improved type annotations."""
 51
 52    _Compatible = Union["Tensor", xr.DataArray, _ScalarOrArray]
 53
 54    def __init__(
 55        self,
 56        array: NDArray[Any],
 57        dims: Sequence[Union[AxisId, AxisLike]],
 58    ) -> None:
 59        super().__init__()
 60        axes = tuple(
 61            a if isinstance(a, AxisId) else AxisInfo.create(a).id for a in dims
 62        )
 63        self._data = xr.DataArray(array, dims=axes)
 64
 65    def __array__(self, dtype: DTypeLike = None):
 66        return np.asarray(self._data, dtype=dtype)
 67
 68    def __getitem__(
 69        self, key: Union[SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]]]
 70    ) -> Self:
 71        if isinstance(key, SliceInfo):
 72            key = slice(*key)
 73        elif isinstance(key, collections.abc.Mapping):
 74            key = {
 75                a: s if isinstance(s, int) else s if isinstance(s, slice) else slice(*s)
 76                for a, s in key.items()
 77            }
 78        return self.__class__.from_xarray(self._data[key])
 79
 80    def __setitem__(self, key: PerAxis[Union[SliceInfo, slice]], value: Tensor) -> None:
 81        key = {a: s if isinstance(s, slice) else slice(*s) for a, s in key.items()}
 82        self._data[key] = value._data
 83
 84    def __len__(self) -> int:
 85        return len(self.data)
 86
 87    def _iter(self: Any) -> Iterator[Any]:
 88        for n in range(len(self)):
 89            yield self[n]
 90
 91    def __iter__(self: Any) -> Iterator[Any]:
 92        if self.ndim == 0:
 93            raise TypeError("iteration over a 0-d array")
 94        return self._iter()
 95
 96    def _binary_op(
 97        self,
 98        other: _Compatible,
 99        f: Callable[[Any, Any], Any],
100        reflexive: bool = False,
101    ) -> Self:
102        data = self._data._binary_op(  # pyright: ignore[reportPrivateUsage]
103            (other._data if isinstance(other, Tensor) else other),
104            f,
105            reflexive,
106        )
107        return self.__class__.from_xarray(data)
108
109    def _inplace_binary_op(
110        self,
111        other: _Compatible,
112        f: Callable[[Any, Any], Any],
113    ) -> Self:
114        _ = self._data._inplace_binary_op(  # pyright: ignore[reportPrivateUsage]
115            (
116                other_d
117                if (other_d := getattr(other, "data")) is not None
118                and isinstance(
119                    other_d,
120                    xr.DataArray,
121                )
122                else other
123            ),
124            f,
125        )
126        return self
127
128    def _unary_op(self, f: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Self:
129        data = self._data._unary_op(  # pyright: ignore[reportPrivateUsage]
130            f, *args, **kwargs
131        )
132        return self.__class__.from_xarray(data)
133
134    @classmethod
135    def from_xarray(cls, data_array: xr.DataArray) -> Self:
136        """create a `Tensor` from an xarray data array
137
138        note for internal use: this factory method is round-trip save
139            for any `Tensor`'s  `data` property (an xarray.DataArray).
140        """
141        return cls(
142            array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims)
143        )
144
145    @classmethod
146    def from_numpy(
147        cls,
148        array: NDArray[Any],
149        *,
150        dims: Optional[Union[AxisLike, Sequence[AxisLike]]],
151    ) -> Tensor:
152        """create a `Tensor` from a numpy array
153
154        Args:
155            array: the nd numpy array
156            axes: A description of the array's axes,
157                if None axes are guessed (which might fail and raise a ValueError.)
158
159        Raises:
160            ValueError: if `axes` is None and axes guessing fails.
161        """
162
163        if dims is None:
164            return cls._interprete_array_wo_known_axes(array)
165        elif isinstance(dims, (str, Axis, v0_5.AxisBase)):
166            dims = [dims]
167
168        axis_infos = [AxisInfo.create(a) for a in dims]
169        original_shape = tuple(array.shape)
170
171        successful_view = _get_array_view(array, axis_infos)
172        if successful_view is None:
173            raise ValueError(
174                f"Array shape {original_shape} does not map to axes {dims}"
175            )
176
177        return Tensor(successful_view, dims=tuple(a.id for a in axis_infos))
178
179    @property
180    def data(self):
181        return self._data
182
183    @property
184    def dims(self):  # TODO: rename to `axes`?
185        """Tuple of dimension names associated with this tensor."""
186        return cast(Tuple[AxisId, ...], self._data.dims)
187
188    @property
189    def dtype(self) -> DTypeStr:
190        dt = str(self.data.dtype)  # pyright: ignore[reportUnknownArgumentType]
191        assert dt in get_args(DTypeStr)
192        return dt  # pyright: ignore[reportReturnType]
193
194    @property
195    def ndim(self):
196        """Number of tensor dimensions."""
197        return self._data.ndim
198
199    @property
200    def shape(self):
201        """Tuple of tensor axes lengths"""
202        return self._data.shape
203
204    @property
205    def shape_tuple(self):
206        """Tuple of tensor axes lengths"""
207        return self._data.shape
208
209    @property
210    def size(self):
211        """Number of elements in the tensor.
212
213        Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions.
214        """
215        return self._data.size
216
217    @property
218    def sizes(self):
219        """Ordered, immutable mapping from axis ids to axis lengths."""
220        return cast(Mapping[AxisId, int], self.data.sizes)
221
222    @property
223    def tagged_shape(self):
224        """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths."""
225        return self.sizes
226
227    def argmax(self) -> Mapping[AxisId, int]:
228        ret = self._data.argmax(...)
229        assert isinstance(ret, dict)
230        return {cast(AxisId, k): cast(int, v.item()) for k, v in ret.items()}
231
232    def astype(self, dtype: DTypeStr, *, copy: bool = False):
233        """Return tensor cast to `dtype`
234
235        note: if dtype is already satisfied copy if `copy`"""
236        return self.__class__.from_xarray(self._data.astype(dtype, copy=copy))
237
238    def clip(self, min: Optional[float] = None, max: Optional[float] = None):
239        """Return a tensor whose values are limited to [min, max].
240        At least one of max or min must be given."""
241        return self.__class__.from_xarray(self._data.clip(min, max))
242
243    def crop_to(
244        self,
245        sizes: PerAxis[int],
246        crop_where: Union[
247            CropWhere,
248            PerAxis[CropWhere],
249        ] = "left_and_right",
250    ) -> Self:
251        """crop to match `sizes`"""
252        if isinstance(crop_where, str):
253            crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims}
254        else:
255            crop_axis_where = crop_where
256
257        slices: Dict[AxisId, SliceInfo] = {}
258
259        for a, s_is in self.sizes.items():
260            if a not in sizes or sizes[a] == s_is:
261                pass
262            elif sizes[a] > s_is:
263                logger.warning(
264                    "Cannot crop axis {} of size {} to larger size {}",
265                    a,
266                    s_is,
267                    sizes[a],
268                )
269            elif a not in crop_axis_where:
270                raise ValueError(
271                    f"Don't know where to crop axis {a}, `crop_where`={crop_where}"
272                )
273            else:
274                crop_this_axis_where = crop_axis_where[a]
275                if crop_this_axis_where == "left":
276                    slices[a] = SliceInfo(s_is - sizes[a], s_is)
277                elif crop_this_axis_where == "right":
278                    slices[a] = SliceInfo(0, sizes[a])
279                elif crop_this_axis_where == "left_and_right":
280                    slices[a] = SliceInfo(
281                        start := (s_is - sizes[a]) // 2, sizes[a] + start
282                    )
283                else:
284                    assert_never(crop_this_axis_where)
285
286        return self[slices]
287
288    def expand_dims(self, dims: Union[Sequence[AxisId], PerAxis[int]]) -> Self:
289        return self.__class__.from_xarray(self._data.expand_dims(dims=dims))
290
291    def item(
292        self,
293        key: Union[
294            None, SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]]
295        ] = None,
296    ):
297        """Copy a tensor element to a standard Python scalar and return it."""
298        if key is None:
299            ret = self._data.item()
300        else:
301            ret = self[key]._data.item()
302
303        assert isinstance(ret, (bool, float, int))
304        return ret
305
306    def mean(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
307        return self.__class__.from_xarray(self._data.mean(dim=dim))
308
309    def pad(
310        self,
311        pad_width: PerAxis[PadWidthLike],
312        mode: PadMode = "symmetric",
313    ) -> Self:
314        pad_width = {a: PadWidth.create(p) for a, p in pad_width.items()}
315        return self.__class__.from_xarray(
316            self._data.pad(pad_width=pad_width, mode=mode)
317        )
318
319    def pad_to(
320        self,
321        sizes: PerAxis[int],
322        pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right",
323        mode: PadMode = "symmetric",
324    ) -> Self:
325        """pad `tensor` to match `sizes`"""
326        if isinstance(pad_where, str):
327            pad_axis_where: PerAxis[PadWhere] = {a: pad_where for a in self.dims}
328        else:
329            pad_axis_where = pad_where
330
331        pad_width: Dict[AxisId, PadWidth] = {}
332        for a, s_is in self.sizes.items():
333            if a not in sizes or sizes[a] == s_is:
334                pad_width[a] = PadWidth(0, 0)
335            elif s_is > sizes[a]:
336                pad_width[a] = PadWidth(0, 0)
337                logger.warning(
338                    "Cannot pad axis {} of size {} to smaller size {}",
339                    a,
340                    s_is,
341                    sizes[a],
342                )
343            elif a not in pad_axis_where:
344                raise ValueError(
345                    f"Don't know where to pad axis {a}, `pad_where`={pad_where}"
346                )
347            else:
348                pad_this_axis_where = pad_axis_where[a]
349                d = sizes[a] - s_is
350                if pad_this_axis_where == "left":
351                    pad_width[a] = PadWidth(d, 0)
352                elif pad_this_axis_where == "right":
353                    pad_width[a] = PadWidth(0, d)
354                elif pad_this_axis_where == "left_and_right":
355                    pad_width[a] = PadWidth(left := d // 2, d - left)
356                else:
357                    assert_never(pad_this_axis_where)
358
359        return self.pad(pad_width, mode)
360
361    def quantile(
362        self,
363        q: Union[float, Sequence[float]],
364        dim: Optional[Union[AxisId, Sequence[AxisId]]] = None,
365    ) -> Self:
366        assert (
367            isinstance(q, (float, int))
368            and q >= 0.0
369            or not isinstance(q, (float, int))
370            and all(qq >= 0.0 for qq in q)
371        )
372        assert (
373            isinstance(q, (float, int))
374            and q <= 1.0
375            or not isinstance(q, (float, int))
376            and all(qq <= 1.0 for qq in q)
377        )
378        assert dim is None or (
379            (quantile_dim := AxisId("quantile")) != dim and quantile_dim not in set(dim)
380        )
381        return self.__class__.from_xarray(self._data.quantile(q, dim=dim))
382
383    def resize_to(
384        self,
385        sizes: PerAxis[int],
386        *,
387        pad_where: Union[
388            PadWhere,
389            PerAxis[PadWhere],
390        ] = "left_and_right",
391        crop_where: Union[
392            CropWhere,
393            PerAxis[CropWhere],
394        ] = "left_and_right",
395        pad_mode: PadMode = "symmetric",
396    ):
397        """return cropped/padded tensor with `sizes`"""
398        crop_to_sizes: Dict[AxisId, int] = {}
399        pad_to_sizes: Dict[AxisId, int] = {}
400        new_axes = dict(sizes)
401        for a, s_is in self.sizes.items():
402            a = AxisId(str(a))
403            _ = new_axes.pop(a, None)
404            if a not in sizes or sizes[a] == s_is:
405                pass
406            elif s_is > sizes[a]:
407                crop_to_sizes[a] = sizes[a]
408            else:
409                pad_to_sizes[a] = sizes[a]
410
411        tensor = self
412        if crop_to_sizes:
413            tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where)
414
415        if pad_to_sizes:
416            tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode)
417
418        if new_axes:
419            tensor = tensor.expand_dims(new_axes)
420
421        return tensor
422
423    def std(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
424        return self.__class__.from_xarray(self._data.std(dim=dim))
425
426    def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
427        """Reduce this Tensor's data by applying sum along some dimension(s)."""
428        return self.__class__.from_xarray(self._data.sum(dim=dim))
429
430    def transpose(
431        self,
432        axes: Sequence[AxisId],
433    ) -> Self:
434        """return a transposed tensor
435
436        Args:
437            axes: the desired tensor axes
438        """
439        # expand missing tensor axes
440        missing_axes = tuple(a for a in axes if a not in self.dims)
441        array = self._data
442        if missing_axes:
443            array = array.expand_dims(missing_axes)
444
445        # transpose to the correct axis order
446        return self.__class__.from_xarray(array.transpose(*axes))
447
448    def var(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
449        return self.__class__.from_xarray(self._data.var(dim=dim))
450
451    @classmethod
452    def _interprete_array_wo_known_axes(cls, array: NDArray[Any]):
453        ndim = array.ndim
454        if ndim == 2:
455            current_axes = (
456                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[0]),
457                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[1]),
458            )
459        elif ndim == 3 and any(s <= 3 for s in array.shape):
460            current_axes = (
461                v0_5.ChannelAxis(
462                    channel_names=[
463                        v0_5.Identifier(f"channel{i}") for i in range(array.shape[0])
464                    ]
465                ),
466                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]),
467                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]),
468            )
469        elif ndim == 3:
470            current_axes = (
471                v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[0]),
472                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]),
473                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]),
474            )
475        elif ndim == 4:
476            current_axes = (
477                v0_5.ChannelAxis(
478                    channel_names=[
479                        v0_5.Identifier(f"channel{i}") for i in range(array.shape[0])
480                    ]
481                ),
482                v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[1]),
483                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[2]),
484                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[3]),
485            )
486        elif ndim == 5:
487            current_axes = (
488                v0_5.BatchAxis(),
489                v0_5.ChannelAxis(
490                    channel_names=[
491                        v0_5.Identifier(f"channel{i}") for i in range(array.shape[1])
492                    ]
493                ),
494                v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[2]),
495                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[3]),
496                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[4]),
497            )
498        else:
499            raise ValueError(f"Could not guess an axis mapping for {array.shape}")
500
501        return cls(array, dims=tuple(a.id for a in current_axes))
502
503
504def _add_singletons(arr: NDArray[Any], axis_infos: Sequence[AxisInfo]):
505    if len(arr.shape) > len(axis_infos):
506        # remove singletons
507        for i, s in enumerate(arr.shape):
508            if s == 1:
509                arr = np.take(arr, 0, axis=i)
510                if len(arr.shape) == len(axis_infos):
511                    break
512
513    # add singletons if nececsary
514    for i, a in enumerate(axis_infos):
515        if len(arr.shape) >= len(axis_infos):
516            break
517
518        if a.maybe_singleton:
519            arr = np.expand_dims(arr, i)
520
521    return arr
522
523
524def _get_array_view(
525    original_array: NDArray[Any], axis_infos: Sequence[AxisInfo]
526) -> Optional[NDArray[Any]]:
527    perms = list(permutations(range(len(original_array.shape))))
528    perms.insert(1, perms.pop())  # try A and A.T first
529
530    for perm in perms:
531        view = original_array.transpose(perm)
532        view = _add_singletons(view, axis_infos)
533        if len(view.shape) != len(axis_infos):
534            return None
535
536        for s, a in zip(view.shape, axis_infos):
537            if s == 1 and not a.maybe_singleton:
538                break
539        else:
540            return view
541
542    return None
class Tensor(bioimageio.core._magic_tensor_ops.MagicTensorOpsMixin):
 49class Tensor(MagicTensorOpsMixin):
 50    """A wrapper around an xr.DataArray for better integration with bioimageio.spec
 51    and improved type annotations."""
 52
 53    _Compatible = Union["Tensor", xr.DataArray, _ScalarOrArray]
 54
 55    def __init__(
 56        self,
 57        array: NDArray[Any],
 58        dims: Sequence[Union[AxisId, AxisLike]],
 59    ) -> None:
 60        super().__init__()
 61        axes = tuple(
 62            a if isinstance(a, AxisId) else AxisInfo.create(a).id for a in dims
 63        )
 64        self._data = xr.DataArray(array, dims=axes)
 65
 66    def __array__(self, dtype: DTypeLike = None):
 67        return np.asarray(self._data, dtype=dtype)
 68
 69    def __getitem__(
 70        self, key: Union[SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]]]
 71    ) -> Self:
 72        if isinstance(key, SliceInfo):
 73            key = slice(*key)
 74        elif isinstance(key, collections.abc.Mapping):
 75            key = {
 76                a: s if isinstance(s, int) else s if isinstance(s, slice) else slice(*s)
 77                for a, s in key.items()
 78            }
 79        return self.__class__.from_xarray(self._data[key])
 80
 81    def __setitem__(self, key: PerAxis[Union[SliceInfo, slice]], value: Tensor) -> None:
 82        key = {a: s if isinstance(s, slice) else slice(*s) for a, s in key.items()}
 83        self._data[key] = value._data
 84
 85    def __len__(self) -> int:
 86        return len(self.data)
 87
 88    def _iter(self: Any) -> Iterator[Any]:
 89        for n in range(len(self)):
 90            yield self[n]
 91
 92    def __iter__(self: Any) -> Iterator[Any]:
 93        if self.ndim == 0:
 94            raise TypeError("iteration over a 0-d array")
 95        return self._iter()
 96
 97    def _binary_op(
 98        self,
 99        other: _Compatible,
100        f: Callable[[Any, Any], Any],
101        reflexive: bool = False,
102    ) -> Self:
103        data = self._data._binary_op(  # pyright: ignore[reportPrivateUsage]
104            (other._data if isinstance(other, Tensor) else other),
105            f,
106            reflexive,
107        )
108        return self.__class__.from_xarray(data)
109
110    def _inplace_binary_op(
111        self,
112        other: _Compatible,
113        f: Callable[[Any, Any], Any],
114    ) -> Self:
115        _ = self._data._inplace_binary_op(  # pyright: ignore[reportPrivateUsage]
116            (
117                other_d
118                if (other_d := getattr(other, "data")) is not None
119                and isinstance(
120                    other_d,
121                    xr.DataArray,
122                )
123                else other
124            ),
125            f,
126        )
127        return self
128
129    def _unary_op(self, f: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Self:
130        data = self._data._unary_op(  # pyright: ignore[reportPrivateUsage]
131            f, *args, **kwargs
132        )
133        return self.__class__.from_xarray(data)
134
135    @classmethod
136    def from_xarray(cls, data_array: xr.DataArray) -> Self:
137        """create a `Tensor` from an xarray data array
138
139        note for internal use: this factory method is round-trip save
140            for any `Tensor`'s  `data` property (an xarray.DataArray).
141        """
142        return cls(
143            array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims)
144        )
145
146    @classmethod
147    def from_numpy(
148        cls,
149        array: NDArray[Any],
150        *,
151        dims: Optional[Union[AxisLike, Sequence[AxisLike]]],
152    ) -> Tensor:
153        """create a `Tensor` from a numpy array
154
155        Args:
156            array: the nd numpy array
157            axes: A description of the array's axes,
158                if None axes are guessed (which might fail and raise a ValueError.)
159
160        Raises:
161            ValueError: if `axes` is None and axes guessing fails.
162        """
163
164        if dims is None:
165            return cls._interprete_array_wo_known_axes(array)
166        elif isinstance(dims, (str, Axis, v0_5.AxisBase)):
167            dims = [dims]
168
169        axis_infos = [AxisInfo.create(a) for a in dims]
170        original_shape = tuple(array.shape)
171
172        successful_view = _get_array_view(array, axis_infos)
173        if successful_view is None:
174            raise ValueError(
175                f"Array shape {original_shape} does not map to axes {dims}"
176            )
177
178        return Tensor(successful_view, dims=tuple(a.id for a in axis_infos))
179
180    @property
181    def data(self):
182        return self._data
183
184    @property
185    def dims(self):  # TODO: rename to `axes`?
186        """Tuple of dimension names associated with this tensor."""
187        return cast(Tuple[AxisId, ...], self._data.dims)
188
189    @property
190    def dtype(self) -> DTypeStr:
191        dt = str(self.data.dtype)  # pyright: ignore[reportUnknownArgumentType]
192        assert dt in get_args(DTypeStr)
193        return dt  # pyright: ignore[reportReturnType]
194
195    @property
196    def ndim(self):
197        """Number of tensor dimensions."""
198        return self._data.ndim
199
200    @property
201    def shape(self):
202        """Tuple of tensor axes lengths"""
203        return self._data.shape
204
205    @property
206    def shape_tuple(self):
207        """Tuple of tensor axes lengths"""
208        return self._data.shape
209
210    @property
211    def size(self):
212        """Number of elements in the tensor.
213
214        Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions.
215        """
216        return self._data.size
217
218    @property
219    def sizes(self):
220        """Ordered, immutable mapping from axis ids to axis lengths."""
221        return cast(Mapping[AxisId, int], self.data.sizes)
222
223    @property
224    def tagged_shape(self):
225        """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths."""
226        return self.sizes
227
228    def argmax(self) -> Mapping[AxisId, int]:
229        ret = self._data.argmax(...)
230        assert isinstance(ret, dict)
231        return {cast(AxisId, k): cast(int, v.item()) for k, v in ret.items()}
232
233    def astype(self, dtype: DTypeStr, *, copy: bool = False):
234        """Return tensor cast to `dtype`
235
236        note: if dtype is already satisfied copy if `copy`"""
237        return self.__class__.from_xarray(self._data.astype(dtype, copy=copy))
238
239    def clip(self, min: Optional[float] = None, max: Optional[float] = None):
240        """Return a tensor whose values are limited to [min, max].
241        At least one of max or min must be given."""
242        return self.__class__.from_xarray(self._data.clip(min, max))
243
244    def crop_to(
245        self,
246        sizes: PerAxis[int],
247        crop_where: Union[
248            CropWhere,
249            PerAxis[CropWhere],
250        ] = "left_and_right",
251    ) -> Self:
252        """crop to match `sizes`"""
253        if isinstance(crop_where, str):
254            crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims}
255        else:
256            crop_axis_where = crop_where
257
258        slices: Dict[AxisId, SliceInfo] = {}
259
260        for a, s_is in self.sizes.items():
261            if a not in sizes or sizes[a] == s_is:
262                pass
263            elif sizes[a] > s_is:
264                logger.warning(
265                    "Cannot crop axis {} of size {} to larger size {}",
266                    a,
267                    s_is,
268                    sizes[a],
269                )
270            elif a not in crop_axis_where:
271                raise ValueError(
272                    f"Don't know where to crop axis {a}, `crop_where`={crop_where}"
273                )
274            else:
275                crop_this_axis_where = crop_axis_where[a]
276                if crop_this_axis_where == "left":
277                    slices[a] = SliceInfo(s_is - sizes[a], s_is)
278                elif crop_this_axis_where == "right":
279                    slices[a] = SliceInfo(0, sizes[a])
280                elif crop_this_axis_where == "left_and_right":
281                    slices[a] = SliceInfo(
282                        start := (s_is - sizes[a]) // 2, sizes[a] + start
283                    )
284                else:
285                    assert_never(crop_this_axis_where)
286
287        return self[slices]
288
289    def expand_dims(self, dims: Union[Sequence[AxisId], PerAxis[int]]) -> Self:
290        return self.__class__.from_xarray(self._data.expand_dims(dims=dims))
291
292    def item(
293        self,
294        key: Union[
295            None, SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]]
296        ] = None,
297    ):
298        """Copy a tensor element to a standard Python scalar and return it."""
299        if key is None:
300            ret = self._data.item()
301        else:
302            ret = self[key]._data.item()
303
304        assert isinstance(ret, (bool, float, int))
305        return ret
306
307    def mean(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
308        return self.__class__.from_xarray(self._data.mean(dim=dim))
309
310    def pad(
311        self,
312        pad_width: PerAxis[PadWidthLike],
313        mode: PadMode = "symmetric",
314    ) -> Self:
315        pad_width = {a: PadWidth.create(p) for a, p in pad_width.items()}
316        return self.__class__.from_xarray(
317            self._data.pad(pad_width=pad_width, mode=mode)
318        )
319
320    def pad_to(
321        self,
322        sizes: PerAxis[int],
323        pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right",
324        mode: PadMode = "symmetric",
325    ) -> Self:
326        """pad `tensor` to match `sizes`"""
327        if isinstance(pad_where, str):
328            pad_axis_where: PerAxis[PadWhere] = {a: pad_where for a in self.dims}
329        else:
330            pad_axis_where = pad_where
331
332        pad_width: Dict[AxisId, PadWidth] = {}
333        for a, s_is in self.sizes.items():
334            if a not in sizes or sizes[a] == s_is:
335                pad_width[a] = PadWidth(0, 0)
336            elif s_is > sizes[a]:
337                pad_width[a] = PadWidth(0, 0)
338                logger.warning(
339                    "Cannot pad axis {} of size {} to smaller size {}",
340                    a,
341                    s_is,
342                    sizes[a],
343                )
344            elif a not in pad_axis_where:
345                raise ValueError(
346                    f"Don't know where to pad axis {a}, `pad_where`={pad_where}"
347                )
348            else:
349                pad_this_axis_where = pad_axis_where[a]
350                d = sizes[a] - s_is
351                if pad_this_axis_where == "left":
352                    pad_width[a] = PadWidth(d, 0)
353                elif pad_this_axis_where == "right":
354                    pad_width[a] = PadWidth(0, d)
355                elif pad_this_axis_where == "left_and_right":
356                    pad_width[a] = PadWidth(left := d // 2, d - left)
357                else:
358                    assert_never(pad_this_axis_where)
359
360        return self.pad(pad_width, mode)
361
362    def quantile(
363        self,
364        q: Union[float, Sequence[float]],
365        dim: Optional[Union[AxisId, Sequence[AxisId]]] = None,
366    ) -> Self:
367        assert (
368            isinstance(q, (float, int))
369            and q >= 0.0
370            or not isinstance(q, (float, int))
371            and all(qq >= 0.0 for qq in q)
372        )
373        assert (
374            isinstance(q, (float, int))
375            and q <= 1.0
376            or not isinstance(q, (float, int))
377            and all(qq <= 1.0 for qq in q)
378        )
379        assert dim is None or (
380            (quantile_dim := AxisId("quantile")) != dim and quantile_dim not in set(dim)
381        )
382        return self.__class__.from_xarray(self._data.quantile(q, dim=dim))
383
384    def resize_to(
385        self,
386        sizes: PerAxis[int],
387        *,
388        pad_where: Union[
389            PadWhere,
390            PerAxis[PadWhere],
391        ] = "left_and_right",
392        crop_where: Union[
393            CropWhere,
394            PerAxis[CropWhere],
395        ] = "left_and_right",
396        pad_mode: PadMode = "symmetric",
397    ):
398        """return cropped/padded tensor with `sizes`"""
399        crop_to_sizes: Dict[AxisId, int] = {}
400        pad_to_sizes: Dict[AxisId, int] = {}
401        new_axes = dict(sizes)
402        for a, s_is in self.sizes.items():
403            a = AxisId(str(a))
404            _ = new_axes.pop(a, None)
405            if a not in sizes or sizes[a] == s_is:
406                pass
407            elif s_is > sizes[a]:
408                crop_to_sizes[a] = sizes[a]
409            else:
410                pad_to_sizes[a] = sizes[a]
411
412        tensor = self
413        if crop_to_sizes:
414            tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where)
415
416        if pad_to_sizes:
417            tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode)
418
419        if new_axes:
420            tensor = tensor.expand_dims(new_axes)
421
422        return tensor
423
424    def std(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
425        return self.__class__.from_xarray(self._data.std(dim=dim))
426
427    def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
428        """Reduce this Tensor's data by applying sum along some dimension(s)."""
429        return self.__class__.from_xarray(self._data.sum(dim=dim))
430
431    def transpose(
432        self,
433        axes: Sequence[AxisId],
434    ) -> Self:
435        """return a transposed tensor
436
437        Args:
438            axes: the desired tensor axes
439        """
440        # expand missing tensor axes
441        missing_axes = tuple(a for a in axes if a not in self.dims)
442        array = self._data
443        if missing_axes:
444            array = array.expand_dims(missing_axes)
445
446        # transpose to the correct axis order
447        return self.__class__.from_xarray(array.transpose(*axes))
448
449    def var(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
450        return self.__class__.from_xarray(self._data.var(dim=dim))
451
452    @classmethod
453    def _interprete_array_wo_known_axes(cls, array: NDArray[Any]):
454        ndim = array.ndim
455        if ndim == 2:
456            current_axes = (
457                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[0]),
458                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[1]),
459            )
460        elif ndim == 3 and any(s <= 3 for s in array.shape):
461            current_axes = (
462                v0_5.ChannelAxis(
463                    channel_names=[
464                        v0_5.Identifier(f"channel{i}") for i in range(array.shape[0])
465                    ]
466                ),
467                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]),
468                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]),
469            )
470        elif ndim == 3:
471            current_axes = (
472                v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[0]),
473                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]),
474                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]),
475            )
476        elif ndim == 4:
477            current_axes = (
478                v0_5.ChannelAxis(
479                    channel_names=[
480                        v0_5.Identifier(f"channel{i}") for i in range(array.shape[0])
481                    ]
482                ),
483                v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[1]),
484                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[2]),
485                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[3]),
486            )
487        elif ndim == 5:
488            current_axes = (
489                v0_5.BatchAxis(),
490                v0_5.ChannelAxis(
491                    channel_names=[
492                        v0_5.Identifier(f"channel{i}") for i in range(array.shape[1])
493                    ]
494                ),
495                v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[2]),
496                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[3]),
497                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[4]),
498            )
499        else:
500            raise ValueError(f"Could not guess an axis mapping for {array.shape}")
501
502        return cls(array, dims=tuple(a.id for a in current_axes))

A wrapper around an xr.DataArray for better integration with bioimageio.spec and improved type annotations.

Tensor( array: numpy.ndarray[typing.Any, numpy.dtype[typing.Any]], dims: Sequence[Union[bioimageio.spec.model.v0_5.AxisId, Literal['b', 'i', 't', 'c', 'z', 'y', 'x'], Annotated[Union[bioimageio.spec.model.v0_5.BatchAxis, bioimageio.spec.model.v0_5.ChannelAxis, bioimageio.spec.model.v0_5.IndexInputAxis, bioimageio.spec.model.v0_5.TimeInputAxis, bioimageio.spec.model.v0_5.SpaceInputAxis], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.spec.model.v0_5.BatchAxis, bioimageio.spec.model.v0_5.ChannelAxis, bioimageio.spec.model.v0_5.IndexOutputAxis, Annotated[Union[Annotated[bioimageio.spec.model.v0_5.TimeOutputAxis, Tag(tag='wo_halo')], Annotated[bioimageio.spec.model.v0_5.TimeOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[Annotated[bioimageio.spec.model.v0_5.SpaceOutputAxis, Tag(tag='wo_halo')], Annotated[bioimageio.spec.model.v0_5.SpaceOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], bioimageio.core.Axis]])
55    def __init__(
56        self,
57        array: NDArray[Any],
58        dims: Sequence[Union[AxisId, AxisLike]],
59    ) -> None:
60        super().__init__()
61        axes = tuple(
62            a if isinstance(a, AxisId) else AxisInfo.create(a).id for a in dims
63        )
64        self._data = xr.DataArray(array, dims=axes)
@classmethod
def from_xarray(cls, data_array: xarray.core.dataarray.DataArray) -> Self:
135    @classmethod
136    def from_xarray(cls, data_array: xr.DataArray) -> Self:
137        """create a `Tensor` from an xarray data array
138
139        note for internal use: this factory method is round-trip save
140            for any `Tensor`'s  `data` property (an xarray.DataArray).
141        """
142        return cls(
143            array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims)
144        )

create a Tensor from an xarray data array

note for internal use: this factory method is round-trip save for any Tensor's data property (an xarray.DataArray).

@classmethod
def from_numpy( cls, array: numpy.ndarray[typing.Any, numpy.dtype[typing.Any]], *, dims: Union[bioimageio.spec.model.v0_5.AxisId, Literal['b', 'i', 't', 'c', 'z', 'y', 'x'], Annotated[Union[bioimageio.spec.model.v0_5.BatchAxis, bioimageio.spec.model.v0_5.ChannelAxis, bioimageio.spec.model.v0_5.IndexInputAxis, bioimageio.spec.model.v0_5.TimeInputAxis, bioimageio.spec.model.v0_5.SpaceInputAxis], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.spec.model.v0_5.BatchAxis, bioimageio.spec.model.v0_5.ChannelAxis, bioimageio.spec.model.v0_5.IndexOutputAxis, Annotated[Union[Annotated[bioimageio.spec.model.v0_5.TimeOutputAxis, Tag(tag='wo_halo')], Annotated[bioimageio.spec.model.v0_5.TimeOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[Annotated[bioimageio.spec.model.v0_5.SpaceOutputAxis, Tag(tag='wo_halo')], Annotated[bioimageio.spec.model.v0_5.SpaceOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], bioimageio.core.Axis, Sequence[Union[bioimageio.spec.model.v0_5.AxisId, Literal['b', 'i', 't', 'c', 'z', 'y', 'x'], Annotated[Union[bioimageio.spec.model.v0_5.BatchAxis, bioimageio.spec.model.v0_5.ChannelAxis, bioimageio.spec.model.v0_5.IndexInputAxis, bioimageio.spec.model.v0_5.TimeInputAxis, bioimageio.spec.model.v0_5.SpaceInputAxis], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.spec.model.v0_5.BatchAxis, bioimageio.spec.model.v0_5.ChannelAxis, bioimageio.spec.model.v0_5.IndexOutputAxis, Annotated[Union[Annotated[bioimageio.spec.model.v0_5.TimeOutputAxis, Tag(tag='wo_halo')], Annotated[bioimageio.spec.model.v0_5.TimeOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[Annotated[bioimageio.spec.model.v0_5.SpaceOutputAxis, Tag(tag='wo_halo')], Annotated[bioimageio.spec.model.v0_5.SpaceOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], bioimageio.core.Axis]], NoneType]) -> Tensor:
146    @classmethod
147    def from_numpy(
148        cls,
149        array: NDArray[Any],
150        *,
151        dims: Optional[Union[AxisLike, Sequence[AxisLike]]],
152    ) -> Tensor:
153        """create a `Tensor` from a numpy array
154
155        Args:
156            array: the nd numpy array
157            axes: A description of the array's axes,
158                if None axes are guessed (which might fail and raise a ValueError.)
159
160        Raises:
161            ValueError: if `axes` is None and axes guessing fails.
162        """
163
164        if dims is None:
165            return cls._interprete_array_wo_known_axes(array)
166        elif isinstance(dims, (str, Axis, v0_5.AxisBase)):
167            dims = [dims]
168
169        axis_infos = [AxisInfo.create(a) for a in dims]
170        original_shape = tuple(array.shape)
171
172        successful_view = _get_array_view(array, axis_infos)
173        if successful_view is None:
174            raise ValueError(
175                f"Array shape {original_shape} does not map to axes {dims}"
176            )
177
178        return Tensor(successful_view, dims=tuple(a.id for a in axis_infos))

create a Tensor from a numpy array

Arguments:
  • array: the nd numpy array
  • axes: A description of the array's axes, if None axes are guessed (which might fail and raise a ValueError.)
Raises:
  • ValueError: if axes is None and axes guessing fails.
data
180    @property
181    def data(self):
182        return self._data
dims
184    @property
185    def dims(self):  # TODO: rename to `axes`?
186        """Tuple of dimension names associated with this tensor."""
187        return cast(Tuple[AxisId, ...], self._data.dims)

Tuple of dimension names associated with this tensor.

dtype: Literal['bool', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64']
189    @property
190    def dtype(self) -> DTypeStr:
191        dt = str(self.data.dtype)  # pyright: ignore[reportUnknownArgumentType]
192        assert dt in get_args(DTypeStr)
193        return dt  # pyright: ignore[reportReturnType]
ndim
195    @property
196    def ndim(self):
197        """Number of tensor dimensions."""
198        return self._data.ndim

Number of tensor dimensions.

shape
200    @property
201    def shape(self):
202        """Tuple of tensor axes lengths"""
203        return self._data.shape

Tuple of tensor axes lengths

shape_tuple
205    @property
206    def shape_tuple(self):
207        """Tuple of tensor axes lengths"""
208        return self._data.shape

Tuple of tensor axes lengths

size
210    @property
211    def size(self):
212        """Number of elements in the tensor.
213
214        Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions.
215        """
216        return self._data.size

Number of elements in the tensor.

Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions.

sizes
218    @property
219    def sizes(self):
220        """Ordered, immutable mapping from axis ids to axis lengths."""
221        return cast(Mapping[AxisId, int], self.data.sizes)

Ordered, immutable mapping from axis ids to axis lengths.

tagged_shape
223    @property
224    def tagged_shape(self):
225        """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths."""
226        return self.sizes

(alias for sizes) Ordered, immutable mapping from axis ids to lengths.

def argmax(self) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
228    def argmax(self) -> Mapping[AxisId, int]:
229        ret = self._data.argmax(...)
230        assert isinstance(ret, dict)
231        return {cast(AxisId, k): cast(int, v.item()) for k, v in ret.items()}
def astype( self, dtype: Literal['bool', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64'], *, copy: bool = False):
233    def astype(self, dtype: DTypeStr, *, copy: bool = False):
234        """Return tensor cast to `dtype`
235
236        note: if dtype is already satisfied copy if `copy`"""
237        return self.__class__.from_xarray(self._data.astype(dtype, copy=copy))

Return tensor cast to dtype

note: if dtype is already satisfied copy if copy

def clip(self, min: Optional[float] = None, max: Optional[float] = None):
239    def clip(self, min: Optional[float] = None, max: Optional[float] = None):
240        """Return a tensor whose values are limited to [min, max].
241        At least one of max or min must be given."""
242        return self.__class__.from_xarray(self._data.clip(min, max))

Return a tensor whose values are limited to [min, max]. At least one of max or min must be given.

def crop_to( self, sizes: Mapping[bioimageio.spec.model.v0_5.AxisId, int], crop_where: Union[Literal['left', 'right', 'left_and_right'], Mapping[bioimageio.spec.model.v0_5.AxisId, Literal['left', 'right', 'left_and_right']]] = 'left_and_right') -> Self:
244    def crop_to(
245        self,
246        sizes: PerAxis[int],
247        crop_where: Union[
248            CropWhere,
249            PerAxis[CropWhere],
250        ] = "left_and_right",
251    ) -> Self:
252        """crop to match `sizes`"""
253        if isinstance(crop_where, str):
254            crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims}
255        else:
256            crop_axis_where = crop_where
257
258        slices: Dict[AxisId, SliceInfo] = {}
259
260        for a, s_is in self.sizes.items():
261            if a not in sizes or sizes[a] == s_is:
262                pass
263            elif sizes[a] > s_is:
264                logger.warning(
265                    "Cannot crop axis {} of size {} to larger size {}",
266                    a,
267                    s_is,
268                    sizes[a],
269                )
270            elif a not in crop_axis_where:
271                raise ValueError(
272                    f"Don't know where to crop axis {a}, `crop_where`={crop_where}"
273                )
274            else:
275                crop_this_axis_where = crop_axis_where[a]
276                if crop_this_axis_where == "left":
277                    slices[a] = SliceInfo(s_is - sizes[a], s_is)
278                elif crop_this_axis_where == "right":
279                    slices[a] = SliceInfo(0, sizes[a])
280                elif crop_this_axis_where == "left_and_right":
281                    slices[a] = SliceInfo(
282                        start := (s_is - sizes[a]) // 2, sizes[a] + start
283                    )
284                else:
285                    assert_never(crop_this_axis_where)
286
287        return self[slices]

crop to match sizes

def expand_dims( self, dims: Union[Sequence[bioimageio.spec.model.v0_5.AxisId], Mapping[bioimageio.spec.model.v0_5.AxisId, int]]) -> Self:
289    def expand_dims(self, dims: Union[Sequence[AxisId], PerAxis[int]]) -> Self:
290        return self.__class__.from_xarray(self._data.expand_dims(dims=dims))
def item( self, key: Union[NoneType, bioimageio.core.common.SliceInfo, slice, int, Mapping[bioimageio.spec.model.v0_5.AxisId, Union[bioimageio.core.common.SliceInfo, slice, int]]] = None):
292    def item(
293        self,
294        key: Union[
295            None, SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]]
296        ] = None,
297    ):
298        """Copy a tensor element to a standard Python scalar and return it."""
299        if key is None:
300            ret = self._data.item()
301        else:
302            ret = self[key]._data.item()
303
304        assert isinstance(ret, (bool, float, int))
305        return ret

Copy a tensor element to a standard Python scalar and return it.

def mean( self, dim: Union[bioimageio.spec.model.v0_5.AxisId, Sequence[bioimageio.spec.model.v0_5.AxisId], NoneType] = None) -> Self:
307    def mean(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
308        return self.__class__.from_xarray(self._data.mean(dim=dim))
def pad( self, pad_width: Mapping[bioimageio.spec.model.v0_5.AxisId, Union[int, Tuple[int, int], bioimageio.core.common.PadWidth]], mode: Literal['edge', 'reflect', 'symmetric'] = 'symmetric') -> Self:
310    def pad(
311        self,
312        pad_width: PerAxis[PadWidthLike],
313        mode: PadMode = "symmetric",
314    ) -> Self:
315        pad_width = {a: PadWidth.create(p) for a, p in pad_width.items()}
316        return self.__class__.from_xarray(
317            self._data.pad(pad_width=pad_width, mode=mode)
318        )
def pad_to( self, sizes: Mapping[bioimageio.spec.model.v0_5.AxisId, int], pad_where: Union[Literal['left', 'right', 'left_and_right'], Mapping[bioimageio.spec.model.v0_5.AxisId, Literal['left', 'right', 'left_and_right']]] = 'left_and_right', mode: Literal['edge', 'reflect', 'symmetric'] = 'symmetric') -> Self:
320    def pad_to(
321        self,
322        sizes: PerAxis[int],
323        pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right",
324        mode: PadMode = "symmetric",
325    ) -> Self:
326        """pad `tensor` to match `sizes`"""
327        if isinstance(pad_where, str):
328            pad_axis_where: PerAxis[PadWhere] = {a: pad_where for a in self.dims}
329        else:
330            pad_axis_where = pad_where
331
332        pad_width: Dict[AxisId, PadWidth] = {}
333        for a, s_is in self.sizes.items():
334            if a not in sizes or sizes[a] == s_is:
335                pad_width[a] = PadWidth(0, 0)
336            elif s_is > sizes[a]:
337                pad_width[a] = PadWidth(0, 0)
338                logger.warning(
339                    "Cannot pad axis {} of size {} to smaller size {}",
340                    a,
341                    s_is,
342                    sizes[a],
343                )
344            elif a not in pad_axis_where:
345                raise ValueError(
346                    f"Don't know where to pad axis {a}, `pad_where`={pad_where}"
347                )
348            else:
349                pad_this_axis_where = pad_axis_where[a]
350                d = sizes[a] - s_is
351                if pad_this_axis_where == "left":
352                    pad_width[a] = PadWidth(d, 0)
353                elif pad_this_axis_where == "right":
354                    pad_width[a] = PadWidth(0, d)
355                elif pad_this_axis_where == "left_and_right":
356                    pad_width[a] = PadWidth(left := d // 2, d - left)
357                else:
358                    assert_never(pad_this_axis_where)
359
360        return self.pad(pad_width, mode)

pad tensor to match sizes

def quantile( self, q: Union[float, Sequence[float]], dim: Union[bioimageio.spec.model.v0_5.AxisId, Sequence[bioimageio.spec.model.v0_5.AxisId], NoneType] = None) -> Self:
362    def quantile(
363        self,
364        q: Union[float, Sequence[float]],
365        dim: Optional[Union[AxisId, Sequence[AxisId]]] = None,
366    ) -> Self:
367        assert (
368            isinstance(q, (float, int))
369            and q >= 0.0
370            or not isinstance(q, (float, int))
371            and all(qq >= 0.0 for qq in q)
372        )
373        assert (
374            isinstance(q, (float, int))
375            and q <= 1.0
376            or not isinstance(q, (float, int))
377            and all(qq <= 1.0 for qq in q)
378        )
379        assert dim is None or (
380            (quantile_dim := AxisId("quantile")) != dim and quantile_dim not in set(dim)
381        )
382        return self.__class__.from_xarray(self._data.quantile(q, dim=dim))
def resize_to( self, sizes: Mapping[bioimageio.spec.model.v0_5.AxisId, int], *, pad_where: Union[Literal['left', 'right', 'left_and_right'], Mapping[bioimageio.spec.model.v0_5.AxisId, Literal['left', 'right', 'left_and_right']]] = 'left_and_right', crop_where: Union[Literal['left', 'right', 'left_and_right'], Mapping[bioimageio.spec.model.v0_5.AxisId, Literal['left', 'right', 'left_and_right']]] = 'left_and_right', pad_mode: Literal['edge', 'reflect', 'symmetric'] = 'symmetric'):
384    def resize_to(
385        self,
386        sizes: PerAxis[int],
387        *,
388        pad_where: Union[
389            PadWhere,
390            PerAxis[PadWhere],
391        ] = "left_and_right",
392        crop_where: Union[
393            CropWhere,
394            PerAxis[CropWhere],
395        ] = "left_and_right",
396        pad_mode: PadMode = "symmetric",
397    ):
398        """return cropped/padded tensor with `sizes`"""
399        crop_to_sizes: Dict[AxisId, int] = {}
400        pad_to_sizes: Dict[AxisId, int] = {}
401        new_axes = dict(sizes)
402        for a, s_is in self.sizes.items():
403            a = AxisId(str(a))
404            _ = new_axes.pop(a, None)
405            if a not in sizes or sizes[a] == s_is:
406                pass
407            elif s_is > sizes[a]:
408                crop_to_sizes[a] = sizes[a]
409            else:
410                pad_to_sizes[a] = sizes[a]
411
412        tensor = self
413        if crop_to_sizes:
414            tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where)
415
416        if pad_to_sizes:
417            tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode)
418
419        if new_axes:
420            tensor = tensor.expand_dims(new_axes)
421
422        return tensor

return cropped/padded tensor with sizes

def std( self, dim: Union[bioimageio.spec.model.v0_5.AxisId, Sequence[bioimageio.spec.model.v0_5.AxisId], NoneType] = None) -> Self:
424    def std(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
425        return self.__class__.from_xarray(self._data.std(dim=dim))
def sum( self, dim: Union[bioimageio.spec.model.v0_5.AxisId, Sequence[bioimageio.spec.model.v0_5.AxisId], NoneType] = None) -> Self:
427    def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
428        """Reduce this Tensor's data by applying sum along some dimension(s)."""
429        return self.__class__.from_xarray(self._data.sum(dim=dim))

Reduce this Tensor's data by applying sum along some dimension(s).

def transpose(self, axes: Sequence[bioimageio.spec.model.v0_5.AxisId]) -> Self:
431    def transpose(
432        self,
433        axes: Sequence[AxisId],
434    ) -> Self:
435        """return a transposed tensor
436
437        Args:
438            axes: the desired tensor axes
439        """
440        # expand missing tensor axes
441        missing_axes = tuple(a for a in axes if a not in self.dims)
442        array = self._data
443        if missing_axes:
444            array = array.expand_dims(missing_axes)
445
446        # transpose to the correct axis order
447        return self.__class__.from_xarray(array.transpose(*axes))

return a transposed tensor

Arguments:
  • axes: the desired tensor axes
def var( self, dim: Union[bioimageio.spec.model.v0_5.AxisId, Sequence[bioimageio.spec.model.v0_5.AxisId], NoneType] = None) -> Self:
449    def var(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
450        return self.__class__.from_xarray(self._data.var(dim=dim))