bioimageio.core.tensor

  1from __future__ import annotations
  2
  3import collections.abc
  4from itertools import permutations
  5from typing import (
  6    TYPE_CHECKING,
  7    Any,
  8    Callable,
  9    Dict,
 10    Iterator,
 11    Mapping,
 12    Optional,
 13    Sequence,
 14    Tuple,
 15    Union,
 16    cast,
 17    get_args,
 18)
 19
 20import numpy as np
 21import xarray as xr
 22from loguru import logger
 23from numpy.typing import DTypeLike, NDArray
 24from typing_extensions import Self, assert_never
 25
 26from bioimageio.spec.model import v0_5
 27
 28from ._magic_tensor_ops import MagicTensorOpsMixin
 29from .axis import Axis, AxisId, AxisInfo, AxisLike, PerAxis
 30from .common import (
 31    CropWhere,
 32    DTypeStr,
 33    PadMode,
 34    PadWhere,
 35    PadWidth,
 36    PadWidthLike,
 37    SliceInfo,
 38)
 39
 40if TYPE_CHECKING:
 41    from numpy.typing import ArrayLike, NDArray
 42
 43
 44_ScalarOrArray = Union["ArrayLike", np.generic, "NDArray[Any]"]  # TODO: add "DaskArray"
 45
 46
 47# TODO: complete docstrings
 48class Tensor(MagicTensorOpsMixin):
 49    """A wrapper around an xr.DataArray for better integration with bioimageio.spec
 50    and improved type annotations."""
 51
 52    _Compatible = Union["Tensor", xr.DataArray, _ScalarOrArray]
 53
 54    def __init__(
 55        self,
 56        array: NDArray[Any],
 57        dims: Sequence[Union[AxisId, AxisLike]],
 58    ) -> None:
 59        super().__init__()
 60        axes = tuple(
 61            a if isinstance(a, AxisId) else AxisInfo.create(a).id for a in dims
 62        )
 63        self._data = xr.DataArray(array, dims=axes)
 64
 65    def __array__(self, dtype: DTypeLike = None):
 66        return np.asarray(self._data, dtype=dtype)
 67
 68    def __getitem__(
 69        self, key: Union[SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]]]
 70    ) -> Self:
 71        if isinstance(key, SliceInfo):
 72            key = slice(*key)
 73        elif isinstance(key, collections.abc.Mapping):
 74            key = {
 75                a: s if isinstance(s, int) else s if isinstance(s, slice) else slice(*s)
 76                for a, s in key.items()
 77            }
 78        return self.__class__.from_xarray(self._data[key])
 79
 80    def __setitem__(self, key: PerAxis[Union[SliceInfo, slice]], value: Tensor) -> None:
 81        key = {a: s if isinstance(s, slice) else slice(*s) for a, s in key.items()}
 82        self._data[key] = value._data
 83
 84    def __len__(self) -> int:
 85        return len(self.data)
 86
 87    def _iter(self: Any) -> Iterator[Any]:
 88        for n in range(len(self)):
 89            yield self[n]
 90
 91    def __iter__(self: Any) -> Iterator[Any]:
 92        if self.ndim == 0:
 93            raise TypeError("iteration over a 0-d array")
 94        return self._iter()
 95
 96    def _binary_op(
 97        self,
 98        other: _Compatible,
 99        f: Callable[[Any, Any], Any],
100        reflexive: bool = False,
101    ) -> Self:
102        data = self._data._binary_op(  # pyright: ignore[reportPrivateUsage]
103            (other._data if isinstance(other, Tensor) else other),
104            f,
105            reflexive,
106        )
107        return self.__class__.from_xarray(data)
108
109    def _inplace_binary_op(
110        self,
111        other: _Compatible,
112        f: Callable[[Any, Any], Any],
113    ) -> Self:
114        _ = self._data._inplace_binary_op(  # pyright: ignore[reportPrivateUsage]
115            (
116                other_d
117                if (other_d := getattr(other, "data")) is not None
118                and isinstance(
119                    other_d,
120                    xr.DataArray,
121                )
122                else other
123            ),
124            f,
125        )
126        return self
127
128    def _unary_op(self, f: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Self:
129        data = self._data._unary_op(  # pyright: ignore[reportPrivateUsage]
130            f, *args, **kwargs
131        )
132        return self.__class__.from_xarray(data)
133
134    @classmethod
135    def from_xarray(cls, data_array: xr.DataArray) -> Self:
136        """create a `Tensor` from an xarray data array
137
138        note for internal use: this factory method is round-trip save
139            for any `Tensor`'s  `data` property (an xarray.DataArray).
140        """
141        return cls(
142            array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims)
143        )
144
145    @classmethod
146    def from_numpy(
147        cls,
148        array: NDArray[Any],
149        *,
150        dims: Optional[Union[AxisLike, Sequence[AxisLike]]],
151    ) -> Tensor:
152        """create a `Tensor` from a numpy array
153
154        Args:
155            array: the nd numpy array
156            axes: A description of the array's axes,
157                if None axes are guessed (which might fail and raise a ValueError.)
158
159        Raises:
160            ValueError: if `axes` is None and axes guessing fails.
161        """
162
163        if dims is None:
164            return cls._interprete_array_wo_known_axes(array)
165        elif isinstance(dims, (str, Axis, v0_5.AxisBase)):
166            dims = [dims]
167
168        axis_infos = [AxisInfo.create(a) for a in dims]
169        original_shape = tuple(array.shape)
170
171        successful_view = _get_array_view(array, axis_infos)
172        if successful_view is None:
173            raise ValueError(
174                f"Array shape {original_shape} does not map to axes {dims}"
175            )
176
177        return Tensor(successful_view, dims=tuple(a.id for a in axis_infos))
178
179    @property
180    def data(self):
181        return self._data
182
183    @property
184    def dims(self):  # TODO: rename to `axes`?
185        """Tuple of dimension names associated with this tensor."""
186        return cast(Tuple[AxisId, ...], self._data.dims)
187
188    @property
189    def tagged_shape(self):
190        """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths."""
191        return self.sizes
192
193    @property
194    def shape_tuple(self):
195        """Tuple of tensor axes lengths"""
196        return self._data.shape
197
198    @property
199    def size(self):
200        """Number of elements in the tensor.
201
202        Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions.
203        """
204        return self._data.size
205
206    def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
207        """Reduce this Tensor's data by applying sum along some dimension(s)."""
208        return self.__class__.from_xarray(self._data.sum(dim=dim))
209
210    @property
211    def ndim(self):
212        """Number of tensor dimensions."""
213        return self._data.ndim
214
215    @property
216    def dtype(self) -> DTypeStr:
217        dt = str(self.data.dtype)  # pyright: ignore[reportUnknownArgumentType]
218        assert dt in get_args(DTypeStr)
219        return dt  # pyright: ignore[reportReturnType]
220
221    @property
222    def sizes(self):
223        """Ordered, immutable mapping from axis ids to axis lengths."""
224        return cast(Mapping[AxisId, int], self.data.sizes)
225
226    def astype(self, dtype: DTypeStr, *, copy: bool = False):
227        """Return tensor cast to `dtype`
228
229        note: if dtype is already satisfied copy if `copy`"""
230        return self.__class__.from_xarray(self._data.astype(dtype, copy=copy))
231
232    def clip(self, min: Optional[float] = None, max: Optional[float] = None):
233        """Return a tensor whose values are limited to [min, max].
234        At least one of max or min must be given."""
235        return self.__class__.from_xarray(self._data.clip(min, max))
236
237    def crop_to(
238        self,
239        sizes: PerAxis[int],
240        crop_where: Union[
241            CropWhere,
242            PerAxis[CropWhere],
243        ] = "left_and_right",
244    ) -> Self:
245        """crop to match `sizes`"""
246        if isinstance(crop_where, str):
247            crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims}
248        else:
249            crop_axis_where = crop_where
250
251        slices: Dict[AxisId, SliceInfo] = {}
252
253        for a, s_is in self.sizes.items():
254            if a not in sizes or sizes[a] == s_is:
255                pass
256            elif sizes[a] > s_is:
257                logger.warning(
258                    "Cannot crop axis {} of size {} to larger size {}",
259                    a,
260                    s_is,
261                    sizes[a],
262                )
263            elif a not in crop_axis_where:
264                raise ValueError(
265                    f"Don't know where to crop axis {a}, `crop_where`={crop_where}"
266                )
267            else:
268                crop_this_axis_where = crop_axis_where[a]
269                if crop_this_axis_where == "left":
270                    slices[a] = SliceInfo(s_is - sizes[a], s_is)
271                elif crop_this_axis_where == "right":
272                    slices[a] = SliceInfo(0, sizes[a])
273                elif crop_this_axis_where == "left_and_right":
274                    slices[a] = SliceInfo(
275                        start := (s_is - sizes[a]) // 2, sizes[a] + start
276                    )
277                else:
278                    assert_never(crop_this_axis_where)
279
280        return self[slices]
281
282    def expand_dims(self, dims: Union[Sequence[AxisId], PerAxis[int]]) -> Self:
283        return self.__class__.from_xarray(self._data.expand_dims(dims=dims))
284
285    def mean(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
286        return self.__class__.from_xarray(self._data.mean(dim=dim))
287
288    def std(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
289        return self.__class__.from_xarray(self._data.std(dim=dim))
290
291    def var(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
292        return self.__class__.from_xarray(self._data.var(dim=dim))
293
294    def pad(
295        self,
296        pad_width: PerAxis[PadWidthLike],
297        mode: PadMode = "symmetric",
298    ) -> Self:
299        pad_width = {a: PadWidth.create(p) for a, p in pad_width.items()}
300        return self.__class__.from_xarray(
301            self._data.pad(pad_width=pad_width, mode=mode)
302        )
303
304    def pad_to(
305        self,
306        sizes: PerAxis[int],
307        pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right",
308        mode: PadMode = "symmetric",
309    ) -> Self:
310        """pad `tensor` to match `sizes`"""
311        if isinstance(pad_where, str):
312            pad_axis_where: PerAxis[PadWhere] = {a: pad_where for a in self.dims}
313        else:
314            pad_axis_where = pad_where
315
316        pad_width: Dict[AxisId, PadWidth] = {}
317        for a, s_is in self.sizes.items():
318            if a not in sizes or sizes[a] == s_is:
319                pad_width[a] = PadWidth(0, 0)
320            elif s_is > sizes[a]:
321                pad_width[a] = PadWidth(0, 0)
322                logger.warning(
323                    "Cannot pad axis {} of size {} to smaller size {}",
324                    a,
325                    s_is,
326                    sizes[a],
327                )
328            elif a not in pad_axis_where:
329                raise ValueError(
330                    f"Don't know where to pad axis {a}, `pad_where`={pad_where}"
331                )
332            else:
333                pad_this_axis_where = pad_axis_where[a]
334                d = sizes[a] - s_is
335                if pad_this_axis_where == "left":
336                    pad_width[a] = PadWidth(d, 0)
337                elif pad_this_axis_where == "right":
338                    pad_width[a] = PadWidth(0, d)
339                elif pad_this_axis_where == "left_and_right":
340                    pad_width[a] = PadWidth(left := d // 2, d - left)
341                else:
342                    assert_never(pad_this_axis_where)
343
344        return self.pad(pad_width, mode)
345
346    def quantile(
347        self,
348        q: Union[float, Sequence[float]],
349        dim: Optional[Union[AxisId, Sequence[AxisId]]] = None,
350    ) -> Self:
351        assert (
352            isinstance(q, (float, int))
353            and q >= 0.0
354            or not isinstance(q, (float, int))
355            and all(qq >= 0.0 for qq in q)
356        )
357        assert (
358            isinstance(q, (float, int))
359            and q <= 1.0
360            or not isinstance(q, (float, int))
361            and all(qq <= 1.0 for qq in q)
362        )
363        assert dim is None or (
364            (quantile_dim := AxisId("quantile")) != dim and quantile_dim not in set(dim)
365        )
366        return self.__class__.from_xarray(self._data.quantile(q, dim=dim))
367
368    def resize_to(
369        self,
370        sizes: PerAxis[int],
371        *,
372        pad_where: Union[
373            PadWhere,
374            PerAxis[PadWhere],
375        ] = "left_and_right",
376        crop_where: Union[
377            CropWhere,
378            PerAxis[CropWhere],
379        ] = "left_and_right",
380        pad_mode: PadMode = "symmetric",
381    ):
382        """return cropped/padded tensor with `sizes`"""
383        crop_to_sizes: Dict[AxisId, int] = {}
384        pad_to_sizes: Dict[AxisId, int] = {}
385        new_axes = dict(sizes)
386        for a, s_is in self.sizes.items():
387            a = AxisId(str(a))
388            _ = new_axes.pop(a, None)
389            if a not in sizes or sizes[a] == s_is:
390                pass
391            elif s_is > sizes[a]:
392                crop_to_sizes[a] = sizes[a]
393            else:
394                pad_to_sizes[a] = sizes[a]
395
396        tensor = self
397        if crop_to_sizes:
398            tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where)
399
400        if pad_to_sizes:
401            tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode)
402
403        if new_axes:
404            tensor = tensor.expand_dims(new_axes)
405
406        return tensor
407
408    def transpose(
409        self,
410        axes: Sequence[AxisId],
411    ) -> Self:
412        """return a transposed tensor
413
414        Args:
415            axes: the desired tensor axes
416        """
417        # expand missing tensor axes
418        missing_axes = tuple(a for a in axes if a not in self.dims)
419        array = self._data
420        if missing_axes:
421            array = array.expand_dims(missing_axes)
422
423        # transpose to the correct axis order
424        return self.__class__.from_xarray(array.transpose(*axes))
425
426    @classmethod
427    def _interprete_array_wo_known_axes(cls, array: NDArray[Any]):
428        ndim = array.ndim
429        if ndim == 2:
430            current_axes = (
431                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[0]),
432                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[1]),
433            )
434        elif ndim == 3 and any(s <= 3 for s in array.shape):
435            current_axes = (
436                v0_5.ChannelAxis(
437                    channel_names=[
438                        v0_5.Identifier(f"channel{i}") for i in range(array.shape[0])
439                    ]
440                ),
441                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]),
442                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]),
443            )
444        elif ndim == 3:
445            current_axes = (
446                v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[0]),
447                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]),
448                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]),
449            )
450        elif ndim == 4:
451            current_axes = (
452                v0_5.ChannelAxis(
453                    channel_names=[
454                        v0_5.Identifier(f"channel{i}") for i in range(array.shape[0])
455                    ]
456                ),
457                v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[1]),
458                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[2]),
459                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[3]),
460            )
461        elif ndim == 5:
462            current_axes = (
463                v0_5.BatchAxis(),
464                v0_5.ChannelAxis(
465                    channel_names=[
466                        v0_5.Identifier(f"channel{i}") for i in range(array.shape[1])
467                    ]
468                ),
469                v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[2]),
470                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[3]),
471                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[4]),
472            )
473        else:
474            raise ValueError(f"Could not guess an axis mapping for {array.shape}")
475
476        return cls(array, dims=tuple(a.id for a in current_axes))
477
478
479def _add_singletons(arr: NDArray[Any], axis_infos: Sequence[AxisInfo]):
480    if len(arr.shape) > len(axis_infos):
481        # remove singletons
482        for i, s in enumerate(arr.shape):
483            if s == 1:
484                arr = np.take(arr, 0, axis=i)
485                if len(arr.shape) == len(axis_infos):
486                    break
487
488    # add singletons if nececsary
489    for i, a in enumerate(axis_infos):
490        if len(arr.shape) >= len(axis_infos):
491            break
492
493        if a.maybe_singleton:
494            arr = np.expand_dims(arr, i)
495
496    return arr
497
498
499def _get_array_view(
500    original_array: NDArray[Any], axis_infos: Sequence[AxisInfo]
501) -> Optional[NDArray[Any]]:
502    perms = list(permutations(range(len(original_array.shape))))
503    perms.insert(1, perms.pop())  # try A and A.T first
504
505    for perm in perms:
506        view = original_array.transpose(perm)
507        view = _add_singletons(view, axis_infos)
508        if len(view.shape) != len(axis_infos):
509            return None
510
511        for s, a in zip(view.shape, axis_infos):
512            if s == 1 and not a.maybe_singleton:
513                break
514        else:
515            return view
516
517    return None
class Tensor(bioimageio.core._magic_tensor_ops.MagicTensorOpsMixin):
 49class Tensor(MagicTensorOpsMixin):
 50    """A wrapper around an xr.DataArray for better integration with bioimageio.spec
 51    and improved type annotations."""
 52
 53    _Compatible = Union["Tensor", xr.DataArray, _ScalarOrArray]
 54
 55    def __init__(
 56        self,
 57        array: NDArray[Any],
 58        dims: Sequence[Union[AxisId, AxisLike]],
 59    ) -> None:
 60        super().__init__()
 61        axes = tuple(
 62            a if isinstance(a, AxisId) else AxisInfo.create(a).id for a in dims
 63        )
 64        self._data = xr.DataArray(array, dims=axes)
 65
 66    def __array__(self, dtype: DTypeLike = None):
 67        return np.asarray(self._data, dtype=dtype)
 68
 69    def __getitem__(
 70        self, key: Union[SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]]]
 71    ) -> Self:
 72        if isinstance(key, SliceInfo):
 73            key = slice(*key)
 74        elif isinstance(key, collections.abc.Mapping):
 75            key = {
 76                a: s if isinstance(s, int) else s if isinstance(s, slice) else slice(*s)
 77                for a, s in key.items()
 78            }
 79        return self.__class__.from_xarray(self._data[key])
 80
 81    def __setitem__(self, key: PerAxis[Union[SliceInfo, slice]], value: Tensor) -> None:
 82        key = {a: s if isinstance(s, slice) else slice(*s) for a, s in key.items()}
 83        self._data[key] = value._data
 84
 85    def __len__(self) -> int:
 86        return len(self.data)
 87
 88    def _iter(self: Any) -> Iterator[Any]:
 89        for n in range(len(self)):
 90            yield self[n]
 91
 92    def __iter__(self: Any) -> Iterator[Any]:
 93        if self.ndim == 0:
 94            raise TypeError("iteration over a 0-d array")
 95        return self._iter()
 96
 97    def _binary_op(
 98        self,
 99        other: _Compatible,
100        f: Callable[[Any, Any], Any],
101        reflexive: bool = False,
102    ) -> Self:
103        data = self._data._binary_op(  # pyright: ignore[reportPrivateUsage]
104            (other._data if isinstance(other, Tensor) else other),
105            f,
106            reflexive,
107        )
108        return self.__class__.from_xarray(data)
109
110    def _inplace_binary_op(
111        self,
112        other: _Compatible,
113        f: Callable[[Any, Any], Any],
114    ) -> Self:
115        _ = self._data._inplace_binary_op(  # pyright: ignore[reportPrivateUsage]
116            (
117                other_d
118                if (other_d := getattr(other, "data")) is not None
119                and isinstance(
120                    other_d,
121                    xr.DataArray,
122                )
123                else other
124            ),
125            f,
126        )
127        return self
128
129    def _unary_op(self, f: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Self:
130        data = self._data._unary_op(  # pyright: ignore[reportPrivateUsage]
131            f, *args, **kwargs
132        )
133        return self.__class__.from_xarray(data)
134
135    @classmethod
136    def from_xarray(cls, data_array: xr.DataArray) -> Self:
137        """create a `Tensor` from an xarray data array
138
139        note for internal use: this factory method is round-trip save
140            for any `Tensor`'s  `data` property (an xarray.DataArray).
141        """
142        return cls(
143            array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims)
144        )
145
146    @classmethod
147    def from_numpy(
148        cls,
149        array: NDArray[Any],
150        *,
151        dims: Optional[Union[AxisLike, Sequence[AxisLike]]],
152    ) -> Tensor:
153        """create a `Tensor` from a numpy array
154
155        Args:
156            array: the nd numpy array
157            axes: A description of the array's axes,
158                if None axes are guessed (which might fail and raise a ValueError.)
159
160        Raises:
161            ValueError: if `axes` is None and axes guessing fails.
162        """
163
164        if dims is None:
165            return cls._interprete_array_wo_known_axes(array)
166        elif isinstance(dims, (str, Axis, v0_5.AxisBase)):
167            dims = [dims]
168
169        axis_infos = [AxisInfo.create(a) for a in dims]
170        original_shape = tuple(array.shape)
171
172        successful_view = _get_array_view(array, axis_infos)
173        if successful_view is None:
174            raise ValueError(
175                f"Array shape {original_shape} does not map to axes {dims}"
176            )
177
178        return Tensor(successful_view, dims=tuple(a.id for a in axis_infos))
179
180    @property
181    def data(self):
182        return self._data
183
184    @property
185    def dims(self):  # TODO: rename to `axes`?
186        """Tuple of dimension names associated with this tensor."""
187        return cast(Tuple[AxisId, ...], self._data.dims)
188
189    @property
190    def tagged_shape(self):
191        """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths."""
192        return self.sizes
193
194    @property
195    def shape_tuple(self):
196        """Tuple of tensor axes lengths"""
197        return self._data.shape
198
199    @property
200    def size(self):
201        """Number of elements in the tensor.
202
203        Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions.
204        """
205        return self._data.size
206
207    def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
208        """Reduce this Tensor's data by applying sum along some dimension(s)."""
209        return self.__class__.from_xarray(self._data.sum(dim=dim))
210
211    @property
212    def ndim(self):
213        """Number of tensor dimensions."""
214        return self._data.ndim
215
216    @property
217    def dtype(self) -> DTypeStr:
218        dt = str(self.data.dtype)  # pyright: ignore[reportUnknownArgumentType]
219        assert dt in get_args(DTypeStr)
220        return dt  # pyright: ignore[reportReturnType]
221
222    @property
223    def sizes(self):
224        """Ordered, immutable mapping from axis ids to axis lengths."""
225        return cast(Mapping[AxisId, int], self.data.sizes)
226
227    def astype(self, dtype: DTypeStr, *, copy: bool = False):
228        """Return tensor cast to `dtype`
229
230        note: if dtype is already satisfied copy if `copy`"""
231        return self.__class__.from_xarray(self._data.astype(dtype, copy=copy))
232
233    def clip(self, min: Optional[float] = None, max: Optional[float] = None):
234        """Return a tensor whose values are limited to [min, max].
235        At least one of max or min must be given."""
236        return self.__class__.from_xarray(self._data.clip(min, max))
237
238    def crop_to(
239        self,
240        sizes: PerAxis[int],
241        crop_where: Union[
242            CropWhere,
243            PerAxis[CropWhere],
244        ] = "left_and_right",
245    ) -> Self:
246        """crop to match `sizes`"""
247        if isinstance(crop_where, str):
248            crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims}
249        else:
250            crop_axis_where = crop_where
251
252        slices: Dict[AxisId, SliceInfo] = {}
253
254        for a, s_is in self.sizes.items():
255            if a not in sizes or sizes[a] == s_is:
256                pass
257            elif sizes[a] > s_is:
258                logger.warning(
259                    "Cannot crop axis {} of size {} to larger size {}",
260                    a,
261                    s_is,
262                    sizes[a],
263                )
264            elif a not in crop_axis_where:
265                raise ValueError(
266                    f"Don't know where to crop axis {a}, `crop_where`={crop_where}"
267                )
268            else:
269                crop_this_axis_where = crop_axis_where[a]
270                if crop_this_axis_where == "left":
271                    slices[a] = SliceInfo(s_is - sizes[a], s_is)
272                elif crop_this_axis_where == "right":
273                    slices[a] = SliceInfo(0, sizes[a])
274                elif crop_this_axis_where == "left_and_right":
275                    slices[a] = SliceInfo(
276                        start := (s_is - sizes[a]) // 2, sizes[a] + start
277                    )
278                else:
279                    assert_never(crop_this_axis_where)
280
281        return self[slices]
282
283    def expand_dims(self, dims: Union[Sequence[AxisId], PerAxis[int]]) -> Self:
284        return self.__class__.from_xarray(self._data.expand_dims(dims=dims))
285
286    def mean(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
287        return self.__class__.from_xarray(self._data.mean(dim=dim))
288
289    def std(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
290        return self.__class__.from_xarray(self._data.std(dim=dim))
291
292    def var(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
293        return self.__class__.from_xarray(self._data.var(dim=dim))
294
295    def pad(
296        self,
297        pad_width: PerAxis[PadWidthLike],
298        mode: PadMode = "symmetric",
299    ) -> Self:
300        pad_width = {a: PadWidth.create(p) for a, p in pad_width.items()}
301        return self.__class__.from_xarray(
302            self._data.pad(pad_width=pad_width, mode=mode)
303        )
304
305    def pad_to(
306        self,
307        sizes: PerAxis[int],
308        pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right",
309        mode: PadMode = "symmetric",
310    ) -> Self:
311        """pad `tensor` to match `sizes`"""
312        if isinstance(pad_where, str):
313            pad_axis_where: PerAxis[PadWhere] = {a: pad_where for a in self.dims}
314        else:
315            pad_axis_where = pad_where
316
317        pad_width: Dict[AxisId, PadWidth] = {}
318        for a, s_is in self.sizes.items():
319            if a not in sizes or sizes[a] == s_is:
320                pad_width[a] = PadWidth(0, 0)
321            elif s_is > sizes[a]:
322                pad_width[a] = PadWidth(0, 0)
323                logger.warning(
324                    "Cannot pad axis {} of size {} to smaller size {}",
325                    a,
326                    s_is,
327                    sizes[a],
328                )
329            elif a not in pad_axis_where:
330                raise ValueError(
331                    f"Don't know where to pad axis {a}, `pad_where`={pad_where}"
332                )
333            else:
334                pad_this_axis_where = pad_axis_where[a]
335                d = sizes[a] - s_is
336                if pad_this_axis_where == "left":
337                    pad_width[a] = PadWidth(d, 0)
338                elif pad_this_axis_where == "right":
339                    pad_width[a] = PadWidth(0, d)
340                elif pad_this_axis_where == "left_and_right":
341                    pad_width[a] = PadWidth(left := d // 2, d - left)
342                else:
343                    assert_never(pad_this_axis_where)
344
345        return self.pad(pad_width, mode)
346
347    def quantile(
348        self,
349        q: Union[float, Sequence[float]],
350        dim: Optional[Union[AxisId, Sequence[AxisId]]] = None,
351    ) -> Self:
352        assert (
353            isinstance(q, (float, int))
354            and q >= 0.0
355            or not isinstance(q, (float, int))
356            and all(qq >= 0.0 for qq in q)
357        )
358        assert (
359            isinstance(q, (float, int))
360            and q <= 1.0
361            or not isinstance(q, (float, int))
362            and all(qq <= 1.0 for qq in q)
363        )
364        assert dim is None or (
365            (quantile_dim := AxisId("quantile")) != dim and quantile_dim not in set(dim)
366        )
367        return self.__class__.from_xarray(self._data.quantile(q, dim=dim))
368
369    def resize_to(
370        self,
371        sizes: PerAxis[int],
372        *,
373        pad_where: Union[
374            PadWhere,
375            PerAxis[PadWhere],
376        ] = "left_and_right",
377        crop_where: Union[
378            CropWhere,
379            PerAxis[CropWhere],
380        ] = "left_and_right",
381        pad_mode: PadMode = "symmetric",
382    ):
383        """return cropped/padded tensor with `sizes`"""
384        crop_to_sizes: Dict[AxisId, int] = {}
385        pad_to_sizes: Dict[AxisId, int] = {}
386        new_axes = dict(sizes)
387        for a, s_is in self.sizes.items():
388            a = AxisId(str(a))
389            _ = new_axes.pop(a, None)
390            if a not in sizes or sizes[a] == s_is:
391                pass
392            elif s_is > sizes[a]:
393                crop_to_sizes[a] = sizes[a]
394            else:
395                pad_to_sizes[a] = sizes[a]
396
397        tensor = self
398        if crop_to_sizes:
399            tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where)
400
401        if pad_to_sizes:
402            tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode)
403
404        if new_axes:
405            tensor = tensor.expand_dims(new_axes)
406
407        return tensor
408
409    def transpose(
410        self,
411        axes: Sequence[AxisId],
412    ) -> Self:
413        """return a transposed tensor
414
415        Args:
416            axes: the desired tensor axes
417        """
418        # expand missing tensor axes
419        missing_axes = tuple(a for a in axes if a not in self.dims)
420        array = self._data
421        if missing_axes:
422            array = array.expand_dims(missing_axes)
423
424        # transpose to the correct axis order
425        return self.__class__.from_xarray(array.transpose(*axes))
426
427    @classmethod
428    def _interprete_array_wo_known_axes(cls, array: NDArray[Any]):
429        ndim = array.ndim
430        if ndim == 2:
431            current_axes = (
432                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[0]),
433                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[1]),
434            )
435        elif ndim == 3 and any(s <= 3 for s in array.shape):
436            current_axes = (
437                v0_5.ChannelAxis(
438                    channel_names=[
439                        v0_5.Identifier(f"channel{i}") for i in range(array.shape[0])
440                    ]
441                ),
442                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]),
443                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]),
444            )
445        elif ndim == 3:
446            current_axes = (
447                v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[0]),
448                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]),
449                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]),
450            )
451        elif ndim == 4:
452            current_axes = (
453                v0_5.ChannelAxis(
454                    channel_names=[
455                        v0_5.Identifier(f"channel{i}") for i in range(array.shape[0])
456                    ]
457                ),
458                v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[1]),
459                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[2]),
460                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[3]),
461            )
462        elif ndim == 5:
463            current_axes = (
464                v0_5.BatchAxis(),
465                v0_5.ChannelAxis(
466                    channel_names=[
467                        v0_5.Identifier(f"channel{i}") for i in range(array.shape[1])
468                    ]
469                ),
470                v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[2]),
471                v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[3]),
472                v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[4]),
473            )
474        else:
475            raise ValueError(f"Could not guess an axis mapping for {array.shape}")
476
477        return cls(array, dims=tuple(a.id for a in current_axes))

A wrapper around an xr.DataArray for better integration with bioimageio.spec and improved type annotations.

Tensor( array: numpy.ndarray[typing.Any, numpy.dtype[typing.Any]], dims: Sequence[Union[bioimageio.spec.model.v0_5.AxisId, Literal['b', 'i', 't', 'c', 'z', 'y', 'x'], Annotated[Union[bioimageio.spec.model.v0_5.BatchAxis, bioimageio.spec.model.v0_5.ChannelAxis, bioimageio.spec.model.v0_5.IndexInputAxis, bioimageio.spec.model.v0_5.TimeInputAxis, bioimageio.spec.model.v0_5.SpaceInputAxis], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.spec.model.v0_5.BatchAxis, bioimageio.spec.model.v0_5.ChannelAxis, bioimageio.spec.model.v0_5.IndexOutputAxis, Annotated[Union[Annotated[bioimageio.spec.model.v0_5.TimeOutputAxis, Tag(tag='wo_halo')], Annotated[bioimageio.spec.model.v0_5.TimeOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[Annotated[bioimageio.spec.model.v0_5.SpaceOutputAxis, Tag(tag='wo_halo')], Annotated[bioimageio.spec.model.v0_5.SpaceOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], bioimageio.core.Axis]])
55    def __init__(
56        self,
57        array: NDArray[Any],
58        dims: Sequence[Union[AxisId, AxisLike]],
59    ) -> None:
60        super().__init__()
61        axes = tuple(
62            a if isinstance(a, AxisId) else AxisInfo.create(a).id for a in dims
63        )
64        self._data = xr.DataArray(array, dims=axes)
@classmethod
def from_xarray(cls, data_array: xarray.core.dataarray.DataArray) -> Self:
135    @classmethod
136    def from_xarray(cls, data_array: xr.DataArray) -> Self:
137        """create a `Tensor` from an xarray data array
138
139        note for internal use: this factory method is round-trip save
140            for any `Tensor`'s  `data` property (an xarray.DataArray).
141        """
142        return cls(
143            array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims)
144        )

create a Tensor from an xarray data array

note for internal use: this factory method is round-trip save for any Tensor's data property (an xarray.DataArray).

@classmethod
def from_numpy( cls, array: numpy.ndarray[typing.Any, numpy.dtype[typing.Any]], *, dims: Union[bioimageio.spec.model.v0_5.AxisId, Literal['b', 'i', 't', 'c', 'z', 'y', 'x'], Annotated[Union[bioimageio.spec.model.v0_5.BatchAxis, bioimageio.spec.model.v0_5.ChannelAxis, bioimageio.spec.model.v0_5.IndexInputAxis, bioimageio.spec.model.v0_5.TimeInputAxis, bioimageio.spec.model.v0_5.SpaceInputAxis], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.spec.model.v0_5.BatchAxis, bioimageio.spec.model.v0_5.ChannelAxis, bioimageio.spec.model.v0_5.IndexOutputAxis, Annotated[Union[Annotated[bioimageio.spec.model.v0_5.TimeOutputAxis, Tag(tag='wo_halo')], Annotated[bioimageio.spec.model.v0_5.TimeOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[Annotated[bioimageio.spec.model.v0_5.SpaceOutputAxis, Tag(tag='wo_halo')], Annotated[bioimageio.spec.model.v0_5.SpaceOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], bioimageio.core.Axis, Sequence[Union[bioimageio.spec.model.v0_5.AxisId, Literal['b', 'i', 't', 'c', 'z', 'y', 'x'], Annotated[Union[bioimageio.spec.model.v0_5.BatchAxis, bioimageio.spec.model.v0_5.ChannelAxis, bioimageio.spec.model.v0_5.IndexInputAxis, bioimageio.spec.model.v0_5.TimeInputAxis, bioimageio.spec.model.v0_5.SpaceInputAxis], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.spec.model.v0_5.BatchAxis, bioimageio.spec.model.v0_5.ChannelAxis, bioimageio.spec.model.v0_5.IndexOutputAxis, Annotated[Union[Annotated[bioimageio.spec.model.v0_5.TimeOutputAxis, Tag(tag='wo_halo')], Annotated[bioimageio.spec.model.v0_5.TimeOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[Annotated[bioimageio.spec.model.v0_5.SpaceOutputAxis, Tag(tag='wo_halo')], Annotated[bioimageio.spec.model.v0_5.SpaceOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], bioimageio.core.Axis]], NoneType]) -> Tensor:
146    @classmethod
147    def from_numpy(
148        cls,
149        array: NDArray[Any],
150        *,
151        dims: Optional[Union[AxisLike, Sequence[AxisLike]]],
152    ) -> Tensor:
153        """create a `Tensor` from a numpy array
154
155        Args:
156            array: the nd numpy array
157            axes: A description of the array's axes,
158                if None axes are guessed (which might fail and raise a ValueError.)
159
160        Raises:
161            ValueError: if `axes` is None and axes guessing fails.
162        """
163
164        if dims is None:
165            return cls._interprete_array_wo_known_axes(array)
166        elif isinstance(dims, (str, Axis, v0_5.AxisBase)):
167            dims = [dims]
168
169        axis_infos = [AxisInfo.create(a) for a in dims]
170        original_shape = tuple(array.shape)
171
172        successful_view = _get_array_view(array, axis_infos)
173        if successful_view is None:
174            raise ValueError(
175                f"Array shape {original_shape} does not map to axes {dims}"
176            )
177
178        return Tensor(successful_view, dims=tuple(a.id for a in axis_infos))

create a Tensor from a numpy array

Arguments:
  • array: the nd numpy array
  • axes: A description of the array's axes, if None axes are guessed (which might fail and raise a ValueError.)
Raises:
  • ValueError: if axes is None and axes guessing fails.
data
180    @property
181    def data(self):
182        return self._data
dims
184    @property
185    def dims(self):  # TODO: rename to `axes`?
186        """Tuple of dimension names associated with this tensor."""
187        return cast(Tuple[AxisId, ...], self._data.dims)

Tuple of dimension names associated with this tensor.

tagged_shape
189    @property
190    def tagged_shape(self):
191        """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths."""
192        return self.sizes

(alias for sizes) Ordered, immutable mapping from axis ids to lengths.

shape_tuple
194    @property
195    def shape_tuple(self):
196        """Tuple of tensor axes lengths"""
197        return self._data.shape

Tuple of tensor axes lengths

size
199    @property
200    def size(self):
201        """Number of elements in the tensor.
202
203        Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions.
204        """
205        return self._data.size

Number of elements in the tensor.

Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions.

def sum( self, dim: Union[bioimageio.spec.model.v0_5.AxisId, Sequence[bioimageio.spec.model.v0_5.AxisId], NoneType] = None) -> Self:
207    def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
208        """Reduce this Tensor's data by applying sum along some dimension(s)."""
209        return self.__class__.from_xarray(self._data.sum(dim=dim))

Reduce this Tensor's data by applying sum along some dimension(s).

ndim
211    @property
212    def ndim(self):
213        """Number of tensor dimensions."""
214        return self._data.ndim

Number of tensor dimensions.

dtype: Literal['bool', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64']
216    @property
217    def dtype(self) -> DTypeStr:
218        dt = str(self.data.dtype)  # pyright: ignore[reportUnknownArgumentType]
219        assert dt in get_args(DTypeStr)
220        return dt  # pyright: ignore[reportReturnType]
sizes
222    @property
223    def sizes(self):
224        """Ordered, immutable mapping from axis ids to axis lengths."""
225        return cast(Mapping[AxisId, int], self.data.sizes)

Ordered, immutable mapping from axis ids to axis lengths.

def astype( self, dtype: Literal['bool', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64'], *, copy: bool = False):
227    def astype(self, dtype: DTypeStr, *, copy: bool = False):
228        """Return tensor cast to `dtype`
229
230        note: if dtype is already satisfied copy if `copy`"""
231        return self.__class__.from_xarray(self._data.astype(dtype, copy=copy))

Return tensor cast to dtype

note: if dtype is already satisfied copy if copy

def clip(self, min: Optional[float] = None, max: Optional[float] = None):
233    def clip(self, min: Optional[float] = None, max: Optional[float] = None):
234        """Return a tensor whose values are limited to [min, max].
235        At least one of max or min must be given."""
236        return self.__class__.from_xarray(self._data.clip(min, max))

Return a tensor whose values are limited to [min, max]. At least one of max or min must be given.

def crop_to( self, sizes: Mapping[bioimageio.spec.model.v0_5.AxisId, int], crop_where: Union[Literal['left', 'right', 'left_and_right'], Mapping[bioimageio.spec.model.v0_5.AxisId, Literal['left', 'right', 'left_and_right']]] = 'left_and_right') -> Self:
238    def crop_to(
239        self,
240        sizes: PerAxis[int],
241        crop_where: Union[
242            CropWhere,
243            PerAxis[CropWhere],
244        ] = "left_and_right",
245    ) -> Self:
246        """crop to match `sizes`"""
247        if isinstance(crop_where, str):
248            crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims}
249        else:
250            crop_axis_where = crop_where
251
252        slices: Dict[AxisId, SliceInfo] = {}
253
254        for a, s_is in self.sizes.items():
255            if a not in sizes or sizes[a] == s_is:
256                pass
257            elif sizes[a] > s_is:
258                logger.warning(
259                    "Cannot crop axis {} of size {} to larger size {}",
260                    a,
261                    s_is,
262                    sizes[a],
263                )
264            elif a not in crop_axis_where:
265                raise ValueError(
266                    f"Don't know where to crop axis {a}, `crop_where`={crop_where}"
267                )
268            else:
269                crop_this_axis_where = crop_axis_where[a]
270                if crop_this_axis_where == "left":
271                    slices[a] = SliceInfo(s_is - sizes[a], s_is)
272                elif crop_this_axis_where == "right":
273                    slices[a] = SliceInfo(0, sizes[a])
274                elif crop_this_axis_where == "left_and_right":
275                    slices[a] = SliceInfo(
276                        start := (s_is - sizes[a]) // 2, sizes[a] + start
277                    )
278                else:
279                    assert_never(crop_this_axis_where)
280
281        return self[slices]

crop to match sizes

def expand_dims( self, dims: Union[Sequence[bioimageio.spec.model.v0_5.AxisId], Mapping[bioimageio.spec.model.v0_5.AxisId, int]]) -> Self:
283    def expand_dims(self, dims: Union[Sequence[AxisId], PerAxis[int]]) -> Self:
284        return self.__class__.from_xarray(self._data.expand_dims(dims=dims))
def mean( self, dim: Union[bioimageio.spec.model.v0_5.AxisId, Sequence[bioimageio.spec.model.v0_5.AxisId], NoneType] = None) -> Self:
286    def mean(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
287        return self.__class__.from_xarray(self._data.mean(dim=dim))
def std( self, dim: Union[bioimageio.spec.model.v0_5.AxisId, Sequence[bioimageio.spec.model.v0_5.AxisId], NoneType] = None) -> Self:
289    def std(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
290        return self.__class__.from_xarray(self._data.std(dim=dim))
def var( self, dim: Union[bioimageio.spec.model.v0_5.AxisId, Sequence[bioimageio.spec.model.v0_5.AxisId], NoneType] = None) -> Self:
292    def var(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self:
293        return self.__class__.from_xarray(self._data.var(dim=dim))
def pad( self, pad_width: Mapping[bioimageio.spec.model.v0_5.AxisId, Union[int, Tuple[int, int], bioimageio.core.common.PadWidth]], mode: Literal['edge', 'reflect', 'symmetric'] = 'symmetric') -> Self:
295    def pad(
296        self,
297        pad_width: PerAxis[PadWidthLike],
298        mode: PadMode = "symmetric",
299    ) -> Self:
300        pad_width = {a: PadWidth.create(p) for a, p in pad_width.items()}
301        return self.__class__.from_xarray(
302            self._data.pad(pad_width=pad_width, mode=mode)
303        )
def pad_to( self, sizes: Mapping[bioimageio.spec.model.v0_5.AxisId, int], pad_where: Union[Literal['left', 'right', 'left_and_right'], Mapping[bioimageio.spec.model.v0_5.AxisId, Literal['left', 'right', 'left_and_right']]] = 'left_and_right', mode: Literal['edge', 'reflect', 'symmetric'] = 'symmetric') -> Self:
305    def pad_to(
306        self,
307        sizes: PerAxis[int],
308        pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right",
309        mode: PadMode = "symmetric",
310    ) -> Self:
311        """pad `tensor` to match `sizes`"""
312        if isinstance(pad_where, str):
313            pad_axis_where: PerAxis[PadWhere] = {a: pad_where for a in self.dims}
314        else:
315            pad_axis_where = pad_where
316
317        pad_width: Dict[AxisId, PadWidth] = {}
318        for a, s_is in self.sizes.items():
319            if a not in sizes or sizes[a] == s_is:
320                pad_width[a] = PadWidth(0, 0)
321            elif s_is > sizes[a]:
322                pad_width[a] = PadWidth(0, 0)
323                logger.warning(
324                    "Cannot pad axis {} of size {} to smaller size {}",
325                    a,
326                    s_is,
327                    sizes[a],
328                )
329            elif a not in pad_axis_where:
330                raise ValueError(
331                    f"Don't know where to pad axis {a}, `pad_where`={pad_where}"
332                )
333            else:
334                pad_this_axis_where = pad_axis_where[a]
335                d = sizes[a] - s_is
336                if pad_this_axis_where == "left":
337                    pad_width[a] = PadWidth(d, 0)
338                elif pad_this_axis_where == "right":
339                    pad_width[a] = PadWidth(0, d)
340                elif pad_this_axis_where == "left_and_right":
341                    pad_width[a] = PadWidth(left := d // 2, d - left)
342                else:
343                    assert_never(pad_this_axis_where)
344
345        return self.pad(pad_width, mode)

pad tensor to match sizes

def quantile( self, q: Union[float, Sequence[float]], dim: Union[bioimageio.spec.model.v0_5.AxisId, Sequence[bioimageio.spec.model.v0_5.AxisId], NoneType] = None) -> Self:
347    def quantile(
348        self,
349        q: Union[float, Sequence[float]],
350        dim: Optional[Union[AxisId, Sequence[AxisId]]] = None,
351    ) -> Self:
352        assert (
353            isinstance(q, (float, int))
354            and q >= 0.0
355            or not isinstance(q, (float, int))
356            and all(qq >= 0.0 for qq in q)
357        )
358        assert (
359            isinstance(q, (float, int))
360            and q <= 1.0
361            or not isinstance(q, (float, int))
362            and all(qq <= 1.0 for qq in q)
363        )
364        assert dim is None or (
365            (quantile_dim := AxisId("quantile")) != dim and quantile_dim not in set(dim)
366        )
367        return self.__class__.from_xarray(self._data.quantile(q, dim=dim))
def resize_to( self, sizes: Mapping[bioimageio.spec.model.v0_5.AxisId, int], *, pad_where: Union[Literal['left', 'right', 'left_and_right'], Mapping[bioimageio.spec.model.v0_5.AxisId, Literal['left', 'right', 'left_and_right']]] = 'left_and_right', crop_where: Union[Literal['left', 'right', 'left_and_right'], Mapping[bioimageio.spec.model.v0_5.AxisId, Literal['left', 'right', 'left_and_right']]] = 'left_and_right', pad_mode: Literal['edge', 'reflect', 'symmetric'] = 'symmetric'):
369    def resize_to(
370        self,
371        sizes: PerAxis[int],
372        *,
373        pad_where: Union[
374            PadWhere,
375            PerAxis[PadWhere],
376        ] = "left_and_right",
377        crop_where: Union[
378            CropWhere,
379            PerAxis[CropWhere],
380        ] = "left_and_right",
381        pad_mode: PadMode = "symmetric",
382    ):
383        """return cropped/padded tensor with `sizes`"""
384        crop_to_sizes: Dict[AxisId, int] = {}
385        pad_to_sizes: Dict[AxisId, int] = {}
386        new_axes = dict(sizes)
387        for a, s_is in self.sizes.items():
388            a = AxisId(str(a))
389            _ = new_axes.pop(a, None)
390            if a not in sizes or sizes[a] == s_is:
391                pass
392            elif s_is > sizes[a]:
393                crop_to_sizes[a] = sizes[a]
394            else:
395                pad_to_sizes[a] = sizes[a]
396
397        tensor = self
398        if crop_to_sizes:
399            tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where)
400
401        if pad_to_sizes:
402            tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode)
403
404        if new_axes:
405            tensor = tensor.expand_dims(new_axes)
406
407        return tensor

return cropped/padded tensor with sizes

def transpose(self, axes: Sequence[bioimageio.spec.model.v0_5.AxisId]) -> Self:
409    def transpose(
410        self,
411        axes: Sequence[AxisId],
412    ) -> Self:
413        """return a transposed tensor
414
415        Args:
416            axes: the desired tensor axes
417        """
418        # expand missing tensor axes
419        missing_axes = tuple(a for a in axes if a not in self.dims)
420        array = self._data
421        if missing_axes:
422            array = array.expand_dims(missing_axes)
423
424        # transpose to the correct axis order
425        return self.__class__.from_xarray(array.transpose(*axes))

return a transposed tensor

Arguments:
  • axes: the desired tensor axes