bioimageio.core.axis

  1from __future__ import annotations
  2
  3from dataclasses import dataclass
  4from typing import Literal, Mapping, Optional, TypeVar, Union
  5
  6from bioimageio.spec.model import v0_5
  7from typing_extensions import Protocol, assert_never, runtime_checkable
  8
  9
 10def _guess_axis_type(a: str):
 11    if a in ("b", "batch"):
 12        return "batch"
 13    elif a in ("t", "time"):
 14        return "time"
 15    elif a in ("i", "index"):
 16        return "index"
 17    elif a in ("c", "channel"):
 18        return "channel"
 19    elif a in ("x", "y", "z"):
 20        return "space"
 21    else:
 22        raise ValueError(
 23            f"Failed to infer axis type for axis id '{a}'."
 24            + " Consider using one of: '"
 25            + "', '".join(
 26                ["b", "batch", "t", "time", "i", "index", "c", "channel", "x", "y", "z"]
 27            )
 28            + "'. Or creating an `Axis` object instead."
 29        )
 30
 31
 32S = TypeVar("S", bound=str)
 33
 34
 35AxisId = v0_5.AxisId
 36"""An axis identifier, e.g. 'batch', 'channel', 'z', 'y', 'x'"""
 37
 38T = TypeVar("T")
 39PerAxis = Mapping[AxisId, T]
 40
 41BatchSize = int
 42
 43AxisLetter = Literal["b", "i", "t", "c", "z", "y", "x"]
 44_AxisLikePlain = Union[str, AxisId, AxisLetter]
 45
 46
 47@runtime_checkable
 48class AxisDescrLike(Protocol):
 49    id: _AxisLikePlain
 50    type: Literal["batch", "channel", "index", "space", "time"]
 51
 52
 53AxisLike = Union[_AxisLikePlain, AxisDescrLike, v0_5.AnyAxis, "Axis"]
 54
 55
 56@dataclass
 57class Axis:
 58    id: AxisId
 59    type: Literal["batch", "channel", "index", "space", "time"]
 60
 61    def __post_init__(self):
 62        if self.type == "batch":
 63            self.id = AxisId("batch")
 64        elif self.type == "channel":
 65            self.id = AxisId("channel")
 66
 67    @classmethod
 68    def create(cls, axis: AxisLike) -> Axis:
 69        if isinstance(axis, cls):
 70            return axis
 71
 72        if isinstance(axis, (AxisId, str)):
 73            axis_id = axis
 74            axis_type = _guess_axis_type(str(axis))
 75        else:
 76            if hasattr(axis, "type"):
 77                axis_type = axis.type
 78            else:
 79                axis_type = _guess_axis_type(str(axis))
 80
 81            if hasattr(axis, "id"):
 82                axis_id = axis.id
 83            else:
 84                axis_id = axis
 85
 86        return Axis(id=AxisId(axis_id), type=axis_type)
 87
 88
 89@dataclass
 90class AxisInfo(Axis):
 91    maybe_singleton: bool  # TODO: replace 'maybe_singleton' with size min/max for better axis guessing
 92
 93    @classmethod
 94    def create(cls, axis: AxisLike, maybe_singleton: Optional[bool] = None) -> AxisInfo:
 95        if isinstance(axis, AxisInfo):
 96            return axis
 97
 98        axis_base = super().create(axis)
 99        if maybe_singleton is None:
100            if not isinstance(axis, v0_5.AxisBase):
101                maybe_singleton = True
102            else:
103                if axis.size is None:
104                    maybe_singleton = True
105                elif isinstance(axis.size, int):
106                    maybe_singleton = axis.size == 1
107                elif isinstance(axis.size, v0_5.SizeReference):
108                    maybe_singleton = (
109                        True  # TODO: check if singleton is ok for a `SizeReference`
110                    )
111                elif isinstance(
112                    axis.size, (v0_5.ParameterizedSize, v0_5.DataDependentSize)
113                ):
114                    try:
115                        maybe_size_one = axis.size.validate_size(
116                            1
117                        )  # TODO: refactor validate_size() to have boolean func here
118                    except ValueError:
119                        maybe_singleton = False
120                    else:
121                        maybe_singleton = maybe_size_one == 1
122                else:
123                    assert_never(axis.size)
124
125        return AxisInfo(
126            id=axis_base.id, type=axis_base.type, maybe_singleton=maybe_singleton
127        )
class AxisId(bioimageio.spec._internal.types.LowerCaseIdentifier):
245class AxisId(LowerCaseIdentifier):
246    root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
247        Annotated[
248            LowerCaseIdentifierAnno,
249            MaxLen(16),
250            AfterValidator(_normalize_axis_id),
251        ]
252    ]

An axis identifier, e.g. 'batch', 'channel', 'z', 'y', 'x'

root_model: ClassVar[Type[pydantic.root_model.RootModel[Any]]] = <class 'pydantic.root_model.RootModel[Annotated[str, MinLen, AfterValidator, AfterValidator, Annotated[TypeVar, Predicate], MaxLen, AfterValidator]]'>

the pydantic root model to validate the string

PerAxis = typing.Mapping[AxisId, ~T]
BatchSize = <class 'int'>
AxisLetter = typing.Literal['b', 'i', 't', 'c', 'z', 'y', 'x']
@runtime_checkable
class AxisDescrLike(typing_extensions.Protocol):
48@runtime_checkable
49class AxisDescrLike(Protocol):
50    id: _AxisLikePlain
51    type: Literal["batch", "channel", "index", "space", "time"]

Base class for protocol classes.

Protocol classes are defined as::

class Proto(Protocol):
    def meth(self) -> int:
        ...

Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).

For example::

class C:
    def meth(self) -> int:
        return 0

def func(x: Proto) -> int:
    return x.meth()

func(C())  # Passes static type check

See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::

class GenProto[T](Protocol):
    def meth(self) -> T:
        ...
AxisDescrLike(*args, **kwargs)
641    def _no_init(self, *args, **kwargs):
642        if type(self)._is_protocol:
643            raise TypeError('Protocols cannot be instantiated')
id: Union[str, AxisId, Literal['b', 'i', 't', 'c', 'z', 'y', 'x']]
type: Literal['batch', 'channel', 'index', 'space', 'time']
AxisLike = typing.Union[str, AxisId, typing.Literal['b', 'i', 't', 'c', 'z', 'y', 'x'], AxisDescrLike, typing.Annotated[typing.Union[bioimageio.spec.model.v0_5.BatchAxis, bioimageio.spec.model.v0_5.ChannelAxis, bioimageio.spec.model.v0_5.IndexInputAxis, bioimageio.spec.model.v0_5.TimeInputAxis, bioimageio.spec.model.v0_5.SpaceInputAxis], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], typing.Annotated[typing.Union[bioimageio.spec.model.v0_5.BatchAxis, bioimageio.spec.model.v0_5.ChannelAxis, bioimageio.spec.model.v0_5.IndexOutputAxis, typing.Annotated[typing.Union[typing.Annotated[bioimageio.spec.model.v0_5.TimeOutputAxis, Tag(tag='wo_halo')], typing.Annotated[bioimageio.spec.model.v0_5.TimeOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)], typing.Annotated[typing.Union[typing.Annotated[bioimageio.spec.model.v0_5.SpaceOutputAxis, Tag(tag='wo_halo')], typing.Annotated[bioimageio.spec.model.v0_5.SpaceOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], ForwardRef('Axis')]
@dataclass
class Axis:
57@dataclass
58class Axis:
59    id: AxisId
60    type: Literal["batch", "channel", "index", "space", "time"]
61
62    def __post_init__(self):
63        if self.type == "batch":
64            self.id = AxisId("batch")
65        elif self.type == "channel":
66            self.id = AxisId("channel")
67
68    @classmethod
69    def create(cls, axis: AxisLike) -> Axis:
70        if isinstance(axis, cls):
71            return axis
72
73        if isinstance(axis, (AxisId, str)):
74            axis_id = axis
75            axis_type = _guess_axis_type(str(axis))
76        else:
77            if hasattr(axis, "type"):
78                axis_type = axis.type
79            else:
80                axis_type = _guess_axis_type(str(axis))
81
82            if hasattr(axis, "id"):
83                axis_id = axis.id
84            else:
85                axis_id = axis
86
87        return Axis(id=AxisId(axis_id), type=axis_type)
Axis( id: AxisId, type: Literal['batch', 'channel', 'index', 'space', 'time'])
id: AxisId
type: Literal['batch', 'channel', 'index', 'space', 'time']
@classmethod
def create( cls, axis: Union[str, AxisId, Literal['b', 'i', 't', 'c', 'z', 'y', 'x'], AxisDescrLike, Annotated[Union[bioimageio.spec.model.v0_5.BatchAxis, bioimageio.spec.model.v0_5.ChannelAxis, bioimageio.spec.model.v0_5.IndexInputAxis, bioimageio.spec.model.v0_5.TimeInputAxis, bioimageio.spec.model.v0_5.SpaceInputAxis], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.spec.model.v0_5.BatchAxis, bioimageio.spec.model.v0_5.ChannelAxis, bioimageio.spec.model.v0_5.IndexOutputAxis, Annotated[Union[Annotated[bioimageio.spec.model.v0_5.TimeOutputAxis, Tag(tag='wo_halo')], Annotated[bioimageio.spec.model.v0_5.TimeOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[Annotated[bioimageio.spec.model.v0_5.SpaceOutputAxis, Tag(tag='wo_halo')], Annotated[bioimageio.spec.model.v0_5.SpaceOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Axis]) -> Axis:
68    @classmethod
69    def create(cls, axis: AxisLike) -> Axis:
70        if isinstance(axis, cls):
71            return axis
72
73        if isinstance(axis, (AxisId, str)):
74            axis_id = axis
75            axis_type = _guess_axis_type(str(axis))
76        else:
77            if hasattr(axis, "type"):
78                axis_type = axis.type
79            else:
80                axis_type = _guess_axis_type(str(axis))
81
82            if hasattr(axis, "id"):
83                axis_id = axis.id
84            else:
85                axis_id = axis
86
87        return Axis(id=AxisId(axis_id), type=axis_type)
@dataclass
class AxisInfo(Axis):
 90@dataclass
 91class AxisInfo(Axis):
 92    maybe_singleton: bool  # TODO: replace 'maybe_singleton' with size min/max for better axis guessing
 93
 94    @classmethod
 95    def create(cls, axis: AxisLike, maybe_singleton: Optional[bool] = None) -> AxisInfo:
 96        if isinstance(axis, AxisInfo):
 97            return axis
 98
 99        axis_base = super().create(axis)
100        if maybe_singleton is None:
101            if not isinstance(axis, v0_5.AxisBase):
102                maybe_singleton = True
103            else:
104                if axis.size is None:
105                    maybe_singleton = True
106                elif isinstance(axis.size, int):
107                    maybe_singleton = axis.size == 1
108                elif isinstance(axis.size, v0_5.SizeReference):
109                    maybe_singleton = (
110                        True  # TODO: check if singleton is ok for a `SizeReference`
111                    )
112                elif isinstance(
113                    axis.size, (v0_5.ParameterizedSize, v0_5.DataDependentSize)
114                ):
115                    try:
116                        maybe_size_one = axis.size.validate_size(
117                            1
118                        )  # TODO: refactor validate_size() to have boolean func here
119                    except ValueError:
120                        maybe_singleton = False
121                    else:
122                        maybe_singleton = maybe_size_one == 1
123                else:
124                    assert_never(axis.size)
125
126        return AxisInfo(
127            id=axis_base.id, type=axis_base.type, maybe_singleton=maybe_singleton
128        )
AxisInfo( id: AxisId, type: Literal['batch', 'channel', 'index', 'space', 'time'], maybe_singleton: bool)
maybe_singleton: bool
@classmethod
def create( cls, axis: Union[str, AxisId, Literal['b', 'i', 't', 'c', 'z', 'y', 'x'], AxisDescrLike, Annotated[Union[bioimageio.spec.model.v0_5.BatchAxis, bioimageio.spec.model.v0_5.ChannelAxis, bioimageio.spec.model.v0_5.IndexInputAxis, bioimageio.spec.model.v0_5.TimeInputAxis, bioimageio.spec.model.v0_5.SpaceInputAxis], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.spec.model.v0_5.BatchAxis, bioimageio.spec.model.v0_5.ChannelAxis, bioimageio.spec.model.v0_5.IndexOutputAxis, Annotated[Union[Annotated[bioimageio.spec.model.v0_5.TimeOutputAxis, Tag(tag='wo_halo')], Annotated[bioimageio.spec.model.v0_5.TimeOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[Annotated[bioimageio.spec.model.v0_5.SpaceOutputAxis, Tag(tag='wo_halo')], Annotated[bioimageio.spec.model.v0_5.SpaceOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Axis], maybe_singleton: Optional[bool] = None) -> AxisInfo:
 94    @classmethod
 95    def create(cls, axis: AxisLike, maybe_singleton: Optional[bool] = None) -> AxisInfo:
 96        if isinstance(axis, AxisInfo):
 97            return axis
 98
 99        axis_base = super().create(axis)
100        if maybe_singleton is None:
101            if not isinstance(axis, v0_5.AxisBase):
102                maybe_singleton = True
103            else:
104                if axis.size is None:
105                    maybe_singleton = True
106                elif isinstance(axis.size, int):
107                    maybe_singleton = axis.size == 1
108                elif isinstance(axis.size, v0_5.SizeReference):
109                    maybe_singleton = (
110                        True  # TODO: check if singleton is ok for a `SizeReference`
111                    )
112                elif isinstance(
113                    axis.size, (v0_5.ParameterizedSize, v0_5.DataDependentSize)
114                ):
115                    try:
116                        maybe_size_one = axis.size.validate_size(
117                            1
118                        )  # TODO: refactor validate_size() to have boolean func here
119                    except ValueError:
120                        maybe_singleton = False
121                    else:
122                        maybe_singleton = maybe_size_one == 1
123                else:
124                    assert_never(axis.size)
125
126        return AxisInfo(
127            id=axis_base.id, type=axis_base.type, maybe_singleton=maybe_singleton
128        )
Inherited Members
Axis
id
type