Coverage for src / bioimageio / core / axis.py: 88%

82 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-18 12:35 +0000

1from __future__ import annotations 

2 

3from dataclasses import dataclass 

4from typing import Literal, Mapping, Optional, TypeVar, Union 

5 

6from typing_extensions import Protocol, assert_never, runtime_checkable 

7 

8from bioimageio.spec.model import v0_5 

9 

10 

11def _guess_axis_type(a: str): 

12 if a in ("b", "batch"): 

13 return "batch" 

14 elif a in ("t", "time"): 

15 return "time" 

16 elif a in ("i", "index"): 

17 return "index" 

18 elif a in ("c", "channel"): 

19 return "channel" 

20 elif a in ("x", "y", "z"): 

21 return "space" 

22 else: 

23 raise ValueError( 

24 f"Failed to infer axis type for axis id '{a}'." 

25 + " Consider using one of: '" 

26 + "', '".join( 

27 ["b", "batch", "t", "time", "i", "index", "c", "channel", "x", "y", "z"] 

28 ) 

29 + "'. Or creating an `Axis` object instead." 

30 ) 

31 

32 

33S = TypeVar("S", bound=str) 

34 

35 

36AxisId = v0_5.AxisId 

37"""An axis identifier, e.g. 'batch', 'channel', 'z', 'y', 'x'""" 

38 

39T = TypeVar("T") 

40PerAxis = Mapping[AxisId, T] 

41 

42BatchSize = int 

43 

44AxisLetter = Literal["b", "i", "t", "c", "z", "y", "x"] 

45_AxisLikePlain = Union[str, AxisId, AxisLetter] 

46 

47 

48@runtime_checkable 

49class AxisDescrLike(Protocol): 

50 id: _AxisLikePlain 

51 type: Literal["batch", "channel", "index", "space", "time"] 

52 

53 

54AxisLike = Union[_AxisLikePlain, AxisDescrLike, v0_5.AnyAxis, "Axis", "AxisInfo"] 

55 

56 

57@dataclass 

58class Axis: 

59 id: AxisId 

60 type: Literal["batch", "channel", "index", "space", "time"] 

61 

62 def __post_init__(self): 

63 if self.type == "batch": 

64 self.id = AxisId("batch") 

65 elif self.type == "channel": 

66 self.id = AxisId("channel") 

67 

68 @classmethod 

69 def create(cls, axis: AxisLike) -> Axis: 

70 if isinstance(axis, cls): 

71 return axis 

72 

73 if isinstance(axis, (AxisId, str)): 

74 axis_id = axis 

75 axis_type = _guess_axis_type(str(axis)) 

76 else: 

77 if hasattr(axis, "type"): 

78 axis_type = axis.type 

79 else: 

80 axis_type = _guess_axis_type(str(axis)) 

81 

82 if hasattr(axis, "id"): 

83 axis_id = axis.id 

84 else: 

85 axis_id = axis 

86 

87 return Axis(id=AxisId(axis_id), type=axis_type) 

88 

89 

90@dataclass 

91class AxisSize: 

92 min: int 

93 max: Optional[int] = None 

94 step: Optional[int] = None 

95 

96 

97@dataclass 

98class AxisInfo(Axis): 

99 size: AxisSize 

100 

101 @classmethod 

102 def create( 

103 cls, axis: AxisLike, size: Optional[Union[int, AxisSize]] = None 

104 ) -> AxisInfo: 

105 if isinstance(axis, AxisInfo): 

106 return axis 

107 

108 axis_base = super().create(axis) 

109 if size is None: 

110 if not isinstance(axis, v0_5.AxisBase): 

111 size = AxisSize(min=1) 

112 else: 

113 if axis.size is None: 

114 size = AxisSize(min=1) 

115 elif isinstance(axis.size, int): 

116 size = AxisSize( 

117 min=axis.size, 

118 max=None 

119 if isinstance(axis, (v0_5.TimeAxisBase, v0_5.SpaceAxisBase)) 

120 or ( 

121 not isinstance(axis, v0_5.IndexOutputAxis) 

122 and axis.concatenable 

123 ) 

124 else axis.size, 

125 ) 

126 elif isinstance(axis.size, v0_5.SizeReference): 

127 size = AxisSize(min=axis.size.offset + 1) 

128 elif isinstance(axis.size, v0_5.ParameterizedSize): 

129 size = AxisSize(min=axis.size.min, step=axis.size.step) 

130 elif isinstance(axis.size, v0_5.DataDependentSize): 

131 size = AxisSize(min=axis.size.min, max=axis.size.max) 

132 else: 

133 assert_never(axis.size) 

134 elif isinstance(size, int): 

135 size = AxisSize(min=size, max=size) 

136 

137 return AxisInfo(id=axis_base.id, type=axis_base.type, size=size)