Coverage for bioimageio/spec/_internal/io.py: 81%

457 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-18 12:47 +0000

1from __future__ import annotations 

2 

3import collections.abc 

4import hashlib 

5import sys 

6import warnings 

7import zipfile 

8from abc import abstractmethod 

9from contextlib import nullcontext 

10from dataclasses import dataclass, field 

11from datetime import date as _date 

12from datetime import datetime as _datetime 

13from functools import partial 

14from io import TextIOWrapper 

15from pathlib import Path, PurePath, PurePosixPath 

16from tempfile import mktemp 

17from typing import ( 

18 TYPE_CHECKING, 

19 Any, 

20 Callable, 

21 Dict, 

22 Generic, 

23 Iterable, 

24 List, 

25 Mapping, 

26 Optional, 

27 Sequence, 

28 Set, 

29 Tuple, 

30 Type, 

31 TypedDict, 

32 TypeVar, 

33 Union, 

34 overload, 

35) 

36from urllib.parse import urlparse, urlsplit, urlunsplit 

37from zipfile import ZipFile 

38 

39import httpx 

40import pydantic 

41from genericache import NoopCache 

42from genericache.digest import ContentDigest, UrlDigest 

43from pydantic import ( 

44 AnyUrl, 

45 DirectoryPath, 

46 Field, 

47 GetCoreSchemaHandler, 

48 PrivateAttr, 

49 RootModel, 

50 TypeAdapter, 

51 model_serializer, 

52 model_validator, 

53) 

54from pydantic_core import core_schema 

55from tqdm import tqdm 

56from typing_extensions import ( 

57 Annotated, 

58 LiteralString, 

59 NotRequired, 

60 Self, 

61 TypeGuard, 

62 Unpack, 

63 assert_never, 

64) 

65from typing_extensions import TypeAliasType as _TypeAliasType 

66 

67from ._settings import settings 

68from .io_basics import ( 

69 ALL_BIOIMAGEIO_YAML_NAMES, 

70 ALTERNATIVE_BIOIMAGEIO_YAML_NAMES, 

71 BIOIMAGEIO_YAML, 

72 AbsoluteDirectory, 

73 AbsoluteFilePath, 

74 BytesReader, 

75 FileName, 

76 FilePath, 

77 Sha256, 

78 ZipPath, 

79 get_sha256, 

80) 

81from .node import Node 

82from .progress import Progressbar 

83from .root_url import RootHttpUrl 

84from .type_guards import is_dict, is_list, is_mapping, is_sequence 

85from .url import HttpUrl 

86from .utils import SLOTS 

87from .validation_context import get_validation_context 

88 

89AbsolutePathT = TypeVar( 

90 "AbsolutePathT", 

91 bound=Union[HttpUrl, AbsoluteDirectory, AbsoluteFilePath, ZipPath], 

92) 

93 

94 

95class LightHttpFileDescr(Node): 

96 """http source with sha256 value (minimal validation)""" 

97 

98 source: pydantic.HttpUrl 

99 """file source""" 

100 

101 sha256: Sha256 

102 """SHA256 checksum of the source file""" 

103 

104 def get_reader( 

105 self, 

106 *, 

107 progressbar: Union[Progressbar, Callable[[], Progressbar], bool, None] = None, 

108 ) -> BytesReader: 

109 """open the file source (download if needed)""" 

110 return get_reader(self.source, sha256=self.sha256, progressbar=progressbar) 

111 

112 download = get_reader 

113 """alias for get_reader() method""" 

114 

115 

116class RelativePathBase(RootModel[PurePath], Generic[AbsolutePathT], frozen=True): 

117 _absolute: AbsolutePathT = PrivateAttr() 

118 

119 @property 

120 def path(self) -> PurePath: 

121 return self.root 

122 

123 def absolute( # method not property analog to `pathlib.Path.absolute()` 

124 self, 

125 ) -> AbsolutePathT: 

126 """get the absolute path/url 

127 

128 (resolved at time of initialization with the root of the ValidationContext) 

129 """ 

130 return self._absolute 

131 

132 def model_post_init(self, __context: Any) -> None: 

133 """set `_absolute` property with validation context at creation time. @private""" 

134 if self.root.is_absolute(): 

135 raise ValueError(f"{self.root} is an absolute path.") 

136 

137 if self.root.parts and self.root.parts[0] in ("http:", "https:"): 

138 raise ValueError(f"{self.root} looks like an http url.") 

139 

140 self._absolute = ( # pyright: ignore[reportAttributeAccessIssue] 

141 self.get_absolute(get_validation_context().root) 

142 ) 

143 super().model_post_init(__context) 

144 

145 def __str__(self) -> str: 

146 return self.root.as_posix() 

147 

148 def __repr__(self) -> str: 

149 return f"RelativePath('{self}')" 

150 

151 @model_serializer() 

152 def format(self) -> str: 

153 return str(self) 

154 

155 @abstractmethod 

156 def get_absolute( 

157 self, root: Union[RootHttpUrl, AbsoluteDirectory, pydantic.AnyUrl, ZipFile] 

158 ) -> AbsolutePathT: ... 

159 

160 def _get_absolute_impl( 

161 self, root: Union[RootHttpUrl, AbsoluteDirectory, pydantic.AnyUrl, ZipFile] 

162 ) -> Union[Path, HttpUrl, ZipPath]: 

163 if isinstance(root, Path): 

164 return (root / self.root).absolute() 

165 

166 rel_path = self.root.as_posix().strip("/") 

167 if isinstance(root, ZipFile): 

168 return ZipPath(root, rel_path) 

169 

170 parsed = urlsplit(str(root)) 

171 path = list(parsed.path.strip("/").split("/")) 

172 if ( 

173 parsed.netloc == "zenodo.org" 

174 and parsed.path.startswith("/api/records/") 

175 and parsed.path.endswith("/content") 

176 ): 

177 path.insert(-1, rel_path) 

178 else: 

179 path.append(rel_path) 

180 

181 return HttpUrl( 

182 urlunsplit( 

183 ( 

184 parsed.scheme, 

185 parsed.netloc, 

186 "/".join(path), 

187 parsed.query, 

188 parsed.fragment, 

189 ) 

190 ) 

191 ) 

192 

193 @classmethod 

194 def _validate(cls, value: Union[PurePath, str]): 

195 if isinstance(value, str) and ( 

196 value.startswith("https://") or value.startswith("http://") 

197 ): 

198 raise ValueError(f"{value} looks like a URL, not a relative path") 

199 

200 return cls(PurePath(value)) 

201 

202 

203class RelativeFilePath( 

204 RelativePathBase[Union[AbsoluteFilePath, HttpUrl, ZipPath]], frozen=True 

205): 

206 """A path relative to the `rdf.yaml` file (also if the RDF source is a URL).""" 

207 

208 def model_post_init(self, __context: Any) -> None: 

209 """add validation @private""" 

210 if not self.root.parts: # an empty path can only be a directory 

211 raise ValueError(f"{self.root} is not a valid file path.") 

212 

213 super().model_post_init(__context) 

214 

215 def get_absolute( 

216 self, root: "RootHttpUrl | Path | AnyUrl | ZipFile" 

217 ) -> "AbsoluteFilePath | HttpUrl | ZipPath": 

218 absolute = self._get_absolute_impl(root) 

219 if ( 

220 isinstance(absolute, Path) 

221 and (context := get_validation_context()).perform_io_checks 

222 and str(self.root) not in context.known_files 

223 and not absolute.is_file() 

224 ): 

225 raise ValueError(f"{absolute} does not point to an existing file") 

226 

227 return absolute 

228 

229 

230class RelativeDirectory( 

231 RelativePathBase[Union[AbsoluteDirectory, HttpUrl, ZipPath]], frozen=True 

232): 

233 def get_absolute( 

234 self, root: "RootHttpUrl | Path | AnyUrl | ZipFile" 

235 ) -> "AbsoluteDirectory | HttpUrl | ZipPath": 

236 absolute = self._get_absolute_impl(root) 

237 if ( 

238 isinstance(absolute, Path) 

239 and get_validation_context().perform_io_checks 

240 and not absolute.is_dir() 

241 ): 

242 raise ValueError(f"{absolute} does not point to an existing directory") 

243 

244 return absolute 

245 

246 

247FileSource = Annotated[ 

248 Union[HttpUrl, RelativeFilePath, FilePath], 

249 Field(union_mode="left_to_right"), 

250] 

251PermissiveFileSource = Union[FileSource, str, pydantic.HttpUrl] 

252 

253 

254class FileDescr(Node): 

255 """A file description""" 

256 

257 source: FileSource 

258 """File source""" 

259 

260 sha256: Optional[Sha256] = None 

261 """SHA256 hash value of the **source** file.""" 

262 

263 @model_validator(mode="after") 

264 def _validate_sha256(self) -> Self: 

265 if get_validation_context().perform_io_checks: 

266 self.validate_sha256() 

267 

268 return self 

269 

270 def validate_sha256(self, force_recompute: bool = False) -> None: 

271 """validate the sha256 hash value of the **source** file""" 

272 context = get_validation_context() 

273 src_str = str(self.source) 

274 if not force_recompute and src_str in context.known_files: 

275 actual_sha = context.known_files[src_str] 

276 else: 

277 reader = get_reader(self.source, sha256=self.sha256) 

278 if force_recompute: 

279 actual_sha = get_sha256(reader) 

280 else: 

281 actual_sha = reader.sha256 

282 

283 context.known_files[src_str] = actual_sha 

284 

285 if actual_sha is None: 

286 return 

287 elif self.sha256 == actual_sha: 

288 pass 

289 elif self.sha256 is None or context.update_hashes: 

290 self.sha256 = actual_sha 

291 elif self.sha256 != actual_sha: 

292 raise ValueError( 

293 f"Sha256 mismatch for {self.source}. Expected {self.sha256}, got " 

294 + f"{actual_sha}. Update expected `sha256` or point to the matching " 

295 + "file." 

296 ) 

297 

298 def get_reader( 

299 self, *, progressbar: Union[Progressbar, Callable[[], Progressbar], bool] = True 

300 ): 

301 """open the file source (download if needed)""" 

302 return get_reader(self.source, progressbar=progressbar, sha256=self.sha256) 

303 

304 download = get_reader 

305 """alias for get_reader() method""" 

306 

307 

308path_or_url_adapter: "TypeAdapter[Union[FilePath, DirectoryPath, HttpUrl]]" = ( 

309 TypeAdapter(Union[FilePath, DirectoryPath, HttpUrl]) 

310) 

311 

312 

313@dataclass(frozen=True, **SLOTS) 

314class WithSuffix: 

315 suffix: Union[LiteralString, Tuple[LiteralString, ...]] 

316 case_sensitive: bool 

317 

318 def __get_pydantic_core_schema__( 

319 self, source: Type[Any], handler: GetCoreSchemaHandler 

320 ): 

321 if not self.suffix: 

322 raise ValueError("suffix may not be empty") 

323 

324 schema = handler(source) 

325 return core_schema.no_info_after_validator_function( 

326 self.validate, 

327 schema, 

328 ) 

329 

330 def validate( 

331 self, value: Union[FileSource, FileDescr] 

332 ) -> Union[FileSource, FileDescr]: 

333 return validate_suffix(value, self.suffix, case_sensitive=self.case_sensitive) 

334 

335 

336def wo_special_file_name(src: F) -> F: 

337 if has_valid_bioimageio_yaml_name(src): 

338 raise ValueError( 

339 f"'{src}' not allowed here as its filename is reserved to identify" 

340 + f" '{BIOIMAGEIO_YAML}' (or equivalent) files." 

341 ) 

342 

343 return src 

344 

345 

346def has_valid_bioimageio_yaml_name(src: Union[FileSource, FileDescr]) -> bool: 

347 return is_valid_bioimageio_yaml_name(extract_file_name(src)) 

348 

349 

350def is_valid_bioimageio_yaml_name(file_name: FileName) -> bool: 

351 for bioimageio_name in ALL_BIOIMAGEIO_YAML_NAMES: 

352 if file_name == bioimageio_name or file_name.endswith("." + bioimageio_name): 

353 return True 

354 

355 return False 

356 

357 

358def identify_bioimageio_yaml_file_name(file_names: Iterable[FileName]) -> FileName: 

359 file_names = sorted(file_names) 

360 for bioimageio_name in ALL_BIOIMAGEIO_YAML_NAMES: 

361 for file_name in file_names: 

362 if file_name == bioimageio_name or file_name.endswith( 

363 "." + bioimageio_name 

364 ): 

365 return file_name 

366 

367 raise ValueError( 

368 f"No {BIOIMAGEIO_YAML} found in {file_names}. (Looking for '{BIOIMAGEIO_YAML}'" 

369 + " or or any of the alterntive file names:" 

370 + f" {ALTERNATIVE_BIOIMAGEIO_YAML_NAMES}, or any file with an extension of" 

371 + f" those, e.g. 'anything.{BIOIMAGEIO_YAML}')." 

372 ) 

373 

374 

375def find_bioimageio_yaml_file_name(path: Union[Path, ZipFile]) -> FileName: 

376 if isinstance(path, ZipFile): 

377 file_names = path.namelist() 

378 elif path.is_file(): 

379 if not zipfile.is_zipfile(path): 

380 return path.name 

381 

382 with ZipFile(path, "r") as f: 

383 file_names = f.namelist() 

384 else: 

385 file_names = [p.name for p in path.glob("*")] 

386 

387 return identify_bioimageio_yaml_file_name( 

388 file_names 

389 ) # TODO: try/except with better error message for dir 

390 

391 

392def ensure_has_valid_bioimageio_yaml_name(src: FileSource) -> FileSource: 

393 if not has_valid_bioimageio_yaml_name(src): 

394 raise ValueError( 

395 f"'{src}' does not have a valid filename to identify" 

396 + f" '{BIOIMAGEIO_YAML}' (or equivalent) files." 

397 ) 

398 

399 return src 

400 

401 

402def ensure_is_valid_bioimageio_yaml_name(file_name: FileName) -> FileName: 

403 if not is_valid_bioimageio_yaml_name(file_name): 

404 raise ValueError( 

405 f"'{file_name}' is not a valid filename to identify" 

406 + f" '{BIOIMAGEIO_YAML}' (or equivalent) files." 

407 ) 

408 

409 return file_name 

410 

411 

412# types as loaded from YAML 1.2 (with ruyaml) 

413YamlLeafValue = Union[ 

414 bool, _date, _datetime, int, float, str, None 

415] # note: order relevant for deserializing 

416YamlKey = Union[ # YAML Arrays are cast to tuples if used as key in mappings 

417 YamlLeafValue, Tuple[YamlLeafValue, ...] # (nesting is not allowed though) 

418] 

419if TYPE_CHECKING: 

420 YamlValue = Union[YamlLeafValue, List["YamlValue"], Dict[YamlKey, "YamlValue"]] 

421 YamlValueView = Union[ 

422 YamlLeafValue, Sequence["YamlValueView"], Mapping[YamlKey, "YamlValueView"] 

423 ] 

424else: 

425 # for pydantic validation we need to use `TypeAliasType`, 

426 # see https://docs.pydantic.dev/latest/concepts/types/#named-recursive-types 

427 # however this results in a partially unknown type with the current pyright 1.1.388 

428 YamlValue = _TypeAliasType( 

429 "YamlValue", 

430 Union[YamlLeafValue, List["YamlValue"], Dict[YamlKey, "YamlValue"]], 

431 ) 

432 YamlValueView = _TypeAliasType( 

433 "YamlValueView", 

434 Union[ 

435 YamlLeafValue, 

436 Sequence["YamlValueView"], 

437 Mapping[YamlKey, "YamlValueView"], 

438 ], 

439 ) 

440 

441BioimageioYamlContent = Dict[str, YamlValue] 

442BioimageioYamlContentView = Mapping[str, YamlValueView] 

443BioimageioYamlSource = Union[ 

444 PermissiveFileSource, ZipFile, BioimageioYamlContent, BioimageioYamlContentView 

445] 

446 

447 

448@overload 

449def deepcopy_yaml_value(value: BioimageioYamlContentView) -> BioimageioYamlContent: ... 

450 

451 

452@overload 

453def deepcopy_yaml_value(value: YamlValueView) -> YamlValue: ... 

454 

455 

456def deepcopy_yaml_value( 

457 value: Union[BioimageioYamlContentView, YamlValueView], 

458) -> Union[BioimageioYamlContent, YamlValue]: 

459 if isinstance(value, str): 

460 return value 

461 elif isinstance(value, collections.abc.Mapping): 

462 return {key: deepcopy_yaml_value(val) for key, val in value.items()} 

463 elif isinstance(value, collections.abc.Sequence): 

464 return [deepcopy_yaml_value(val) for val in value] 

465 else: 

466 return value 

467 

468 

469def is_yaml_leaf_value(value: Any) -> TypeGuard[YamlLeafValue]: 

470 return isinstance(value, (bool, _date, _datetime, int, float, str, type(None))) 

471 

472 

473def is_yaml_list(value: Any) -> TypeGuard[List[YamlValue]]: 

474 return is_list(value) and all(is_yaml_value(item) for item in value) 

475 

476 

477def is_yaml_sequence(value: Any) -> TypeGuard[List[YamlValueView]]: 

478 return is_sequence(value) and all(is_yaml_value(item) for item in value) 

479 

480 

481def is_yaml_dict(value: Any) -> TypeGuard[BioimageioYamlContent]: 

482 return is_dict(value) and all( 

483 isinstance(key, str) and is_yaml_value(val) for key, val in value.items() 

484 ) 

485 

486 

487def is_yaml_mapping(value: Any) -> TypeGuard[BioimageioYamlContentView]: 

488 return is_mapping(value) and all( 

489 isinstance(key, str) and is_yaml_value_read_only(val) 

490 for key, val in value.items() 

491 ) 

492 

493 

494def is_yaml_value(value: Any) -> TypeGuard[YamlValue]: 

495 return is_yaml_leaf_value(value) or is_yaml_list(value) or is_yaml_dict(value) 

496 

497 

498def is_yaml_value_read_only(value: Any) -> TypeGuard[YamlValueView]: 

499 return ( 

500 is_yaml_leaf_value(value) or is_yaml_sequence(value) or is_yaml_mapping(value) 

501 ) 

502 

503 

504@dataclass(frozen=True, **SLOTS) 

505class OpenedBioimageioYaml: 

506 content: BioimageioYamlContent = field(repr=False) 

507 original_root: Union[AbsoluteDirectory, RootHttpUrl, ZipFile] 

508 original_file_name: FileName 

509 unparsed_content: str = field(repr=False) 

510 

511 

512@dataclass(frozen=True, **SLOTS) 

513class LocalFile: 

514 path: FilePath 

515 original_root: Union[AbsoluteDirectory, RootHttpUrl, ZipFile] 

516 original_file_name: FileName 

517 

518 

519@dataclass(frozen=True, **SLOTS) 

520class FileInZip: 

521 path: ZipPath 

522 original_root: Union[RootHttpUrl, ZipFile] 

523 original_file_name: FileName 

524 

525 

526class HashKwargs(TypedDict): 

527 sha256: NotRequired[Optional[Sha256]] 

528 

529 

530_file_source_adapter: TypeAdapter[Union[HttpUrl, RelativeFilePath, FilePath]] = ( 

531 TypeAdapter(FileSource) 

532) 

533 

534 

535def interprete_file_source(file_source: PermissiveFileSource) -> FileSource: 

536 if isinstance(file_source, Path): 

537 if file_source.is_dir(): 

538 raise FileNotFoundError( 

539 f"{file_source} is a directory, but expected a file." 

540 ) 

541 return file_source 

542 

543 if isinstance(file_source, HttpUrl): 

544 return file_source 

545 

546 if isinstance(file_source, pydantic.AnyUrl): 

547 file_source = str(file_source) 

548 

549 with get_validation_context().replace(perform_io_checks=False): 

550 strict = _file_source_adapter.validate_python(file_source) 

551 if isinstance(strict, Path) and strict.is_dir(): 

552 raise FileNotFoundError(f"{strict} is a directory, but expected a file.") 

553 

554 return strict 

555 

556 

557def extract( 

558 source: Union[FilePath, ZipFile, ZipPath], 

559 folder: Optional[DirectoryPath] = None, 

560 overwrite: bool = False, 

561) -> DirectoryPath: 

562 extract_member = None 

563 if isinstance(source, ZipPath): 

564 extract_member = source.at 

565 source = source.root 

566 

567 if isinstance(source, ZipFile): 

568 zip_context = nullcontext(source) 

569 if folder is None: 

570 if source.filename is None: 

571 folder = Path(mktemp()) 

572 else: 

573 zip_path = Path(source.filename) 

574 folder = zip_path.with_suffix(zip_path.suffix + ".unzip") 

575 else: 

576 zip_context = ZipFile(source, "r") 

577 if folder is None: 

578 folder = source.with_suffix(source.suffix + ".unzip") 

579 

580 if overwrite and folder.exists(): 

581 warnings.warn(f"Overwriting existing unzipped archive at {folder}") 

582 

583 with zip_context as f: 

584 if extract_member is not None: 

585 extracted_file_path = folder / extract_member 

586 if extracted_file_path.exists() and not overwrite: 

587 warnings.warn(f"Found unzipped {extracted_file_path}.") 

588 else: 

589 _ = f.extract(extract_member, folder) 

590 

591 return folder 

592 

593 elif overwrite or not folder.exists(): 

594 f.extractall(folder) 

595 return folder 

596 

597 found_content = {p.relative_to(folder).as_posix() for p in folder.glob("*")} 

598 expected_content = {info.filename for info in f.filelist} 

599 if expected_missing := expected_content - found_content: 

600 parts = folder.name.split("_") 

601 nr, *suffixes = parts[-1].split(".") 

602 if nr.isdecimal(): 

603 nr = str(int(nr) + 1) 

604 else: 

605 nr = f"1.{nr}" 

606 

607 parts[-1] = ".".join([nr, *suffixes]) 

608 out_path_new = folder.with_name("_".join(parts)) 

609 warnings.warn( 

610 f"Unzipped archive at {folder} is missing expected files" 

611 + f" {expected_missing}." 

612 + f" Unzipping to {out_path_new} instead to avoid overwriting." 

613 ) 

614 return extract(f, out_path_new, overwrite=overwrite) 

615 else: 

616 warnings.warn( 

617 f"Found unzipped archive with all expected files at {folder}." 

618 ) 

619 return folder 

620 

621 

622def get_reader( 

623 source: Union[PermissiveFileSource, FileDescr, ZipPath], 

624 /, 

625 progressbar: Union[Progressbar, Callable[[], Progressbar], bool, None] = None, 

626 **kwargs: Unpack[HashKwargs], 

627) -> BytesReader: 

628 """Open a file `source` (download if needed)""" 

629 if isinstance(source, FileDescr): 

630 if "sha256" not in kwargs: 

631 kwargs["sha256"] = source.sha256 

632 

633 source = source.source 

634 elif isinstance(source, str): 

635 source = interprete_file_source(source) 

636 

637 if isinstance(source, RelativeFilePath): 

638 source = source.absolute() 

639 elif isinstance(source, pydantic.AnyUrl): 

640 with get_validation_context().replace(perform_io_checks=False): 

641 source = HttpUrl(source) 

642 

643 if isinstance(source, HttpUrl): 

644 return _open_url(source, progressbar=progressbar, **kwargs) 

645 

646 if isinstance(source, ZipPath): 

647 if not source.exists(): 

648 raise FileNotFoundError(source) 

649 

650 f = source.open(mode="rb") 

651 assert not isinstance(f, TextIOWrapper) 

652 root = source.root 

653 elif isinstance(source, Path): 

654 if source.is_dir(): 

655 raise FileNotFoundError(f"{source} is a directory, not a file") 

656 

657 if not source.exists(): 

658 raise FileNotFoundError(source) 

659 

660 f = source.open("rb") 

661 root = source.parent 

662 else: 

663 assert_never(source) 

664 

665 expected_sha = kwargs.get("sha256") 

666 if expected_sha is None: 

667 sha = None 

668 else: 

669 sha = get_sha256(f) 

670 _ = f.seek(0) 

671 if sha != expected_sha: 

672 raise ValueError( 

673 f"SHA256 mismatch for {source}. Expected {expected_sha}, got {sha}." 

674 ) 

675 

676 return BytesReader( 

677 f, 

678 sha256=sha, 

679 suffix=source.suffix, 

680 original_file_name=source.name, 

681 original_root=root, 

682 is_zipfile=None, 

683 ) 

684 

685 

686download = get_reader 

687 

688 

689def _open_url( 

690 source: HttpUrl, 

691 /, 

692 progressbar: Union[Progressbar, Callable[[], Progressbar], bool, None] = None, 

693 **kwargs: Unpack[HashKwargs], 

694) -> BytesReader: 

695 cache = ( 

696 NoopCache[RootHttpUrl](url_hasher=UrlDigest.from_str) 

697 if get_validation_context().disable_cache 

698 else settings.disk_cache 

699 ) 

700 sha = kwargs.get("sha256") 

701 digest = False if sha is None else ContentDigest.parse(hexdigest=sha) 

702 source_path = PurePosixPath( 

703 source.path 

704 or sha 

705 or hashlib.sha256(str(source).encode(encoding="utf-8")).hexdigest() 

706 ) 

707 

708 try: 

709 reader = cache.fetch( 

710 source, 

711 fetcher=partial(_fetch_url, progressbar=progressbar), 

712 force_refetch=digest, 

713 ) 

714 except Exception as e: 

715 raise ValueError(f"Failed to fetch {source}.") from e 

716 else: 

717 return BytesReader( 

718 reader, 

719 suffix=source_path.suffix, 

720 sha256=sha, 

721 original_file_name=source_path.name, 

722 original_root=source.parent, 

723 is_zipfile=None, 

724 ) 

725 

726 

727def _fetch_url( 

728 source: RootHttpUrl, 

729 *, 

730 progressbar: Union[Progressbar, Callable[[], Progressbar], bool, None] = None, 

731): 

732 if source.scheme not in ("http", "https"): 

733 raise NotImplementedError(source.scheme) 

734 

735 if callable(progressbar): 

736 progressbar = progressbar() 

737 

738 if settings.CI: 

739 headers = {"User-Agent": "ci"} 

740 if progressbar is None: 

741 progressbar = False 

742 else: 

743 headers = {} 

744 if progressbar is None: 

745 progressbar = True 

746 

747 if isinstance(progressbar, bool): 

748 # setup progressbar 

749 if progressbar: 

750 use_ascii = bool(sys.platform == "win32") 

751 progressbar = tqdm( 

752 ncols=79, 

753 ascii=use_ascii, 

754 unit="B", 

755 unit_scale=True, 

756 leave=True, 

757 ) 

758 

759 if progressbar is not False: 

760 progressbar.set_description(f"Downloading {extract_file_name(source)}") 

761 

762 if settings.user_agent is not None: 

763 headers["User-Agent"] = settings.user_agent 

764 

765 r = httpx.get(str(source), follow_redirects=True, headers=headers) 

766 _ = r.raise_for_status() 

767 

768 # set progressbar.total 

769 total = r.headers.get("content-length") 

770 if total is not None and not isinstance(total, int): 

771 try: 

772 total = int(total) 

773 except Exception: 

774 total = None 

775 

776 if progressbar is not False: 

777 if total is None: 

778 progressbar.total = 0 

779 else: 

780 progressbar.total = total 

781 

782 def iter_content(): 

783 for chunk in r.iter_bytes(chunk_size=4096): 

784 yield chunk 

785 if progressbar is not False: 

786 _ = progressbar.update(len(chunk)) 

787 

788 # Make sure the progress bar gets filled even if the actual number 

789 # is chunks is smaller than expected. This happens when streaming 

790 # text files that are compressed by the server when sending (gzip). 

791 # Binary files don't experience this. 

792 # (adapted from pooch.HttpDownloader) 

793 if progressbar is not False: 

794 progressbar.reset() 

795 if total is not None: 

796 _ = progressbar.update(total) 

797 

798 progressbar.close() 

799 

800 return iter_content() 

801 

802 

803def extract_file_name( 

804 src: Union[ 

805 pydantic.HttpUrl, RootHttpUrl, PurePath, RelativeFilePath, ZipPath, FileDescr 

806 ], 

807) -> FileName: 

808 if isinstance(src, FileDescr): 

809 src = src.source 

810 

811 if isinstance(src, ZipPath): 

812 return src.name or src.root.filename or "bioimageio.zip" 

813 elif isinstance(src, RelativeFilePath): 

814 return src.path.name 

815 elif isinstance(src, PurePath): 

816 return src.name 

817 else: 

818 url = urlparse(str(src)) 

819 if ( 

820 url.scheme == "https" 

821 and url.hostname == "zenodo.org" 

822 and url.path.startswith("/api/records/") 

823 and url.path.endswith("/content") 

824 ): 

825 return url.path.split("/")[-2] 

826 else: 

827 return url.path.split("/")[-1] 

828 

829 

830def extract_file_descrs(data: YamlValueView): 

831 collected: List[FileDescr] = [] 

832 with get_validation_context().replace(perform_io_checks=False, log_warnings=False): 

833 _extract_file_descrs_impl(data, collected) 

834 

835 return collected 

836 

837 

838def _extract_file_descrs_impl(data: YamlValueView, collected: List[FileDescr]): 

839 if isinstance(data, collections.abc.Mapping): 

840 if "source" in data and "sha256" in data: 

841 try: 

842 fd = FileDescr.model_validate( 

843 dict(source=data["source"], sha256=data["sha256"]) 

844 ) 

845 except Exception: 

846 pass 

847 else: 

848 collected.append(fd) 

849 

850 for v in data.values(): 

851 _extract_file_descrs_impl(v, collected) 

852 elif not isinstance(data, str) and isinstance(data, collections.abc.Sequence): 

853 for v in data: 

854 _extract_file_descrs_impl(v, collected) 

855 

856 

857F = TypeVar("F", bound=Union[FileSource, FileDescr]) 

858 

859 

860def validate_suffix( 

861 value: F, suffix: Union[str, Sequence[str]], case_sensitive: bool 

862) -> F: 

863 """check final suffix""" 

864 if isinstance(suffix, str): 

865 suffixes = [suffix] 

866 else: 

867 suffixes = suffix 

868 

869 assert len(suffixes) > 0, "no suffix given" 

870 assert all( 

871 suff.startswith(".") for suff in suffixes 

872 ), "expected suffixes to start with '.'" 

873 o_value = value 

874 if isinstance(value, FileDescr): 

875 strict = value.source 

876 else: 

877 strict = interprete_file_source(value) 

878 

879 if isinstance(strict, (HttpUrl, AnyUrl)): 

880 if strict.path is None or "." not in (path := strict.path): 

881 actual_suffixes = [] 

882 else: 

883 if ( 

884 strict.host == "zenodo.org" 

885 and path.startswith("/api/records/") 

886 and path.endswith("/content") 

887 ): 

888 # Zenodo API URLs have a "/content" suffix that should be ignored 

889 path = path[: -len("/content")] 

890 

891 actual_suffixes = [f".{path.split('.')[-1]}"] 

892 

893 elif isinstance(strict, PurePath): 

894 actual_suffixes = strict.suffixes 

895 elif isinstance(strict, RelativeFilePath): 

896 actual_suffixes = strict.path.suffixes 

897 else: 

898 assert_never(strict) 

899 

900 if actual_suffixes: 

901 actual_suffix = actual_suffixes[-1] 

902 else: 

903 actual_suffix = "no suffix" 

904 

905 if ( 

906 case_sensitive 

907 and actual_suffix not in suffixes 

908 or not case_sensitive 

909 and actual_suffix.lower() not in [s.lower() for s in suffixes] 

910 ): 

911 if len(suffixes) == 1: 

912 raise ValueError(f"Expected suffix {suffixes[0]}, but got {actual_suffix}") 

913 else: 

914 raise ValueError( 

915 f"Expected a suffix from {suffixes}, but got {actual_suffix}" 

916 ) 

917 

918 return o_value 

919 

920 

921def populate_cache(sources: Sequence[Union[FileDescr, LightHttpFileDescr]]): 

922 unique: Set[str] = set() 

923 for src in sources: 

924 if src.sha256 is None: 

925 continue # not caching without known SHA 

926 

927 if isinstance(src.source, (HttpUrl, pydantic.AnyUrl)): 

928 url = str(src.source) 

929 elif isinstance(src.source, RelativeFilePath): 

930 if isinstance(absolute := src.source.absolute(), HttpUrl): 

931 url = str(absolute) 

932 else: 

933 continue # not caching local paths 

934 elif isinstance(src.source, Path): 

935 continue # not caching local paths 

936 else: 

937 assert_never(src.source) 

938 

939 if url in unique: 

940 continue # skip duplicate URLs 

941 

942 unique.add(url) 

943 _ = src.download()