Coverage for bioimageio/core/axis.py: 75%

68 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-01 09:51 +0000

1from __future__ import annotations 

2 

3from dataclasses import dataclass 

4from typing import Literal, Mapping, Optional, TypeVar, Union 

5 

6from typing_extensions import assert_never 

7 

8from bioimageio.spec.model import v0_5 

9 

10 

11def _guess_axis_type(a: str): 

12 if a in ("b", "batch"): 

13 return "batch" 

14 elif a in ("t", "time"): 

15 return "time" 

16 elif a in ("i", "index"): 

17 return "index" 

18 elif a in ("c", "channel"): 

19 return "channel" 

20 elif a in ("x", "y", "z"): 

21 return "space" 

22 else: 

23 raise ValueError( 

24 f"Failed to infer axis type for axis id '{a}'." 

25 + " Consider using one of: '" 

26 + "', '".join( 

27 ["b", "batch", "t", "time", "i", "index", "c", "channel", "x", "y", "z"] 

28 ) 

29 + "'. Or creating an `Axis` object instead." 

30 ) 

31 

32 

33S = TypeVar("S", bound=str) 

34 

35 

36AxisId = v0_5.AxisId 

37"""An axis identifier, e.g. 'batch', 'channel', 'z', 'y', 'x'""" 

38 

39T = TypeVar("T") 

40PerAxis = Mapping[AxisId, T] 

41 

42BatchSize = int 

43 

44AxisLetter = Literal["b", "i", "t", "c", "z", "y", "x"] 

45AxisLike = Union[AxisId, AxisLetter, v0_5.AnyAxis, "Axis"] 

46 

47 

48@dataclass 

49class Axis: 

50 id: AxisId 

51 type: Literal["batch", "channel", "index", "space", "time"] 

52 

53 def __post_init__(self): 

54 if self.type == "batch": 

55 self.id = AxisId("batch") 

56 elif self.type == "channel": 

57 self.id = AxisId("channel") 

58 

59 @classmethod 

60 def create(cls, axis: AxisLike) -> Axis: 

61 if isinstance(axis, cls): 

62 return axis 

63 elif isinstance(axis, Axis): 

64 return Axis(id=axis.id, type=axis.type) 

65 elif isinstance(axis, v0_5.AxisBase): 

66 return Axis(id=AxisId(axis.id), type=axis.type) 

67 elif isinstance(axis, str): 

68 return Axis(id=AxisId(axis), type=_guess_axis_type(axis)) 

69 else: 

70 assert_never(axis) 

71 

72 

73@dataclass 

74class AxisInfo(Axis): 

75 maybe_singleton: bool # TODO: replace 'maybe_singleton' with size min/max for better axis guessing 

76 

77 @classmethod 

78 def create(cls, axis: AxisLike, maybe_singleton: Optional[bool] = None) -> AxisInfo: 

79 if isinstance(axis, AxisInfo): 

80 return axis 

81 

82 axis_base = super().create(axis) 

83 if maybe_singleton is None: 

84 if isinstance(axis, (Axis, str)): 

85 maybe_singleton = True 

86 else: 

87 if axis.size is None: 

88 maybe_singleton = True 

89 elif isinstance(axis.size, int): 

90 maybe_singleton = axis.size == 1 

91 elif isinstance(axis.size, v0_5.SizeReference): 

92 maybe_singleton = ( 

93 True # TODO: check if singleton is ok for a `SizeReference` 

94 ) 

95 elif isinstance( 

96 axis.size, (v0_5.ParameterizedSize, v0_5.DataDependentSize) 

97 ): 

98 try: 

99 maybe_size_one = axis.size.validate_size( 

100 1 

101 ) # TODO: refactor validate_size() to have boolean func here 

102 except ValueError: 

103 maybe_singleton = False 

104 else: 

105 maybe_singleton = maybe_size_one == 1 

106 else: 

107 assert_never(axis.size) 

108 

109 return AxisInfo( 

110 id=axis_base.id, type=axis_base.type, maybe_singleton=maybe_singleton 

111 )