Coverage for bioimageio/spec/_internal/utils.py: 35%

80 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-02-05 13:53 +0000

1from __future__ import annotations 

2 

3import re 

4import sys 

5from functools import wraps 

6from inspect import signature 

7from pathlib import Path 

8from typing import ( 

9 Any, 

10 Callable, 

11 Dict, 

12 Set, 

13 Tuple, 

14 Type, 

15 TypeVar, 

16 Union, 

17) 

18 

19from ruyaml import Optional 

20from typing_extensions import ParamSpec 

21 

22K = TypeVar("K") 

23V = TypeVar("V") 

24NestedDict = Dict[K, "NestedDict[K, V] | V"] 

25 

26 

27if sys.version_info < (3, 9): 

28 from functools import lru_cache as cache 

29 

30 def files(package_name: str): 

31 assert package_name == "bioimageio.spec", package_name 

32 return Path(__file__).parent.parent 

33 

34else: 

35 from functools import cache as cache 

36 from importlib.resources import files as files 

37 

38 

39def get_format_version_tuple(format_version: Any) -> Optional[Tuple[int, int, int]]: 

40 if ( 

41 not isinstance(format_version, str) 

42 or format_version.count(".") != 2 

43 or any(not v.isdigit() for v in format_version.split(".")) 

44 ): 

45 return None 

46 

47 parsed = tuple(map(int, format_version.split("."))) 

48 assert len(parsed) == 3 

49 return parsed 

50 

51 

52def nest_dict(flat_dict: Dict[Tuple[K, ...], V]) -> NestedDict[K, V]: 

53 res: NestedDict[K, V] = {} 

54 for k, v in flat_dict.items(): 

55 node: Union[Dict[K, Union[NestedDict[K, V], V]], NestedDict[K, V]] = res 

56 for kk in k[:-1]: 

57 if not isinstance(node, dict): 

58 raise ValueError(f"nesting level collision for flat key {k} at {kk}") 

59 d: NestedDict[K, V] = {} 

60 node = node.setdefault(kk, d) # type: ignore 

61 

62 if not isinstance(node, dict): 

63 raise ValueError(f"nesting level collision for flat key {k}") 

64 

65 node[k[-1]] = v 

66 

67 return res 

68 

69 

70FirstK = TypeVar("FirstK") 

71 

72 

73def nest_dict_with_narrow_first_key( 

74 flat_dict: Dict[Tuple[K, ...], V], first_k: Type[FirstK] 

75) -> Dict[FirstK, "NestedDict[K, V] | V"]: 

76 """convenience function to annotate a special version of a NestedDict. 

77 Root level keys are of a narrower type than the nested keys. If not a ValueError is raisd. 

78 """ 

79 nested = nest_dict(flat_dict) 

80 invalid_first_keys = [k for k in nested if not isinstance(k, first_k)] 

81 if invalid_first_keys: 

82 raise ValueError(f"Invalid root level keys: {invalid_first_keys}") 

83 

84 return nested # type: ignore 

85 

86 

87def unindent(text: str, ignore_first_line: bool = False): 

88 """remove minimum count of spaces at beginning of each line. 

89 

90 Args: 

91 text: indented text 

92 ignore_first_line: allows to correctly unindent doc strings 

93 """ 

94 first = int(ignore_first_line) 

95 lines = text.split("\n") 

96 filled_lines = [line for line in lines[first:] if line] 

97 if len(filled_lines) < 2: 

98 return "\n".join(line.strip() for line in lines) 

99 

100 indent = min(len(line) - len(line.lstrip(" ")) for line in filled_lines) 

101 return "\n".join(lines[:first] + [line[indent:] for line in lines[first:]]) 

102 

103 

104T = TypeVar("T") 

105P = ParamSpec("P") 

106 

107 

108def assert_all_params_set_explicitly(fn: Callable[P, T]) -> Callable[P, T]: 

109 @wraps(fn) 

110 def wrapper(*args: P.args, **kwargs: P.kwargs): 

111 n_args = len(args) 

112 missing: Set[str] = set() 

113 

114 for p in signature(fn).parameters.values(): 

115 if p.kind == p.POSITIONAL_ONLY: 

116 if n_args == 0: 

117 missing.add(p.name) 

118 else: 

119 n_args -= 1 # 'use' positional arg 

120 elif p.kind == p.POSITIONAL_OR_KEYWORD: 

121 if n_args == 0: 

122 if p.name not in kwargs: 

123 missing.add(p.name) 

124 else: 

125 n_args -= 1 # 'use' positional arg 

126 elif p.kind in (p.VAR_POSITIONAL, p.VAR_KEYWORD): 

127 pass 

128 elif p.kind == p.KEYWORD_ONLY: 

129 if p.name not in kwargs: 

130 missing.add(p.name) 

131 

132 assert not missing, f"parameters {missing} of {fn} are not set explicitly" 

133 

134 return fn(*args, **kwargs) 

135 

136 return wrapper 

137 

138 

139def get_os_friendly_file_name(name: str) -> str: 

140 return re.sub(r"\W+|^(?=\d)", "_", name)