Coverage for src / bioimageio / spec / _internal / io_utils.py: 78%

175 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-23 10:51 +0000

1import collections.abc 

2import io 

3import shutil 

4import zipfile 

5from contextlib import nullcontext 

6from difflib import get_close_matches 

7from pathlib import Path 

8from types import MappingProxyType 

9from typing import IO, Any, Dict, Mapping, Union, cast 

10from zipfile import ZipFile 

11 

12import httpx 

13import numpy 

14from loguru import logger 

15from numpy.typing import NDArray 

16from pydantic import BaseModel, FilePath, NewPath, RootModel 

17from ruyaml import YAML 

18from typing_extensions import Unpack 

19 

20from ._settings import settings 

21from .io import ( 

22 BIOIMAGEIO_YAML, 

23 BioimageioYamlContent, 

24 BioimageioYamlContentView, 

25 BytesReader, 

26 FileDescr, 

27 HashKwargs, 

28 LightHttpFileDescr, 

29 OpenedBioimageioYaml, 

30 RelativeFilePath, 

31 YamlValue, 

32 extract_file_name, 

33 find_bioimageio_yaml_file_name, 

34 get_reader, 

35 identify_bioimageio_yaml_file_name, 

36 interprete_file_source, 

37) 

38from .io_basics import AbsoluteDirectory, FileName, ZipPath 

39from .types import FileSource, PermissiveFileSource 

40from .url import HttpUrl, RootHttpUrl 

41from .utils import cache 

42from .validation_context import ValidationContext, get_validation_context 

43 

44_yaml_load = YAML(typ="safe") 

45 

46_yaml_dump = YAML() 

47_yaml_dump.version = (1, 2) # pyright: ignore[reportAttributeAccessIssue] 

48_yaml_dump.default_flow_style = False 

49_yaml_dump.indent(mapping=2, sequence=4, offset=2) 

50_yaml_dump.width = 88 # pyright: ignore[reportAttributeAccessIssue] 

51 

52 

53def read_yaml( 

54 file: Union[FilePath, ZipPath, IO[str], IO[bytes], BytesReader, str], 

55) -> YamlValue: 

56 if isinstance(file, (ZipPath, Path)): 

57 data = file.read_text(encoding="utf-8") 

58 else: 

59 data = file 

60 

61 content: YamlValue = _yaml_load.load(data) 

62 return content 

63 

64 

65def write_yaml( 

66 content: Union[YamlValue, BioimageioYamlContentView, BaseModel], 

67 /, 

68 file: Union[NewPath, FilePath, IO[str], IO[bytes], ZipPath], 

69): 

70 if isinstance(file, Path): 

71 cm = file.open("w", encoding="utf-8") 

72 else: 

73 cm = nullcontext(file) 

74 

75 if isinstance(content, BaseModel): 

76 content = content.model_dump(mode="json") 

77 

78 with cm as f: 

79 _yaml_dump.dump(content, f) 

80 

81 

82def _sanitize_bioimageio_yaml(content: YamlValue) -> BioimageioYamlContent: 

83 if not isinstance(content, dict): 

84 raise ValueError( 

85 f"Expected {BIOIMAGEIO_YAML} content to be a mapping (got {type(content)})." 

86 ) 

87 

88 for key in content: 

89 if not isinstance(key, str): 

90 raise ValueError( 

91 f"Expected all keys (field names) in a {BIOIMAGEIO_YAML} " 

92 + f"to be strings (got '{key}' of type {type(key)})." 

93 ) 

94 

95 return cast(BioimageioYamlContent, content) 

96 

97 

98def _open_bioimageio_rdf_in_zip( 

99 path: ZipPath, 

100 *, 

101 original_root: Union[AbsoluteDirectory, RootHttpUrl, ZipFile], 

102 original_source_name: str, 

103) -> OpenedBioimageioYaml: 

104 with path.open("rb") as f: 

105 assert not isinstance(f, io.TextIOWrapper) 

106 unparsed_content = f.read().decode(encoding="utf-8") 

107 

108 content = _sanitize_bioimageio_yaml(read_yaml(io.StringIO(unparsed_content))) 

109 

110 return OpenedBioimageioYaml( 

111 content, 

112 original_root=original_root, 

113 original_file_name=extract_file_name(path), 

114 original_source_name=original_source_name, 

115 unparsed_content=unparsed_content, 

116 ) 

117 

118 

119def _open_bioimageio_zip( 

120 source: ZipFile, 

121 *, 

122 original_source_name: str, 

123) -> OpenedBioimageioYaml: 

124 rdf_name = identify_bioimageio_yaml_file_name( 

125 [info.filename for info in source.filelist] 

126 ) 

127 return _open_bioimageio_rdf_in_zip( 

128 ZipPath(source, rdf_name), 

129 original_root=source, 

130 original_source_name=original_source_name, 

131 ) 

132 

133 

134def open_bioimageio_yaml( 

135 source: Union[PermissiveFileSource, ZipFile, ZipPath], 

136 /, 

137 **kwargs: Unpack[HashKwargs], 

138) -> OpenedBioimageioYaml: 

139 if ( 

140 isinstance(source, str) 

141 and source.startswith("huggingface/") 

142 and source.count("/") >= 2 

143 ): 

144 if source.count("/") == 2: 

145 # huggingface/{user_or_org}/{repo_name} 

146 repo_id = source[len("huggingface/") :] 

147 branch = "main" 

148 else: 

149 # huggingface/{user_or_org}/{repo_id}/ 

150 # huggingface/{user_or_org}/{repo_id}/version 

151 repo_id, version = source[len("huggingface/") :].rsplit("/", 1) 

152 if len(version) == 0: 

153 branch = "main" 

154 elif version[0].isdigit(): 

155 branch = f"v{version}" 

156 else: 

157 branch = version 

158 

159 source = HttpUrl( 

160 settings.huggingface_http_pattern.format(repo_id=repo_id, branch=branch) 

161 ) 

162 

163 if isinstance(source, RelativeFilePath): 

164 source = source.absolute() 

165 

166 if isinstance(source, ZipFile): 

167 return _open_bioimageio_zip(source, original_source_name=str(source)) 

168 elif isinstance(source, ZipPath): 

169 return _open_bioimageio_rdf_in_zip( 

170 source, original_root=source.root, original_source_name=str(source) 

171 ) 

172 

173 try: 

174 if isinstance(source, (Path, str)) and (source_dir := Path(source)).is_dir(): 

175 # open bioimageio yaml from a folder 

176 src = source_dir / find_bioimageio_yaml_file_name(source_dir) 

177 else: 

178 src = interprete_file_source(source) 

179 

180 reader = get_reader(src, **kwargs) 

181 

182 except Exception: 

183 # check if `source` is a collection id 

184 if ( 

185 not isinstance(source, str) 

186 or not isinstance(settings.id_map, str) 

187 or "/" not in settings.id_map 

188 ): 

189 raise 

190 

191 if settings.collection_http_pattern: 

192 with ValidationContext(perform_io_checks=False): 

193 url = HttpUrl( 

194 settings.collection_http_pattern.format(bioimageio_id=source) 

195 ) 

196 

197 try: 

198 r = httpx.get(url, follow_redirects=True) 

199 _ = r.raise_for_status() 

200 unparsed_content = r.content.decode(encoding="utf-8") 

201 content = _sanitize_bioimageio_yaml(read_yaml(unparsed_content)) 

202 except Exception as e: 

203 logger.warning("Failed to get bioimageio.yaml from {}: {}", url, e) 

204 else: 

205 logger.info("loaded {} from {}", source, url) 

206 original_file_name = ( 

207 "rdf.yaml" if url.path is None else url.path.split("/")[-1] 

208 ) 

209 return OpenedBioimageioYaml( 

210 content=content, 

211 original_root=url.parent, 

212 original_file_name=original_file_name, 

213 original_source_name=url, 

214 unparsed_content=unparsed_content, 

215 ) 

216 

217 id_map = get_id_map() 

218 if id_map and source not in id_map: 

219 close_matches = get_close_matches(source, id_map) 

220 if len(close_matches) == 0: 

221 raise 

222 

223 if len(close_matches) == 1: 

224 did_you_mean = f" Did you mean '{close_matches[0]}'?" 

225 else: 

226 did_you_mean = f" Did you mean any of {close_matches}?" 

227 

228 raise FileNotFoundError(f"'{source}' not found.{did_you_mean}") 

229 

230 entry = id_map[source] 

231 logger.info("loading {} from {}", source, entry.source) 

232 reader = entry.get_reader() 

233 with get_validation_context().replace(perform_io_checks=False): 

234 src = HttpUrl(entry.source) 

235 

236 if reader.is_zipfile: 

237 return _open_bioimageio_zip(ZipFile(reader), original_source_name=str(src)) 

238 

239 unparsed_content = reader.read().decode(encoding="utf-8") 

240 content = _sanitize_bioimageio_yaml(read_yaml(unparsed_content)) 

241 

242 if isinstance(src, RelativeFilePath): 

243 src = src.absolute() 

244 

245 if isinstance(src, ZipPath): 

246 root = src.root 

247 else: 

248 root = src.parent 

249 

250 return OpenedBioimageioYaml( 

251 content, 

252 original_root=root, 

253 original_source_name=str(src), 

254 original_file_name=extract_file_name(src), 

255 unparsed_content=unparsed_content, 

256 ) 

257 

258 

259_IdMap = RootModel[Dict[str, LightHttpFileDescr]] 

260 

261 

262def _get_id_map_impl(url: str) -> Dict[str, LightHttpFileDescr]: 

263 if not isinstance(url, str) or "/" not in url: 

264 logger.opt(depth=1).error("invalid id map url: {}", url) 

265 try: 

266 id_map_raw: Any = httpx.get( 

267 url, timeout=settings.http_timeout, follow_redirects=True 

268 ).json() 

269 except Exception as e: 

270 logger.opt(depth=1).error("failed to get {}: {}", url, e) 

271 return {} 

272 

273 id_map = _IdMap.model_validate(id_map_raw) 

274 return id_map.root 

275 

276 

277@cache 

278def get_id_map() -> Mapping[str, LightHttpFileDescr]: 

279 try: 

280 if settings.resolve_draft: 

281 ret = _get_id_map_impl(settings.id_map_draft) 

282 else: 

283 ret = {} 

284 

285 ret.update(_get_id_map_impl(settings.id_map)) 

286 

287 except Exception as e: 

288 logger.error("failed to get resource id mapping: {}", e) 

289 ret = {} 

290 

291 return MappingProxyType(ret) 

292 

293 

294def write_content_to_zip( 

295 content: Mapping[ 

296 FileName, 

297 Union[ 

298 str, FilePath, ZipPath, BioimageioYamlContentView, FileDescr, BytesReader 

299 ], 

300 ], 

301 zip: zipfile.ZipFile, 

302): 

303 """write strings as text, dictionaries as yaml and files to a ZipFile 

304 Args: 

305 content: dict mapping archive names to local file paths, 

306 strings (for text files), or dict (for yaml files). 

307 zip: ZipFile 

308 """ 

309 for arc_name, file in content.items(): 

310 if isinstance(file, collections.abc.Mapping): 

311 buf = io.StringIO() 

312 write_yaml(file, buf) 

313 file = buf.getvalue() 

314 

315 if isinstance(file, str): 

316 zip.writestr(arc_name, file.encode("utf-8")) 

317 else: 

318 if isinstance(file, BytesReader): 

319 reader = file 

320 else: 

321 reader = get_reader(file) 

322 

323 if ( 

324 isinstance(reader.original_root, ZipFile) 

325 and reader.original_root.filename == zip.filename 

326 and reader.original_file_name == arc_name 

327 ): 

328 logger.debug( 

329 f"Not copying {reader.original_root}/{reader.original_file_name} to itself." 

330 ) 

331 continue 

332 

333 with zip.open(arc_name, "w") as dest: 

334 shutil.copyfileobj(reader, dest, 1024 * 8) 

335 

336 

337def write_zip( 

338 path: Union[FilePath, IO[bytes]], 

339 content: Mapping[ 

340 FileName, Union[str, FilePath, ZipPath, BioimageioYamlContentView, BytesReader] 

341 ], 

342 *, 

343 compression: int, 

344 compression_level: int, 

345) -> None: 

346 """Write a zip archive. 

347 

348 Args: 

349 path: output path to write to. 

350 content: dict mapping archive names to local file paths, strings (for text files), or dict (for yaml files). 

351 compression: The numeric constant of compression method. 

352 compression_level: Compression level to use when writing files to the archive. 

353 See https://docs.python.org/3/library/zipfile.html#zipfile.ZipFile 

354 

355 """ 

356 if isinstance(path, Path): 

357 path.parent.mkdir(parents=True, exist_ok=True) 

358 

359 with ZipFile( 

360 path, "w", compression=compression, compresslevel=compression_level 

361 ) as zip: 

362 write_content_to_zip(content, zip) 

363 

364 

365def load_array(source: Union[FileSource, FileDescr, ZipPath]) -> NDArray[Any]: 

366 """load a numpy ndarray from a .npy file""" 

367 reader = get_reader(source) 

368 if settings.allow_pickle: 

369 logger.warning("Loading numpy array with `allow_pickle=True`.") 

370 

371 return numpy.load(reader, allow_pickle=settings.allow_pickle) 

372 

373 

374def save_array(path: Union[Path, ZipPath], array: NDArray[Any]) -> None: 

375 """save a numpy ndarray to a .npy file""" 

376 with path.open(mode="wb") as f: 

377 assert not isinstance(f, io.TextIOWrapper) 

378 return numpy.save(f, array, allow_pickle=False)