Coverage for bioimageio/core/cli.py: 80%
272 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-19 09:02 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-19 09:02 +0000
1"""bioimageio CLI
3Note: Some docstrings use a hair space ' '
4 to place the added '(default: ...)' on a new line.
5"""
7import json
8import shutil
9import subprocess
10import sys
11from argparse import RawTextHelpFormatter
12from difflib import SequenceMatcher
13from functools import cached_property
14from pathlib import Path
15from pprint import pformat, pprint
16from typing import (
17 Any,
18 Dict,
19 Iterable,
20 List,
21 Mapping,
22 Optional,
23 Sequence,
24 Set,
25 Tuple,
26 Type,
27 Union,
28)
30from loguru import logger
31from pydantic import BaseModel, Field, model_validator
32from pydantic_settings import (
33 BaseSettings,
34 CliPositionalArg,
35 CliSettingsSource,
36 CliSubCommand,
37 JsonConfigSettingsSource,
38 PydanticBaseSettingsSource,
39 SettingsConfigDict,
40 YamlConfigSettingsSource,
41)
42from ruyaml import YAML
43from tqdm import tqdm
44from typing_extensions import assert_never
46from bioimageio.spec import AnyModelDescr, InvalidDescr, load_description
47from bioimageio.spec._internal.io_basics import ZipPath
48from bioimageio.spec._internal.types import NotEmpty
49from bioimageio.spec.dataset import DatasetDescr
50from bioimageio.spec.model import ModelDescr, v0_4, v0_5
51from bioimageio.spec.notebook import NotebookDescr
52from bioimageio.spec.utils import download, ensure_description_is_model
54from .commands import (
55 WeightFormatArgAll,
56 WeightFormatArgAny,
57 package,
58 test,
59 validate_format,
60)
61from .common import MemberId, SampleId
62from .digest_spec import get_member_ids, load_sample_for_model
63from .io import load_dataset_stat, save_dataset_stat, save_sample
64from .prediction import create_prediction_pipeline
65from .proc_setup import (
66 DatasetMeasure,
67 Measure,
68 MeasureValue,
69 StatsCalculator,
70 get_required_dataset_measures,
71)
72from .sample import Sample
73from .stat_measures import Stat
74from .utils import VERSION
76yaml = YAML(typ="safe")
79class CmdBase(BaseModel, use_attribute_docstrings=True, cli_implicit_flags=True):
80 pass
83class ArgMixin(BaseModel, use_attribute_docstrings=True, cli_implicit_flags=True):
84 pass
87class WithSource(ArgMixin):
88 source: CliPositionalArg[str]
89 """Url/path to a bioimageio.yaml/rdf.yaml file
90 or a bioimage.io resource identifier, e.g. 'affable-shark'"""
92 @cached_property
93 def descr(self):
94 return load_description(self.source)
96 @property
97 def descr_id(self) -> str:
98 """a more user-friendly description id
99 (replacing legacy ids with their nicknames)
100 """
101 if isinstance(self.descr, InvalidDescr):
102 return str(getattr(self.descr, "id", getattr(self.descr, "name")))
103 else:
104 return str(
105 (
106 (bio_config := self.descr.config.get("bioimageio", {}))
107 and isinstance(bio_config, dict)
108 and bio_config.get("nickname")
109 )
110 or self.descr.id
111 or self.descr.name
112 )
115class ValidateFormatCmd(CmdBase, WithSource):
116 """validate the meta data format of a bioimageio resource."""
118 def run(self):
119 sys.exit(validate_format(self.descr))
122class TestCmd(CmdBase, WithSource):
123 """Test a bioimageio resource (beyond meta data formatting)"""
125 weight_format: WeightFormatArgAll = "all"
126 """The weight format to limit testing to.
128 (only relevant for model resources)"""
130 devices: Optional[Union[str, Sequence[str]]] = None
131 """Device(s) to use for testing"""
133 decimal: int = 4
134 """Precision for numerical comparisons"""
136 def run(self):
137 sys.exit(
138 test(
139 self.descr,
140 weight_format=self.weight_format,
141 devices=self.devices,
142 decimal=self.decimal,
143 )
144 )
147class PackageCmd(CmdBase, WithSource):
148 """save a resource's metadata with its associated files."""
150 path: CliPositionalArg[Path]
151 """The path to write the (zipped) package to.
152 If it does not have a `.zip` suffix
153 this command will save the package as an unzipped folder instead."""
155 weight_format: WeightFormatArgAll = "all"
156 """The weight format to include in the package (for model descriptions only)."""
158 def run(self):
159 if isinstance(self.descr, InvalidDescr):
160 self.descr.validation_summary.display()
161 raise ValueError("resource description is invalid")
163 sys.exit(
164 package(
165 self.descr,
166 self.path,
167 weight_format=self.weight_format,
168 )
169 )
172def _get_stat(
173 model_descr: AnyModelDescr,
174 dataset: Iterable[Sample],
175 dataset_length: int,
176 stats_path: Path,
177) -> Mapping[DatasetMeasure, MeasureValue]:
178 req_dataset_meas, _ = get_required_dataset_measures(model_descr)
179 if not req_dataset_meas:
180 return {}
182 req_dataset_meas, _ = get_required_dataset_measures(model_descr)
184 if stats_path.exists():
185 logger.info(f"loading precomputed dataset measures from {stats_path}")
186 stat = load_dataset_stat(stats_path)
187 for m in req_dataset_meas:
188 if m not in stat:
189 raise ValueError(f"Missing {m} in {stats_path}")
191 return stat
193 stats_calc = StatsCalculator(req_dataset_meas)
195 for sample in tqdm(
196 dataset, total=dataset_length, desc="precomputing dataset stats", unit="sample"
197 ):
198 stats_calc.update(sample)
200 stat = stats_calc.finalize()
201 save_dataset_stat(stat, stats_path)
203 return stat
206class PredictCmd(CmdBase, WithSource):
207 """Run inference on your data with a bioimage.io model."""
209 inputs: NotEmpty[Sequence[Union[str, NotEmpty[Tuple[str, ...]]]]] = (
210 "{input_id}/001.tif",
211 )
212 """Model input sample paths (for each input tensor)
214 The input paths are expected to have shape...
215 - (n_samples,) or (n_samples,1) for models expecting a single input tensor
216 - (n_samples,) containing the substring '{input_id}', or
217 - (n_samples, n_model_inputs) to provide each input tensor path explicitly.
219 All substrings that are replaced by metadata from the model description:
220 - '{model_id}'
221 - '{input_id}'
223 Example inputs to process sample 'a' and 'b'
224 for a model expecting a 'raw' and a 'mask' input tensor:
225 --inputs="[[\"a_raw.tif\",\"a_mask.tif\"],[\"b_raw.tif\",\"b_mask.tif\"]]"
226 (Note that JSON double quotes need to be escaped.)
228 Alternatively a `bioimageio-cli.yaml` (or `bioimageio-cli.json`) file
229 may provide the arguments, e.g.:
230 ```yaml
231 inputs:
232 - [a_raw.tif, a_mask.tif]
233 - [b_raw.tif, b_mask.tif]
234 ```
236 `.npy` and any file extension supported by imageio are supported.
237 Aavailable formats are listed at
238 https://imageio.readthedocs.io/en/stable/formats/index.html#all-formats.
239 Some formats have additional dependencies.
242 """
244 outputs: Union[str, NotEmpty[Tuple[str, ...]]] = (
245 "outputs_{model_id}/{output_id}/{sample_id}.tif"
246 )
247 """Model output path pattern (per output tensor)
249 All substrings that are replaced:
250 - '{model_id}' (from model description)
251 - '{output_id}' (from model description)
252 - '{sample_id}' (extracted from input paths)
255 """
257 overwrite: bool = False
258 """allow overwriting existing output files"""
260 blockwise: bool = False
261 """process inputs blockwise"""
263 stats: Path = Path("dataset_statistics.json")
264 """path to dataset statistics
265 (will be written if it does not exist,
266 but the model requires statistical dataset measures)
267 """
269 preview: bool = False
270 """preview which files would be processed
271 and what outputs would be generated."""
273 weight_format: WeightFormatArgAny = "any"
274 """The weight format to use."""
276 example: bool = False
277 """generate and run an example
279 1. downloads example model inputs
280 2. creates a `{model_id}_example` folder
281 3. writes input arguments to `{model_id}_example/bioimageio-cli.yaml`
282 4. executes a preview dry-run
283 5. executes prediction with example input
286 """
288 def _example(self):
289 model_descr = ensure_description_is_model(self.descr)
290 input_ids = get_member_ids(model_descr.inputs)
291 example_inputs = (
292 model_descr.sample_inputs
293 if isinstance(model_descr, v0_4.ModelDescr)
294 else [ipt.sample_tensor or ipt.test_tensor for ipt in model_descr.inputs]
295 )
296 if not example_inputs:
297 raise ValueError(f"{self.descr_id} does not specify any example inputs.")
299 inputs001: List[str] = []
300 example_path = Path(f"{self.descr_id}_example")
301 example_path.mkdir(exist_ok=True)
303 for t, src in zip(input_ids, example_inputs):
304 local = download(src).path
305 dst = Path(f"{example_path}/{t}/001{''.join(local.suffixes)}")
306 dst.parent.mkdir(parents=True, exist_ok=True)
307 inputs001.append(dst.as_posix())
308 if isinstance(local, Path):
309 shutil.copy(local, dst)
310 elif isinstance(local, ZipPath):
311 _ = local.root.extract(local.at, path=dst)
312 else:
313 assert_never(local)
315 inputs = [tuple(inputs001)]
316 output_pattern = f"{example_path}/outputs/{ output_id} /{ sample_id} .tif"
318 bioimageio_cli_path = example_path / YAML_FILE
319 stats_file = "dataset_statistics.json"
320 stats = (example_path / stats_file).as_posix()
321 yaml.dump(
322 dict(
323 inputs=inputs,
324 outputs=output_pattern,
325 stats=stats_file,
326 blockwise=self.blockwise,
327 ),
328 bioimageio_cli_path,
329 )
331 yaml_file_content = None
333 # escaped double quotes
334 inputs_json = json.dumps(inputs)
335 inputs_escaped = inputs_json.replace('"', r"\"")
336 source_escaped = self.source.replace('"', r"\"")
338 def get_example_command(preview: bool, escape: bool = False):
339 q: str = '"' if escape else ""
341 return [
342 "bioimageio",
343 "predict",
344 # --no-preview not supported for py=3.8
345 *(["--preview"] if preview else []),
346 "--overwrite",
347 *(["--blockwise"] if self.blockwise else []),
348 f"--stats={q}{stats}{q}",
349 f"--inputs={q}{inputs_escaped if escape else inputs_json}{q}",
350 f"--outputs={q}{output_pattern}{q}",
351 f"{q}{source_escaped if escape else self.source}{q}",
352 ]
354 if Path(YAML_FILE).exists():
355 logger.info(
356 "temporarily removing '{}' to execute example prediction", YAML_FILE
357 )
358 yaml_file_content = Path(YAML_FILE).read_bytes()
359 Path(YAML_FILE).unlink()
361 try:
362 _ = subprocess.run(get_example_command(True), check=True)
363 _ = subprocess.run(get_example_command(False), check=True)
364 finally:
365 if yaml_file_content is not None:
366 _ = Path(YAML_FILE).write_bytes(yaml_file_content)
367 logger.debug("restored '{}'", YAML_FILE)
369 print(
370 "🎉 Sucessfully ran example prediction!\n"
371 + "To predict the example input using the CLI example config file"
372 + f" {example_path/YAML_FILE}, execute `bioimageio predict` from {example_path}:\n"
373 + f"$ cd {str(example_path)}\n"
374 + f'$ bioimageio predict "{source_escaped}"\n\n'
375 + "Alternatively run the following command"
376 + " in the current workind directory, not the example folder:\n$ "
377 + " ".join(get_example_command(False, escape=True))
378 + f"\n(note that a local '{JSON_FILE}' or '{YAML_FILE}' may interfere with this)"
379 )
381 def run(self):
382 if self.example:
383 return self._example()
385 model_descr = ensure_description_is_model(self.descr)
387 input_ids = get_member_ids(model_descr.inputs)
388 output_ids = get_member_ids(model_descr.outputs)
390 minimum_input_ids = tuple(
391 str(ipt.id) if isinstance(ipt, v0_5.InputTensorDescr) else str(ipt.name)
392 for ipt in model_descr.inputs
393 if not isinstance(ipt, v0_5.InputTensorDescr) or not ipt.optional
394 )
395 maximum_input_ids = tuple(
396 str(ipt.id) if isinstance(ipt, v0_5.InputTensorDescr) else str(ipt.name)
397 for ipt in model_descr.inputs
398 )
400 def expand_inputs(i: int, ipt: Union[str, Tuple[str, ...]]) -> Tuple[str, ...]:
401 if isinstance(ipt, str):
402 ipts = tuple(
403 ipt.format(model_id=self.descr_id, input_id=t) for t in input_ids
404 )
405 else:
406 ipts = tuple(
407 p.format(model_id=self.descr_id, input_id=t)
408 for t, p in zip(input_ids, ipt)
409 )
411 if len(set(ipts)) < len(ipts):
412 if len(minimum_input_ids) == len(maximum_input_ids):
413 n = len(minimum_input_ids)
414 else:
415 n = f"{len(minimum_input_ids)}-{len(maximum_input_ids)}"
417 raise ValueError(
418 f"[input sample #{i}] Include '{ input_id} ' in path pattern or explicitly specify {n} distinct input paths (got {ipt})"
419 )
421 if len(ipts) < len(minimum_input_ids):
422 raise ValueError(
423 f"[input sample #{i}] Expected at least {len(minimum_input_ids)} inputs {minimum_input_ids}, got {ipts}"
424 )
426 if len(ipts) > len(maximum_input_ids):
427 raise ValueError(
428 f"Expected at most {len(maximum_input_ids)} inputs {maximum_input_ids}, got {ipts}"
429 )
431 return ipts
433 inputs = [expand_inputs(i, ipt) for i, ipt in enumerate(self.inputs, start=1)]
435 sample_paths_in = [
436 {t: Path(p) for t, p in zip(input_ids, ipts)} for ipts in inputs
437 ]
439 sample_ids = _get_sample_ids(sample_paths_in)
441 def expand_outputs():
442 if isinstance(self.outputs, str):
443 outputs = [
444 tuple(
445 Path(
446 self.outputs.format(
447 model_id=self.descr_id, output_id=t, sample_id=s
448 )
449 )
450 for t in output_ids
451 )
452 for s in sample_ids
453 ]
454 else:
455 outputs = [
456 tuple(
457 Path(p.format(model_id=self.descr_id, output_id=t, sample_id=s))
458 for t, p in zip(output_ids, self.outputs)
459 )
460 for s in sample_ids
461 ]
463 for i, out in enumerate(outputs, start=1):
464 if len(set(out)) < len(out):
465 raise ValueError(
466 f"[output sample #{i}] Include '{ output_id} ' in path pattern or explicitly specify {len(output_ids)} distinct output paths (got {out})"
467 )
469 if len(out) != len(output_ids):
470 raise ValueError(
471 f"[output sample #{i}] Expected {len(output_ids)} outputs {output_ids}, got {out}"
472 )
474 return outputs
476 outputs = expand_outputs()
478 sample_paths_out = [
479 {MemberId(t): Path(p) for t, p in zip(output_ids, out)} for out in outputs
480 ]
482 if not self.overwrite:
483 for sample_paths in sample_paths_out:
484 for p in sample_paths.values():
485 if p.exists():
486 raise FileExistsError(
487 f"{p} already exists. use --overwrite to (re-)write outputs anyway."
488 )
489 if self.preview:
490 print("🛈 bioimageio prediction preview structure:")
491 pprint(
492 {
493 "{sample_id}": dict(
494 inputs={"{input_id}": "<input path>"},
495 outputs={"{output_id}": "<output path>"},
496 )
497 }
498 )
499 print("🔎 bioimageio prediction preview output:")
500 pprint(
501 {
502 s: dict(
503 inputs={t: p.as_posix() for t, p in sp_in.items()},
504 outputs={t: p.as_posix() for t, p in sp_out.items()},
505 )
506 for s, sp_in, sp_out in zip(
507 sample_ids, sample_paths_in, sample_paths_out
508 )
509 }
510 )
511 return
513 def input_dataset(stat: Stat):
514 for s, sp_in in zip(sample_ids, sample_paths_in):
515 yield load_sample_for_model(
516 model=model_descr,
517 paths=sp_in,
518 stat=stat,
519 sample_id=s,
520 )
522 stat: Dict[Measure, MeasureValue] = dict(
523 _get_stat(
524 model_descr, input_dataset({}), len(sample_ids), self.stats
525 ).items()
526 )
528 pp = create_prediction_pipeline(
529 model_descr,
530 weight_format=None if self.weight_format == "any" else self.weight_format,
531 )
532 predict_method = (
533 pp.predict_sample_with_blocking
534 if self.blockwise
535 else pp.predict_sample_without_blocking
536 )
538 for sample_in, sp_out in tqdm(
539 zip(input_dataset(dict(stat)), sample_paths_out),
540 total=len(inputs),
541 desc=f"predict with {self.descr_id}",
542 unit="sample",
543 ):
544 sample_out = predict_method(sample_in)
545 save_sample(sp_out, sample_out)
548JSON_FILE = "bioimageio-cli.json"
549YAML_FILE = "bioimageio-cli.yaml"
552class Bioimageio(
553 BaseSettings,
554 cli_parse_args=True,
555 cli_prog_name="bioimageio",
556 cli_use_class_docs_for_groups=True,
557 cli_implicit_flags=True,
558 use_attribute_docstrings=True,
559):
560 """bioimageio - CLI for bioimage.io resources 🦒"""
562 model_config = SettingsConfigDict(
563 json_file=JSON_FILE,
564 yaml_file=YAML_FILE,
565 )
567 validate_format: CliSubCommand[ValidateFormatCmd] = Field(alias="validate-format")
568 "Check a resource's metadata format"
570 test: CliSubCommand[TestCmd]
571 "Test a bioimageio resource (beyond meta data formatting)"
573 package: CliSubCommand[PackageCmd]
574 "Package a resource"
576 predict: CliSubCommand[PredictCmd]
577 "Predict with a model resource"
579 @classmethod
580 def settings_customise_sources(
581 cls,
582 settings_cls: Type[BaseSettings],
583 init_settings: PydanticBaseSettingsSource,
584 env_settings: PydanticBaseSettingsSource,
585 dotenv_settings: PydanticBaseSettingsSource,
586 file_secret_settings: PydanticBaseSettingsSource,
587 ) -> Tuple[PydanticBaseSettingsSource, ...]:
588 cli: CliSettingsSource[BaseSettings] = CliSettingsSource(
589 settings_cls,
590 cli_parse_args=True,
591 formatter_class=RawTextHelpFormatter,
592 )
593 sys_args = pformat(sys.argv)
594 logger.info("starting CLI with arguments:\n{}", sys_args)
595 return (
596 cli,
597 init_settings,
598 YamlConfigSettingsSource(settings_cls),
599 JsonConfigSettingsSource(settings_cls),
600 )
602 @model_validator(mode="before")
603 @classmethod
604 def _log(cls, data: Any):
605 logger.info(
606 "loaded CLI input:\n{}",
607 pformat({k: v for k, v in data.items() if v is not None}),
608 )
609 return data
611 def run(self):
612 logger.info(
613 "executing CLI command:\n{}",
614 pformat({k: v for k, v in self.model_dump().items() if v is not None}),
615 )
616 cmd = self.validate_format or self.test or self.package or self.predict
617 assert cmd is not None
618 cmd.run()
621assert isinstance(Bioimageio.__doc__, str)
622Bioimageio.__doc__ += f"""
624library versions:
625 bioimageio.core {VERSION}
626 bioimageio.spec {VERSION}
628spec format versions:
629 model RDF {ModelDescr.implemented_format_version}
630 dataset RDF {DatasetDescr.implemented_format_version}
631 notebook RDF {NotebookDescr.implemented_format_version}
633"""
636def _get_sample_ids(
637 input_paths: Sequence[Mapping[MemberId, Path]],
638) -> Sequence[SampleId]:
639 """Get sample ids for given input paths, based on the common path per sample.
641 Falls back to sample01, samle02, etc..."""
643 matcher = SequenceMatcher()
645 def get_common_seq(seqs: Sequence[Sequence[str]]) -> Sequence[str]:
646 """extract a common sequence from multiple sequences
647 (order sensitive; strips whitespace and slashes)
648 """
649 common = seqs[0]
651 for seq in seqs[1:]:
652 if not seq:
653 continue
654 matcher.set_seqs(common, seq)
655 i, _, size = matcher.find_longest_match()
656 common = common[i : i + size]
658 if isinstance(common, str):
659 common = common.strip().strip("/")
660 else:
661 common = [cs for c in common if (cs := c.strip().strip("/"))]
663 if not common:
664 raise ValueError(f"failed to find common sequence for {seqs}")
666 return common
668 def get_shorter_diff(seqs: Sequence[Sequence[str]]) -> List[Sequence[str]]:
669 """get a shorter sequence whose entries are still unique
670 (order sensitive, not minimal sequence)
671 """
672 min_seq_len = min(len(s) for s in seqs)
673 # cut from the start
674 for start in range(min_seq_len - 1, -1, -1):
675 shortened = [s[start:] for s in seqs]
676 if len(set(shortened)) == len(seqs):
677 min_seq_len -= start
678 break
679 else:
680 seen: Set[Sequence[str]] = set()
681 dupes = [s for s in seqs if s in seen or seen.add(s)]
682 raise ValueError(f"Found duplicate entries {dupes}")
684 # cut from the end
685 for end in range(min_seq_len - 1, 1, -1):
686 shortened = [s[:end] for s in shortened]
687 if len(set(shortened)) == len(seqs):
688 break
690 return shortened
692 full_tensor_ids = [
693 sorted(
694 p.resolve().with_suffix("").as_posix() for p in input_sample_paths.values()
695 )
696 for input_sample_paths in input_paths
697 ]
698 try:
699 long_sample_ids = [get_common_seq(t) for t in full_tensor_ids]
700 sample_ids = get_shorter_diff(long_sample_ids)
701 except ValueError as e:
702 raise ValueError(f"failed to extract sample ids: {e}")
704 return sample_ids