Coverage for src / bioimageio / core / cli.py: 84%

377 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-18 11:12 +0000

1"""bioimageio CLI 

2 

3Note: Some docstrings use a hair space ' ' 

4 to place the added '(default: ...)' on a new line. 

5""" 

6 

7import json 

8import shutil 

9import subprocess 

10import sys 

11from abc import ABC 

12from argparse import RawTextHelpFormatter 

13from difflib import SequenceMatcher 

14from functools import cached_property 

15from io import StringIO 

16from pathlib import Path 

17from pprint import pformat, pprint 

18from typing import ( 

19 Annotated, 

20 Any, 

21 Dict, 

22 Iterable, 

23 List, 

24 Literal, 

25 Mapping, 

26 Optional, 

27 Sequence, 

28 Set, 

29 Tuple, 

30 Type, 

31 Union, 

32) 

33 

34import rich.markdown 

35from loguru import logger 

36from pydantic import ( 

37 AliasChoices, 

38 BaseModel, 

39 Field, 

40 PlainSerializer, 

41 WithJsonSchema, 

42 model_validator, 

43) 

44from pydantic_settings import ( 

45 BaseSettings, 

46 CliApp, 

47 CliPositionalArg, 

48 CliSettingsSource, 

49 CliSubCommand, 

50 JsonConfigSettingsSource, 

51 PydanticBaseSettingsSource, 

52 SettingsConfigDict, 

53 YamlConfigSettingsSource, 

54) 

55from tqdm import tqdm 

56from typing_extensions import assert_never 

57 

58import bioimageio.spec 

59from bioimageio.core import __version__ 

60from bioimageio.spec import ( 

61 AnyModelDescr, 

62 InvalidDescr, 

63 ResourceDescr, 

64 load_description, 

65 save_bioimageio_yaml_only, 

66 settings, 

67 update_format, 

68 update_hashes, 

69) 

70from bioimageio.spec._internal.io import is_yaml_value 

71from bioimageio.spec._internal.io_utils import open_bioimageio_yaml 

72from bioimageio.spec._internal.types import FormatVersionPlaceholder, NotEmpty 

73from bioimageio.spec.dataset import DatasetDescr 

74from bioimageio.spec.model import ModelDescr, v0_4, v0_5 

75from bioimageio.spec.notebook import NotebookDescr 

76from bioimageio.spec.utils import ( 

77 empty_cache, 

78 ensure_description_is_model, 

79 get_reader, 

80 write_yaml, 

81) 

82 

83from .commands import WeightFormatArgAll, WeightFormatArgAny, package, test 

84from .common import MemberId, SampleId, SupportedWeightsFormat 

85from .digest_spec import get_member_ids, load_sample_for_model 

86from .io import load_dataset_stat, save_dataset_stat, save_sample 

87from .prediction import create_prediction_pipeline 

88from .proc_setup import ( 

89 DatasetMeasure, 

90 Measure, 

91 MeasureValue, 

92 StatsCalculator, 

93 get_required_dataset_measures, 

94) 

95from .sample import Sample 

96from .stat_measures import Stat 

97from .utils import compare 

98from .weight_converters._add_weights import add_weights 

99 

100WEIGHT_FORMAT_ALIASES = AliasChoices( 

101 "weight-format", 

102 "weights-format", 

103 "weight_format", 

104 "weights_format", 

105) 

106 

107 

108class CmdBase(BaseModel, use_attribute_docstrings=True, cli_implicit_flags=True): 

109 pass 

110 

111 

112class ArgMixin(BaseModel, use_attribute_docstrings=True, cli_implicit_flags=True): 

113 pass 

114 

115 

116class WithSummaryLogging(ArgMixin): 

117 summary: List[Union[Literal["display"], Path]] = Field( 

118 default_factory=lambda: ["display"], 

119 examples=[ 

120 Path("summary.md"), 

121 Path("bioimageio_summaries/"), 

122 ["display", Path("summary.md")], 

123 ], 

124 ) 

125 """Display the validation summary or save it as JSON, Markdown or HTML. 

126 The format is chosen based on the suffix: `.json`, `.md`, `.html`. 

127 If a folder is given (path w/o suffix) the summary is saved in all formats. 

128 Choose/add `"display"` to render the validation summary to the terminal. 

129 """ 

130 

131 def log(self, descr: Union[ResourceDescr, InvalidDescr]): 

132 _ = descr.validation_summary.log(self.summary) 

133 

134 

135class WithSource(ArgMixin): 

136 source: CliPositionalArg[str] 

137 """Url/path to a (folder with a) bioimageio.yaml/rdf.yaml file 

138 or a bioimage.io resource identifier, e.g. 'affable-shark'""" 

139 

140 @cached_property 

141 def descr(self): 

142 return load_description(self.source) 

143 

144 @property 

145 def descr_id(self) -> str: 

146 """a more user-friendly description id 

147 (replacing legacy ids with their nicknames) 

148 """ 

149 if isinstance(self.descr, InvalidDescr): 

150 return str(getattr(self.descr, "id", getattr(self.descr, "name"))) 

151 

152 nickname = None 

153 if ( 

154 isinstance(self.descr.config, v0_5.Config) 

155 and (bio_config := self.descr.config.bioimageio) 

156 and bio_config.model_extra is not None 

157 ): 

158 nickname = bio_config.model_extra.get("nickname") 

159 

160 return str(nickname or self.descr.id or self.descr.name) 

161 

162 

163class ValidateFormatCmd(CmdBase, WithSource, WithSummaryLogging): 

164 """Validate the meta data format of a bioimageio resource.""" 

165 

166 perform_io_checks: bool = Field( 

167 settings.perform_io_checks, alias="perform-io-checks" 

168 ) 

169 """Wether or not to perform validations that requires downloading remote files. 

170 Note: Default value is set by `BIOIMAGEIO_PERFORM_IO_CHECKS` environment variable. 

171 """ 

172 

173 @cached_property 

174 def descr(self): 

175 return load_description(self.source, perform_io_checks=self.perform_io_checks) 

176 

177 def cli_cmd(self): 

178 self.log(self.descr) 

179 sys.exit( 

180 0 

181 if self.descr.validation_summary.status in ("valid-format", "passed") 

182 else 1 

183 ) 

184 

185 

186class TestCmd(CmdBase, WithSource, WithSummaryLogging): 

187 """Test a bioimageio resource (beyond meta data formatting).""" 

188 

189 weight_format: WeightFormatArgAll = Field( 

190 "all", 

191 alias="weight-format", 

192 validation_alias=WEIGHT_FORMAT_ALIASES, 

193 ) 

194 """The weight format to limit testing to. 

195 

196 (only relevant for model resources)""" 

197 

198 devices: Optional[List[str]] = None 

199 """Device(s) to use for testing""" 

200 

201 runtime_env: Union[Literal["currently-active", "as-described"], Path] = Field( 

202 "currently-active", alias="runtime-env" 

203 ) 

204 """The python environment to run the tests in 

205 - `"currently-active"`: use active Python interpreter 

206 - `"as-described"`: generate a conda environment YAML file based on the model 

207 weights description. 

208 - A path to a conda environment YAML. 

209 Note: The `bioimageio.core` dependency will be added automatically if not present. 

210 """ 

211 

212 working_dir: Optional[Path] = Field(None, alias="working-dir") 

213 """(for debugging) Directory to save any temporary files.""" 

214 

215 determinism: Literal["seed_only", "full"] = "seed_only" 

216 """Modes to improve reproducibility of test outputs.""" 

217 

218 stop_early: bool = Field( 

219 False, alias="stop-early", validation_alias=AliasChoices("stop-early", "x") 

220 ) 

221 """Do not run further subtests after a failed one.""" 

222 

223 format_version: Union[FormatVersionPlaceholder, str] = Field( 

224 "discover", alias="format-version" 

225 ) 

226 """The format version to use for testing. 

227 - 'latest': Use the latest implemented format version for the given resource type (may trigger auto updating) 

228 - 'discover': Use the format version as described in the resource description 

229 - '0.4', '0.5', ...: Use the specified format version (may trigger auto updating) 

230 """ 

231 

232 def cli_cmd(self): 

233 sys.exit( 

234 test( 

235 self.descr, 

236 weight_format=self.weight_format, 

237 devices=self.devices, 

238 summary=self.summary, 

239 runtime_env=self.runtime_env, 

240 determinism=self.determinism, 

241 format_version=self.format_version, 

242 working_dir=self.working_dir, 

243 ) 

244 ) 

245 

246 

247class PackageCmd(CmdBase, WithSource, WithSummaryLogging): 

248 """Save a resource's metadata with its associated files.""" 

249 

250 path: CliPositionalArg[Path] 

251 """The path to write the (zipped) package to. 

252 If it does not have a `.zip` suffix 

253 this command will save the package as an unzipped folder instead.""" 

254 

255 weight_format: WeightFormatArgAll = Field( 

256 "all", 

257 alias="weight-format", 

258 validation_alias=WEIGHT_FORMAT_ALIASES, 

259 ) 

260 """The weight format to include in the package (for model descriptions only).""" 

261 

262 def cli_cmd(self): 

263 if isinstance(self.descr, InvalidDescr): 

264 self.log(self.descr) 

265 raise ValueError(f"Invalid {self.descr.type} description.") 

266 

267 sys.exit( 

268 package( 

269 self.descr, 

270 self.path, 

271 weight_format=self.weight_format, 

272 ) 

273 ) 

274 

275 

276def _get_stat( 

277 model_descr: AnyModelDescr, 

278 dataset: Iterable[Sample], 

279 dataset_length: int, 

280 stats_path: Path, 

281) -> Mapping[DatasetMeasure, MeasureValue]: 

282 req_dataset_meas, _ = get_required_dataset_measures(model_descr) 

283 if not req_dataset_meas: 

284 return {} 

285 

286 req_dataset_meas, _ = get_required_dataset_measures(model_descr) 

287 

288 if stats_path.exists(): 

289 logger.info("loading precomputed dataset measures from {}", stats_path) 

290 stat = load_dataset_stat(stats_path) 

291 for m in req_dataset_meas: 

292 if m not in stat: 

293 raise ValueError(f"Missing {m} in {stats_path}") 

294 

295 return stat 

296 

297 stats_calc = StatsCalculator(req_dataset_meas) 

298 

299 for sample in tqdm( 

300 dataset, total=dataset_length, desc="precomputing dataset stats", unit="sample" 

301 ): 

302 stats_calc.update(sample) 

303 

304 stat = stats_calc.finalize() 

305 save_dataset_stat(stat, stats_path) 

306 

307 return stat 

308 

309 

310class UpdateCmdBase(CmdBase, WithSource, ABC): 

311 output: Union[Literal["display", "stdout"], Path] = "display" 

312 """Output updated bioimageio.yaml to the terminal or write to a file. 

313 Notes: 

314 - `"display"`: Render to the terminal with syntax highlighting. 

315 - `"stdout"`: Write to sys.stdout without syntax highligthing. 

316 (More convenient for copying the updated bioimageio.yaml from the terminal.) 

317 """ 

318 

319 diff: Union[bool, Path] = Field(True, alias="diff") 

320 """Output a diff of original and updated bioimageio.yaml. 

321 If a given path has an `.html` extension, a standalone HTML file is written, 

322 otherwise the diff is saved in unified diff format (pure text). 

323 """ 

324 

325 exclude_unset: bool = Field(True, alias="exclude-unset") 

326 """Exclude fields that have not explicitly be set.""" 

327 

328 exclude_defaults: bool = Field(False, alias="exclude-defaults") 

329 """Exclude fields that have the default value (even if set explicitly).""" 

330 

331 @cached_property 

332 def updated(self) -> Union[ResourceDescr, InvalidDescr]: 

333 raise NotImplementedError 

334 

335 def cli_cmd(self): 

336 original_yaml = open_bioimageio_yaml(self.source).unparsed_content 

337 assert isinstance(original_yaml, str) 

338 stream = StringIO() 

339 

340 save_bioimageio_yaml_only( 

341 self.updated, 

342 stream, 

343 exclude_unset=self.exclude_unset, 

344 exclude_defaults=self.exclude_defaults, 

345 ) 

346 updated_yaml = stream.getvalue() 

347 

348 diff = compare( 

349 original_yaml.split("\n"), 

350 updated_yaml.split("\n"), 

351 diff_format=( 

352 "html" 

353 if isinstance(self.diff, Path) and self.diff.suffix == ".html" 

354 else "unified" 

355 ), 

356 ) 

357 

358 if isinstance(self.diff, Path): 

359 _ = self.diff.write_text(diff, encoding="utf-8") 

360 elif self.diff: 

361 console = rich.console.Console() 

362 diff_md = f"## Diff\n\n````````diff\n{diff}\n````````" 

363 console.print(rich.markdown.Markdown(diff_md)) 

364 

365 if isinstance(self.output, Path): 

366 _ = self.output.write_text(updated_yaml, encoding="utf-8") 

367 logger.info(f"written updated description to {self.output}") 

368 elif self.output == "display": 

369 updated_md = f"## Updated bioimageio.yaml\n\n```yaml\n{updated_yaml}\n```" 

370 rich.console.Console().print(rich.markdown.Markdown(updated_md)) 

371 elif self.output == "stdout": 

372 print(updated_yaml) 

373 else: 

374 assert_never(self.output) 

375 

376 if isinstance(self.updated, InvalidDescr): 

377 logger.warning("Update resulted in invalid description") 

378 _ = self.updated.validation_summary.display() 

379 

380 

381class UpdateFormatCmd(UpdateCmdBase): 

382 """Update the metadata format to the latest format version.""" 

383 

384 exclude_defaults: bool = Field(True, alias="exclude-defaults") 

385 """Exclude fields that have the default value (even if set explicitly). 

386 

387 Note: 

388 The update process sets most unset fields explicitly with their default value. 

389 """ 

390 

391 perform_io_checks: bool = Field( 

392 settings.perform_io_checks, alias="perform-io-checks" 

393 ) 

394 """Wether or not to attempt validation that may require file download. 

395 If `True` file hash values are added if not present.""" 

396 

397 @cached_property 

398 def updated(self): 

399 return update_format( 

400 self.source, 

401 exclude_defaults=self.exclude_defaults, 

402 perform_io_checks=self.perform_io_checks, 

403 ) 

404 

405 

406class UpdateHashesCmd(UpdateCmdBase): 

407 """Create a bioimageio.yaml description with updated file hashes.""" 

408 

409 @cached_property 

410 def updated(self): 

411 return update_hashes(self.source) 

412 

413 

414class PredictCmd(CmdBase, WithSource): 

415 """Run inference on your data with a bioimage.io model.""" 

416 

417 inputs: NotEmpty[List[Union[str, NotEmpty[List[str]]]]] = Field( 

418 default_factory=lambda: ["{input_id}/001.tif"] 

419 ) 

420 """Model input sample paths (for each input tensor) 

421 

422 The input paths are expected to have shape... 

423 - (n_samples,) or (n_samples,1) for models expecting a single input tensor 

424 - (n_samples,) containing the substring '{input_id}', or 

425 - (n_samples, n_model_inputs) to provide each input tensor path explicitly. 

426 

427 All substrings that are replaced by metadata from the model description: 

428 - '{model_id}' 

429 - '{input_id}' 

430 

431 Example inputs to process sample 'a' and 'b' 

432 for a model expecting a 'raw' and a 'mask' input tensor: 

433 --inputs="[[\\"a_raw.tif\\",\\"a_mask.tif\\"],[\\"b_raw.tif\\",\\"b_mask.tif\\"]]" 

434 (Note that JSON double quotes need to be escaped.) 

435 

436 Alternatively a `bioimageio-cli.yaml` (or `bioimageio-cli.json`) file 

437 may provide the arguments, e.g.: 

438 ```yaml 

439 inputs: 

440 - [a_raw.tif, a_mask.tif] 

441 - [b_raw.tif, b_mask.tif] 

442 ``` 

443 

444 `.npy` and any file extension supported by imageio are supported. 

445 Aavailable formats are listed at 

446 https://imageio.readthedocs.io/en/stable/formats/index.html#all-formats. 

447 Some formats have additional dependencies. 

448 

449 

450 """ 

451 

452 outputs: Union[str, NotEmpty[Tuple[str, ...]]] = ( 

453 "outputs_{model_id}/{output_id}/{sample_id}.tif" 

454 ) 

455 """Model output path pattern (per output tensor) 

456 

457 All substrings that are replaced: 

458 - '{model_id}' (from model description) 

459 - '{output_id}' (from model description) 

460 - '{sample_id}' (extracted from input paths) 

461 

462 

463 """ 

464 

465 overwrite: bool = False 

466 """allow overwriting existing output files""" 

467 

468 blockwise: bool = False 

469 """process inputs blockwise""" 

470 

471 stats: Annotated[ 

472 Path, 

473 WithJsonSchema({"type": "string"}), 

474 PlainSerializer(lambda p: p.as_posix(), return_type=str), 

475 ] = Path("dataset_statistics.json") 

476 """path to dataset statistics 

477 (will be written if it does not exist, 

478 but the model requires statistical dataset measures) 

479  """ 

480 

481 preview: bool = False 

482 """preview which files would be processed 

483 and what outputs would be generated.""" 

484 

485 weight_format: WeightFormatArgAny = Field( 

486 "any", 

487 alias="weight-format", 

488 validation_alias=WEIGHT_FORMAT_ALIASES, 

489 ) 

490 """The weight format to use.""" 

491 

492 example: bool = False 

493 """generate and run an example 

494 

495 1. downloads example model inputs 

496 2. creates a `{model_id}_example` folder 

497 3. writes input arguments to `{model_id}_example/bioimageio-cli.yaml` 

498 4. executes a preview dry-run 

499 5. executes prediction with example input 

500 

501 

502 """ 

503 

504 def _example(self): 

505 model_descr = ensure_description_is_model(self.descr) 

506 input_ids = get_member_ids(model_descr.inputs) 

507 example_inputs = ( 

508 model_descr.sample_inputs 

509 if isinstance(model_descr, v0_4.ModelDescr) 

510 else [ 

511 t 

512 for ipt in model_descr.inputs 

513 if (t := ipt.sample_tensor or ipt.test_tensor) 

514 ] 

515 ) 

516 if not example_inputs: 

517 raise ValueError(f"{self.descr_id} does not specify any example inputs.") 

518 

519 inputs001: List[str] = [] 

520 example_path = Path(f"{self.descr_id}_example") 

521 example_path.mkdir(exist_ok=True) 

522 

523 for t, src in zip(input_ids, example_inputs): 

524 reader = get_reader(src) 

525 dst = Path(f"{example_path}/{t}/001{reader.suffix}") 

526 dst.parent.mkdir(parents=True, exist_ok=True) 

527 inputs001.append(dst.as_posix()) 

528 with dst.open("wb") as f: 

529 shutil.copyfileobj(reader, f) 

530 

531 inputs = [inputs001] 

532 output_pattern = f"{example_path}/outputs/{{output_id}}/{{sample_id}}.tif" 

533 

534 bioimageio_cli_path = example_path / YAML_FILE 

535 stats_file = "dataset_statistics.json" 

536 stats = (example_path / stats_file).as_posix() 

537 cli_example_args = dict( 

538 inputs=inputs, 

539 outputs=output_pattern, 

540 stats=stats_file, 

541 blockwise=self.blockwise, 

542 ) 

543 assert is_yaml_value(cli_example_args), cli_example_args 

544 write_yaml( 

545 cli_example_args, 

546 bioimageio_cli_path, 

547 ) 

548 

549 yaml_file_content = None 

550 

551 # escaped double quotes 

552 inputs_json = json.dumps(inputs) 

553 inputs_escaped = inputs_json.replace('"', r"\"") 

554 source_escaped = self.source.replace('"', r"\"") 

555 

556 def get_example_command(preview: bool, escape: bool = False): 

557 q: str = '"' if escape else "" 

558 

559 return [ 

560 "bioimageio", 

561 "predict", 

562 # --no-preview not supported for py=3.8 

563 *(["--preview"] if preview else []), 

564 "--overwrite", 

565 *(["--blockwise"] if self.blockwise else []), 

566 f"--stats={q}{stats}{q}", 

567 f"--inputs={q}{inputs_escaped if escape else inputs_json}{q}", 

568 f"--outputs={q}{output_pattern}{q}", 

569 f"{q}{source_escaped if escape else self.source}{q}", 

570 ] 

571 

572 if Path(YAML_FILE).exists(): 

573 logger.info( 

574 "temporarily removing '{}' to execute example prediction", YAML_FILE 

575 ) 

576 yaml_file_content = Path(YAML_FILE).read_bytes() 

577 Path(YAML_FILE).unlink() 

578 

579 try: 

580 _ = subprocess.run(get_example_command(True), check=True) 

581 _ = subprocess.run(get_example_command(False), check=True) 

582 finally: 

583 if yaml_file_content is not None: 

584 _ = Path(YAML_FILE).write_bytes(yaml_file_content) 

585 logger.debug("restored '{}'", YAML_FILE) 

586 

587 print( 

588 "🎉 Sucessfully ran example prediction!\n" 

589 + "To predict the example input using the CLI example config file" 

590 + f" {example_path / YAML_FILE}, execute `bioimageio predict` from {example_path}:\n" 

591 + f"$ cd {str(example_path)}\n" 

592 + f'$ bioimageio predict "{source_escaped}"\n\n' 

593 + "Alternatively run the following command" 

594 + " in the current workind directory, not the example folder:\n$ " 

595 + " ".join(get_example_command(False, escape=True)) 

596 + f"\n(note that a local '{JSON_FILE}' or '{YAML_FILE}' may interfere with this)" 

597 ) 

598 

599 def cli_cmd(self): 

600 if self.example: 

601 return self._example() 

602 

603 model_descr = ensure_description_is_model(self.descr) 

604 

605 input_ids = get_member_ids(model_descr.inputs) 

606 output_ids = get_member_ids(model_descr.outputs) 

607 

608 minimum_input_ids = tuple( 

609 str(ipt.id) if isinstance(ipt, v0_5.InputTensorDescr) else str(ipt.name) 

610 for ipt in model_descr.inputs 

611 if not isinstance(ipt, v0_5.InputTensorDescr) or not ipt.optional 

612 ) 

613 maximum_input_ids = tuple( 

614 str(ipt.id) if isinstance(ipt, v0_5.InputTensorDescr) else str(ipt.name) 

615 for ipt in model_descr.inputs 

616 ) 

617 

618 def expand_inputs(i: int, ipt: Union[str, Sequence[str]]) -> Tuple[str, ...]: 

619 if isinstance(ipt, str): 

620 ipts = tuple( 

621 ipt.format(model_id=self.descr_id, input_id=t) for t in input_ids 

622 ) 

623 else: 

624 ipts = tuple( 

625 p.format(model_id=self.descr_id, input_id=t) 

626 for t, p in zip(input_ids, ipt) 

627 ) 

628 

629 if len(set(ipts)) < len(ipts): 

630 if len(minimum_input_ids) == len(maximum_input_ids): 

631 n = len(minimum_input_ids) 

632 else: 

633 n = f"{len(minimum_input_ids)}-{len(maximum_input_ids)}" 

634 

635 raise ValueError( 

636 f"[input sample #{i}] Include '{{input_id}}' in path pattern or explicitly specify {n} distinct input paths (got {ipt})" 

637 ) 

638 

639 if len(ipts) < len(minimum_input_ids): 

640 raise ValueError( 

641 f"[input sample #{i}] Expected at least {len(minimum_input_ids)} inputs {minimum_input_ids}, got {ipts}" 

642 ) 

643 

644 if len(ipts) > len(maximum_input_ids): 

645 raise ValueError( 

646 f"Expected at most {len(maximum_input_ids)} inputs {maximum_input_ids}, got {ipts}" 

647 ) 

648 

649 return ipts 

650 

651 inputs = [expand_inputs(i, ipt) for i, ipt in enumerate(self.inputs, start=1)] 

652 

653 sample_paths_in = [ 

654 {t: Path(p) for t, p in zip(input_ids, ipts)} for ipts in inputs 

655 ] 

656 

657 sample_ids = _get_sample_ids(sample_paths_in) 

658 

659 def expand_outputs(): 

660 if isinstance(self.outputs, str): 

661 outputs = [ 

662 tuple( 

663 Path( 

664 self.outputs.format( 

665 model_id=self.descr_id, output_id=t, sample_id=s 

666 ) 

667 ) 

668 for t in output_ids 

669 ) 

670 for s in sample_ids 

671 ] 

672 else: 

673 outputs = [ 

674 tuple( 

675 Path(p.format(model_id=self.descr_id, output_id=t, sample_id=s)) 

676 for t, p in zip(output_ids, self.outputs) 

677 ) 

678 for s in sample_ids 

679 ] 

680 # check for distinctness and correct number within each output sample 

681 for i, out in enumerate(outputs, start=1): 

682 if len(set(out)) < len(out): 

683 raise ValueError( 

684 f"[output sample #{i}] Include '{{output_id}}' in path pattern or explicitly specify {len(output_ids)} distinct output paths (got {out})" 

685 ) 

686 

687 if len(out) != len(output_ids): 

688 raise ValueError( 

689 f"[output sample #{i}] Expected {len(output_ids)} outputs {output_ids}, got {out}" 

690 ) 

691 

692 # check for distinctness across all output samples 

693 all_output_paths = [p for out in outputs for p in out] 

694 if len(set(all_output_paths)) < len(all_output_paths): 

695 raise ValueError( 

696 "Output paths are not distinct across samples. " 

697 + f"Make sure to include '{{sample_id}}' in the output path pattern." 

698 ) 

699 

700 return outputs 

701 

702 outputs = expand_outputs() 

703 

704 sample_paths_out = [ 

705 {MemberId(t): Path(p) for t, p in zip(output_ids, out)} for out in outputs 

706 ] 

707 

708 if not self.overwrite: 

709 for sample_paths in sample_paths_out: 

710 for p in sample_paths.values(): 

711 if p.exists(): 

712 raise FileExistsError( 

713 f"{p} already exists. use --overwrite to (re-)write outputs anyway." 

714 ) 

715 if self.preview: 

716 print("🛈 bioimageio prediction preview structure:") 

717 pprint( 

718 { 

719 "{sample_id}": dict( 

720 inputs={"{input_id}": "<input path>"}, 

721 outputs={"{output_id}": "<output path>"}, 

722 ) 

723 } 

724 ) 

725 print("🔎 bioimageio prediction preview output:") 

726 pprint( 

727 { 

728 s: dict( 

729 inputs={t: p.as_posix() for t, p in sp_in.items()}, 

730 outputs={t: p.as_posix() for t, p in sp_out.items()}, 

731 ) 

732 for s, sp_in, sp_out in zip( 

733 sample_ids, sample_paths_in, sample_paths_out 

734 ) 

735 } 

736 ) 

737 return 

738 

739 def input_dataset(stat: Stat): 

740 for s, sp_in in zip(sample_ids, sample_paths_in): 

741 yield load_sample_for_model( 

742 model=model_descr, 

743 paths=sp_in, 

744 stat=stat, 

745 sample_id=s, 

746 ) 

747 

748 stat: Dict[Measure, MeasureValue] = dict( 

749 _get_stat( 

750 model_descr, input_dataset({}), len(sample_ids), self.stats 

751 ).items() 

752 ) 

753 

754 pp = create_prediction_pipeline( 

755 model_descr, 

756 weight_format=None if self.weight_format == "any" else self.weight_format, 

757 ) 

758 predict_method = ( 

759 pp.predict_sample_with_blocking 

760 if self.blockwise 

761 else pp.predict_sample_without_blocking 

762 ) 

763 

764 for sample_in, sp_out in tqdm( 

765 zip(input_dataset(dict(stat)), sample_paths_out), 

766 total=len(inputs), 

767 desc=f"predict with {self.descr_id}", 

768 unit="sample", 

769 ): 

770 sample_out = predict_method(sample_in) 

771 save_sample(sp_out, sample_out) 

772 

773 

774class AddWeightsCmd(CmdBase, WithSource, WithSummaryLogging): 

775 """Add additional weights to a model description by converting from available formats.""" 

776 

777 output: CliPositionalArg[Path] 

778 """The path to write the updated model package to.""" 

779 

780 source_format: Optional[SupportedWeightsFormat] = Field(None, alias="source-format") 

781 """Exclusively use these weights to convert to other formats.""" 

782 

783 target_format: Optional[SupportedWeightsFormat] = Field(None, alias="target-format") 

784 """Exclusively add this weight format.""" 

785 

786 verbose: bool = False 

787 """Log more (error) output.""" 

788 

789 tracing: bool = True 

790 """Allow tracing when converting pytorch_state_dict to torchscript 

791 (still uses scripting if possible).""" 

792 

793 def cli_cmd(self): 

794 model_descr = ensure_description_is_model(self.descr) 

795 if isinstance(model_descr, v0_4.ModelDescr): 

796 raise TypeError( 

797 f"model format {model_descr.format_version} not supported." 

798 + " Please update the model first." 

799 ) 

800 updated_model_descr = add_weights( 

801 model_descr, 

802 output_path=self.output, 

803 source_format=self.source_format, 

804 target_format=self.target_format, 

805 verbose=self.verbose, 

806 allow_tracing=self.tracing, 

807 ) 

808 self.log(updated_model_descr) 

809 

810 

811class EmptyCache(CmdBase): 

812 """Empty the bioimageio cache directory.""" 

813 

814 def cli_cmd(self): 

815 empty_cache() 

816 

817 

818JSON_FILE = "bioimageio-cli.json" 

819YAML_FILE = "bioimageio-cli.yaml" 

820 

821 

822class Bioimageio( 

823 BaseSettings, 

824 cli_implicit_flags=True, 

825 cli_parse_args=True, 

826 cli_prog_name="bioimageio", 

827 cli_use_class_docs_for_groups=True, 

828 use_attribute_docstrings=True, 

829): 

830 """bioimageio - CLI for bioimage.io resources 🦒""" 

831 

832 model_config = SettingsConfigDict( 

833 json_file=JSON_FILE, 

834 yaml_file=YAML_FILE, 

835 ) 

836 

837 validate_format: CliSubCommand[ValidateFormatCmd] = Field(alias="validate-format") 

838 "Check a resource's metadata format" 

839 

840 test: CliSubCommand[TestCmd] 

841 "Test a bioimageio resource (beyond meta data formatting)" 

842 

843 package: CliSubCommand[PackageCmd] 

844 "Package a resource" 

845 

846 predict: CliSubCommand[PredictCmd] 

847 "Predict with a model resource" 

848 

849 update_format: CliSubCommand[UpdateFormatCmd] = Field(alias="update-format") 

850 """Update the metadata format""" 

851 

852 update_hashes: CliSubCommand[UpdateHashesCmd] = Field(alias="update-hashes") 

853 """Create a bioimageio.yaml description with updated file hashes.""" 

854 

855 add_weights: CliSubCommand[AddWeightsCmd] = Field(alias="add-weights") 

856 """Add additional weights to a model description by converting from available formats.""" 

857 

858 empty_cache: CliSubCommand[EmptyCache] = Field(alias="empty-cache") 

859 """Empty the bioimageio cache directory.""" 

860 

861 @classmethod 

862 def settings_customise_sources( 

863 cls, 

864 settings_cls: Type[BaseSettings], 

865 init_settings: PydanticBaseSettingsSource, 

866 env_settings: PydanticBaseSettingsSource, 

867 dotenv_settings: PydanticBaseSettingsSource, 

868 file_secret_settings: PydanticBaseSettingsSource, 

869 ) -> Tuple[PydanticBaseSettingsSource, ...]: 

870 cli: CliSettingsSource[BaseSettings] = CliSettingsSource( 

871 settings_cls, 

872 cli_parse_args=True, 

873 formatter_class=RawTextHelpFormatter, 

874 ) 

875 sys_args = pformat(sys.argv) 

876 logger.info("starting CLI with arguments:\n{}", sys_args) 

877 return ( 

878 cli, 

879 init_settings, 

880 YamlConfigSettingsSource(settings_cls), 

881 JsonConfigSettingsSource(settings_cls), 

882 ) 

883 

884 @model_validator(mode="before") 

885 @classmethod 

886 def _log(cls, data: Any): 

887 logger.info( 

888 "loaded CLI input:\n{}", 

889 pformat({k: v for k, v in data.items() if v is not None}), 

890 ) 

891 return data 

892 

893 def cli_cmd(self) -> None: 

894 logger.info( 

895 "executing CLI command:\n{}", 

896 pformat({k: v for k, v in self.model_dump().items() if v is not None}), 

897 ) 

898 _ = CliApp.run_subcommand(self) 

899 

900 

901assert isinstance(Bioimageio.__doc__, str) 

902Bioimageio.__doc__ += f""" 

903 

904library versions: 

905 bioimageio.core {__version__} 

906 bioimageio.spec {bioimageio.spec.__version__} 

907 

908spec format versions: 

909 model RDF {ModelDescr.implemented_format_version} 

910 dataset RDF {DatasetDescr.implemented_format_version} 

911 notebook RDF {NotebookDescr.implemented_format_version} 

912 

913""" 

914 

915 

916def _get_sample_ids( 

917 input_paths: Sequence[Mapping[MemberId, Path]], 

918) -> Sequence[SampleId]: 

919 """Get sample ids for given input paths, based on the common path per sample. 

920 

921 Falls back to sample01, samle02, etc...""" 

922 

923 matcher = SequenceMatcher() 

924 

925 def get_common_seq(seqs: Sequence[Sequence[str]]) -> Sequence[str]: 

926 """extract a common sequence from multiple sequences 

927 (order sensitive; strips whitespace and slashes) 

928 """ 

929 common = seqs[0] 

930 

931 for seq in seqs[1:]: 

932 if not seq: 

933 continue 

934 matcher.set_seqs(common, seq) 

935 i, _, size = matcher.find_longest_match() 

936 common = common[i : i + size] 

937 

938 if isinstance(common, str): 

939 common = common.strip().strip("/") 

940 else: 

941 common = [cs for c in common if (cs := c.strip().strip("/"))] 

942 

943 if not common: 

944 raise ValueError(f"failed to find common sequence for {seqs}") 

945 

946 return common 

947 

948 def get_shorter_diff(seqs: Sequence[Sequence[str]]) -> List[Sequence[str]]: 

949 """get a shorter sequence whose entries are still unique 

950 (order sensitive, not minimal sequence) 

951 """ 

952 min_seq_len = min(len(s) for s in seqs) 

953 # cut from the start 

954 for start in range(min_seq_len - 1, -1, -1): 

955 shortened = [s[start:] for s in seqs] 

956 if len(set(shortened)) == len(seqs): 

957 min_seq_len -= start 

958 break 

959 else: 

960 seen: Set[Sequence[str]] = set() 

961 dupes = [s for s in seqs if s in seen or seen.add(s)] 

962 raise ValueError(f"Found duplicate entries {dupes}") 

963 

964 # cut from the end 

965 for end in range(min_seq_len - 1, 1, -1): 

966 shortened = [s[:end] for s in shortened] 

967 if len(set(shortened)) == len(seqs): 

968 break 

969 

970 return shortened 

971 

972 full_tensor_ids = [ 

973 sorted( 

974 p.resolve().with_suffix("").as_posix() for p in input_sample_paths.values() 

975 ) 

976 for input_sample_paths in input_paths 

977 ] 

978 try: 

979 long_sample_ids = [get_common_seq(t) for t in full_tensor_ids] 

980 sample_ids = get_shorter_diff(long_sample_ids) 

981 except ValueError as e: 

982 raise ValueError(f"failed to extract sample ids: {e}") 

983 

984 return sample_ids