Coverage for src/bioimageio/core/axis.py: 88%
82 statements
« prev ^ index » next coverage.py v7.14.2, created at 2026-06-22 16:54 +0000
« prev ^ index » next coverage.py v7.14.2, created at 2026-06-22 16:54 +0000
1from __future__ import annotations
3from dataclasses import dataclass
4from typing import (
5 Literal,
6 Mapping,
7 Optional,
8 TypeVar,
9 Union,
10)
12from typing_extensions import Protocol, TypeAlias, assert_never, runtime_checkable
14from bioimageio.spec.model import v0_5
17def _guess_axis_type(a: str):
18 if a in ("b", "batch"):
19 return "batch"
20 elif a in ("t", "time"):
21 return "time"
22 elif a in ("i", "index"):
23 return "index"
24 elif a in ("c", "channel"):
25 return "channel"
26 elif a in ("x", "y", "z"):
27 return "space"
28 else:
29 raise ValueError(
30 f"Failed to infer axis type for axis id '{a}'."
31 + " Consider using one of: '"
32 + "', '".join(
33 ["b", "batch", "t", "time", "i", "index", "c", "channel", "x", "y", "z"]
34 )
35 + "'. Or creating an `Axis` object instead."
36 )
39S = TypeVar("S", bound=str)
42AxisId: TypeAlias = v0_5.AxisId
43"""An axis identifier, e.g. 'batch', 'channel', 'z', 'y', 'x'"""
45_T = TypeVar("_T")
46PerAxis = Mapping[AxisId, _T]
49BatchSize = int
51AxisLetter = Literal["b", "i", "t", "c", "z", "y", "x"]
52_AxisLikePlain = Union[str, AxisId, AxisLetter]
55@runtime_checkable
56class AxisDescrLike(Protocol):
57 id: _AxisLikePlain
58 type: Literal["batch", "channel", "index", "space", "time"]
61AxisLike = Union[_AxisLikePlain, AxisDescrLike, v0_5.AnyAxis, "Axis", "AxisInfo"]
64@dataclass
65class Axis:
66 id: AxisId
67 type: Literal["batch", "channel", "index", "space", "time"]
69 def __post_init__(self):
70 if self.type == "batch":
71 self.id = AxisId("batch")
72 elif self.type == "channel":
73 self.id = AxisId("channel")
75 @classmethod
76 def create(cls, axis: AxisLike) -> Axis:
77 if isinstance(axis, cls):
78 return axis
80 if isinstance(axis, (AxisId, str)):
81 axis_id = axis
82 axis_type = _guess_axis_type(str(axis))
83 else:
84 if hasattr(axis, "type"):
85 axis_type = axis.type
86 else:
87 axis_type = _guess_axis_type(str(axis))
89 if hasattr(axis, "id"):
90 axis_id = axis.id
91 else:
92 axis_id = axis
94 return Axis(id=AxisId(axis_id), type=axis_type)
97@dataclass
98class AxisSize:
99 min: int
100 max: Optional[int] = None
101 step: Optional[int] = None
104@dataclass
105class AxisInfo(Axis):
106 size: AxisSize
108 @classmethod
109 def create(
110 cls, axis: AxisLike, size: Optional[Union[int, AxisSize]] = None
111 ) -> AxisInfo:
112 if isinstance(axis, AxisInfo):
113 return axis
115 axis_base = super().create(axis)
116 if size is None:
117 if not isinstance(axis, v0_5.AxisBase):
118 size = AxisSize(min=1)
119 else:
120 if axis.size is None:
121 size = AxisSize(min=1)
122 elif isinstance(axis.size, int):
123 size = AxisSize(
124 min=axis.size,
125 max=None
126 if isinstance(axis, (v0_5.TimeAxisBase, v0_5.SpaceAxisBase))
127 or (
128 not isinstance(axis, v0_5.IndexOutputAxis)
129 and axis.concatenable
130 )
131 else axis.size,
132 )
133 elif isinstance(axis.size, v0_5.SizeReference):
134 size = AxisSize(min=axis.size.offset + 1)
135 elif isinstance(axis.size, v0_5.ParameterizedSize):
136 size = AxisSize(min=axis.size.min, step=axis.size.step)
137 elif isinstance(axis.size, v0_5.DataDependentSize):
138 size = AxisSize(min=axis.size.min, max=axis.size.max)
139 else:
140 assert_never(axis.size)
141 elif isinstance(size, int):
142 size = AxisSize(min=size, max=size)
144 return AxisInfo(id=axis_base.id, type=axis_base.type, size=size)