Coverage for src / bioimageio / core / cli.py: 82%
400 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 13:23 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 13:23 +0000
1"""bioimageio CLI
3Note: Some docstrings use a hair space ' '
4 to place the added '(default: ...)' on a new line.
5"""
7import json
8import shutil
9import subprocess
10import sys
11from abc import ABC
12from argparse import RawTextHelpFormatter
13from difflib import SequenceMatcher
14from functools import cached_property, partial
15from io import StringIO
16from pathlib import Path
17from pprint import pformat, pprint
18from typing import (
19 Annotated,
20 Any,
21 Dict,
22 Iterable,
23 List,
24 Literal,
25 Mapping,
26 Optional,
27 Sequence,
28 Set,
29 Tuple,
30 Type,
31 Union,
32)
34import numpy as np
35import rich.markdown
36from loguru import logger
37from pydantic import (
38 AliasChoices,
39 BaseModel,
40 Field,
41 PlainSerializer,
42 WithJsonSchema,
43 model_validator,
44)
45from pydantic_settings import (
46 BaseSettings,
47 CliApp,
48 CliPositionalArg,
49 CliSettingsSource,
50 CliSubCommand,
51 JsonConfigSettingsSource,
52 PydanticBaseSettingsSource,
53 SettingsConfigDict,
54 YamlConfigSettingsSource,
55)
56from tqdm import tqdm
57from typing_extensions import assert_never
59import bioimageio.spec
60from bioimageio.core import __version__
61from bioimageio.spec import (
62 AnyModelDescr,
63 InvalidDescr,
64 ResourceDescr,
65 load_description,
66 save_bioimageio_yaml_only,
67 settings,
68 update_format,
69 update_hashes,
70)
71from bioimageio.spec._internal.io import is_yaml_value
72from bioimageio.spec._internal.io_utils import open_bioimageio_yaml
73from bioimageio.spec._internal.types import FormatVersionPlaceholder, NotEmpty
74from bioimageio.spec.dataset import DatasetDescr
75from bioimageio.spec.model import ModelDescr, v0_4, v0_5
76from bioimageio.spec.notebook import NotebookDescr
77from bioimageio.spec.utils import (
78 empty_cache,
79 ensure_description_is_model,
80 get_reader,
81 write_yaml,
82)
84from .commands import WeightFormatArgAll, WeightFormatArgAny, package, test
85from .common import MemberId, SampleId, SupportedWeightsFormat
86from .digest_spec import get_member_ids, load_sample_for_model
87from .io import load_dataset_stat, save_dataset_stat, save_sample
88from .prediction import create_prediction_pipeline
89from .proc_setup import (
90 DatasetMeasure,
91 Measure,
92 MeasureValue,
93 StatsCalculator,
94 get_required_dataset_measures,
95)
96from .sample import Sample
97from .stat_measures import Stat
98from .utils import compare
99from .weight_converters._add_weights import add_weights
101WEIGHT_FORMAT_ALIASES = AliasChoices(
102 "weight-format",
103 "weights-format",
104 "weight_format",
105 "weights_format",
106)
109class CmdBase(BaseModel, use_attribute_docstrings=True, cli_implicit_flags=True):
110 pass
113class ArgMixin(BaseModel, use_attribute_docstrings=True, cli_implicit_flags=True):
114 pass
117class WithSummaryLogging(ArgMixin):
118 summary: List[Union[Literal["display"], Path]] = Field(
119 default_factory=lambda: ["display"],
120 examples=[
121 Path("summary.md"),
122 Path("bioimageio_summaries/"),
123 ["display", Path("summary.md")],
124 ],
125 )
126 """Display the validation summary or save it as JSON, Markdown or HTML.
127 The format is chosen based on the suffix: `.json`, `.md`, `.html`.
128 If a folder is given (path w/o suffix) the summary is saved in all formats.
129 Choose/add `"display"` to render the validation summary to the terminal.
130 """
132 def log(self, descr: Union[ResourceDescr, InvalidDescr]):
133 _ = descr.validation_summary.log(self.summary)
136class WithSource(ArgMixin):
137 source: CliPositionalArg[str]
138 """Url/path to a (folder with a) bioimageio.yaml/rdf.yaml file
139 or a bioimage.io resource identifier, e.g. 'affable-shark'"""
141 @cached_property
142 def descr(self):
143 return load_description(self.source)
145 @property
146 def descr_id(self) -> str:
147 """a more user-friendly description id
148 (replacing legacy ids with their nicknames)
149 """
150 if isinstance(self.descr, InvalidDescr):
151 return str(getattr(self.descr, "id", getattr(self.descr, "name")))
153 nickname = None
154 if (
155 isinstance(self.descr.config, v0_5.Config)
156 and (bio_config := self.descr.config.bioimageio)
157 and bio_config.model_extra is not None
158 ):
159 nickname = bio_config.model_extra.get("nickname")
161 return str(nickname or self.descr.id or self.descr.name)
164class ValidateFormatCmd(CmdBase, WithSource, WithSummaryLogging):
165 """Validate the meta data format of a bioimageio resource."""
167 perform_io_checks: bool = Field(
168 settings.perform_io_checks, alias="perform-io-checks"
169 )
170 """Wether or not to perform validations that requires downloading remote files.
171 Note: Default value is set by `BIOIMAGEIO_PERFORM_IO_CHECKS` environment variable.
172 """
174 @cached_property
175 def descr(self):
176 return load_description(self.source, perform_io_checks=self.perform_io_checks)
178 def cli_cmd(self):
179 self.log(self.descr)
180 sys.exit(
181 0
182 if self.descr.validation_summary.status in ("valid-format", "passed")
183 else 1
184 )
187class TestCmd(CmdBase, WithSource, WithSummaryLogging):
188 """Test a bioimageio resource (beyond meta data formatting)."""
190 weight_format: WeightFormatArgAll = Field(
191 "all",
192 alias="weight-format",
193 validation_alias=WEIGHT_FORMAT_ALIASES,
194 )
195 """The weight format to limit testing to.
197 (only relevant for model resources)"""
199 devices: Optional[List[str]] = None
200 """Device(s) to use for testing"""
202 runtime_env: Union[Literal["currently-active", "as-described"], Path] = Field(
203 "currently-active", alias="runtime-env"
204 )
205 """The python environment to run the tests in
206 - `"currently-active"`: use active Python interpreter
207 - `"as-described"`: generate a conda environment YAML file based on the model
208 weights description.
209 - A path to a conda environment YAML.
210 Note: The `bioimageio.core` dependency will be added automatically if not present.
211 """
213 working_dir: Optional[Path] = Field(None, alias="working-dir")
214 """(for debugging) Directory to save any temporary files."""
216 determinism: Literal["seed_only", "full"] = "seed_only"
217 """Modes to improve reproducibility of test outputs."""
219 stop_early: bool = Field(
220 False, alias="stop-early", validation_alias=AliasChoices("stop-early", "x")
221 )
222 """Do not run further subtests after a failed one."""
224 format_version: Union[FormatVersionPlaceholder, str] = Field(
225 "discover", alias="format-version"
226 )
227 """The format version to use for testing.
228 - 'latest': Use the latest implemented format version for the given resource type (may trigger auto updating)
229 - 'discover': Use the format version as described in the resource description
230 - '0.4', '0.5', ...: Use the specified format version (may trigger auto updating)
231 """
233 def cli_cmd(self):
234 sys.exit(
235 test(
236 self.descr,
237 weight_format=self.weight_format,
238 devices=self.devices,
239 summary=self.summary,
240 runtime_env=self.runtime_env,
241 determinism=self.determinism,
242 format_version=self.format_version,
243 working_dir=self.working_dir,
244 )
245 )
248class PackageCmd(CmdBase, WithSource, WithSummaryLogging):
249 """Save a resource's metadata with its associated files."""
251 path: CliPositionalArg[Path]
252 """The path to write the (zipped) package to.
253 If it does not have a `.zip` suffix
254 this command will save the package as an unzipped folder instead."""
256 weight_format: WeightFormatArgAll = Field(
257 "all",
258 alias="weight-format",
259 validation_alias=WEIGHT_FORMAT_ALIASES,
260 )
261 """The weight format to include in the package (for model descriptions only)."""
263 def cli_cmd(self):
264 if isinstance(self.descr, InvalidDescr):
265 self.log(self.descr)
266 raise ValueError(f"Invalid {self.descr.type} description.")
268 sys.exit(
269 package(
270 self.descr,
271 self.path,
272 weight_format=self.weight_format,
273 )
274 )
277def _get_stat(
278 model_descr: AnyModelDescr,
279 dataset: Iterable[Sample],
280 dataset_length: int,
281 stats_path: Path,
282) -> Mapping[DatasetMeasure, MeasureValue]:
283 req_dataset_meas, _ = get_required_dataset_measures(model_descr)
284 if not req_dataset_meas:
285 return {}
287 req_dataset_meas, _ = get_required_dataset_measures(model_descr)
289 if stats_path.exists():
290 logger.info("loading precomputed dataset measures from {}", stats_path)
291 stat = load_dataset_stat(stats_path)
292 for m in req_dataset_meas:
293 if m not in stat:
294 raise ValueError(f"Missing {m} in {stats_path}")
296 return stat
298 stats_calc = StatsCalculator(req_dataset_meas)
300 for sample in tqdm(
301 dataset, total=dataset_length, desc="precomputing dataset stats", unit="sample"
302 ):
303 stats_calc.update(sample)
305 stat = stats_calc.finalize()
306 save_dataset_stat(stat, stats_path)
308 return stat
311class UpdateCmdBase(CmdBase, WithSource, ABC):
312 output: Union[Literal["display", "stdout"], Path] = "display"
313 """Output updated bioimageio.yaml to the terminal or write to a file.
314 Notes:
315 - `"display"`: Render to the terminal with syntax highlighting.
316 - `"stdout"`: Write to sys.stdout without syntax highligthing.
317 (More convenient for copying the updated bioimageio.yaml from the terminal.)
318 """
320 diff: Union[bool, Path] = Field(True, alias="diff")
321 """Output a diff of original and updated bioimageio.yaml.
322 If a given path has an `.html` extension, a standalone HTML file is written,
323 otherwise the diff is saved in unified diff format (pure text).
324 """
326 exclude_unset: bool = Field(True, alias="exclude-unset")
327 """Exclude fields that have not explicitly be set."""
329 exclude_defaults: bool = Field(False, alias="exclude-defaults")
330 """Exclude fields that have the default value (even if set explicitly)."""
332 @cached_property
333 def updated(self) -> Union[ResourceDescr, InvalidDescr]:
334 raise NotImplementedError
336 def cli_cmd(self):
337 original_yaml = open_bioimageio_yaml(self.source).unparsed_content
338 assert isinstance(original_yaml, str)
339 stream = StringIO()
341 save_bioimageio_yaml_only(
342 self.updated,
343 stream,
344 exclude_unset=self.exclude_unset,
345 exclude_defaults=self.exclude_defaults,
346 )
347 updated_yaml = stream.getvalue()
349 diff = compare(
350 original_yaml.split("\n"),
351 updated_yaml.split("\n"),
352 diff_format=(
353 "html"
354 if isinstance(self.diff, Path) and self.diff.suffix == ".html"
355 else "unified"
356 ),
357 )
359 if isinstance(self.diff, Path):
360 _ = self.diff.write_text(diff, encoding="utf-8")
361 elif self.diff:
362 console = rich.console.Console()
363 diff_md = f"## Diff\n\n````````diff\n{diff}\n````````"
364 console.print(rich.markdown.Markdown(diff_md))
366 if isinstance(self.output, Path):
367 _ = self.output.write_text(updated_yaml, encoding="utf-8")
368 logger.info(f"written updated description to {self.output}")
369 elif self.output == "display":
370 updated_md = f"## Updated bioimageio.yaml\n\n```yaml\n{updated_yaml}\n```"
371 rich.console.Console().print(rich.markdown.Markdown(updated_md))
372 elif self.output == "stdout":
373 print(updated_yaml)
374 else:
375 assert_never(self.output)
377 if isinstance(self.updated, InvalidDescr):
378 logger.warning("Update resulted in invalid description")
379 _ = self.updated.validation_summary.display()
382class UpdateFormatCmd(UpdateCmdBase):
383 """Update the metadata format to the latest format version."""
385 exclude_defaults: bool = Field(True, alias="exclude-defaults")
386 """Exclude fields that have the default value (even if set explicitly).
388 Note:
389 The update process sets most unset fields explicitly with their default value.
390 """
392 perform_io_checks: bool = Field(
393 settings.perform_io_checks, alias="perform-io-checks"
394 )
395 """Wether or not to attempt validation that may require file download.
396 If `True` file hash values are added if not present."""
398 @cached_property
399 def updated(self):
400 return update_format(
401 self.source,
402 exclude_defaults=self.exclude_defaults,
403 perform_io_checks=self.perform_io_checks,
404 )
407class UpdateHashesCmd(UpdateCmdBase):
408 """Create a bioimageio.yaml description with updated file hashes."""
410 @cached_property
411 def updated(self):
412 return update_hashes(self.source)
415class PredictCmd(CmdBase, WithSource):
416 """Run inference on your data with a bioimage.io model."""
418 inputs: NotEmpty[List[Union[str, NotEmpty[List[str]]]]] = Field(
419 default_factory=lambda: ["{input_id}/001.tif"]
420 )
421 """Model input sample paths (for each input tensor)
423 The input paths are expected to have shape...
424 - (n_samples,) or (n_samples,1) for models expecting a single input tensor
425 - (n_samples,) containing the substring '{input_id}', or
426 - (n_samples, n_model_inputs) to provide each input tensor path explicitly.
428 All substrings that are replaced by metadata from the model description:
429 - '{model_id}'
430 - '{input_id}'
432 Example inputs to process sample 'a' and 'b'
433 for a model expecting a 'raw' and a 'mask' input tensor:
434 --inputs="[[\\"a_raw.tif\\",\\"a_mask.tif\\"],[\\"b_raw.tif\\",\\"b_mask.tif\\"]]"
435 (Note that JSON double quotes need to be escaped.)
437 Alternatively a `bioimageio-cli.yaml` (or `bioimageio-cli.json`) file
438 may provide the arguments, e.g.:
439 ```yaml
440 inputs:
441 - [a_raw.tif, a_mask.tif]
442 - [b_raw.tif, b_mask.tif]
443 ```
445 `.npy` and any file extension supported by imageio are supported.
446 Aavailable formats are listed at
447 https://imageio.readthedocs.io/en/stable/formats/index.html#all-formats.
448 Some formats have additional dependencies.
451 """
453 outputs: Union[str, NotEmpty[Tuple[str, ...]]] = (
454 "outputs_{model_id}/{output_id}/{sample_id}.tif"
455 )
456 """Model output path pattern (per output tensor)
458 All substrings that are replaced:
459 - '{model_id}' (from model description)
460 - '{output_id}' (from model description)
461 - '{sample_id}' (extracted from input paths)
463 """
465 overwrite: bool = False
466 """allow overwriting existing output files"""
468 blockwise: Union[bool, int] = False
469 """Process inputs blockwise
471 - If an integer is given, it is used as the blocksize parameter 'n' for blockwise processing.
472 The blockize parameter determines the block size along axes with parameterized input size
473 by adding n*step_size to the minimum valid input size.
474 - If `True`, the blocksize parameter is set to 10.
475 - If `False`, inputs are processed as a whole without blocking.
477 """
479 stats: Annotated[
480 Path,
481 WithJsonSchema({"type": "string"}),
482 PlainSerializer(lambda p: p.as_posix(), return_type=str),
483 ] = Path("dataset_statistics.json")
484 """path to dataset statistics
485 (will be written if it does not exist
486 and the model requires statistical dataset measures)
487 """
489 preview: bool = False
490 """preview which files would be processed
491 and what outputs would be generated."""
493 weight_format: WeightFormatArgAny = Field(
494 "any",
495 alias="weight-format",
496 validation_alias=WEIGHT_FORMAT_ALIASES,
497 )
498 """The weight format to use."""
500 example: bool = False
501 """generate and run an example
503 1. downloads example model inputs
504 2. creates a `{model_id}_example` folder
505 3. writes input arguments to `{model_id}_example/bioimageio-cli.yaml`
506 4. executes a preview dry-run
507 5. executes prediction with example input
510 """
512 def _example(self):
513 model_descr = ensure_description_is_model(self.descr)
514 input_ids = get_member_ids(model_descr.inputs)
515 example_inputs = (
516 model_descr.sample_inputs
517 if isinstance(model_descr, v0_4.ModelDescr)
518 else [
519 t
520 for ipt in model_descr.inputs
521 if (t := ipt.sample_tensor or ipt.test_tensor)
522 ]
523 )
524 if not example_inputs:
525 raise ValueError(f"{self.descr_id} does not specify any example inputs.")
527 inputs001: List[str] = []
528 example_path = Path(f"{self.descr_id}_example")
529 example_path.mkdir(exist_ok=True)
531 for t, src in zip(input_ids, example_inputs):
532 reader = get_reader(src)
533 dst = Path(f"{example_path}/{t}/001{reader.suffix}")
534 dst.parent.mkdir(parents=True, exist_ok=True)
535 inputs001.append(dst.as_posix())
536 with dst.open("wb") as f:
537 shutil.copyfileobj(reader, f)
539 inputs = [inputs001]
540 output_pattern = f"{example_path}/outputs/{{output_id}}/{{sample_id}}.tif"
542 bioimageio_cli_path = example_path / YAML_FILE
543 stats_file = "dataset_statistics.json"
544 stats = (example_path / stats_file).as_posix()
545 cli_example_args = dict(
546 inputs=inputs,
547 outputs=output_pattern,
548 stats=stats_file,
549 blockwise=self.blockwise,
550 )
551 assert is_yaml_value(cli_example_args), cli_example_args
552 write_yaml(
553 cli_example_args,
554 bioimageio_cli_path,
555 )
557 yaml_file_content = None
559 # escaped double quotes
560 inputs_json = json.dumps(inputs)
561 inputs_escaped = inputs_json.replace('"', r"\"")
562 source_escaped = self.source.replace('"', r"\"")
564 def get_example_command(preview: bool, escape: bool = False):
565 q: str = '"' if escape else ""
567 return [
568 "bioimageio",
569 "predict",
570 # --no-preview not supported for py=3.8
571 *(["--preview"] if preview else []),
572 "--overwrite",
573 *(["--blockwise"] if self.blockwise else []),
574 f"--stats={q}{stats}{q}",
575 f"--inputs={q}{inputs_escaped if escape else inputs_json}{q}",
576 f"--outputs={q}{output_pattern}{q}",
577 f"{q}{source_escaped if escape else self.source}{q}",
578 ]
580 if Path(YAML_FILE).exists():
581 logger.info(
582 "temporarily removing '{}' to execute example prediction", YAML_FILE
583 )
584 yaml_file_content = Path(YAML_FILE).read_bytes()
585 Path(YAML_FILE).unlink()
587 try:
588 _ = subprocess.run(get_example_command(True), check=True)
589 _ = subprocess.run(get_example_command(False), check=True)
590 finally:
591 if yaml_file_content is not None:
592 _ = Path(YAML_FILE).write_bytes(yaml_file_content)
593 logger.debug("restored '{}'", YAML_FILE)
595 print(
596 "🎉 Sucessfully ran example prediction!\n"
597 + "To predict the example input using the CLI example config file"
598 + f" {example_path / YAML_FILE}, execute `bioimageio predict` from {example_path}:\n"
599 + f"$ cd {str(example_path)}\n"
600 + f'$ bioimageio predict "{source_escaped}"\n\n'
601 + "Alternatively run the following command"
602 + " in the current workind directory, not the example folder:\n$ "
603 + " ".join(get_example_command(False, escape=True))
604 + f"\n(note that a local '{JSON_FILE}' or '{YAML_FILE}' may interfere with this)"
605 )
607 def cli_cmd(self):
608 for out_sample, out_path in self._yield_predictions(self.blockwise):
609 save_sample(out_path, out_sample)
611 def _yield_predictions(self, blockwise: Union[bool, int]):
612 if self.example:
613 return self._example()
615 model_descr = ensure_description_is_model(self.descr)
617 input_ids = get_member_ids(model_descr.inputs)
618 output_ids = get_member_ids(model_descr.outputs)
620 minimum_input_ids = tuple(
621 str(ipt.id) if isinstance(ipt, v0_5.InputTensorDescr) else str(ipt.name)
622 for ipt in model_descr.inputs
623 if not isinstance(ipt, v0_5.InputTensorDescr) or not ipt.optional
624 )
625 maximum_input_ids = tuple(
626 str(ipt.id) if isinstance(ipt, v0_5.InputTensorDescr) else str(ipt.name)
627 for ipt in model_descr.inputs
628 )
630 def expand_inputs(i: int, ipt: Union[str, Sequence[str]]) -> Tuple[str, ...]:
631 if isinstance(ipt, str):
632 ipts = tuple(
633 ipt.format(model_id=self.descr_id, input_id=t) for t in input_ids
634 )
635 else:
636 ipts = tuple(
637 p.format(model_id=self.descr_id, input_id=t)
638 for t, p in zip(input_ids, ipt)
639 )
641 if len(set(ipts)) < len(ipts):
642 if len(minimum_input_ids) == len(maximum_input_ids):
643 n = len(minimum_input_ids)
644 else:
645 n = f"{len(minimum_input_ids)}-{len(maximum_input_ids)}"
647 raise ValueError(
648 f"[input sample #{i}] Include '{{input_id}}' in path pattern or explicitly specify {n} distinct input paths (got {ipt})"
649 )
651 if len(ipts) < len(minimum_input_ids):
652 raise ValueError(
653 f"[input sample #{i}] Expected at least {len(minimum_input_ids)} inputs {minimum_input_ids}, got {ipts}"
654 )
656 if len(ipts) > len(maximum_input_ids):
657 raise ValueError(
658 f"Expected at most {len(maximum_input_ids)} inputs {maximum_input_ids}, got {ipts}"
659 )
661 return ipts
663 inputs = [expand_inputs(i, ipt) for i, ipt in enumerate(self.inputs, start=1)]
665 sample_paths_in = [
666 {t: Path(p) for t, p in zip(input_ids, ipts)} for ipts in inputs
667 ]
669 sample_ids = _get_sample_ids(sample_paths_in)
671 def expand_outputs():
672 if isinstance(self.outputs, str):
673 outputs = [
674 tuple(
675 Path(
676 self.outputs.format(
677 model_id=self.descr_id, output_id=t, sample_id=s
678 )
679 )
680 for t in output_ids
681 )
682 for s in sample_ids
683 ]
684 else:
685 outputs = [
686 tuple(
687 Path(p.format(model_id=self.descr_id, output_id=t, sample_id=s))
688 for t, p in zip(output_ids, self.outputs)
689 )
690 for s in sample_ids
691 ]
692 # check for distinctness and correct number within each output sample
693 for i, out in enumerate(outputs, start=1):
694 if len(set(out)) < len(out):
695 raise ValueError(
696 f"[output sample #{i}] Include '{{output_id}}' in path pattern or explicitly specify {len(output_ids)} distinct output paths (got {out})"
697 )
699 if len(out) != len(output_ids):
700 raise ValueError(
701 f"[output sample #{i}] Expected {len(output_ids)} outputs {output_ids}, got {out}"
702 )
704 # check for distinctness across all output samples
705 all_output_paths = [p for out in outputs for p in out]
706 if len(set(all_output_paths)) < len(all_output_paths):
707 raise ValueError(
708 "Output paths are not distinct across samples. "
709 + "Make sure to include '{{sample_id}}' in the output path pattern."
710 )
712 return outputs
714 outputs = expand_outputs()
716 sample_paths_out = [
717 {MemberId(t): Path(p) for t, p in zip(output_ids, out)} for out in outputs
718 ]
720 if not self.overwrite:
721 for sample_paths in sample_paths_out:
722 for p in sample_paths.values():
723 if p.exists():
724 raise FileExistsError(
725 f"{p} already exists. use --overwrite to (re-)write outputs anyway."
726 )
727 if self.preview:
728 print("🛈 bioimageio prediction preview structure:")
729 pprint(
730 {
731 "{sample_id}": dict(
732 inputs={"{input_id}": "<input path>"},
733 outputs={"{output_id}": "<output path>"},
734 )
735 }
736 )
737 print("🔎 bioimageio prediction preview output:")
738 pprint(
739 {
740 s: dict(
741 inputs={t: p.as_posix() for t, p in sp_in.items()},
742 outputs={t: p.as_posix() for t, p in sp_out.items()},
743 )
744 for s, sp_in, sp_out in zip(
745 sample_ids, sample_paths_in, sample_paths_out
746 )
747 }
748 )
749 return
751 def input_dataset(stat: Stat):
752 for s, sp_in in zip(sample_ids, sample_paths_in):
753 yield load_sample_for_model(
754 model=model_descr,
755 paths=sp_in,
756 stat=stat,
757 sample_id=s,
758 )
760 stat: Dict[Measure, MeasureValue] = dict(
761 _get_stat(
762 model_descr, input_dataset({}), len(sample_ids), self.stats
763 ).items()
764 )
766 pp = create_prediction_pipeline(
767 model_descr,
768 weight_format=None if self.weight_format == "any" else self.weight_format,
769 )
771 if blockwise:
772 predict_method = partial(
773 pp.predict_sample_with_blocking,
774 ns=None if isinstance(blockwise, bool) else blockwise,
775 )
776 else:
777 predict_method = pp.predict_sample_without_blocking
779 for sample_in, sp_out in tqdm(
780 zip(input_dataset(dict(stat)), sample_paths_out),
781 total=len(inputs),
782 desc=f"predict with {self.descr_id}",
783 unit="sample",
784 ):
785 yield (predict_method(sample_in), sp_out)
788class PredictBlockArtifactsCmd(PredictCmd):
789 """Command to inspect block artifacts by subtracting the combined, blockwise predictions from a whole sample prediction.
791 Note:
792 - This command intentionally uses a small blocksize (default: 1) to create block artifacts for testing purposes.
793 - Typical sources of block artifacts include:
794 - Described halo is smaller than the model's receptive field
795 - Normalization layers inside the network cannot aggregate statistics over the whole sample.
796 """
798 blockwise: Union[Literal[True], int] = 1
799 """Process inputs blockwise
801 - If an integer is given, it is used as the blocksize parameter 'n' for blockwise processing.
802 The blockize parameter determines the block size along axes with parameterized input size
803 by adding n*step_size to the minimum valid input size.
804 - If `True`, the blocksize parameter is set to 10.
806 Defaults to a small blocksize to intentionally create block artifacts for testing purposes.
807 """
809 def cli_cmd(self):
810 for (out_sample, out_path), (out_sample_blockwise, _) in zip(
811 self._yield_predictions(False), self._yield_predictions(self.blockwise)
812 ):
813 diff_sample = self._subtract_samples(out_sample, out_sample_blockwise)
814 for k, v_a in out_sample.stat.items():
815 v_b = out_sample_blockwise.stat.get(k)
816 if v_b is None:
817 logger.error(
818 "measure '{}' not found in blockwise prediction statistics", k
819 )
820 elif not np.not_equal(v_a, v_b):
821 logger.error(
822 "measure '{}' has different values (whole sample!=blockwise): {}!={}",
823 k,
824 v_a,
825 v_b,
826 )
828 save_sample(out_path, diff_sample)
830 @staticmethod
831 def _subtract_samples(a: Sample, b: Sample) -> Sample:
832 return Sample(
833 members={t: a.members[t] - b.members[t] for t in a.members},
834 id=a.id,
835 stat=a.stat,
836 )
839class AddWeightsCmd(CmdBase, WithSource, WithSummaryLogging):
840 """Add additional weights to a model description by converting from available formats."""
842 output: CliPositionalArg[Path]
843 """The path to write the updated model package to."""
845 source_format: Optional[SupportedWeightsFormat] = Field(None, alias="source-format")
846 """Exclusively use these weights to convert to other formats."""
848 target_format: Optional[SupportedWeightsFormat] = Field(None, alias="target-format")
849 """Exclusively add this weight format."""
851 verbose: bool = False
852 """Log more (error) output."""
854 tracing: bool = True
855 """Allow tracing when converting pytorch_state_dict to torchscript
856 (still uses scripting if possible)."""
858 def cli_cmd(self):
859 model_descr = ensure_description_is_model(self.descr)
860 if isinstance(model_descr, v0_4.ModelDescr):
861 raise TypeError(
862 f"model format {model_descr.format_version} not supported."
863 + " Please update the model first."
864 )
865 updated_model_descr = add_weights(
866 model_descr,
867 output_path=self.output,
868 source_format=self.source_format,
869 target_format=self.target_format,
870 verbose=self.verbose,
871 allow_tracing=self.tracing,
872 )
873 self.log(updated_model_descr)
876class EmptyCache(CmdBase):
877 """Empty the bioimageio cache directory."""
879 def cli_cmd(self):
880 empty_cache()
883JSON_FILE = "bioimageio-cli.json"
884YAML_FILE = "bioimageio-cli.yaml"
887class Bioimageio(
888 BaseSettings,
889 cli_implicit_flags=True,
890 cli_parse_args=True,
891 cli_prog_name="bioimageio",
892 cli_use_class_docs_for_groups=True,
893 use_attribute_docstrings=True,
894):
895 """bioimageio - CLI for bioimage.io resources 🦒"""
897 model_config = SettingsConfigDict(
898 json_file=JSON_FILE,
899 yaml_file=YAML_FILE,
900 )
902 validate_format: CliSubCommand[ValidateFormatCmd] = Field(alias="validate-format")
903 """Check a resource's metadata format"""
905 test: CliSubCommand[TestCmd]
906 """Test a bioimageio resource (beyond meta data formatting)"""
908 package: CliSubCommand[PackageCmd]
909 """Package a resource"""
911 predict: CliSubCommand[PredictCmd]
912 """Predict with a model resource"""
914 predict_block_artifacts: CliSubCommand[PredictBlockArtifactsCmd] = Field(
915 alias="predict-block-artifacts"
916 )
917 """Save the difference between predicting blowise and whole sample to check for block artifacts."""
919 update_format: CliSubCommand[UpdateFormatCmd] = Field(alias="update-format")
920 """Update the metadata format"""
922 update_hashes: CliSubCommand[UpdateHashesCmd] = Field(alias="update-hashes")
923 """Create a bioimageio.yaml description with updated file hashes."""
925 add_weights: CliSubCommand[AddWeightsCmd] = Field(alias="add-weights")
926 """Add additional weights to a model description by converting from available formats."""
928 empty_cache: CliSubCommand[EmptyCache] = Field(alias="empty-cache")
929 """Empty the bioimageio cache directory."""
931 @classmethod
932 def settings_customise_sources(
933 cls,
934 settings_cls: Type[BaseSettings],
935 init_settings: PydanticBaseSettingsSource,
936 env_settings: PydanticBaseSettingsSource,
937 dotenv_settings: PydanticBaseSettingsSource,
938 file_secret_settings: PydanticBaseSettingsSource,
939 ) -> Tuple[PydanticBaseSettingsSource, ...]:
940 cli: CliSettingsSource[BaseSettings] = CliSettingsSource(
941 settings_cls,
942 cli_parse_args=True,
943 formatter_class=RawTextHelpFormatter,
944 )
945 sys_args = pformat(sys.argv)
946 logger.info("starting CLI with arguments:\n{}", sys_args)
947 return (
948 cli,
949 init_settings,
950 YamlConfigSettingsSource(settings_cls),
951 JsonConfigSettingsSource(settings_cls),
952 )
954 @model_validator(mode="before")
955 @classmethod
956 def _log(cls, data: Any):
957 logger.info(
958 "loaded CLI input:\n{}",
959 pformat({k: v for k, v in data.items() if v is not None}),
960 )
961 return data
963 def cli_cmd(self) -> None:
964 logger.info(
965 "executing CLI command:\n{}",
966 pformat({k: v for k, v in self.model_dump().items() if v is not None}),
967 )
968 _ = CliApp.run_subcommand(self)
971assert isinstance(Bioimageio.__doc__, str)
972Bioimageio.__doc__ += f"""
974library versions:
975 bioimageio.core {__version__}
976 bioimageio.spec {bioimageio.spec.__version__}
978spec format versions:
979 model RDF {ModelDescr.implemented_format_version}
980 dataset RDF {DatasetDescr.implemented_format_version}
981 notebook RDF {NotebookDescr.implemented_format_version}
983"""
986def _get_sample_ids(
987 input_paths: Sequence[Mapping[MemberId, Path]],
988) -> Sequence[SampleId]:
989 """Get sample ids for given input paths, based on the common path per sample.
991 Falls back to sample01, samle02, etc..."""
993 matcher = SequenceMatcher()
995 def get_common_seq(seqs: Sequence[Sequence[str]]) -> Sequence[str]:
996 """extract a common sequence from multiple sequences
997 (order sensitive; strips whitespace and slashes)
998 """
999 common = seqs[0]
1001 for seq in seqs[1:]:
1002 if not seq:
1003 continue
1004 matcher.set_seqs(common, seq)
1005 i, _, size = matcher.find_longest_match()
1006 common = common[i : i + size]
1008 if isinstance(common, str):
1009 common = common.strip().strip("/")
1010 else:
1011 common = [cs for c in common if (cs := c.strip().strip("/"))]
1013 if not common:
1014 raise ValueError(f"failed to find common sequence for {seqs}")
1016 return common
1018 def get_shorter_diff(seqs: Sequence[Sequence[str]]) -> List[Sequence[str]]:
1019 """get a shorter sequence whose entries are still unique
1020 (order sensitive, not minimal sequence)
1021 """
1022 min_seq_len = min(len(s) for s in seqs)
1023 # cut from the start
1024 for start in range(min_seq_len - 1, -1, -1):
1025 shortened = [s[start:] for s in seqs]
1026 if len(set(shortened)) == len(seqs):
1027 min_seq_len -= start
1028 break
1029 else:
1030 seen: Set[Sequence[str]] = set()
1031 dupes = [s for s in seqs if s in seen or seen.add(s)]
1032 raise ValueError(f"Found duplicate entries {dupes}")
1034 # cut from the end
1035 for end in range(min_seq_len - 1, 1, -1):
1036 shortened = [s[:end] for s in shortened]
1037 if len(set(shortened)) == len(seqs):
1038 break
1040 return shortened
1042 full_tensor_ids = [
1043 sorted(
1044 p.resolve().with_suffix("").as_posix() for p in input_sample_paths.values()
1045 )
1046 for input_sample_paths in input_paths
1047 ]
1048 try:
1049 long_sample_ids = [get_common_seq(t) for t in full_tensor_ids]
1050 sample_ids = get_shorter_diff(long_sample_ids)
1051 except ValueError as e:
1052 raise ValueError(f"failed to extract sample ids: {e}")
1054 return sample_ids