bioimageio.core.tensor

  1from __future__ import annotations
  2
  3import collections.abc
  4from itertools import permutations
  5from typing import (
  6    TYPE_CHECKING,
  7    Any,
  8    Callable,
  9    Dict,
 10    Iterator,
 11    Mapping,
 12    Optional,
 13    Sequence,
 14    Tuple,
 15    Union,
 16    cast,
 17    get_args,
 18)
 19
 20import numpy as np
 21import xarray as xr
 22from loguru import logger
 23from numpy.typing import DTypeLike, NDArray
 24from typing_extensions import Self, assert_never
 25
 26from bioimageio.spec.model import v0_5
 27
 28from ._magic_tensor_ops import MagicTensorOpsMixin
 29from .axis import Axis, AxisId, AxisInfo, AxisLike, PerAxis
 30from .common import (
 31    CropWhere,
 32    DTypeStr,
 33    PadMode,
 34    PadWhere,
 35    PadWidth,
 36    PadWidthLike,
 37    SliceInfo,
 38)
 39
 40if TYPE_CHECKING:
 41    from numpy.typing import ArrayLike, NDArray
 42
 43
 44_ScalarOrArray = Union["ArrayLike", np.generic, "NDArray[Any]"]  # TODO: add "DaskArray"
 45
 46
 47# TODO: complete docstrings
 48# TODO: in the long run---with improved typing in xarray---we should probably replace `Tensor` with xr.DataArray
 49class Tensor(MagicTensorOpsMixin):
 50    """A wrapper around an xr.DataArray for better integration with bioimageio.spec
 51    and improved type annotations."""
 52
 53    _Compatible = Union["Tensor", xr.DataArray, _ScalarOrArray]
 54
 55    def __init__(
 56        self,
 57        array: NDArray[Any],
 58        dims: Sequence[Union[AxisId, AxisLike]],
 59    ) -> None:
 60        super().__init__()
 61        axes = tuple(
 62            a if isinstance(a, AxisId) else AxisInfo.create(a).id for a in dims
 63        )
 64        self._data = xr.DataArray(array, dims=axes)
 65
 66    def __array__(self, dtype: DTypeLike = None):
 67        return np.asarray(self._data, dtype=dtype)
 68
 69    def __getitem__(
 70        self,
 71        key: Union[
 72            SliceInfo,
 73            slice,
 74            int,
 75            PerAxis[Union[SliceInfo, slice, int]],
 76            Tensor,
 77            xr.DataArray,
 78        ],
 79    ) -> Self:
 80        if isinstance(key, SliceInfo):
 81            key = slice(*key)
 82        elif isinstance(key, collections.abc.Mapping):
 83            key = {
 84                a: s if isinstance(s, int) else s if isinstance(s, slice) else slice(*s)
 85                for a, s in key.items()
 86            }
 87        elif isinstance(key, Tensor):
 88            key = key._data
 89
 90        return self.__class__.from_xarray(self._data[key])
 91
 92    def __setitem__(
 93        self,
 94        key: Union[PerAxis[Union[SliceInfo, slice]], Tensor, xr.DataArray],
 95        value: Union[Tensor, xr.DataArray, float, int],
 96    ) -> None:
 97        if isinstance(key, Tensor):
 98            key = key._data
 99        elif isinstance(key, xr.DataArray):
100            pass
101        else:
102            key = {a: s if isinstance(s, slice) else slice(*s) for a, s in key.items()}
103
104        if isinstance(value, Tensor):
105            value = value._data
106
107        self._data[key] = value
108
109    def __len__(self) -> int:
110        return len(self.data)
111
112    def _iter(self: Any) -> Iterator[Any]:
113        for n in range(len(self)):
114            yield self[n]
115
116    def __iter__(self: Any) -> Iterator[Any]:
117        if self.ndim == 0:
118            raise TypeError("iteration over a 0-d array")
119        return self._iter()
120
121    def _binary_op(
122        self,
123        other: _Compatible,
124        f: Callable[[Any, Any], Any],
125        reflexive: bool = False,
126    ) -> Self:
127        data = self._data._binary_op(  # pyright: ignore[reportPrivateUsage]
128            (other._data if isinstance(other, Tensor) else other),
129            f,
130            reflexive,
131        )
132        return self.__class__.from_xarray(data)
133
134    def _inplace_binary_op(
135        self,
136        other: _Compatible,
137        f: Callable[[Any, Any], Any],
138    ) -> Self:
139        _ = self._data._inplace_binary_op(  # pyright: ignore[reportPrivateUsage]
140            (
141                other_d
142                if (other_d := getattr(other, "data")) is not None
143                and isinstance(
144                    other_d,
145                    xr.DataArray,
146                )
147                else other
148            ),
149            f,
150        )
151        return self
152
153    def _unary_op(self, f: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Self:
154        data = self._data._unary_op(  # pyright: ignore[reportPrivateUsage]
155            f, *args, **kwargs
156        )
157        return self.__class__.from_xarray(data)
158
159    @classmethod
160    def from_xarray(cls, data_array: xr.DataArray) -> Self:
161        """create a `Tensor` from an xarray data array
162
163        note for internal use: this factory method is round-trip save
164            for any `Tensor`'s  `data` property (an xarray.DataArray).
165        """
166        return cls(
167            array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims)
168        )
169
170    @classmethod
171    def from_numpy(
172        cls,
173        array: NDArray[Any],
174        *,
175        dims: Optional[Union[AxisLike, Sequence[AxisLike]]],
176    ) -> Tensor:
177        """create a `Tensor` from a numpy array
178
179        Args:
180            array: the nd numpy array
181            axes: A description of the array's axes,
182                if None axes are guessed (which might fail and raise a ValueError.)
183
184        Raises:
185            ValueError: if `axes` is None and axes guessing fails.
186        """
187
188        if dims is None:
189            return cls._interprete_array_wo_known_axes(array)
190        elif isinstance(dims, (str, Axis, v0_5.AxisBase)):
191            dims = [dims]
192
193        axis_infos = [AxisInfo.create(a) for a in dims]
194        original_shape = tuple(array.shape)
195
196        successful_view = _get_array_view(array, axis_infos)
197        if successful_view is None:
198            raise ValueError(
199                f"Array shape {original_shape} does not map to axes {dims}"
200            )
201
202        return Tensor(successful_view, dims=tuple(a.id for a in axis_infos))
203
204    @property
205    def data(self):
206        return self._data
207
208    @property
209    def dims(self):  # TODO: rename to `axes`?
210        """Tuple of dimension names associated with this tensor."""
211        return cast(Tuple[AxisId, ...], self._data.dims)
212
213    @property
214    def dtype(self) -> DTypeStr:
215        dt = str(self.data.dtype)  # pyright: ignore[reportUnknownArgumentType]
216        assert dt in get_args(DTypeStr)
217        return dt  # pyright: ignore[reportReturnType]
218
219    @property
220    def ndim(self):
221        """Number of tensor dimensions."""
222        return self._data.ndim
223
224    @property
225    def shape(self):
226        """Tuple of tensor axes lengths"""
227        return self._data.shape
228
229    @property
230    def shape_tuple(self):
231        """Tuple of tensor axes lengths"""
232        return self._data.shape
233
234    @property
235    def size(self):
236        """Number of elements in the tensor.
237
238        Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions.
239        """
240        return self._data.size
241
242    @property
243    def sizes(self):
244        """Ordered, immutable mapping from axis ids to axis lengths."""
245        return cast(Mapping[AxisId, int], self.data.sizes)
246
247    @property
248    def tagged_shape(self):
249        """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths."""
250        return self.sizes
251
252    def argmax(self) -> Mapping[AxisId, int]:
253        ret = self._data.argmax(...)
254        assert isinstance(ret, dict)
255        return {cast(AxisId, k): cast(int, v.item()) for k, v in ret.items()}
256
257    def astype(self, dtype: DTypeStr, *, copy: bool = False):
258        """Return tensor cast to `dtype`
259
260        note: if dtype is already satisfied copy if `copy`"""
261        return self.__class__.from_xarray(self._data.astype(dtype, copy=copy))
262
263    def clip(self, min: Optional[float] = None, max: Optional[float] = None):
264        """Return a tensor whose values are limited to [min, max].
265        At least one of max or min must be given."""
266        return self.__class__.from_xarray(self._data.clip(min, max))
267
268    def crop_to(
269        self,
270        sizes: PerAxis[int],
271        crop_where: Union[
272            CropWhere,
273            PerAxis[CropWhere],
274        ] = "left_and_right",
275    ) -> Self:
276        """crop to match `sizes`"""
277        if isinstance(crop_where, str):
278            crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims}
279        else:
280            crop_axis_where = crop_where
281
282        slices: Dict[AxisId, SliceInfo] = {}
283
284        for a, s_is in self.sizes.items():
285            if a not in sizes or sizes[a] == s_is:
286                pass
287            elif sizes[a] > s_is:
288                logger.warning(
289                    "Cannot crop axis {} of size {} to larger size {}",
290                    a,
291                    s_is,
292                    sizes[a],
293                )
294            elif a not in crop_axis_where:
295                raise ValueError(
296                    f"Don't know where to crop axis {a}, `crop_where`={crop_where}"
297                )
298            else:
299                crop_this_axis_where = crop_axis_where[a]
300                if crop_this_axis_where == "left":
301                    slices[a] = SliceInfo(s_is - sizes[a], s_is)
302                elif crop_this_axis_where == "right":
303                    slices[a] = SliceInfo(0, sizes[a])
304                elif crop_this_axis_where == "left_and_right":
305                    slices[a] = SliceInfo(
306                        start := (s_is - sizes[a]) // 2, sizes[a] + start
307                    )
308                else:
309                    assert_never(crop_this_axis_where)
310
311        return self[slices]
312
313    def expand_dims(self, dims: Union[Sequence[AxisId], PerAxis[int]]) -> Self:
314        return self.__class__.from_xarray(self._data.expand_dims(dims=dims))
315
316    def item(
317        self,
318        key: Union[
319            None, SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]]
320        ] = None,
321    ):
322        """Copy a tensor element to a standard Python scalar and return it."""
323        if key is None:
324            ret = self._data.item()
325        else:
326            ret = self[key]._data.item()
327
328        assert isinstance(ret, (bool, float, int))
329        return ret
330
331    def mean(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
332        return self.__class__.from_xarray(self._data.mean(dim=dim))
333
334    def pad(
335        self,
336        pad_width: PerAxis[PadWidthLike],
337        mode: PadMode = "symmetric",
338    ) -> Self:
339        pad_width = {a: PadWidth.create(p) for a, p in pad_width.items()}
340        return self.__class__.from_xarray(
341            self._data.pad(pad_width=pad_width, mode=mode)
342        )
343
344    def pad_to(
345        self,
346        sizes: PerAxis[int],
347        pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right",
348        mode: PadMode = "symmetric",
349    ) -> Self:
350        """pad `tensor` to match `sizes`"""
351        if isinstance(pad_where, str):
352            pad_axis_where: PerAxis[PadWhere] = {a: pad_where for a in self.dims}
353        else:
354            pad_axis_where = pad_where
355
356        pad_width: Dict[AxisId, PadWidth] = {}
357        for a, s_is in self.sizes.items():
358            if a not in sizes or sizes[a] == s_is:
359                pad_width[a] = PadWidth(0, 0)
360            elif s_is > sizes[a]:
361                pad_width[a] = PadWidth(0, 0)
362                logger.warning(
363                    "Cannot pad axis {} of size {} to smaller size {}",
364                    a,
365                    s_is,
366                    sizes[a],
367                )
368            elif a not in pad_axis_where:
369                raise ValueError(
370                    f"Don't know where to pad axis {a}, `pad_where`={pad_where}"
371                )
372            else:
373                pad_this_axis_where = pad_axis_where[a]
374                d = sizes[a] - s_is
375                if pad_this_axis_where == "left":
376                    pad_width[a] = PadWidth(d, 0)
377                elif pad_this_axis_where == "right":
378                    pad_width[a] = PadWidth(0, d)
379                elif pad_this_axis_where == "left_and_right":
380                    pad_width[a] = PadWidth(left := d // 2, d - left)
381                else:
382                    assert_never(pad_this_axis_where)
383
384        return self.pad(pad_width, mode)
385
386    def quantile(
387        self,
388        q: Union[float, Sequence[float]],
389        dim: Optional[Union[AxisId, Sequence[AxisId]]] = None,
390    ) -> Self:
391        assert (
392            isinstance(q, (float, int))
393            and q >= 0.0
394            or not isinstance(q, (float, int))
395            and all(qq >= 0.0 for qq in q)
396        )
397        assert (
398            isinstance(q, (float, int))
399            and q <= 1.0
400            or not isinstance(q, (float, int))
401            and all(qq <= 1.0 for qq in q)
402        )
403        assert dim is None or (
404            (quantile_dim := AxisId("quantile")) != dim and quantile_dim not in set(dim)
405        )
406        return self.__class__.from_xarray(self._data.quantile(q, dim=dim))
407
408    def resize_to(
409        self,
410        sizes: PerAxis[int],
411        *,
412        pad_where: Union[
413            PadWhere,
414            PerAxis[PadWhere],
415        ] = "left_and_right",
416        crop_where: Union[
417            CropWhere,
418            PerAxis[CropWhere],
419        ] = "left_and_right",
420        pad_mode: PadMode = "symmetric",
421    ):
422        """return cropped/padded tensor with `sizes`"""
423        crop_to_sizes: Dict[AxisId, int] = {}
424        pad_to_sizes: Dict[AxisId, int] = {}
425        new_axes = dict(sizes)
426        for a, s_is in self.sizes.items():
427            a = AxisId(str(a))
428            _ = new_axes.pop(a, None)
429            if a not in sizes or sizes[a] == s_is:
430                pass
431            elif s_is > sizes[a]:
432                crop_to_sizes[a] = sizes[a]
433            else:
434                pad_to_sizes[a] = sizes[a]
435
436        tensor = self
437        if crop_to_sizes:
438            tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where)
439
440        if pad_to_sizes:
441            tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode)
442
443        if new_axes:
444            tensor = tensor.expand_dims(new_axes)
445
446        return tensor
447
448    def std(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
449        return self.__class__.from_xarray(self._data.std(dim=dim))
450
451    def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
452        """Reduce this Tensor's data by applying sum along some dimension(s)."""
453        return self.__class__.from_xarray(self._data.sum(dim=dim))
454
455    def transpose(
456        self,
457        axes: Sequence[AxisId],
458    ) -> Self:
459        """return a transposed tensor
460
461        Args:
462            axes: the desired tensor axes
463        """
464        # expand missing tensor axes
465        missing_axes = tuple(a for a in axes if a not in self.dims)
466        array = self._data
467        if missing_axes:
468            array = array.expand_dims(missing_axes)
469
470        # transpose to the correct axis order
471        return self.__class__.from_xarray(array.transpose(*axes))
472
473    def var(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
474        return self.__class__.from_xarray(self._data.var(dim=dim))
475
476    @classmethod
477    def _interprete_array_wo_known_axes(cls, array: NDArray[Any]):
478        ndim = array.ndim
479        if ndim == 2:
480            current_axes = (
481                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[0]),
482                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[1]),
483            )
484        elif ndim == 3 and any(s <= 3 for s in array.shape):
485            current_axes = (
486                v0_5.ChannelAxis(
487                    channel_names=[
488                        v0_5.Identifier(f"channel{i}") for i in range(array.shape[0])
489                    ]
490                ),
491                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]),
492                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]),
493            )
494        elif ndim == 3:
495            current_axes = (
496                v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[0]),
497                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]),
498                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]),
499            )
500        elif ndim == 4:
501            current_axes = (
502                v0_5.ChannelAxis(
503                    channel_names=[
504                        v0_5.Identifier(f"channel{i}") for i in range(array.shape[0])
505                    ]
506                ),
507                v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[1]),
508                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[2]),
509                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[3]),
510            )
511        elif ndim == 5:
512            current_axes = (
513                v0_5.BatchAxis(),
514                v0_5.ChannelAxis(
515                    channel_names=[
516                        v0_5.Identifier(f"channel{i}") for i in range(array.shape[1])
517                    ]
518                ),
519                v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[2]),
520                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[3]),
521                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[4]),
522            )
523        else:
524            raise ValueError(f"Could not guess an axis mapping for {array.shape}")
525
526        return cls(array, dims=tuple(a.id for a in current_axes))
527
528
529def _add_singletons(arr: NDArray[Any], axis_infos: Sequence[AxisInfo]):
530    if len(arr.shape) > len(axis_infos):
531        # remove singletons
532        for i, s in enumerate(arr.shape):
533            if s == 1:
534                arr = np.take(arr, 0, axis=i)
535                if len(arr.shape) == len(axis_infos):
536                    break
537
538    # add singletons if nececsary
539    for i, a in enumerate(axis_infos):
540        if len(arr.shape) >= len(axis_infos):
541            break
542
543        if a.maybe_singleton:
544            arr = np.expand_dims(arr, i)
545
546    return arr
547
548
549def _get_array_view(
550    original_array: NDArray[Any], axis_infos: Sequence[AxisInfo]
551) -> Optional[NDArray[Any]]:
552    perms = list(permutations(range(len(original_array.shape))))
553    perms.insert(1, perms.pop())  # try A and A.T first
554
555    for perm in perms:
556        view = original_array.transpose(perm)
557        view = _add_singletons(view, axis_infos)
558        if len(view.shape) != len(axis_infos):
559            return None
560
561        for s, a in zip(view.shape, axis_infos):
562            if s == 1 and not a.maybe_singleton:
563                break
564        else:
565            return view
566
567    return None
class Tensor(bioimageio.core._magic_tensor_ops.MagicTensorOpsMixin):
 50class Tensor(MagicTensorOpsMixin):
 51    """A wrapper around an xr.DataArray for better integration with bioimageio.spec
 52    and improved type annotations."""
 53
 54    _Compatible = Union["Tensor", xr.DataArray, _ScalarOrArray]
 55
 56    def __init__(
 57        self,
 58        array: NDArray[Any],
 59        dims: Sequence[Union[AxisId, AxisLike]],
 60    ) -> None:
 61        super().__init__()
 62        axes = tuple(
 63            a if isinstance(a, AxisId) else AxisInfo.create(a).id for a in dims
 64        )
 65        self._data = xr.DataArray(array, dims=axes)
 66
 67    def __array__(self, dtype: DTypeLike = None):
 68        return np.asarray(self._data, dtype=dtype)
 69
 70    def __getitem__(
 71        self,
 72        key: Union[
 73            SliceInfo,
 74            slice,
 75            int,
 76            PerAxis[Union[SliceInfo, slice, int]],
 77            Tensor,
 78            xr.DataArray,
 79        ],
 80    ) -> Self:
 81        if isinstance(key, SliceInfo):
 82            key = slice(*key)
 83        elif isinstance(key, collections.abc.Mapping):
 84            key = {
 85                a: s if isinstance(s, int) else s if isinstance(s, slice) else slice(*s)
 86                for a, s in key.items()
 87            }
 88        elif isinstance(key, Tensor):
 89            key = key._data
 90
 91        return self.__class__.from_xarray(self._data[key])
 92
 93    def __setitem__(
 94        self,
 95        key: Union[PerAxis[Union[SliceInfo, slice]], Tensor, xr.DataArray],
 96        value: Union[Tensor, xr.DataArray, float, int],
 97    ) -> None:
 98        if isinstance(key, Tensor):
 99            key = key._data
100        elif isinstance(key, xr.DataArray):
101            pass
102        else:
103            key = {a: s if isinstance(s, slice) else slice(*s) for a, s in key.items()}
104
105        if isinstance(value, Tensor):
106            value = value._data
107
108        self._data[key] = value
109
110    def __len__(self) -> int:
111        return len(self.data)
112
113    def _iter(self: Any) -> Iterator[Any]:
114        for n in range(len(self)):
115            yield self[n]
116
117    def __iter__(self: Any) -> Iterator[Any]:
118        if self.ndim == 0:
119            raise TypeError("iteration over a 0-d array")
120        return self._iter()
121
122    def _binary_op(
123        self,
124        other: _Compatible,
125        f: Callable[[Any, Any], Any],
126        reflexive: bool = False,
127    ) -> Self:
128        data = self._data._binary_op(  # pyright: ignore[reportPrivateUsage]
129            (other._data if isinstance(other, Tensor) else other),
130            f,
131            reflexive,
132        )
133        return self.__class__.from_xarray(data)
134
135    def _inplace_binary_op(
136        self,
137        other: _Compatible,
138        f: Callable[[Any, Any], Any],
139    ) -> Self:
140        _ = self._data._inplace_binary_op(  # pyright: ignore[reportPrivateUsage]
141            (
142                other_d
143                if (other_d := getattr(other, "data")) is not None
144                and isinstance(
145                    other_d,
146                    xr.DataArray,
147                )
148                else other
149            ),
150            f,
151        )
152        return self
153
154    def _unary_op(self, f: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Self:
155        data = self._data._unary_op(  # pyright: ignore[reportPrivateUsage]
156            f, *args, **kwargs
157        )
158        return self.__class__.from_xarray(data)
159
160    @classmethod
161    def from_xarray(cls, data_array: xr.DataArray) -> Self:
162        """create a `Tensor` from an xarray data array
163
164        note for internal use: this factory method is round-trip save
165            for any `Tensor`'s  `data` property (an xarray.DataArray).
166        """
167        return cls(
168            array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims)
169        )
170
171    @classmethod
172    def from_numpy(
173        cls,
174        array: NDArray[Any],
175        *,
176        dims: Optional[Union[AxisLike, Sequence[AxisLike]]],
177    ) -> Tensor:
178        """create a `Tensor` from a numpy array
179
180        Args:
181            array: the nd numpy array
182            axes: A description of the array's axes,
183                if None axes are guessed (which might fail and raise a ValueError.)
184
185        Raises:
186            ValueError: if `axes` is None and axes guessing fails.
187        """
188
189        if dims is None:
190            return cls._interprete_array_wo_known_axes(array)
191        elif isinstance(dims, (str, Axis, v0_5.AxisBase)):
192            dims = [dims]
193
194        axis_infos = [AxisInfo.create(a) for a in dims]
195        original_shape = tuple(array.shape)
196
197        successful_view = _get_array_view(array, axis_infos)
198        if successful_view is None:
199            raise ValueError(
200                f"Array shape {original_shape} does not map to axes {dims}"
201            )
202
203        return Tensor(successful_view, dims=tuple(a.id for a in axis_infos))
204
205    @property
206    def data(self):
207        return self._data
208
209    @property
210    def dims(self):  # TODO: rename to `axes`?
211        """Tuple of dimension names associated with this tensor."""
212        return cast(Tuple[AxisId, ...], self._data.dims)
213
214    @property
215    def dtype(self) -> DTypeStr:
216        dt = str(self.data.dtype)  # pyright: ignore[reportUnknownArgumentType]
217        assert dt in get_args(DTypeStr)
218        return dt  # pyright: ignore[reportReturnType]
219
220    @property
221    def ndim(self):
222        """Number of tensor dimensions."""
223        return self._data.ndim
224
225    @property
226    def shape(self):
227        """Tuple of tensor axes lengths"""
228        return self._data.shape
229
230    @property
231    def shape_tuple(self):
232        """Tuple of tensor axes lengths"""
233        return self._data.shape
234
235    @property
236    def size(self):
237        """Number of elements in the tensor.
238
239        Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions.
240        """
241        return self._data.size
242
243    @property
244    def sizes(self):
245        """Ordered, immutable mapping from axis ids to axis lengths."""
246        return cast(Mapping[AxisId, int], self.data.sizes)
247
248    @property
249    def tagged_shape(self):
250        """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths."""
251        return self.sizes
252
253    def argmax(self) -> Mapping[AxisId, int]:
254        ret = self._data.argmax(...)
255        assert isinstance(ret, dict)
256        return {cast(AxisId, k): cast(int, v.item()) for k, v in ret.items()}
257
258    def astype(self, dtype: DTypeStr, *, copy: bool = False):
259        """Return tensor cast to `dtype`
260
261        note: if dtype is already satisfied copy if `copy`"""
262        return self.__class__.from_xarray(self._data.astype(dtype, copy=copy))
263
264    def clip(self, min: Optional[float] = None, max: Optional[float] = None):
265        """Return a tensor whose values are limited to [min, max].
266        At least one of max or min must be given."""
267        return self.__class__.from_xarray(self._data.clip(min, max))
268
269    def crop_to(
270        self,
271        sizes: PerAxis[int],
272        crop_where: Union[
273            CropWhere,
274            PerAxis[CropWhere],
275        ] = "left_and_right",
276    ) -> Self:
277        """crop to match `sizes`"""
278        if isinstance(crop_where, str):
279            crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims}
280        else:
281            crop_axis_where = crop_where
282
283        slices: Dict[AxisId, SliceInfo] = {}
284
285        for a, s_is in self.sizes.items():
286            if a not in sizes or sizes[a] == s_is:
287                pass
288            elif sizes[a] > s_is:
289                logger.warning(
290                    "Cannot crop axis {} of size {} to larger size {}",
291                    a,
292                    s_is,
293                    sizes[a],
294                )
295            elif a not in crop_axis_where:
296                raise ValueError(
297                    f"Don't know where to crop axis {a}, `crop_where`={crop_where}"
298                )
299            else:
300                crop_this_axis_where = crop_axis_where[a]
301                if crop_this_axis_where == "left":
302                    slices[a] = SliceInfo(s_is - sizes[a], s_is)
303                elif crop_this_axis_where == "right":
304                    slices[a] = SliceInfo(0, sizes[a])
305                elif crop_this_axis_where == "left_and_right":
306                    slices[a] = SliceInfo(
307                        start := (s_is - sizes[a]) // 2, sizes[a] + start
308                    )
309                else:
310                    assert_never(crop_this_axis_where)
311
312        return self[slices]
313
314    def expand_dims(self, dims: Union[Sequence[AxisId], PerAxis[int]]) -> Self:
315        return self.__class__.from_xarray(self._data.expand_dims(dims=dims))
316
317    def item(
318        self,
319        key: Union[
320            None, SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]]
321        ] = None,
322    ):
323        """Copy a tensor element to a standard Python scalar and return it."""
324        if key is None:
325            ret = self._data.item()
326        else:
327            ret = self[key]._data.item()
328
329        assert isinstance(ret, (bool, float, int))
330        return ret
331
332    def mean(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
333        return self.__class__.from_xarray(self._data.mean(dim=dim))
334
335    def pad(
336        self,
337        pad_width: PerAxis[PadWidthLike],
338        mode: PadMode = "symmetric",
339    ) -> Self:
340        pad_width = {a: PadWidth.create(p) for a, p in pad_width.items()}
341        return self.__class__.from_xarray(
342            self._data.pad(pad_width=pad_width, mode=mode)
343        )
344
345    def pad_to(
346        self,
347        sizes: PerAxis[int],
348        pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right",
349        mode: PadMode = "symmetric",
350    ) -> Self:
351        """pad `tensor` to match `sizes`"""
352        if isinstance(pad_where, str):
353            pad_axis_where: PerAxis[PadWhere] = {a: pad_where for a in self.dims}
354        else:
355            pad_axis_where = pad_where
356
357        pad_width: Dict[AxisId, PadWidth] = {}
358        for a, s_is in self.sizes.items():
359            if a not in sizes or sizes[a] == s_is:
360                pad_width[a] = PadWidth(0, 0)
361            elif s_is > sizes[a]:
362                pad_width[a] = PadWidth(0, 0)
363                logger.warning(
364                    "Cannot pad axis {} of size {} to smaller size {}",
365                    a,
366                    s_is,
367                    sizes[a],
368                )
369            elif a not in pad_axis_where:
370                raise ValueError(
371                    f"Don't know where to pad axis {a}, `pad_where`={pad_where}"
372                )
373            else:
374                pad_this_axis_where = pad_axis_where[a]
375                d = sizes[a] - s_is
376                if pad_this_axis_where == "left":
377                    pad_width[a] = PadWidth(d, 0)
378                elif pad_this_axis_where == "right":
379                    pad_width[a] = PadWidth(0, d)
380                elif pad_this_axis_where == "left_and_right":
381                    pad_width[a] = PadWidth(left := d // 2, d - left)
382                else:
383                    assert_never(pad_this_axis_where)
384
385        return self.pad(pad_width, mode)
386
387    def quantile(
388        self,
389        q: Union[float, Sequence[float]],
390        dim: Optional[Union[AxisId, Sequence[AxisId]]] = None,
391    ) -> Self:
392        assert (
393            isinstance(q, (float, int))
394            and q >= 0.0
395            or not isinstance(q, (float, int))
396            and all(qq >= 0.0 for qq in q)
397        )
398        assert (
399            isinstance(q, (float, int))
400            and q <= 1.0
401            or not isinstance(q, (float, int))
402            and all(qq <= 1.0 for qq in q)
403        )
404        assert dim is None or (
405            (quantile_dim := AxisId("quantile")) != dim and quantile_dim not in set(dim)
406        )
407        return self.__class__.from_xarray(self._data.quantile(q, dim=dim))
408
409    def resize_to(
410        self,
411        sizes: PerAxis[int],
412        *,
413        pad_where: Union[
414            PadWhere,
415            PerAxis[PadWhere],
416        ] = "left_and_right",
417        crop_where: Union[
418            CropWhere,
419            PerAxis[CropWhere],
420        ] = "left_and_right",
421        pad_mode: PadMode = "symmetric",
422    ):
423        """return cropped/padded tensor with `sizes`"""
424        crop_to_sizes: Dict[AxisId, int] = {}
425        pad_to_sizes: Dict[AxisId, int] = {}
426        new_axes = dict(sizes)
427        for a, s_is in self.sizes.items():
428            a = AxisId(str(a))
429            _ = new_axes.pop(a, None)
430            if a not in sizes or sizes[a] == s_is:
431                pass
432            elif s_is > sizes[a]:
433                crop_to_sizes[a] = sizes[a]
434            else:
435                pad_to_sizes[a] = sizes[a]
436
437        tensor = self
438        if crop_to_sizes:
439            tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where)
440
441        if pad_to_sizes:
442            tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode)
443
444        if new_axes:
445            tensor = tensor.expand_dims(new_axes)
446
447        return tensor
448
449    def std(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
450        return self.__class__.from_xarray(self._data.std(dim=dim))
451
452    def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
453        """Reduce this Tensor's data by applying sum along some dimension(s)."""
454        return self.__class__.from_xarray(self._data.sum(dim=dim))
455
456    def transpose(
457        self,
458        axes: Sequence[AxisId],
459    ) -> Self:
460        """return a transposed tensor
461
462        Args:
463            axes: the desired tensor axes
464        """
465        # expand missing tensor axes
466        missing_axes = tuple(a for a in axes if a not in self.dims)
467        array = self._data
468        if missing_axes:
469            array = array.expand_dims(missing_axes)
470
471        # transpose to the correct axis order
472        return self.__class__.from_xarray(array.transpose(*axes))
473
474    def var(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
475        return self.__class__.from_xarray(self._data.var(dim=dim))
476
477    @classmethod
478    def _interprete_array_wo_known_axes(cls, array: NDArray[Any]):
479        ndim = array.ndim
480        if ndim == 2:
481            current_axes = (
482                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[0]),
483                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[1]),
484            )
485        elif ndim == 3 and any(s <= 3 for s in array.shape):
486            current_axes = (
487                v0_5.ChannelAxis(
488                    channel_names=[
489                        v0_5.Identifier(f"channel{i}") for i in range(array.shape[0])
490                    ]
491                ),
492                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]),
493                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]),
494            )
495        elif ndim == 3:
496            current_axes = (
497                v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[0]),
498                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]),
499                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]),
500            )
501        elif ndim == 4:
502            current_axes = (
503                v0_5.ChannelAxis(
504                    channel_names=[
505                        v0_5.Identifier(f"channel{i}") for i in range(array.shape[0])
506                    ]
507                ),
508                v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[1]),
509                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[2]),
510                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[3]),
511            )
512        elif ndim == 5:
513            current_axes = (
514                v0_5.BatchAxis(),
515                v0_5.ChannelAxis(
516                    channel_names=[
517                        v0_5.Identifier(f"channel{i}") for i in range(array.shape[1])
518                    ]
519                ),
520                v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[2]),
521                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[3]),
522                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[4]),
523            )
524        else:
525            raise ValueError(f"Could not guess an axis mapping for {array.shape}")
526
527        return cls(array, dims=tuple(a.id for a in current_axes))

A wrapper around an xr.DataArray for better integration with bioimageio.spec and improved type annotations.

Tensor( array: numpy.ndarray[typing.Any, numpy.dtype[typing.Any]], dims: Sequence[Union[bioimageio.spec.model.v0_5.AxisId, Literal['b', 'i', 't', 'c', 'z', 'y', 'x'], Annotated[Union[bioimageio.spec.model.v0_5.BatchAxis, bioimageio.spec.model.v0_5.ChannelAxis, bioimageio.spec.model.v0_5.IndexInputAxis, bioimageio.spec.model.v0_5.TimeInputAxis, bioimageio.spec.model.v0_5.SpaceInputAxis], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.spec.model.v0_5.BatchAxis, bioimageio.spec.model.v0_5.ChannelAxis, bioimageio.spec.model.v0_5.IndexOutputAxis, Annotated[Union[Annotated[bioimageio.spec.model.v0_5.TimeOutputAxis, Tag(tag='wo_halo')], Annotated[bioimageio.spec.model.v0_5.TimeOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[Annotated[bioimageio.spec.model.v0_5.SpaceOutputAxis, Tag(tag='wo_halo')], Annotated[bioimageio.spec.model.v0_5.SpaceOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], bioimageio.core.Axis]])
56    def __init__(
57        self,
58        array: NDArray[Any],
59        dims: Sequence[Union[AxisId, AxisLike]],
60    ) -> None:
61        super().__init__()
62        axes = tuple(
63            a if isinstance(a, AxisId) else AxisInfo.create(a).id for a in dims
64        )
65        self._data = xr.DataArray(array, dims=axes)
@classmethod
def from_xarray(cls, data_array: xarray.core.dataarray.DataArray) -> Self:
160    @classmethod
161    def from_xarray(cls, data_array: xr.DataArray) -> Self:
162        """create a `Tensor` from an xarray data array
163
164        note for internal use: this factory method is round-trip save
165            for any `Tensor`'s  `data` property (an xarray.DataArray).
166        """
167        return cls(
168            array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims)
169        )

create a Tensor from an xarray data array

note for internal use: this factory method is round-trip save for any Tensor's data property (an xarray.DataArray).

@classmethod
def from_numpy( cls, array: numpy.ndarray[typing.Any, numpy.dtype[typing.Any]], *, dims: Union[bioimageio.spec.model.v0_5.AxisId, Literal['b', 'i', 't', 'c', 'z', 'y', 'x'], Annotated[Union[bioimageio.spec.model.v0_5.BatchAxis, bioimageio.spec.model.v0_5.ChannelAxis, bioimageio.spec.model.v0_5.IndexInputAxis, bioimageio.spec.model.v0_5.TimeInputAxis, bioimageio.spec.model.v0_5.SpaceInputAxis], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.spec.model.v0_5.BatchAxis, bioimageio.spec.model.v0_5.ChannelAxis, bioimageio.spec.model.v0_5.IndexOutputAxis, Annotated[Union[Annotated[bioimageio.spec.model.v0_5.TimeOutputAxis, Tag(tag='wo_halo')], Annotated[bioimageio.spec.model.v0_5.TimeOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[Annotated[bioimageio.spec.model.v0_5.SpaceOutputAxis, Tag(tag='wo_halo')], Annotated[bioimageio.spec.model.v0_5.SpaceOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], bioimageio.core.Axis, Sequence[Union[bioimageio.spec.model.v0_5.AxisId, Literal['b', 'i', 't', 'c', 'z', 'y', 'x'], Annotated[Union[bioimageio.spec.model.v0_5.BatchAxis, bioimageio.spec.model.v0_5.ChannelAxis, bioimageio.spec.model.v0_5.IndexInputAxis, bioimageio.spec.model.v0_5.TimeInputAxis, bioimageio.spec.model.v0_5.SpaceInputAxis], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.spec.model.v0_5.BatchAxis, bioimageio.spec.model.v0_5.ChannelAxis, bioimageio.spec.model.v0_5.IndexOutputAxis, Annotated[Union[Annotated[bioimageio.spec.model.v0_5.TimeOutputAxis, Tag(tag='wo_halo')], Annotated[bioimageio.spec.model.v0_5.TimeOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[Annotated[bioimageio.spec.model.v0_5.SpaceOutputAxis, Tag(tag='wo_halo')], Annotated[bioimageio.spec.model.v0_5.SpaceOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], bioimageio.core.Axis]], NoneType]) -> Tensor:
171    @classmethod
172    def from_numpy(
173        cls,
174        array: NDArray[Any],
175        *,
176        dims: Optional[Union[AxisLike, Sequence[AxisLike]]],
177    ) -> Tensor:
178        """create a `Tensor` from a numpy array
179
180        Args:
181            array: the nd numpy array
182            axes: A description of the array's axes,
183                if None axes are guessed (which might fail and raise a ValueError.)
184
185        Raises:
186            ValueError: if `axes` is None and axes guessing fails.
187        """
188
189        if dims is None:
190            return cls._interprete_array_wo_known_axes(array)
191        elif isinstance(dims, (str, Axis, v0_5.AxisBase)):
192            dims = [dims]
193
194        axis_infos = [AxisInfo.create(a) for a in dims]
195        original_shape = tuple(array.shape)
196
197        successful_view = _get_array_view(array, axis_infos)
198        if successful_view is None:
199            raise ValueError(
200                f"Array shape {original_shape} does not map to axes {dims}"
201            )
202
203        return Tensor(successful_view, dims=tuple(a.id for a in axis_infos))

create a Tensor from a numpy array

Arguments:
  • array: the nd numpy array
  • axes: A description of the array's axes, if None axes are guessed (which might fail and raise a ValueError.)
Raises:
  • ValueError: if axes is None and axes guessing fails.
data
205    @property
206    def data(self):
207        return self._data
dims
209    @property
210    def dims(self):  # TODO: rename to `axes`?
211        """Tuple of dimension names associated with this tensor."""
212        return cast(Tuple[AxisId, ...], self._data.dims)

Tuple of dimension names associated with this tensor.

dtype: Literal['bool', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64']
214    @property
215    def dtype(self) -> DTypeStr:
216        dt = str(self.data.dtype)  # pyright: ignore[reportUnknownArgumentType]
217        assert dt in get_args(DTypeStr)
218        return dt  # pyright: ignore[reportReturnType]
ndim
220    @property
221    def ndim(self):
222        """Number of tensor dimensions."""
223        return self._data.ndim

Number of tensor dimensions.

shape
225    @property
226    def shape(self):
227        """Tuple of tensor axes lengths"""
228        return self._data.shape

Tuple of tensor axes lengths

shape_tuple
230    @property
231    def shape_tuple(self):
232        """Tuple of tensor axes lengths"""
233        return self._data.shape

Tuple of tensor axes lengths

size
235    @property
236    def size(self):
237        """Number of elements in the tensor.
238
239        Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions.
240        """
241        return self._data.size

Number of elements in the tensor.

Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions.

sizes
243    @property
244    def sizes(self):
245        """Ordered, immutable mapping from axis ids to axis lengths."""
246        return cast(Mapping[AxisId, int], self.data.sizes)

Ordered, immutable mapping from axis ids to axis lengths.

tagged_shape
248    @property
249    def tagged_shape(self):
250        """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths."""
251        return self.sizes

(alias for sizes) Ordered, immutable mapping from axis ids to lengths.

def argmax(self) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
253    def argmax(self) -> Mapping[AxisId, int]:
254        ret = self._data.argmax(...)
255        assert isinstance(ret, dict)
256        return {cast(AxisId, k): cast(int, v.item()) for k, v in ret.items()}
def astype( self, dtype: Literal['bool', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64'], *, copy: bool = False):
258    def astype(self, dtype: DTypeStr, *, copy: bool = False):
259        """Return tensor cast to `dtype`
260
261        note: if dtype is already satisfied copy if `copy`"""
262        return self.__class__.from_xarray(self._data.astype(dtype, copy=copy))

Return tensor cast to dtype

note: if dtype is already satisfied copy if copy

def clip(self, min: Optional[float] = None, max: Optional[float] = None):
264    def clip(self, min: Optional[float] = None, max: Optional[float] = None):
265        """Return a tensor whose values are limited to [min, max].
266        At least one of max or min must be given."""
267        return self.__class__.from_xarray(self._data.clip(min, max))

Return a tensor whose values are limited to [min, max]. At least one of max or min must be given.

def crop_to( self, sizes: Mapping[bioimageio.spec.model.v0_5.AxisId, int], crop_where: Union[Literal['left', 'right', 'left_and_right'], Mapping[bioimageio.spec.model.v0_5.AxisId, Literal['left', 'right', 'left_and_right']]] = 'left_and_right') -> Self:
269    def crop_to(
270        self,
271        sizes: PerAxis[int],
272        crop_where: Union[
273            CropWhere,
274            PerAxis[CropWhere],
275        ] = "left_and_right",
276    ) -> Self:
277        """crop to match `sizes`"""
278        if isinstance(crop_where, str):
279            crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims}
280        else:
281            crop_axis_where = crop_where
282
283        slices: Dict[AxisId, SliceInfo] = {}
284
285        for a, s_is in self.sizes.items():
286            if a not in sizes or sizes[a] == s_is:
287                pass
288            elif sizes[a] > s_is:
289                logger.warning(
290                    "Cannot crop axis {} of size {} to larger size {}",
291                    a,
292                    s_is,
293                    sizes[a],
294                )
295            elif a not in crop_axis_where:
296                raise ValueError(
297                    f"Don't know where to crop axis {a}, `crop_where`={crop_where}"
298                )
299            else:
300                crop_this_axis_where = crop_axis_where[a]
301                if crop_this_axis_where == "left":
302                    slices[a] = SliceInfo(s_is - sizes[a], s_is)
303                elif crop_this_axis_where == "right":
304                    slices[a] = SliceInfo(0, sizes[a])
305                elif crop_this_axis_where == "left_and_right":
306                    slices[a] = SliceInfo(
307                        start := (s_is - sizes[a]) // 2, sizes[a] + start
308                    )
309                else:
310                    assert_never(crop_this_axis_where)
311
312        return self[slices]

crop to match sizes

def expand_dims( self, dims: Union[Sequence[bioimageio.spec.model.v0_5.AxisId], Mapping[bioimageio.spec.model.v0_5.AxisId, int]]) -> Self:
314    def expand_dims(self, dims: Union[Sequence[AxisId], PerAxis[int]]) -> Self:
315        return self.__class__.from_xarray(self._data.expand_dims(dims=dims))
def item( self, key: Union[NoneType, bioimageio.core.common.SliceInfo, slice, int, Mapping[bioimageio.spec.model.v0_5.AxisId, Union[bioimageio.core.common.SliceInfo, slice, int]]] = None):
317    def item(
318        self,
319        key: Union[
320            None, SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]]
321        ] = None,
322    ):
323        """Copy a tensor element to a standard Python scalar and return it."""
324        if key is None:
325            ret = self._data.item()
326        else:
327            ret = self[key]._data.item()
328
329        assert isinstance(ret, (bool, float, int))
330        return ret

Copy a tensor element to a standard Python scalar and return it.

def mean( self, dim: Union[bioimageio.spec.model.v0_5.AxisId, Sequence[bioimageio.spec.model.v0_5.AxisId], NoneType] = None) -> Self:
332    def mean(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
333        return self.__class__.from_xarray(self._data.mean(dim=dim))
def pad( self, pad_width: Mapping[bioimageio.spec.model.v0_5.AxisId, Union[int, Tuple[int, int], bioimageio.core.common.PadWidth]], mode: Literal['edge', 'reflect', 'symmetric'] = 'symmetric') -> Self:
335    def pad(
336        self,
337        pad_width: PerAxis[PadWidthLike],
338        mode: PadMode = "symmetric",
339    ) -> Self:
340        pad_width = {a: PadWidth.create(p) for a, p in pad_width.items()}
341        return self.__class__.from_xarray(
342            self._data.pad(pad_width=pad_width, mode=mode)
343        )
def pad_to( self, sizes: Mapping[bioimageio.spec.model.v0_5.AxisId, int], pad_where: Union[Literal['left', 'right', 'left_and_right'], Mapping[bioimageio.spec.model.v0_5.AxisId, Literal['left', 'right', 'left_and_right']]] = 'left_and_right', mode: Literal['edge', 'reflect', 'symmetric'] = 'symmetric') -> Self:
345    def pad_to(
346        self,
347        sizes: PerAxis[int],
348        pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right",
349        mode: PadMode = "symmetric",
350    ) -> Self:
351        """pad `tensor` to match `sizes`"""
352        if isinstance(pad_where, str):
353            pad_axis_where: PerAxis[PadWhere] = {a: pad_where for a in self.dims}
354        else:
355            pad_axis_where = pad_where
356
357        pad_width: Dict[AxisId, PadWidth] = {}
358        for a, s_is in self.sizes.items():
359            if a not in sizes or sizes[a] == s_is:
360                pad_width[a] = PadWidth(0, 0)
361            elif s_is > sizes[a]:
362                pad_width[a] = PadWidth(0, 0)
363                logger.warning(
364                    "Cannot pad axis {} of size {} to smaller size {}",
365                    a,
366                    s_is,
367                    sizes[a],
368                )
369            elif a not in pad_axis_where:
370                raise ValueError(
371                    f"Don't know where to pad axis {a}, `pad_where`={pad_where}"
372                )
373            else:
374                pad_this_axis_where = pad_axis_where[a]
375                d = sizes[a] - s_is
376                if pad_this_axis_where == "left":
377                    pad_width[a] = PadWidth(d, 0)
378                elif pad_this_axis_where == "right":
379                    pad_width[a] = PadWidth(0, d)
380                elif pad_this_axis_where == "left_and_right":
381                    pad_width[a] = PadWidth(left := d // 2, d - left)
382                else:
383                    assert_never(pad_this_axis_where)
384
385        return self.pad(pad_width, mode)

pad tensor to match sizes

def quantile( self, q: Union[float, Sequence[float]], dim: Union[bioimageio.spec.model.v0_5.AxisId, Sequence[bioimageio.spec.model.v0_5.AxisId], NoneType] = None) -> Self:
387    def quantile(
388        self,
389        q: Union[float, Sequence[float]],
390        dim: Optional[Union[AxisId, Sequence[AxisId]]] = None,
391    ) -> Self:
392        assert (
393            isinstance(q, (float, int))
394            and q >= 0.0
395            or not isinstance(q, (float, int))
396            and all(qq >= 0.0 for qq in q)
397        )
398        assert (
399            isinstance(q, (float, int))
400            and q <= 1.0
401            or not isinstance(q, (float, int))
402            and all(qq <= 1.0 for qq in q)
403        )
404        assert dim is None or (
405            (quantile_dim := AxisId("quantile")) != dim and quantile_dim not in set(dim)
406        )
407        return self.__class__.from_xarray(self._data.quantile(q, dim=dim))
def resize_to( self, sizes: Mapping[bioimageio.spec.model.v0_5.AxisId, int], *, pad_where: Union[Literal['left', 'right', 'left_and_right'], Mapping[bioimageio.spec.model.v0_5.AxisId, Literal['left', 'right', 'left_and_right']]] = 'left_and_right', crop_where: Union[Literal['left', 'right', 'left_and_right'], Mapping[bioimageio.spec.model.v0_5.AxisId, Literal['left', 'right', 'left_and_right']]] = 'left_and_right', pad_mode: Literal['edge', 'reflect', 'symmetric'] = 'symmetric'):
409    def resize_to(
410        self,
411        sizes: PerAxis[int],
412        *,
413        pad_where: Union[
414            PadWhere,
415            PerAxis[PadWhere],
416        ] = "left_and_right",
417        crop_where: Union[
418            CropWhere,
419            PerAxis[CropWhere],
420        ] = "left_and_right",
421        pad_mode: PadMode = "symmetric",
422    ):
423        """return cropped/padded tensor with `sizes`"""
424        crop_to_sizes: Dict[AxisId, int] = {}
425        pad_to_sizes: Dict[AxisId, int] = {}
426        new_axes = dict(sizes)
427        for a, s_is in self.sizes.items():
428            a = AxisId(str(a))
429            _ = new_axes.pop(a, None)
430            if a not in sizes or sizes[a] == s_is:
431                pass
432            elif s_is > sizes[a]:
433                crop_to_sizes[a] = sizes[a]
434            else:
435                pad_to_sizes[a] = sizes[a]
436
437        tensor = self
438        if crop_to_sizes:
439            tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where)
440
441        if pad_to_sizes:
442            tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode)
443
444        if new_axes:
445            tensor = tensor.expand_dims(new_axes)
446
447        return tensor

return cropped/padded tensor with sizes

def std( self, dim: Union[bioimageio.spec.model.v0_5.AxisId, Sequence[bioimageio.spec.model.v0_5.AxisId], NoneType] = None) -> Self:
449    def std(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
450        return self.__class__.from_xarray(self._data.std(dim=dim))
def sum( self, dim: Union[bioimageio.spec.model.v0_5.AxisId, Sequence[bioimageio.spec.model.v0_5.AxisId], NoneType] = None) -> Self:
452    def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
453        """Reduce this Tensor's data by applying sum along some dimension(s)."""
454        return self.__class__.from_xarray(self._data.sum(dim=dim))

Reduce this Tensor's data by applying sum along some dimension(s).

def transpose(self, axes: Sequence[bioimageio.spec.model.v0_5.AxisId]) -> Self:
456    def transpose(
457        self,
458        axes: Sequence[AxisId],
459    ) -> Self:
460        """return a transposed tensor
461
462        Args:
463            axes: the desired tensor axes
464        """
465        # expand missing tensor axes
466        missing_axes = tuple(a for a in axes if a not in self.dims)
467        array = self._data
468        if missing_axes:
469            array = array.expand_dims(missing_axes)
470
471        # transpose to the correct axis order
472        return self.__class__.from_xarray(array.transpose(*axes))

return a transposed tensor

Arguments:
  • axes: the desired tensor axes
def var( self, dim: Union[bioimageio.spec.model.v0_5.AxisId, Sequence[bioimageio.spec.model.v0_5.AxisId], NoneType] = None) -> Self:
474    def var(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
475        return self.__class__.from_xarray(self._data.var(dim=dim))