Coverage for src/bioimageio/core/axis.py: 77%

77 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-14 08:35 +0000

1from __future__ import annotations 

2 

3from dataclasses import dataclass 

4from typing import Literal, Mapping, Optional, TypeVar, Union 

5 

6from bioimageio.spec.model import v0_5 

7from typing_extensions import Protocol, assert_never, runtime_checkable 

8 

9 

10def _guess_axis_type(a: str): 

11 if a in ("b", "batch"): 

12 return "batch" 

13 elif a in ("t", "time"): 

14 return "time" 

15 elif a in ("i", "index"): 

16 return "index" 

17 elif a in ("c", "channel"): 

18 return "channel" 

19 elif a in ("x", "y", "z"): 

20 return "space" 

21 else: 

22 raise ValueError( 

23 f"Failed to infer axis type for axis id '{a}'." 

24 + " Consider using one of: '" 

25 + "', '".join( 

26 ["b", "batch", "t", "time", "i", "index", "c", "channel", "x", "y", "z"] 

27 ) 

28 + "'. Or creating an `Axis` object instead." 

29 ) 

30 

31 

32S = TypeVar("S", bound=str) 

33 

34 

35AxisId = v0_5.AxisId 

36"""An axis identifier, e.g. 'batch', 'channel', 'z', 'y', 'x'""" 

37 

38T = TypeVar("T") 

39PerAxis = Mapping[AxisId, T] 

40 

41BatchSize = int 

42 

43AxisLetter = Literal["b", "i", "t", "c", "z", "y", "x"] 

44_AxisLikePlain = Union[str, AxisId, AxisLetter] 

45 

46 

47@runtime_checkable 

48class AxisDescrLike(Protocol): 

49 id: _AxisLikePlain 

50 type: Literal["batch", "channel", "index", "space", "time"] 

51 

52 

53AxisLike = Union[_AxisLikePlain, AxisDescrLike, v0_5.AnyAxis, "Axis"] 

54 

55 

56@dataclass 

57class Axis: 

58 id: AxisId 

59 type: Literal["batch", "channel", "index", "space", "time"] 

60 

61 def __post_init__(self): 

62 if self.type == "batch": 

63 self.id = AxisId("batch") 

64 elif self.type == "channel": 

65 self.id = AxisId("channel") 

66 

67 @classmethod 

68 def create(cls, axis: AxisLike) -> Axis: 

69 if isinstance(axis, cls): 

70 return axis 

71 

72 if isinstance(axis, (AxisId, str)): 

73 axis_id = axis 

74 axis_type = _guess_axis_type(str(axis)) 

75 else: 

76 if hasattr(axis, "type"): 

77 axis_type = axis.type 

78 else: 

79 axis_type = _guess_axis_type(str(axis)) 

80 

81 if hasattr(axis, "id"): 

82 axis_id = axis.id 

83 else: 

84 axis_id = axis 

85 

86 return Axis(id=AxisId(axis_id), type=axis_type) 

87 

88 

89@dataclass 

90class AxisInfo(Axis): 

91 maybe_singleton: bool # TODO: replace 'maybe_singleton' with size min/max for better axis guessing 

92 

93 @classmethod 

94 def create(cls, axis: AxisLike, maybe_singleton: Optional[bool] = None) -> AxisInfo: 

95 if isinstance(axis, AxisInfo): 

96 return axis 

97 

98 axis_base = super().create(axis) 

99 if maybe_singleton is None: 

100 if not isinstance(axis, v0_5.AxisBase): 

101 maybe_singleton = True 

102 else: 

103 if axis.size is None: 

104 maybe_singleton = True 

105 elif isinstance(axis.size, int): 

106 maybe_singleton = axis.size == 1 

107 elif isinstance(axis.size, v0_5.SizeReference): 

108 maybe_singleton = ( 

109 True # TODO: check if singleton is ok for a `SizeReference` 

110 ) 

111 elif isinstance( 

112 axis.size, (v0_5.ParameterizedSize, v0_5.DataDependentSize) 

113 ): 

114 try: 

115 maybe_size_one = axis.size.validate_size( 

116 1 

117 ) # TODO: refactor validate_size() to have boolean func here 

118 except ValueError: 

119 maybe_singleton = False 

120 else: 

121 maybe_singleton = maybe_size_one == 1 

122 else: 

123 assert_never(axis.size) 

124 

125 return AxisInfo( 

126 id=axis_base.id, type=axis_base.type, maybe_singleton=maybe_singleton 

127 )