Coverage for src / bioimageio / spec / _internal / io_utils.py: 76%

189 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-15 08:44 +0000

1import collections.abc 

2import io 

3import shutil 

4import zipfile 

5from contextlib import nullcontext 

6from difflib import get_close_matches 

7from pathlib import Path 

8from types import MappingProxyType 

9from typing import IO, Any, Dict, Mapping, Union, cast 

10from zipfile import ZipFile 

11 

12import httpx 

13import numpy 

14from loguru import logger 

15from numpy.typing import NDArray 

16from pydantic import BaseModel, FilePath, NewPath, RootModel 

17from ruyaml import YAML 

18from typing_extensions import Unpack 

19 

20from ._settings import settings 

21from .io import ( 

22 BIOIMAGEIO_YAML, 

23 BioimageioYamlContent, 

24 BioimageioYamlContentView, 

25 BytesReader, 

26 FileDescr, 

27 HashKwargs, 

28 LightHttpFileDescr, 

29 OpenedBioimageioYaml, 

30 RelativeFilePath, 

31 YamlValue, 

32 extract_file_name, 

33 find_bioimageio_yaml_file_name, 

34 get_reader, 

35 identify_bioimageio_yaml_file_name, 

36 interprete_file_source, 

37) 

38from .io_basics import AbsoluteDirectory, FileName, ZipPath 

39from .types import FileSource, PermissiveFileSource 

40from .url import HttpUrl, RootHttpUrl 

41from .utils import cache 

42from .validation_context import ValidationContext, get_validation_context 

43 

44_yaml_load = YAML(typ="safe") 

45 

46_yaml_dump = YAML() 

47_yaml_dump.version = (1, 2) # pyright: ignore[reportAttributeAccessIssue] 

48_yaml_dump.default_flow_style = False 

49_yaml_dump.indent(mapping=2, sequence=4, offset=2) 

50_yaml_dump.width = 88 # pyright: ignore[reportAttributeAccessIssue] 

51 

52 

53def read_yaml( 

54 file: Union[FilePath, ZipPath, IO[str], IO[bytes], BytesReader, str], 

55) -> YamlValue: 

56 if isinstance(file, (ZipPath, Path)): 

57 data = file.read_text(encoding="utf-8") 

58 else: 

59 data = file 

60 

61 content: YamlValue = _yaml_load.load(data) 

62 return content 

63 

64 

65def write_yaml( 

66 content: Union[YamlValue, BioimageioYamlContentView, BaseModel], 

67 /, 

68 file: Union[NewPath, FilePath, IO[str], IO[bytes], ZipPath], 

69): 

70 if isinstance(file, Path): 

71 cm = file.open("w", encoding="utf-8") 

72 else: 

73 cm = nullcontext(file) 

74 

75 if isinstance(content, BaseModel): 

76 content = content.model_dump(mode="json") 

77 

78 with cm as f: 

79 _yaml_dump.dump(content, f) 

80 

81 

82def _sanitize_bioimageio_yaml(content: YamlValue) -> BioimageioYamlContent: 

83 if not isinstance(content, dict): 

84 raise ValueError( 

85 f"Expected {BIOIMAGEIO_YAML} content to be a mapping (got {type(content)})." 

86 ) 

87 

88 for key in content: 

89 if not isinstance(key, str): 

90 raise ValueError( 

91 f"Expected all keys (field names) in a {BIOIMAGEIO_YAML} " 

92 + f"to be strings (got '{key}' of type {type(key)})." 

93 ) 

94 

95 return cast(BioimageioYamlContent, content) 

96 

97 

98def _open_bioimageio_rdf_in_zip( 

99 path: ZipPath, 

100 *, 

101 original_root: Union[AbsoluteDirectory, RootHttpUrl, ZipFile], 

102 original_source_name: str, 

103) -> OpenedBioimageioYaml: 

104 with path.open("rb") as f: 

105 assert not isinstance(f, io.TextIOWrapper) 

106 unparsed_content = f.read().decode(encoding="utf-8") 

107 

108 content = _sanitize_bioimageio_yaml(read_yaml(io.StringIO(unparsed_content))) 

109 

110 return OpenedBioimageioYaml( 

111 content, 

112 original_root=original_root, 

113 original_file_name=extract_file_name(path), 

114 original_source_name=original_source_name, 

115 unparsed_content=unparsed_content, 

116 ) 

117 

118 

119def _open_bioimageio_zip( 

120 source: ZipFile, 

121 *, 

122 original_source_name: str, 

123) -> OpenedBioimageioYaml: 

124 rdf_name = identify_bioimageio_yaml_file_name( 

125 [info.filename for info in source.filelist] 

126 ) 

127 return _open_bioimageio_rdf_in_zip( 

128 ZipPath(source, rdf_name), 

129 original_root=source, 

130 original_source_name=original_source_name, 

131 ) 

132 

133 

134def open_bioimageio_yaml( 

135 source: Union[PermissiveFileSource, ZipFile, ZipPath], 

136 /, 

137 **kwargs: Unpack[HashKwargs], 

138) -> OpenedBioimageioYaml: 

139 if ( 

140 isinstance(source, str) 

141 and source.startswith("huggingface/") 

142 and source.count("/") >= 2 

143 ): 

144 if source.count("/") == 2: 

145 # huggingface/{user_or_org}/{repo_name} 

146 repo_id = source[len("huggingface/") :] 

147 branch = "main" 

148 else: 

149 # huggingface/{user_or_org}/{repo_id}/ 

150 # huggingface/{user_or_org}/{repo_id}/version 

151 repo_id, version = source[len("huggingface/") :].rsplit("/", 1) 

152 if len(version) == 0: 

153 branch = "main" 

154 elif version[0].isdigit(): 

155 branch = f"v{version}" 

156 else: 

157 branch = version 

158 

159 source = HttpUrl( 

160 settings.huggingface_http_pattern.format(repo_id=repo_id, branch=branch) 

161 ) 

162 

163 if isinstance(source, RelativeFilePath): 

164 source = source.absolute() 

165 

166 if isinstance(source, ZipFile): 

167 return _open_bioimageio_zip(source, original_source_name=str(source)) 

168 elif isinstance(source, ZipPath): 

169 return _open_bioimageio_rdf_in_zip( 

170 source, original_root=source.root, original_source_name=str(source) 

171 ) 

172 

173 try: 

174 if isinstance(source, (FileDescr, ZipPath)): 

175 src = source 

176 elif isinstance(source, (Path, str)) and (source_dir := Path(source)).is_dir(): 

177 # open bioimageio yaml from a folder 

178 src = source_dir / find_bioimageio_yaml_file_name(source_dir) 

179 else: 

180 src = interprete_file_source(source) 

181 

182 reader = get_reader(src, **kwargs) 

183 

184 except Exception as e: 

185 # check if `source` is a collection id 

186 if not isinstance(source, str): 

187 raise e 

188 

189 if settings.collection_http_pattern: 

190 with ValidationContext(perform_io_checks=False): 

191 url = HttpUrl( 

192 settings.collection_http_pattern.format(bioimageio_id=source) 

193 ) 

194 

195 try: 

196 r = httpx.get(url, follow_redirects=True) 

197 _ = r.raise_for_status() 

198 unparsed_content = r.content.decode(encoding="utf-8") 

199 content = _sanitize_bioimageio_yaml(read_yaml(unparsed_content)) 

200 except Exception as e_coll_pattern: 

201 collection_pattern_error_msg = f"BIOIMAGEIO_COLLECTION_HTTP_PATTERN: Failed to get bioimageio.yaml from {url}: {e_coll_pattern}" 

202 logger.warning(collection_pattern_error_msg) 

203 collection_pattern_error_msg = "\n" + collection_pattern_error_msg 

204 else: 

205 logger.info("loaded {} from {}", source, url) 

206 original_file_name = ( 

207 "rdf.yaml" if url.path is None else url.path.split("/")[-1] 

208 ) 

209 return OpenedBioimageioYaml( 

210 content=content, 

211 original_root=url.parent, 

212 original_file_name=original_file_name, 

213 original_source_name=url, 

214 unparsed_content=unparsed_content, 

215 ) 

216 else: 

217 collection_pattern_error_msg = "" 

218 

219 if not isinstance(settings.id_map, str) or "/" not in settings.id_map: 

220 raise ValueError( 

221 f"BIOIMAGEIO_ID_MAP: Invalid id map url {settings.id_map}.{collection_pattern_error_msg}" 

222 ) from e 

223 

224 id_map = get_id_map() 

225 if not id_map: 

226 raise ValueError( 

227 f"BIOIMAGEIO_ID_MAP: Empty (or unavailable) id map from {settings.id_map}.{collection_pattern_error_msg}" 

228 ) from e 

229 

230 if id_map and source not in id_map: 

231 close_matches = get_close_matches(source, id_map) 

232 if len(close_matches) == 0: 

233 raise ValueError( 

234 f"BIOIMAGEIO_ID_MAP: '{source}' not found in {settings.id_map}.{collection_pattern_error_msg}" 

235 ) from e 

236 

237 if len(close_matches) == 1: 

238 did_you_mean = f" Did you mean '{close_matches[0]}'?" 

239 else: 

240 did_you_mean = f" Did you mean any of {close_matches}?" 

241 

242 raise ValueError( 

243 f"BIOIMAGEIO_ID_MAP: '{source}' not found in {settings.id_map}.{did_you_mean}{collection_pattern_error_msg}" 

244 ) from e 

245 

246 entry = id_map[source] 

247 logger.info("loading {} from {}", source, entry.source) 

248 reader = entry.get_reader() 

249 with get_validation_context().replace(perform_io_checks=False): 

250 src = HttpUrl(entry.source) 

251 

252 if reader.is_zipfile: 

253 return _open_bioimageio_zip(ZipFile(reader), original_source_name=str(src)) 

254 

255 unparsed_content = reader.read().decode(encoding="utf-8") 

256 content = _sanitize_bioimageio_yaml(read_yaml(unparsed_content)) 

257 

258 if isinstance(src, RelativeFilePath): 

259 src = src.absolute() 

260 

261 if isinstance(src, ZipPath): 

262 root = src.root 

263 elif isinstance(src, FileDescr): 

264 file_source = src.source.absolute() 

265 if isinstance(file_source, ZipPath): 

266 root = file_source.root 

267 else: 

268 root = file_source.parent 

269 else: 

270 root = src.parent 

271 

272 return OpenedBioimageioYaml( 

273 content, 

274 original_root=root, 

275 original_source_name=str(src), 

276 original_file_name=extract_file_name(src), 

277 unparsed_content=unparsed_content, 

278 ) 

279 

280 

281_IdMap = RootModel[Dict[str, LightHttpFileDescr]] 

282 

283 

284def _get_id_map_impl(url: str) -> Dict[str, LightHttpFileDescr]: 

285 if not isinstance(url, str) or "/" not in url: 

286 logger.opt(depth=1).error("invalid id map url: {}", url) 

287 try: 

288 id_map_raw: Any = httpx.get( 

289 url, timeout=settings.http_timeout, follow_redirects=True 

290 ).json() 

291 except Exception as e: 

292 logger.opt(depth=1).error("failed to get {}: {}", url, e) 

293 return {} 

294 

295 id_map = _IdMap.model_validate(id_map_raw) 

296 return id_map.root 

297 

298 

299@cache 

300def get_id_map() -> Mapping[str, LightHttpFileDescr]: 

301 try: 

302 if settings.resolve_draft: 

303 ret = _get_id_map_impl(settings.id_map_draft) 

304 else: 

305 ret = {} 

306 

307 ret.update(_get_id_map_impl(settings.id_map)) 

308 

309 except Exception as e: 

310 logger.error("failed to get resource id map: {}", e) 

311 ret = {} 

312 

313 return MappingProxyType(ret) 

314 

315 

316def write_content_to_zip( 

317 content: Mapping[ 

318 FileName, 

319 Union[ 

320 str, FilePath, ZipPath, BioimageioYamlContentView, FileDescr, BytesReader 

321 ], 

322 ], 

323 zip: zipfile.ZipFile, 

324): 

325 """write strings as text, dictionaries as yaml and files to a ZipFile 

326 Args: 

327 content: dict mapping archive names to local file paths, 

328 strings (for text files), or dict (for yaml files). 

329 zip: ZipFile 

330 """ 

331 for arc_name, file in content.items(): 

332 if isinstance(file, collections.abc.Mapping): 

333 buf = io.StringIO() 

334 write_yaml(file, buf) 

335 file = buf.getvalue() 

336 

337 if isinstance(file, str): 

338 zip.writestr(arc_name, file.encode("utf-8")) 

339 else: 

340 if isinstance(file, BytesReader): 

341 reader = file 

342 else: 

343 reader = get_reader(file) 

344 

345 if ( 

346 isinstance(reader.original_root, ZipFile) 

347 and reader.original_root is zip 

348 ): 

349 logger.debug( 

350 f"Not copying {reader.original_file_name} in " 

351 + ( 

352 "zip file" 

353 if reader.original_root.filename is None 

354 else reader.original_root.filename 

355 ) 

356 + " to itself." 

357 ) 

358 continue 

359 

360 with zip.open(arc_name, "w") as dest: 

361 shutil.copyfileobj(reader, dest, 1024 * 8) 

362 

363 

364def write_zip( 

365 path: Union[FilePath, IO[bytes]], 

366 content: Mapping[ 

367 FileName, Union[str, FilePath, ZipPath, BioimageioYamlContentView, BytesReader] 

368 ], 

369 *, 

370 compression: int, 

371 compression_level: int, 

372) -> None: 

373 """Write a zip archive. 

374 

375 Args: 

376 path: output path to write to. 

377 content: dict mapping archive names to local file paths, strings (for text files), or dict (for yaml files). 

378 compression: The numeric constant of compression method. 

379 compression_level: Compression level to use when writing files to the archive. 

380 See https://docs.python.org/3/library/zipfile.html#zipfile.ZipFile 

381 

382 """ 

383 if isinstance(path, Path): 

384 path.parent.mkdir(parents=True, exist_ok=True) 

385 

386 with ZipFile( 

387 path, "w", compression=compression, compresslevel=compression_level 

388 ) as zip: 

389 write_content_to_zip(content, zip) 

390 

391 

392def load_array(source: Union[FileSource, FileDescr, ZipPath]) -> NDArray[Any]: 

393 """load a numpy ndarray from a .npy file""" 

394 reader = get_reader(source) 

395 if settings.allow_pickle: 

396 logger.warning("Loading numpy array with `allow_pickle=True`.") 

397 

398 return numpy.load(reader, allow_pickle=settings.allow_pickle) 

399 

400 

401def save_array(path: Union[Path, ZipPath], array: NDArray[Any]) -> None: 

402 """save a numpy ndarray to a .npy file""" 

403 with path.open(mode="wb") as f: 

404 assert not isinstance(f, io.TextIOWrapper) 

405 return numpy.save(f, array, allow_pickle=False)