Coverage for src/bioimageio/spec/_internal/io.py: 79%
447 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-07 08:37 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-07 08:37 +0000
1from __future__ import annotations
3import collections.abc
4import hashlib
5import sys
6import warnings
7import zipfile
8from abc import abstractmethod
9from contextlib import nullcontext
10from dataclasses import dataclass, field
11from datetime import date as _date
12from datetime import datetime as _datetime
13from functools import partial
14from io import TextIOWrapper
15from pathlib import Path, PurePath, PurePosixPath
16from tempfile import mkdtemp
17from typing import (
18 TYPE_CHECKING,
19 Any,
20 Callable,
21 Dict,
22 Generic,
23 Iterable,
24 List,
25 Mapping,
26 Optional,
27 Sequence,
28 Set,
29 Tuple,
30 Type,
31 TypedDict,
32 TypeVar,
33 Union,
34 overload,
35)
36from urllib.parse import urlparse, urlsplit, urlunsplit
37from zipfile import ZipFile
39import httpx
40import pydantic
41from genericache import NoopCache
42from genericache.digest import ContentDigest, UrlDigest
43from pydantic import (
44 AnyUrl,
45 DirectoryPath,
46 Field,
47 GetCoreSchemaHandler,
48 PrivateAttr,
49 RootModel,
50 TypeAdapter,
51 model_serializer,
52 model_validator,
53)
54from pydantic_core import core_schema
55from tqdm import tqdm
56from typing_extensions import (
57 Annotated,
58 LiteralString,
59 NotRequired,
60 Self,
61 TypeGuard,
62 Unpack,
63 assert_never,
64)
65from typing_extensions import TypeAliasType as _TypeAliasType
67from ._settings import settings
68from .io_basics import (
69 ALL_BIOIMAGEIO_YAML_NAMES,
70 ALTERNATIVE_BIOIMAGEIO_YAML_NAMES,
71 BIOIMAGEIO_YAML,
72 AbsoluteDirectory,
73 AbsoluteFilePath,
74 BytesReader,
75 FileName,
76 FilePath,
77 Sha256,
78 ZipPath,
79 get_sha256,
80)
81from .node import Node
82from .progress import Progressbar
83from .root_url import RootHttpUrl
84from .type_guards import is_dict, is_list, is_mapping, is_sequence
85from .url import HttpUrl
86from .utils import SLOTS
87from .validation_context import get_validation_context
89AbsolutePathT = TypeVar(
90 "AbsolutePathT",
91 bound=Union[HttpUrl, AbsoluteDirectory, AbsoluteFilePath, ZipPath],
92)
95class LightHttpFileDescr(Node):
96 """http source with sha256 value (minimal validation)"""
98 source: pydantic.HttpUrl
99 """file source"""
101 sha256: Sha256
102 """SHA256 checksum of the source file"""
104 def get_reader(
105 self,
106 *,
107 progressbar: Union[Progressbar, Callable[[], Progressbar], bool, None] = None,
108 ) -> BytesReader:
109 """open the file source (download if needed)"""
110 return get_reader(self.source, sha256=self.sha256, progressbar=progressbar)
112 download = get_reader
113 """alias for get_reader() method"""
116class RelativePathBase(RootModel[PurePath], Generic[AbsolutePathT], frozen=True):
117 _absolute: AbsolutePathT = PrivateAttr()
119 @property
120 def path(self) -> PurePath:
121 return self.root
123 def absolute( # method not property analog to `pathlib.Path.absolute()`
124 self,
125 ) -> AbsolutePathT:
126 """get the absolute path/url
128 (resolved at time of initialization with the root of the ValidationContext)
129 """
130 return self._absolute
132 def model_post_init(self, __context: Any) -> None:
133 """set `_absolute` property with validation context at creation time. @private"""
134 if self.root.is_absolute():
135 raise ValueError(f"{self.root} is an absolute path.")
137 if self.root.parts and self.root.parts[0] in ("http:", "https:"):
138 raise ValueError(f"{self.root} looks like an http url.")
140 self._absolute = ( # pyright: ignore[reportAttributeAccessIssue]
141 self.get_absolute(get_validation_context().root)
142 )
143 super().model_post_init(__context)
145 def __str__(self) -> str:
146 return self.root.as_posix()
148 def __repr__(self) -> str:
149 return f"RelativePath('{self}')"
151 @model_serializer()
152 def format(self) -> str:
153 return str(self)
155 @abstractmethod
156 def get_absolute(
157 self, root: Union[RootHttpUrl, AbsoluteDirectory, pydantic.AnyUrl, ZipFile]
158 ) -> AbsolutePathT: ...
160 def _get_absolute_impl(
161 self, root: Union[RootHttpUrl, AbsoluteDirectory, pydantic.AnyUrl, ZipFile]
162 ) -> Union[Path, HttpUrl, ZipPath]:
163 if isinstance(root, Path):
164 return (root / self.root).absolute()
166 rel_path = self.root.as_posix().strip("/")
167 if isinstance(root, ZipFile):
168 return ZipPath(root, rel_path)
170 parsed = urlsplit(str(root))
171 path = list(parsed.path.strip("/").split("/"))
172 if (
173 parsed.netloc == "zenodo.org"
174 and parsed.path.startswith("/api/records/")
175 and parsed.path.endswith("/content")
176 ):
177 path.insert(-1, rel_path)
178 else:
179 path.append(rel_path)
181 return HttpUrl(
182 urlunsplit(
183 (
184 parsed.scheme,
185 parsed.netloc,
186 "/".join(path),
187 parsed.query,
188 parsed.fragment,
189 )
190 )
191 )
193 @classmethod
194 def _validate(cls, value: Union[PurePath, str]):
195 if isinstance(value, str) and (
196 value.startswith("https://") or value.startswith("http://")
197 ):
198 raise ValueError(f"{value} looks like a URL, not a relative path")
200 return cls(PurePath(value))
203class RelativeFilePath(
204 RelativePathBase[Union[AbsoluteFilePath, HttpUrl, ZipPath]], frozen=True
205):
206 """A path relative to the `rdf.yaml` file (also if the RDF source is a URL)."""
208 def model_post_init(self, __context: Any) -> None:
209 """add validation @private"""
210 if not self.root.parts: # an empty path can only be a directory
211 raise ValueError(f"{self.root} is not a valid file path.")
213 super().model_post_init(__context)
215 def get_absolute(
216 self, root: "RootHttpUrl | Path | AnyUrl | ZipFile"
217 ) -> "AbsoluteFilePath | HttpUrl | ZipPath":
218 absolute = self._get_absolute_impl(root)
219 if (
220 isinstance(absolute, Path)
221 and (context := get_validation_context()).perform_io_checks
222 and str(self.root) not in context.known_files
223 and not absolute.is_file()
224 ):
225 raise ValueError(f"{absolute} does not point to an existing file")
227 return absolute
230class RelativeDirectory(
231 RelativePathBase[Union[AbsoluteDirectory, HttpUrl, ZipPath]], frozen=True
232):
233 def get_absolute(
234 self, root: "RootHttpUrl | Path | AnyUrl | ZipFile"
235 ) -> "AbsoluteDirectory | HttpUrl | ZipPath":
236 absolute = self._get_absolute_impl(root)
237 if (
238 isinstance(absolute, Path)
239 and get_validation_context().perform_io_checks
240 and not absolute.is_dir()
241 ):
242 raise ValueError(f"{absolute} does not point to an existing directory")
244 return absolute
247FileSource = Annotated[
248 Union[HttpUrl, RelativeFilePath, FilePath],
249 Field(union_mode="left_to_right"),
250]
251PermissiveFileSource = Union[FileSource, str, pydantic.HttpUrl]
254class FileDescr(Node):
255 """A file description"""
257 source: FileSource
258 """File source"""
260 sha256: Optional[Sha256] = None
261 """SHA256 hash value of the **source** file."""
263 @model_validator(mode="after")
264 def _validate_sha256(self) -> Self:
265 if get_validation_context().perform_io_checks:
266 self.validate_sha256()
268 return self
270 def validate_sha256(self, force_recompute: bool = False) -> None:
271 """validate the sha256 hash value of the **source** file"""
272 context = get_validation_context()
273 src_str = str(self.source)
274 if not force_recompute and src_str in context.known_files:
275 actual_sha = context.known_files[src_str]
276 else:
277 reader = get_reader(self.source, sha256=self.sha256)
278 if force_recompute:
279 actual_sha = get_sha256(reader)
280 else:
281 actual_sha = reader.sha256
283 context.known_files[src_str] = actual_sha
285 if actual_sha is None:
286 return
287 elif self.sha256 == actual_sha:
288 pass
289 elif self.sha256 is None or context.update_hashes:
290 self.sha256 = actual_sha
291 elif self.sha256 != actual_sha:
292 raise ValueError(
293 f"Sha256 mismatch for {self.source}. Expected {self.sha256}, got "
294 + f"{actual_sha}. Update expected `sha256` or point to the matching "
295 + "file."
296 )
298 def get_reader(
299 self,
300 *,
301 progressbar: Union[Progressbar, Callable[[], Progressbar], bool, None] = None,
302 ):
303 """open the file source (download if needed)"""
304 return get_reader(self.source, progressbar=progressbar, sha256=self.sha256)
306 download = get_reader
307 """alias for get_reader() method"""
310path_or_url_adapter: "TypeAdapter[Union[FilePath, DirectoryPath, HttpUrl]]" = (
311 TypeAdapter(Union[FilePath, DirectoryPath, HttpUrl])
312)
315@dataclass(frozen=True, **SLOTS)
316class WithSuffix:
317 suffix: Union[LiteralString, Tuple[LiteralString, ...]]
318 case_sensitive: bool
320 def __get_pydantic_core_schema__(
321 self, source: Type[Any], handler: GetCoreSchemaHandler
322 ):
323 if not self.suffix:
324 raise ValueError("suffix may not be empty")
326 schema = handler(source)
327 return core_schema.no_info_after_validator_function(
328 self.validate,
329 schema,
330 )
332 def validate(
333 self, value: Union[FileSource, FileDescr]
334 ) -> Union[FileSource, FileDescr]:
335 return validate_suffix(value, self.suffix, case_sensitive=self.case_sensitive)
338def wo_special_file_name(src: F) -> F:
339 if has_valid_bioimageio_yaml_name(src):
340 raise ValueError(
341 f"'{src}' not allowed here as its filename is reserved to identify"
342 + f" '{BIOIMAGEIO_YAML}' (or equivalent) files."
343 )
345 return src
348def has_valid_bioimageio_yaml_name(src: Union[FileSource, FileDescr]) -> bool:
349 return is_valid_bioimageio_yaml_name(extract_file_name(src))
352def is_valid_bioimageio_yaml_name(file_name: FileName) -> bool:
353 for bioimageio_name in ALL_BIOIMAGEIO_YAML_NAMES:
354 if file_name == bioimageio_name or file_name.endswith("." + bioimageio_name):
355 return True
357 return False
360def identify_bioimageio_yaml_file_name(file_names: Iterable[FileName]) -> FileName:
361 file_names = sorted(file_names)
362 for bioimageio_name in ALL_BIOIMAGEIO_YAML_NAMES:
363 for file_name in file_names:
364 if file_name == bioimageio_name or file_name.endswith(
365 "." + bioimageio_name
366 ):
367 return file_name
369 raise ValueError(
370 f"No {BIOIMAGEIO_YAML} found in {file_names}. (Looking for '{BIOIMAGEIO_YAML}'"
371 + " or or any of the alterntive file names:"
372 + f" {ALTERNATIVE_BIOIMAGEIO_YAML_NAMES}, or any file with an extension of"
373 + f" those, e.g. 'anything.{BIOIMAGEIO_YAML}')."
374 )
377def find_bioimageio_yaml_file_name(path: Union[Path, ZipFile]) -> FileName:
378 if isinstance(path, ZipFile):
379 file_names = path.namelist()
380 elif path.is_file():
381 if not zipfile.is_zipfile(path):
382 return path.name
384 with ZipFile(path, "r") as f:
385 file_names = f.namelist()
386 else:
387 file_names = [p.name for p in path.glob("*")]
389 return identify_bioimageio_yaml_file_name(
390 file_names
391 ) # TODO: try/except with better error message for dir
394def ensure_has_valid_bioimageio_yaml_name(src: FileSource) -> FileSource:
395 if not has_valid_bioimageio_yaml_name(src):
396 raise ValueError(
397 f"'{src}' does not have a valid filename to identify"
398 + f" '{BIOIMAGEIO_YAML}' (or equivalent) files."
399 )
401 return src
404def ensure_is_valid_bioimageio_yaml_name(file_name: FileName) -> FileName:
405 if not is_valid_bioimageio_yaml_name(file_name):
406 raise ValueError(
407 f"'{file_name}' is not a valid filename to identify"
408 + f" '{BIOIMAGEIO_YAML}' (or equivalent) files."
409 )
411 return file_name
414# types as loaded from YAML 1.2 (with ruyaml)
415YamlLeafValue = Union[
416 bool, _date, _datetime, int, float, str, None
417] # note: order relevant for deserializing
418YamlKey = Union[ # YAML Arrays are cast to tuples if used as key in mappings
419 YamlLeafValue, Tuple[YamlLeafValue, ...] # (nesting is not allowed though)
420]
421if TYPE_CHECKING:
422 YamlValue = Union[YamlLeafValue, List["YamlValue"], Dict[YamlKey, "YamlValue"]]
423 YamlValueView = Union[
424 YamlLeafValue, Sequence["YamlValueView"], Mapping[YamlKey, "YamlValueView"]
425 ]
426else:
427 # for pydantic validation we need to use `TypeAliasType`,
428 # see https://docs.pydantic.dev/latest/concepts/types/#named-recursive-types
429 # however this results in a partially unknown type with the current pyright 1.1.388
430 YamlValue = _TypeAliasType(
431 "YamlValue",
432 Union[YamlLeafValue, List["YamlValue"], Dict[YamlKey, "YamlValue"]],
433 )
434 YamlValueView = _TypeAliasType(
435 "YamlValueView",
436 Union[
437 YamlLeafValue,
438 Sequence["YamlValueView"],
439 Mapping[YamlKey, "YamlValueView"],
440 ],
441 )
443BioimageioYamlContent = Dict[str, YamlValue]
444BioimageioYamlContentView = Mapping[str, YamlValueView]
445BioimageioYamlSource = Union[
446 PermissiveFileSource, ZipFile, BioimageioYamlContent, BioimageioYamlContentView
447]
450@overload
451def deepcopy_yaml_value(value: BioimageioYamlContentView) -> BioimageioYamlContent: ...
454@overload
455def deepcopy_yaml_value(value: YamlValueView) -> YamlValue: ...
458def deepcopy_yaml_value(
459 value: Union[BioimageioYamlContentView, YamlValueView],
460) -> Union[BioimageioYamlContent, YamlValue]:
461 if isinstance(value, str):
462 return value
463 elif isinstance(value, collections.abc.Mapping):
464 return {key: deepcopy_yaml_value(val) for key, val in value.items()}
465 elif isinstance(value, collections.abc.Sequence):
466 return [deepcopy_yaml_value(val) for val in value]
467 else:
468 return value
471def is_yaml_leaf_value(value: Any) -> TypeGuard[YamlLeafValue]:
472 return isinstance(value, (bool, _date, _datetime, int, float, str, type(None)))
475def is_yaml_list(value: Any) -> TypeGuard[List[YamlValue]]:
476 return is_list(value) and all(is_yaml_value(item) for item in value)
479def is_yaml_sequence(value: Any) -> TypeGuard[List[YamlValueView]]:
480 return is_sequence(value) and all(is_yaml_value(item) for item in value)
483def is_yaml_dict(value: Any) -> TypeGuard[BioimageioYamlContent]:
484 return is_dict(value) and all(
485 isinstance(key, str) and is_yaml_value(val) for key, val in value.items()
486 )
489def is_yaml_mapping(value: Any) -> TypeGuard[BioimageioYamlContentView]:
490 return is_mapping(value) and all(
491 isinstance(key, str) and is_yaml_value_read_only(val)
492 for key, val in value.items()
493 )
496def is_yaml_value(value: Any) -> TypeGuard[YamlValue]:
497 return is_yaml_leaf_value(value) or is_yaml_list(value) or is_yaml_dict(value)
500def is_yaml_value_read_only(value: Any) -> TypeGuard[YamlValueView]:
501 return (
502 is_yaml_leaf_value(value) or is_yaml_sequence(value) or is_yaml_mapping(value)
503 )
506@dataclass(frozen=True, **SLOTS)
507class OpenedBioimageioYaml:
508 content: BioimageioYamlContent = field(repr=False)
509 original_root: Union[AbsoluteDirectory, RootHttpUrl, ZipFile]
510 original_source_name: Optional[str]
511 original_file_name: FileName
512 unparsed_content: str = field(repr=False)
515@dataclass(frozen=True, **SLOTS)
516class LocalFile:
517 path: FilePath
518 original_root: Union[AbsoluteDirectory, RootHttpUrl, ZipFile]
519 original_file_name: FileName
522@dataclass(frozen=True, **SLOTS)
523class FileInZip:
524 path: ZipPath
525 original_root: Union[RootHttpUrl, ZipFile]
526 original_file_name: FileName
529class HashKwargs(TypedDict):
530 sha256: NotRequired[Optional[Sha256]]
533_file_source_adapter: TypeAdapter[Union[HttpUrl, RelativeFilePath, FilePath]] = (
534 TypeAdapter(FileSource)
535)
538def interprete_file_source(file_source: PermissiveFileSource) -> FileSource:
539 if isinstance(file_source, Path):
540 if file_source.is_dir():
541 raise FileNotFoundError(
542 f"{file_source} is a directory, but expected a file."
543 )
544 return file_source
546 if isinstance(file_source, HttpUrl):
547 return file_source
549 if isinstance(file_source, pydantic.AnyUrl):
550 file_source = str(file_source)
552 with get_validation_context().replace(perform_io_checks=False):
553 strict = _file_source_adapter.validate_python(file_source)
554 if isinstance(strict, Path) and strict.is_dir():
555 raise FileNotFoundError(f"{strict} is a directory, but expected a file.")
557 return strict
560def extract(
561 source: Union[FilePath, ZipFile, ZipPath],
562 folder: Optional[DirectoryPath] = None,
563 overwrite: bool = False,
564) -> DirectoryPath:
565 extract_member = None
566 if isinstance(source, ZipPath):
567 extract_member = source.at
568 source = source.root
570 if isinstance(source, ZipFile):
571 zip_context = nullcontext(source)
572 if folder is None:
573 if source.filename is None:
574 folder = Path(mkdtemp())
575 else:
576 zip_path = Path(source.filename)
577 folder = zip_path.with_suffix(zip_path.suffix + ".unzip")
578 else:
579 zip_context = ZipFile(source, "r")
580 if folder is None:
581 folder = source.with_suffix(source.suffix + ".unzip")
583 if overwrite and folder.exists():
584 warnings.warn(f"Overwriting existing unzipped archive at {folder}")
586 with zip_context as f:
587 if extract_member is not None:
588 extracted_file_path = folder / extract_member
589 if extracted_file_path.exists() and not overwrite:
590 warnings.warn(f"Found unzipped {extracted_file_path}.")
591 else:
592 _ = f.extract(extract_member, folder)
594 return folder
596 elif overwrite or not folder.exists():
597 f.extractall(folder)
598 return folder
600 found_content = {p.relative_to(folder).as_posix() for p in folder.glob("*")}
601 expected_content = {info.filename for info in f.filelist}
602 if expected_missing := expected_content - found_content:
603 parts = folder.name.split("_")
604 nr, *suffixes = parts[-1].split(".")
605 if nr.isdecimal():
606 nr = str(int(nr) + 1)
607 else:
608 nr = f"1.{nr}"
610 parts[-1] = ".".join([nr, *suffixes])
611 out_path_new = folder.with_name("_".join(parts))
612 warnings.warn(
613 f"Unzipped archive at {folder} is missing expected files"
614 + f" {expected_missing}."
615 + f" Unzipping to {out_path_new} instead to avoid overwriting."
616 )
617 return extract(f, out_path_new, overwrite=overwrite)
618 else:
619 warnings.warn(
620 f"Found unzipped archive with all expected files at {folder}."
621 )
622 return folder
625def get_reader(
626 source: Union[PermissiveFileSource, FileDescr, ZipPath],
627 /,
628 progressbar: Union[Progressbar, Callable[[], Progressbar], bool, None] = None,
629 **kwargs: Unpack[HashKwargs],
630) -> BytesReader:
631 """Open a file `source` (download if needed)"""
632 if isinstance(source, FileDescr):
633 if "sha256" not in kwargs:
634 kwargs["sha256"] = source.sha256
636 source = source.source
637 elif isinstance(source, str):
638 source = interprete_file_source(source)
640 if isinstance(source, RelativeFilePath):
641 source = source.absolute()
642 elif isinstance(source, pydantic.AnyUrl):
643 with get_validation_context().replace(perform_io_checks=False):
644 source = HttpUrl(source)
646 if isinstance(source, HttpUrl):
647 return _open_url(source, progressbar=progressbar, **kwargs)
649 if isinstance(source, ZipPath):
650 if not source.exists():
651 raise FileNotFoundError(source)
653 f = source.open(mode="rb")
654 assert not isinstance(f, TextIOWrapper)
655 root = source.root
656 elif isinstance(source, Path):
657 if source.is_dir():
658 raise FileNotFoundError(f"{source} is a directory, not a file")
660 if not source.exists():
661 raise FileNotFoundError(source)
663 f = source.open("rb")
664 root = source.parent
665 else:
666 assert_never(source)
668 expected_sha = kwargs.get("sha256")
669 if expected_sha is None:
670 sha = None
671 else:
672 sha = get_sha256(f)
673 _ = f.seek(0)
674 if sha != expected_sha:
675 raise ValueError(
676 f"SHA256 mismatch for {source}. Expected {expected_sha}, got {sha}."
677 )
679 return BytesReader(
680 f,
681 sha256=sha,
682 suffix=source.suffix,
683 original_file_name=source.name,
684 original_root=root,
685 is_zipfile=None,
686 )
689download = get_reader
692def _open_url(
693 source: HttpUrl,
694 /,
695 progressbar: Union[Progressbar, Callable[[], Progressbar], bool, None],
696 **kwargs: Unpack[HashKwargs],
697) -> BytesReader:
698 cache = (
699 NoopCache[RootHttpUrl](url_hasher=UrlDigest.from_str)
700 if get_validation_context().disable_cache
701 else settings.disk_cache
702 )
703 sha = kwargs.get("sha256")
704 digest = False if sha is None else ContentDigest.parse(hexdigest=sha)
705 source_path = PurePosixPath(
706 source.path
707 or sha
708 or hashlib.sha256(str(source).encode(encoding="utf-8")).hexdigest()
709 )
711 reader = cache.fetch(
712 source,
713 fetcher=partial(_fetch_url, progressbar=progressbar),
714 force_refetch=digest,
715 )
716 return BytesReader(
717 reader,
718 suffix=source_path.suffix,
719 sha256=sha,
720 original_file_name=source_path.name,
721 original_root=source.parent,
722 is_zipfile=None,
723 )
726def _fetch_url(
727 source: RootHttpUrl,
728 *,
729 progressbar: Union[Progressbar, Callable[[], Progressbar], bool, None],
730):
731 if source.scheme not in ("http", "https"):
732 raise NotImplementedError(source.scheme)
734 if progressbar is None:
735 # chose progressbar option from validation context
736 progressbar = get_validation_context().progressbar
738 if progressbar is None:
739 # default to no progressbar in CI environments
740 progressbar = not settings.CI
742 if callable(progressbar):
743 progressbar = progressbar()
745 if isinstance(progressbar, bool) and progressbar:
746 progressbar = tqdm(
747 ncols=79,
748 ascii=bool(sys.platform == "win32"),
749 unit="B",
750 unit_scale=True,
751 leave=True,
752 )
754 if progressbar is not False:
755 progressbar.set_description(f"Downloading {extract_file_name(source)}")
757 headers: Dict[str, str] = {}
758 if settings.user_agent is not None:
759 headers["User-Agent"] = settings.user_agent
760 elif settings.CI:
761 headers["User-Agent"] = "ci"
763 r = httpx.get(str(source), follow_redirects=True, headers=headers)
764 _ = r.raise_for_status()
766 # set progressbar.total
767 total = r.headers.get("content-length")
768 if total is not None and not isinstance(total, int):
769 try:
770 total = int(total)
771 except Exception:
772 total = None
774 if progressbar is not False:
775 if total is None:
776 progressbar.total = 0
777 else:
778 progressbar.total = total
780 def iter_content():
781 for chunk in r.iter_bytes(chunk_size=4096):
782 yield chunk
783 if progressbar is not False:
784 _ = progressbar.update(len(chunk))
786 # Make sure the progress bar gets filled even if the actual number
787 # is chunks is smaller than expected. This happens when streaming
788 # text files that are compressed by the server when sending (gzip).
789 # Binary files don't experience this.
790 # (adapted from pooch.HttpDownloader)
791 if progressbar is not False:
792 progressbar.reset()
793 if total is not None:
794 _ = progressbar.update(total)
796 progressbar.close()
798 return iter_content()
801def extract_file_name(
802 src: Union[
803 pydantic.HttpUrl, RootHttpUrl, PurePath, RelativeFilePath, ZipPath, FileDescr
804 ],
805) -> FileName:
806 if isinstance(src, FileDescr):
807 src = src.source
809 if isinstance(src, ZipPath):
810 return src.name or src.root.filename or "bioimageio.zip"
811 elif isinstance(src, RelativeFilePath):
812 return src.path.name
813 elif isinstance(src, PurePath):
814 return src.name
815 else:
816 url = urlparse(str(src))
817 if (
818 url.scheme == "https"
819 and url.hostname == "zenodo.org"
820 and url.path.startswith("/api/records/")
821 and url.path.endswith("/content")
822 ):
823 return url.path.split("/")[-2]
824 else:
825 return url.path.split("/")[-1]
828def extract_file_descrs(data: YamlValueView):
829 collected: List[FileDescr] = []
830 with get_validation_context().replace(perform_io_checks=False, log_warnings=False):
831 _extract_file_descrs_impl(data, collected)
833 return collected
836def _extract_file_descrs_impl(data: YamlValueView, collected: List[FileDescr]):
837 if isinstance(data, collections.abc.Mapping):
838 if "source" in data and "sha256" in data:
839 try:
840 fd = FileDescr.model_validate(
841 dict(source=data["source"], sha256=data["sha256"])
842 )
843 except Exception:
844 pass
845 else:
846 collected.append(fd)
848 for v in data.values():
849 _extract_file_descrs_impl(v, collected)
850 elif not isinstance(data, str) and isinstance(data, collections.abc.Sequence):
851 for v in data:
852 _extract_file_descrs_impl(v, collected)
855F = TypeVar("F", bound=Union[FileSource, FileDescr])
858def validate_suffix(
859 value: F, suffix: Union[str, Sequence[str]], case_sensitive: bool
860) -> F:
861 """check final suffix"""
862 if isinstance(suffix, str):
863 suffixes = [suffix]
864 else:
865 suffixes = suffix
867 assert len(suffixes) > 0, "no suffix given"
868 assert all(suff.startswith(".") for suff in suffixes), (
869 "expected suffixes to start with '.'"
870 )
871 o_value = value
872 if isinstance(value, FileDescr):
873 strict = value.source
874 else:
875 strict = interprete_file_source(value)
877 if isinstance(strict, (HttpUrl, AnyUrl)):
878 if strict.path is None or "." not in (path := strict.path):
879 actual_suffixes = []
880 else:
881 if (
882 strict.host == "zenodo.org"
883 and path.startswith("/api/records/")
884 and path.endswith("/content")
885 ):
886 # Zenodo API URLs have a "/content" suffix that should be ignored
887 path = path[: -len("/content")]
889 actual_suffixes = [f".{path.split('.')[-1]}"]
891 elif isinstance(strict, PurePath):
892 actual_suffixes = strict.suffixes
893 elif isinstance(strict, RelativeFilePath):
894 actual_suffixes = strict.path.suffixes
895 else:
896 assert_never(strict)
898 if actual_suffixes:
899 actual_suffix = actual_suffixes[-1]
900 else:
901 actual_suffix = "no suffix"
903 if (
904 case_sensitive
905 and actual_suffix not in suffixes
906 or not case_sensitive
907 and actual_suffix.lower() not in [s.lower() for s in suffixes]
908 ):
909 if len(suffixes) == 1:
910 raise ValueError(f"Expected suffix {suffixes[0]}, but got {actual_suffix}")
911 else:
912 raise ValueError(
913 f"Expected a suffix from {suffixes}, but got {actual_suffix}"
914 )
916 return o_value
919def populate_cache(sources: Sequence[Union[FileDescr, LightHttpFileDescr]]):
920 unique: Set[str] = set()
921 for src in sources:
922 if src.sha256 is None:
923 continue # not caching without known SHA
925 if isinstance(src.source, (HttpUrl, pydantic.AnyUrl)):
926 url = str(src.source)
927 elif isinstance(src.source, RelativeFilePath):
928 if isinstance(absolute := src.source.absolute(), HttpUrl):
929 url = str(absolute)
930 else:
931 continue # not caching local paths
932 elif isinstance(src.source, Path):
933 continue # not caching local paths
934 else:
935 assert_never(src.source)
937 if url in unique:
938 continue # skip duplicate URLs
940 unique.add(url)
941 _ = src.download()