Coverage for src / bioimageio / spec / _internal / utils.py: 65%

113 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-18 15:21 +0000

1from __future__ import annotations 

2 

3import dataclasses 

4import re 

5import sys 

6from dataclasses import dataclass 

7from functools import wraps 

8from inspect import isfunction, signature 

9from typing import ( 

10 Any, 

11 Callable, 

12 Dict, 

13 Iterable, 

14 List, 

15 Set, 

16 Tuple, 

17 Type, 

18 TypeVar, 

19 Union, 

20) 

21 

22import pydantic 

23from exceptiongroup import ExceptionGroup 

24from ruyaml import Optional 

25from typing_extensions import ParamSpec 

26 

27if sys.version_info < (3, 10): # pragma: no cover 

28 SLOTS: Dict[str, bool] = {} 

29else: 

30 SLOTS = {"slots": True} 

31 

32 

33K = TypeVar("K") 

34V = TypeVar("V") 

35NestedDict = Dict[K, "NestedDict[K, V] | V"] 

36 

37if sys.version_info < (3, 9): # pragma: no cover 

38 from functools import lru_cache as cache 

39 from pathlib import Path 

40 

41 def files(package_name: str): 

42 assert package_name == "bioimageio.spec", package_name 

43 return Path(__file__).parent.parent 

44 

45else: 

46 from functools import cache as cache 

47 from importlib.resources import files as files 

48 

49 

50def get_format_version_tuple(format_version: Any) -> Optional[Tuple[int, int, int]]: 

51 if ( 

52 not isinstance(format_version, str) 

53 or format_version.count(".") != 2 

54 or any(not v.isdigit() for v in format_version.split(".")) 

55 ): 

56 return None 

57 

58 parsed = tuple(map(int, format_version.split("."))) 

59 assert len(parsed) == 3 

60 return parsed 

61 

62 

63def nest_dict(flat_dict: Dict[Tuple[K, ...], V]) -> NestedDict[K, V]: 

64 res: NestedDict[K, V] = {} 

65 for k, v in flat_dict.items(): 

66 node: Union[Dict[K, Union[NestedDict[K, V], V]], NestedDict[K, V]] = res 

67 for kk in k[:-1]: 

68 if not isinstance(node, dict): 

69 raise ValueError(f"nesting level collision for flat key {k} at {kk}") 

70 d: NestedDict[K, V] = {} 

71 node = node.setdefault(kk, d) # type: ignore 

72 

73 if not isinstance(node, dict): 

74 raise ValueError(f"nesting level collision for flat key {k}") 

75 

76 node[k[-1]] = v 

77 

78 return res 

79 

80 

81FirstK = TypeVar("FirstK") 

82 

83 

84def nest_dict_with_narrow_first_key( 

85 flat_dict: Dict[Tuple[K, ...], V], first_k: Type[FirstK] 

86) -> Dict[FirstK, "NestedDict[K, V] | V"]: 

87 """convenience function to annotate a special version of a NestedDict. 

88 Root level keys are of a narrower type than the nested keys. If not a ValueError is raisd. 

89 """ 

90 nested = nest_dict(flat_dict) 

91 invalid_first_keys = [k for k in nested if not isinstance(k, first_k)] 

92 if invalid_first_keys: 

93 raise ValueError(f"Invalid root level keys: {invalid_first_keys}") 

94 

95 return nested # type: ignore 

96 

97 

98def unindent(text: str, ignore_first_line: bool = False): 

99 """remove minimum count of spaces at beginning of each line. 

100 

101 Args: 

102 text: indented text 

103 ignore_first_line: allows to correctly unindent doc strings 

104 """ 

105 first = int(ignore_first_line) 

106 lines = text.split("\n") 

107 filled_lines = [line for line in lines[first:] if line] 

108 if len(filled_lines) < 2: 

109 return "\n".join(line.strip() for line in lines) 

110 

111 indent = min(len(line) - len(line.lstrip(" ")) for line in filled_lines) 

112 return "\n".join(lines[:first] + [line[indent:] for line in lines[first:]]) 

113 

114 

115T = TypeVar("T") 

116P = ParamSpec("P") 

117 

118 

119def assert_all_params_set_explicitly(fn: Callable[P, T]) -> Callable[P, T]: 

120 @wraps(fn) 

121 def wrapper(*args: P.args, **kwargs: P.kwargs): 

122 n_args = len(args) 

123 missing: Set[str] = set() 

124 

125 for p in signature(fn).parameters.values(): 

126 if p.kind == p.POSITIONAL_ONLY: 

127 if n_args == 0: 

128 missing.add(p.name) 

129 else: 

130 n_args -= 1 # 'use' positional arg 

131 elif p.kind == p.POSITIONAL_OR_KEYWORD: 

132 if n_args == 0: 

133 if p.name not in kwargs: 

134 missing.add(p.name) 

135 else: 

136 n_args -= 1 # 'use' positional arg 

137 elif p.kind in (p.VAR_POSITIONAL, p.VAR_KEYWORD): 

138 pass 

139 elif p.kind == p.KEYWORD_ONLY: 

140 if p.name not in kwargs: 

141 missing.add(p.name) 

142 

143 assert not missing, f"parameters {missing} of {fn} are not set explicitly" 

144 

145 return fn(*args, **kwargs) 

146 

147 return wrapper 

148 

149 

150def get_os_friendly_file_name(name: str) -> str: 

151 return re.sub(r"\W+|^(?=\d)", "_", name) 

152 

153 

154@dataclass 

155class _PrettyDataClassReprMixin: 

156 """A mixin that provides a pretty __repr__ for dataclasses 

157 

158 - leaving out fields that are None 

159 - leaving out memory locations of functions 

160 """ 

161 

162 def __repr__(self): 

163 field_values = { 

164 f.name: v 

165 for f in dataclasses.fields(self) 

166 if (v := getattr(self, f.name)) is not None 

167 } 

168 field_str = ", ".join( 

169 f"{k}=" + (f"<function {v.__name__}>" if isfunction(v) else repr(v)) 

170 for k, v in field_values.items() 

171 ) 

172 return f"{self.__class__.__name__}({field_str})" 

173 

174 

175class PrettyPlainSerializer(pydantic.PlainSerializer, _PrettyDataClassReprMixin): 

176 pass 

177 

178 

179class PrettyWrapSerializer(pydantic.WrapSerializer, _PrettyDataClassReprMixin): 

180 pass 

181 

182 

183def try_all( 

184 funcs: Iterable[Callable[P, T]], 

185 *args: P.args, 

186 **kwargs: P.kwargs, 

187) -> T: 

188 ret, errors = _try_all(funcs, False, *args, **kwargs) 

189 if errors: 

190 raise ExceptionGroup("All functions raised", errors) 

191 

192 assert not isinstance(ret, _AllFailedSentinel) 

193 return ret 

194 

195 

196def try_all_raise_last( 

197 funcs: Iterable[Callable[P, T]], 

198 *args: P.args, 

199 **kwargs: P.kwargs, 

200) -> T: 

201 ret, errors = _try_all(funcs, True, *args, **kwargs) 

202 if errors: 

203 raise errors[-1] 

204 

205 assert not isinstance(ret, _AllFailedSentinel) 

206 return ret 

207 

208 

209class _AllFailedSentinel: 

210 pass 

211 

212 

213def _try_all( 

214 funcs: Iterable[Callable[P, T]], 

215 raise_last_only: bool, 

216 *args: P.args, 

217 **kwargs: P.kwargs, 

218) -> Tuple[Union[_AllFailedSentinel, T], List[Exception]]: 

219 """Try to call each of the functions `funcs` with the given arguments. 

220 

221 If all raise, raise an exception group (or the last). 

222 

223 Returns: 

224 Result of the first successful call. 

225 """ 

226 errors: List[Exception] = [] 

227 for c in funcs: 

228 try: 

229 return c(*args, **kwargs), [] 

230 except Exception as e: 

231 errors.append(e) 

232 

233 if errors: 

234 errors.append(RuntimeError("No functions provided to try.")) 

235 

236 return _AllFailedSentinel(), errors