Coverage for src / bioimageio / core / cli.py: 81%
418 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-18 12:35 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-18 12:35 +0000
1"""bioimageio CLI
3Note: Some docstrings use a hair space ' '
4 to place the added '(default: ...)' on a new line.
5"""
7import json
8import shutil
9import subprocess
10import sys
11from abc import ABC
12from argparse import RawTextHelpFormatter
13from difflib import SequenceMatcher
14from functools import cached_property, partial
15from io import StringIO
16from pathlib import Path
17from pprint import pformat, pprint
18from typing import (
19 Annotated,
20 Any,
21 Dict,
22 Iterable,
23 List,
24 Literal,
25 Mapping,
26 Optional,
27 Sequence,
28 Set,
29 Tuple,
30 Type,
31 Union,
32)
34import numpy as np
35import rich.markdown
36from loguru import logger
37from pydantic import (
38 AliasChoices,
39 BaseModel,
40 Field,
41 PlainSerializer,
42 WithJsonSchema,
43 model_validator,
44)
45from pydantic_settings import (
46 BaseSettings,
47 CliApp,
48 CliPositionalArg,
49 CliSettingsSource,
50 CliSubCommand,
51 JsonConfigSettingsSource,
52 PydanticBaseSettingsSource,
53 SettingsConfigDict,
54 YamlConfigSettingsSource,
55)
56from tqdm import tqdm
57from typing_extensions import assert_never
59import bioimageio.spec
60from bioimageio.core import __version__
61from bioimageio.spec import (
62 AnyModelDescr,
63 InvalidDescr,
64 ResourceDescr,
65 load_description,
66 save_bioimageio_package,
67 save_bioimageio_package_as_folder,
68 save_bioimageio_yaml_only,
69 settings,
70 update_format,
71 update_hashes,
72)
73from bioimageio.spec._internal.io import is_yaml_value
74from bioimageio.spec._internal.io_utils import open_bioimageio_yaml
75from bioimageio.spec._internal.types import FormatVersionPlaceholder, NotEmpty
76from bioimageio.spec.dataset import DatasetDescr
77from bioimageio.spec.model import ModelDescr, v0_4, v0_5
78from bioimageio.spec.notebook import NotebookDescr
79from bioimageio.spec.utils import (
80 empty_cache,
81 ensure_description_is_model,
82 get_reader,
83 write_yaml,
84)
86from .commands import WeightFormatArgAll, WeightFormatArgAny, package, test
87from .common import MemberId, SampleId, SupportedWeightsFormat
88from .digest_spec import get_member_ids, load_sample_for_model
89from .io import load_stat, save_sample, save_stat
90from .prediction import create_prediction_pipeline
91from .proc_setup import (
92 Measure,
93 MeasureValue,
94 StatsCalculator,
95 get_required_dataset_measures,
96)
97from .sample import Sample
98from .stat_measures import Stat
99from .utils import compare
100from .weight_converters._add_weights import add_weights
102WEIGHT_FORMAT_ALIASES = AliasChoices(
103 "weight-format",
104 "weights-format",
105 "weight_format",
106 "weights_format",
107)
110class CmdBase(BaseModel, use_attribute_docstrings=True, cli_implicit_flags=True):
111 pass
114class ArgMixin(BaseModel, use_attribute_docstrings=True, cli_implicit_flags=True):
115 pass
118class WithSummaryLogging(ArgMixin):
119 summary: List[Union[Literal["display"], Path]] = Field(
120 default_factory=lambda: ["display"],
121 examples=[
122 Path("summary.md"),
123 Path("bioimageio_summaries/"),
124 ["display", Path("summary.md")],
125 ],
126 )
127 """Display the validation summary or save it as JSON, Markdown or HTML.
128 The format is chosen based on the suffix: `.json`, `.md`, `.html`.
129 If a folder is given (path w/o suffix) the summary is saved in all formats.
130 Choose/add `"display"` to render the validation summary to the terminal.
131 """
133 def log(self, descr: Union[ResourceDescr, InvalidDescr]):
134 _ = descr.validation_summary.log(self.summary)
137class WithSource(ArgMixin):
138 source: CliPositionalArg[str]
139 """Url/path to a (folder with a) bioimageio.yaml/rdf.yaml file
140 or a bioimage.io resource identifier, e.g. 'affable-shark'"""
142 @cached_property
143 def descr(self):
144 return load_description(self.source)
146 @property
147 def descr_id(self) -> str:
148 """a more user-friendly description id
149 (replacing legacy ids with their nicknames)
150 """
151 if isinstance(self.descr, InvalidDescr):
152 return str(getattr(self.descr, "id", getattr(self.descr, "name")))
154 nickname = None
155 if (
156 isinstance(self.descr.config, v0_5.Config)
157 and (bio_config := self.descr.config.bioimageio)
158 and bio_config.model_extra is not None
159 ):
160 nickname = bio_config.model_extra.get("nickname")
162 return str(nickname or self.descr.id or self.descr.name)
165class ValidateFormatCmd(CmdBase, WithSource, WithSummaryLogging):
166 """Validate the meta data format of a bioimageio resource."""
168 perform_io_checks: bool = Field(
169 settings.perform_io_checks, alias="perform-io-checks"
170 )
171 """Wether or not to perform validations that requires downloading remote files.
172 Note: Default value is set by `BIOIMAGEIO_PERFORM_IO_CHECKS` environment variable.
173 """
175 @cached_property
176 def descr(self):
177 return load_description(self.source, perform_io_checks=self.perform_io_checks)
179 def cli_cmd(self):
180 self.log(self.descr)
181 sys.exit(
182 0
183 if self.descr.validation_summary.status in ("valid-format", "passed")
184 else 1
185 )
188class TestCmd(CmdBase, WithSource, WithSummaryLogging):
189 """Test a bioimageio resource (beyond meta data formatting)."""
191 weight_format: WeightFormatArgAll = Field(
192 "all",
193 alias="weight-format",
194 validation_alias=WEIGHT_FORMAT_ALIASES,
195 )
196 """The weight format to limit testing to.
198 (only relevant for model resources)"""
200 devices: Optional[List[str]] = Field(
201 None, validation_alias=AliasChoices("devices", "device")
202 )
203 """Device(s) to use"""
205 runtime_env: Union[Literal["currently-active", "as-described"], Path] = Field(
206 "currently-active", alias="runtime-env"
207 )
208 """The python environment to run the tests in
209 - `"currently-active"`: use active Python interpreter
210 - `"as-described"`: generate a conda environment YAML file based on the model
211 weights description.
212 - A path to a conda environment YAML.
213 Note: The `bioimageio.core` dependency will be added automatically if not present.
214 """
216 working_dir: Optional[Path] = Field(None, alias="working-dir")
217 """(for debugging) Directory to save any temporary files."""
219 determinism: Literal["seed_only", "full"] = "seed_only"
220 """Modes to improve reproducibility of test outputs."""
222 stop_early: bool = Field(
223 False, alias="stop-early", validation_alias=AliasChoices("stop-early", "x")
224 )
225 """Do not run further subtests after a failed one."""
227 format_version: Union[FormatVersionPlaceholder, str] = Field(
228 "discover", alias="format-version"
229 )
230 """The format version to use for testing.
231 - 'latest': Use the latest implemented format version for the given resource type (may trigger auto updating)
232 - 'discover': Use the format version as described in the resource description
233 - '0.4', '0.5', ...: Use the specified format version (may trigger auto updating)
234 """
236 def cli_cmd(self):
237 sys.exit(
238 test(
239 self.descr,
240 weight_format=self.weight_format,
241 devices=self.devices,
242 summary=self.summary,
243 runtime_env=self.runtime_env,
244 determinism=self.determinism,
245 format_version=self.format_version,
246 working_dir=self.working_dir,
247 )
248 )
251class PackageCmd(CmdBase, WithSource, WithSummaryLogging):
252 """Save a resource's metadata with its associated files."""
254 path: CliPositionalArg[Path]
255 """The path to write the (zipped) package to.
256 If it does not have a `.zip` suffix
257 this command will save the package as an unzipped folder instead."""
259 weight_format: WeightFormatArgAll = Field(
260 "all",
261 alias="weight-format",
262 validation_alias=WEIGHT_FORMAT_ALIASES,
263 )
264 """The weight format to include in the package (for model descriptions only)."""
266 def cli_cmd(self):
267 if isinstance(self.descr, InvalidDescr):
268 self.log(self.descr)
269 raise ValueError(f"Invalid {self.descr.type} description.")
271 sys.exit(
272 package(
273 self.descr,
274 self.path,
275 weight_format=self.weight_format,
276 )
277 )
280def _get_stat(
281 model_descr: AnyModelDescr,
282 dataset: Iterable[Sample],
283 dataset_length: int,
284 stats_path: Path,
285) -> Stat:
286 req_dataset_meas, _ = get_required_dataset_measures(model_descr)
287 if not req_dataset_meas:
288 return {}
290 req_dataset_meas, _ = get_required_dataset_measures(model_descr)
292 if stats_path.exists():
293 logger.info("loading precomputed dataset measures from {}", stats_path)
294 stat = load_stat(stats_path)
295 for m in req_dataset_meas:
296 if m not in stat:
297 raise ValueError(f"Missing {m} in {stats_path}")
299 return stat
301 stats_calc = StatsCalculator(req_dataset_meas)
303 for sample in tqdm(
304 dataset, total=dataset_length, desc="precomputing dataset stats", unit="sample"
305 ):
306 stats_calc.update(sample)
308 stat: Dict[Measure, MeasureValue] = {k: v for k, v in stats_calc.finalize().items()}
309 save_stat(stat, stats_path)
310 return stat
313class UpdateCmdBase(CmdBase, WithSource, ABC):
314 output: Union[Literal["display", "stdout"], Path] = "display"
315 """Output updated bioimageio.yaml to the terminal or write to a file.
316 Notes:
317 - `"display"`: Render to the terminal with syntax highlighting.
318 - `"stdout"`: Write to sys.stdout without syntax highligthing.
319 (More convenient for copying the updated bioimageio.yaml from the terminal.)
320 """
322 diff: Union[bool, Path] = Field(True, alias="diff")
323 """Output a diff of original and updated bioimageio.yaml.
324 If a given path has an `.html` extension, a standalone HTML file is written,
325 otherwise the diff is saved in unified diff format (pure text).
326 """
328 exclude_unset: bool = Field(True, alias="exclude-unset")
329 """Exclude fields that have not explicitly be set."""
331 exclude_defaults: bool = Field(False, alias="exclude-defaults")
332 """Exclude fields that have the default value (even if set explicitly)."""
334 @cached_property
335 def updated(self) -> Union[ResourceDescr, InvalidDescr]:
336 raise NotImplementedError
338 def cli_cmd(self):
339 original_yaml = open_bioimageio_yaml(self.source).unparsed_content
340 assert isinstance(original_yaml, str)
341 stream = StringIO()
343 save_bioimageio_yaml_only(
344 self.updated,
345 stream,
346 exclude_unset=self.exclude_unset,
347 exclude_defaults=self.exclude_defaults,
348 )
349 updated_yaml = stream.getvalue()
351 diff = compare(
352 original_yaml.split("\n"),
353 updated_yaml.split("\n"),
354 diff_format=(
355 "html"
356 if isinstance(self.diff, Path) and self.diff.suffix == ".html"
357 else "unified"
358 ),
359 )
361 if isinstance(self.diff, Path):
362 _ = self.diff.write_text(diff, encoding="utf-8")
363 elif self.diff:
364 console = rich.console.Console()
365 diff_md = f"## Diff\n\n````````diff\n{diff}\n````````"
366 console.print(rich.markdown.Markdown(diff_md))
368 if isinstance(self.output, Path):
369 if self.output.suffix in (".yaml", ".yml"):
370 _ = self.output.write_text(updated_yaml, encoding="utf-8")
371 logger.info(f"written updated description to {self.output}")
372 elif isinstance(self.updated, InvalidDescr):
373 raise ValueError(
374 f"Cannot save invalid description package to {self.output}."
375 + " To save the metadata only, choose an output with a '.yaml' extension."
376 )
378 elif not self.output.suffix:
379 _ = save_bioimageio_package_as_folder(
380 self.updated, output_path=self.output
381 )
382 else:
383 _ = save_bioimageio_package(self.updated, output_path=self.output)
385 elif self.output == "display":
386 updated_md = f"## Updated bioimageio.yaml\n\n```yaml\n{updated_yaml}\n```"
387 rich.console.Console().print(rich.markdown.Markdown(updated_md))
388 elif self.output == "stdout":
389 print(updated_yaml)
390 else:
391 assert_never(self.output)
393 if isinstance(self.updated, InvalidDescr):
394 logger.warning("Update resulted in invalid description")
395 _ = self.updated.validation_summary.display()
398class UpdateFormatCmd(UpdateCmdBase):
399 """Update the metadata format to the latest format version."""
401 exclude_defaults: bool = Field(True, alias="exclude-defaults")
402 """Exclude fields that have the default value (even if set explicitly).
404 Note:
405 The update process sets most unset fields explicitly with their default value.
406 """
408 perform_io_checks: bool = Field(
409 settings.perform_io_checks, alias="perform-io-checks"
410 )
411 """Wether or not to attempt validation that may require file download.
412 If `True` file hash values are added if not present."""
414 @cached_property
415 def updated(self):
416 return update_format(
417 self.source,
418 exclude_defaults=self.exclude_defaults,
419 perform_io_checks=self.perform_io_checks,
420 )
423class UpdateHashesCmd(UpdateCmdBase):
424 """Create a bioimageio.yaml description with updated file hashes."""
426 @cached_property
427 def updated(self):
428 return update_hashes(self.source)
431class PredictCmd(CmdBase, WithSource):
432 """Run inference on your data with a bioimage.io model."""
434 inputs: NotEmpty[List[Union[str, NotEmpty[List[str]]]]] = Field(
435 default_factory=lambda: ["{input_id}/001.tif"],
436 validation_alias=AliasChoices("inputs", "input"),
437 )
438 """Model input sample paths (for each input tensor)
440 The input paths are expected to have shape...
441 - (n_samples,) or (n_samples,1) for models expecting a single input tensor
442 - (n_samples,) containing the substring '{input_id}', or
443 - (n_samples, n_model_inputs) to provide each input tensor path explicitly.
445 All substrings that are replaced by metadata from the model description:
446 - '{model_id}'
447 - '{input_id}'
449 Example inputs to process sample 'a' and 'b'
450 for a model expecting a 'raw' and a 'mask' input tensor:
451 --inputs="[[\\"a_raw.tif\\",\\"a_mask.tif\\"],[\\"b_raw.tif\\",\\"b_mask.tif\\"]]"
452 (Note that JSON double quotes need to be escaped.)
454 Alternatively a `bioimageio-cli.yaml` (or `bioimageio-cli.json`) file
455 may provide the arguments, e.g.:
456 ```yaml
457 inputs:
458 - [a_raw.tif, a_mask.tif]
459 - [b_raw.tif, b_mask.tif]
460 ```
462 `.npy` and any file extension supported by imageio are supported.
463 Aavailable formats are listed at
464 https://imageio.readthedocs.io/en/stable/formats/index.html#all-formats.
465 Some formats have additional dependencies.
468 """
470 outputs: Union[str, NotEmpty[Tuple[str, ...]]] = Field(
471 "outputs_{model_id}/{output_id}/{sample_id}.tif",
472 validation_alias=AliasChoices("outputs", "output"),
473 )
474 """Model output path pattern (per output tensor)
476 All substrings that are replaced:
477 - '{model_id}' (from model description)
478 - '{output_id}' (from model description)
479 - '{sample_id}' (extracted from input paths)
481 """
483 overwrite: bool = False
484 """allow overwriting existing output files"""
486 blockwise: Union[bool, int] = False
487 """Process inputs blockwise
489 - If an integer is given, it is used as the blocksize parameter 'n' for blockwise processing.
490 The blockize parameter determines the block size along axes with parameterized input size
491 by adding n*step_size to the minimum valid input size.
492 - If `True`, the blocksize parameter is set to 10.
493 - If `False`, inputs are processed as a whole without blocking.
495 """
497 stats: Annotated[
498 Path,
499 WithJsonSchema({"type": "string"}),
500 PlainSerializer(lambda p: p.as_posix(), return_type=str),
501 ] = Path("precomputed_statistics.json")
502 """path to dataset statistics
503 (will be written if it does not exist
504 and the model requires statistical dataset measures)
505 """
507 preview: bool = False
508 """preview which files would be processed
509 and what outputs would be generated."""
511 weight_format: WeightFormatArgAny = Field(
512 "any",
513 alias="weight-format",
514 validation_alias=WEIGHT_FORMAT_ALIASES,
515 )
516 """The weight format to use."""
518 devices: Optional[List[str]] = Field(
519 None, validation_alias=AliasChoices("devices", "device")
520 )
521 """Device(s) to use"""
523 example: bool = False
524 """generate and run an example
526 1. downloads example model inputs
527 2. creates a `{model_id}_example` folder
528 3. writes input arguments to `{model_id}_example/bioimageio-cli.yaml`
529 4. executes a preview dry-run
530 5. executes prediction with example input
533 """
535 def _example(self):
536 model_descr = ensure_description_is_model(self.descr)
537 input_ids = get_member_ids(model_descr.inputs)
538 example_inputs = (
539 model_descr.sample_inputs
540 if isinstance(model_descr, v0_4.ModelDescr)
541 else [
542 t
543 for ipt in model_descr.inputs
544 if (t := ipt.sample_tensor or ipt.test_tensor)
545 ]
546 )
547 if not example_inputs:
548 raise ValueError(f"{self.descr_id} does not specify any example inputs.")
550 inputs001: List[str] = []
551 example_path = Path(f"{self.descr_id}_example")
552 example_path.mkdir(exist_ok=True)
554 for t, src in zip(input_ids, example_inputs):
555 reader = get_reader(src)
556 dst = Path(f"{example_path}/{t}/001{reader.suffix}")
557 dst.parent.mkdir(parents=True, exist_ok=True)
558 inputs001.append(dst.as_posix())
559 with dst.open("wb") as f:
560 shutil.copyfileobj(reader, f)
562 inputs = [inputs001]
563 output_pattern = f"{example_path}/outputs/{{output_id}}/{{sample_id}}.tif"
565 bioimageio_cli_path = example_path / YAML_FILE
566 stats_file = "precomputed_statistics.json"
567 stats = (example_path / stats_file).as_posix()
568 cli_example_args = dict(
569 inputs=inputs,
570 outputs=output_pattern,
571 stats=stats_file,
572 blockwise=self.blockwise,
573 )
574 assert is_yaml_value(cli_example_args), cli_example_args
575 write_yaml(
576 cli_example_args,
577 bioimageio_cli_path,
578 )
580 yaml_file_content = None
582 # escaped double quotes
583 inputs_json = json.dumps(inputs)
584 inputs_escaped = inputs_json.replace('"', r"\"")
585 source_escaped = self.source.replace('"', r"\"")
587 def get_example_command(preview: bool, escape: bool = False):
588 q: str = '"' if escape else ""
590 return [
591 "bioimageio",
592 "predict",
593 # --no-preview not supported for py=3.8
594 *(["--preview"] if preview else []),
595 "--overwrite",
596 f"--blockwise={self.blockwise}",
597 f"--stats={q}{stats}{q}",
598 f"--inputs={q}{inputs_escaped if escape else inputs_json}{q}",
599 f"--outputs={q}{output_pattern}{q}",
600 f"{q}{source_escaped if escape else self.source}{q}",
601 ]
603 if Path(YAML_FILE).exists():
604 logger.info(
605 "temporarily removing '{}' to execute example prediction", YAML_FILE
606 )
607 yaml_file_content = Path(YAML_FILE).read_bytes()
608 Path(YAML_FILE).unlink()
610 try:
611 _ = subprocess.run(get_example_command(True), check=True)
612 _ = subprocess.run(get_example_command(False), check=True)
613 finally:
614 if yaml_file_content is not None:
615 _ = Path(YAML_FILE).write_bytes(yaml_file_content)
616 logger.debug("restored '{}'", YAML_FILE)
618 print(
619 "🎉 Sucessfully ran example prediction!\n"
620 + "To predict the example input using the CLI example config file"
621 + f" {example_path / YAML_FILE}, execute `bioimageio predict` from {example_path}:\n"
622 + f"$ cd {str(example_path)}\n"
623 + f'$ bioimageio predict "{source_escaped}"\n\n'
624 + "Alternatively run the following command"
625 + " in the current workind directory, not the example folder:\n$ "
626 + " ".join(get_example_command(False, escape=True))
627 + f"\n(note that a local '{JSON_FILE}' or '{YAML_FILE}' may interfere with this)"
628 )
630 def cli_cmd(self):
631 try:
632 for out_sample, out_path in self._yield_predictions(self.blockwise):
633 save_sample(out_path, out_sample)
634 except Exception as e:
635 if not self.blockwise:
636 raise RuntimeError(
637 f"Prediction failed ({e}).\nConsider using blockwise processing, "
638 + "e.g. with `--blockwise=10` to process inputs in blocks."
639 ) from e
640 raise e
642 def _yield_predictions(self, blockwise: Union[bool, int]):
643 if self.example:
644 return self._example()
646 model_descr = ensure_description_is_model(self.descr)
648 input_ids = get_member_ids(model_descr.inputs)
649 output_ids = get_member_ids(model_descr.outputs)
651 minimum_input_ids = tuple(
652 str(ipt.id) if isinstance(ipt, v0_5.InputTensorDescr) else str(ipt.name)
653 for ipt in model_descr.inputs
654 if not isinstance(ipt, v0_5.InputTensorDescr) or not ipt.optional
655 )
656 maximum_input_ids = tuple(
657 str(ipt.id) if isinstance(ipt, v0_5.InputTensorDescr) else str(ipt.name)
658 for ipt in model_descr.inputs
659 )
661 def expand_inputs(i: int, ipt: Union[str, Sequence[str]]) -> Tuple[str, ...]:
662 if isinstance(ipt, str):
663 ipts = tuple(
664 ipt.format(model_id=self.descr_id, input_id=t) for t in input_ids
665 )
666 else:
667 ipts = tuple(
668 p.format(model_id=self.descr_id, input_id=t)
669 for t, p in zip(input_ids, ipt)
670 )
672 if len(set(ipts)) < len(ipts):
673 if len(minimum_input_ids) == len(maximum_input_ids):
674 n = len(minimum_input_ids)
675 else:
676 n = f"{len(minimum_input_ids)}-{len(maximum_input_ids)}"
678 raise ValueError(
679 f"[input sample #{i}] Include '{{input_id}}' in path pattern or explicitly specify {n} distinct input paths (got {ipt})"
680 )
682 if len(ipts) < len(minimum_input_ids):
683 raise ValueError(
684 f"[input sample #{i}] Expected at least {len(minimum_input_ids)} inputs {minimum_input_ids}, got {ipts}"
685 )
687 if len(ipts) > len(maximum_input_ids):
688 raise ValueError(
689 f"Expected at most {len(maximum_input_ids)} inputs {maximum_input_ids}, got {ipts}"
690 )
692 return ipts
694 inputs = [expand_inputs(i, ipt) for i, ipt in enumerate(self.inputs, start=1)]
696 sample_paths_in = [
697 {t: Path(p) for t, p in zip(input_ids, ipts)} for ipts in inputs
698 ]
700 sample_ids = _get_sample_ids(sample_paths_in)
702 def expand_outputs():
703 if isinstance(self.outputs, str):
704 outputs = [
705 tuple(
706 Path(
707 self.outputs.format(
708 model_id=self.descr_id, output_id=t, sample_id=s
709 )
710 )
711 for t in output_ids
712 )
713 for s in sample_ids
714 ]
715 else:
716 outputs = [
717 tuple(
718 Path(p.format(model_id=self.descr_id, output_id=t, sample_id=s))
719 for t, p in zip(output_ids, self.outputs)
720 )
721 for s in sample_ids
722 ]
723 # check for distinctness and correct number within each output sample
724 for i, out in enumerate(outputs, start=1):
725 if len(set(out)) < len(out):
726 raise ValueError(
727 f"[output sample #{i}] Include '{{output_id}}' in path pattern or explicitly specify {len(output_ids)} distinct output paths (got {out})"
728 )
730 if len(out) != len(output_ids):
731 raise ValueError(
732 f"[output sample #{i}] Expected {len(output_ids)} outputs {output_ids}, got {out}"
733 )
735 # check for distinctness across all output samples
736 all_output_paths = [p for out in outputs for p in out]
737 if len(set(all_output_paths)) < len(all_output_paths):
738 raise ValueError(
739 "Output paths are not distinct across samples. "
740 + "Make sure to include '{{sample_id}}' in the output path pattern."
741 )
743 return outputs
745 outputs = expand_outputs()
747 sample_paths_out = [
748 {MemberId(t): Path(p) for t, p in zip(output_ids, out)} for out in outputs
749 ]
751 if not self.overwrite:
752 for sample_paths in sample_paths_out:
753 for p in sample_paths.values():
754 if p.exists():
755 raise FileExistsError(
756 f"{p} already exists. use --overwrite to (re-)write outputs anyway."
757 )
758 if self.preview:
759 print("🛈 bioimageio prediction preview structure:")
760 pprint(
761 {
762 "{sample_id}": dict(
763 inputs={"{input_id}": "<input path>"},
764 outputs={"{output_id}": "<output path>"},
765 )
766 }
767 )
768 print("🔎 bioimageio prediction preview output:")
769 pprint(
770 {
771 s: dict(
772 inputs={t: p.as_posix() for t, p in sp_in.items()},
773 outputs={t: p.as_posix() for t, p in sp_out.items()},
774 )
775 for s, sp_in, sp_out in zip(
776 sample_ids, sample_paths_in, sample_paths_out
777 )
778 }
779 )
780 return
782 def input_dataset(stat: Stat):
783 for s, sp_in in zip(sample_ids, sample_paths_in):
784 yield load_sample_for_model(
785 model=model_descr,
786 paths=sp_in,
787 stat=stat,
788 sample_id=s,
789 )
791 stat: Dict[Measure, MeasureValue] = dict(
792 _get_stat(
793 model_descr, input_dataset({}), len(sample_ids), self.stats
794 ).items()
795 )
797 pp = create_prediction_pipeline(
798 model_descr,
799 weight_format=None if self.weight_format == "any" else self.weight_format,
800 devices=self.devices,
801 )
803 if blockwise:
804 predict_method = partial(
805 pp.predict_sample_with_blocking,
806 ns=None if isinstance(blockwise, bool) else blockwise,
807 )
808 else:
809 predict_method = pp.predict_sample_without_blocking
811 for sample_in, sp_out in tqdm(
812 zip(input_dataset(dict(stat)), sample_paths_out),
813 total=len(inputs),
814 desc=f"predict with {self.descr_id}",
815 unit="sample",
816 ):
817 if self.blockwise is False and not isinstance(
818 pp.model_description, v0_4.ModelDescr
819 ):
820 try:
821 _ = pp.model_description.validate_input_tensors(
822 sample_in.as_arrays()
823 )
824 except Exception as e:
825 logger.warning(
826 "Input sample '{}' failed validation for whole-sample prediction: {}\n"
827 + "Consider using blockwise processing, e.g. with `--blockwise=10` to process inputs in blocks.",
828 sample_in.id,
829 e,
830 )
832 yield (predict_method(sample_in), sp_out)
835class PredictBlockArtifactsCmd(PredictCmd):
836 """Command to inspect block artifacts by subtracting the combined, blockwise predictions from a whole sample prediction.
838 Note:
839 - This command intentionally uses a small blocksize (default: 1) to create block artifacts for testing purposes.
840 - Typical sources of block artifacts include:
841 - Described halo is smaller than the model's receptive field
842 - Normalization layers inside the network cannot aggregate statistics over the whole sample.
843 """
845 blockwise: Union[Literal[True], int] = 1
846 """Process inputs blockwise
848 - If an integer is given, it is used as the blocksize parameter 'n' for blockwise processing.
849 The blockize parameter determines the block size along axes with parameterized input size
850 by adding n*step_size to the minimum valid input size.
851 - If `True`, the blocksize parameter is set to 10.
853 Defaults to a small blocksize to intentionally create block artifacts for testing purposes.
854 """
856 def cli_cmd(self):
857 for (out_sample, out_path), (out_sample_blockwise, _) in zip(
858 self._yield_predictions(False), self._yield_predictions(self.blockwise)
859 ):
860 diff_sample = self._subtract_samples(out_sample, out_sample_blockwise)
861 for k, v_a in out_sample.stat.items():
862 v_b = out_sample_blockwise.stat.get(k)
863 if v_b is None:
864 logger.error(
865 "measure '{}' not found in blockwise prediction statistics", k
866 )
867 elif not np.not_equal(v_a, v_b):
868 logger.error(
869 "measure '{}' has different values (whole sample!=blockwise): {}!={}",
870 k,
871 v_a,
872 v_b,
873 )
875 save_sample(out_path, diff_sample)
877 @staticmethod
878 def _subtract_samples(a: Sample, b: Sample) -> Sample:
879 return Sample(
880 members={t: a.members[t] - b.members[t] for t in a.members},
881 id=a.id,
882 stat=a.stat,
883 )
886class AddWeightsCmd(CmdBase, WithSource, WithSummaryLogging):
887 """Add additional weights to a model description by converting from available formats."""
889 output: CliPositionalArg[Path]
890 """The path to write the updated model package to."""
892 source_format: Optional[SupportedWeightsFormat] = Field(None, alias="source-format")
893 """Exclusively use these weights to convert to other formats."""
895 target_format: Optional[SupportedWeightsFormat] = Field(None, alias="target-format")
896 """Exclusively add this weight format."""
898 verbose: bool = False
899 """Log more (error) output."""
901 tracing: bool = True
902 """Allow tracing when converting pytorch_state_dict to torchscript
903 (still uses scripting if possible)."""
905 def cli_cmd(self):
906 model_descr = ensure_description_is_model(self.descr)
907 if isinstance(model_descr, v0_4.ModelDescr):
908 raise TypeError(
909 f"model format {model_descr.format_version} not supported."
910 + " Please update the model first."
911 )
912 updated_model_descr = add_weights(
913 model_descr,
914 output_path=self.output,
915 source_format=self.source_format,
916 target_format=self.target_format,
917 verbose=self.verbose,
918 allow_tracing=self.tracing,
919 )
920 self.log(updated_model_descr)
923class EmptyCache(CmdBase):
924 """Empty the bioimageio cache directory."""
926 def cli_cmd(self):
927 empty_cache()
930JSON_FILE = "bioimageio-cli.json"
931YAML_FILE = "bioimageio-cli.yaml"
934class Bioimageio(
935 BaseSettings,
936 cli_implicit_flags=True,
937 cli_parse_args=True,
938 cli_prog_name="bioimageio",
939 cli_use_class_docs_for_groups=True,
940 use_attribute_docstrings=True,
941):
942 """bioimageio - CLI for bioimage.io resources 🦒"""
944 model_config = SettingsConfigDict(
945 json_file=JSON_FILE,
946 yaml_file=YAML_FILE,
947 )
949 validate_format: CliSubCommand[ValidateFormatCmd] = Field(alias="validate-format")
950 """Check a resource's metadata format"""
952 test: CliSubCommand[TestCmd]
953 """Test a bioimageio resource (beyond meta data formatting)"""
955 package: CliSubCommand[PackageCmd]
956 """Package a resource"""
958 predict: CliSubCommand[PredictCmd]
959 """Predict with a model resource"""
961 predict_block_artifacts: CliSubCommand[PredictBlockArtifactsCmd] = Field(
962 alias="predict-block-artifacts"
963 )
964 """Save the difference between predicting blowise and whole sample to check for block artifacts."""
966 update_format: CliSubCommand[UpdateFormatCmd] = Field(alias="update-format")
967 """Update the metadata format"""
969 update_hashes: CliSubCommand[UpdateHashesCmd] = Field(alias="update-hashes")
970 """Create a bioimageio.yaml description with updated file hashes."""
972 add_weights: CliSubCommand[AddWeightsCmd] = Field(alias="add-weights")
973 """Add additional weights to a model description by converting from available formats."""
975 empty_cache: CliSubCommand[EmptyCache] = Field(alias="empty-cache")
976 """Empty the bioimageio cache directory."""
978 @classmethod
979 def settings_customise_sources(
980 cls,
981 settings_cls: Type[BaseSettings],
982 init_settings: PydanticBaseSettingsSource,
983 env_settings: PydanticBaseSettingsSource,
984 dotenv_settings: PydanticBaseSettingsSource,
985 file_secret_settings: PydanticBaseSettingsSource,
986 ) -> Tuple[PydanticBaseSettingsSource, ...]:
987 cli: CliSettingsSource[BaseSettings] = CliSettingsSource(
988 settings_cls,
989 cli_parse_args=True,
990 formatter_class=RawTextHelpFormatter,
991 )
992 sys_args = pformat(sys.argv)
993 logger.info("starting CLI with arguments:\n{}", sys_args)
994 return (
995 cli,
996 init_settings,
997 YamlConfigSettingsSource(settings_cls),
998 JsonConfigSettingsSource(settings_cls),
999 )
1001 @model_validator(mode="before")
1002 @classmethod
1003 def _log(cls, data: Any):
1004 logger.info(
1005 "loaded CLI input:\n{}",
1006 pformat({k: v for k, v in data.items() if v is not None}),
1007 )
1008 return data
1010 def cli_cmd(self) -> None:
1011 logger.info(
1012 "executing CLI command:\n{}",
1013 pformat({k: v for k, v in self.model_dump().items() if v is not None}),
1014 )
1015 _ = CliApp.run_subcommand(self)
1018assert isinstance(Bioimageio.__doc__, str)
1019Bioimageio.__doc__ += f"""
1021library versions:
1022 bioimageio.core {__version__}
1023 bioimageio.spec {bioimageio.spec.__version__}
1025spec format versions:
1026 model RDF {ModelDescr.implemented_format_version}
1027 dataset RDF {DatasetDescr.implemented_format_version}
1028 notebook RDF {NotebookDescr.implemented_format_version}
1030"""
1033def _get_sample_ids(
1034 input_paths: Sequence[Mapping[MemberId, Path]],
1035) -> Sequence[SampleId]:
1036 """Get sample ids for given input paths, based on the common path per sample.
1038 Falls back to sample01, samle02, etc..."""
1040 matcher = SequenceMatcher()
1042 def get_common_seq(seqs: Sequence[Sequence[str]]) -> Sequence[str]:
1043 """extract a common sequence from multiple sequences
1044 (order sensitive; strips whitespace and slashes)
1045 """
1046 common = seqs[0]
1048 for seq in seqs[1:]:
1049 if not seq:
1050 continue
1051 matcher.set_seqs(common, seq)
1052 i, _, size = matcher.find_longest_match()
1053 common = common[i : i + size]
1055 if isinstance(common, str):
1056 common = common.strip().strip("/")
1057 else:
1058 common = [cs for c in common if (cs := c.strip().strip("/"))]
1060 if not common:
1061 raise ValueError(f"failed to find common sequence for {seqs}")
1063 return common
1065 def get_shorter_diff(seqs: Sequence[Sequence[str]]) -> List[Sequence[str]]:
1066 """get a shorter sequence whose entries are still unique
1067 (order sensitive, not minimal sequence)
1068 """
1069 min_seq_len = min(len(s) for s in seqs)
1070 # cut from the start
1071 for start in range(min_seq_len - 1, -1, -1):
1072 shortened = [s[start:] for s in seqs]
1073 if len(set(shortened)) == len(seqs):
1074 min_seq_len -= start
1075 break
1076 else:
1077 seen: Set[Sequence[str]] = set()
1078 dupes = [s for s in seqs if s in seen or seen.add(s)]
1079 raise ValueError(f"Found duplicate entries {dupes}")
1081 # cut from the end
1082 for end in range(min_seq_len - 1, 1, -1):
1083 shortened = [s[:end] for s in shortened]
1084 if len(set(shortened)) == len(seqs):
1085 break
1087 return shortened
1089 full_tensor_ids = [
1090 sorted(
1091 p.resolve().with_suffix("").as_posix() for p in input_sample_paths.values()
1092 )
1093 for input_sample_paths in input_paths
1094 ]
1095 try:
1096 long_sample_ids = [get_common_seq(t) for t in full_tensor_ids]
1097 sample_ids = get_shorter_diff(long_sample_ids)
1098 except ValueError as e:
1099 raise ValueError(f"failed to extract sample ids: {e}")
1101 return sample_ids