Coverage for src / bioimageio / spec / _internal / io_utils.py: 77%

182 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-31 13:09 +0000

1import collections.abc 

2import io 

3import shutil 

4import zipfile 

5from contextlib import nullcontext 

6from difflib import get_close_matches 

7from pathlib import Path 

8from types import MappingProxyType 

9from typing import IO, Any, Dict, Mapping, Union, cast 

10from zipfile import ZipFile 

11 

12import httpx 

13import numpy 

14from loguru import logger 

15from numpy.typing import NDArray 

16from pydantic import BaseModel, FilePath, NewPath, RootModel 

17from ruyaml import YAML 

18from typing_extensions import Unpack 

19 

20from ._settings import settings 

21from .io import ( 

22 BIOIMAGEIO_YAML, 

23 BioimageioYamlContent, 

24 BioimageioYamlContentView, 

25 BytesReader, 

26 FileDescr, 

27 HashKwargs, 

28 LightHttpFileDescr, 

29 OpenedBioimageioYaml, 

30 RelativeFilePath, 

31 YamlValue, 

32 extract_file_name, 

33 find_bioimageio_yaml_file_name, 

34 get_reader, 

35 identify_bioimageio_yaml_file_name, 

36 interprete_file_source, 

37) 

38from .io_basics import AbsoluteDirectory, FileName, ZipPath 

39from .types import FileSource, PermissiveFileSource 

40from .url import HttpUrl, RootHttpUrl 

41from .utils import cache 

42from .validation_context import ValidationContext, get_validation_context 

43 

44_yaml_load = YAML(typ="safe") 

45 

46_yaml_dump = YAML() 

47_yaml_dump.version = (1, 2) # pyright: ignore[reportAttributeAccessIssue] 

48_yaml_dump.default_flow_style = False 

49_yaml_dump.indent(mapping=2, sequence=4, offset=2) 

50_yaml_dump.width = 88 # pyright: ignore[reportAttributeAccessIssue] 

51 

52 

53def read_yaml( 

54 file: Union[FilePath, ZipPath, IO[str], IO[bytes], BytesReader, str], 

55) -> YamlValue: 

56 if isinstance(file, (ZipPath, Path)): 

57 data = file.read_text(encoding="utf-8") 

58 else: 

59 data = file 

60 

61 content: YamlValue = _yaml_load.load(data) 

62 return content 

63 

64 

65def write_yaml( 

66 content: Union[YamlValue, BioimageioYamlContentView, BaseModel], 

67 /, 

68 file: Union[NewPath, FilePath, IO[str], IO[bytes], ZipPath], 

69): 

70 if isinstance(file, Path): 

71 cm = file.open("w", encoding="utf-8") 

72 else: 

73 cm = nullcontext(file) 

74 

75 if isinstance(content, BaseModel): 

76 content = content.model_dump(mode="json") 

77 

78 with cm as f: 

79 _yaml_dump.dump(content, f) 

80 

81 

82def _sanitize_bioimageio_yaml(content: YamlValue) -> BioimageioYamlContent: 

83 if not isinstance(content, dict): 

84 raise ValueError( 

85 f"Expected {BIOIMAGEIO_YAML} content to be a mapping (got {type(content)})." 

86 ) 

87 

88 for key in content: 

89 if not isinstance(key, str): 

90 raise ValueError( 

91 f"Expected all keys (field names) in a {BIOIMAGEIO_YAML} " 

92 + f"to be strings (got '{key}' of type {type(key)})." 

93 ) 

94 

95 return cast(BioimageioYamlContent, content) 

96 

97 

98def _open_bioimageio_rdf_in_zip( 

99 path: ZipPath, 

100 *, 

101 original_root: Union[AbsoluteDirectory, RootHttpUrl, ZipFile], 

102 original_source_name: str, 

103) -> OpenedBioimageioYaml: 

104 with path.open("rb") as f: 

105 assert not isinstance(f, io.TextIOWrapper) 

106 unparsed_content = f.read().decode(encoding="utf-8") 

107 

108 content = _sanitize_bioimageio_yaml(read_yaml(io.StringIO(unparsed_content))) 

109 

110 return OpenedBioimageioYaml( 

111 content, 

112 original_root=original_root, 

113 original_file_name=extract_file_name(path), 

114 original_source_name=original_source_name, 

115 unparsed_content=unparsed_content, 

116 ) 

117 

118 

119def _open_bioimageio_zip( 

120 source: ZipFile, 

121 *, 

122 original_source_name: str, 

123) -> OpenedBioimageioYaml: 

124 rdf_name = identify_bioimageio_yaml_file_name( 

125 [info.filename for info in source.filelist] 

126 ) 

127 return _open_bioimageio_rdf_in_zip( 

128 ZipPath(source, rdf_name), 

129 original_root=source, 

130 original_source_name=original_source_name, 

131 ) 

132 

133 

134def open_bioimageio_yaml( 

135 source: Union[PermissiveFileSource, ZipFile, ZipPath], 

136 /, 

137 **kwargs: Unpack[HashKwargs], 

138) -> OpenedBioimageioYaml: 

139 if ( 

140 isinstance(source, str) 

141 and source.startswith("huggingface/") 

142 and source.count("/") >= 2 

143 ): 

144 if source.count("/") == 2: 

145 # huggingface/{user_or_org}/{repo_name} 

146 repo_id = source[len("huggingface/") :] 

147 branch = "main" 

148 else: 

149 # huggingface/{user_or_org}/{repo_id}/ 

150 # huggingface/{user_or_org}/{repo_id}/version 

151 repo_id, version = source[len("huggingface/") :].rsplit("/", 1) 

152 if len(version) == 0: 

153 branch = "main" 

154 elif version[0].isdigit(): 

155 branch = f"v{version}" 

156 else: 

157 branch = version 

158 

159 source = HttpUrl( 

160 settings.huggingface_http_pattern.format(repo_id=repo_id, branch=branch) 

161 ) 

162 

163 if isinstance(source, RelativeFilePath): 

164 source = source.absolute() 

165 

166 if isinstance(source, ZipFile): 

167 return _open_bioimageio_zip(source, original_source_name=str(source)) 

168 elif isinstance(source, ZipPath): 

169 return _open_bioimageio_rdf_in_zip( 

170 source, original_root=source.root, original_source_name=str(source) 

171 ) 

172 

173 try: 

174 if isinstance(source, (Path, str)) and (source_dir := Path(source)).is_dir(): 

175 # open bioimageio yaml from a folder 

176 src = source_dir / find_bioimageio_yaml_file_name(source_dir) 

177 else: 

178 src = interprete_file_source(source) 

179 

180 reader = get_reader(src, **kwargs) 

181 

182 except Exception as e: 

183 # check if `source` is a collection id 

184 if not isinstance(source, str): 

185 raise e 

186 

187 if settings.collection_http_pattern: 

188 with ValidationContext(perform_io_checks=False): 

189 url = HttpUrl( 

190 settings.collection_http_pattern.format(bioimageio_id=source) 

191 ) 

192 

193 try: 

194 r = httpx.get(url, follow_redirects=True) 

195 _ = r.raise_for_status() 

196 unparsed_content = r.content.decode(encoding="utf-8") 

197 content = _sanitize_bioimageio_yaml(read_yaml(unparsed_content)) 

198 except Exception as e_coll_pattern: 

199 collection_pattern_error_msg = f"BIOIMAGEIO_COLLECTION_HTTP_PATTERN: Failed to get bioimageio.yaml from {url}: {e_coll_pattern}" 

200 logger.warning(collection_pattern_error_msg) 

201 collection_pattern_error_msg = "\n" + collection_pattern_error_msg 

202 else: 

203 logger.info("loaded {} from {}", source, url) 

204 original_file_name = ( 

205 "rdf.yaml" if url.path is None else url.path.split("/")[-1] 

206 ) 

207 return OpenedBioimageioYaml( 

208 content=content, 

209 original_root=url.parent, 

210 original_file_name=original_file_name, 

211 original_source_name=url, 

212 unparsed_content=unparsed_content, 

213 ) 

214 else: 

215 collection_pattern_error_msg = "" 

216 

217 if not isinstance(settings.id_map, str) or "/" not in settings.id_map: 

218 raise ValueError( 

219 f"BIOIMAGEIO_ID_MAP: Invalid id map url {settings.id_map}.{collection_pattern_error_msg}" 

220 ) from e 

221 

222 id_map = get_id_map() 

223 if not id_map: 

224 raise ValueError( 

225 f"BIOIMAGEIO_ID_MAP: Empty (or unavailable) id map from {settings.id_map}.{collection_pattern_error_msg}" 

226 ) from e 

227 

228 if id_map and source not in id_map: 

229 close_matches = get_close_matches(source, id_map) 

230 if len(close_matches) == 0: 

231 raise ValueError( 

232 f"BIOIMAGEIO_ID_MAP: '{source}' not found in {settings.id_map}.{collection_pattern_error_msg}" 

233 ) from e 

234 

235 if len(close_matches) == 1: 

236 did_you_mean = f" Did you mean '{close_matches[0]}'?" 

237 else: 

238 did_you_mean = f" Did you mean any of {close_matches}?" 

239 

240 raise ValueError( 

241 f"BIOIMAGEIO_ID_MAP: '{source}' not found in {settings.id_map}.{did_you_mean}{collection_pattern_error_msg}" 

242 ) from e 

243 

244 entry = id_map[source] 

245 logger.info("loading {} from {}", source, entry.source) 

246 reader = entry.get_reader() 

247 with get_validation_context().replace(perform_io_checks=False): 

248 src = HttpUrl(entry.source) 

249 

250 if reader.is_zipfile: 

251 return _open_bioimageio_zip(ZipFile(reader), original_source_name=str(src)) 

252 

253 unparsed_content = reader.read().decode(encoding="utf-8") 

254 content = _sanitize_bioimageio_yaml(read_yaml(unparsed_content)) 

255 

256 if isinstance(src, RelativeFilePath): 

257 src = src.absolute() 

258 

259 if isinstance(src, ZipPath): 

260 root = src.root 

261 else: 

262 root = src.parent 

263 

264 return OpenedBioimageioYaml( 

265 content, 

266 original_root=root, 

267 original_source_name=str(src), 

268 original_file_name=extract_file_name(src), 

269 unparsed_content=unparsed_content, 

270 ) 

271 

272 

273_IdMap = RootModel[Dict[str, LightHttpFileDescr]] 

274 

275 

276def _get_id_map_impl(url: str) -> Dict[str, LightHttpFileDescr]: 

277 if not isinstance(url, str) or "/" not in url: 

278 logger.opt(depth=1).error("invalid id map url: {}", url) 

279 try: 

280 id_map_raw: Any = httpx.get( 

281 url, timeout=settings.http_timeout, follow_redirects=True 

282 ).json() 

283 except Exception as e: 

284 logger.opt(depth=1).error("failed to get {}: {}", url, e) 

285 return {} 

286 

287 id_map = _IdMap.model_validate(id_map_raw) 

288 return id_map.root 

289 

290 

291@cache 

292def get_id_map() -> Mapping[str, LightHttpFileDescr]: 

293 try: 

294 if settings.resolve_draft: 

295 ret = _get_id_map_impl(settings.id_map_draft) 

296 else: 

297 ret = {} 

298 

299 ret.update(_get_id_map_impl(settings.id_map)) 

300 

301 except Exception as e: 

302 logger.error("failed to get resource id map: {}", e) 

303 ret = {} 

304 

305 return MappingProxyType(ret) 

306 

307 

308def write_content_to_zip( 

309 content: Mapping[ 

310 FileName, 

311 Union[ 

312 str, FilePath, ZipPath, BioimageioYamlContentView, FileDescr, BytesReader 

313 ], 

314 ], 

315 zip: zipfile.ZipFile, 

316): 

317 """write strings as text, dictionaries as yaml and files to a ZipFile 

318 Args: 

319 content: dict mapping archive names to local file paths, 

320 strings (for text files), or dict (for yaml files). 

321 zip: ZipFile 

322 """ 

323 for arc_name, file in content.items(): 

324 if isinstance(file, collections.abc.Mapping): 

325 buf = io.StringIO() 

326 write_yaml(file, buf) 

327 file = buf.getvalue() 

328 

329 if isinstance(file, str): 

330 zip.writestr(arc_name, file.encode("utf-8")) 

331 else: 

332 if isinstance(file, BytesReader): 

333 reader = file 

334 else: 

335 reader = get_reader(file) 

336 

337 if ( 

338 isinstance(reader.original_root, ZipFile) 

339 and reader.original_root is zip 

340 ): 

341 logger.debug( 

342 f"Not copying {reader.original_file_name} in " 

343 + ( 

344 "zip file" 

345 if reader.original_root.filename is None 

346 else reader.original_root.filename 

347 ) 

348 + " to itself." 

349 ) 

350 continue 

351 

352 with zip.open(arc_name, "w") as dest: 

353 shutil.copyfileobj(reader, dest, 1024 * 8) 

354 

355 

356def write_zip( 

357 path: Union[FilePath, IO[bytes]], 

358 content: Mapping[ 

359 FileName, Union[str, FilePath, ZipPath, BioimageioYamlContentView, BytesReader] 

360 ], 

361 *, 

362 compression: int, 

363 compression_level: int, 

364) -> None: 

365 """Write a zip archive. 

366 

367 Args: 

368 path: output path to write to. 

369 content: dict mapping archive names to local file paths, strings (for text files), or dict (for yaml files). 

370 compression: The numeric constant of compression method. 

371 compression_level: Compression level to use when writing files to the archive. 

372 See https://docs.python.org/3/library/zipfile.html#zipfile.ZipFile 

373 

374 """ 

375 if isinstance(path, Path): 

376 path.parent.mkdir(parents=True, exist_ok=True) 

377 

378 with ZipFile( 

379 path, "w", compression=compression, compresslevel=compression_level 

380 ) as zip: 

381 write_content_to_zip(content, zip) 

382 

383 

384def load_array(source: Union[FileSource, FileDescr, ZipPath]) -> NDArray[Any]: 

385 """load a numpy ndarray from a .npy file""" 

386 reader = get_reader(source) 

387 if settings.allow_pickle: 

388 logger.warning("Loading numpy array with `allow_pickle=True`.") 

389 

390 return numpy.load(reader, allow_pickle=settings.allow_pickle) 

391 

392 

393def save_array(path: Union[Path, ZipPath], array: NDArray[Any]) -> None: 

394 """save a numpy ndarray to a .npy file""" 

395 with path.open(mode="wb") as f: 

396 assert not isinstance(f, io.TextIOWrapper) 

397 return numpy.save(f, array, allow_pickle=False)