Coverage for src / bioimageio / core / cli.py: 84%
377 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-18 11:12 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-18 11:12 +0000
1"""bioimageio CLI
3Note: Some docstrings use a hair space ' '
4 to place the added '(default: ...)' on a new line.
5"""
7import json
8import shutil
9import subprocess
10import sys
11from abc import ABC
12from argparse import RawTextHelpFormatter
13from difflib import SequenceMatcher
14from functools import cached_property
15from io import StringIO
16from pathlib import Path
17from pprint import pformat, pprint
18from typing import (
19 Annotated,
20 Any,
21 Dict,
22 Iterable,
23 List,
24 Literal,
25 Mapping,
26 Optional,
27 Sequence,
28 Set,
29 Tuple,
30 Type,
31 Union,
32)
34import rich.markdown
35from loguru import logger
36from pydantic import (
37 AliasChoices,
38 BaseModel,
39 Field,
40 PlainSerializer,
41 WithJsonSchema,
42 model_validator,
43)
44from pydantic_settings import (
45 BaseSettings,
46 CliApp,
47 CliPositionalArg,
48 CliSettingsSource,
49 CliSubCommand,
50 JsonConfigSettingsSource,
51 PydanticBaseSettingsSource,
52 SettingsConfigDict,
53 YamlConfigSettingsSource,
54)
55from tqdm import tqdm
56from typing_extensions import assert_never
58import bioimageio.spec
59from bioimageio.core import __version__
60from bioimageio.spec import (
61 AnyModelDescr,
62 InvalidDescr,
63 ResourceDescr,
64 load_description,
65 save_bioimageio_yaml_only,
66 settings,
67 update_format,
68 update_hashes,
69)
70from bioimageio.spec._internal.io import is_yaml_value
71from bioimageio.spec._internal.io_utils import open_bioimageio_yaml
72from bioimageio.spec._internal.types import FormatVersionPlaceholder, NotEmpty
73from bioimageio.spec.dataset import DatasetDescr
74from bioimageio.spec.model import ModelDescr, v0_4, v0_5
75from bioimageio.spec.notebook import NotebookDescr
76from bioimageio.spec.utils import (
77 empty_cache,
78 ensure_description_is_model,
79 get_reader,
80 write_yaml,
81)
83from .commands import WeightFormatArgAll, WeightFormatArgAny, package, test
84from .common import MemberId, SampleId, SupportedWeightsFormat
85from .digest_spec import get_member_ids, load_sample_for_model
86from .io import load_dataset_stat, save_dataset_stat, save_sample
87from .prediction import create_prediction_pipeline
88from .proc_setup import (
89 DatasetMeasure,
90 Measure,
91 MeasureValue,
92 StatsCalculator,
93 get_required_dataset_measures,
94)
95from .sample import Sample
96from .stat_measures import Stat
97from .utils import compare
98from .weight_converters._add_weights import add_weights
100WEIGHT_FORMAT_ALIASES = AliasChoices(
101 "weight-format",
102 "weights-format",
103 "weight_format",
104 "weights_format",
105)
108class CmdBase(BaseModel, use_attribute_docstrings=True, cli_implicit_flags=True):
109 pass
112class ArgMixin(BaseModel, use_attribute_docstrings=True, cli_implicit_flags=True):
113 pass
116class WithSummaryLogging(ArgMixin):
117 summary: List[Union[Literal["display"], Path]] = Field(
118 default_factory=lambda: ["display"],
119 examples=[
120 Path("summary.md"),
121 Path("bioimageio_summaries/"),
122 ["display", Path("summary.md")],
123 ],
124 )
125 """Display the validation summary or save it as JSON, Markdown or HTML.
126 The format is chosen based on the suffix: `.json`, `.md`, `.html`.
127 If a folder is given (path w/o suffix) the summary is saved in all formats.
128 Choose/add `"display"` to render the validation summary to the terminal.
129 """
131 def log(self, descr: Union[ResourceDescr, InvalidDescr]):
132 _ = descr.validation_summary.log(self.summary)
135class WithSource(ArgMixin):
136 source: CliPositionalArg[str]
137 """Url/path to a (folder with a) bioimageio.yaml/rdf.yaml file
138 or a bioimage.io resource identifier, e.g. 'affable-shark'"""
140 @cached_property
141 def descr(self):
142 return load_description(self.source)
144 @property
145 def descr_id(self) -> str:
146 """a more user-friendly description id
147 (replacing legacy ids with their nicknames)
148 """
149 if isinstance(self.descr, InvalidDescr):
150 return str(getattr(self.descr, "id", getattr(self.descr, "name")))
152 nickname = None
153 if (
154 isinstance(self.descr.config, v0_5.Config)
155 and (bio_config := self.descr.config.bioimageio)
156 and bio_config.model_extra is not None
157 ):
158 nickname = bio_config.model_extra.get("nickname")
160 return str(nickname or self.descr.id or self.descr.name)
163class ValidateFormatCmd(CmdBase, WithSource, WithSummaryLogging):
164 """Validate the meta data format of a bioimageio resource."""
166 perform_io_checks: bool = Field(
167 settings.perform_io_checks, alias="perform-io-checks"
168 )
169 """Wether or not to perform validations that requires downloading remote files.
170 Note: Default value is set by `BIOIMAGEIO_PERFORM_IO_CHECKS` environment variable.
171 """
173 @cached_property
174 def descr(self):
175 return load_description(self.source, perform_io_checks=self.perform_io_checks)
177 def cli_cmd(self):
178 self.log(self.descr)
179 sys.exit(
180 0
181 if self.descr.validation_summary.status in ("valid-format", "passed")
182 else 1
183 )
186class TestCmd(CmdBase, WithSource, WithSummaryLogging):
187 """Test a bioimageio resource (beyond meta data formatting)."""
189 weight_format: WeightFormatArgAll = Field(
190 "all",
191 alias="weight-format",
192 validation_alias=WEIGHT_FORMAT_ALIASES,
193 )
194 """The weight format to limit testing to.
196 (only relevant for model resources)"""
198 devices: Optional[List[str]] = None
199 """Device(s) to use for testing"""
201 runtime_env: Union[Literal["currently-active", "as-described"], Path] = Field(
202 "currently-active", alias="runtime-env"
203 )
204 """The python environment to run the tests in
205 - `"currently-active"`: use active Python interpreter
206 - `"as-described"`: generate a conda environment YAML file based on the model
207 weights description.
208 - A path to a conda environment YAML.
209 Note: The `bioimageio.core` dependency will be added automatically if not present.
210 """
212 working_dir: Optional[Path] = Field(None, alias="working-dir")
213 """(for debugging) Directory to save any temporary files."""
215 determinism: Literal["seed_only", "full"] = "seed_only"
216 """Modes to improve reproducibility of test outputs."""
218 stop_early: bool = Field(
219 False, alias="stop-early", validation_alias=AliasChoices("stop-early", "x")
220 )
221 """Do not run further subtests after a failed one."""
223 format_version: Union[FormatVersionPlaceholder, str] = Field(
224 "discover", alias="format-version"
225 )
226 """The format version to use for testing.
227 - 'latest': Use the latest implemented format version for the given resource type (may trigger auto updating)
228 - 'discover': Use the format version as described in the resource description
229 - '0.4', '0.5', ...: Use the specified format version (may trigger auto updating)
230 """
232 def cli_cmd(self):
233 sys.exit(
234 test(
235 self.descr,
236 weight_format=self.weight_format,
237 devices=self.devices,
238 summary=self.summary,
239 runtime_env=self.runtime_env,
240 determinism=self.determinism,
241 format_version=self.format_version,
242 working_dir=self.working_dir,
243 )
244 )
247class PackageCmd(CmdBase, WithSource, WithSummaryLogging):
248 """Save a resource's metadata with its associated files."""
250 path: CliPositionalArg[Path]
251 """The path to write the (zipped) package to.
252 If it does not have a `.zip` suffix
253 this command will save the package as an unzipped folder instead."""
255 weight_format: WeightFormatArgAll = Field(
256 "all",
257 alias="weight-format",
258 validation_alias=WEIGHT_FORMAT_ALIASES,
259 )
260 """The weight format to include in the package (for model descriptions only)."""
262 def cli_cmd(self):
263 if isinstance(self.descr, InvalidDescr):
264 self.log(self.descr)
265 raise ValueError(f"Invalid {self.descr.type} description.")
267 sys.exit(
268 package(
269 self.descr,
270 self.path,
271 weight_format=self.weight_format,
272 )
273 )
276def _get_stat(
277 model_descr: AnyModelDescr,
278 dataset: Iterable[Sample],
279 dataset_length: int,
280 stats_path: Path,
281) -> Mapping[DatasetMeasure, MeasureValue]:
282 req_dataset_meas, _ = get_required_dataset_measures(model_descr)
283 if not req_dataset_meas:
284 return {}
286 req_dataset_meas, _ = get_required_dataset_measures(model_descr)
288 if stats_path.exists():
289 logger.info("loading precomputed dataset measures from {}", stats_path)
290 stat = load_dataset_stat(stats_path)
291 for m in req_dataset_meas:
292 if m not in stat:
293 raise ValueError(f"Missing {m} in {stats_path}")
295 return stat
297 stats_calc = StatsCalculator(req_dataset_meas)
299 for sample in tqdm(
300 dataset, total=dataset_length, desc="precomputing dataset stats", unit="sample"
301 ):
302 stats_calc.update(sample)
304 stat = stats_calc.finalize()
305 save_dataset_stat(stat, stats_path)
307 return stat
310class UpdateCmdBase(CmdBase, WithSource, ABC):
311 output: Union[Literal["display", "stdout"], Path] = "display"
312 """Output updated bioimageio.yaml to the terminal or write to a file.
313 Notes:
314 - `"display"`: Render to the terminal with syntax highlighting.
315 - `"stdout"`: Write to sys.stdout without syntax highligthing.
316 (More convenient for copying the updated bioimageio.yaml from the terminal.)
317 """
319 diff: Union[bool, Path] = Field(True, alias="diff")
320 """Output a diff of original and updated bioimageio.yaml.
321 If a given path has an `.html` extension, a standalone HTML file is written,
322 otherwise the diff is saved in unified diff format (pure text).
323 """
325 exclude_unset: bool = Field(True, alias="exclude-unset")
326 """Exclude fields that have not explicitly be set."""
328 exclude_defaults: bool = Field(False, alias="exclude-defaults")
329 """Exclude fields that have the default value (even if set explicitly)."""
331 @cached_property
332 def updated(self) -> Union[ResourceDescr, InvalidDescr]:
333 raise NotImplementedError
335 def cli_cmd(self):
336 original_yaml = open_bioimageio_yaml(self.source).unparsed_content
337 assert isinstance(original_yaml, str)
338 stream = StringIO()
340 save_bioimageio_yaml_only(
341 self.updated,
342 stream,
343 exclude_unset=self.exclude_unset,
344 exclude_defaults=self.exclude_defaults,
345 )
346 updated_yaml = stream.getvalue()
348 diff = compare(
349 original_yaml.split("\n"),
350 updated_yaml.split("\n"),
351 diff_format=(
352 "html"
353 if isinstance(self.diff, Path) and self.diff.suffix == ".html"
354 else "unified"
355 ),
356 )
358 if isinstance(self.diff, Path):
359 _ = self.diff.write_text(diff, encoding="utf-8")
360 elif self.diff:
361 console = rich.console.Console()
362 diff_md = f"## Diff\n\n````````diff\n{diff}\n````````"
363 console.print(rich.markdown.Markdown(diff_md))
365 if isinstance(self.output, Path):
366 _ = self.output.write_text(updated_yaml, encoding="utf-8")
367 logger.info(f"written updated description to {self.output}")
368 elif self.output == "display":
369 updated_md = f"## Updated bioimageio.yaml\n\n```yaml\n{updated_yaml}\n```"
370 rich.console.Console().print(rich.markdown.Markdown(updated_md))
371 elif self.output == "stdout":
372 print(updated_yaml)
373 else:
374 assert_never(self.output)
376 if isinstance(self.updated, InvalidDescr):
377 logger.warning("Update resulted in invalid description")
378 _ = self.updated.validation_summary.display()
381class UpdateFormatCmd(UpdateCmdBase):
382 """Update the metadata format to the latest format version."""
384 exclude_defaults: bool = Field(True, alias="exclude-defaults")
385 """Exclude fields that have the default value (even if set explicitly).
387 Note:
388 The update process sets most unset fields explicitly with their default value.
389 """
391 perform_io_checks: bool = Field(
392 settings.perform_io_checks, alias="perform-io-checks"
393 )
394 """Wether or not to attempt validation that may require file download.
395 If `True` file hash values are added if not present."""
397 @cached_property
398 def updated(self):
399 return update_format(
400 self.source,
401 exclude_defaults=self.exclude_defaults,
402 perform_io_checks=self.perform_io_checks,
403 )
406class UpdateHashesCmd(UpdateCmdBase):
407 """Create a bioimageio.yaml description with updated file hashes."""
409 @cached_property
410 def updated(self):
411 return update_hashes(self.source)
414class PredictCmd(CmdBase, WithSource):
415 """Run inference on your data with a bioimage.io model."""
417 inputs: NotEmpty[List[Union[str, NotEmpty[List[str]]]]] = Field(
418 default_factory=lambda: ["{input_id}/001.tif"]
419 )
420 """Model input sample paths (for each input tensor)
422 The input paths are expected to have shape...
423 - (n_samples,) or (n_samples,1) for models expecting a single input tensor
424 - (n_samples,) containing the substring '{input_id}', or
425 - (n_samples, n_model_inputs) to provide each input tensor path explicitly.
427 All substrings that are replaced by metadata from the model description:
428 - '{model_id}'
429 - '{input_id}'
431 Example inputs to process sample 'a' and 'b'
432 for a model expecting a 'raw' and a 'mask' input tensor:
433 --inputs="[[\\"a_raw.tif\\",\\"a_mask.tif\\"],[\\"b_raw.tif\\",\\"b_mask.tif\\"]]"
434 (Note that JSON double quotes need to be escaped.)
436 Alternatively a `bioimageio-cli.yaml` (or `bioimageio-cli.json`) file
437 may provide the arguments, e.g.:
438 ```yaml
439 inputs:
440 - [a_raw.tif, a_mask.tif]
441 - [b_raw.tif, b_mask.tif]
442 ```
444 `.npy` and any file extension supported by imageio are supported.
445 Aavailable formats are listed at
446 https://imageio.readthedocs.io/en/stable/formats/index.html#all-formats.
447 Some formats have additional dependencies.
450 """
452 outputs: Union[str, NotEmpty[Tuple[str, ...]]] = (
453 "outputs_{model_id}/{output_id}/{sample_id}.tif"
454 )
455 """Model output path pattern (per output tensor)
457 All substrings that are replaced:
458 - '{model_id}' (from model description)
459 - '{output_id}' (from model description)
460 - '{sample_id}' (extracted from input paths)
463 """
465 overwrite: bool = False
466 """allow overwriting existing output files"""
468 blockwise: bool = False
469 """process inputs blockwise"""
471 stats: Annotated[
472 Path,
473 WithJsonSchema({"type": "string"}),
474 PlainSerializer(lambda p: p.as_posix(), return_type=str),
475 ] = Path("dataset_statistics.json")
476 """path to dataset statistics
477 (will be written if it does not exist,
478 but the model requires statistical dataset measures)
479 """
481 preview: bool = False
482 """preview which files would be processed
483 and what outputs would be generated."""
485 weight_format: WeightFormatArgAny = Field(
486 "any",
487 alias="weight-format",
488 validation_alias=WEIGHT_FORMAT_ALIASES,
489 )
490 """The weight format to use."""
492 example: bool = False
493 """generate and run an example
495 1. downloads example model inputs
496 2. creates a `{model_id}_example` folder
497 3. writes input arguments to `{model_id}_example/bioimageio-cli.yaml`
498 4. executes a preview dry-run
499 5. executes prediction with example input
502 """
504 def _example(self):
505 model_descr = ensure_description_is_model(self.descr)
506 input_ids = get_member_ids(model_descr.inputs)
507 example_inputs = (
508 model_descr.sample_inputs
509 if isinstance(model_descr, v0_4.ModelDescr)
510 else [
511 t
512 for ipt in model_descr.inputs
513 if (t := ipt.sample_tensor or ipt.test_tensor)
514 ]
515 )
516 if not example_inputs:
517 raise ValueError(f"{self.descr_id} does not specify any example inputs.")
519 inputs001: List[str] = []
520 example_path = Path(f"{self.descr_id}_example")
521 example_path.mkdir(exist_ok=True)
523 for t, src in zip(input_ids, example_inputs):
524 reader = get_reader(src)
525 dst = Path(f"{example_path}/{t}/001{reader.suffix}")
526 dst.parent.mkdir(parents=True, exist_ok=True)
527 inputs001.append(dst.as_posix())
528 with dst.open("wb") as f:
529 shutil.copyfileobj(reader, f)
531 inputs = [inputs001]
532 output_pattern = f"{example_path}/outputs/{{output_id}}/{{sample_id}}.tif"
534 bioimageio_cli_path = example_path / YAML_FILE
535 stats_file = "dataset_statistics.json"
536 stats = (example_path / stats_file).as_posix()
537 cli_example_args = dict(
538 inputs=inputs,
539 outputs=output_pattern,
540 stats=stats_file,
541 blockwise=self.blockwise,
542 )
543 assert is_yaml_value(cli_example_args), cli_example_args
544 write_yaml(
545 cli_example_args,
546 bioimageio_cli_path,
547 )
549 yaml_file_content = None
551 # escaped double quotes
552 inputs_json = json.dumps(inputs)
553 inputs_escaped = inputs_json.replace('"', r"\"")
554 source_escaped = self.source.replace('"', r"\"")
556 def get_example_command(preview: bool, escape: bool = False):
557 q: str = '"' if escape else ""
559 return [
560 "bioimageio",
561 "predict",
562 # --no-preview not supported for py=3.8
563 *(["--preview"] if preview else []),
564 "--overwrite",
565 *(["--blockwise"] if self.blockwise else []),
566 f"--stats={q}{stats}{q}",
567 f"--inputs={q}{inputs_escaped if escape else inputs_json}{q}",
568 f"--outputs={q}{output_pattern}{q}",
569 f"{q}{source_escaped if escape else self.source}{q}",
570 ]
572 if Path(YAML_FILE).exists():
573 logger.info(
574 "temporarily removing '{}' to execute example prediction", YAML_FILE
575 )
576 yaml_file_content = Path(YAML_FILE).read_bytes()
577 Path(YAML_FILE).unlink()
579 try:
580 _ = subprocess.run(get_example_command(True), check=True)
581 _ = subprocess.run(get_example_command(False), check=True)
582 finally:
583 if yaml_file_content is not None:
584 _ = Path(YAML_FILE).write_bytes(yaml_file_content)
585 logger.debug("restored '{}'", YAML_FILE)
587 print(
588 "🎉 Sucessfully ran example prediction!\n"
589 + "To predict the example input using the CLI example config file"
590 + f" {example_path / YAML_FILE}, execute `bioimageio predict` from {example_path}:\n"
591 + f"$ cd {str(example_path)}\n"
592 + f'$ bioimageio predict "{source_escaped}"\n\n'
593 + "Alternatively run the following command"
594 + " in the current workind directory, not the example folder:\n$ "
595 + " ".join(get_example_command(False, escape=True))
596 + f"\n(note that a local '{JSON_FILE}' or '{YAML_FILE}' may interfere with this)"
597 )
599 def cli_cmd(self):
600 if self.example:
601 return self._example()
603 model_descr = ensure_description_is_model(self.descr)
605 input_ids = get_member_ids(model_descr.inputs)
606 output_ids = get_member_ids(model_descr.outputs)
608 minimum_input_ids = tuple(
609 str(ipt.id) if isinstance(ipt, v0_5.InputTensorDescr) else str(ipt.name)
610 for ipt in model_descr.inputs
611 if not isinstance(ipt, v0_5.InputTensorDescr) or not ipt.optional
612 )
613 maximum_input_ids = tuple(
614 str(ipt.id) if isinstance(ipt, v0_5.InputTensorDescr) else str(ipt.name)
615 for ipt in model_descr.inputs
616 )
618 def expand_inputs(i: int, ipt: Union[str, Sequence[str]]) -> Tuple[str, ...]:
619 if isinstance(ipt, str):
620 ipts = tuple(
621 ipt.format(model_id=self.descr_id, input_id=t) for t in input_ids
622 )
623 else:
624 ipts = tuple(
625 p.format(model_id=self.descr_id, input_id=t)
626 for t, p in zip(input_ids, ipt)
627 )
629 if len(set(ipts)) < len(ipts):
630 if len(minimum_input_ids) == len(maximum_input_ids):
631 n = len(minimum_input_ids)
632 else:
633 n = f"{len(minimum_input_ids)}-{len(maximum_input_ids)}"
635 raise ValueError(
636 f"[input sample #{i}] Include '{{input_id}}' in path pattern or explicitly specify {n} distinct input paths (got {ipt})"
637 )
639 if len(ipts) < len(minimum_input_ids):
640 raise ValueError(
641 f"[input sample #{i}] Expected at least {len(minimum_input_ids)} inputs {minimum_input_ids}, got {ipts}"
642 )
644 if len(ipts) > len(maximum_input_ids):
645 raise ValueError(
646 f"Expected at most {len(maximum_input_ids)} inputs {maximum_input_ids}, got {ipts}"
647 )
649 return ipts
651 inputs = [expand_inputs(i, ipt) for i, ipt in enumerate(self.inputs, start=1)]
653 sample_paths_in = [
654 {t: Path(p) for t, p in zip(input_ids, ipts)} for ipts in inputs
655 ]
657 sample_ids = _get_sample_ids(sample_paths_in)
659 def expand_outputs():
660 if isinstance(self.outputs, str):
661 outputs = [
662 tuple(
663 Path(
664 self.outputs.format(
665 model_id=self.descr_id, output_id=t, sample_id=s
666 )
667 )
668 for t in output_ids
669 )
670 for s in sample_ids
671 ]
672 else:
673 outputs = [
674 tuple(
675 Path(p.format(model_id=self.descr_id, output_id=t, sample_id=s))
676 for t, p in zip(output_ids, self.outputs)
677 )
678 for s in sample_ids
679 ]
680 # check for distinctness and correct number within each output sample
681 for i, out in enumerate(outputs, start=1):
682 if len(set(out)) < len(out):
683 raise ValueError(
684 f"[output sample #{i}] Include '{{output_id}}' in path pattern or explicitly specify {len(output_ids)} distinct output paths (got {out})"
685 )
687 if len(out) != len(output_ids):
688 raise ValueError(
689 f"[output sample #{i}] Expected {len(output_ids)} outputs {output_ids}, got {out}"
690 )
692 # check for distinctness across all output samples
693 all_output_paths = [p for out in outputs for p in out]
694 if len(set(all_output_paths)) < len(all_output_paths):
695 raise ValueError(
696 "Output paths are not distinct across samples. "
697 + f"Make sure to include '{{sample_id}}' in the output path pattern."
698 )
700 return outputs
702 outputs = expand_outputs()
704 sample_paths_out = [
705 {MemberId(t): Path(p) for t, p in zip(output_ids, out)} for out in outputs
706 ]
708 if not self.overwrite:
709 for sample_paths in sample_paths_out:
710 for p in sample_paths.values():
711 if p.exists():
712 raise FileExistsError(
713 f"{p} already exists. use --overwrite to (re-)write outputs anyway."
714 )
715 if self.preview:
716 print("🛈 bioimageio prediction preview structure:")
717 pprint(
718 {
719 "{sample_id}": dict(
720 inputs={"{input_id}": "<input path>"},
721 outputs={"{output_id}": "<output path>"},
722 )
723 }
724 )
725 print("🔎 bioimageio prediction preview output:")
726 pprint(
727 {
728 s: dict(
729 inputs={t: p.as_posix() for t, p in sp_in.items()},
730 outputs={t: p.as_posix() for t, p in sp_out.items()},
731 )
732 for s, sp_in, sp_out in zip(
733 sample_ids, sample_paths_in, sample_paths_out
734 )
735 }
736 )
737 return
739 def input_dataset(stat: Stat):
740 for s, sp_in in zip(sample_ids, sample_paths_in):
741 yield load_sample_for_model(
742 model=model_descr,
743 paths=sp_in,
744 stat=stat,
745 sample_id=s,
746 )
748 stat: Dict[Measure, MeasureValue] = dict(
749 _get_stat(
750 model_descr, input_dataset({}), len(sample_ids), self.stats
751 ).items()
752 )
754 pp = create_prediction_pipeline(
755 model_descr,
756 weight_format=None if self.weight_format == "any" else self.weight_format,
757 )
758 predict_method = (
759 pp.predict_sample_with_blocking
760 if self.blockwise
761 else pp.predict_sample_without_blocking
762 )
764 for sample_in, sp_out in tqdm(
765 zip(input_dataset(dict(stat)), sample_paths_out),
766 total=len(inputs),
767 desc=f"predict with {self.descr_id}",
768 unit="sample",
769 ):
770 sample_out = predict_method(sample_in)
771 save_sample(sp_out, sample_out)
774class AddWeightsCmd(CmdBase, WithSource, WithSummaryLogging):
775 """Add additional weights to a model description by converting from available formats."""
777 output: CliPositionalArg[Path]
778 """The path to write the updated model package to."""
780 source_format: Optional[SupportedWeightsFormat] = Field(None, alias="source-format")
781 """Exclusively use these weights to convert to other formats."""
783 target_format: Optional[SupportedWeightsFormat] = Field(None, alias="target-format")
784 """Exclusively add this weight format."""
786 verbose: bool = False
787 """Log more (error) output."""
789 tracing: bool = True
790 """Allow tracing when converting pytorch_state_dict to torchscript
791 (still uses scripting if possible)."""
793 def cli_cmd(self):
794 model_descr = ensure_description_is_model(self.descr)
795 if isinstance(model_descr, v0_4.ModelDescr):
796 raise TypeError(
797 f"model format {model_descr.format_version} not supported."
798 + " Please update the model first."
799 )
800 updated_model_descr = add_weights(
801 model_descr,
802 output_path=self.output,
803 source_format=self.source_format,
804 target_format=self.target_format,
805 verbose=self.verbose,
806 allow_tracing=self.tracing,
807 )
808 self.log(updated_model_descr)
811class EmptyCache(CmdBase):
812 """Empty the bioimageio cache directory."""
814 def cli_cmd(self):
815 empty_cache()
818JSON_FILE = "bioimageio-cli.json"
819YAML_FILE = "bioimageio-cli.yaml"
822class Bioimageio(
823 BaseSettings,
824 cli_implicit_flags=True,
825 cli_parse_args=True,
826 cli_prog_name="bioimageio",
827 cli_use_class_docs_for_groups=True,
828 use_attribute_docstrings=True,
829):
830 """bioimageio - CLI for bioimage.io resources 🦒"""
832 model_config = SettingsConfigDict(
833 json_file=JSON_FILE,
834 yaml_file=YAML_FILE,
835 )
837 validate_format: CliSubCommand[ValidateFormatCmd] = Field(alias="validate-format")
838 "Check a resource's metadata format"
840 test: CliSubCommand[TestCmd]
841 "Test a bioimageio resource (beyond meta data formatting)"
843 package: CliSubCommand[PackageCmd]
844 "Package a resource"
846 predict: CliSubCommand[PredictCmd]
847 "Predict with a model resource"
849 update_format: CliSubCommand[UpdateFormatCmd] = Field(alias="update-format")
850 """Update the metadata format"""
852 update_hashes: CliSubCommand[UpdateHashesCmd] = Field(alias="update-hashes")
853 """Create a bioimageio.yaml description with updated file hashes."""
855 add_weights: CliSubCommand[AddWeightsCmd] = Field(alias="add-weights")
856 """Add additional weights to a model description by converting from available formats."""
858 empty_cache: CliSubCommand[EmptyCache] = Field(alias="empty-cache")
859 """Empty the bioimageio cache directory."""
861 @classmethod
862 def settings_customise_sources(
863 cls,
864 settings_cls: Type[BaseSettings],
865 init_settings: PydanticBaseSettingsSource,
866 env_settings: PydanticBaseSettingsSource,
867 dotenv_settings: PydanticBaseSettingsSource,
868 file_secret_settings: PydanticBaseSettingsSource,
869 ) -> Tuple[PydanticBaseSettingsSource, ...]:
870 cli: CliSettingsSource[BaseSettings] = CliSettingsSource(
871 settings_cls,
872 cli_parse_args=True,
873 formatter_class=RawTextHelpFormatter,
874 )
875 sys_args = pformat(sys.argv)
876 logger.info("starting CLI with arguments:\n{}", sys_args)
877 return (
878 cli,
879 init_settings,
880 YamlConfigSettingsSource(settings_cls),
881 JsonConfigSettingsSource(settings_cls),
882 )
884 @model_validator(mode="before")
885 @classmethod
886 def _log(cls, data: Any):
887 logger.info(
888 "loaded CLI input:\n{}",
889 pformat({k: v for k, v in data.items() if v is not None}),
890 )
891 return data
893 def cli_cmd(self) -> None:
894 logger.info(
895 "executing CLI command:\n{}",
896 pformat({k: v for k, v in self.model_dump().items() if v is not None}),
897 )
898 _ = CliApp.run_subcommand(self)
901assert isinstance(Bioimageio.__doc__, str)
902Bioimageio.__doc__ += f"""
904library versions:
905 bioimageio.core {__version__}
906 bioimageio.spec {bioimageio.spec.__version__}
908spec format versions:
909 model RDF {ModelDescr.implemented_format_version}
910 dataset RDF {DatasetDescr.implemented_format_version}
911 notebook RDF {NotebookDescr.implemented_format_version}
913"""
916def _get_sample_ids(
917 input_paths: Sequence[Mapping[MemberId, Path]],
918) -> Sequence[SampleId]:
919 """Get sample ids for given input paths, based on the common path per sample.
921 Falls back to sample01, samle02, etc..."""
923 matcher = SequenceMatcher()
925 def get_common_seq(seqs: Sequence[Sequence[str]]) -> Sequence[str]:
926 """extract a common sequence from multiple sequences
927 (order sensitive; strips whitespace and slashes)
928 """
929 common = seqs[0]
931 for seq in seqs[1:]:
932 if not seq:
933 continue
934 matcher.set_seqs(common, seq)
935 i, _, size = matcher.find_longest_match()
936 common = common[i : i + size]
938 if isinstance(common, str):
939 common = common.strip().strip("/")
940 else:
941 common = [cs for c in common if (cs := c.strip().strip("/"))]
943 if not common:
944 raise ValueError(f"failed to find common sequence for {seqs}")
946 return common
948 def get_shorter_diff(seqs: Sequence[Sequence[str]]) -> List[Sequence[str]]:
949 """get a shorter sequence whose entries are still unique
950 (order sensitive, not minimal sequence)
951 """
952 min_seq_len = min(len(s) for s in seqs)
953 # cut from the start
954 for start in range(min_seq_len - 1, -1, -1):
955 shortened = [s[start:] for s in seqs]
956 if len(set(shortened)) == len(seqs):
957 min_seq_len -= start
958 break
959 else:
960 seen: Set[Sequence[str]] = set()
961 dupes = [s for s in seqs if s in seen or seen.add(s)]
962 raise ValueError(f"Found duplicate entries {dupes}")
964 # cut from the end
965 for end in range(min_seq_len - 1, 1, -1):
966 shortened = [s[:end] for s in shortened]
967 if len(set(shortened)) == len(seqs):
968 break
970 return shortened
972 full_tensor_ids = [
973 sorted(
974 p.resolve().with_suffix("").as_posix() for p in input_sample_paths.values()
975 )
976 for input_sample_paths in input_paths
977 ]
978 try:
979 long_sample_ids = [get_common_seq(t) for t in full_tensor_ids]
980 sample_ids = get_shorter_diff(long_sample_ids)
981 except ValueError as e:
982 raise ValueError(f"failed to extract sample ids: {e}")
984 return sample_ids