Coverage for src / bioimageio / spec / _internal / utils.py: 65%
113 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-18 15:21 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-18 15:21 +0000
1from __future__ import annotations
3import dataclasses
4import re
5import sys
6from dataclasses import dataclass
7from functools import wraps
8from inspect import isfunction, signature
9from typing import (
10 Any,
11 Callable,
12 Dict,
13 Iterable,
14 List,
15 Set,
16 Tuple,
17 Type,
18 TypeVar,
19 Union,
20)
22import pydantic
23from exceptiongroup import ExceptionGroup
24from ruyaml import Optional
25from typing_extensions import ParamSpec
27if sys.version_info < (3, 10): # pragma: no cover
28 SLOTS: Dict[str, bool] = {}
29else:
30 SLOTS = {"slots": True}
33K = TypeVar("K")
34V = TypeVar("V")
35NestedDict = Dict[K, "NestedDict[K, V] | V"]
37if sys.version_info < (3, 9): # pragma: no cover
38 from functools import lru_cache as cache
39 from pathlib import Path
41 def files(package_name: str):
42 assert package_name == "bioimageio.spec", package_name
43 return Path(__file__).parent.parent
45else:
46 from functools import cache as cache
47 from importlib.resources import files as files
50def get_format_version_tuple(format_version: Any) -> Optional[Tuple[int, int, int]]:
51 if (
52 not isinstance(format_version, str)
53 or format_version.count(".") != 2
54 or any(not v.isdigit() for v in format_version.split("."))
55 ):
56 return None
58 parsed = tuple(map(int, format_version.split(".")))
59 assert len(parsed) == 3
60 return parsed
63def nest_dict(flat_dict: Dict[Tuple[K, ...], V]) -> NestedDict[K, V]:
64 res: NestedDict[K, V] = {}
65 for k, v in flat_dict.items():
66 node: Union[Dict[K, Union[NestedDict[K, V], V]], NestedDict[K, V]] = res
67 for kk in k[:-1]:
68 if not isinstance(node, dict):
69 raise ValueError(f"nesting level collision for flat key {k} at {kk}")
70 d: NestedDict[K, V] = {}
71 node = node.setdefault(kk, d) # type: ignore
73 if not isinstance(node, dict):
74 raise ValueError(f"nesting level collision for flat key {k}")
76 node[k[-1]] = v
78 return res
81FirstK = TypeVar("FirstK")
84def nest_dict_with_narrow_first_key(
85 flat_dict: Dict[Tuple[K, ...], V], first_k: Type[FirstK]
86) -> Dict[FirstK, "NestedDict[K, V] | V"]:
87 """convenience function to annotate a special version of a NestedDict.
88 Root level keys are of a narrower type than the nested keys. If not a ValueError is raisd.
89 """
90 nested = nest_dict(flat_dict)
91 invalid_first_keys = [k for k in nested if not isinstance(k, first_k)]
92 if invalid_first_keys:
93 raise ValueError(f"Invalid root level keys: {invalid_first_keys}")
95 return nested # type: ignore
98def unindent(text: str, ignore_first_line: bool = False):
99 """remove minimum count of spaces at beginning of each line.
101 Args:
102 text: indented text
103 ignore_first_line: allows to correctly unindent doc strings
104 """
105 first = int(ignore_first_line)
106 lines = text.split("\n")
107 filled_lines = [line for line in lines[first:] if line]
108 if len(filled_lines) < 2:
109 return "\n".join(line.strip() for line in lines)
111 indent = min(len(line) - len(line.lstrip(" ")) for line in filled_lines)
112 return "\n".join(lines[:first] + [line[indent:] for line in lines[first:]])
115T = TypeVar("T")
116P = ParamSpec("P")
119def assert_all_params_set_explicitly(fn: Callable[P, T]) -> Callable[P, T]:
120 @wraps(fn)
121 def wrapper(*args: P.args, **kwargs: P.kwargs):
122 n_args = len(args)
123 missing: Set[str] = set()
125 for p in signature(fn).parameters.values():
126 if p.kind == p.POSITIONAL_ONLY:
127 if n_args == 0:
128 missing.add(p.name)
129 else:
130 n_args -= 1 # 'use' positional arg
131 elif p.kind == p.POSITIONAL_OR_KEYWORD:
132 if n_args == 0:
133 if p.name not in kwargs:
134 missing.add(p.name)
135 else:
136 n_args -= 1 # 'use' positional arg
137 elif p.kind in (p.VAR_POSITIONAL, p.VAR_KEYWORD):
138 pass
139 elif p.kind == p.KEYWORD_ONLY:
140 if p.name not in kwargs:
141 missing.add(p.name)
143 assert not missing, f"parameters {missing} of {fn} are not set explicitly"
145 return fn(*args, **kwargs)
147 return wrapper
150def get_os_friendly_file_name(name: str) -> str:
151 return re.sub(r"\W+|^(?=\d)", "_", name)
154@dataclass
155class _PrettyDataClassReprMixin:
156 """A mixin that provides a pretty __repr__ for dataclasses
158 - leaving out fields that are None
159 - leaving out memory locations of functions
160 """
162 def __repr__(self):
163 field_values = {
164 f.name: v
165 for f in dataclasses.fields(self)
166 if (v := getattr(self, f.name)) is not None
167 }
168 field_str = ", ".join(
169 f"{k}=" + (f"<function {v.__name__}>" if isfunction(v) else repr(v))
170 for k, v in field_values.items()
171 )
172 return f"{self.__class__.__name__}({field_str})"
175class PrettyPlainSerializer(pydantic.PlainSerializer, _PrettyDataClassReprMixin):
176 pass
179class PrettyWrapSerializer(pydantic.WrapSerializer, _PrettyDataClassReprMixin):
180 pass
183def try_all(
184 funcs: Iterable[Callable[P, T]],
185 *args: P.args,
186 **kwargs: P.kwargs,
187) -> T:
188 ret, errors = _try_all(funcs, False, *args, **kwargs)
189 if errors:
190 raise ExceptionGroup("All functions raised", errors)
192 assert not isinstance(ret, _AllFailedSentinel)
193 return ret
196def try_all_raise_last(
197 funcs: Iterable[Callable[P, T]],
198 *args: P.args,
199 **kwargs: P.kwargs,
200) -> T:
201 ret, errors = _try_all(funcs, True, *args, **kwargs)
202 if errors:
203 raise errors[-1]
205 assert not isinstance(ret, _AllFailedSentinel)
206 return ret
209class _AllFailedSentinel:
210 pass
213def _try_all(
214 funcs: Iterable[Callable[P, T]],
215 raise_last_only: bool,
216 *args: P.args,
217 **kwargs: P.kwargs,
218) -> Tuple[Union[_AllFailedSentinel, T], List[Exception]]:
219 """Try to call each of the functions `funcs` with the given arguments.
221 If all raise, raise an exception group (or the last).
223 Returns:
224 Result of the first successful call.
225 """
226 errors: List[Exception] = []
227 for c in funcs:
228 try:
229 return c(*args, **kwargs), []
230 except Exception as e:
231 errors.append(e)
233 if errors:
234 errors.append(RuntimeError("No functions provided to try."))
236 return _AllFailedSentinel(), errors