Coverage for src/bioimageio/core/cli.py: 80%
445 statements
« prev ^ index » next coverage.py v7.14.2, created at 2026-06-22 16:54 +0000
« prev ^ index » next coverage.py v7.14.2, created at 2026-06-22 16:54 +0000
1"""bioimageio CLI
3Note: Some docstrings use a hair space ' '
4 to place the added '(default: ...)' on a new line.
5"""
7import json
8import shutil
9import subprocess
10import sys
11from abc import ABC
12from argparse import RawTextHelpFormatter
13from difflib import SequenceMatcher
14from functools import cached_property, partial
15from io import StringIO
16from pathlib import Path
17from pprint import pformat, pprint
18from typing import (
19 Annotated,
20 Any,
21 Dict,
22 Iterable,
23 List,
24 Literal,
25 Mapping,
26 Optional,
27 Sequence,
28 Set,
29 Tuple,
30 Type,
31 Union,
32)
34import numpy as np
35import rich.markdown
36from loguru import logger
37from pydantic import (
38 AliasChoices,
39 BaseModel,
40 Field,
41 PlainSerializer,
42 WithJsonSchema,
43 model_validator,
44)
45from pydantic_settings import (
46 BaseSettings,
47 CliApp,
48 CliPositionalArg,
49 CliSettingsSource,
50 CliSubCommand,
51 JsonConfigSettingsSource,
52 PydanticBaseSettingsSource,
53 SettingsConfigDict,
54 YamlConfigSettingsSource,
55)
56from tqdm import tqdm
57from typing_extensions import assert_never
59import bioimageio.spec
60from bioimageio.core import __version__
61from bioimageio.spec import (
62 AnyModelDescr,
63 InvalidDescr,
64 ResourceDescr,
65 load_description,
66 save_bioimageio_package,
67 save_bioimageio_package_as_folder,
68 save_bioimageio_yaml_only,
69 settings,
70 update_format,
71 update_hashes,
72)
73from bioimageio.spec._internal.io import is_yaml_value
74from bioimageio.spec._internal.io_utils import open_bioimageio_yaml
75from bioimageio.spec._internal.types import FormatVersionPlaceholder, NotEmpty
76from bioimageio.spec.dataset import DatasetDescr
77from bioimageio.spec.model import ModelDescr, v0_4, v0_5
78from bioimageio.spec.notebook import NotebookDescr
79from bioimageio.spec.utils import (
80 empty_cache,
81 ensure_description_is_model,
82 get_reader,
83 write_yaml,
84)
86from ._prediction_pipeline import (
87 create_prediction_pipeline,
88 create_remote_prediction_pipeline,
89)
90from .commands import WeightFormatArgAll, WeightFormatArgAny, package, test
91from .common import MemberId, SampleId, SupportedWeightsFormat
92from .digest_spec import get_member_ids, load_sample_for_model
93from .io import load_stat, save_sample, save_stat
94from .proc_setup import get_required_dataset_measures
95from .remote_backends import create_remote_model_adapter
96from .sample import Sample
97from .stat_calculators import StatsCalculator
98from .stat_measures import Measure, MeasureValue, Stat
99from .utils import compare
100from .weight_converters._add_weights import add_weights
102WEIGHT_FORMAT_ALIASES = AliasChoices(
103 "weight-format",
104 "weights-format",
105 "weight_format",
106 "weights_format",
107)
110class CmdBase(BaseModel, use_attribute_docstrings=True, cli_implicit_flags=True):
111 pass
114class ArgMixin(BaseModel, use_attribute_docstrings=True, cli_implicit_flags=True):
115 pass
118class WithSummaryLogging(ArgMixin):
119 summary: List[Union[Literal["display"], Path]] = Field(
120 default_factory=lambda: ["display"],
121 examples=[
122 Path("summary.md"),
123 Path("bioimageio_summaries/"),
124 ["display", Path("summary.md")],
125 ],
126 )
127 """Display the validation summary or save it as JSON, Markdown or HTML.
128 The format is chosen based on the suffix: `.json`, `.md`, `.html`.
129 If a folder is given (path w/o suffix) the summary is saved in all formats.
130 Choose/add `"display"` to render the validation summary to the terminal.
131 """
133 def log(self, descr: Union[ResourceDescr, InvalidDescr]):
134 _ = descr.validation_summary.log(self.summary)
137class WithSource(ArgMixin):
138 source: CliPositionalArg[str]
139 """Url/path to a (folder with a) bioimageio.yaml/rdf.yaml file
140 or a bioimage.io resource identifier, e.g. 'affable-shark'"""
142 @cached_property
143 def descr(self):
144 return load_description(self.source)
146 @property
147 def descr_id(self) -> str:
148 """a more user-friendly description id
149 (replacing legacy ids with their nicknames)
150 """
151 if isinstance(self.descr, InvalidDescr):
152 return str(getattr(self.descr, "id", getattr(self.descr, "name")))
154 nickname = None
155 if (
156 isinstance(self.descr.config, v0_5.Config)
157 and (bio_config := self.descr.config.bioimageio)
158 and bio_config.model_extra is not None
159 ):
160 nickname = bio_config.model_extra.get("nickname")
162 return str(nickname or self.descr.id or self.descr.name)
165class ValidateFormatCmd(CmdBase, WithSource, WithSummaryLogging):
166 """Validate the meta data format of a bioimageio resource."""
168 perform_io_checks: bool = Field(
169 settings.perform_io_checks, alias="perform-io-checks"
170 )
171 """Wether or not to perform validations that requires downloading remote files.
172 Note: Default value is set by `BIOIMAGEIO_PERFORM_IO_CHECKS` environment variable.
173 """
175 @cached_property
176 def descr(self):
177 return load_description(self.source, perform_io_checks=self.perform_io_checks)
179 def cli_cmd(self):
180 self.log(self.descr)
181 sys.exit(
182 0
183 if self.descr.validation_summary.status in ("valid-format", "passed")
184 else 1
185 )
188class TestCmd(CmdBase, WithSource, WithSummaryLogging):
189 """Test a bioimageio resource (beyond meta data formatting)."""
191 weight_format: WeightFormatArgAll = Field(
192 "all",
193 alias="weight-format",
194 validation_alias=WEIGHT_FORMAT_ALIASES,
195 )
196 """The weight format to limit testing to.
198 (only relevant for model resources)"""
200 devices: Optional[List[str]] = Field(
201 None, validation_alias=AliasChoices("devices", "device")
202 )
203 """Device(s) to use"""
205 runtime_env: Union[Literal["currently-active", "as-described"], Path] = Field(
206 "currently-active", alias="runtime-env"
207 )
208 """The python environment to run the tests in
209 - `"currently-active"`: use active Python interpreter
210 - `"as-described"`: generate a conda environment YAML file based on the model
211 weights description.
212 - A path to a conda environment YAML.
213 Note: The `bioimageio.core` dependency will be added automatically if not present.
214 """
216 working_dir: Optional[Path] = Field(None, alias="working-dir")
217 """(for debugging) Directory to save any temporary files."""
219 determinism: Literal["seed_only", "full"] = "seed_only"
220 """Modes to improve reproducibility of test outputs."""
222 stop_early: bool = Field(
223 False, alias="stop-early", validation_alias=AliasChoices("stop-early", "x")
224 )
225 """Do not run further subtests after a failed one."""
227 format_version: Union[FormatVersionPlaceholder, str] = Field(
228 "discover", alias="format-version"
229 )
230 """The format version to use for testing.
231 - 'latest': Use the latest implemented format version for the given resource type (may trigger auto updating)
232 - 'discover': Use the format version as described in the resource description
233 - '0.4', '0.5', ...: Use the specified format version (may trigger auto updating)
234 """
236 def cli_cmd(self):
237 sys.exit(
238 test(
239 self.descr,
240 weight_format=self.weight_format,
241 devices=self.devices,
242 summary=self.summary,
243 runtime_env=self.runtime_env,
244 determinism=self.determinism,
245 format_version=self.format_version,
246 working_dir=self.working_dir,
247 )
248 )
251class PackageCmd(CmdBase, WithSource, WithSummaryLogging):
252 """Save a resource's metadata with its associated files."""
254 path: CliPositionalArg[Path]
255 """The path to write the (zipped) package to.
256 If it does not have a `.zip` suffix
257 this command will save the package as an unzipped folder instead."""
259 weight_format: WeightFormatArgAll = Field(
260 "all",
261 alias="weight-format",
262 validation_alias=WEIGHT_FORMAT_ALIASES,
263 )
264 """The weight format to include in the package (for model descriptions only)."""
266 def cli_cmd(self):
267 if isinstance(self.descr, InvalidDescr):
268 self.log(self.descr)
269 raise ValueError(f"Invalid {self.descr.type} description.")
271 sys.exit(
272 package(
273 self.descr,
274 self.path,
275 weight_format=self.weight_format,
276 )
277 )
280def _get_stat(
281 model_descr: AnyModelDescr,
282 dataset: Iterable[Sample],
283 dataset_length: int,
284 stats_path: Path,
285) -> Stat:
286 req_dataset_meas, _ = get_required_dataset_measures(model_descr)
287 if not req_dataset_meas:
288 return {}
290 req_dataset_meas, _ = get_required_dataset_measures(model_descr)
292 if stats_path.exists():
293 logger.info("loading precomputed dataset measures from {}", stats_path)
294 stat = load_stat(stats_path)
295 for m in req_dataset_meas:
296 if m not in stat:
297 raise ValueError(f"Missing {m} in {stats_path}")
299 return stat
301 stats_calc = StatsCalculator(req_dataset_meas)
303 for sample in tqdm(
304 dataset, total=dataset_length, desc="precomputing dataset stats", unit="sample"
305 ):
306 stats_calc.update(sample)
308 stat: Dict[Measure, MeasureValue] = {k: v for k, v in stats_calc.finalize().items()}
309 save_stat(stat, stats_path)
310 return stat
313class UpdateCmdBase(CmdBase, WithSource, ABC):
314 output: Union[Literal["display", "stdout"], Path] = "display"
315 """Output updated bioimageio.yaml to the terminal or write to a file.
316 Notes:
317 - `"display"`: Render to the terminal with syntax highlighting.
318 - `"stdout"`: Write to sys.stdout without syntax highligthing.
319 (More convenient for copying the updated bioimageio.yaml from the terminal.)
320 """
322 diff: Union[bool, Path] = Field(True, alias="diff")
323 """Output a diff of original and updated bioimageio.yaml.
324 If a given path has an `.html` extension, a standalone HTML file is written,
325 otherwise the diff is saved in unified diff format (pure text).
326 """
328 exclude_unset: bool = Field(True, alias="exclude-unset")
329 """Exclude fields that have not explicitly be set."""
331 exclude_defaults: bool = Field(False, alias="exclude-defaults")
332 """Exclude fields that have the default value (even if set explicitly)."""
334 @cached_property
335 def updated(self) -> Union[ResourceDescr, InvalidDescr]:
336 raise NotImplementedError
338 def cli_cmd(self):
339 original_yaml = open_bioimageio_yaml(self.source).unparsed_content
340 assert isinstance(original_yaml, str)
341 stream = StringIO()
343 save_bioimageio_yaml_only(
344 self.updated,
345 stream,
346 exclude_unset=self.exclude_unset,
347 exclude_defaults=self.exclude_defaults,
348 )
349 updated_yaml = stream.getvalue()
351 diff = compare(
352 original_yaml.split("\n"),
353 updated_yaml.split("\n"),
354 diff_format=(
355 "html"
356 if isinstance(self.diff, Path) and self.diff.suffix == ".html"
357 else "unified"
358 ),
359 )
361 if isinstance(self.diff, Path):
362 _ = self.diff.write_text(diff, encoding="utf-8")
363 elif self.diff:
364 console = rich.console.Console()
365 diff_md = f"## Diff\n\n````````diff\n{diff}\n````````"
366 console.print(rich.markdown.Markdown(diff_md))
368 if isinstance(self.output, Path):
369 if self.output.suffix in (".yaml", ".yml"):
370 _ = self.output.write_text(updated_yaml, encoding="utf-8")
371 logger.info(f"written updated description to {self.output}")
372 elif isinstance(self.updated, InvalidDescr):
373 raise ValueError(
374 f"Cannot save invalid description package to {self.output}."
375 + " To save the metadata only, choose an output with a '.yaml' extension."
376 )
378 elif not self.output.suffix:
379 _ = save_bioimageio_package_as_folder(
380 self.updated, output_path=self.output
381 )
382 else:
383 _ = save_bioimageio_package(self.updated, output_path=self.output)
385 elif self.output == "display":
386 updated_md = f"## Updated bioimageio.yaml\n\n```yaml\n{updated_yaml}\n```"
387 rich.console.Console().print(rich.markdown.Markdown(updated_md))
388 elif self.output == "stdout":
389 print(updated_yaml)
390 else:
391 assert_never(self.output)
393 if isinstance(self.updated, InvalidDescr):
394 logger.warning("Update resulted in invalid description")
395 _ = self.updated.validation_summary.display()
398class UpdateFormatCmd(UpdateCmdBase):
399 """Update the metadata format to the latest format version."""
401 exclude_defaults: bool = Field(True, alias="exclude-defaults")
402 """Exclude fields that have the default value (even if set explicitly).
404 Note:
405 The update process sets most unset fields explicitly with their default value.
406 """
408 perform_io_checks: bool = Field(
409 settings.perform_io_checks, alias="perform-io-checks"
410 )
411 """Wether or not to attempt validation that may require file download.
412 If `True` file hash values are added if not present."""
414 @cached_property
415 def updated(self):
416 return update_format(
417 self.source,
418 exclude_defaults=self.exclude_defaults,
419 perform_io_checks=self.perform_io_checks,
420 )
423class UpdateHashesCmd(UpdateCmdBase):
424 """Create a bioimageio.yaml description with updated file hashes."""
426 @cached_property
427 def updated(self):
428 return update_hashes(self.source)
431class PredictCmd(CmdBase, WithSource):
432 """Run inference on your data with a bioimage.io model."""
434 inputs: NotEmpty[List[Union[str, NotEmpty[List[str]]]]] = Field(
435 default_factory=lambda: ["{input_id}/001.tif"],
436 validation_alias=AliasChoices("inputs", "input"),
437 )
438 """Model input sample paths (for each input tensor)
440 The input paths are expected to have shape...
441 - (n_samples,) or (n_samples,1) for models expecting a single input tensor
442 - (n_samples,) containing the substring '{input_id}', or
443 - (n_samples, n_model_inputs) to provide each input tensor path explicitly.
445 All substrings that are replaced by metadata from the model description:
446 - '{model_id}'
447 - '{input_id}'
449 Example inputs to process sample 'a' and 'b'
450 for a model expecting a 'raw' and a 'mask' input tensor:
451 --inputs="[[\\"a_raw.tif\\",\\"a_mask.tif\\"],[\\"b_raw.tif\\",\\"b_mask.tif\\"]]"
452 (Note that JSON double quotes need to be escaped.)
454 Alternatively a `bioimageio-cli.yaml` (or `bioimageio-cli.json`) file
455 may provide the arguments, e.g.:
456 ```yaml
457 inputs:
458 - [a_raw.tif, a_mask.tif]
459 - [b_raw.tif, b_mask.tif]
460 ```
462 `.npy` and any file extension supported by imageio are supported.
463 Aavailable formats are listed at
464 https://imageio.readthedocs.io/en/stable/formats/index.html#all-formats.
465 Some formats have additional dependencies.
468 """
470 outputs: Union[str, NotEmpty[Tuple[str, ...]]] = Field(
471 "outputs_{model_id}/{output_id}/{sample_id}.tif",
472 validation_alias=AliasChoices("outputs", "output"),
473 )
474 """Model output path pattern (per output tensor)
476 All substrings that are replaced:
477 - '{model_id}' (from model description)
478 - '{output_id}' (from model description)
479 - '{sample_id}' (extracted from input paths)
481 """
483 overwrite: bool = False
484 """allow overwriting existing output files"""
486 blockwise: Union[bool, int] = False
487 """Process inputs blockwise
489 - If an integer is given, it is used as the blocksize parameter 'n' for blockwise processing.
490 The blockize parameter determines the block size along axes with parameterized input size
491 by adding n*step_size to the minimum valid input size.
492 - If `True`, the blocksize parameter is set to 10.
493 - If `False`, inputs are processed as a whole without blocking.
495 """
497 stats: Annotated[
498 Path,
499 WithJsonSchema({"type": "string"}),
500 PlainSerializer(lambda p: p.as_posix(), return_type=str),
501 ] = Path("precomputed_statistics.json")
502 """path to dataset statistics
503 (will be written if it does not exist
504 and the model requires statistical dataset measures)
505 """
507 preview: bool = False
508 """preview which files would be processed
509 and what outputs would be generated."""
511 weight_format: WeightFormatArgAny = Field(
512 "any",
513 alias="weight-format",
514 validation_alias=WEIGHT_FORMAT_ALIASES,
515 )
516 """The weight format to use."""
518 devices: Optional[List[str]] = Field(
519 None, validation_alias=AliasChoices("devices", "device")
520 )
521 """Device(s) to use"""
523 server: Optional[str] = None
524 """The URL or Hugging Face space name of a running bioimageio (gradio) server instance to use as a remote backend for prediction."""
526 pre_post_processing_location: Literal["local", "remote"] = Field(
527 "local", alias="pre-post-processing-location"
528 )
529 """Where to run preprocessing/postprocessing operations when using `--server`.
531 - `local`: Run preprocessing/postprocessing locally and only model inference on the server.
532 - `remote`: Run preprocessing/postprocessing on the server as well.
535 """
537 example: bool = False
538 """generate and run an example
540 1. downloads example model inputs
541 2. creates a `{model_id}_example` folder
542 3. writes input arguments to `{model_id}_example/bioimageio-cli.yaml`
543 4. executes a preview dry-run
544 5. executes prediction with example input
547 """
549 def _example(self):
550 model_descr = ensure_description_is_model(self.descr)
551 input_ids = get_member_ids(model_descr.inputs)
552 example_inputs = (
553 model_descr.sample_inputs
554 if isinstance(model_descr, v0_4.ModelDescr)
555 else [
556 t
557 for ipt in model_descr.inputs
558 if (t := ipt.sample_tensor or ipt.test_tensor)
559 ]
560 )
561 if not example_inputs:
562 raise ValueError(f"{self.descr_id} does not specify any example inputs.")
564 inputs001: List[str] = []
565 example_path = Path(f"{self.descr_id}_example")
566 example_path.mkdir(exist_ok=True)
568 for t, src in zip(input_ids, example_inputs):
569 reader = get_reader(src)
570 dst = Path(f"{example_path}/{t}/001{reader.suffix}")
571 dst.parent.mkdir(parents=True, exist_ok=True)
572 inputs001.append(dst.as_posix())
573 with dst.open("wb") as f:
574 shutil.copyfileobj(reader, f)
576 inputs = [inputs001]
577 output_pattern = f"{example_path}/outputs/{{output_id}}/{{sample_id}}.tif"
579 bioimageio_cli_path = example_path / YAML_FILE
580 stats_file = "precomputed_statistics.json"
581 stats = (example_path / stats_file).as_posix()
582 cli_example_args = dict(
583 inputs=inputs,
584 outputs=output_pattern,
585 stats=stats_file,
586 blockwise=self.blockwise,
587 )
588 assert is_yaml_value(cli_example_args), cli_example_args
589 write_yaml(
590 cli_example_args,
591 bioimageio_cli_path,
592 )
594 yaml_file_content = None
596 # escaped double quotes
597 inputs_json = json.dumps(inputs)
598 inputs_escaped = inputs_json.replace('"', r"\"")
599 source_escaped = self.source.replace('"', r"\"")
601 def get_example_command(preview: bool, escape: bool = False):
602 q: str = '"' if escape else ""
604 return [
605 "bioimageio",
606 "predict",
607 # --no-preview not supported for py=3.8
608 *(["--preview"] if preview else []),
609 "--overwrite",
610 f"--blockwise={self.blockwise}",
611 f"--stats={q}{stats}{q}",
612 f"--inputs={q}{inputs_escaped if escape else inputs_json}{q}",
613 f"--outputs={q}{output_pattern}{q}",
614 f"{q}{source_escaped if escape else self.source}{q}",
615 ]
617 if Path(YAML_FILE).exists():
618 logger.info(
619 "temporarily removing '{}' to execute example prediction", YAML_FILE
620 )
621 yaml_file_content = Path(YAML_FILE).read_bytes()
622 Path(YAML_FILE).unlink()
624 try:
625 _ = subprocess.run(get_example_command(True), check=True)
626 _ = subprocess.run(get_example_command(False), check=True)
627 finally:
628 if yaml_file_content is not None:
629 _ = Path(YAML_FILE).write_bytes(yaml_file_content)
630 logger.debug("restored '{}'", YAML_FILE)
632 print(
633 "🎉 Sucessfully ran example prediction!\n"
634 + "To predict the example input using the CLI example config file"
635 + f" {example_path / YAML_FILE}, execute `bioimageio predict` from {example_path}:\n"
636 + f"$ cd {str(example_path)}\n"
637 + f'$ bioimageio predict "{source_escaped}"\n\n'
638 + "Alternatively run the following command"
639 + " in the current workind directory, not the example folder:\n$ "
640 + " ".join(get_example_command(False, escape=True))
641 + f"\n(note that a local '{JSON_FILE}' or '{YAML_FILE}' may interfere with this)"
642 )
644 def cli_cmd(self):
645 try:
646 for out_sample, out_path in self._yield_predictions(self.blockwise):
647 save_sample(out_path, out_sample)
648 except Exception as e:
649 if not self.blockwise:
650 raise RuntimeError(
651 f"Prediction failed ({e}).\nConsider using blockwise processing, "
652 + "e.g. with `--blockwise=10` to process inputs in blocks."
653 ) from e
654 raise e
656 def _yield_predictions(self, blockwise: Union[bool, int]):
657 if self.example:
658 return self._example()
660 model_descr = ensure_description_is_model(self.descr)
662 input_ids = get_member_ids(model_descr.inputs)
663 output_ids = get_member_ids(model_descr.outputs)
665 minimum_input_ids = tuple(
666 str(ipt.id) if isinstance(ipt, v0_5.InputTensorDescr) else str(ipt.name)
667 for ipt in model_descr.inputs
668 if not isinstance(ipt, v0_5.InputTensorDescr) or not ipt.optional
669 )
670 maximum_input_ids = tuple(
671 str(ipt.id) if isinstance(ipt, v0_5.InputTensorDescr) else str(ipt.name)
672 for ipt in model_descr.inputs
673 )
675 def expand_inputs(i: int, ipt: Union[str, Sequence[str]]) -> Tuple[str, ...]:
676 if isinstance(ipt, str):
677 ipts = tuple(
678 ipt.format(model_id=self.descr_id, input_id=t) for t in input_ids
679 )
680 else:
681 ipts = tuple(
682 p.format(model_id=self.descr_id, input_id=t)
683 for t, p in zip(input_ids, ipt)
684 )
686 if len(set(ipts)) < len(ipts):
687 if len(minimum_input_ids) == len(maximum_input_ids):
688 n = len(minimum_input_ids)
689 else:
690 n = f"{len(minimum_input_ids)}-{len(maximum_input_ids)}"
692 raise ValueError(
693 f"[input sample #{i}] Include '{{input_id}}' in path pattern or explicitly specify {n} distinct input paths (got {ipt})"
694 )
696 if len(ipts) < len(minimum_input_ids):
697 raise ValueError(
698 f"[input sample #{i}] Expected at least {len(minimum_input_ids)} inputs {minimum_input_ids}, got {ipts}"
699 )
701 if len(ipts) > len(maximum_input_ids):
702 raise ValueError(
703 f"Expected at most {len(maximum_input_ids)} inputs {maximum_input_ids}, got {ipts}"
704 )
706 return ipts
708 inputs = [expand_inputs(i, ipt) for i, ipt in enumerate(self.inputs, start=1)]
710 sample_paths_in = [
711 {t: Path(p) for t, p in zip(input_ids, ipts)} for ipts in inputs
712 ]
714 sample_ids = _get_sample_ids(sample_paths_in)
716 def expand_outputs():
717 if isinstance(self.outputs, str):
718 outputs = [
719 tuple(
720 Path(
721 self.outputs.format(
722 model_id=self.descr_id, output_id=t, sample_id=s
723 )
724 )
725 for t in output_ids
726 )
727 for s in sample_ids
728 ]
729 else:
730 outputs = [
731 tuple(
732 Path(p.format(model_id=self.descr_id, output_id=t, sample_id=s))
733 for t, p in zip(output_ids, self.outputs)
734 )
735 for s in sample_ids
736 ]
737 # check for distinctness and correct number within each output sample
738 for i, out in enumerate(outputs, start=1):
739 if len(set(out)) < len(out):
740 raise ValueError(
741 f"[output sample #{i}] Include '{{output_id}}' in path pattern or explicitly specify {len(output_ids)} distinct output paths (got {out})"
742 )
744 if len(out) != len(output_ids):
745 raise ValueError(
746 f"[output sample #{i}] Expected {len(output_ids)} outputs {output_ids}, got {out}"
747 )
749 # check for distinctness across all output samples
750 all_output_paths = [p for out in outputs for p in out]
751 if len(set(all_output_paths)) < len(all_output_paths):
752 raise ValueError(
753 "Output paths are not distinct across samples. "
754 + "Make sure to include '{{sample_id}}' in the output path pattern."
755 )
757 return outputs
759 outputs = expand_outputs()
761 sample_paths_out = [
762 {MemberId(t): Path(p) for t, p in zip(output_ids, out)} for out in outputs
763 ]
765 if not self.overwrite:
766 for sample_paths in sample_paths_out:
767 for p in sample_paths.values():
768 if p.exists():
769 raise FileExistsError(
770 f"{p} already exists. use --overwrite to (re-)write outputs anyway."
771 )
772 if self.preview:
773 print("🛈 bioimageio prediction preview structure:")
774 pprint(
775 {
776 "{sample_id}": dict(
777 inputs={"{input_id}": "<input path>"},
778 outputs={"{output_id}": "<output path>"},
779 )
780 }
781 )
782 print("🔎 bioimageio prediction preview output:")
783 pprint(
784 {
785 s: dict(
786 inputs={t: p.as_posix() for t, p in sp_in.items()},
787 outputs={t: p.as_posix() for t, p in sp_out.items()},
788 )
789 for s, sp_in, sp_out in zip(
790 sample_ids, sample_paths_in, sample_paths_out
791 )
792 }
793 )
794 return
796 def input_dataset(stat: Stat):
797 for s, sp_in in zip(sample_ids, sample_paths_in):
798 yield load_sample_for_model(
799 model=model_descr,
800 paths=sp_in,
801 stat=stat,
802 sample_id=s,
803 )
805 stat: Dict[Measure, MeasureValue] = dict(
806 _get_stat(
807 model_descr, input_dataset({}), len(sample_ids), self.stats
808 ).items()
809 )
811 if self.server is not None and self.pre_post_processing_location == "remote":
812 pp = create_remote_prediction_pipeline(model_descr, server=self.server)
813 else:
814 if self.server is None:
815 model_adapter = None
816 else:
817 assert self.pre_post_processing_location == "local"
818 model_adapter = create_remote_model_adapter(
819 model_descr, server=self.server
820 )
822 pp = create_prediction_pipeline(
823 model_descr,
824 weight_format=None
825 if self.weight_format == "any"
826 else self.weight_format,
827 devices=self.devices,
828 model_adapter=model_adapter,
829 )
831 if blockwise:
832 predict_method = partial(
833 pp.predict_sample_with_blocking,
834 ns=None if isinstance(blockwise, bool) else blockwise,
835 )
836 else:
837 predict_method = pp.predict_sample_without_blocking
839 for sample_in, sp_out in tqdm(
840 zip(input_dataset(dict(stat)), sample_paths_out),
841 total=len(inputs),
842 desc=f"predict with {self.descr_id}",
843 unit="sample",
844 ):
845 if self.blockwise is False and not isinstance(
846 pp.model_description, v0_4.ModelDescr
847 ):
848 try:
849 _ = pp.model_description.validate_input_tensors(
850 sample_in.as_arrays()
851 )
852 except Exception as e:
853 logger.warning(
854 "Input sample '{}' failed validation for whole-sample prediction: {}\n"
855 + "Consider using blockwise processing, e.g. with `--blockwise=10` to process inputs in blocks.",
856 sample_in.id,
857 e,
858 )
860 yield (predict_method(sample_in), sp_out)
863class PredictBlockArtifactsCmd(PredictCmd):
864 """Command to inspect block artifacts by subtracting the combined, blockwise predictions from a whole sample prediction.
866 Note:
867 - This command intentionally uses a small blocksize (default: 1) to create block artifacts for testing purposes.
868 - Typical sources of block artifacts include:
869 - Described halo is smaller than the model's receptive field
870 - Normalization layers inside the network cannot aggregate statistics over the whole sample.
871 """
873 blockwise: Union[Literal[True], int] = 1
874 """Process inputs blockwise
876 - If an integer is given, it is used as the blocksize parameter 'n' for blockwise processing.
877 The blockize parameter determines the block size along axes with parameterized input size
878 by adding n*step_size to the minimum valid input size.
879 - If `True`, the blocksize parameter is set to 10.
881 Defaults to a small blocksize to intentionally create block artifacts for testing purposes.
882 """
884 def cli_cmd(self):
885 for (out_sample, out_path), (out_sample_blockwise, _) in zip(
886 self._yield_predictions(False), self._yield_predictions(self.blockwise)
887 ):
888 diff_sample = self._subtract_samples(out_sample, out_sample_blockwise)
889 for k, v_a in out_sample.stat.items():
890 v_b = out_sample_blockwise.stat.get(k)
891 if v_b is None:
892 logger.error(
893 "measure '{}' not found in blockwise prediction statistics", k
894 )
895 elif not np.not_equal(v_a, v_b):
896 logger.error(
897 "measure '{}' has different values (whole sample!=blockwise): {}!={}",
898 k,
899 v_a,
900 v_b,
901 )
903 save_sample(out_path, diff_sample)
905 @staticmethod
906 def _subtract_samples(a: Sample, b: Sample) -> Sample:
907 return Sample(
908 members={t: a.members[t] - b.members[t] for t in a.members},
909 id=a.id,
910 stat=a.stat,
911 )
914class AddWeightsCmd(CmdBase, WithSource, WithSummaryLogging):
915 """Add additional weights to a model description by converting from available formats."""
917 output: CliPositionalArg[Path]
918 """The path to write the updated model package to."""
920 source_format: Optional[SupportedWeightsFormat] = Field(None, alias="source-format")
921 """Exclusively use these weights to convert to other formats."""
923 target_format: Optional[SupportedWeightsFormat] = Field(None, alias="target-format")
924 """Exclusively add this weight format."""
926 verbose: bool = False
927 """Log more (error) output."""
929 tracing: bool = True
930 """Allow tracing when converting pytorch_state_dict to torchscript
931 (still uses scripting if possible)."""
933 def cli_cmd(self):
934 model_descr = ensure_description_is_model(self.descr)
935 if isinstance(model_descr, v0_4.ModelDescr):
936 raise TypeError(
937 f"model format {model_descr.format_version} not supported."
938 + " Please update the model first."
939 )
940 updated_model_descr = add_weights(
941 model_descr,
942 output_path=self.output,
943 source_format=self.source_format,
944 target_format=self.target_format,
945 verbose=self.verbose,
946 allow_tracing=self.tracing,
947 )
948 self.log(updated_model_descr)
951class EmptyCacheCmd(CmdBase):
952 """Empty the bioimageio cache directory."""
954 def cli_cmd(self):
955 empty_cache()
958class ServerCmd(CmdBase):
959 """Start a server to connect to with remote model adapters or remote prediction pipelines."""
961 backend: Literal["gradio"] = "gradio"
962 """The remote backend to use."""
964 port: Optional[int] = None
965 """The port to start the server on. If not given, a free port will be used."""
967 def cli_cmd(self) -> None:
968 try:
969 if self.backend == "gradio":
970 from .remote_backends.gradio.server import main
971 else:
972 assert_never(self.backend)
973 except ImportError as e:
974 raise ImportError(
975 f"{self.backend.capitalize()} is not installed. Please install the '{self.backend}-server' extra to use this command,"
976 + f" e.g. with `pip install bioimageio.core[{self.backend}-server]`."
977 ) from e
979 local_server_url = main(port=self.port)
980 logger.info(
981 "{} server shutdown at {}", self.backend.capitalize(), local_server_url
982 )
985JSON_FILE = "bioimageio-cli.json"
986YAML_FILE = "bioimageio-cli.yaml"
989class Bioimageio(
990 BaseSettings,
991 cli_implicit_flags=True,
992 cli_parse_args=True,
993 cli_prog_name="bioimageio",
994 cli_use_class_docs_for_groups=True,
995 use_attribute_docstrings=True,
996):
997 """bioimageio - CLI for bioimage.io resources 🦒"""
999 model_config = SettingsConfigDict(
1000 json_file=JSON_FILE,
1001 yaml_file=YAML_FILE,
1002 )
1004 validate_format: CliSubCommand[ValidateFormatCmd] = Field(alias="validate-format")
1005 """Check a resource's metadata format"""
1007 test: CliSubCommand[TestCmd]
1008 """Test a bioimageio resource (beyond meta data formatting)"""
1010 package: CliSubCommand[PackageCmd]
1011 """Package a resource"""
1013 predict: CliSubCommand[PredictCmd]
1014 """Predict with a model resource"""
1016 predict_block_artifacts: CliSubCommand[PredictBlockArtifactsCmd] = Field(
1017 alias="predict-block-artifacts"
1018 )
1019 """Save the difference between predicting blowise and whole sample to check for block artifacts."""
1021 update_format: CliSubCommand[UpdateFormatCmd] = Field(alias="update-format")
1022 """Update the metadata format"""
1024 update_hashes: CliSubCommand[UpdateHashesCmd] = Field(alias="update-hashes")
1025 """Create a bioimageio.yaml description with updated file hashes."""
1027 add_weights: CliSubCommand[AddWeightsCmd] = Field(alias="add-weights")
1028 """Add additional weights to a model description by converting from available formats."""
1030 empty_cache: CliSubCommand[EmptyCacheCmd] = Field(alias="empty-cache")
1031 """Empty the bioimageio cache directory."""
1033 server: CliSubCommand[ServerCmd]
1034 """Start a server to connect to with remote model adapters or remote prediction pipelines."""
1036 @classmethod
1037 def settings_customise_sources(
1038 cls,
1039 settings_cls: Type[BaseSettings],
1040 init_settings: PydanticBaseSettingsSource,
1041 env_settings: PydanticBaseSettingsSource,
1042 dotenv_settings: PydanticBaseSettingsSource,
1043 file_secret_settings: PydanticBaseSettingsSource,
1044 ) -> Tuple[PydanticBaseSettingsSource, ...]:
1045 cli: CliSettingsSource[BaseSettings] = CliSettingsSource(
1046 settings_cls,
1047 cli_parse_args=True,
1048 formatter_class=RawTextHelpFormatter,
1049 )
1050 sys_args = pformat(sys.argv)
1051 logger.info("starting CLI with arguments:\n{}", sys_args)
1052 return (
1053 cli,
1054 init_settings,
1055 YamlConfigSettingsSource(settings_cls),
1056 JsonConfigSettingsSource(settings_cls),
1057 )
1059 @model_validator(mode="before")
1060 @classmethod
1061 def _log(cls, data: Any):
1062 logger.info(
1063 "loaded CLI input:\n{}",
1064 pformat({k: v for k, v in data.items() if v is not None}),
1065 )
1066 return data
1068 def cli_cmd(self) -> None:
1069 logger.info(
1070 "executing CLI command:\n{}",
1071 pformat({k: v for k, v in self.model_dump().items() if v is not None}),
1072 )
1073 _ = CliApp.run_subcommand(self)
1076assert isinstance(Bioimageio.__doc__, str)
1077Bioimageio.__doc__ += f"""
1079library versions:
1080 bioimageio.core {__version__}
1081 bioimageio.spec {bioimageio.spec.__version__}
1083spec format versions:
1084 model RDF {ModelDescr.implemented_format_version}
1085 dataset RDF {DatasetDescr.implemented_format_version}
1086 notebook RDF {NotebookDescr.implemented_format_version}
1088"""
1091def _get_sample_ids(
1092 input_paths: Sequence[Mapping[MemberId, Path]],
1093) -> Sequence[SampleId]:
1094 """Get sample ids for given input paths, based on the common path per sample.
1096 Falls back to sample01, samle02, etc..."""
1098 matcher = SequenceMatcher()
1100 def get_common_seq(seqs: Sequence[Sequence[str]]) -> Sequence[str]:
1101 """extract a common sequence from multiple sequences
1102 (order sensitive; strips whitespace and slashes)
1103 """
1104 common = seqs[0]
1106 for seq in seqs[1:]:
1107 if not seq:
1108 continue
1109 matcher.set_seqs(common, seq)
1110 i, _, size = matcher.find_longest_match()
1111 common = common[i : i + size]
1113 if isinstance(common, str):
1114 common = common.strip().strip("/")
1115 else:
1116 common = [cs for c in common if (cs := c.strip().strip("/"))]
1118 if not common:
1119 raise ValueError(f"failed to find common sequence for {seqs}")
1121 return common
1123 def get_shorter_diff(seqs: Sequence[Sequence[str]]) -> List[Sequence[str]]:
1124 """get a shorter sequence whose entries are still unique
1125 (order sensitive, not minimal sequence)
1126 """
1127 min_seq_len = min(len(s) for s in seqs)
1128 # cut from the start
1129 for start in range(min_seq_len - 1, -1, -1):
1130 shortened = [s[start:] for s in seqs]
1131 if len(set(shortened)) == len(seqs):
1132 min_seq_len -= start
1133 break
1134 else:
1135 seen: Set[Sequence[str]] = set()
1136 dupes = [s for s in seqs if s in seen or seen.add(s)]
1137 raise ValueError(f"Found duplicate entries {dupes}")
1139 # cut from the end
1140 for end in range(min_seq_len - 1, 1, -1):
1141 shortened = [s[:end] for s in shortened]
1142 if len(set(shortened)) == len(seqs):
1143 break
1145 return shortened
1147 full_tensor_ids = [
1148 sorted(
1149 p.resolve().with_suffix("").as_posix() for p in input_sample_paths.values()
1150 )
1151 for input_sample_paths in input_paths
1152 ]
1153 try:
1154 long_sample_ids = [get_common_seq(t) for t in full_tensor_ids]
1155 sample_ids = get_shorter_diff(long_sample_ids)
1156 except ValueError as e:
1157 raise ValueError(f"failed to extract sample ids: {e}")
1159 return sample_ids