Coverage for bioimageio/spec/_internal/io_utils.py: 84%

146 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-02 14:21 +0000

1import io 

2import zipfile 

3from contextlib import nullcontext 

4from difflib import get_close_matches 

5from pathlib import Path 

6from types import MappingProxyType 

7from typing import ( 

8 IO, 

9 Any, 

10 Dict, 

11 Mapping, 

12 Union, 

13 cast, 

14) 

15from zipfile import ZipFile, is_zipfile 

16 

17import numpy 

18import requests 

19from loguru import logger 

20from numpy.typing import NDArray 

21from pydantic import FilePath, NewPath, RootModel 

22from ruyaml import YAML 

23from typing_extensions import Unpack 

24 

25from ._settings import settings 

26from .io import ( 

27 BIOIMAGEIO_YAML, 

28 BioimageioYamlContent, 

29 FileDescr, 

30 FileInZip, 

31 HashKwargs, 

32 LightHttpFileDescr, 

33 OpenedBioimageioYaml, 

34 YamlValue, 

35 find_bioimageio_yaml_file_name, 

36 identify_bioimageio_yaml_file_name, 

37 resolve, 

38) 

39from .io_basics import FileName, ZipPath 

40from .types import FileSource, PermissiveFileSource 

41from .url import HttpUrl 

42from .utils import cache 

43from .validation_context import ValidationContext 

44 

45_yaml_load = YAML(typ="safe") 

46 

47_yaml_dump = YAML() 

48_yaml_dump.version = (1, 2) # pyright: ignore[reportAttributeAccessIssue] 

49_yaml_dump.default_flow_style = False 

50_yaml_dump.indent(mapping=2, sequence=4, offset=2) 

51_yaml_dump.width = 88 # pyright: ignore[reportAttributeAccessIssue] 

52 

53 

54def read_yaml(file: Union[FilePath, ZipPath, IO[str], IO[bytes]]) -> YamlValue: 

55 if isinstance(file, (ZipPath, Path)): 

56 data = file.read_text(encoding="utf-8") 

57 else: 

58 data = file 

59 

60 content: YamlValue = _yaml_load.load(data) 

61 return content 

62 

63 

64def write_yaml( 

65 content: Union[YamlValue, BioimageioYamlContent], 

66 /, 

67 file: Union[NewPath, FilePath, IO[str], IO[bytes], ZipPath], 

68): 

69 if isinstance(file, Path): 

70 cm = file.open("w", encoding="utf-8") 

71 else: 

72 cm = nullcontext(file) 

73 

74 with cm as f: 

75 _yaml_dump.dump(content, f) 

76 

77 

78def _sanitize_bioimageio_yaml(content: YamlValue) -> BioimageioYamlContent: 

79 if not isinstance(content, dict): 

80 raise ValueError( 

81 f"Expected {BIOIMAGEIO_YAML} content to be a mapping (got {type(content)})." 

82 ) 

83 

84 for key in content: 

85 if not isinstance(key, str): 

86 raise ValueError( 

87 f"Expected all keys (field names) in a {BIOIMAGEIO_YAML} " 

88 + f"need to be strings (got '{key}' of type {type(key)})." 

89 ) 

90 

91 return cast(BioimageioYamlContent, content) 

92 

93 

94def _open_bioimageio_rdf_in_zip(source: ZipFile, rdf_name: str) -> OpenedBioimageioYaml: 

95 with source.open(rdf_name) as f: 

96 unparsed_content = f.read().decode(encoding="utf-8") 

97 

98 content = _sanitize_bioimageio_yaml(read_yaml(io.StringIO(unparsed_content))) 

99 

100 return OpenedBioimageioYaml( 

101 content, 

102 source, 

103 source.filename or "bioimageio.zip", 

104 unparsed_content=unparsed_content, 

105 ) 

106 

107 

108def _open_bioimageio_zip(source: ZipFile) -> OpenedBioimageioYaml: 

109 rdf_name = identify_bioimageio_yaml_file_name( 

110 [info.filename for info in source.filelist] 

111 ) 

112 return _open_bioimageio_rdf_in_zip(source, rdf_name) 

113 

114 

115def open_bioimageio_yaml( 

116 source: Union[PermissiveFileSource, ZipFile], /, **kwargs: Unpack[HashKwargs] 

117) -> OpenedBioimageioYaml: 

118 if isinstance(source, ZipFile): 

119 return _open_bioimageio_zip(source) 

120 

121 try: 

122 if isinstance(source, (Path, str)) and (source_dir := Path(source)).is_dir(): 

123 # open bioimageio yaml from a folder 

124 src = source_dir / find_bioimageio_yaml_file_name(source_dir) 

125 else: 

126 src = source 

127 

128 downloaded = resolve(src, **kwargs) 

129 

130 except Exception: 

131 # check if `source` is a collection id 

132 if ( 

133 not isinstance(source, str) 

134 or not isinstance(settings.id_map, str) 

135 or "/" not in settings.id_map 

136 ): 

137 raise 

138 

139 if settings.collection_http_pattern: 

140 with ValidationContext(perform_io_checks=False): 

141 url = HttpUrl( 

142 settings.collection_http_pattern.format(bioimageio_id=source) 

143 ) 

144 

145 try: 

146 r = requests.get(url) 

147 r.raise_for_status() 

148 unparsed_content = r.content.decode(encoding="utf-8") 

149 content = _sanitize_bioimageio_yaml( 

150 read_yaml(io.StringIO(unparsed_content)) 

151 ) 

152 except Exception as e: 

153 logger.warning("Failed to get bioimageio.yaml from {}: {}", url, e) 

154 else: 

155 original_file_name = ( 

156 "rdf.yaml" if url.path is None else url.path.split("/")[-1] 

157 ) 

158 return OpenedBioimageioYaml( 

159 content=content, 

160 original_root=url.parent, 

161 original_file_name=original_file_name, 

162 unparsed_content=unparsed_content, 

163 ) 

164 

165 id_map = get_id_map() 

166 if id_map and source not in id_map: 

167 close_matches = get_close_matches(source, id_map) 

168 if len(close_matches) == 0: 

169 raise 

170 

171 if len(close_matches) == 1: 

172 did_you_mean = f" Did you mean '{close_matches[0]}'?" 

173 else: 

174 did_you_mean = f" Did you mean any of {close_matches}?" 

175 

176 raise FileNotFoundError(f"'{source}' not found.{did_you_mean}") 

177 

178 entry = id_map[source] 

179 logger.info("loading {} from {}", source, entry.source) 

180 downloaded = entry.download() 

181 

182 local_source = downloaded.path 

183 if isinstance(local_source, ZipPath): 

184 return _open_bioimageio_rdf_in_zip(local_source.root, local_source.name) 

185 elif is_zipfile(local_source): 

186 return _open_bioimageio_zip(ZipFile(local_source)) 

187 

188 if local_source.is_dir(): 

189 root = local_source 

190 local_source = local_source / find_bioimageio_yaml_file_name(local_source) 

191 else: 

192 root = downloaded.original_root 

193 

194 content = _sanitize_bioimageio_yaml(read_yaml(local_source)) 

195 return OpenedBioimageioYaml( 

196 content, 

197 root.original_root if isinstance(root, FileInZip) else root, 

198 downloaded.original_file_name, 

199 unparsed_content=local_source.read_text(encoding="utf-8"), 

200 ) 

201 

202 

203_IdMap = RootModel[Dict[str, LightHttpFileDescr]] 

204 

205 

206def _get_id_map_impl(url: str) -> Dict[str, LightHttpFileDescr]: 

207 if not isinstance(url, str) or "/" not in url: 

208 logger.opt(depth=1).error("invalid id map url: {}", url) 

209 try: 

210 id_map_raw: Any = requests.get(url, timeout=10).json() 

211 except Exception as e: 

212 logger.opt(depth=1).error("failed to get {}: {}", url, e) 

213 return {} 

214 

215 id_map = _IdMap.model_validate(id_map_raw) 

216 return id_map.root 

217 

218 

219@cache 

220def get_id_map() -> Mapping[str, LightHttpFileDescr]: 

221 try: 

222 if settings.resolve_draft: 

223 ret = _get_id_map_impl(settings.id_map_draft) 

224 else: 

225 ret = {} 

226 

227 ret.update(_get_id_map_impl(settings.id_map)) 

228 

229 except Exception as e: 

230 logger.error("failed to get resource id mapping: {}", e) 

231 ret = {} 

232 

233 return MappingProxyType(ret) 

234 

235 

236def write_content_to_zip( 

237 content: Mapping[FileName, Union[str, FilePath, ZipPath, Dict[Any, Any]]], 

238 zip: zipfile.ZipFile, 

239): 

240 """write strings as text, dictionaries as yaml and files to a ZipFile 

241 Args: 

242 content: dict mapping archive names to local file paths, 

243 strings (for text files), or dict (for yaml files). 

244 zip: ZipFile 

245 """ 

246 for arc_name, file in content.items(): 

247 if isinstance(file, dict): 

248 buf = io.StringIO() 

249 write_yaml(file, buf) 

250 file = buf.getvalue() 

251 

252 if isinstance(file, str): 

253 zip.writestr(arc_name, file.encode("utf-8")) 

254 elif isinstance(file, ZipPath): 

255 zip.writestr(arc_name, file.read_bytes()) 

256 else: 

257 zip.write(file, arcname=arc_name) 

258 

259 

260def write_zip( 

261 path: Union[FilePath, IO[bytes]], 

262 content: Mapping[FileName, Union[str, FilePath, ZipPath, Dict[Any, Any]]], 

263 *, 

264 compression: int, 

265 compression_level: int, 

266) -> None: 

267 """Write a zip archive. 

268 

269 Args: 

270 path: output path to write to. 

271 content: dict mapping archive names to local file paths, strings (for text files), or dict (for yaml files). 

272 compression: The numeric constant of compression method. 

273 compression_level: Compression level to use when writing files to the archive. 

274 See https://docs.python.org/3/library/zipfile.html#zipfile.ZipFile 

275 

276 """ 

277 with ZipFile( 

278 path, "w", compression=compression, compresslevel=compression_level 

279 ) as zip: 

280 write_content_to_zip(content, zip) 

281 

282 

283def load_array(source: Union[FileSource, FileDescr, ZipPath]) -> NDArray[Any]: 

284 path = resolve(source).path 

285 with path.open(mode="rb") as f: 

286 assert not isinstance(f, io.TextIOWrapper) 

287 return numpy.load(f, allow_pickle=settings.allow_pickle) 

288 

289 

290def save_array(path: Union[Path, ZipPath], array: NDArray[Any]) -> None: 

291 with path.open(mode="wb") as f: 

292 assert not isinstance(f, io.TextIOWrapper) 

293 return numpy.save(f, array, allow_pickle=settings.allow_pickle)