Coverage for src/bioimageio/core/axis.py: 88%

82 statements  

« prev     ^ index     » next       coverage.py v7.14.2, created at 2026-06-22 16:54 +0000

1from __future__ import annotations 

2 

3from dataclasses import dataclass 

4from typing import ( 

5 Literal, 

6 Mapping, 

7 Optional, 

8 TypeVar, 

9 Union, 

10) 

11 

12from typing_extensions import Protocol, TypeAlias, assert_never, runtime_checkable 

13 

14from bioimageio.spec.model import v0_5 

15 

16 

17def _guess_axis_type(a: str): 

18 if a in ("b", "batch"): 

19 return "batch" 

20 elif a in ("t", "time"): 

21 return "time" 

22 elif a in ("i", "index"): 

23 return "index" 

24 elif a in ("c", "channel"): 

25 return "channel" 

26 elif a in ("x", "y", "z"): 

27 return "space" 

28 else: 

29 raise ValueError( 

30 f"Failed to infer axis type for axis id '{a}'." 

31 + " Consider using one of: '" 

32 + "', '".join( 

33 ["b", "batch", "t", "time", "i", "index", "c", "channel", "x", "y", "z"] 

34 ) 

35 + "'. Or creating an `Axis` object instead." 

36 ) 

37 

38 

39S = TypeVar("S", bound=str) 

40 

41 

42AxisId: TypeAlias = v0_5.AxisId 

43"""An axis identifier, e.g. 'batch', 'channel', 'z', 'y', 'x'""" 

44 

45_T = TypeVar("_T") 

46PerAxis = Mapping[AxisId, _T] 

47 

48 

49BatchSize = int 

50 

51AxisLetter = Literal["b", "i", "t", "c", "z", "y", "x"] 

52_AxisLikePlain = Union[str, AxisId, AxisLetter] 

53 

54 

55@runtime_checkable 

56class AxisDescrLike(Protocol): 

57 id: _AxisLikePlain 

58 type: Literal["batch", "channel", "index", "space", "time"] 

59 

60 

61AxisLike = Union[_AxisLikePlain, AxisDescrLike, v0_5.AnyAxis, "Axis", "AxisInfo"] 

62 

63 

64@dataclass 

65class Axis: 

66 id: AxisId 

67 type: Literal["batch", "channel", "index", "space", "time"] 

68 

69 def __post_init__(self): 

70 if self.type == "batch": 

71 self.id = AxisId("batch") 

72 elif self.type == "channel": 

73 self.id = AxisId("channel") 

74 

75 @classmethod 

76 def create(cls, axis: AxisLike) -> Axis: 

77 if isinstance(axis, cls): 

78 return axis 

79 

80 if isinstance(axis, (AxisId, str)): 

81 axis_id = axis 

82 axis_type = _guess_axis_type(str(axis)) 

83 else: 

84 if hasattr(axis, "type"): 

85 axis_type = axis.type 

86 else: 

87 axis_type = _guess_axis_type(str(axis)) 

88 

89 if hasattr(axis, "id"): 

90 axis_id = axis.id 

91 else: 

92 axis_id = axis 

93 

94 return Axis(id=AxisId(axis_id), type=axis_type) 

95 

96 

97@dataclass 

98class AxisSize: 

99 min: int 

100 max: Optional[int] = None 

101 step: Optional[int] = None 

102 

103 

104@dataclass 

105class AxisInfo(Axis): 

106 size: AxisSize 

107 

108 @classmethod 

109 def create( 

110 cls, axis: AxisLike, size: Optional[Union[int, AxisSize]] = None 

111 ) -> AxisInfo: 

112 if isinstance(axis, AxisInfo): 

113 return axis 

114 

115 axis_base = super().create(axis) 

116 if size is None: 

117 if not isinstance(axis, v0_5.AxisBase): 

118 size = AxisSize(min=1) 

119 else: 

120 if axis.size is None: 

121 size = AxisSize(min=1) 

122 elif isinstance(axis.size, int): 

123 size = AxisSize( 

124 min=axis.size, 

125 max=None 

126 if isinstance(axis, (v0_5.TimeAxisBase, v0_5.SpaceAxisBase)) 

127 or ( 

128 not isinstance(axis, v0_5.IndexOutputAxis) 

129 and axis.concatenable 

130 ) 

131 else axis.size, 

132 ) 

133 elif isinstance(axis.size, v0_5.SizeReference): 

134 size = AxisSize(min=axis.size.offset + 1) 

135 elif isinstance(axis.size, v0_5.ParameterizedSize): 

136 size = AxisSize(min=axis.size.min, step=axis.size.step) 

137 elif isinstance(axis.size, v0_5.DataDependentSize): 

138 size = AxisSize(min=axis.size.min, max=axis.size.max) 

139 else: 

140 assert_never(axis.size) 

141 elif isinstance(size, int): 

142 size = AxisSize(min=size, max=size) 

143 

144 return AxisInfo(id=axis_base.id, type=axis_base.type, size=size)