Coverage for src / bioimageio / spec / _internal / io.py: 79%
447 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-08 13:04 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-08 13:04 +0000
1from __future__ import annotations
3import collections.abc
4import hashlib
5import sys
6import warnings
7import zipfile
8from abc import abstractmethod
9from contextlib import nullcontext
10from dataclasses import dataclass, field
11from datetime import date as _date
12from datetime import datetime as _datetime
13from functools import partial
14from io import TextIOWrapper
15from pathlib import Path, PurePath, PurePosixPath
16from tempfile import mkdtemp
17from typing import (
18 TYPE_CHECKING,
19 Any,
20 Callable,
21 Dict,
22 Generic,
23 Iterable,
24 List,
25 Mapping,
26 Optional,
27 Sequence,
28 Set,
29 Tuple,
30 Type,
31 TypedDict,
32 TypeVar,
33 Union,
34 overload,
35)
36from urllib.parse import urlparse, urlsplit, urlunsplit
37from zipfile import ZipFile
39import httpx
40import pydantic
41from genericache import NoopCache
42from genericache.digest import ContentDigest, UrlDigest
43from pydantic import (
44 AnyUrl,
45 DirectoryPath,
46 Field,
47 GetCoreSchemaHandler,
48 PrivateAttr,
49 RootModel,
50 TypeAdapter,
51 model_serializer,
52 model_validator,
53)
54from pydantic_core import core_schema
55from tqdm import tqdm
56from typing_extensions import (
57 Annotated,
58 LiteralString,
59 NotRequired,
60 Self,
61 TypeGuard,
62 Unpack,
63 assert_never,
64)
65from typing_extensions import TypeAliasType as _TypeAliasType
67from ._settings import settings
68from .io_basics import (
69 ALL_BIOIMAGEIO_YAML_NAMES,
70 ALTERNATIVE_BIOIMAGEIO_YAML_NAMES,
71 BIOIMAGEIO_YAML,
72 AbsoluteDirectory,
73 AbsoluteFilePath,
74 BytesReader,
75 FileName,
76 FilePath,
77 Sha256,
78 ZipPath,
79 get_sha256,
80)
81from .node import Node
82from .progress import Progressbar
83from .root_url import RootHttpUrl
84from .type_guards import is_dict, is_list, is_mapping, is_sequence
85from .url import HttpUrl
86from .utils import SLOTS
87from .validation_context import get_validation_context
89AbsolutePathT = TypeVar(
90 "AbsolutePathT",
91 bound=Union[HttpUrl, AbsoluteDirectory, AbsoluteFilePath, ZipPath],
92)
95class LightHttpFileDescr(Node):
96 """http source with sha256 value (minimal validation)"""
98 source: pydantic.HttpUrl
99 """file source"""
101 sha256: Sha256
102 """SHA256 checksum of the source file"""
104 def get_reader(
105 self,
106 *,
107 progressbar: Union[Progressbar, Callable[[], Progressbar], bool, None] = None,
108 ) -> BytesReader:
109 """open the file source (download if needed)"""
110 return get_reader(self.source, sha256=self.sha256, progressbar=progressbar)
112 download = get_reader
113 """alias for get_reader() method"""
116class RelativePathBase(RootModel[PurePath], Generic[AbsolutePathT], frozen=True):
117 _absolute: AbsolutePathT = PrivateAttr()
119 @property
120 def path(self) -> PurePath:
121 return self.root
123 def absolute( # method not property analog to `pathlib.Path.absolute()`
124 self,
125 ) -> AbsolutePathT:
126 """get the absolute path/url
128 (resolved at time of initialization with the root of the ValidationContext)
129 """
130 return self._absolute
132 def model_post_init(self, __context: Any) -> None:
133 """set `_absolute` property with validation context at creation time. @private"""
134 if self.root.is_absolute():
135 raise ValueError(f"{self.root} is an absolute path.")
137 if self.root.parts and self.root.parts[0] in ("http:", "https:"):
138 raise ValueError(f"{self.root} looks like an http url.")
140 self._absolute = ( # pyright: ignore[reportAttributeAccessIssue]
141 self.get_absolute(get_validation_context().root)
142 )
143 super().model_post_init(__context)
145 def __str__(self) -> str:
146 return self.root.as_posix()
148 def __repr__(self) -> str:
149 return f"RelativePath('{self}')"
151 @model_serializer()
152 def format(self) -> str:
153 return str(self)
155 @abstractmethod
156 def get_absolute(
157 self, root: Union[RootHttpUrl, AbsoluteDirectory, pydantic.AnyUrl, ZipFile]
158 ) -> AbsolutePathT: ...
160 def _get_absolute_impl(
161 self, root: Union[RootHttpUrl, AbsoluteDirectory, pydantic.AnyUrl, ZipFile]
162 ) -> Union[Path, HttpUrl, ZipPath]:
163 if isinstance(root, Path):
164 return (root / self.root).absolute()
166 rel_path = self.root.as_posix().strip("/")
167 if isinstance(root, ZipFile):
168 return ZipPath(root, rel_path)
170 parsed = urlsplit(str(root))
171 path = list(parsed.path.strip("/").split("/"))
172 if (
173 parsed.netloc == "zenodo.org"
174 and parsed.path.startswith("/api/records/")
175 and parsed.path.endswith("/content")
176 ):
177 path.insert(-1, rel_path)
178 else:
179 path.append(rel_path)
181 return HttpUrl(
182 urlunsplit(
183 (
184 parsed.scheme,
185 parsed.netloc,
186 "/".join(path),
187 parsed.query,
188 parsed.fragment,
189 )
190 )
191 )
193 @classmethod
194 def _validate(cls, value: Union[PurePath, str]):
195 if isinstance(value, str) and (
196 value.startswith("https://") or value.startswith("http://")
197 ):
198 raise ValueError(f"{value} looks like a URL, not a relative path")
200 return cls(PurePath(value))
203class RelativeFilePath(
204 RelativePathBase[Union[AbsoluteFilePath, HttpUrl, ZipPath]], frozen=True
205):
206 """A path relative to the `rdf.yaml` file (also if the RDF source is a URL)."""
208 def model_post_init(self, __context: Any) -> None:
209 """add validation @private"""
210 if not self.root.parts: # an empty path can only be a directory
211 raise ValueError(f"{self.root} is not a valid file path.")
213 super().model_post_init(__context)
215 def get_absolute(
216 self, root: "RootHttpUrl | Path | AnyUrl | ZipFile"
217 ) -> "AbsoluteFilePath | HttpUrl | ZipPath":
218 absolute = self._get_absolute_impl(root)
219 if (
220 isinstance(absolute, Path)
221 and (context := get_validation_context()).perform_io_checks
222 and str(self.root) not in context.known_files
223 and not absolute.is_file()
224 ):
225 raise ValueError(f"{absolute} does not point to an existing file")
227 return absolute
230class RelativeDirectory(
231 RelativePathBase[Union[AbsoluteDirectory, HttpUrl, ZipPath]], frozen=True
232):
233 def get_absolute(
234 self, root: "RootHttpUrl | Path | AnyUrl | ZipFile"
235 ) -> "AbsoluteDirectory | HttpUrl | ZipPath":
236 absolute = self._get_absolute_impl(root)
237 if (
238 isinstance(absolute, Path)
239 and get_validation_context().perform_io_checks
240 and not absolute.is_dir()
241 ):
242 raise ValueError(f"{absolute} does not point to an existing directory")
244 return absolute
247FileSource = Annotated[
248 Union[HttpUrl, RelativeFilePath, FilePath],
249 Field(union_mode="left_to_right"),
250]
251PermissiveFileSource = Union[FileSource, str, pydantic.HttpUrl]
254class FileDescr(Node):
255 """A file description"""
257 source: FileSource
258 """File source"""
260 sha256: Optional[Sha256] = None
261 """SHA256 hash value of the **source** file."""
263 @model_validator(mode="after")
264 def _validate_sha256(self) -> Self:
265 if get_validation_context().perform_io_checks:
266 self.validate_sha256()
268 return self
270 def validate_sha256(self, force_recompute: bool = False) -> None:
271 """validate the sha256 hash value of the **source** file"""
272 context = get_validation_context()
273 src_str = str(self.source)
274 if not force_recompute and src_str in context.known_files:
275 actual_sha = context.known_files[src_str]
276 else:
277 reader = get_reader(self.source, sha256=self.sha256)
278 if force_recompute:
279 actual_sha = get_sha256(reader)
280 else:
281 actual_sha = reader.sha256
283 context.known_files[src_str] = actual_sha
285 if actual_sha is None:
286 return
287 elif self.sha256 == actual_sha:
288 pass
289 elif self.sha256 is None or context.update_hashes:
290 self.sha256 = actual_sha
291 elif self.sha256 != actual_sha:
292 raise ValueError(
293 f"Sha256 mismatch for {self.source}. Expected {self.sha256}, got "
294 + f"{actual_sha}. Update expected `sha256` or point to the matching "
295 + "file."
296 )
298 def get_reader(
299 self,
300 *,
301 progressbar: Union[Progressbar, Callable[[], Progressbar], bool, None] = None,
302 ):
303 """open the file source (download if needed)"""
304 return get_reader(self.source, progressbar=progressbar, sha256=self.sha256)
306 def download(
307 self,
308 *,
309 progressbar: Union[Progressbar, Callable[[], Progressbar], bool, None] = None,
310 ):
311 """alias for `.get_reader`"""
312 return get_reader(self.source, progressbar=progressbar, sha256=self.sha256)
315path_or_url_adapter: "TypeAdapter[Union[FilePath, DirectoryPath, HttpUrl]]" = (
316 TypeAdapter(Union[FilePath, DirectoryPath, HttpUrl])
317)
320@dataclass(frozen=True, **SLOTS)
321class WithSuffix:
322 suffix: Union[LiteralString, Tuple[LiteralString, ...]]
323 case_sensitive: bool
325 def __get_pydantic_core_schema__(
326 self, source: Type[Any], handler: GetCoreSchemaHandler
327 ):
328 if not self.suffix:
329 raise ValueError("suffix may not be empty")
331 schema = handler(source)
332 return core_schema.no_info_after_validator_function(
333 self.validate,
334 schema,
335 )
337 def validate(
338 self, value: Union[FileSource, FileDescr]
339 ) -> Union[FileSource, FileDescr]:
340 return validate_suffix(value, self.suffix, case_sensitive=self.case_sensitive)
343def wo_special_file_name(src: F) -> F:
344 if has_valid_bioimageio_yaml_name(src):
345 raise ValueError(
346 f"'{src}' not allowed here as its filename is reserved to identify"
347 + f" '{BIOIMAGEIO_YAML}' (or equivalent) files."
348 )
350 return src
353def has_valid_bioimageio_yaml_name(src: Union[FileSource, FileDescr]) -> bool:
354 return is_valid_bioimageio_yaml_name(extract_file_name(src))
357def is_valid_bioimageio_yaml_name(file_name: FileName) -> bool:
358 for bioimageio_name in ALL_BIOIMAGEIO_YAML_NAMES:
359 if file_name == bioimageio_name or file_name.endswith("." + bioimageio_name):
360 return True
362 return False
365def identify_bioimageio_yaml_file_name(file_names: Iterable[FileName]) -> FileName:
366 file_names = sorted(file_names)
367 for bioimageio_name in ALL_BIOIMAGEIO_YAML_NAMES:
368 for file_name in file_names:
369 if file_name == bioimageio_name or file_name.endswith(
370 "." + bioimageio_name
371 ):
372 return file_name
374 raise ValueError(
375 f"No {BIOIMAGEIO_YAML} found in {file_names}. (Looking for '{BIOIMAGEIO_YAML}'"
376 + " or or any of the alterntive file names:"
377 + f" {ALTERNATIVE_BIOIMAGEIO_YAML_NAMES}, or any file with an extension of"
378 + f" those, e.g. 'anything.{BIOIMAGEIO_YAML}')."
379 )
382def find_bioimageio_yaml_file_name(path: Union[Path, ZipFile]) -> FileName:
383 if isinstance(path, ZipFile):
384 file_names = path.namelist()
385 elif path.is_file():
386 if not zipfile.is_zipfile(path):
387 return path.name
389 with ZipFile(path, "r") as f:
390 file_names = f.namelist()
391 else:
392 file_names = [p.name for p in path.glob("*")]
394 return identify_bioimageio_yaml_file_name(
395 file_names
396 ) # TODO: try/except with better error message for dir
399def ensure_has_valid_bioimageio_yaml_name(src: FileSource) -> FileSource:
400 if not has_valid_bioimageio_yaml_name(src):
401 raise ValueError(
402 f"'{src}' does not have a valid filename to identify"
403 + f" '{BIOIMAGEIO_YAML}' (or equivalent) files."
404 )
406 return src
409def ensure_is_valid_bioimageio_yaml_name(file_name: FileName) -> FileName:
410 if not is_valid_bioimageio_yaml_name(file_name):
411 raise ValueError(
412 f"'{file_name}' is not a valid filename to identify"
413 + f" '{BIOIMAGEIO_YAML}' (or equivalent) files."
414 )
416 return file_name
419# types as loaded from YAML 1.2 (with ruyaml)
420YamlLeafValue = Union[
421 bool, _date, _datetime, int, float, str, None
422] # note: order relevant for deserializing
423YamlKey = Union[ # YAML Arrays are cast to tuples if used as key in mappings
424 YamlLeafValue, Tuple[YamlLeafValue, ...] # (nesting is not allowed though)
425]
426if TYPE_CHECKING:
427 YamlValue = Union[YamlLeafValue, List["YamlValue"], Dict[YamlKey, "YamlValue"]]
428 YamlValueView = Union[
429 YamlLeafValue, Sequence["YamlValueView"], Mapping[YamlKey, "YamlValueView"]
430 ]
431else:
432 # for pydantic validation we need to use `TypeAliasType`,
433 # see https://docs.pydantic.dev/latest/concepts/types/#named-recursive-types
434 # however this results in a partially unknown type with the current pyright 1.1.388
435 YamlValue = _TypeAliasType(
436 "YamlValue",
437 Union[YamlLeafValue, List["YamlValue"], Dict[YamlKey, "YamlValue"]],
438 )
439 YamlValueView = _TypeAliasType(
440 "YamlValueView",
441 Union[
442 YamlLeafValue,
443 Sequence["YamlValueView"],
444 Mapping[YamlKey, "YamlValueView"],
445 ],
446 )
448BioimageioYamlContent = Dict[str, YamlValue]
449BioimageioYamlContentView = Mapping[str, YamlValueView]
450BioimageioYamlSource = Union[
451 PermissiveFileSource, ZipFile, BioimageioYamlContent, BioimageioYamlContentView
452]
455@overload
456def deepcopy_yaml_value(value: BioimageioYamlContentView) -> BioimageioYamlContent: ...
459@overload
460def deepcopy_yaml_value(value: YamlValueView) -> YamlValue: ...
463def deepcopy_yaml_value(
464 value: Union[BioimageioYamlContentView, YamlValueView],
465) -> Union[BioimageioYamlContent, YamlValue]:
466 if isinstance(value, str):
467 return value
468 elif isinstance(value, collections.abc.Mapping):
469 return {key: deepcopy_yaml_value(val) for key, val in value.items()}
470 elif isinstance(value, collections.abc.Sequence):
471 return [deepcopy_yaml_value(val) for val in value]
472 else:
473 return value
476def is_yaml_leaf_value(value: Any) -> TypeGuard[YamlLeafValue]:
477 return isinstance(value, (bool, _date, _datetime, int, float, str, type(None)))
480def is_yaml_list(value: Any) -> TypeGuard[List[YamlValue]]:
481 return is_list(value) and all(is_yaml_value(item) for item in value)
484def is_yaml_sequence(value: Any) -> TypeGuard[List[YamlValueView]]:
485 return is_sequence(value) and all(is_yaml_value(item) for item in value)
488def is_yaml_dict(value: Any) -> TypeGuard[BioimageioYamlContent]:
489 return is_dict(value) and all(
490 isinstance(key, str) and is_yaml_value(val) for key, val in value.items()
491 )
494def is_yaml_mapping(value: Any) -> TypeGuard[BioimageioYamlContentView]:
495 return is_mapping(value) and all(
496 isinstance(key, str) and is_yaml_value_read_only(val)
497 for key, val in value.items()
498 )
501def is_yaml_value(value: Any) -> TypeGuard[YamlValue]:
502 return is_yaml_leaf_value(value) or is_yaml_list(value) or is_yaml_dict(value)
505def is_yaml_value_read_only(value: Any) -> TypeGuard[YamlValueView]:
506 return (
507 is_yaml_leaf_value(value) or is_yaml_sequence(value) or is_yaml_mapping(value)
508 )
511@dataclass(frozen=True, **SLOTS)
512class OpenedBioimageioYaml:
513 content: BioimageioYamlContent = field(repr=False)
514 original_root: Union[AbsoluteDirectory, RootHttpUrl, ZipFile]
515 original_source_name: Optional[str]
516 original_file_name: FileName
517 unparsed_content: str = field(repr=False)
520@dataclass(frozen=True, **SLOTS)
521class LocalFile:
522 path: FilePath
523 original_root: Union[AbsoluteDirectory, RootHttpUrl, ZipFile]
524 original_file_name: FileName
527@dataclass(frozen=True, **SLOTS)
528class FileInZip:
529 path: ZipPath
530 original_root: Union[RootHttpUrl, ZipFile]
531 original_file_name: FileName
534class HashKwargs(TypedDict):
535 sha256: NotRequired[Optional[Sha256]]
538_file_source_adapter: TypeAdapter[Union[HttpUrl, RelativeFilePath, FilePath]] = (
539 TypeAdapter(FileSource)
540)
543def interprete_file_source(file_source: PermissiveFileSource) -> FileSource:
544 if isinstance(file_source, Path):
545 if file_source.is_dir():
546 raise FileNotFoundError(
547 f"{file_source} is a directory, but expected a file."
548 )
549 return file_source
551 if isinstance(file_source, HttpUrl):
552 return file_source
554 if isinstance(file_source, pydantic.AnyUrl):
555 file_source = str(file_source)
557 with get_validation_context().replace(perform_io_checks=False):
558 strict = _file_source_adapter.validate_python(file_source)
559 if isinstance(strict, Path) and strict.is_dir():
560 raise FileNotFoundError(f"{strict} is a directory, but expected a file.")
562 return strict
565def extract(
566 source: Union[FilePath, ZipFile, ZipPath],
567 folder: Optional[DirectoryPath] = None,
568 overwrite: bool = False,
569) -> DirectoryPath:
570 extract_member = None
571 if isinstance(source, ZipPath):
572 extract_member = source.at
573 source = source.root
575 if isinstance(source, ZipFile):
576 zip_context = nullcontext(source)
577 if folder is None:
578 if source.filename is None:
579 folder = Path(mkdtemp())
580 else:
581 zip_path = Path(source.filename)
582 folder = zip_path.with_suffix(zip_path.suffix + ".unzip")
583 else:
584 zip_context = ZipFile(source, "r")
585 if folder is None:
586 folder = source.with_suffix(source.suffix + ".unzip")
588 if overwrite and folder.exists():
589 warnings.warn(f"Overwriting existing unzipped archive at {folder}")
591 with zip_context as f:
592 if extract_member is not None:
593 extracted_file_path = folder / extract_member
594 if extracted_file_path.exists() and not overwrite:
595 warnings.warn(f"Found unzipped {extracted_file_path}.")
596 else:
597 _ = f.extract(extract_member, folder)
599 return folder
601 elif overwrite or not folder.exists():
602 f.extractall(folder)
603 return folder
605 found_content = {p.relative_to(folder).as_posix() for p in folder.glob("*")}
606 expected_content = {info.filename for info in f.filelist}
607 if expected_missing := expected_content - found_content:
608 parts = folder.name.split("_")
609 nr, *suffixes = parts[-1].split(".")
610 if nr.isdecimal():
611 nr = str(int(nr) + 1)
612 else:
613 nr = f"1.{nr}"
615 parts[-1] = ".".join([nr, *suffixes])
616 out_path_new = folder.with_name("_".join(parts))
617 warnings.warn(
618 f"Unzipped archive at {folder} is missing expected files"
619 + f" {expected_missing}."
620 + f" Unzipping to {out_path_new} instead to avoid overwriting."
621 )
622 return extract(f, out_path_new, overwrite=overwrite)
623 else:
624 warnings.warn(
625 f"Found unzipped archive with all expected files at {folder}."
626 )
627 return folder
630def get_reader(
631 source: Union[PermissiveFileSource, FileDescr, ZipPath],
632 /,
633 progressbar: Union[Progressbar, Callable[[], Progressbar], bool, None] = None,
634 **kwargs: Unpack[HashKwargs],
635) -> BytesReader:
636 """Open a file `source` (download if needed)"""
637 if isinstance(source, FileDescr):
638 if "sha256" not in kwargs:
639 kwargs["sha256"] = source.sha256
641 source = source.source
642 elif isinstance(source, str):
643 source = interprete_file_source(source)
645 if isinstance(source, RelativeFilePath):
646 source = source.absolute()
647 elif isinstance(source, pydantic.AnyUrl):
648 with get_validation_context().replace(perform_io_checks=False):
649 source = HttpUrl(source)
651 if isinstance(source, HttpUrl):
652 return _open_url(source, progressbar=progressbar, **kwargs)
654 if isinstance(source, ZipPath):
655 if not source.exists():
656 raise FileNotFoundError(source)
658 f = source.open(mode="rb")
659 assert not isinstance(f, TextIOWrapper)
660 root = source.root
661 elif isinstance(source, Path):
662 if source.is_dir():
663 raise FileNotFoundError(f"{source} is a directory, not a file")
665 if not source.exists():
666 raise FileNotFoundError(source)
668 f = source.open("rb")
669 root = source.parent
670 else:
671 assert_never(source)
673 expected_sha = kwargs.get("sha256")
674 if expected_sha is None:
675 sha = None
676 else:
677 sha = get_sha256(f)
678 _ = f.seek(0)
679 if sha != expected_sha:
680 raise ValueError(
681 f"SHA256 mismatch for {source}. Expected {expected_sha}, got {sha}."
682 )
684 return BytesReader(
685 f,
686 sha256=sha,
687 suffix=source.suffix,
688 original_file_name=source.name,
689 original_root=root,
690 is_zipfile=None,
691 )
694download = get_reader
697def _open_url(
698 source: HttpUrl,
699 /,
700 progressbar: Union[Progressbar, Callable[[], Progressbar], bool, None],
701 **kwargs: Unpack[HashKwargs],
702) -> BytesReader:
703 cache = (
704 NoopCache[RootHttpUrl](url_hasher=UrlDigest.from_str)
705 if get_validation_context().disable_cache
706 else settings.disk_cache
707 )
708 sha = kwargs.get("sha256")
709 digest = False if sha is None else ContentDigest.parse(hexdigest=sha)
710 source_path = PurePosixPath(
711 source.path
712 or sha
713 or hashlib.sha256(str(source).encode(encoding="utf-8")).hexdigest()
714 )
716 reader = cache.fetch(
717 source,
718 fetcher=partial(_fetch_url, progressbar=progressbar),
719 force_refetch=digest,
720 )
721 return BytesReader(
722 reader,
723 suffix=source_path.suffix,
724 sha256=sha,
725 original_file_name=source_path.name,
726 original_root=source.parent,
727 is_zipfile=None,
728 )
731def _fetch_url(
732 source: RootHttpUrl,
733 *,
734 progressbar: Union[Progressbar, Callable[[], Progressbar], bool, None],
735):
736 if source.scheme not in ("http", "https"):
737 raise NotImplementedError(source.scheme)
739 if progressbar is None:
740 # chose progressbar option from validation context
741 progressbar = get_validation_context().progressbar
743 if progressbar is None:
744 # default to no progressbar in CI environments
745 progressbar = not settings.CI
747 if callable(progressbar):
748 progressbar = progressbar()
750 if isinstance(progressbar, bool) and progressbar:
751 progressbar = tqdm(
752 ncols=79,
753 ascii=bool(sys.platform == "win32"),
754 unit="B",
755 unit_scale=True,
756 leave=True,
757 )
759 if progressbar is not False:
760 progressbar.set_description(f"Downloading {extract_file_name(source)}")
762 headers: Dict[str, str] = {}
763 if settings.user_agent is not None:
764 headers["User-Agent"] = settings.user_agent
765 elif settings.CI:
766 headers["User-Agent"] = "ci"
768 r = httpx.get(
769 str(source),
770 follow_redirects=True,
771 headers=headers,
772 timeout=settings.http_timeout,
773 )
774 _ = r.raise_for_status()
776 # set progressbar.total
777 total = r.headers.get("content-length")
778 if total is not None and not isinstance(total, int):
779 try:
780 total = int(total)
781 except Exception:
782 total = None
784 if progressbar is not False:
785 if total is None:
786 progressbar.total = 0
787 else:
788 progressbar.total = total
790 def iter_content():
791 for chunk in r.iter_bytes(chunk_size=4096):
792 yield chunk
793 if progressbar is not False:
794 _ = progressbar.update(len(chunk))
796 # Make sure the progress bar gets filled even if the actual number
797 # is chunks is smaller than expected. This happens when streaming
798 # text files that are compressed by the server when sending (gzip).
799 # Binary files don't experience this.
800 # (adapted from pooch.HttpDownloader)
801 if progressbar is not False:
802 progressbar.reset()
803 if total is not None:
804 _ = progressbar.update(total)
806 progressbar.close()
808 return iter_content()
811def extract_file_name(
812 src: Union[
813 pydantic.HttpUrl, RootHttpUrl, PurePath, RelativeFilePath, ZipPath, FileDescr
814 ],
815) -> FileName:
816 if isinstance(src, FileDescr):
817 src = src.source
819 if isinstance(src, ZipPath):
820 return src.name or src.root.filename or "bioimageio.zip"
821 elif isinstance(src, RelativeFilePath):
822 return src.path.name
823 elif isinstance(src, PurePath):
824 return src.name
825 else:
826 url = urlparse(str(src))
827 if (
828 url.scheme == "https"
829 and url.hostname == "zenodo.org"
830 and url.path.startswith("/api/records/")
831 and url.path.endswith("/content")
832 ):
833 return url.path.split("/")[-2]
834 else:
835 return url.path.split("/")[-1]
838def extract_file_descrs(data: YamlValueView):
839 collected: List[FileDescr] = []
840 with get_validation_context().replace(perform_io_checks=False, log_warnings=False):
841 _extract_file_descrs_impl(data, collected)
843 return collected
846def _extract_file_descrs_impl(data: YamlValueView, collected: List[FileDescr]):
847 if isinstance(data, collections.abc.Mapping):
848 if "source" in data and "sha256" in data:
849 try:
850 fd = FileDescr.model_validate(
851 dict(source=data["source"], sha256=data["sha256"])
852 )
853 except Exception:
854 pass
855 else:
856 collected.append(fd)
858 for v in data.values():
859 _extract_file_descrs_impl(v, collected)
860 elif not isinstance(data, str) and isinstance(data, collections.abc.Sequence):
861 for v in data:
862 _extract_file_descrs_impl(v, collected)
865F = TypeVar("F", bound=Union[FileSource, FileDescr])
868def validate_suffix(
869 value: F, suffix: Union[str, Sequence[str]], case_sensitive: bool
870) -> F:
871 """check final suffix"""
872 if isinstance(suffix, str):
873 suffixes = [suffix]
874 else:
875 suffixes = suffix
877 assert len(suffixes) > 0, "no suffix given"
878 assert all(suff.startswith(".") for suff in suffixes), (
879 "expected suffixes to start with '.'"
880 )
881 o_value = value
882 if isinstance(value, FileDescr):
883 strict = value.source
884 else:
885 strict = interprete_file_source(value)
887 if isinstance(strict, (HttpUrl, AnyUrl)):
888 if strict.path is None or "." not in (path := strict.path):
889 actual_suffixes = []
890 else:
891 if (
892 strict.host == "zenodo.org"
893 and path.startswith("/api/records/")
894 and path.endswith("/content")
895 ):
896 # Zenodo API URLs have a "/content" suffix that should be ignored
897 path = path[: -len("/content")]
899 actual_suffixes = [f".{path.split('.')[-1]}"]
901 elif isinstance(strict, PurePath):
902 actual_suffixes = strict.suffixes
903 elif isinstance(strict, RelativeFilePath):
904 actual_suffixes = strict.path.suffixes
905 else:
906 assert_never(strict)
908 if actual_suffixes:
909 actual_suffix = actual_suffixes[-1]
910 else:
911 actual_suffix = "no suffix"
913 if (
914 case_sensitive
915 and actual_suffix not in suffixes
916 or not case_sensitive
917 and actual_suffix.lower() not in [s.lower() for s in suffixes]
918 ):
919 if len(suffixes) == 1:
920 raise ValueError(f"Expected suffix {suffixes[0]}, but got {actual_suffix}")
921 else:
922 raise ValueError(
923 f"Expected a suffix from {suffixes}, but got {actual_suffix}"
924 )
926 return o_value
929def populate_cache(sources: Sequence[Union[FileDescr, LightHttpFileDescr]]):
930 unique: Set[str] = set()
931 for src in sources:
932 if src.sha256 is None:
933 continue # not caching without known SHA
935 if isinstance(src.source, (HttpUrl, pydantic.AnyUrl)):
936 url = str(src.source)
937 elif isinstance(src.source, RelativeFilePath):
938 if isinstance(absolute := src.source.absolute(), HttpUrl):
939 url = str(absolute)
940 else:
941 continue # not caching local paths
942 elif isinstance(src.source, Path):
943 continue # not caching local paths
944 else:
945 assert_never(src.source)
947 if url in unique:
948 continue # skip duplicate URLs
950 unique.add(url)
951 _ = src.download()