bioimageio.core.tensor

  1from __future__ import annotations
  2
  3import collections.abc
  4from itertools import permutations
  5from typing import (
  6    TYPE_CHECKING,
  7    Any,
  8    Callable,
  9    Dict,
 10    Iterator,
 11    Mapping,
 12    Optional,
 13    Sequence,
 14    Tuple,
 15    Union,
 16    cast,
 17    get_args,
 18)
 19
 20import numpy as np
 21import xarray as xr
 22from loguru import logger
 23from numpy.typing import DTypeLike, NDArray
 24from typing_extensions import Self, assert_never
 25
 26from bioimageio.spec.model import v0_5
 27
 28from ._magic_tensor_ops import MagicTensorOpsMixin
 29from .axis import Axis, AxisId, AxisInfo, AxisLike, PerAxis
 30from .common import (
 31    CropWhere,
 32    DTypeStr,
 33    PadMode,
 34    PadWhere,
 35    PadWidth,
 36    PadWidthLike,
 37    SliceInfo,
 38)
 39
 40if TYPE_CHECKING:
 41    from numpy.typing import ArrayLike, NDArray
 42
 43
 44_ScalarOrArray = Union["ArrayLike", np.generic, "NDArray[Any]"]  # TODO: add "DaskArray"
 45
 46
 47# TODO: complete docstrings
 48class Tensor(MagicTensorOpsMixin):
 49    """A wrapper around an xr.DataArray for better integration with bioimageio.spec
 50    and improved type annotations."""
 51
 52    _Compatible = Union["Tensor", xr.DataArray, _ScalarOrArray]
 53
 54    def __init__(
 55        self,
 56        array: NDArray[Any],
 57        dims: Sequence[Union[AxisId, AxisLike]],
 58    ) -> None:
 59        super().__init__()
 60        axes = tuple(
 61            a if isinstance(a, AxisId) else AxisInfo.create(a).id for a in dims
 62        )
 63        self._data = xr.DataArray(array, dims=axes)
 64
 65    def __array__(self, dtype: DTypeLike = None):
 66        return np.asarray(self._data, dtype=dtype)
 67
 68    def __getitem__(
 69        self,
 70        key: Union[
 71            SliceInfo,
 72            slice,
 73            int,
 74            PerAxis[Union[SliceInfo, slice, int]],
 75            Tensor,
 76            xr.DataArray,
 77        ],
 78    ) -> Self:
 79        if isinstance(key, SliceInfo):
 80            key = slice(*key)
 81        elif isinstance(key, collections.abc.Mapping):
 82            key = {
 83                a: s if isinstance(s, int) else s if isinstance(s, slice) else slice(*s)
 84                for a, s in key.items()
 85            }
 86        elif isinstance(key, Tensor):
 87            key = key._data
 88
 89        return self.__class__.from_xarray(self._data[key])
 90
 91    def __setitem__(
 92        self,
 93        key: Union[PerAxis[Union[SliceInfo, slice]], Tensor, xr.DataArray],
 94        value: Union[Tensor, xr.DataArray, float, int],
 95    ) -> None:
 96        if isinstance(key, Tensor):
 97            key = key._data
 98        elif isinstance(key, xr.DataArray):
 99            pass
100        else:
101            key = {a: s if isinstance(s, slice) else slice(*s) for a, s in key.items()}
102
103        if isinstance(value, Tensor):
104            value = value._data
105
106        self._data[key] = value
107
108    def __len__(self) -> int:
109        return len(self.data)
110
111    def _iter(self: Any) -> Iterator[Any]:
112        for n in range(len(self)):
113            yield self[n]
114
115    def __iter__(self: Any) -> Iterator[Any]:
116        if self.ndim == 0:
117            raise TypeError("iteration over a 0-d array")
118        return self._iter()
119
120    def _binary_op(
121        self,
122        other: _Compatible,
123        f: Callable[[Any, Any], Any],
124        reflexive: bool = False,
125    ) -> Self:
126        data = self._data._binary_op(  # pyright: ignore[reportPrivateUsage]
127            (other._data if isinstance(other, Tensor) else other),
128            f,
129            reflexive,
130        )
131        return self.__class__.from_xarray(data)
132
133    def _inplace_binary_op(
134        self,
135        other: _Compatible,
136        f: Callable[[Any, Any], Any],
137    ) -> Self:
138        _ = self._data._inplace_binary_op(  # pyright: ignore[reportPrivateUsage]
139            (
140                other_d
141                if (other_d := getattr(other, "data")) is not None
142                and isinstance(
143                    other_d,
144                    xr.DataArray,
145                )
146                else other
147            ),
148            f,
149        )
150        return self
151
152    def _unary_op(self, f: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Self:
153        data = self._data._unary_op(  # pyright: ignore[reportPrivateUsage]
154            f, *args, **kwargs
155        )
156        return self.__class__.from_xarray(data)
157
158    @classmethod
159    def from_xarray(cls, data_array: xr.DataArray) -> Self:
160        """create a `Tensor` from an xarray data array
161
162        note for internal use: this factory method is round-trip save
163            for any `Tensor`'s  `data` property (an xarray.DataArray).
164        """
165        return cls(
166            array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims)
167        )
168
169    @classmethod
170    def from_numpy(
171        cls,
172        array: NDArray[Any],
173        *,
174        dims: Optional[Union[AxisLike, Sequence[AxisLike]]],
175    ) -> Tensor:
176        """create a `Tensor` from a numpy array
177
178        Args:
179            array: the nd numpy array
180            axes: A description of the array's axes,
181                if None axes are guessed (which might fail and raise a ValueError.)
182
183        Raises:
184            ValueError: if `axes` is None and axes guessing fails.
185        """
186
187        if dims is None:
188            return cls._interprete_array_wo_known_axes(array)
189        elif isinstance(dims, (str, Axis, v0_5.AxisBase)):
190            dims = [dims]
191
192        axis_infos = [AxisInfo.create(a) for a in dims]
193        original_shape = tuple(array.shape)
194
195        successful_view = _get_array_view(array, axis_infos)
196        if successful_view is None:
197            raise ValueError(
198                f"Array shape {original_shape} does not map to axes {dims}"
199            )
200
201        return Tensor(successful_view, dims=tuple(a.id for a in axis_infos))
202
203    @property
204    def data(self):
205        return self._data
206
207    @property
208    def dims(self):  # TODO: rename to `axes`?
209        """Tuple of dimension names associated with this tensor."""
210        return cast(Tuple[AxisId, ...], self._data.dims)
211
212    @property
213    def dtype(self) -> DTypeStr:
214        dt = str(self.data.dtype)  # pyright: ignore[reportUnknownArgumentType]
215        assert dt in get_args(DTypeStr)
216        return dt  # pyright: ignore[reportReturnType]
217
218    @property
219    def ndim(self):
220        """Number of tensor dimensions."""
221        return self._data.ndim
222
223    @property
224    def shape(self):
225        """Tuple of tensor axes lengths"""
226        return self._data.shape
227
228    @property
229    def shape_tuple(self):
230        """Tuple of tensor axes lengths"""
231        return self._data.shape
232
233    @property
234    def size(self):
235        """Number of elements in the tensor.
236
237        Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions.
238        """
239        return self._data.size
240
241    @property
242    def sizes(self):
243        """Ordered, immutable mapping from axis ids to axis lengths."""
244        return cast(Mapping[AxisId, int], self.data.sizes)
245
246    @property
247    def tagged_shape(self):
248        """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths."""
249        return self.sizes
250
251    def argmax(self) -> Mapping[AxisId, int]:
252        ret = self._data.argmax(...)
253        assert isinstance(ret, dict)
254        return {cast(AxisId, k): cast(int, v.item()) for k, v in ret.items()}
255
256    def astype(self, dtype: DTypeStr, *, copy: bool = False):
257        """Return tensor cast to `dtype`
258
259        note: if dtype is already satisfied copy if `copy`"""
260        return self.__class__.from_xarray(self._data.astype(dtype, copy=copy))
261
262    def clip(self, min: Optional[float] = None, max: Optional[float] = None):
263        """Return a tensor whose values are limited to [min, max].
264        At least one of max or min must be given."""
265        return self.__class__.from_xarray(self._data.clip(min, max))
266
267    def crop_to(
268        self,
269        sizes: PerAxis[int],
270        crop_where: Union[
271            CropWhere,
272            PerAxis[CropWhere],
273        ] = "left_and_right",
274    ) -> Self:
275        """crop to match `sizes`"""
276        if isinstance(crop_where, str):
277            crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims}
278        else:
279            crop_axis_where = crop_where
280
281        slices: Dict[AxisId, SliceInfo] = {}
282
283        for a, s_is in self.sizes.items():
284            if a not in sizes or sizes[a] == s_is:
285                pass
286            elif sizes[a] > s_is:
287                logger.warning(
288                    "Cannot crop axis {} of size {} to larger size {}",
289                    a,
290                    s_is,
291                    sizes[a],
292                )
293            elif a not in crop_axis_where:
294                raise ValueError(
295                    f"Don't know where to crop axis {a}, `crop_where`={crop_where}"
296                )
297            else:
298                crop_this_axis_where = crop_axis_where[a]
299                if crop_this_axis_where == "left":
300                    slices[a] = SliceInfo(s_is - sizes[a], s_is)
301                elif crop_this_axis_where == "right":
302                    slices[a] = SliceInfo(0, sizes[a])
303                elif crop_this_axis_where == "left_and_right":
304                    slices[a] = SliceInfo(
305                        start := (s_is - sizes[a]) // 2, sizes[a] + start
306                    )
307                else:
308                    assert_never(crop_this_axis_where)
309
310        return self[slices]
311
312    def expand_dims(self, dims: Union[Sequence[AxisId], PerAxis[int]]) -> Self:
313        return self.__class__.from_xarray(self._data.expand_dims(dims=dims))
314
315    def item(
316        self,
317        key: Union[
318            None, SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]]
319        ] = None,
320    ):
321        """Copy a tensor element to a standard Python scalar and return it."""
322        if key is None:
323            ret = self._data.item()
324        else:
325            ret = self[key]._data.item()
326
327        assert isinstance(ret, (bool, float, int))
328        return ret
329
330    def mean(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
331        return self.__class__.from_xarray(self._data.mean(dim=dim))
332
333    def pad(
334        self,
335        pad_width: PerAxis[PadWidthLike],
336        mode: PadMode = "symmetric",
337    ) -> Self:
338        pad_width = {a: PadWidth.create(p) for a, p in pad_width.items()}
339        return self.__class__.from_xarray(
340            self._data.pad(pad_width=pad_width, mode=mode)
341        )
342
343    def pad_to(
344        self,
345        sizes: PerAxis[int],
346        pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right",
347        mode: PadMode = "symmetric",
348    ) -> Self:
349        """pad `tensor` to match `sizes`"""
350        if isinstance(pad_where, str):
351            pad_axis_where: PerAxis[PadWhere] = {a: pad_where for a in self.dims}
352        else:
353            pad_axis_where = pad_where
354
355        pad_width: Dict[AxisId, PadWidth] = {}
356        for a, s_is in self.sizes.items():
357            if a not in sizes or sizes[a] == s_is:
358                pad_width[a] = PadWidth(0, 0)
359            elif s_is > sizes[a]:
360                pad_width[a] = PadWidth(0, 0)
361                logger.warning(
362                    "Cannot pad axis {} of size {} to smaller size {}",
363                    a,
364                    s_is,
365                    sizes[a],
366                )
367            elif a not in pad_axis_where:
368                raise ValueError(
369                    f"Don't know where to pad axis {a}, `pad_where`={pad_where}"
370                )
371            else:
372                pad_this_axis_where = pad_axis_where[a]
373                d = sizes[a] - s_is
374                if pad_this_axis_where == "left":
375                    pad_width[a] = PadWidth(d, 0)
376                elif pad_this_axis_where == "right":
377                    pad_width[a] = PadWidth(0, d)
378                elif pad_this_axis_where == "left_and_right":
379                    pad_width[a] = PadWidth(left := d // 2, d - left)
380                else:
381                    assert_never(pad_this_axis_where)
382
383        return self.pad(pad_width, mode)
384
385    def quantile(
386        self,
387        q: Union[float, Sequence[float]],
388        dim: Optional[Union[AxisId, Sequence[AxisId]]] = None,
389    ) -> Self:
390        assert (
391            isinstance(q, (float, int))
392            and q >= 0.0
393            or not isinstance(q, (float, int))
394            and all(qq >= 0.0 for qq in q)
395        )
396        assert (
397            isinstance(q, (float, int))
398            and q <= 1.0
399            or not isinstance(q, (float, int))
400            and all(qq <= 1.0 for qq in q)
401        )
402        assert dim is None or (
403            (quantile_dim := AxisId("quantile")) != dim and quantile_dim not in set(dim)
404        )
405        return self.__class__.from_xarray(self._data.quantile(q, dim=dim))
406
407    def resize_to(
408        self,
409        sizes: PerAxis[int],
410        *,
411        pad_where: Union[
412            PadWhere,
413            PerAxis[PadWhere],
414        ] = "left_and_right",
415        crop_where: Union[
416            CropWhere,
417            PerAxis[CropWhere],
418        ] = "left_and_right",
419        pad_mode: PadMode = "symmetric",
420    ):
421        """return cropped/padded tensor with `sizes`"""
422        crop_to_sizes: Dict[AxisId, int] = {}
423        pad_to_sizes: Dict[AxisId, int] = {}
424        new_axes = dict(sizes)
425        for a, s_is in self.sizes.items():
426            a = AxisId(str(a))
427            _ = new_axes.pop(a, None)
428            if a not in sizes or sizes[a] == s_is:
429                pass
430            elif s_is > sizes[a]:
431                crop_to_sizes[a] = sizes[a]
432            else:
433                pad_to_sizes[a] = sizes[a]
434
435        tensor = self
436        if crop_to_sizes:
437            tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where)
438
439        if pad_to_sizes:
440            tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode)
441
442        if new_axes:
443            tensor = tensor.expand_dims(new_axes)
444
445        return tensor
446
447    def std(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
448        return self.__class__.from_xarray(self._data.std(dim=dim))
449
450    def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
451        """Reduce this Tensor's data by applying sum along some dimension(s)."""
452        return self.__class__.from_xarray(self._data.sum(dim=dim))
453
454    def transpose(
455        self,
456        axes: Sequence[AxisId],
457    ) -> Self:
458        """return a transposed tensor
459
460        Args:
461            axes: the desired tensor axes
462        """
463        # expand missing tensor axes
464        missing_axes = tuple(a for a in axes if a not in self.dims)
465        array = self._data
466        if missing_axes:
467            array = array.expand_dims(missing_axes)
468
469        # transpose to the correct axis order
470        return self.__class__.from_xarray(array.transpose(*axes))
471
472    def var(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
473        return self.__class__.from_xarray(self._data.var(dim=dim))
474
475    @classmethod
476    def _interprete_array_wo_known_axes(cls, array: NDArray[Any]):
477        ndim = array.ndim
478        if ndim == 2:
479            current_axes = (
480                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[0]),
481                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[1]),
482            )
483        elif ndim == 3 and any(s <= 3 for s in array.shape):
484            current_axes = (
485                v0_5.ChannelAxis(
486                    channel_names=[
487                        v0_5.Identifier(f"channel{i}") for i in range(array.shape[0])
488                    ]
489                ),
490                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]),
491                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]),
492            )
493        elif ndim == 3:
494            current_axes = (
495                v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[0]),
496                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]),
497                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]),
498            )
499        elif ndim == 4:
500            current_axes = (
501                v0_5.ChannelAxis(
502                    channel_names=[
503                        v0_5.Identifier(f"channel{i}") for i in range(array.shape[0])
504                    ]
505                ),
506                v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[1]),
507                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[2]),
508                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[3]),
509            )
510        elif ndim == 5:
511            current_axes = (
512                v0_5.BatchAxis(),
513                v0_5.ChannelAxis(
514                    channel_names=[
515                        v0_5.Identifier(f"channel{i}") for i in range(array.shape[1])
516                    ]
517                ),
518                v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[2]),
519                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[3]),
520                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[4]),
521            )
522        else:
523            raise ValueError(f"Could not guess an axis mapping for {array.shape}")
524
525        return cls(array, dims=tuple(a.id for a in current_axes))
526
527
528def _add_singletons(arr: NDArray[Any], axis_infos: Sequence[AxisInfo]):
529    if len(arr.shape) > len(axis_infos):
530        # remove singletons
531        for i, s in enumerate(arr.shape):
532            if s == 1:
533                arr = np.take(arr, 0, axis=i)
534                if len(arr.shape) == len(axis_infos):
535                    break
536
537    # add singletons if nececsary
538    for i, a in enumerate(axis_infos):
539        if len(arr.shape) >= len(axis_infos):
540            break
541
542        if a.maybe_singleton:
543            arr = np.expand_dims(arr, i)
544
545    return arr
546
547
548def _get_array_view(
549    original_array: NDArray[Any], axis_infos: Sequence[AxisInfo]
550) -> Optional[NDArray[Any]]:
551    perms = list(permutations(range(len(original_array.shape))))
552    perms.insert(1, perms.pop())  # try A and A.T first
553
554    for perm in perms:
555        view = original_array.transpose(perm)
556        view = _add_singletons(view, axis_infos)
557        if len(view.shape) != len(axis_infos):
558            return None
559
560        for s, a in zip(view.shape, axis_infos):
561            if s == 1 and not a.maybe_singleton:
562                break
563        else:
564            return view
565
566    return None
class Tensor(bioimageio.core._magic_tensor_ops.MagicTensorOpsMixin):
 49class Tensor(MagicTensorOpsMixin):
 50    """A wrapper around an xr.DataArray for better integration with bioimageio.spec
 51    and improved type annotations."""
 52
 53    _Compatible = Union["Tensor", xr.DataArray, _ScalarOrArray]
 54
 55    def __init__(
 56        self,
 57        array: NDArray[Any],
 58        dims: Sequence[Union[AxisId, AxisLike]],
 59    ) -> None:
 60        super().__init__()
 61        axes = tuple(
 62            a if isinstance(a, AxisId) else AxisInfo.create(a).id for a in dims
 63        )
 64        self._data = xr.DataArray(array, dims=axes)
 65
 66    def __array__(self, dtype: DTypeLike = None):
 67        return np.asarray(self._data, dtype=dtype)
 68
 69    def __getitem__(
 70        self,
 71        key: Union[
 72            SliceInfo,
 73            slice,
 74            int,
 75            PerAxis[Union[SliceInfo, slice, int]],
 76            Tensor,
 77            xr.DataArray,
 78        ],
 79    ) -> Self:
 80        if isinstance(key, SliceInfo):
 81            key = slice(*key)
 82        elif isinstance(key, collections.abc.Mapping):
 83            key = {
 84                a: s if isinstance(s, int) else s if isinstance(s, slice) else slice(*s)
 85                for a, s in key.items()
 86            }
 87        elif isinstance(key, Tensor):
 88            key = key._data
 89
 90        return self.__class__.from_xarray(self._data[key])
 91
 92    def __setitem__(
 93        self,
 94        key: Union[PerAxis[Union[SliceInfo, slice]], Tensor, xr.DataArray],
 95        value: Union[Tensor, xr.DataArray, float, int],
 96    ) -> None:
 97        if isinstance(key, Tensor):
 98            key = key._data
 99        elif isinstance(key, xr.DataArray):
100            pass
101        else:
102            key = {a: s if isinstance(s, slice) else slice(*s) for a, s in key.items()}
103
104        if isinstance(value, Tensor):
105            value = value._data
106
107        self._data[key] = value
108
109    def __len__(self) -> int:
110        return len(self.data)
111
112    def _iter(self: Any) -> Iterator[Any]:
113        for n in range(len(self)):
114            yield self[n]
115
116    def __iter__(self: Any) -> Iterator[Any]:
117        if self.ndim == 0:
118            raise TypeError("iteration over a 0-d array")
119        return self._iter()
120
121    def _binary_op(
122        self,
123        other: _Compatible,
124        f: Callable[[Any, Any], Any],
125        reflexive: bool = False,
126    ) -> Self:
127        data = self._data._binary_op(  # pyright: ignore[reportPrivateUsage]
128            (other._data if isinstance(other, Tensor) else other),
129            f,
130            reflexive,
131        )
132        return self.__class__.from_xarray(data)
133
134    def _inplace_binary_op(
135        self,
136        other: _Compatible,
137        f: Callable[[Any, Any], Any],
138    ) -> Self:
139        _ = self._data._inplace_binary_op(  # pyright: ignore[reportPrivateUsage]
140            (
141                other_d
142                if (other_d := getattr(other, "data")) is not None
143                and isinstance(
144                    other_d,
145                    xr.DataArray,
146                )
147                else other
148            ),
149            f,
150        )
151        return self
152
153    def _unary_op(self, f: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Self:
154        data = self._data._unary_op(  # pyright: ignore[reportPrivateUsage]
155            f, *args, **kwargs
156        )
157        return self.__class__.from_xarray(data)
158
159    @classmethod
160    def from_xarray(cls, data_array: xr.DataArray) -> Self:
161        """create a `Tensor` from an xarray data array
162
163        note for internal use: this factory method is round-trip save
164            for any `Tensor`'s  `data` property (an xarray.DataArray).
165        """
166        return cls(
167            array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims)
168        )
169
170    @classmethod
171    def from_numpy(
172        cls,
173        array: NDArray[Any],
174        *,
175        dims: Optional[Union[AxisLike, Sequence[AxisLike]]],
176    ) -> Tensor:
177        """create a `Tensor` from a numpy array
178
179        Args:
180            array: the nd numpy array
181            axes: A description of the array's axes,
182                if None axes are guessed (which might fail and raise a ValueError.)
183
184        Raises:
185            ValueError: if `axes` is None and axes guessing fails.
186        """
187
188        if dims is None:
189            return cls._interprete_array_wo_known_axes(array)
190        elif isinstance(dims, (str, Axis, v0_5.AxisBase)):
191            dims = [dims]
192
193        axis_infos = [AxisInfo.create(a) for a in dims]
194        original_shape = tuple(array.shape)
195
196        successful_view = _get_array_view(array, axis_infos)
197        if successful_view is None:
198            raise ValueError(
199                f"Array shape {original_shape} does not map to axes {dims}"
200            )
201
202        return Tensor(successful_view, dims=tuple(a.id for a in axis_infos))
203
204    @property
205    def data(self):
206        return self._data
207
208    @property
209    def dims(self):  # TODO: rename to `axes`?
210        """Tuple of dimension names associated with this tensor."""
211        return cast(Tuple[AxisId, ...], self._data.dims)
212
213    @property
214    def dtype(self) -> DTypeStr:
215        dt = str(self.data.dtype)  # pyright: ignore[reportUnknownArgumentType]
216        assert dt in get_args(DTypeStr)
217        return dt  # pyright: ignore[reportReturnType]
218
219    @property
220    def ndim(self):
221        """Number of tensor dimensions."""
222        return self._data.ndim
223
224    @property
225    def shape(self):
226        """Tuple of tensor axes lengths"""
227        return self._data.shape
228
229    @property
230    def shape_tuple(self):
231        """Tuple of tensor axes lengths"""
232        return self._data.shape
233
234    @property
235    def size(self):
236        """Number of elements in the tensor.
237
238        Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions.
239        """
240        return self._data.size
241
242    @property
243    def sizes(self):
244        """Ordered, immutable mapping from axis ids to axis lengths."""
245        return cast(Mapping[AxisId, int], self.data.sizes)
246
247    @property
248    def tagged_shape(self):
249        """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths."""
250        return self.sizes
251
252    def argmax(self) -> Mapping[AxisId, int]:
253        ret = self._data.argmax(...)
254        assert isinstance(ret, dict)
255        return {cast(AxisId, k): cast(int, v.item()) for k, v in ret.items()}
256
257    def astype(self, dtype: DTypeStr, *, copy: bool = False):
258        """Return tensor cast to `dtype`
259
260        note: if dtype is already satisfied copy if `copy`"""
261        return self.__class__.from_xarray(self._data.astype(dtype, copy=copy))
262
263    def clip(self, min: Optional[float] = None, max: Optional[float] = None):
264        """Return a tensor whose values are limited to [min, max].
265        At least one of max or min must be given."""
266        return self.__class__.from_xarray(self._data.clip(min, max))
267
268    def crop_to(
269        self,
270        sizes: PerAxis[int],
271        crop_where: Union[
272            CropWhere,
273            PerAxis[CropWhere],
274        ] = "left_and_right",
275    ) -> Self:
276        """crop to match `sizes`"""
277        if isinstance(crop_where, str):
278            crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims}
279        else:
280            crop_axis_where = crop_where
281
282        slices: Dict[AxisId, SliceInfo] = {}
283
284        for a, s_is in self.sizes.items():
285            if a not in sizes or sizes[a] == s_is:
286                pass
287            elif sizes[a] > s_is:
288                logger.warning(
289                    "Cannot crop axis {} of size {} to larger size {}",
290                    a,
291                    s_is,
292                    sizes[a],
293                )
294            elif a not in crop_axis_where:
295                raise ValueError(
296                    f"Don't know where to crop axis {a}, `crop_where`={crop_where}"
297                )
298            else:
299                crop_this_axis_where = crop_axis_where[a]
300                if crop_this_axis_where == "left":
301                    slices[a] = SliceInfo(s_is - sizes[a], s_is)
302                elif crop_this_axis_where == "right":
303                    slices[a] = SliceInfo(0, sizes[a])
304                elif crop_this_axis_where == "left_and_right":
305                    slices[a] = SliceInfo(
306                        start := (s_is - sizes[a]) // 2, sizes[a] + start
307                    )
308                else:
309                    assert_never(crop_this_axis_where)
310
311        return self[slices]
312
313    def expand_dims(self, dims: Union[Sequence[AxisId], PerAxis[int]]) -> Self:
314        return self.__class__.from_xarray(self._data.expand_dims(dims=dims))
315
316    def item(
317        self,
318        key: Union[
319            None, SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]]
320        ] = None,
321    ):
322        """Copy a tensor element to a standard Python scalar and return it."""
323        if key is None:
324            ret = self._data.item()
325        else:
326            ret = self[key]._data.item()
327
328        assert isinstance(ret, (bool, float, int))
329        return ret
330
331    def mean(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
332        return self.__class__.from_xarray(self._data.mean(dim=dim))
333
334    def pad(
335        self,
336        pad_width: PerAxis[PadWidthLike],
337        mode: PadMode = "symmetric",
338    ) -> Self:
339        pad_width = {a: PadWidth.create(p) for a, p in pad_width.items()}
340        return self.__class__.from_xarray(
341            self._data.pad(pad_width=pad_width, mode=mode)
342        )
343
344    def pad_to(
345        self,
346        sizes: PerAxis[int],
347        pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right",
348        mode: PadMode = "symmetric",
349    ) -> Self:
350        """pad `tensor` to match `sizes`"""
351        if isinstance(pad_where, str):
352            pad_axis_where: PerAxis[PadWhere] = {a: pad_where for a in self.dims}
353        else:
354            pad_axis_where = pad_where
355
356        pad_width: Dict[AxisId, PadWidth] = {}
357        for a, s_is in self.sizes.items():
358            if a not in sizes or sizes[a] == s_is:
359                pad_width[a] = PadWidth(0, 0)
360            elif s_is > sizes[a]:
361                pad_width[a] = PadWidth(0, 0)
362                logger.warning(
363                    "Cannot pad axis {} of size {} to smaller size {}",
364                    a,
365                    s_is,
366                    sizes[a],
367                )
368            elif a not in pad_axis_where:
369                raise ValueError(
370                    f"Don't know where to pad axis {a}, `pad_where`={pad_where}"
371                )
372            else:
373                pad_this_axis_where = pad_axis_where[a]
374                d = sizes[a] - s_is
375                if pad_this_axis_where == "left":
376                    pad_width[a] = PadWidth(d, 0)
377                elif pad_this_axis_where == "right":
378                    pad_width[a] = PadWidth(0, d)
379                elif pad_this_axis_where == "left_and_right":
380                    pad_width[a] = PadWidth(left := d // 2, d - left)
381                else:
382                    assert_never(pad_this_axis_where)
383
384        return self.pad(pad_width, mode)
385
386    def quantile(
387        self,
388        q: Union[float, Sequence[float]],
389        dim: Optional[Union[AxisId, Sequence[AxisId]]] = None,
390    ) -> Self:
391        assert (
392            isinstance(q, (float, int))
393            and q >= 0.0
394            or not isinstance(q, (float, int))
395            and all(qq >= 0.0 for qq in q)
396        )
397        assert (
398            isinstance(q, (float, int))
399            and q <= 1.0
400            or not isinstance(q, (float, int))
401            and all(qq <= 1.0 for qq in q)
402        )
403        assert dim is None or (
404            (quantile_dim := AxisId("quantile")) != dim and quantile_dim not in set(dim)
405        )
406        return self.__class__.from_xarray(self._data.quantile(q, dim=dim))
407
408    def resize_to(
409        self,
410        sizes: PerAxis[int],
411        *,
412        pad_where: Union[
413            PadWhere,
414            PerAxis[PadWhere],
415        ] = "left_and_right",
416        crop_where: Union[
417            CropWhere,
418            PerAxis[CropWhere],
419        ] = "left_and_right",
420        pad_mode: PadMode = "symmetric",
421    ):
422        """return cropped/padded tensor with `sizes`"""
423        crop_to_sizes: Dict[AxisId, int] = {}
424        pad_to_sizes: Dict[AxisId, int] = {}
425        new_axes = dict(sizes)
426        for a, s_is in self.sizes.items():
427            a = AxisId(str(a))
428            _ = new_axes.pop(a, None)
429            if a not in sizes or sizes[a] == s_is:
430                pass
431            elif s_is > sizes[a]:
432                crop_to_sizes[a] = sizes[a]
433            else:
434                pad_to_sizes[a] = sizes[a]
435
436        tensor = self
437        if crop_to_sizes:
438            tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where)
439
440        if pad_to_sizes:
441            tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode)
442
443        if new_axes:
444            tensor = tensor.expand_dims(new_axes)
445
446        return tensor
447
448    def std(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
449        return self.__class__.from_xarray(self._data.std(dim=dim))
450
451    def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
452        """Reduce this Tensor's data by applying sum along some dimension(s)."""
453        return self.__class__.from_xarray(self._data.sum(dim=dim))
454
455    def transpose(
456        self,
457        axes: Sequence[AxisId],
458    ) -> Self:
459        """return a transposed tensor
460
461        Args:
462            axes: the desired tensor axes
463        """
464        # expand missing tensor axes
465        missing_axes = tuple(a for a in axes if a not in self.dims)
466        array = self._data
467        if missing_axes:
468            array = array.expand_dims(missing_axes)
469
470        # transpose to the correct axis order
471        return self.__class__.from_xarray(array.transpose(*axes))
472
473    def var(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
474        return self.__class__.from_xarray(self._data.var(dim=dim))
475
476    @classmethod
477    def _interprete_array_wo_known_axes(cls, array: NDArray[Any]):
478        ndim = array.ndim
479        if ndim == 2:
480            current_axes = (
481                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[0]),
482                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[1]),
483            )
484        elif ndim == 3 and any(s <= 3 for s in array.shape):
485            current_axes = (
486                v0_5.ChannelAxis(
487                    channel_names=[
488                        v0_5.Identifier(f"channel{i}") for i in range(array.shape[0])
489                    ]
490                ),
491                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]),
492                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]),
493            )
494        elif ndim == 3:
495            current_axes = (
496                v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[0]),
497                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]),
498                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]),
499            )
500        elif ndim == 4:
501            current_axes = (
502                v0_5.ChannelAxis(
503                    channel_names=[
504                        v0_5.Identifier(f"channel{i}") for i in range(array.shape[0])
505                    ]
506                ),
507                v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[1]),
508                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[2]),
509                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[3]),
510            )
511        elif ndim == 5:
512            current_axes = (
513                v0_5.BatchAxis(),
514                v0_5.ChannelAxis(
515                    channel_names=[
516                        v0_5.Identifier(f"channel{i}") for i in range(array.shape[1])
517                    ]
518                ),
519                v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[2]),
520                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[3]),
521                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[4]),
522            )
523        else:
524            raise ValueError(f"Could not guess an axis mapping for {array.shape}")
525
526        return cls(array, dims=tuple(a.id for a in current_axes))

A wrapper around an xr.DataArray for better integration with bioimageio.spec and improved type annotations.

Tensor( array: numpy.ndarray[typing.Any, numpy.dtype[typing.Any]], dims: Sequence[Union[bioimageio.spec.model.v0_5.AxisId, Literal['b', 'i', 't', 'c', 'z', 'y', 'x'], Annotated[Union[bioimageio.spec.model.v0_5.BatchAxis, bioimageio.spec.model.v0_5.ChannelAxis, bioimageio.spec.model.v0_5.IndexInputAxis, bioimageio.spec.model.v0_5.TimeInputAxis, bioimageio.spec.model.v0_5.SpaceInputAxis], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.spec.model.v0_5.BatchAxis, bioimageio.spec.model.v0_5.ChannelAxis, bioimageio.spec.model.v0_5.IndexOutputAxis, Annotated[Union[Annotated[bioimageio.spec.model.v0_5.TimeOutputAxis, Tag(tag='wo_halo')], Annotated[bioimageio.spec.model.v0_5.TimeOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[Annotated[bioimageio.spec.model.v0_5.SpaceOutputAxis, Tag(tag='wo_halo')], Annotated[bioimageio.spec.model.v0_5.SpaceOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], bioimageio.core.Axis]])
55    def __init__(
56        self,
57        array: NDArray[Any],
58        dims: Sequence[Union[AxisId, AxisLike]],
59    ) -> None:
60        super().__init__()
61        axes = tuple(
62            a if isinstance(a, AxisId) else AxisInfo.create(a).id for a in dims
63        )
64        self._data = xr.DataArray(array, dims=axes)
@classmethod
def from_xarray(cls, data_array: xarray.core.dataarray.DataArray) -> Self:
159    @classmethod
160    def from_xarray(cls, data_array: xr.DataArray) -> Self:
161        """create a `Tensor` from an xarray data array
162
163        note for internal use: this factory method is round-trip save
164            for any `Tensor`'s  `data` property (an xarray.DataArray).
165        """
166        return cls(
167            array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims)
168        )

create a Tensor from an xarray data array

note for internal use: this factory method is round-trip save for any Tensor's data property (an xarray.DataArray).

@classmethod
def from_numpy( cls, array: numpy.ndarray[typing.Any, numpy.dtype[typing.Any]], *, dims: Union[bioimageio.spec.model.v0_5.AxisId, Literal['b', 'i', 't', 'c', 'z', 'y', 'x'], Annotated[Union[bioimageio.spec.model.v0_5.BatchAxis, bioimageio.spec.model.v0_5.ChannelAxis, bioimageio.spec.model.v0_5.IndexInputAxis, bioimageio.spec.model.v0_5.TimeInputAxis, bioimageio.spec.model.v0_5.SpaceInputAxis], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.spec.model.v0_5.BatchAxis, bioimageio.spec.model.v0_5.ChannelAxis, bioimageio.spec.model.v0_5.IndexOutputAxis, Annotated[Union[Annotated[bioimageio.spec.model.v0_5.TimeOutputAxis, Tag(tag='wo_halo')], Annotated[bioimageio.spec.model.v0_5.TimeOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[Annotated[bioimageio.spec.model.v0_5.SpaceOutputAxis, Tag(tag='wo_halo')], Annotated[bioimageio.spec.model.v0_5.SpaceOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], bioimageio.core.Axis, Sequence[Union[bioimageio.spec.model.v0_5.AxisId, Literal['b', 'i', 't', 'c', 'z', 'y', 'x'], Annotated[Union[bioimageio.spec.model.v0_5.BatchAxis, bioimageio.spec.model.v0_5.ChannelAxis, bioimageio.spec.model.v0_5.IndexInputAxis, bioimageio.spec.model.v0_5.TimeInputAxis, bioimageio.spec.model.v0_5.SpaceInputAxis], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.spec.model.v0_5.BatchAxis, bioimageio.spec.model.v0_5.ChannelAxis, bioimageio.spec.model.v0_5.IndexOutputAxis, Annotated[Union[Annotated[bioimageio.spec.model.v0_5.TimeOutputAxis, Tag(tag='wo_halo')], Annotated[bioimageio.spec.model.v0_5.TimeOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[Annotated[bioimageio.spec.model.v0_5.SpaceOutputAxis, Tag(tag='wo_halo')], Annotated[bioimageio.spec.model.v0_5.SpaceOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], bioimageio.core.Axis]], NoneType]) -> Tensor:
170    @classmethod
171    def from_numpy(
172        cls,
173        array: NDArray[Any],
174        *,
175        dims: Optional[Union[AxisLike, Sequence[AxisLike]]],
176    ) -> Tensor:
177        """create a `Tensor` from a numpy array
178
179        Args:
180            array: the nd numpy array
181            axes: A description of the array's axes,
182                if None axes are guessed (which might fail and raise a ValueError.)
183
184        Raises:
185            ValueError: if `axes` is None and axes guessing fails.
186        """
187
188        if dims is None:
189            return cls._interprete_array_wo_known_axes(array)
190        elif isinstance(dims, (str, Axis, v0_5.AxisBase)):
191            dims = [dims]
192
193        axis_infos = [AxisInfo.create(a) for a in dims]
194        original_shape = tuple(array.shape)
195
196        successful_view = _get_array_view(array, axis_infos)
197        if successful_view is None:
198            raise ValueError(
199                f"Array shape {original_shape} does not map to axes {dims}"
200            )
201
202        return Tensor(successful_view, dims=tuple(a.id for a in axis_infos))

create a Tensor from a numpy array

Arguments:
  • array: the nd numpy array
  • axes: A description of the array's axes, if None axes are guessed (which might fail and raise a ValueError.)
Raises:
  • ValueError: if axes is None and axes guessing fails.
data
204    @property
205    def data(self):
206        return self._data
dims
208    @property
209    def dims(self):  # TODO: rename to `axes`?
210        """Tuple of dimension names associated with this tensor."""
211        return cast(Tuple[AxisId, ...], self._data.dims)

Tuple of dimension names associated with this tensor.

dtype: Literal['bool', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64']
213    @property
214    def dtype(self) -> DTypeStr:
215        dt = str(self.data.dtype)  # pyright: ignore[reportUnknownArgumentType]
216        assert dt in get_args(DTypeStr)
217        return dt  # pyright: ignore[reportReturnType]
ndim
219    @property
220    def ndim(self):
221        """Number of tensor dimensions."""
222        return self._data.ndim

Number of tensor dimensions.

shape
224    @property
225    def shape(self):
226        """Tuple of tensor axes lengths"""
227        return self._data.shape

Tuple of tensor axes lengths

shape_tuple
229    @property
230    def shape_tuple(self):
231        """Tuple of tensor axes lengths"""
232        return self._data.shape

Tuple of tensor axes lengths

size
234    @property
235    def size(self):
236        """Number of elements in the tensor.
237
238        Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions.
239        """
240        return self._data.size

Number of elements in the tensor.

Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions.

sizes
242    @property
243    def sizes(self):
244        """Ordered, immutable mapping from axis ids to axis lengths."""
245        return cast(Mapping[AxisId, int], self.data.sizes)

Ordered, immutable mapping from axis ids to axis lengths.

tagged_shape
247    @property
248    def tagged_shape(self):
249        """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths."""
250        return self.sizes

(alias for sizes) Ordered, immutable mapping from axis ids to lengths.

def argmax(self) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
252    def argmax(self) -> Mapping[AxisId, int]:
253        ret = self._data.argmax(...)
254        assert isinstance(ret, dict)
255        return {cast(AxisId, k): cast(int, v.item()) for k, v in ret.items()}
def astype( self, dtype: Literal['bool', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64'], *, copy: bool = False):
257    def astype(self, dtype: DTypeStr, *, copy: bool = False):
258        """Return tensor cast to `dtype`
259
260        note: if dtype is already satisfied copy if `copy`"""
261        return self.__class__.from_xarray(self._data.astype(dtype, copy=copy))

Return tensor cast to dtype

note: if dtype is already satisfied copy if copy

def clip(self, min: Optional[float] = None, max: Optional[float] = None):
263    def clip(self, min: Optional[float] = None, max: Optional[float] = None):
264        """Return a tensor whose values are limited to [min, max].
265        At least one of max or min must be given."""
266        return self.__class__.from_xarray(self._data.clip(min, max))

Return a tensor whose values are limited to [min, max]. At least one of max or min must be given.

def crop_to( self, sizes: Mapping[bioimageio.spec.model.v0_5.AxisId, int], crop_where: Union[Literal['left', 'right', 'left_and_right'], Mapping[bioimageio.spec.model.v0_5.AxisId, Literal['left', 'right', 'left_and_right']]] = 'left_and_right') -> Self:
268    def crop_to(
269        self,
270        sizes: PerAxis[int],
271        crop_where: Union[
272            CropWhere,
273            PerAxis[CropWhere],
274        ] = "left_and_right",
275    ) -> Self:
276        """crop to match `sizes`"""
277        if isinstance(crop_where, str):
278            crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims}
279        else:
280            crop_axis_where = crop_where
281
282        slices: Dict[AxisId, SliceInfo] = {}
283
284        for a, s_is in self.sizes.items():
285            if a not in sizes or sizes[a] == s_is:
286                pass
287            elif sizes[a] > s_is:
288                logger.warning(
289                    "Cannot crop axis {} of size {} to larger size {}",
290                    a,
291                    s_is,
292                    sizes[a],
293                )
294            elif a not in crop_axis_where:
295                raise ValueError(
296                    f"Don't know where to crop axis {a}, `crop_where`={crop_where}"
297                )
298            else:
299                crop_this_axis_where = crop_axis_where[a]
300                if crop_this_axis_where == "left":
301                    slices[a] = SliceInfo(s_is - sizes[a], s_is)
302                elif crop_this_axis_where == "right":
303                    slices[a] = SliceInfo(0, sizes[a])
304                elif crop_this_axis_where == "left_and_right":
305                    slices[a] = SliceInfo(
306                        start := (s_is - sizes[a]) // 2, sizes[a] + start
307                    )
308                else:
309                    assert_never(crop_this_axis_where)
310
311        return self[slices]

crop to match sizes

def expand_dims( self, dims: Union[Sequence[bioimageio.spec.model.v0_5.AxisId], Mapping[bioimageio.spec.model.v0_5.AxisId, int]]) -> Self:
313    def expand_dims(self, dims: Union[Sequence[AxisId], PerAxis[int]]) -> Self:
314        return self.__class__.from_xarray(self._data.expand_dims(dims=dims))
def item( self, key: Union[NoneType, bioimageio.core.common.SliceInfo, slice, int, Mapping[bioimageio.spec.model.v0_5.AxisId, Union[bioimageio.core.common.SliceInfo, slice, int]]] = None):
316    def item(
317        self,
318        key: Union[
319            None, SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]]
320        ] = None,
321    ):
322        """Copy a tensor element to a standard Python scalar and return it."""
323        if key is None:
324            ret = self._data.item()
325        else:
326            ret = self[key]._data.item()
327
328        assert isinstance(ret, (bool, float, int))
329        return ret

Copy a tensor element to a standard Python scalar and return it.

def mean( self, dim: Union[bioimageio.spec.model.v0_5.AxisId, Sequence[bioimageio.spec.model.v0_5.AxisId], NoneType] = None) -> Self:
331    def mean(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
332        return self.__class__.from_xarray(self._data.mean(dim=dim))
def pad( self, pad_width: Mapping[bioimageio.spec.model.v0_5.AxisId, Union[int, Tuple[int, int], bioimageio.core.common.PadWidth]], mode: Literal['edge', 'reflect', 'symmetric'] = 'symmetric') -> Self:
334    def pad(
335        self,
336        pad_width: PerAxis[PadWidthLike],
337        mode: PadMode = "symmetric",
338    ) -> Self:
339        pad_width = {a: PadWidth.create(p) for a, p in pad_width.items()}
340        return self.__class__.from_xarray(
341            self._data.pad(pad_width=pad_width, mode=mode)
342        )
def pad_to( self, sizes: Mapping[bioimageio.spec.model.v0_5.AxisId, int], pad_where: Union[Literal['left', 'right', 'left_and_right'], Mapping[bioimageio.spec.model.v0_5.AxisId, Literal['left', 'right', 'left_and_right']]] = 'left_and_right', mode: Literal['edge', 'reflect', 'symmetric'] = 'symmetric') -> Self:
344    def pad_to(
345        self,
346        sizes: PerAxis[int],
347        pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right",
348        mode: PadMode = "symmetric",
349    ) -> Self:
350        """pad `tensor` to match `sizes`"""
351        if isinstance(pad_where, str):
352            pad_axis_where: PerAxis[PadWhere] = {a: pad_where for a in self.dims}
353        else:
354            pad_axis_where = pad_where
355
356        pad_width: Dict[AxisId, PadWidth] = {}
357        for a, s_is in self.sizes.items():
358            if a not in sizes or sizes[a] == s_is:
359                pad_width[a] = PadWidth(0, 0)
360            elif s_is > sizes[a]:
361                pad_width[a] = PadWidth(0, 0)
362                logger.warning(
363                    "Cannot pad axis {} of size {} to smaller size {}",
364                    a,
365                    s_is,
366                    sizes[a],
367                )
368            elif a not in pad_axis_where:
369                raise ValueError(
370                    f"Don't know where to pad axis {a}, `pad_where`={pad_where}"
371                )
372            else:
373                pad_this_axis_where = pad_axis_where[a]
374                d = sizes[a] - s_is
375                if pad_this_axis_where == "left":
376                    pad_width[a] = PadWidth(d, 0)
377                elif pad_this_axis_where == "right":
378                    pad_width[a] = PadWidth(0, d)
379                elif pad_this_axis_where == "left_and_right":
380                    pad_width[a] = PadWidth(left := d // 2, d - left)
381                else:
382                    assert_never(pad_this_axis_where)
383
384        return self.pad(pad_width, mode)

pad tensor to match sizes

def quantile( self, q: Union[float, Sequence[float]], dim: Union[bioimageio.spec.model.v0_5.AxisId, Sequence[bioimageio.spec.model.v0_5.AxisId], NoneType] = None) -> Self:
386    def quantile(
387        self,
388        q: Union[float, Sequence[float]],
389        dim: Optional[Union[AxisId, Sequence[AxisId]]] = None,
390    ) -> Self:
391        assert (
392            isinstance(q, (float, int))
393            and q >= 0.0
394            or not isinstance(q, (float, int))
395            and all(qq >= 0.0 for qq in q)
396        )
397        assert (
398            isinstance(q, (float, int))
399            and q <= 1.0
400            or not isinstance(q, (float, int))
401            and all(qq <= 1.0 for qq in q)
402        )
403        assert dim is None or (
404            (quantile_dim := AxisId("quantile")) != dim and quantile_dim not in set(dim)
405        )
406        return self.__class__.from_xarray(self._data.quantile(q, dim=dim))
def resize_to( self, sizes: Mapping[bioimageio.spec.model.v0_5.AxisId, int], *, pad_where: Union[Literal['left', 'right', 'left_and_right'], Mapping[bioimageio.spec.model.v0_5.AxisId, Literal['left', 'right', 'left_and_right']]] = 'left_and_right', crop_where: Union[Literal['left', 'right', 'left_and_right'], Mapping[bioimageio.spec.model.v0_5.AxisId, Literal['left', 'right', 'left_and_right']]] = 'left_and_right', pad_mode: Literal['edge', 'reflect', 'symmetric'] = 'symmetric'):
408    def resize_to(
409        self,
410        sizes: PerAxis[int],
411        *,
412        pad_where: Union[
413            PadWhere,
414            PerAxis[PadWhere],
415        ] = "left_and_right",
416        crop_where: Union[
417            CropWhere,
418            PerAxis[CropWhere],
419        ] = "left_and_right",
420        pad_mode: PadMode = "symmetric",
421    ):
422        """return cropped/padded tensor with `sizes`"""
423        crop_to_sizes: Dict[AxisId, int] = {}
424        pad_to_sizes: Dict[AxisId, int] = {}
425        new_axes = dict(sizes)
426        for a, s_is in self.sizes.items():
427            a = AxisId(str(a))
428            _ = new_axes.pop(a, None)
429            if a not in sizes or sizes[a] == s_is:
430                pass
431            elif s_is > sizes[a]:
432                crop_to_sizes[a] = sizes[a]
433            else:
434                pad_to_sizes[a] = sizes[a]
435
436        tensor = self
437        if crop_to_sizes:
438            tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where)
439
440        if pad_to_sizes:
441            tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode)
442
443        if new_axes:
444            tensor = tensor.expand_dims(new_axes)
445
446        return tensor

return cropped/padded tensor with sizes

def std( self, dim: Union[bioimageio.spec.model.v0_5.AxisId, Sequence[bioimageio.spec.model.v0_5.AxisId], NoneType] = None) -> Self:
448    def std(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
449        return self.__class__.from_xarray(self._data.std(dim=dim))
def sum( self, dim: Union[bioimageio.spec.model.v0_5.AxisId, Sequence[bioimageio.spec.model.v0_5.AxisId], NoneType] = None) -> Self:
451    def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
452        """Reduce this Tensor's data by applying sum along some dimension(s)."""
453        return self.__class__.from_xarray(self._data.sum(dim=dim))

Reduce this Tensor's data by applying sum along some dimension(s).

def transpose(self, axes: Sequence[bioimageio.spec.model.v0_5.AxisId]) -> Self:
455    def transpose(
456        self,
457        axes: Sequence[AxisId],
458    ) -> Self:
459        """return a transposed tensor
460
461        Args:
462            axes: the desired tensor axes
463        """
464        # expand missing tensor axes
465        missing_axes = tuple(a for a in axes if a not in self.dims)
466        array = self._data
467        if missing_axes:
468            array = array.expand_dims(missing_axes)
469
470        # transpose to the correct axis order
471        return self.__class__.from_xarray(array.transpose(*axes))

return a transposed tensor

Arguments:
  • axes: the desired tensor axes
def var( self, dim: Union[bioimageio.spec.model.v0_5.AxisId, Sequence[bioimageio.spec.model.v0_5.AxisId], NoneType] = None) -> Self:
473    def var(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
474        return self.__class__.from_xarray(self._data.var(dim=dim))