Coverage for src / bioimageio / spec / _internal / utils.py: 69%

89 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-08 13:04 +0000

1from __future__ import annotations 

2 

3import dataclasses 

4import re 

5import sys 

6from dataclasses import dataclass 

7from functools import wraps 

8from inspect import isfunction, signature 

9from pathlib import Path 

10from typing import ( 

11 Any, 

12 Callable, 

13 Dict, 

14 Set, 

15 Tuple, 

16 Type, 

17 TypeVar, 

18 Union, 

19) 

20 

21import pydantic 

22from ruyaml import Optional 

23from typing_extensions import ParamSpec 

24 

25if sys.version_info < (3, 10): # pragma: no cover 

26 SLOTS: Dict[str, bool] = {} 

27else: 

28 SLOTS = {"slots": True} 

29 

30 

31K = TypeVar("K") 

32V = TypeVar("V") 

33NestedDict = Dict[K, "NestedDict[K, V] | V"] 

34 

35if sys.version_info < (3, 9): # pragma: no cover 

36 from functools import lru_cache as cache 

37 

38 def files(package_name: str): 

39 assert package_name == "bioimageio.spec", package_name 

40 return Path(__file__).parent.parent 

41 

42else: 

43 from functools import cache as cache 

44 from importlib.resources import files as files 

45 

46 

47def get_format_version_tuple(format_version: Any) -> Optional[Tuple[int, int, int]]: 

48 if ( 

49 not isinstance(format_version, str) 

50 or format_version.count(".") != 2 

51 or any(not v.isdigit() for v in format_version.split(".")) 

52 ): 

53 return None 

54 

55 parsed = tuple(map(int, format_version.split("."))) 

56 assert len(parsed) == 3 

57 return parsed 

58 

59 

60def nest_dict(flat_dict: Dict[Tuple[K, ...], V]) -> NestedDict[K, V]: 

61 res: NestedDict[K, V] = {} 

62 for k, v in flat_dict.items(): 

63 node: Union[Dict[K, Union[NestedDict[K, V], V]], NestedDict[K, V]] = res 

64 for kk in k[:-1]: 

65 if not isinstance(node, dict): 

66 raise ValueError(f"nesting level collision for flat key {k} at {kk}") 

67 d: NestedDict[K, V] = {} 

68 node = node.setdefault(kk, d) # type: ignore 

69 

70 if not isinstance(node, dict): 

71 raise ValueError(f"nesting level collision for flat key {k}") 

72 

73 node[k[-1]] = v 

74 

75 return res 

76 

77 

78FirstK = TypeVar("FirstK") 

79 

80 

81def nest_dict_with_narrow_first_key( 

82 flat_dict: Dict[Tuple[K, ...], V], first_k: Type[FirstK] 

83) -> Dict[FirstK, "NestedDict[K, V] | V"]: 

84 """convenience function to annotate a special version of a NestedDict. 

85 Root level keys are of a narrower type than the nested keys. If not a ValueError is raisd. 

86 """ 

87 nested = nest_dict(flat_dict) 

88 invalid_first_keys = [k for k in nested if not isinstance(k, first_k)] 

89 if invalid_first_keys: 

90 raise ValueError(f"Invalid root level keys: {invalid_first_keys}") 

91 

92 return nested # type: ignore 

93 

94 

95def unindent(text: str, ignore_first_line: bool = False): 

96 """remove minimum count of spaces at beginning of each line. 

97 

98 Args: 

99 text: indented text 

100 ignore_first_line: allows to correctly unindent doc strings 

101 """ 

102 first = int(ignore_first_line) 

103 lines = text.split("\n") 

104 filled_lines = [line for line in lines[first:] if line] 

105 if len(filled_lines) < 2: 

106 return "\n".join(line.strip() for line in lines) 

107 

108 indent = min(len(line) - len(line.lstrip(" ")) for line in filled_lines) 

109 return "\n".join(lines[:first] + [line[indent:] for line in lines[first:]]) 

110 

111 

112T = TypeVar("T") 

113P = ParamSpec("P") 

114 

115 

116def assert_all_params_set_explicitly(fn: Callable[P, T]) -> Callable[P, T]: 

117 @wraps(fn) 

118 def wrapper(*args: P.args, **kwargs: P.kwargs): 

119 n_args = len(args) 

120 missing: Set[str] = set() 

121 

122 for p in signature(fn).parameters.values(): 

123 if p.kind == p.POSITIONAL_ONLY: 

124 if n_args == 0: 

125 missing.add(p.name) 

126 else: 

127 n_args -= 1 # 'use' positional arg 

128 elif p.kind == p.POSITIONAL_OR_KEYWORD: 

129 if n_args == 0: 

130 if p.name not in kwargs: 

131 missing.add(p.name) 

132 else: 

133 n_args -= 1 # 'use' positional arg 

134 elif p.kind in (p.VAR_POSITIONAL, p.VAR_KEYWORD): 

135 pass 

136 elif p.kind == p.KEYWORD_ONLY: 

137 if p.name not in kwargs: 

138 missing.add(p.name) 

139 

140 assert not missing, f"parameters {missing} of {fn} are not set explicitly" 

141 

142 return fn(*args, **kwargs) 

143 

144 return wrapper 

145 

146 

147def get_os_friendly_file_name(name: str) -> str: 

148 return re.sub(r"\W+|^(?=\d)", "_", name) 

149 

150 

151@dataclass 

152class _PrettyDataClassReprMixin: 

153 """A mixin that provides a pretty __repr__ for dataclasses 

154 

155 - leaving out fields that are None 

156 - leaving out memory locations of functions 

157 """ 

158 

159 def __repr__(self): 

160 field_values = { 

161 f.name: v 

162 for f in dataclasses.fields(self) 

163 if (v := getattr(self, f.name)) is not None 

164 } 

165 field_str = ", ".join( 

166 f"{k}=" + (f"<function {v.__name__}>" if isfunction(v) else repr(v)) 

167 for k, v in field_values.items() 

168 ) 

169 return f"{self.__class__.__name__}({field_str})" 

170 

171 

172class PrettyPlainSerializer(pydantic.PlainSerializer, _PrettyDataClassReprMixin): 

173 pass 

174 

175 

176class PrettyWrapSerializer(pydantic.WrapSerializer, _PrettyDataClassReprMixin): 

177 pass