Coverage for src / bioimageio / spec / _internal / io.py: 79%

453 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-17 16:08 +0000

1from __future__ import annotations 

2 

3import collections.abc 

4import hashlib 

5import sys 

6import warnings 

7import zipfile 

8from abc import abstractmethod 

9from contextlib import nullcontext 

10from dataclasses import dataclass, field 

11from datetime import date as _date 

12from datetime import datetime as _datetime 

13from functools import partial 

14from io import TextIOWrapper 

15from pathlib import Path, PurePath, PurePosixPath 

16from tempfile import mkdtemp 

17from typing import ( 

18 TYPE_CHECKING, 

19 Any, 

20 Callable, 

21 Dict, 

22 Generic, 

23 Iterable, 

24 List, 

25 Mapping, 

26 Optional, 

27 Sequence, 

28 Set, 

29 Tuple, 

30 Type, 

31 TypedDict, 

32 TypeVar, 

33 Union, 

34 overload, 

35) 

36from urllib.parse import urlparse, urlsplit, urlunsplit 

37from zipfile import ZipFile 

38 

39import httpx 

40import pydantic 

41from genericache import NoopCache 

42from genericache.digest import ContentDigest, UrlDigest 

43from pydantic import ( 

44 AnyUrl, 

45 DirectoryPath, 

46 Field, 

47 GetCoreSchemaHandler, 

48 PrivateAttr, 

49 RootModel, 

50 TypeAdapter, 

51 model_serializer, 

52 model_validator, 

53) 

54from pydantic_core import core_schema 

55from tqdm import tqdm 

56from typing_extensions import ( 

57 Annotated, 

58 LiteralString, 

59 NotRequired, 

60 Self, 

61 TypeGuard, 

62 Unpack, 

63 assert_never, 

64) 

65from typing_extensions import TypeAliasType as _TypeAliasType 

66 

67from ._settings import settings 

68from .io_basics import ( 

69 ALL_BIOIMAGEIO_YAML_NAMES, 

70 ALTERNATIVE_BIOIMAGEIO_YAML_NAMES, 

71 BIOIMAGEIO_YAML, 

72 AbsoluteDirectory, 

73 AbsoluteFilePath, 

74 BytesReader, 

75 FileName, 

76 FilePath, 

77 Sha256, 

78 ZipPath, 

79 get_sha256, 

80) 

81from .node import Node 

82from .progress import Progressbar 

83from .root_url import RootHttpUrl 

84from .type_guards import is_dict, is_list, is_mapping, is_sequence 

85from .url import HttpUrl 

86from .utils import SLOTS 

87from .validation_context import get_validation_context 

88 

89AbsolutePathT = TypeVar( 

90 "AbsolutePathT", 

91 bound=Union[HttpUrl, AbsoluteDirectory, AbsoluteFilePath, ZipPath], 

92) 

93 

94 

95class LightHttpFileDescr(Node): 

96 """http source with sha256 value (minimal validation)""" 

97 

98 source: pydantic.HttpUrl 

99 """file source""" 

100 

101 sha256: Sha256 

102 """SHA256 checksum of the source file""" 

103 

104 def get_reader( 

105 self, 

106 *, 

107 progressbar: Union[Progressbar, Callable[[], Progressbar], bool, None] = None, 

108 ) -> BytesReader: 

109 """open the file source (download if needed)""" 

110 return get_reader(self.source, sha256=self.sha256, progressbar=progressbar) 

111 

112 download = get_reader 

113 """alias for get_reader() method""" 

114 

115 

116class RelativePathBase(RootModel[PurePath], Generic[AbsolutePathT], frozen=True): 

117 _absolute: AbsolutePathT = PrivateAttr() 

118 

119 @property 

120 def path(self) -> PurePath: 

121 return self.root 

122 

123 def absolute( # method not property analog to `pathlib.Path.absolute()` 

124 self, 

125 ) -> AbsolutePathT: 

126 """get the absolute path/url 

127 

128 (resolved at time of initialization with the root of the ValidationContext) 

129 """ 

130 return self._absolute 

131 

132 def model_post_init(self, __context: Any) -> None: 

133 """set `_absolute` property with validation context at creation time. @private""" 

134 if self.root.is_absolute(): 

135 raise ValueError(f"{self.root} is an absolute path.") 

136 

137 if self.root.parts and self.root.parts[0] in ("http:", "https:"): 

138 raise ValueError(f"{self.root} looks like an http url.") 

139 

140 self._absolute = ( # pyright: ignore[reportAttributeAccessIssue] 

141 self.get_absolute(get_validation_context().root) 

142 ) 

143 super().model_post_init(__context) 

144 

145 def __str__(self) -> str: 

146 return self.root.as_posix() 

147 

148 def __repr__(self) -> str: 

149 return f"RelativePath('{self}')" 

150 

151 @model_serializer() 

152 def format(self) -> str: 

153 return str(self) 

154 

155 @abstractmethod 

156 def get_absolute( 

157 self, root: Union[RootHttpUrl, AbsoluteDirectory, pydantic.AnyUrl, ZipFile] 

158 ) -> AbsolutePathT: ... 

159 

160 def _get_absolute_impl( 

161 self, root: Union[RootHttpUrl, AbsoluteDirectory, pydantic.AnyUrl, ZipFile] 

162 ) -> Union[Path, HttpUrl, ZipPath]: 

163 if isinstance(root, Path): 

164 return (root / self.root).absolute() 

165 

166 rel_path = self.root.as_posix().strip("/") 

167 if isinstance(root, ZipFile): 

168 return ZipPath(root, rel_path) 

169 

170 parsed = urlsplit(str(root)) 

171 path = list(parsed.path.strip("/").split("/")) 

172 if ( 

173 parsed.netloc == "zenodo.org" 

174 and parsed.path.startswith("/api/records/") 

175 and parsed.path.endswith("/content") 

176 ): 

177 path.insert(-1, rel_path) 

178 else: 

179 path.append(rel_path) 

180 

181 return HttpUrl( 

182 urlunsplit( 

183 ( 

184 parsed.scheme, 

185 parsed.netloc, 

186 "/".join(path), 

187 parsed.query, 

188 parsed.fragment, 

189 ) 

190 ) 

191 ) 

192 

193 @classmethod 

194 def _validate(cls, value: Union[PurePath, str]): 

195 if isinstance(value, str) and ( 

196 value.startswith("https://") or value.startswith("http://") 

197 ): 

198 raise ValueError(f"{value} looks like a URL, not a relative path") 

199 

200 return cls(PurePath(value)) 

201 

202 

203class RelativeFilePath( 

204 RelativePathBase[Union[AbsoluteFilePath, HttpUrl, ZipPath]], frozen=True 

205): 

206 """A path relative to the `rdf.yaml` file (also if the RDF source is a URL).""" 

207 

208 def model_post_init(self, __context: Any) -> None: 

209 """add validation @private""" 

210 if not self.root.parts: # an empty path can only be a directory 

211 raise ValueError(f"{self.root} is not a valid file path.") 

212 

213 super().model_post_init(__context) 

214 

215 def get_absolute( 

216 self, root: "RootHttpUrl | Path | AnyUrl | ZipFile" 

217 ) -> "AbsoluteFilePath | HttpUrl | ZipPath": 

218 absolute = self._get_absolute_impl(root) 

219 if ( 

220 isinstance(absolute, Path) 

221 and (context := get_validation_context()).perform_io_checks 

222 and str(self.root) not in context.known_files 

223 and not absolute.is_file() 

224 ): 

225 raise ValueError(f"{absolute} does not point to an existing file") 

226 

227 return absolute 

228 

229 @property 

230 def suffix(self): 

231 return self.root.suffix 

232 

233 

234class RelativeDirectory( 

235 RelativePathBase[Union[AbsoluteDirectory, HttpUrl, ZipPath]], frozen=True 

236): 

237 def get_absolute( 

238 self, root: "RootHttpUrl | Path | AnyUrl | ZipFile" 

239 ) -> "AbsoluteDirectory | HttpUrl | ZipPath": 

240 absolute = self._get_absolute_impl(root) 

241 if ( 

242 isinstance(absolute, Path) 

243 and get_validation_context().perform_io_checks 

244 and not absolute.is_dir() 

245 ): 

246 raise ValueError(f"{absolute} does not point to an existing directory") 

247 

248 return absolute 

249 

250 

251FileSource = Annotated[ 

252 Union[HttpUrl, RelativeFilePath, FilePath], 

253 Field(union_mode="left_to_right"), 

254] 

255PermissiveFileSource = Union[FileSource, str, pydantic.HttpUrl] 

256 

257 

258class FileDescr(Node): 

259 """A file description""" 

260 

261 source: FileSource 

262 """File source""" 

263 

264 sha256: Optional[Sha256] = None 

265 """SHA256 hash value of the **source** file.""" 

266 

267 @model_validator(mode="after") 

268 def _validate_sha256(self) -> Self: 

269 if get_validation_context().perform_io_checks: 

270 self.validate_sha256() 

271 

272 return self 

273 

274 def validate_sha256(self, force_recompute: bool = False) -> None: 

275 """validate the sha256 hash value of the **source** file""" 

276 context = get_validation_context() 

277 src_str = str(self.source) 

278 if not force_recompute and src_str in context.known_files: 

279 actual_sha = context.known_files[src_str] 

280 else: 

281 reader = get_reader(self.source, sha256=self.sha256) 

282 if force_recompute: 

283 actual_sha = get_sha256(reader) 

284 else: 

285 actual_sha = reader.sha256 

286 

287 context.known_files[src_str] = actual_sha 

288 

289 if actual_sha is None: 

290 return 

291 elif self.sha256 == actual_sha: 

292 pass 

293 elif self.sha256 is None or context.update_hashes: 

294 self.sha256 = actual_sha 

295 elif self.sha256 != actual_sha: 

296 raise ValueError( 

297 f"Sha256 mismatch for {self.source}. Expected {self.sha256}, got " 

298 + f"{actual_sha}. Update expected `sha256` or point to the matching " 

299 + "file." 

300 ) 

301 

302 def get_reader( 

303 self, 

304 *, 

305 progressbar: Union[Progressbar, Callable[[], Progressbar], bool, None] = None, 

306 ): 

307 """open the file source (download if needed)""" 

308 return get_reader(self.source, progressbar=progressbar, sha256=self.sha256) 

309 

310 def download( 

311 self, 

312 *, 

313 progressbar: Union[Progressbar, Callable[[], Progressbar], bool, None] = None, 

314 ): 

315 """alias for `.get_reader`""" 

316 return get_reader(self.source, progressbar=progressbar, sha256=self.sha256) 

317 

318 @property 

319 def suffix(self) -> str: 

320 return self.source.suffix 

321 

322 

323path_or_url_adapter: "TypeAdapter[Union[FilePath, DirectoryPath, HttpUrl]]" = ( 

324 TypeAdapter(Union[FilePath, DirectoryPath, HttpUrl]) 

325) 

326 

327 

328@dataclass(frozen=True, **SLOTS) 

329class WithSuffix: 

330 suffix: Union[LiteralString, Tuple[LiteralString, ...]] 

331 case_sensitive: bool 

332 

333 def __get_pydantic_core_schema__( 

334 self, source: Type[Any], handler: GetCoreSchemaHandler 

335 ): 

336 if not self.suffix: 

337 raise ValueError("suffix may not be empty") 

338 

339 schema = handler(source) 

340 return core_schema.no_info_after_validator_function( 

341 self.validate, 

342 schema, 

343 ) 

344 

345 def validate( 

346 self, value: Union[FileSource, FileDescr] 

347 ) -> Union[FileSource, FileDescr]: 

348 return validate_suffix(value, self.suffix, case_sensitive=self.case_sensitive) 

349 

350 

351def wo_special_file_name(src: F) -> F: 

352 if has_valid_bioimageio_yaml_name(src): 

353 raise ValueError( 

354 f"'{src}' not allowed here as its filename is reserved to identify" 

355 + f" '{BIOIMAGEIO_YAML}' (or equivalent) files." 

356 ) 

357 

358 return src 

359 

360 

361def has_valid_bioimageio_yaml_name(src: Union[FileSource, FileDescr]) -> bool: 

362 return is_valid_bioimageio_yaml_name(extract_file_name(src)) 

363 

364 

365def is_valid_bioimageio_yaml_name(file_name: FileName) -> bool: 

366 for bioimageio_name in ALL_BIOIMAGEIO_YAML_NAMES: 

367 if file_name == bioimageio_name or file_name.endswith("." + bioimageio_name): 

368 return True 

369 

370 return False 

371 

372 

373def identify_bioimageio_yaml_file_name(file_names: Iterable[FileName]) -> FileName: 

374 file_names = sorted(file_names) 

375 for bioimageio_name in ALL_BIOIMAGEIO_YAML_NAMES: 

376 for file_name in file_names: 

377 if file_name == bioimageio_name or file_name.endswith( 

378 "." + bioimageio_name 

379 ): 

380 return file_name 

381 

382 raise ValueError( 

383 f"No {BIOIMAGEIO_YAML} found in {file_names}. (Looking for '{BIOIMAGEIO_YAML}'" 

384 + " or or any of the alterntive file names:" 

385 + f" {ALTERNATIVE_BIOIMAGEIO_YAML_NAMES}, or any file with an extension of" 

386 + f" those, e.g. 'anything.{BIOIMAGEIO_YAML}')." 

387 ) 

388 

389 

390def find_bioimageio_yaml_file_name(path: Union[Path, ZipFile]) -> FileName: 

391 if isinstance(path, ZipFile): 

392 file_names = path.namelist() 

393 elif path.is_file(): 

394 if not zipfile.is_zipfile(path): 

395 return path.name 

396 

397 with ZipFile(path, "r") as f: 

398 file_names = f.namelist() 

399 else: 

400 file_names = [p.name for p in path.glob("*")] 

401 

402 return identify_bioimageio_yaml_file_name( 

403 file_names 

404 ) # TODO: try/except with better error message for dir 

405 

406 

407def ensure_has_valid_bioimageio_yaml_name(src: FileSource) -> FileSource: 

408 if not has_valid_bioimageio_yaml_name(src): 

409 raise ValueError( 

410 f"'{src}' does not have a valid filename to identify" 

411 + f" '{BIOIMAGEIO_YAML}' (or equivalent) files." 

412 ) 

413 

414 return src 

415 

416 

417def ensure_is_valid_bioimageio_yaml_name(file_name: FileName) -> FileName: 

418 if not is_valid_bioimageio_yaml_name(file_name): 

419 raise ValueError( 

420 f"'{file_name}' is not a valid filename to identify" 

421 + f" '{BIOIMAGEIO_YAML}' (or equivalent) files." 

422 ) 

423 

424 return file_name 

425 

426 

427# types as loaded from YAML 1.2 (with ruyaml) 

428YamlLeafValue = Union[ 

429 bool, _date, _datetime, int, float, str, None 

430] # note: order relevant for deserializing 

431YamlKey = Union[ # YAML Arrays are cast to tuples if used as key in mappings 

432 YamlLeafValue, Tuple[YamlLeafValue, ...] # (nesting is not allowed though) 

433] 

434if TYPE_CHECKING: 

435 YamlValue = Union[YamlLeafValue, List["YamlValue"], Dict[YamlKey, "YamlValue"]] 

436 YamlValueView = Union[ 

437 YamlLeafValue, Sequence["YamlValueView"], Mapping[YamlKey, "YamlValueView"] 

438 ] 

439else: 

440 # for pydantic validation we need to use `TypeAliasType`, 

441 # see https://docs.pydantic.dev/latest/concepts/types/#named-recursive-types 

442 # however this results in a partially unknown type with the current pyright 1.1.388 

443 YamlValue = _TypeAliasType( 

444 "YamlValue", 

445 Union[YamlLeafValue, List["YamlValue"], Dict[YamlKey, "YamlValue"]], 

446 ) 

447 YamlValueView = _TypeAliasType( 

448 "YamlValueView", 

449 Union[ 

450 YamlLeafValue, 

451 Sequence["YamlValueView"], 

452 Mapping[YamlKey, "YamlValueView"], 

453 ], 

454 ) 

455 

456BioimageioYamlContent = Dict[str, YamlValue] 

457BioimageioYamlContentView = Mapping[str, YamlValueView] 

458BioimageioYamlSource = Union[ 

459 PermissiveFileSource, ZipFile, BioimageioYamlContent, BioimageioYamlContentView 

460] 

461 

462 

463@overload 

464def deepcopy_yaml_value(value: BioimageioYamlContentView) -> BioimageioYamlContent: ... 

465 

466 

467@overload 

468def deepcopy_yaml_value(value: YamlValueView) -> YamlValue: ... 

469 

470 

471def deepcopy_yaml_value( 

472 value: Union[BioimageioYamlContentView, YamlValueView], 

473) -> Union[BioimageioYamlContent, YamlValue]: 

474 if isinstance(value, str): 

475 return value 

476 elif isinstance(value, collections.abc.Mapping): 

477 return {key: deepcopy_yaml_value(val) for key, val in value.items()} 

478 elif isinstance(value, collections.abc.Sequence): 

479 return [deepcopy_yaml_value(val) for val in value] 

480 else: 

481 return value 

482 

483 

484def is_yaml_leaf_value(value: Any) -> TypeGuard[YamlLeafValue]: 

485 return isinstance(value, (bool, _date, _datetime, int, float, str, type(None))) 

486 

487 

488def is_yaml_list(value: Any) -> TypeGuard[List[YamlValue]]: 

489 return is_list(value) and all(is_yaml_value(item) for item in value) 

490 

491 

492def is_yaml_sequence(value: Any) -> TypeGuard[List[YamlValueView]]: 

493 return is_sequence(value) and all(is_yaml_value(item) for item in value) 

494 

495 

496def is_yaml_dict(value: Any) -> TypeGuard[BioimageioYamlContent]: 

497 return is_dict(value) and all( 

498 isinstance(key, str) and is_yaml_value(val) for key, val in value.items() 

499 ) 

500 

501 

502def is_yaml_mapping(value: Any) -> TypeGuard[BioimageioYamlContentView]: 

503 return is_mapping(value) and all( 

504 isinstance(key, str) and is_yaml_value_read_only(val) 

505 for key, val in value.items() 

506 ) 

507 

508 

509def is_yaml_value(value: Any) -> TypeGuard[YamlValue]: 

510 return is_yaml_leaf_value(value) or is_yaml_list(value) or is_yaml_dict(value) 

511 

512 

513def is_yaml_value_read_only(value: Any) -> TypeGuard[YamlValueView]: 

514 return ( 

515 is_yaml_leaf_value(value) or is_yaml_sequence(value) or is_yaml_mapping(value) 

516 ) 

517 

518 

519@dataclass(frozen=True, **SLOTS) 

520class OpenedBioimageioYaml: 

521 content: BioimageioYamlContent = field(repr=False) 

522 original_root: Union[AbsoluteDirectory, RootHttpUrl, ZipFile] 

523 original_source_name: Optional[str] 

524 original_file_name: FileName 

525 unparsed_content: str = field(repr=False) 

526 

527 

528@dataclass(frozen=True, **SLOTS) 

529class LocalFile: 

530 path: FilePath 

531 original_root: Union[AbsoluteDirectory, RootHttpUrl, ZipFile] 

532 original_file_name: FileName 

533 

534 

535@dataclass(frozen=True, **SLOTS) 

536class FileInZip: 

537 path: ZipPath 

538 original_root: Union[RootHttpUrl, ZipFile] 

539 original_file_name: FileName 

540 

541 

542class HashKwargs(TypedDict): 

543 sha256: NotRequired[Optional[Sha256]] 

544 

545 

546_file_source_adapter: TypeAdapter[Union[HttpUrl, RelativeFilePath, FilePath]] = ( 

547 TypeAdapter(FileSource) 

548) 

549 

550 

551def interprete_file_source(file_source: PermissiveFileSource) -> FileSource: 

552 if isinstance(file_source, Path): 

553 if file_source.is_dir(): 

554 raise FileNotFoundError( 

555 f"{file_source} is a directory, but expected a file." 

556 ) 

557 return file_source 

558 

559 if isinstance(file_source, HttpUrl): 

560 return file_source 

561 

562 if isinstance(file_source, pydantic.AnyUrl): 

563 file_source = str(file_source) 

564 

565 with get_validation_context().replace(perform_io_checks=False): 

566 strict = _file_source_adapter.validate_python(file_source) 

567 if isinstance(strict, Path) and strict.is_dir(): 

568 raise FileNotFoundError(f"{strict} is a directory, but expected a file.") 

569 

570 return strict 

571 

572 

573def extract( 

574 source: Union[FilePath, ZipFile, ZipPath], 

575 folder: Optional[DirectoryPath] = None, 

576 overwrite: bool = False, 

577) -> DirectoryPath: 

578 extract_member = None 

579 if isinstance(source, ZipPath): 

580 extract_member = source.at 

581 source = source.root 

582 

583 if isinstance(source, ZipFile): 

584 zip_context = nullcontext(source) 

585 if folder is None: 

586 if source.filename is None: 

587 folder = Path(mkdtemp()) 

588 else: 

589 zip_path = Path(source.filename) 

590 folder = zip_path.with_suffix(zip_path.suffix + ".unzip") 

591 else: 

592 zip_context = ZipFile(source, "r") 

593 if folder is None: 

594 folder = source.with_suffix(source.suffix + ".unzip") 

595 

596 if overwrite and folder.exists(): 

597 warnings.warn(f"Overwriting existing unzipped archive at {folder}") 

598 

599 with zip_context as f: 

600 if extract_member is not None: 

601 extracted_file_path = folder / extract_member 

602 if extracted_file_path.exists() and not overwrite: 

603 warnings.warn(f"Found unzipped {extracted_file_path}.") 

604 else: 

605 _ = f.extract(extract_member, folder) 

606 

607 return folder 

608 

609 elif overwrite or not folder.exists(): 

610 f.extractall(folder) 

611 return folder 

612 

613 found_content = {p.relative_to(folder).as_posix() for p in folder.glob("*")} 

614 expected_content = {info.filename for info in f.filelist} 

615 if expected_missing := expected_content - found_content: 

616 parts = folder.name.split("_") 

617 nr, *suffixes = parts[-1].split(".") 

618 if nr.isdecimal(): 

619 nr = str(int(nr) + 1) 

620 else: 

621 nr = f"1.{nr}" 

622 

623 parts[-1] = ".".join([nr, *suffixes]) 

624 out_path_new = folder.with_name("_".join(parts)) 

625 warnings.warn( 

626 f"Unzipped archive at {folder} is missing expected files" 

627 + f" {expected_missing}." 

628 + f" Unzipping to {out_path_new} instead to avoid overwriting." 

629 ) 

630 return extract(f, out_path_new, overwrite=overwrite) 

631 else: 

632 warnings.warn( 

633 f"Found unzipped archive with all expected files at {folder}." 

634 ) 

635 return folder 

636 

637 

638def get_reader( 

639 source: Union[PermissiveFileSource, FileDescr, ZipPath], 

640 /, 

641 progressbar: Union[Progressbar, Callable[[], Progressbar], bool, None] = None, 

642 **kwargs: Unpack[HashKwargs], 

643) -> BytesReader: 

644 """Open a file `source` (download if needed)""" 

645 if isinstance(source, FileDescr): 

646 if "sha256" not in kwargs: 

647 kwargs["sha256"] = source.sha256 

648 

649 source = source.source 

650 elif isinstance(source, str): 

651 source = interprete_file_source(source) 

652 

653 if isinstance(source, RelativeFilePath): 

654 source = source.absolute() 

655 elif isinstance(source, pydantic.AnyUrl): 

656 with get_validation_context().replace(perform_io_checks=False): 

657 source = HttpUrl(source) 

658 

659 if isinstance(source, HttpUrl): 

660 return _open_url(source, progressbar=progressbar, **kwargs) 

661 

662 if isinstance(source, ZipPath): 

663 if not source.exists(): 

664 raise FileNotFoundError(source) 

665 

666 f = source.open(mode="rb") 

667 assert not isinstance(f, TextIOWrapper) 

668 root = source.root 

669 elif isinstance(source, Path): 

670 if source.is_dir(): 

671 raise FileNotFoundError(f"{source} is a directory, not a file") 

672 

673 if not source.exists(): 

674 raise FileNotFoundError(source) 

675 

676 f = source.open("rb") 

677 root = source.parent 

678 else: 

679 assert_never(source) 

680 

681 expected_sha = kwargs.get("sha256") 

682 if expected_sha is None: 

683 sha = None 

684 else: 

685 sha = get_sha256(f) 

686 _ = f.seek(0) 

687 if sha != expected_sha: 

688 raise ValueError( 

689 f"SHA256 mismatch for {source}. Expected {expected_sha}, got {sha}." 

690 ) 

691 

692 return BytesReader( 

693 f, 

694 sha256=sha, 

695 suffix=source.suffix, 

696 original_file_name=source.name, 

697 original_root=root, 

698 is_zipfile=None, 

699 ) 

700 

701 

702download = get_reader 

703 

704 

705def _open_url( 

706 source: HttpUrl, 

707 /, 

708 progressbar: Union[Progressbar, Callable[[], Progressbar], bool, None], 

709 **kwargs: Unpack[HashKwargs], 

710) -> BytesReader: 

711 cache = ( 

712 NoopCache[RootHttpUrl](url_hasher=UrlDigest.from_str) 

713 if get_validation_context().disable_cache 

714 else settings.disk_cache 

715 ) 

716 sha = kwargs.get("sha256") 

717 force_refetch = True if sha is None else ContentDigest.parse(hexdigest=sha) 

718 source_path = PurePosixPath( 

719 source.path 

720 or sha 

721 or hashlib.sha256(str(source).encode(encoding="utf-8")).hexdigest() 

722 ) 

723 

724 reader = cache.fetch( 

725 source, 

726 fetcher=partial(_fetch_url, progressbar=progressbar), 

727 force_refetch=force_refetch, 

728 ) 

729 return BytesReader( 

730 reader, 

731 suffix=source_path.suffix, 

732 sha256=sha, 

733 original_file_name=source_path.name, 

734 original_root=source.parent, 

735 is_zipfile=None, 

736 ) 

737 

738 

739def _fetch_url( 

740 source: RootHttpUrl, 

741 *, 

742 progressbar: Union[Progressbar, Callable[[], Progressbar], bool, None], 

743): 

744 if source.scheme not in ("http", "https"): 

745 raise NotImplementedError(source.scheme) 

746 

747 if progressbar is None: 

748 # chose progressbar option from validation context 

749 progressbar = get_validation_context().progressbar 

750 

751 if progressbar is None: 

752 # default to no progressbar in CI environments 

753 progressbar = not settings.CI 

754 

755 if callable(progressbar): 

756 progressbar = progressbar() 

757 

758 if isinstance(progressbar, bool) and progressbar: 

759 progressbar = tqdm( 

760 ncols=79, 

761 ascii=bool(sys.platform == "win32"), 

762 unit="B", 

763 unit_scale=True, 

764 leave=True, 

765 ) 

766 

767 if progressbar is not False: 

768 progressbar.set_description(f"Downloading {extract_file_name(source)}") 

769 

770 headers: Dict[str, str] = {} 

771 if settings.user_agent is not None: 

772 headers["User-Agent"] = settings.user_agent 

773 elif settings.CI: 

774 headers["User-Agent"] = "ci" 

775 

776 r = httpx.get( 

777 str(source), 

778 follow_redirects=True, 

779 headers=headers, 

780 timeout=settings.http_timeout, 

781 ) 

782 _ = r.raise_for_status() 

783 

784 # set progressbar.total 

785 total = r.headers.get("content-length") 

786 if total is not None and not isinstance(total, int): 

787 try: 

788 total = int(total) 

789 except Exception: 

790 total = None 

791 

792 if progressbar is not False: 

793 if total is None: 

794 progressbar.total = 0 

795 else: 

796 progressbar.total = total 

797 

798 def iter_content(): 

799 for chunk in r.iter_bytes(chunk_size=4096): 

800 yield chunk 

801 if progressbar is not False: 

802 _ = progressbar.update(len(chunk)) 

803 

804 # Make sure the progress bar gets filled even if the actual number 

805 # is chunks is smaller than expected. This happens when streaming 

806 # text files that are compressed by the server when sending (gzip). 

807 # Binary files don't experience this. 

808 # (adapted from pooch.HttpDownloader) 

809 if progressbar is not False: 

810 progressbar.reset() 

811 if total is not None: 

812 _ = progressbar.update(total) 

813 

814 progressbar.close() 

815 

816 return iter_content() 

817 

818 

819def extract_file_name( 

820 src: Union[ 

821 pydantic.HttpUrl, RootHttpUrl, PurePath, RelativeFilePath, ZipPath, FileDescr 

822 ], 

823) -> FileName: 

824 if isinstance(src, FileDescr): 

825 src = src.source 

826 

827 if isinstance(src, ZipPath): 

828 return src.name or src.root.filename or "bioimageio.zip" 

829 elif isinstance(src, RelativeFilePath): 

830 return src.path.name 

831 elif isinstance(src, PurePath): 

832 return src.name 

833 else: 

834 url = urlparse(str(src)) 

835 if ( 

836 url.scheme == "https" 

837 and url.hostname == "zenodo.org" 

838 and url.path.startswith("/api/records/") 

839 and url.path.endswith("/content") 

840 ): 

841 return url.path.split("/")[-2] 

842 else: 

843 return url.path.split("/")[-1] 

844 

845 

846def extract_file_descrs(data: YamlValueView): 

847 collected: List[FileDescr] = [] 

848 with get_validation_context().replace(perform_io_checks=False, log_warnings=False): 

849 _extract_file_descrs_impl(data, collected) 

850 

851 return collected 

852 

853 

854def _extract_file_descrs_impl(data: YamlValueView, collected: List[FileDescr]): 

855 if isinstance(data, collections.abc.Mapping): 

856 if "source" in data and "sha256" in data: 

857 try: 

858 fd = FileDescr.model_validate( 

859 dict(source=data["source"], sha256=data["sha256"]) 

860 ) 

861 except Exception: 

862 pass 

863 else: 

864 collected.append(fd) 

865 

866 for v in data.values(): 

867 _extract_file_descrs_impl(v, collected) 

868 elif not isinstance(data, str) and isinstance(data, collections.abc.Sequence): 

869 for v in data: 

870 _extract_file_descrs_impl(v, collected) 

871 

872 

873F = TypeVar("F", bound=Union[FileSource, FileDescr]) 

874 

875 

876def validate_suffix( 

877 value: F, suffix: Union[str, Sequence[str]], case_sensitive: bool 

878) -> F: 

879 """check final suffix""" 

880 if isinstance(suffix, str): 

881 suffixes = [suffix] 

882 else: 

883 suffixes = suffix 

884 

885 assert len(suffixes) > 0, "no suffix given" 

886 assert all(suff.startswith(".") for suff in suffixes), ( 

887 "expected suffixes to start with '.'" 

888 ) 

889 o_value = value 

890 if isinstance(value, FileDescr): 

891 strict = value.source 

892 else: 

893 strict = interprete_file_source(value) 

894 

895 if isinstance(strict, (HttpUrl, AnyUrl)): 

896 if strict.path is None or "." not in (path := strict.path): 

897 actual_suffixes = [] 

898 else: 

899 if ( 

900 strict.host == "zenodo.org" 

901 and path.startswith("/api/records/") 

902 and path.endswith("/content") 

903 ): 

904 # Zenodo API URLs have a "/content" suffix that should be ignored 

905 path = path[: -len("/content")] 

906 

907 actual_suffixes = [f".{path.split('.')[-1]}"] 

908 

909 elif isinstance(strict, PurePath): 

910 actual_suffixes = strict.suffixes 

911 elif isinstance(strict, RelativeFilePath): 

912 actual_suffixes = strict.path.suffixes 

913 else: 

914 assert_never(strict) 

915 

916 if actual_suffixes: 

917 actual_suffix = actual_suffixes[-1] 

918 else: 

919 actual_suffix = "no suffix" 

920 

921 if ( 

922 case_sensitive 

923 and actual_suffix not in suffixes 

924 or not case_sensitive 

925 and actual_suffix.lower() not in [s.lower() for s in suffixes] 

926 ): 

927 if len(suffixes) == 1: 

928 raise ValueError(f"Expected suffix {suffixes[0]}, but got {actual_suffix}") 

929 else: 

930 raise ValueError( 

931 f"Expected a suffix from {suffixes}, but got {actual_suffix}" 

932 ) 

933 

934 return o_value 

935 

936 

937def populate_cache(sources: Sequence[Union[FileDescr, LightHttpFileDescr]]): 

938 unique: Set[str] = set() 

939 for src in sources: 

940 if src.sha256 is None: 

941 continue # not caching without known SHA 

942 

943 if isinstance(src.source, (HttpUrl, pydantic.AnyUrl)): 

944 url = str(src.source) 

945 elif isinstance(src.source, RelativeFilePath): 

946 if isinstance(absolute := src.source.absolute(), HttpUrl): 

947 url = str(absolute) 

948 else: 

949 continue # not caching local paths 

950 elif isinstance(src.source, Path): 

951 continue # not caching local paths 

952 else: 

953 assert_never(src.source) 

954 

955 if url in unique: 

956 continue # skip duplicate URLs 

957 

958 unique.add(url) 

959 _ = src.download()