Coverage for src/bioimageio/core/axis.py: 77%
77 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-14 08:35 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-14 08:35 +0000
1from __future__ import annotations
3from dataclasses import dataclass
4from typing import Literal, Mapping, Optional, TypeVar, Union
6from bioimageio.spec.model import v0_5
7from typing_extensions import Protocol, assert_never, runtime_checkable
10def _guess_axis_type(a: str):
11 if a in ("b", "batch"):
12 return "batch"
13 elif a in ("t", "time"):
14 return "time"
15 elif a in ("i", "index"):
16 return "index"
17 elif a in ("c", "channel"):
18 return "channel"
19 elif a in ("x", "y", "z"):
20 return "space"
21 else:
22 raise ValueError(
23 f"Failed to infer axis type for axis id '{a}'."
24 + " Consider using one of: '"
25 + "', '".join(
26 ["b", "batch", "t", "time", "i", "index", "c", "channel", "x", "y", "z"]
27 )
28 + "'. Or creating an `Axis` object instead."
29 )
32S = TypeVar("S", bound=str)
35AxisId = v0_5.AxisId
36"""An axis identifier, e.g. 'batch', 'channel', 'z', 'y', 'x'"""
38T = TypeVar("T")
39PerAxis = Mapping[AxisId, T]
41BatchSize = int
43AxisLetter = Literal["b", "i", "t", "c", "z", "y", "x"]
44_AxisLikePlain = Union[str, AxisId, AxisLetter]
47@runtime_checkable
48class AxisDescrLike(Protocol):
49 id: _AxisLikePlain
50 type: Literal["batch", "channel", "index", "space", "time"]
53AxisLike = Union[_AxisLikePlain, AxisDescrLike, v0_5.AnyAxis, "Axis"]
56@dataclass
57class Axis:
58 id: AxisId
59 type: Literal["batch", "channel", "index", "space", "time"]
61 def __post_init__(self):
62 if self.type == "batch":
63 self.id = AxisId("batch")
64 elif self.type == "channel":
65 self.id = AxisId("channel")
67 @classmethod
68 def create(cls, axis: AxisLike) -> Axis:
69 if isinstance(axis, cls):
70 return axis
72 if isinstance(axis, (AxisId, str)):
73 axis_id = axis
74 axis_type = _guess_axis_type(str(axis))
75 else:
76 if hasattr(axis, "type"):
77 axis_type = axis.type
78 else:
79 axis_type = _guess_axis_type(str(axis))
81 if hasattr(axis, "id"):
82 axis_id = axis.id
83 else:
84 axis_id = axis
86 return Axis(id=AxisId(axis_id), type=axis_type)
89@dataclass
90class AxisInfo(Axis):
91 maybe_singleton: bool # TODO: replace 'maybe_singleton' with size min/max for better axis guessing
93 @classmethod
94 def create(cls, axis: AxisLike, maybe_singleton: Optional[bool] = None) -> AxisInfo:
95 if isinstance(axis, AxisInfo):
96 return axis
98 axis_base = super().create(axis)
99 if maybe_singleton is None:
100 if not isinstance(axis, v0_5.AxisBase):
101 maybe_singleton = True
102 else:
103 if axis.size is None:
104 maybe_singleton = True
105 elif isinstance(axis.size, int):
106 maybe_singleton = axis.size == 1
107 elif isinstance(axis.size, v0_5.SizeReference):
108 maybe_singleton = (
109 True # TODO: check if singleton is ok for a `SizeReference`
110 )
111 elif isinstance(
112 axis.size, (v0_5.ParameterizedSize, v0_5.DataDependentSize)
113 ):
114 try:
115 maybe_size_one = axis.size.validate_size(
116 1
117 ) # TODO: refactor validate_size() to have boolean func here
118 except ValueError:
119 maybe_singleton = False
120 else:
121 maybe_singleton = maybe_size_one == 1
122 else:
123 assert_never(axis.size)
125 return AxisInfo(
126 id=axis_base.id, type=axis_base.type, maybe_singleton=maybe_singleton
127 )