Coverage for src / bioimageio / core / axis.py: 88%
82 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-18 12:35 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-18 12:35 +0000
1from __future__ import annotations
3from dataclasses import dataclass
4from typing import Literal, Mapping, Optional, TypeVar, Union
6from typing_extensions import Protocol, assert_never, runtime_checkable
8from bioimageio.spec.model import v0_5
11def _guess_axis_type(a: str):
12 if a in ("b", "batch"):
13 return "batch"
14 elif a in ("t", "time"):
15 return "time"
16 elif a in ("i", "index"):
17 return "index"
18 elif a in ("c", "channel"):
19 return "channel"
20 elif a in ("x", "y", "z"):
21 return "space"
22 else:
23 raise ValueError(
24 f"Failed to infer axis type for axis id '{a}'."
25 + " Consider using one of: '"
26 + "', '".join(
27 ["b", "batch", "t", "time", "i", "index", "c", "channel", "x", "y", "z"]
28 )
29 + "'. Or creating an `Axis` object instead."
30 )
33S = TypeVar("S", bound=str)
36AxisId = v0_5.AxisId
37"""An axis identifier, e.g. 'batch', 'channel', 'z', 'y', 'x'"""
39T = TypeVar("T")
40PerAxis = Mapping[AxisId, T]
42BatchSize = int
44AxisLetter = Literal["b", "i", "t", "c", "z", "y", "x"]
45_AxisLikePlain = Union[str, AxisId, AxisLetter]
48@runtime_checkable
49class AxisDescrLike(Protocol):
50 id: _AxisLikePlain
51 type: Literal["batch", "channel", "index", "space", "time"]
54AxisLike = Union[_AxisLikePlain, AxisDescrLike, v0_5.AnyAxis, "Axis", "AxisInfo"]
57@dataclass
58class Axis:
59 id: AxisId
60 type: Literal["batch", "channel", "index", "space", "time"]
62 def __post_init__(self):
63 if self.type == "batch":
64 self.id = AxisId("batch")
65 elif self.type == "channel":
66 self.id = AxisId("channel")
68 @classmethod
69 def create(cls, axis: AxisLike) -> Axis:
70 if isinstance(axis, cls):
71 return axis
73 if isinstance(axis, (AxisId, str)):
74 axis_id = axis
75 axis_type = _guess_axis_type(str(axis))
76 else:
77 if hasattr(axis, "type"):
78 axis_type = axis.type
79 else:
80 axis_type = _guess_axis_type(str(axis))
82 if hasattr(axis, "id"):
83 axis_id = axis.id
84 else:
85 axis_id = axis
87 return Axis(id=AxisId(axis_id), type=axis_type)
90@dataclass
91class AxisSize:
92 min: int
93 max: Optional[int] = None
94 step: Optional[int] = None
97@dataclass
98class AxisInfo(Axis):
99 size: AxisSize
101 @classmethod
102 def create(
103 cls, axis: AxisLike, size: Optional[Union[int, AxisSize]] = None
104 ) -> AxisInfo:
105 if isinstance(axis, AxisInfo):
106 return axis
108 axis_base = super().create(axis)
109 if size is None:
110 if not isinstance(axis, v0_5.AxisBase):
111 size = AxisSize(min=1)
112 else:
113 if axis.size is None:
114 size = AxisSize(min=1)
115 elif isinstance(axis.size, int):
116 size = AxisSize(
117 min=axis.size,
118 max=None
119 if isinstance(axis, (v0_5.TimeAxisBase, v0_5.SpaceAxisBase))
120 or (
121 not isinstance(axis, v0_5.IndexOutputAxis)
122 and axis.concatenable
123 )
124 else axis.size,
125 )
126 elif isinstance(axis.size, v0_5.SizeReference):
127 size = AxisSize(min=axis.size.offset + 1)
128 elif isinstance(axis.size, v0_5.ParameterizedSize):
129 size = AxisSize(min=axis.size.min, step=axis.size.step)
130 elif isinstance(axis.size, v0_5.DataDependentSize):
131 size = AxisSize(min=axis.size.min, max=axis.size.max)
132 else:
133 assert_never(axis.size)
134 elif isinstance(size, int):
135 size = AxisSize(min=size, max=size)
137 return AxisInfo(id=axis_base.id, type=axis_base.type, size=size)