bioimageio.core.tensor

  1from __future__ import annotations
  2
  3import collections.abc
  4from itertools import permutations
  5from typing import (
  6    TYPE_CHECKING,
  7    Any,
  8    Callable,
  9    Dict,
 10    Iterator,
 11    Mapping,
 12    Optional,
 13    Sequence,
 14    Tuple,
 15    Union,
 16    cast,
 17    get_args,
 18)
 19
 20import numpy as np
 21import xarray as xr
 22from bioimageio.spec.model import v0_5
 23from loguru import logger
 24from numpy.typing import DTypeLike, NDArray
 25from typing_extensions import Self, assert_never
 26
 27from ._magic_tensor_ops import MagicTensorOpsMixin
 28from .axis import AxisId, AxisInfo, AxisLike, PerAxis
 29from .common import (
 30    CropWhere,
 31    DTypeStr,
 32    PadMode,
 33    PadWhere,
 34    PadWidth,
 35    PadWidthLike,
 36    SliceInfo,
 37)
 38
 39if TYPE_CHECKING:
 40    from numpy.typing import ArrayLike, NDArray
 41
 42
 43_ScalarOrArray = Union["ArrayLike", np.generic, "NDArray[Any]"]  # TODO: add "DaskArray"
 44
 45
 46# TODO: complete docstrings
 47# TODO: in the long run---with improved typing in xarray---we should probably replace `Tensor` with xr.DataArray
 48class Tensor(MagicTensorOpsMixin):
 49    """A wrapper around an xr.DataArray for better integration with bioimageio.spec
 50    and improved type annotations."""
 51
 52    _Compatible = Union["Tensor", xr.DataArray, _ScalarOrArray]
 53
 54    def __init__(
 55        self,
 56        array: NDArray[Any],
 57        dims: Sequence[Union[AxisId, AxisLike]],
 58    ) -> None:
 59        super().__init__()
 60        axes = tuple(
 61            a if isinstance(a, AxisId) else AxisInfo.create(a).id for a in dims
 62        )
 63        self._data = xr.DataArray(array, dims=axes)
 64
 65    def __array__(self, dtype: DTypeLike = None):
 66        return np.asarray(self._data, dtype=dtype)
 67
 68    def __getitem__(
 69        self,
 70        key: Union[
 71            SliceInfo,
 72            slice,
 73            int,
 74            PerAxis[Union[SliceInfo, slice, int]],
 75            Tensor,
 76            xr.DataArray,
 77        ],
 78    ) -> Self:
 79        if isinstance(key, SliceInfo):
 80            key = slice(*key)
 81        elif isinstance(key, collections.abc.Mapping):
 82            key = {
 83                a: s if isinstance(s, int) else s if isinstance(s, slice) else slice(*s)
 84                for a, s in key.items()
 85            }
 86        elif isinstance(key, Tensor):
 87            key = key._data
 88
 89        return self.__class__.from_xarray(self._data[key])
 90
 91    def __setitem__(
 92        self,
 93        key: Union[PerAxis[Union[SliceInfo, slice]], Tensor, xr.DataArray],
 94        value: Union[Tensor, xr.DataArray, float, int],
 95    ) -> None:
 96        if isinstance(key, Tensor):
 97            key = key._data
 98        elif isinstance(key, xr.DataArray):
 99            pass
100        else:
101            key = {a: s if isinstance(s, slice) else slice(*s) for a, s in key.items()}
102
103        if isinstance(value, Tensor):
104            value = value._data
105
106        self._data[key] = value
107
108    def __len__(self) -> int:
109        return len(self.data)
110
111    def _iter(self: Any) -> Iterator[Any]:
112        for n in range(len(self)):
113            yield self[n]
114
115    def __iter__(self: Any) -> Iterator[Any]:
116        if self.ndim == 0:
117            raise TypeError("iteration over a 0-d array")
118        return self._iter()
119
120    def _binary_op(
121        self,
122        other: _Compatible,
123        f: Callable[[Any, Any], Any],
124        reflexive: bool = False,
125    ) -> Self:
126        data = self._data._binary_op(  # pyright: ignore[reportPrivateUsage]
127            (other._data if isinstance(other, Tensor) else other),
128            f,
129            reflexive,
130        )
131        return self.__class__.from_xarray(data)
132
133    def _inplace_binary_op(
134        self,
135        other: _Compatible,
136        f: Callable[[Any, Any], Any],
137    ) -> Self:
138        _ = self._data._inplace_binary_op(  # pyright: ignore[reportPrivateUsage]
139            (
140                other_d
141                if (other_d := getattr(other, "data")) is not None
142                and isinstance(
143                    other_d,
144                    xr.DataArray,
145                )
146                else other
147            ),
148            f,
149        )
150        return self
151
152    def _unary_op(self, f: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Self:
153        data = self._data._unary_op(  # pyright: ignore[reportPrivateUsage]
154            f, *args, **kwargs
155        )
156        return self.__class__.from_xarray(data)
157
158    @classmethod
159    def from_xarray(cls, data_array: xr.DataArray) -> Self:
160        """create a `Tensor` from an xarray data array
161
162        note for internal use: this factory method is round-trip save
163            for any `Tensor`'s  `data` property (an xarray.DataArray).
164        """
165        return cls(
166            array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims)
167        )
168
169    @classmethod
170    def from_numpy(
171        cls,
172        array: NDArray[Any],
173        *,
174        dims: Optional[Union[AxisLike, Sequence[AxisLike]]],
175    ) -> Tensor:
176        """create a `Tensor` from a numpy array
177
178        Args:
179            array: the nd numpy array
180            axes: A description of the array's axes,
181                if None axes are guessed (which might fail and raise a ValueError.)
182
183        Raises:
184            ValueError: if `axes` is None and axes guessing fails.
185        """
186
187        if dims is None:
188            return cls._interprete_array_wo_known_axes(array)
189        elif isinstance(dims, collections.abc.Sequence):
190            dim_seq = list(dims)
191        else:
192            dim_seq = [dims]
193
194        axis_infos = [AxisInfo.create(a) for a in dim_seq]
195        original_shape = tuple(array.shape)
196
197        successful_view = _get_array_view(array, axis_infos)
198        if successful_view is None:
199            raise ValueError(
200                f"Array shape {original_shape} does not map to axes {dims}"
201            )
202
203        return Tensor(successful_view, dims=tuple(a.id for a in axis_infos))
204
205    @property
206    def data(self):
207        return self._data
208
209    @property
210    def dims(self):  # TODO: rename to `axes`?
211        """Tuple of dimension names associated with this tensor."""
212        return cast(Tuple[AxisId, ...], self._data.dims)
213
214    @property
215    def dtype(self) -> DTypeStr:
216        dt = str(self.data.dtype)  # pyright: ignore[reportUnknownArgumentType]
217        assert dt in get_args(DTypeStr)
218        return dt  # pyright: ignore[reportReturnType]
219
220    @property
221    def ndim(self):
222        """Number of tensor dimensions."""
223        return self._data.ndim
224
225    @property
226    def shape(self):
227        """Tuple of tensor axes lengths"""
228        return self._data.shape
229
230    @property
231    def shape_tuple(self):
232        """Tuple of tensor axes lengths"""
233        return self._data.shape
234
235    @property
236    def size(self):
237        """Number of elements in the tensor.
238
239        Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions.
240        """
241        return self._data.size
242
243    @property
244    def sizes(self):
245        """Ordered, immutable mapping from axis ids to axis lengths."""
246        return cast(Mapping[AxisId, int], self.data.sizes)
247
248    @property
249    def tagged_shape(self):
250        """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths."""
251        return self.sizes
252
253    def argmax(self) -> Mapping[AxisId, int]:
254        ret = self._data.argmax(...)
255        assert isinstance(ret, dict)
256        return {cast(AxisId, k): cast(int, v.item()) for k, v in ret.items()}
257
258    def astype(self, dtype: DTypeStr, *, copy: bool = False):
259        """Return tensor cast to `dtype`
260
261        note: if dtype is already satisfied copy if `copy`"""
262        return self.__class__.from_xarray(self._data.astype(dtype, copy=copy))
263
264    def clip(self, min: Optional[float] = None, max: Optional[float] = None):
265        """Return a tensor whose values are limited to [min, max].
266        At least one of max or min must be given."""
267        return self.__class__.from_xarray(self._data.clip(min, max))
268
269    def crop_to(
270        self,
271        sizes: PerAxis[int],
272        crop_where: Union[
273            CropWhere,
274            PerAxis[CropWhere],
275        ] = "left_and_right",
276    ) -> Self:
277        """crop to match `sizes`"""
278        if isinstance(crop_where, str):
279            crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims}
280        else:
281            crop_axis_where = crop_where
282
283        slices: Dict[AxisId, SliceInfo] = {}
284
285        for a, s_is in self.sizes.items():
286            if a not in sizes or sizes[a] == s_is:
287                pass
288            elif sizes[a] > s_is:
289                logger.warning(
290                    "Cannot crop axis {} of size {} to larger size {}",
291                    a,
292                    s_is,
293                    sizes[a],
294                )
295            elif a not in crop_axis_where:
296                raise ValueError(
297                    f"Don't know where to crop axis {a}, `crop_where`={crop_where}"
298                )
299            else:
300                crop_this_axis_where = crop_axis_where[a]
301                if crop_this_axis_where == "left":
302                    slices[a] = SliceInfo(s_is - sizes[a], s_is)
303                elif crop_this_axis_where == "right":
304                    slices[a] = SliceInfo(0, sizes[a])
305                elif crop_this_axis_where == "left_and_right":
306                    slices[a] = SliceInfo(
307                        start := (s_is - sizes[a]) // 2, sizes[a] + start
308                    )
309                else:
310                    assert_never(crop_this_axis_where)
311
312        return self[slices]
313
314    def expand_dims(self, dims: Union[Sequence[AxisId], PerAxis[int]]) -> Self:
315        return self.__class__.from_xarray(self._data.expand_dims(dims=dims))
316
317    def item(
318        self,
319        key: Union[
320            None, SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]]
321        ] = None,
322    ):
323        """Copy a tensor element to a standard Python scalar and return it."""
324        if key is None:
325            ret = self._data.item()
326        else:
327            ret = self[key]._data.item()
328
329        assert isinstance(ret, (bool, float, int))
330        return ret
331
332    def mean(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
333        return self.__class__.from_xarray(self._data.mean(dim=dim))
334
335    def pad(
336        self,
337        pad_width: PerAxis[PadWidthLike],
338        mode: PadMode = "symmetric",
339    ) -> Self:
340        pad_width = {a: PadWidth.create(p) for a, p in pad_width.items()}
341        return self.__class__.from_xarray(
342            self._data.pad(pad_width=pad_width, mode=mode)
343        )
344
345    def pad_to(
346        self,
347        sizes: PerAxis[int],
348        pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right",
349        mode: PadMode = "symmetric",
350    ) -> Self:
351        """pad `tensor` to match `sizes`"""
352        if isinstance(pad_where, str):
353            pad_axis_where: PerAxis[PadWhere] = {a: pad_where for a in self.dims}
354        else:
355            pad_axis_where = pad_where
356
357        pad_width: Dict[AxisId, PadWidth] = {}
358        for a, s_is in self.sizes.items():
359            if a not in sizes or sizes[a] == s_is:
360                pad_width[a] = PadWidth(0, 0)
361            elif s_is > sizes[a]:
362                pad_width[a] = PadWidth(0, 0)
363                logger.warning(
364                    "Cannot pad axis {} of size {} to smaller size {}",
365                    a,
366                    s_is,
367                    sizes[a],
368                )
369            elif a not in pad_axis_where:
370                raise ValueError(
371                    f"Don't know where to pad axis {a}, `pad_where`={pad_where}"
372                )
373            else:
374                pad_this_axis_where = pad_axis_where[a]
375                d = sizes[a] - s_is
376                if pad_this_axis_where == "left":
377                    pad_width[a] = PadWidth(d, 0)
378                elif pad_this_axis_where == "right":
379                    pad_width[a] = PadWidth(0, d)
380                elif pad_this_axis_where == "left_and_right":
381                    pad_width[a] = PadWidth(left := d // 2, d - left)
382                else:
383                    assert_never(pad_this_axis_where)
384
385        return self.pad(pad_width, mode)
386
387    def quantile(
388        self,
389        q: Union[float, Sequence[float]],
390        dim: Optional[Union[AxisId, Sequence[AxisId]]] = None,
391    ) -> Self:
392        assert (
393            isinstance(q, (float, int))
394            and q >= 0.0
395            or not isinstance(q, (float, int))
396            and all(qq >= 0.0 for qq in q)
397        )
398        assert (
399            isinstance(q, (float, int))
400            and q <= 1.0
401            or not isinstance(q, (float, int))
402            and all(qq <= 1.0 for qq in q)
403        )
404        assert dim is None or (
405            (quantile_dim := AxisId("quantile")) != dim and quantile_dim not in set(dim)
406        )
407        return self.__class__.from_xarray(self._data.quantile(q, dim=dim))
408
409    def resize_to(
410        self,
411        sizes: PerAxis[int],
412        *,
413        pad_where: Union[
414            PadWhere,
415            PerAxis[PadWhere],
416        ] = "left_and_right",
417        crop_where: Union[
418            CropWhere,
419            PerAxis[CropWhere],
420        ] = "left_and_right",
421        pad_mode: PadMode = "symmetric",
422    ):
423        """return cropped/padded tensor with `sizes`"""
424        crop_to_sizes: Dict[AxisId, int] = {}
425        pad_to_sizes: Dict[AxisId, int] = {}
426        new_axes = dict(sizes)
427        for a, s_is in self.sizes.items():
428            a = AxisId(str(a))
429            _ = new_axes.pop(a, None)
430            if a not in sizes or sizes[a] == s_is:
431                pass
432            elif s_is > sizes[a]:
433                crop_to_sizes[a] = sizes[a]
434            else:
435                pad_to_sizes[a] = sizes[a]
436
437        tensor = self
438        if crop_to_sizes:
439            tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where)
440
441        if pad_to_sizes:
442            tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode)
443
444        if new_axes:
445            tensor = tensor.expand_dims(new_axes)
446
447        return tensor
448
449    def std(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
450        return self.__class__.from_xarray(self._data.std(dim=dim))
451
452    def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
453        """Reduce this Tensor's data by applying sum along some dimension(s)."""
454        return self.__class__.from_xarray(self._data.sum(dim=dim))
455
456    def transpose(
457        self,
458        axes: Sequence[AxisId],
459    ) -> Self:
460        """return a transposed tensor
461
462        Args:
463            axes: the desired tensor axes
464        """
465        # expand missing tensor axes
466        missing_axes = tuple(a for a in axes if a not in self.dims)
467        array = self._data
468        if missing_axes:
469            array = array.expand_dims(missing_axes)
470
471        # transpose to the correct axis order
472        return self.__class__.from_xarray(array.transpose(*axes))
473
474    def var(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
475        return self.__class__.from_xarray(self._data.var(dim=dim))
476
477    @classmethod
478    def _interprete_array_wo_known_axes(cls, array: NDArray[Any]):
479        ndim = array.ndim
480        if ndim == 2:
481            current_axes = (
482                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[0]),
483                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[1]),
484            )
485        elif ndim == 3 and any(s <= 3 for s in array.shape):
486            current_axes = (
487                v0_5.ChannelAxis(
488                    channel_names=[
489                        v0_5.Identifier(f"channel{i}") for i in range(array.shape[0])
490                    ]
491                ),
492                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]),
493                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]),
494            )
495        elif ndim == 3:
496            current_axes = (
497                v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[0]),
498                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]),
499                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]),
500            )
501        elif ndim == 4:
502            current_axes = (
503                v0_5.ChannelAxis(
504                    channel_names=[
505                        v0_5.Identifier(f"channel{i}") for i in range(array.shape[0])
506                    ]
507                ),
508                v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[1]),
509                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[2]),
510                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[3]),
511            )
512        elif ndim == 5:
513            current_axes = (
514                v0_5.BatchAxis(),
515                v0_5.ChannelAxis(
516                    channel_names=[
517                        v0_5.Identifier(f"channel{i}") for i in range(array.shape[1])
518                    ]
519                ),
520                v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[2]),
521                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[3]),
522                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[4]),
523            )
524        else:
525            raise ValueError(f"Could not guess an axis mapping for {array.shape}")
526
527        return cls(array, dims=tuple(a.id for a in current_axes))
528
529
530def _add_singletons(arr: NDArray[Any], axis_infos: Sequence[AxisInfo]):
531    if len(arr.shape) > len(axis_infos):
532        # remove singletons
533        for i, s in enumerate(arr.shape):
534            if s == 1:
535                arr = np.take(arr, 0, axis=i)
536                if len(arr.shape) == len(axis_infos):
537                    break
538
539    # add singletons if nececsary
540    for i, a in enumerate(axis_infos):
541        if len(arr.shape) >= len(axis_infos):
542            break
543
544        if a.maybe_singleton:
545            arr = np.expand_dims(arr, i)
546
547    return arr
548
549
550def _get_array_view(
551    original_array: NDArray[Any], axis_infos: Sequence[AxisInfo]
552) -> Optional[NDArray[Any]]:
553    perms = list(permutations(range(len(original_array.shape))))
554    perms.insert(1, perms.pop())  # try A and A.T first
555
556    for perm in perms:
557        view = original_array.transpose(perm)
558        view = _add_singletons(view, axis_infos)
559        if len(view.shape) != len(axis_infos):
560            return None
561
562        for s, a in zip(view.shape, axis_infos):
563            if s == 1 and not a.maybe_singleton:
564                break
565        else:
566            return view
567
568    return None
class Tensor(bioimageio.core._magic_tensor_ops.MagicTensorOpsMixin):
 49class Tensor(MagicTensorOpsMixin):
 50    """A wrapper around an xr.DataArray for better integration with bioimageio.spec
 51    and improved type annotations."""
 52
 53    _Compatible = Union["Tensor", xr.DataArray, _ScalarOrArray]
 54
 55    def __init__(
 56        self,
 57        array: NDArray[Any],
 58        dims: Sequence[Union[AxisId, AxisLike]],
 59    ) -> None:
 60        super().__init__()
 61        axes = tuple(
 62            a if isinstance(a, AxisId) else AxisInfo.create(a).id for a in dims
 63        )
 64        self._data = xr.DataArray(array, dims=axes)
 65
 66    def __array__(self, dtype: DTypeLike = None):
 67        return np.asarray(self._data, dtype=dtype)
 68
 69    def __getitem__(
 70        self,
 71        key: Union[
 72            SliceInfo,
 73            slice,
 74            int,
 75            PerAxis[Union[SliceInfo, slice, int]],
 76            Tensor,
 77            xr.DataArray,
 78        ],
 79    ) -> Self:
 80        if isinstance(key, SliceInfo):
 81            key = slice(*key)
 82        elif isinstance(key, collections.abc.Mapping):
 83            key = {
 84                a: s if isinstance(s, int) else s if isinstance(s, slice) else slice(*s)
 85                for a, s in key.items()
 86            }
 87        elif isinstance(key, Tensor):
 88            key = key._data
 89
 90        return self.__class__.from_xarray(self._data[key])
 91
 92    def __setitem__(
 93        self,
 94        key: Union[PerAxis[Union[SliceInfo, slice]], Tensor, xr.DataArray],
 95        value: Union[Tensor, xr.DataArray, float, int],
 96    ) -> None:
 97        if isinstance(key, Tensor):
 98            key = key._data
 99        elif isinstance(key, xr.DataArray):
100            pass
101        else:
102            key = {a: s if isinstance(s, slice) else slice(*s) for a, s in key.items()}
103
104        if isinstance(value, Tensor):
105            value = value._data
106
107        self._data[key] = value
108
109    def __len__(self) -> int:
110        return len(self.data)
111
112    def _iter(self: Any) -> Iterator[Any]:
113        for n in range(len(self)):
114            yield self[n]
115
116    def __iter__(self: Any) -> Iterator[Any]:
117        if self.ndim == 0:
118            raise TypeError("iteration over a 0-d array")
119        return self._iter()
120
121    def _binary_op(
122        self,
123        other: _Compatible,
124        f: Callable[[Any, Any], Any],
125        reflexive: bool = False,
126    ) -> Self:
127        data = self._data._binary_op(  # pyright: ignore[reportPrivateUsage]
128            (other._data if isinstance(other, Tensor) else other),
129            f,
130            reflexive,
131        )
132        return self.__class__.from_xarray(data)
133
134    def _inplace_binary_op(
135        self,
136        other: _Compatible,
137        f: Callable[[Any, Any], Any],
138    ) -> Self:
139        _ = self._data._inplace_binary_op(  # pyright: ignore[reportPrivateUsage]
140            (
141                other_d
142                if (other_d := getattr(other, "data")) is not None
143                and isinstance(
144                    other_d,
145                    xr.DataArray,
146                )
147                else other
148            ),
149            f,
150        )
151        return self
152
153    def _unary_op(self, f: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Self:
154        data = self._data._unary_op(  # pyright: ignore[reportPrivateUsage]
155            f, *args, **kwargs
156        )
157        return self.__class__.from_xarray(data)
158
159    @classmethod
160    def from_xarray(cls, data_array: xr.DataArray) -> Self:
161        """create a `Tensor` from an xarray data array
162
163        note for internal use: this factory method is round-trip save
164            for any `Tensor`'s  `data` property (an xarray.DataArray).
165        """
166        return cls(
167            array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims)
168        )
169
170    @classmethod
171    def from_numpy(
172        cls,
173        array: NDArray[Any],
174        *,
175        dims: Optional[Union[AxisLike, Sequence[AxisLike]]],
176    ) -> Tensor:
177        """create a `Tensor` from a numpy array
178
179        Args:
180            array: the nd numpy array
181            axes: A description of the array's axes,
182                if None axes are guessed (which might fail and raise a ValueError.)
183
184        Raises:
185            ValueError: if `axes` is None and axes guessing fails.
186        """
187
188        if dims is None:
189            return cls._interprete_array_wo_known_axes(array)
190        elif isinstance(dims, collections.abc.Sequence):
191            dim_seq = list(dims)
192        else:
193            dim_seq = [dims]
194
195        axis_infos = [AxisInfo.create(a) for a in dim_seq]
196        original_shape = tuple(array.shape)
197
198        successful_view = _get_array_view(array, axis_infos)
199        if successful_view is None:
200            raise ValueError(
201                f"Array shape {original_shape} does not map to axes {dims}"
202            )
203
204        return Tensor(successful_view, dims=tuple(a.id for a in axis_infos))
205
206    @property
207    def data(self):
208        return self._data
209
210    @property
211    def dims(self):  # TODO: rename to `axes`?
212        """Tuple of dimension names associated with this tensor."""
213        return cast(Tuple[AxisId, ...], self._data.dims)
214
215    @property
216    def dtype(self) -> DTypeStr:
217        dt = str(self.data.dtype)  # pyright: ignore[reportUnknownArgumentType]
218        assert dt in get_args(DTypeStr)
219        return dt  # pyright: ignore[reportReturnType]
220
221    @property
222    def ndim(self):
223        """Number of tensor dimensions."""
224        return self._data.ndim
225
226    @property
227    def shape(self):
228        """Tuple of tensor axes lengths"""
229        return self._data.shape
230
231    @property
232    def shape_tuple(self):
233        """Tuple of tensor axes lengths"""
234        return self._data.shape
235
236    @property
237    def size(self):
238        """Number of elements in the tensor.
239
240        Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions.
241        """
242        return self._data.size
243
244    @property
245    def sizes(self):
246        """Ordered, immutable mapping from axis ids to axis lengths."""
247        return cast(Mapping[AxisId, int], self.data.sizes)
248
249    @property
250    def tagged_shape(self):
251        """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths."""
252        return self.sizes
253
254    def argmax(self) -> Mapping[AxisId, int]:
255        ret = self._data.argmax(...)
256        assert isinstance(ret, dict)
257        return {cast(AxisId, k): cast(int, v.item()) for k, v in ret.items()}
258
259    def astype(self, dtype: DTypeStr, *, copy: bool = False):
260        """Return tensor cast to `dtype`
261
262        note: if dtype is already satisfied copy if `copy`"""
263        return self.__class__.from_xarray(self._data.astype(dtype, copy=copy))
264
265    def clip(self, min: Optional[float] = None, max: Optional[float] = None):
266        """Return a tensor whose values are limited to [min, max].
267        At least one of max or min must be given."""
268        return self.__class__.from_xarray(self._data.clip(min, max))
269
270    def crop_to(
271        self,
272        sizes: PerAxis[int],
273        crop_where: Union[
274            CropWhere,
275            PerAxis[CropWhere],
276        ] = "left_and_right",
277    ) -> Self:
278        """crop to match `sizes`"""
279        if isinstance(crop_where, str):
280            crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims}
281        else:
282            crop_axis_where = crop_where
283
284        slices: Dict[AxisId, SliceInfo] = {}
285
286        for a, s_is in self.sizes.items():
287            if a not in sizes or sizes[a] == s_is:
288                pass
289            elif sizes[a] > s_is:
290                logger.warning(
291                    "Cannot crop axis {} of size {} to larger size {}",
292                    a,
293                    s_is,
294                    sizes[a],
295                )
296            elif a not in crop_axis_where:
297                raise ValueError(
298                    f"Don't know where to crop axis {a}, `crop_where`={crop_where}"
299                )
300            else:
301                crop_this_axis_where = crop_axis_where[a]
302                if crop_this_axis_where == "left":
303                    slices[a] = SliceInfo(s_is - sizes[a], s_is)
304                elif crop_this_axis_where == "right":
305                    slices[a] = SliceInfo(0, sizes[a])
306                elif crop_this_axis_where == "left_and_right":
307                    slices[a] = SliceInfo(
308                        start := (s_is - sizes[a]) // 2, sizes[a] + start
309                    )
310                else:
311                    assert_never(crop_this_axis_where)
312
313        return self[slices]
314
315    def expand_dims(self, dims: Union[Sequence[AxisId], PerAxis[int]]) -> Self:
316        return self.__class__.from_xarray(self._data.expand_dims(dims=dims))
317
318    def item(
319        self,
320        key: Union[
321            None, SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]]
322        ] = None,
323    ):
324        """Copy a tensor element to a standard Python scalar and return it."""
325        if key is None:
326            ret = self._data.item()
327        else:
328            ret = self[key]._data.item()
329
330        assert isinstance(ret, (bool, float, int))
331        return ret
332
333    def mean(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
334        return self.__class__.from_xarray(self._data.mean(dim=dim))
335
336    def pad(
337        self,
338        pad_width: PerAxis[PadWidthLike],
339        mode: PadMode = "symmetric",
340    ) -> Self:
341        pad_width = {a: PadWidth.create(p) for a, p in pad_width.items()}
342        return self.__class__.from_xarray(
343            self._data.pad(pad_width=pad_width, mode=mode)
344        )
345
346    def pad_to(
347        self,
348        sizes: PerAxis[int],
349        pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right",
350        mode: PadMode = "symmetric",
351    ) -> Self:
352        """pad `tensor` to match `sizes`"""
353        if isinstance(pad_where, str):
354            pad_axis_where: PerAxis[PadWhere] = {a: pad_where for a in self.dims}
355        else:
356            pad_axis_where = pad_where
357
358        pad_width: Dict[AxisId, PadWidth] = {}
359        for a, s_is in self.sizes.items():
360            if a not in sizes or sizes[a] == s_is:
361                pad_width[a] = PadWidth(0, 0)
362            elif s_is > sizes[a]:
363                pad_width[a] = PadWidth(0, 0)
364                logger.warning(
365                    "Cannot pad axis {} of size {} to smaller size {}",
366                    a,
367                    s_is,
368                    sizes[a],
369                )
370            elif a not in pad_axis_where:
371                raise ValueError(
372                    f"Don't know where to pad axis {a}, `pad_where`={pad_where}"
373                )
374            else:
375                pad_this_axis_where = pad_axis_where[a]
376                d = sizes[a] - s_is
377                if pad_this_axis_where == "left":
378                    pad_width[a] = PadWidth(d, 0)
379                elif pad_this_axis_where == "right":
380                    pad_width[a] = PadWidth(0, d)
381                elif pad_this_axis_where == "left_and_right":
382                    pad_width[a] = PadWidth(left := d // 2, d - left)
383                else:
384                    assert_never(pad_this_axis_where)
385
386        return self.pad(pad_width, mode)
387
388    def quantile(
389        self,
390        q: Union[float, Sequence[float]],
391        dim: Optional[Union[AxisId, Sequence[AxisId]]] = None,
392    ) -> Self:
393        assert (
394            isinstance(q, (float, int))
395            and q >= 0.0
396            or not isinstance(q, (float, int))
397            and all(qq >= 0.0 for qq in q)
398        )
399        assert (
400            isinstance(q, (float, int))
401            and q <= 1.0
402            or not isinstance(q, (float, int))
403            and all(qq <= 1.0 for qq in q)
404        )
405        assert dim is None or (
406            (quantile_dim := AxisId("quantile")) != dim and quantile_dim not in set(dim)
407        )
408        return self.__class__.from_xarray(self._data.quantile(q, dim=dim))
409
410    def resize_to(
411        self,
412        sizes: PerAxis[int],
413        *,
414        pad_where: Union[
415            PadWhere,
416            PerAxis[PadWhere],
417        ] = "left_and_right",
418        crop_where: Union[
419            CropWhere,
420            PerAxis[CropWhere],
421        ] = "left_and_right",
422        pad_mode: PadMode = "symmetric",
423    ):
424        """return cropped/padded tensor with `sizes`"""
425        crop_to_sizes: Dict[AxisId, int] = {}
426        pad_to_sizes: Dict[AxisId, int] = {}
427        new_axes = dict(sizes)
428        for a, s_is in self.sizes.items():
429            a = AxisId(str(a))
430            _ = new_axes.pop(a, None)
431            if a not in sizes or sizes[a] == s_is:
432                pass
433            elif s_is > sizes[a]:
434                crop_to_sizes[a] = sizes[a]
435            else:
436                pad_to_sizes[a] = sizes[a]
437
438        tensor = self
439        if crop_to_sizes:
440            tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where)
441
442        if pad_to_sizes:
443            tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode)
444
445        if new_axes:
446            tensor = tensor.expand_dims(new_axes)
447
448        return tensor
449
450    def std(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
451        return self.__class__.from_xarray(self._data.std(dim=dim))
452
453    def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
454        """Reduce this Tensor's data by applying sum along some dimension(s)."""
455        return self.__class__.from_xarray(self._data.sum(dim=dim))
456
457    def transpose(
458        self,
459        axes: Sequence[AxisId],
460    ) -> Self:
461        """return a transposed tensor
462
463        Args:
464            axes: the desired tensor axes
465        """
466        # expand missing tensor axes
467        missing_axes = tuple(a for a in axes if a not in self.dims)
468        array = self._data
469        if missing_axes:
470            array = array.expand_dims(missing_axes)
471
472        # transpose to the correct axis order
473        return self.__class__.from_xarray(array.transpose(*axes))
474
475    def var(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
476        return self.__class__.from_xarray(self._data.var(dim=dim))
477
478    @classmethod
479    def _interprete_array_wo_known_axes(cls, array: NDArray[Any]):
480        ndim = array.ndim
481        if ndim == 2:
482            current_axes = (
483                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[0]),
484                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[1]),
485            )
486        elif ndim == 3 and any(s <= 3 for s in array.shape):
487            current_axes = (
488                v0_5.ChannelAxis(
489                    channel_names=[
490                        v0_5.Identifier(f"channel{i}") for i in range(array.shape[0])
491                    ]
492                ),
493                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]),
494                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]),
495            )
496        elif ndim == 3:
497            current_axes = (
498                v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[0]),
499                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]),
500                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]),
501            )
502        elif ndim == 4:
503            current_axes = (
504                v0_5.ChannelAxis(
505                    channel_names=[
506                        v0_5.Identifier(f"channel{i}") for i in range(array.shape[0])
507                    ]
508                ),
509                v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[1]),
510                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[2]),
511                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[3]),
512            )
513        elif ndim == 5:
514            current_axes = (
515                v0_5.BatchAxis(),
516                v0_5.ChannelAxis(
517                    channel_names=[
518                        v0_5.Identifier(f"channel{i}") for i in range(array.shape[1])
519                    ]
520                ),
521                v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[2]),
522                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[3]),
523                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[4]),
524            )
525        else:
526            raise ValueError(f"Could not guess an axis mapping for {array.shape}")
527
528        return cls(array, dims=tuple(a.id for a in current_axes))

A wrapper around an xr.DataArray for better integration with bioimageio.spec and improved type annotations.

Tensor( array: numpy.ndarray[tuple[int, ...], numpy.dtype[typing.Any]], dims: Sequence[Union[str, bioimageio.spec.model.v0_5.AxisId, Literal['b', 'i', 't', 'c', 'z', 'y', 'x'], bioimageio.core.axis.AxisDescrLike, Annotated[Union[bioimageio.spec.model.v0_5.BatchAxis, bioimageio.spec.model.v0_5.ChannelAxis, bioimageio.spec.model.v0_5.IndexInputAxis, bioimageio.spec.model.v0_5.TimeInputAxis, bioimageio.spec.model.v0_5.SpaceInputAxis], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.spec.model.v0_5.BatchAxis, bioimageio.spec.model.v0_5.ChannelAxis, bioimageio.spec.model.v0_5.IndexOutputAxis, Annotated[Union[Annotated[bioimageio.spec.model.v0_5.TimeOutputAxis, Tag(tag='wo_halo')], Annotated[bioimageio.spec.model.v0_5.TimeOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[Annotated[bioimageio.spec.model.v0_5.SpaceOutputAxis, Tag(tag='wo_halo')], Annotated[bioimageio.spec.model.v0_5.SpaceOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], bioimageio.core.Axis]])
55    def __init__(
56        self,
57        array: NDArray[Any],
58        dims: Sequence[Union[AxisId, AxisLike]],
59    ) -> None:
60        super().__init__()
61        axes = tuple(
62            a if isinstance(a, AxisId) else AxisInfo.create(a).id for a in dims
63        )
64        self._data = xr.DataArray(array, dims=axes)
@classmethod
def from_xarray(cls, data_array: xarray.core.dataarray.DataArray) -> Self:
159    @classmethod
160    def from_xarray(cls, data_array: xr.DataArray) -> Self:
161        """create a `Tensor` from an xarray data array
162
163        note for internal use: this factory method is round-trip save
164            for any `Tensor`'s  `data` property (an xarray.DataArray).
165        """
166        return cls(
167            array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims)
168        )

create a Tensor from an xarray data array

note for internal use: this factory method is round-trip save for any Tensor's data property (an xarray.DataArray).

@classmethod
def from_numpy( cls, array: numpy.ndarray[tuple[int, ...], numpy.dtype[typing.Any]], *, dims: Union[str, bioimageio.spec.model.v0_5.AxisId, Literal['b', 'i', 't', 'c', 'z', 'y', 'x'], bioimageio.core.axis.AxisDescrLike, Annotated[Union[bioimageio.spec.model.v0_5.BatchAxis, bioimageio.spec.model.v0_5.ChannelAxis, bioimageio.spec.model.v0_5.IndexInputAxis, bioimageio.spec.model.v0_5.TimeInputAxis, bioimageio.spec.model.v0_5.SpaceInputAxis], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.spec.model.v0_5.BatchAxis, bioimageio.spec.model.v0_5.ChannelAxis, bioimageio.spec.model.v0_5.IndexOutputAxis, Annotated[Union[Annotated[bioimageio.spec.model.v0_5.TimeOutputAxis, Tag(tag='wo_halo')], Annotated[bioimageio.spec.model.v0_5.TimeOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[Annotated[bioimageio.spec.model.v0_5.SpaceOutputAxis, Tag(tag='wo_halo')], Annotated[bioimageio.spec.model.v0_5.SpaceOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], bioimageio.core.Axis, Sequence[Union[str, bioimageio.spec.model.v0_5.AxisId, Literal['b', 'i', 't', 'c', 'z', 'y', 'x'], bioimageio.core.axis.AxisDescrLike, Annotated[Union[bioimageio.spec.model.v0_5.BatchAxis, bioimageio.spec.model.v0_5.ChannelAxis, bioimageio.spec.model.v0_5.IndexInputAxis, bioimageio.spec.model.v0_5.TimeInputAxis, bioimageio.spec.model.v0_5.SpaceInputAxis], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.spec.model.v0_5.BatchAxis, bioimageio.spec.model.v0_5.ChannelAxis, bioimageio.spec.model.v0_5.IndexOutputAxis, Annotated[Union[Annotated[bioimageio.spec.model.v0_5.TimeOutputAxis, Tag(tag='wo_halo')], Annotated[bioimageio.spec.model.v0_5.TimeOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[Annotated[bioimageio.spec.model.v0_5.SpaceOutputAxis, Tag(tag='wo_halo')], Annotated[bioimageio.spec.model.v0_5.SpaceOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], bioimageio.core.Axis]], NoneType]) -> Tensor:
170    @classmethod
171    def from_numpy(
172        cls,
173        array: NDArray[Any],
174        *,
175        dims: Optional[Union[AxisLike, Sequence[AxisLike]]],
176    ) -> Tensor:
177        """create a `Tensor` from a numpy array
178
179        Args:
180            array: the nd numpy array
181            axes: A description of the array's axes,
182                if None axes are guessed (which might fail and raise a ValueError.)
183
184        Raises:
185            ValueError: if `axes` is None and axes guessing fails.
186        """
187
188        if dims is None:
189            return cls._interprete_array_wo_known_axes(array)
190        elif isinstance(dims, collections.abc.Sequence):
191            dim_seq = list(dims)
192        else:
193            dim_seq = [dims]
194
195        axis_infos = [AxisInfo.create(a) for a in dim_seq]
196        original_shape = tuple(array.shape)
197
198        successful_view = _get_array_view(array, axis_infos)
199        if successful_view is None:
200            raise ValueError(
201                f"Array shape {original_shape} does not map to axes {dims}"
202            )
203
204        return Tensor(successful_view, dims=tuple(a.id for a in axis_infos))

create a Tensor from a numpy array

Arguments:
  • array: the nd numpy array
  • axes: A description of the array's axes, if None axes are guessed (which might fail and raise a ValueError.)
Raises:
  • ValueError: if axes is None and axes guessing fails.
data
206    @property
207    def data(self):
208        return self._data
dims
210    @property
211    def dims(self):  # TODO: rename to `axes`?
212        """Tuple of dimension names associated with this tensor."""
213        return cast(Tuple[AxisId, ...], self._data.dims)

Tuple of dimension names associated with this tensor.

dtype: Literal['bool', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64']
215    @property
216    def dtype(self) -> DTypeStr:
217        dt = str(self.data.dtype)  # pyright: ignore[reportUnknownArgumentType]
218        assert dt in get_args(DTypeStr)
219        return dt  # pyright: ignore[reportReturnType]
ndim
221    @property
222    def ndim(self):
223        """Number of tensor dimensions."""
224        return self._data.ndim

Number of tensor dimensions.

shape
226    @property
227    def shape(self):
228        """Tuple of tensor axes lengths"""
229        return self._data.shape

Tuple of tensor axes lengths

shape_tuple
231    @property
232    def shape_tuple(self):
233        """Tuple of tensor axes lengths"""
234        return self._data.shape

Tuple of tensor axes lengths

size
236    @property
237    def size(self):
238        """Number of elements in the tensor.
239
240        Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions.
241        """
242        return self._data.size

Number of elements in the tensor.

Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions.

sizes
244    @property
245    def sizes(self):
246        """Ordered, immutable mapping from axis ids to axis lengths."""
247        return cast(Mapping[AxisId, int], self.data.sizes)

Ordered, immutable mapping from axis ids to axis lengths.

tagged_shape
249    @property
250    def tagged_shape(self):
251        """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths."""
252        return self.sizes

(alias for sizes) Ordered, immutable mapping from axis ids to lengths.

def argmax(self) -> Mapping[bioimageio.spec.model.v0_5.AxisId, int]:
254    def argmax(self) -> Mapping[AxisId, int]:
255        ret = self._data.argmax(...)
256        assert isinstance(ret, dict)
257        return {cast(AxisId, k): cast(int, v.item()) for k, v in ret.items()}
def astype( self, dtype: Literal['bool', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64'], *, copy: bool = False):
259    def astype(self, dtype: DTypeStr, *, copy: bool = False):
260        """Return tensor cast to `dtype`
261
262        note: if dtype is already satisfied copy if `copy`"""
263        return self.__class__.from_xarray(self._data.astype(dtype, copy=copy))

Return tensor cast to dtype

note: if dtype is already satisfied copy if copy

def clip(self, min: Optional[float] = None, max: Optional[float] = None):
265    def clip(self, min: Optional[float] = None, max: Optional[float] = None):
266        """Return a tensor whose values are limited to [min, max].
267        At least one of max or min must be given."""
268        return self.__class__.from_xarray(self._data.clip(min, max))

Return a tensor whose values are limited to [min, max]. At least one of max or min must be given.

def crop_to( self, sizes: Mapping[bioimageio.spec.model.v0_5.AxisId, int], crop_where: Union[Literal['left', 'right', 'left_and_right'], Mapping[bioimageio.spec.model.v0_5.AxisId, Literal['left', 'right', 'left_and_right']]] = 'left_and_right') -> Self:
270    def crop_to(
271        self,
272        sizes: PerAxis[int],
273        crop_where: Union[
274            CropWhere,
275            PerAxis[CropWhere],
276        ] = "left_and_right",
277    ) -> Self:
278        """crop to match `sizes`"""
279        if isinstance(crop_where, str):
280            crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims}
281        else:
282            crop_axis_where = crop_where
283
284        slices: Dict[AxisId, SliceInfo] = {}
285
286        for a, s_is in self.sizes.items():
287            if a not in sizes or sizes[a] == s_is:
288                pass
289            elif sizes[a] > s_is:
290                logger.warning(
291                    "Cannot crop axis {} of size {} to larger size {}",
292                    a,
293                    s_is,
294                    sizes[a],
295                )
296            elif a not in crop_axis_where:
297                raise ValueError(
298                    f"Don't know where to crop axis {a}, `crop_where`={crop_where}"
299                )
300            else:
301                crop_this_axis_where = crop_axis_where[a]
302                if crop_this_axis_where == "left":
303                    slices[a] = SliceInfo(s_is - sizes[a], s_is)
304                elif crop_this_axis_where == "right":
305                    slices[a] = SliceInfo(0, sizes[a])
306                elif crop_this_axis_where == "left_and_right":
307                    slices[a] = SliceInfo(
308                        start := (s_is - sizes[a]) // 2, sizes[a] + start
309                    )
310                else:
311                    assert_never(crop_this_axis_where)
312
313        return self[slices]

crop to match sizes

def expand_dims( self, dims: Union[Sequence[bioimageio.spec.model.v0_5.AxisId], Mapping[bioimageio.spec.model.v0_5.AxisId, int]]) -> Self:
315    def expand_dims(self, dims: Union[Sequence[AxisId], PerAxis[int]]) -> Self:
316        return self.__class__.from_xarray(self._data.expand_dims(dims=dims))
def item( self, key: Union[NoneType, bioimageio.core.common.SliceInfo, slice, int, Mapping[bioimageio.spec.model.v0_5.AxisId, Union[bioimageio.core.common.SliceInfo, slice, int]]] = None):
318    def item(
319        self,
320        key: Union[
321            None, SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]]
322        ] = None,
323    ):
324        """Copy a tensor element to a standard Python scalar and return it."""
325        if key is None:
326            ret = self._data.item()
327        else:
328            ret = self[key]._data.item()
329
330        assert isinstance(ret, (bool, float, int))
331        return ret

Copy a tensor element to a standard Python scalar and return it.

def mean( self, dim: Union[bioimageio.spec.model.v0_5.AxisId, Sequence[bioimageio.spec.model.v0_5.AxisId], NoneType] = None) -> Self:
333    def mean(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
334        return self.__class__.from_xarray(self._data.mean(dim=dim))
def pad( self, pad_width: Mapping[bioimageio.spec.model.v0_5.AxisId, Union[int, Tuple[int, int], bioimageio.core.common.PadWidth]], mode: Literal['edge', 'reflect', 'symmetric'] = 'symmetric') -> Self:
336    def pad(
337        self,
338        pad_width: PerAxis[PadWidthLike],
339        mode: PadMode = "symmetric",
340    ) -> Self:
341        pad_width = {a: PadWidth.create(p) for a, p in pad_width.items()}
342        return self.__class__.from_xarray(
343            self._data.pad(pad_width=pad_width, mode=mode)
344        )
def pad_to( self, sizes: Mapping[bioimageio.spec.model.v0_5.AxisId, int], pad_where: Union[Literal['left', 'right', 'left_and_right'], Mapping[bioimageio.spec.model.v0_5.AxisId, Literal['left', 'right', 'left_and_right']]] = 'left_and_right', mode: Literal['edge', 'reflect', 'symmetric'] = 'symmetric') -> Self:
346    def pad_to(
347        self,
348        sizes: PerAxis[int],
349        pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right",
350        mode: PadMode = "symmetric",
351    ) -> Self:
352        """pad `tensor` to match `sizes`"""
353        if isinstance(pad_where, str):
354            pad_axis_where: PerAxis[PadWhere] = {a: pad_where for a in self.dims}
355        else:
356            pad_axis_where = pad_where
357
358        pad_width: Dict[AxisId, PadWidth] = {}
359        for a, s_is in self.sizes.items():
360            if a not in sizes or sizes[a] == s_is:
361                pad_width[a] = PadWidth(0, 0)
362            elif s_is > sizes[a]:
363                pad_width[a] = PadWidth(0, 0)
364                logger.warning(
365                    "Cannot pad axis {} of size {} to smaller size {}",
366                    a,
367                    s_is,
368                    sizes[a],
369                )
370            elif a not in pad_axis_where:
371                raise ValueError(
372                    f"Don't know where to pad axis {a}, `pad_where`={pad_where}"
373                )
374            else:
375                pad_this_axis_where = pad_axis_where[a]
376                d = sizes[a] - s_is
377                if pad_this_axis_where == "left":
378                    pad_width[a] = PadWidth(d, 0)
379                elif pad_this_axis_where == "right":
380                    pad_width[a] = PadWidth(0, d)
381                elif pad_this_axis_where == "left_and_right":
382                    pad_width[a] = PadWidth(left := d // 2, d - left)
383                else:
384                    assert_never(pad_this_axis_where)
385
386        return self.pad(pad_width, mode)

pad tensor to match sizes

def quantile( self, q: Union[float, Sequence[float]], dim: Union[bioimageio.spec.model.v0_5.AxisId, Sequence[bioimageio.spec.model.v0_5.AxisId], NoneType] = None) -> Self:
388    def quantile(
389        self,
390        q: Union[float, Sequence[float]],
391        dim: Optional[Union[AxisId, Sequence[AxisId]]] = None,
392    ) -> Self:
393        assert (
394            isinstance(q, (float, int))
395            and q >= 0.0
396            or not isinstance(q, (float, int))
397            and all(qq >= 0.0 for qq in q)
398        )
399        assert (
400            isinstance(q, (float, int))
401            and q <= 1.0
402            or not isinstance(q, (float, int))
403            and all(qq <= 1.0 for qq in q)
404        )
405        assert dim is None or (
406            (quantile_dim := AxisId("quantile")) != dim and quantile_dim not in set(dim)
407        )
408        return self.__class__.from_xarray(self._data.quantile(q, dim=dim))
def resize_to( self, sizes: Mapping[bioimageio.spec.model.v0_5.AxisId, int], *, pad_where: Union[Literal['left', 'right', 'left_and_right'], Mapping[bioimageio.spec.model.v0_5.AxisId, Literal['left', 'right', 'left_and_right']]] = 'left_and_right', crop_where: Union[Literal['left', 'right', 'left_and_right'], Mapping[bioimageio.spec.model.v0_5.AxisId, Literal['left', 'right', 'left_and_right']]] = 'left_and_right', pad_mode: Literal['edge', 'reflect', 'symmetric'] = 'symmetric'):
410    def resize_to(
411        self,
412        sizes: PerAxis[int],
413        *,
414        pad_where: Union[
415            PadWhere,
416            PerAxis[PadWhere],
417        ] = "left_and_right",
418        crop_where: Union[
419            CropWhere,
420            PerAxis[CropWhere],
421        ] = "left_and_right",
422        pad_mode: PadMode = "symmetric",
423    ):
424        """return cropped/padded tensor with `sizes`"""
425        crop_to_sizes: Dict[AxisId, int] = {}
426        pad_to_sizes: Dict[AxisId, int] = {}
427        new_axes = dict(sizes)
428        for a, s_is in self.sizes.items():
429            a = AxisId(str(a))
430            _ = new_axes.pop(a, None)
431            if a not in sizes or sizes[a] == s_is:
432                pass
433            elif s_is > sizes[a]:
434                crop_to_sizes[a] = sizes[a]
435            else:
436                pad_to_sizes[a] = sizes[a]
437
438        tensor = self
439        if crop_to_sizes:
440            tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where)
441
442        if pad_to_sizes:
443            tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode)
444
445        if new_axes:
446            tensor = tensor.expand_dims(new_axes)
447
448        return tensor

return cropped/padded tensor with sizes

def std( self, dim: Union[bioimageio.spec.model.v0_5.AxisId, Sequence[bioimageio.spec.model.v0_5.AxisId], NoneType] = None) -> Self:
450    def std(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
451        return self.__class__.from_xarray(self._data.std(dim=dim))
def sum( self, dim: Union[bioimageio.spec.model.v0_5.AxisId, Sequence[bioimageio.spec.model.v0_5.AxisId], NoneType] = None) -> Self:
453    def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
454        """Reduce this Tensor's data by applying sum along some dimension(s)."""
455        return self.__class__.from_xarray(self._data.sum(dim=dim))

Reduce this Tensor's data by applying sum along some dimension(s).

def transpose(self, axes: Sequence[bioimageio.spec.model.v0_5.AxisId]) -> Self:
457    def transpose(
458        self,
459        axes: Sequence[AxisId],
460    ) -> Self:
461        """return a transposed tensor
462
463        Args:
464            axes: the desired tensor axes
465        """
466        # expand missing tensor axes
467        missing_axes = tuple(a for a in axes if a not in self.dims)
468        array = self._data
469        if missing_axes:
470            array = array.expand_dims(missing_axes)
471
472        # transpose to the correct axis order
473        return self.__class__.from_xarray(array.transpose(*axes))

return a transposed tensor

Arguments:
  • axes: the desired tensor axes
def var( self, dim: Union[bioimageio.spec.model.v0_5.AxisId, Sequence[bioimageio.spec.model.v0_5.AxisId], NoneType] = None) -> Self:
475    def var(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
476        return self.__class__.from_xarray(self._data.var(dim=dim))