Coverage for src / bioimageio / spec / _hf.py: 98%
40 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-23 10:51 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-23 10:51 +0000
1import os
2import tempfile
3import warnings
4from contextlib import nullcontext
5from functools import cache
6from pathlib import Path
7from typing import Optional, Union
9from loguru import logger
11from bioimageio.spec import save_bioimageio_package_as_folder
12from bioimageio.spec._internal.validation_context import get_validation_context
13from bioimageio.spec.model.v0_5 import ModelDescr
15from ._hf_card import create_huggingface_model_card
16from ._version import VERSION
19@cache
20def get_huggingface_api(): # pragma: no cover
21 from huggingface_hub import HfApi
23 return HfApi(library_name="bioimageio.spec", library_version=VERSION)
26def push_to_hub(
27 descr: ModelDescr,
28 username_or_org: str,
29 *,
30 prep_dir: Optional[Union[os.PathLike[str], str]] = None,
31 prep_only_no_upload: bool = False,
32 create_pr: Optional[bool] = None,
33):
34 """Push the model package described by `descr` to the Hugging Face Hub.
36 Note:
37 - Uses `descr.id` as the repository name under the provided `username_or_org`.
38 - If `descr.version` is set, the model package is uploaded to the 'main' branch
39 and tagged with the version.
40 - If `descr.version` is `None`, the model package is uploaded to the 'draft' branch.
42 Args:
43 descr: The model description to be pushed to the Hugging Face Hub.
44 username_or_org: The Hugging Face username or organization under which the model package will be uploaded.
45 The model ID from `descr.id` will be used as the repository name.
46 prep_dir: Optional path to an empty directory where the model package will be prepared before uploading.
47 prep_only_no_upload: If `True`, only prepare the model package in `prep_dir` without uploading it
48 to the Hugging Face Hub.
49 create_pr: If `False` commit directly to the 'main'/'draft' branch.
50 If `True`, create a pull request targeting 'main'/'draft'.
51 Defaults to `True` if uploading to a model description with version (to the main branch),
52 and `False` if uploading a model description without version (to the 'draft' branch).
54 Examples:
55 Upload a model description as a new version to the main branch
56 (`descr.id` and `descr.version` must be set):
58 >>> descr = ModelDescr(id="my-model-id", version="1.0", create_pr=False, ...) # doctest: +SKIP
59 >>> push_to_hub(descr, "my_hf_username") # doctest: +SKIP
61 Upload a model description as a draft to the 'draft' branch
62 (`descr.id` must be set; `descr.version` must be `None`):
64 >>> descr = ModelDescr(id="my-model-id", version=None, ...) # doctest: +SKIP
65 >>> push_to_hub(descr, "my_hf_username") # doctest: +SKIP
67 See what would be uploaded without actually uploading:
69 >>> push_to_hub(..., prep_dir="empty_local_folder", prep_only_no_upload=True) # doctest: +SKIP
71 """
73 if descr.id is None:
74 raise ValueError("descr.id must be set to push to Hugging Face Hub.")
75 repo_id = f"{username_or_org}/{descr.id}"
77 if prep_dir is None:
78 ctxt = tempfile.TemporaryDirectory(suffix="_" + repo_id.replace("/", "_"))
79 elif Path(prep_dir).exists() and any(Path(prep_dir).iterdir()):
80 raise ValueError("Provided `prep_dir` is not empty.")
81 # TODO: implement resuming upload
82 # prep_dir: If a non-empty folder is provided, it will be attempted to continue an interrupted upload.
83 # logger.info(f"Continuing upload from {prep_dir}")
84 # if prep_only_no_upload:
85 # raise ValueError("`prep_only_no_upload` is True but `prep_dir` is non-empty.")
86 else:
87 ctxt = nullcontext(prep_dir)
89 with ctxt as pdir:
90 _push_to_hub_impl(
91 descr,
92 repo_id=repo_id,
93 prep_dir=Path(pdir),
94 prep_only=prep_only_no_upload,
95 create_pr=create_pr,
96 )
99def _push_to_hub_impl(
100 descr: ModelDescr,
101 *,
102 repo_id: str,
103 prep_dir: Path,
104 prep_only: bool,
105 create_pr: Optional[bool],
106):
107 readme, referenced_files = create_huggingface_model_card(descr, repo_id=repo_id)
108 referenced_files_subfolders = {"images"}
109 assert not (
110 unexpected := [
111 rf
112 for rf in referenced_files
113 if not any(rf.startswith(f"{sf}/") for sf in referenced_files_subfolders)
114 ]
115 ), f"unexpected folder of referenced files: {unexpected}"
117 logger.info(f"Preparing model for upload at {prep_dir}.")
118 prep_dir.mkdir(parents=True, exist_ok=True)
119 _ = (prep_dir / "README.md").write_text(readme, encoding="utf-8")
120 for img_name, img_data in referenced_files.items():
121 image_path = prep_dir / img_name
122 image_path.parent.mkdir(parents=True, exist_ok=True)
123 _ = image_path.write_bytes(img_data)
125 with get_validation_context().replace(file_name="bioimageio.yaml"):
126 _ = save_bioimageio_package_as_folder(descr, output_path=prep_dir / "package")
128 logger.info(f"Prepared model for upload at {prep_dir}")
130 commit_message = f"Upload {descr.version or 'draft'} with bioimageio.spec {VERSION}"
131 commit_description = (
132 f"Version comment: {descr.version_comment}" if descr.version_comment else None
133 )
135 if not prep_only: # pragma: no cover
136 logger.info(f"Pushing model '{descr.id}' to Hugging Face Hub")
138 api = get_huggingface_api()
139 repo_url = api.create_repo(repo_id=repo_id, exist_ok=True, repo_type="model")
140 logger.info(f"Created repository at {repo_url}")
142 existing_refs = api.list_repo_refs(
143 repo_id=repo_id, repo_type="model", include_pull_requests=True
144 )
145 has_draft_ref = False
146 has_tag = False
147 for ref in existing_refs.branches + existing_refs.tags:
148 if ref.name == str(descr.version):
149 has_tag = True
150 if ref.name == "draft":
151 has_draft_ref = True
153 if descr.version is None:
154 revision = "draft"
155 if not has_draft_ref:
156 api.create_branch(repo_id=repo_id, branch="draft", repo_type="model")
157 else:
158 revision = None
160 if create_pr is None:
161 # default to creating a PR if commiting to main branch,
162 # commit directly to 'draft' branch
163 create_pr = revision is None
165 commit_info = api.upload_folder(
166 repo_id=repo_id,
167 revision=revision,
168 folder_path=prep_dir,
169 delete_patterns=[f"{sf}/*" for sf in referenced_files_subfolders]
170 + ["package/*"],
171 commit_message=commit_message,
172 commit_description=commit_description,
173 create_pr=create_pr,
174 )
175 logger.info(f"Created commit {commit_info.commit_url}")
176 if descr.version is not None:
177 if has_tag:
178 warnings.warn(f"Moving existing version tag {descr.version}.")
180 api.create_tag(
181 repo_id=repo_id,
182 tag=str(descr.version),
183 revision=commit_info.oid,
184 tag_message=descr.version_comment,
185 exist_ok=True,
186 )