Coverage for src / bioimageio / core / cli.py: 81%

418 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-18 12:35 +0000

1"""bioimageio CLI 

2 

3Note: Some docstrings use a hair space ' ' 

4 to place the added '(default: ...)' on a new line. 

5""" 

6 

7import json 

8import shutil 

9import subprocess 

10import sys 

11from abc import ABC 

12from argparse import RawTextHelpFormatter 

13from difflib import SequenceMatcher 

14from functools import cached_property, partial 

15from io import StringIO 

16from pathlib import Path 

17from pprint import pformat, pprint 

18from typing import ( 

19 Annotated, 

20 Any, 

21 Dict, 

22 Iterable, 

23 List, 

24 Literal, 

25 Mapping, 

26 Optional, 

27 Sequence, 

28 Set, 

29 Tuple, 

30 Type, 

31 Union, 

32) 

33 

34import numpy as np 

35import rich.markdown 

36from loguru import logger 

37from pydantic import ( 

38 AliasChoices, 

39 BaseModel, 

40 Field, 

41 PlainSerializer, 

42 WithJsonSchema, 

43 model_validator, 

44) 

45from pydantic_settings import ( 

46 BaseSettings, 

47 CliApp, 

48 CliPositionalArg, 

49 CliSettingsSource, 

50 CliSubCommand, 

51 JsonConfigSettingsSource, 

52 PydanticBaseSettingsSource, 

53 SettingsConfigDict, 

54 YamlConfigSettingsSource, 

55) 

56from tqdm import tqdm 

57from typing_extensions import assert_never 

58 

59import bioimageio.spec 

60from bioimageio.core import __version__ 

61from bioimageio.spec import ( 

62 AnyModelDescr, 

63 InvalidDescr, 

64 ResourceDescr, 

65 load_description, 

66 save_bioimageio_package, 

67 save_bioimageio_package_as_folder, 

68 save_bioimageio_yaml_only, 

69 settings, 

70 update_format, 

71 update_hashes, 

72) 

73from bioimageio.spec._internal.io import is_yaml_value 

74from bioimageio.spec._internal.io_utils import open_bioimageio_yaml 

75from bioimageio.spec._internal.types import FormatVersionPlaceholder, NotEmpty 

76from bioimageio.spec.dataset import DatasetDescr 

77from bioimageio.spec.model import ModelDescr, v0_4, v0_5 

78from bioimageio.spec.notebook import NotebookDescr 

79from bioimageio.spec.utils import ( 

80 empty_cache, 

81 ensure_description_is_model, 

82 get_reader, 

83 write_yaml, 

84) 

85 

86from .commands import WeightFormatArgAll, WeightFormatArgAny, package, test 

87from .common import MemberId, SampleId, SupportedWeightsFormat 

88from .digest_spec import get_member_ids, load_sample_for_model 

89from .io import load_stat, save_sample, save_stat 

90from .prediction import create_prediction_pipeline 

91from .proc_setup import ( 

92 Measure, 

93 MeasureValue, 

94 StatsCalculator, 

95 get_required_dataset_measures, 

96) 

97from .sample import Sample 

98from .stat_measures import Stat 

99from .utils import compare 

100from .weight_converters._add_weights import add_weights 

101 

102WEIGHT_FORMAT_ALIASES = AliasChoices( 

103 "weight-format", 

104 "weights-format", 

105 "weight_format", 

106 "weights_format", 

107) 

108 

109 

110class CmdBase(BaseModel, use_attribute_docstrings=True, cli_implicit_flags=True): 

111 pass 

112 

113 

114class ArgMixin(BaseModel, use_attribute_docstrings=True, cli_implicit_flags=True): 

115 pass 

116 

117 

118class WithSummaryLogging(ArgMixin): 

119 summary: List[Union[Literal["display"], Path]] = Field( 

120 default_factory=lambda: ["display"], 

121 examples=[ 

122 Path("summary.md"), 

123 Path("bioimageio_summaries/"), 

124 ["display", Path("summary.md")], 

125 ], 

126 ) 

127 """Display the validation summary or save it as JSON, Markdown or HTML. 

128 The format is chosen based on the suffix: `.json`, `.md`, `.html`. 

129 If a folder is given (path w/o suffix) the summary is saved in all formats. 

130 Choose/add `"display"` to render the validation summary to the terminal. 

131 """ 

132 

133 def log(self, descr: Union[ResourceDescr, InvalidDescr]): 

134 _ = descr.validation_summary.log(self.summary) 

135 

136 

137class WithSource(ArgMixin): 

138 source: CliPositionalArg[str] 

139 """Url/path to a (folder with a) bioimageio.yaml/rdf.yaml file 

140 or a bioimage.io resource identifier, e.g. 'affable-shark'""" 

141 

142 @cached_property 

143 def descr(self): 

144 return load_description(self.source) 

145 

146 @property 

147 def descr_id(self) -> str: 

148 """a more user-friendly description id 

149 (replacing legacy ids with their nicknames) 

150 """ 

151 if isinstance(self.descr, InvalidDescr): 

152 return str(getattr(self.descr, "id", getattr(self.descr, "name"))) 

153 

154 nickname = None 

155 if ( 

156 isinstance(self.descr.config, v0_5.Config) 

157 and (bio_config := self.descr.config.bioimageio) 

158 and bio_config.model_extra is not None 

159 ): 

160 nickname = bio_config.model_extra.get("nickname") 

161 

162 return str(nickname or self.descr.id or self.descr.name) 

163 

164 

165class ValidateFormatCmd(CmdBase, WithSource, WithSummaryLogging): 

166 """Validate the meta data format of a bioimageio resource.""" 

167 

168 perform_io_checks: bool = Field( 

169 settings.perform_io_checks, alias="perform-io-checks" 

170 ) 

171 """Wether or not to perform validations that requires downloading remote files. 

172 Note: Default value is set by `BIOIMAGEIO_PERFORM_IO_CHECKS` environment variable. 

173 """ 

174 

175 @cached_property 

176 def descr(self): 

177 return load_description(self.source, perform_io_checks=self.perform_io_checks) 

178 

179 def cli_cmd(self): 

180 self.log(self.descr) 

181 sys.exit( 

182 0 

183 if self.descr.validation_summary.status in ("valid-format", "passed") 

184 else 1 

185 ) 

186 

187 

188class TestCmd(CmdBase, WithSource, WithSummaryLogging): 

189 """Test a bioimageio resource (beyond meta data formatting).""" 

190 

191 weight_format: WeightFormatArgAll = Field( 

192 "all", 

193 alias="weight-format", 

194 validation_alias=WEIGHT_FORMAT_ALIASES, 

195 ) 

196 """The weight format to limit testing to. 

197 

198 (only relevant for model resources)""" 

199 

200 devices: Optional[List[str]] = Field( 

201 None, validation_alias=AliasChoices("devices", "device") 

202 ) 

203 """Device(s) to use""" 

204 

205 runtime_env: Union[Literal["currently-active", "as-described"], Path] = Field( 

206 "currently-active", alias="runtime-env" 

207 ) 

208 """The python environment to run the tests in 

209 - `"currently-active"`: use active Python interpreter 

210 - `"as-described"`: generate a conda environment YAML file based on the model 

211 weights description. 

212 - A path to a conda environment YAML. 

213 Note: The `bioimageio.core` dependency will be added automatically if not present. 

214 """ 

215 

216 working_dir: Optional[Path] = Field(None, alias="working-dir") 

217 """(for debugging) Directory to save any temporary files.""" 

218 

219 determinism: Literal["seed_only", "full"] = "seed_only" 

220 """Modes to improve reproducibility of test outputs.""" 

221 

222 stop_early: bool = Field( 

223 False, alias="stop-early", validation_alias=AliasChoices("stop-early", "x") 

224 ) 

225 """Do not run further subtests after a failed one.""" 

226 

227 format_version: Union[FormatVersionPlaceholder, str] = Field( 

228 "discover", alias="format-version" 

229 ) 

230 """The format version to use for testing. 

231 - 'latest': Use the latest implemented format version for the given resource type (may trigger auto updating) 

232 - 'discover': Use the format version as described in the resource description 

233 - '0.4', '0.5', ...: Use the specified format version (may trigger auto updating) 

234 """ 

235 

236 def cli_cmd(self): 

237 sys.exit( 

238 test( 

239 self.descr, 

240 weight_format=self.weight_format, 

241 devices=self.devices, 

242 summary=self.summary, 

243 runtime_env=self.runtime_env, 

244 determinism=self.determinism, 

245 format_version=self.format_version, 

246 working_dir=self.working_dir, 

247 ) 

248 ) 

249 

250 

251class PackageCmd(CmdBase, WithSource, WithSummaryLogging): 

252 """Save a resource's metadata with its associated files.""" 

253 

254 path: CliPositionalArg[Path] 

255 """The path to write the (zipped) package to. 

256 If it does not have a `.zip` suffix 

257 this command will save the package as an unzipped folder instead.""" 

258 

259 weight_format: WeightFormatArgAll = Field( 

260 "all", 

261 alias="weight-format", 

262 validation_alias=WEIGHT_FORMAT_ALIASES, 

263 ) 

264 """The weight format to include in the package (for model descriptions only).""" 

265 

266 def cli_cmd(self): 

267 if isinstance(self.descr, InvalidDescr): 

268 self.log(self.descr) 

269 raise ValueError(f"Invalid {self.descr.type} description.") 

270 

271 sys.exit( 

272 package( 

273 self.descr, 

274 self.path, 

275 weight_format=self.weight_format, 

276 ) 

277 ) 

278 

279 

280def _get_stat( 

281 model_descr: AnyModelDescr, 

282 dataset: Iterable[Sample], 

283 dataset_length: int, 

284 stats_path: Path, 

285) -> Stat: 

286 req_dataset_meas, _ = get_required_dataset_measures(model_descr) 

287 if not req_dataset_meas: 

288 return {} 

289 

290 req_dataset_meas, _ = get_required_dataset_measures(model_descr) 

291 

292 if stats_path.exists(): 

293 logger.info("loading precomputed dataset measures from {}", stats_path) 

294 stat = load_stat(stats_path) 

295 for m in req_dataset_meas: 

296 if m not in stat: 

297 raise ValueError(f"Missing {m} in {stats_path}") 

298 

299 return stat 

300 

301 stats_calc = StatsCalculator(req_dataset_meas) 

302 

303 for sample in tqdm( 

304 dataset, total=dataset_length, desc="precomputing dataset stats", unit="sample" 

305 ): 

306 stats_calc.update(sample) 

307 

308 stat: Dict[Measure, MeasureValue] = {k: v for k, v in stats_calc.finalize().items()} 

309 save_stat(stat, stats_path) 

310 return stat 

311 

312 

313class UpdateCmdBase(CmdBase, WithSource, ABC): 

314 output: Union[Literal["display", "stdout"], Path] = "display" 

315 """Output updated bioimageio.yaml to the terminal or write to a file. 

316 Notes: 

317 - `"display"`: Render to the terminal with syntax highlighting. 

318 - `"stdout"`: Write to sys.stdout without syntax highligthing. 

319 (More convenient for copying the updated bioimageio.yaml from the terminal.) 

320 """ 

321 

322 diff: Union[bool, Path] = Field(True, alias="diff") 

323 """Output a diff of original and updated bioimageio.yaml. 

324 If a given path has an `.html` extension, a standalone HTML file is written, 

325 otherwise the diff is saved in unified diff format (pure text). 

326 """ 

327 

328 exclude_unset: bool = Field(True, alias="exclude-unset") 

329 """Exclude fields that have not explicitly be set.""" 

330 

331 exclude_defaults: bool = Field(False, alias="exclude-defaults") 

332 """Exclude fields that have the default value (even if set explicitly).""" 

333 

334 @cached_property 

335 def updated(self) -> Union[ResourceDescr, InvalidDescr]: 

336 raise NotImplementedError 

337 

338 def cli_cmd(self): 

339 original_yaml = open_bioimageio_yaml(self.source).unparsed_content 

340 assert isinstance(original_yaml, str) 

341 stream = StringIO() 

342 

343 save_bioimageio_yaml_only( 

344 self.updated, 

345 stream, 

346 exclude_unset=self.exclude_unset, 

347 exclude_defaults=self.exclude_defaults, 

348 ) 

349 updated_yaml = stream.getvalue() 

350 

351 diff = compare( 

352 original_yaml.split("\n"), 

353 updated_yaml.split("\n"), 

354 diff_format=( 

355 "html" 

356 if isinstance(self.diff, Path) and self.diff.suffix == ".html" 

357 else "unified" 

358 ), 

359 ) 

360 

361 if isinstance(self.diff, Path): 

362 _ = self.diff.write_text(diff, encoding="utf-8") 

363 elif self.diff: 

364 console = rich.console.Console() 

365 diff_md = f"## Diff\n\n````````diff\n{diff}\n````````" 

366 console.print(rich.markdown.Markdown(diff_md)) 

367 

368 if isinstance(self.output, Path): 

369 if self.output.suffix in (".yaml", ".yml"): 

370 _ = self.output.write_text(updated_yaml, encoding="utf-8") 

371 logger.info(f"written updated description to {self.output}") 

372 elif isinstance(self.updated, InvalidDescr): 

373 raise ValueError( 

374 f"Cannot save invalid description package to {self.output}." 

375 + " To save the metadata only, choose an output with a '.yaml' extension." 

376 ) 

377 

378 elif not self.output.suffix: 

379 _ = save_bioimageio_package_as_folder( 

380 self.updated, output_path=self.output 

381 ) 

382 else: 

383 _ = save_bioimageio_package(self.updated, output_path=self.output) 

384 

385 elif self.output == "display": 

386 updated_md = f"## Updated bioimageio.yaml\n\n```yaml\n{updated_yaml}\n```" 

387 rich.console.Console().print(rich.markdown.Markdown(updated_md)) 

388 elif self.output == "stdout": 

389 print(updated_yaml) 

390 else: 

391 assert_never(self.output) 

392 

393 if isinstance(self.updated, InvalidDescr): 

394 logger.warning("Update resulted in invalid description") 

395 _ = self.updated.validation_summary.display() 

396 

397 

398class UpdateFormatCmd(UpdateCmdBase): 

399 """Update the metadata format to the latest format version.""" 

400 

401 exclude_defaults: bool = Field(True, alias="exclude-defaults") 

402 """Exclude fields that have the default value (even if set explicitly). 

403 

404 Note: 

405 The update process sets most unset fields explicitly with their default value. 

406 """ 

407 

408 perform_io_checks: bool = Field( 

409 settings.perform_io_checks, alias="perform-io-checks" 

410 ) 

411 """Wether or not to attempt validation that may require file download. 

412 If `True` file hash values are added if not present.""" 

413 

414 @cached_property 

415 def updated(self): 

416 return update_format( 

417 self.source, 

418 exclude_defaults=self.exclude_defaults, 

419 perform_io_checks=self.perform_io_checks, 

420 ) 

421 

422 

423class UpdateHashesCmd(UpdateCmdBase): 

424 """Create a bioimageio.yaml description with updated file hashes.""" 

425 

426 @cached_property 

427 def updated(self): 

428 return update_hashes(self.source) 

429 

430 

431class PredictCmd(CmdBase, WithSource): 

432 """Run inference on your data with a bioimage.io model.""" 

433 

434 inputs: NotEmpty[List[Union[str, NotEmpty[List[str]]]]] = Field( 

435 default_factory=lambda: ["{input_id}/001.tif"], 

436 validation_alias=AliasChoices("inputs", "input"), 

437 ) 

438 """Model input sample paths (for each input tensor) 

439 

440 The input paths are expected to have shape... 

441 - (n_samples,) or (n_samples,1) for models expecting a single input tensor 

442 - (n_samples,) containing the substring '{input_id}', or 

443 - (n_samples, n_model_inputs) to provide each input tensor path explicitly. 

444 

445 All substrings that are replaced by metadata from the model description: 

446 - '{model_id}' 

447 - '{input_id}' 

448 

449 Example inputs to process sample 'a' and 'b' 

450 for a model expecting a 'raw' and a 'mask' input tensor: 

451 --inputs="[[\\"a_raw.tif\\",\\"a_mask.tif\\"],[\\"b_raw.tif\\",\\"b_mask.tif\\"]]" 

452 (Note that JSON double quotes need to be escaped.) 

453 

454 Alternatively a `bioimageio-cli.yaml` (or `bioimageio-cli.json`) file 

455 may provide the arguments, e.g.: 

456 ```yaml 

457 inputs: 

458 - [a_raw.tif, a_mask.tif] 

459 - [b_raw.tif, b_mask.tif] 

460 ``` 

461 

462 `.npy` and any file extension supported by imageio are supported. 

463 Aavailable formats are listed at 

464 https://imageio.readthedocs.io/en/stable/formats/index.html#all-formats. 

465 Some formats have additional dependencies. 

466 

467 

468 """ 

469 

470 outputs: Union[str, NotEmpty[Tuple[str, ...]]] = Field( 

471 "outputs_{model_id}/{output_id}/{sample_id}.tif", 

472 validation_alias=AliasChoices("outputs", "output"), 

473 ) 

474 """Model output path pattern (per output tensor) 

475 

476 All substrings that are replaced: 

477 - '{model_id}' (from model description) 

478 - '{output_id}' (from model description) 

479 - '{sample_id}' (extracted from input paths) 

480 

481  """ 

482 

483 overwrite: bool = False 

484 """allow overwriting existing output files""" 

485 

486 blockwise: Union[bool, int] = False 

487 """Process inputs blockwise 

488 

489 - If an integer is given, it is used as the blocksize parameter 'n' for blockwise processing. 

490 The blockize parameter determines the block size along axes with parameterized input size 

491 by adding n*step_size to the minimum valid input size. 

492 - If `True`, the blocksize parameter is set to 10. 

493 - If `False`, inputs are processed as a whole without blocking. 

494 

495  """ 

496 

497 stats: Annotated[ 

498 Path, 

499 WithJsonSchema({"type": "string"}), 

500 PlainSerializer(lambda p: p.as_posix(), return_type=str), 

501 ] = Path("precomputed_statistics.json") 

502 """path to dataset statistics 

503 (will be written if it does not exist 

504 and the model requires statistical dataset measures) 

505  """ 

506 

507 preview: bool = False 

508 """preview which files would be processed 

509 and what outputs would be generated.""" 

510 

511 weight_format: WeightFormatArgAny = Field( 

512 "any", 

513 alias="weight-format", 

514 validation_alias=WEIGHT_FORMAT_ALIASES, 

515 ) 

516 """The weight format to use.""" 

517 

518 devices: Optional[List[str]] = Field( 

519 None, validation_alias=AliasChoices("devices", "device") 

520 ) 

521 """Device(s) to use""" 

522 

523 example: bool = False 

524 """generate and run an example 

525 

526 1. downloads example model inputs 

527 2. creates a `{model_id}_example` folder 

528 3. writes input arguments to `{model_id}_example/bioimageio-cli.yaml` 

529 4. executes a preview dry-run 

530 5. executes prediction with example input 

531 

532 

533 """ 

534 

535 def _example(self): 

536 model_descr = ensure_description_is_model(self.descr) 

537 input_ids = get_member_ids(model_descr.inputs) 

538 example_inputs = ( 

539 model_descr.sample_inputs 

540 if isinstance(model_descr, v0_4.ModelDescr) 

541 else [ 

542 t 

543 for ipt in model_descr.inputs 

544 if (t := ipt.sample_tensor or ipt.test_tensor) 

545 ] 

546 ) 

547 if not example_inputs: 

548 raise ValueError(f"{self.descr_id} does not specify any example inputs.") 

549 

550 inputs001: List[str] = [] 

551 example_path = Path(f"{self.descr_id}_example") 

552 example_path.mkdir(exist_ok=True) 

553 

554 for t, src in zip(input_ids, example_inputs): 

555 reader = get_reader(src) 

556 dst = Path(f"{example_path}/{t}/001{reader.suffix}") 

557 dst.parent.mkdir(parents=True, exist_ok=True) 

558 inputs001.append(dst.as_posix()) 

559 with dst.open("wb") as f: 

560 shutil.copyfileobj(reader, f) 

561 

562 inputs = [inputs001] 

563 output_pattern = f"{example_path}/outputs/{{output_id}}/{{sample_id}}.tif" 

564 

565 bioimageio_cli_path = example_path / YAML_FILE 

566 stats_file = "precomputed_statistics.json" 

567 stats = (example_path / stats_file).as_posix() 

568 cli_example_args = dict( 

569 inputs=inputs, 

570 outputs=output_pattern, 

571 stats=stats_file, 

572 blockwise=self.blockwise, 

573 ) 

574 assert is_yaml_value(cli_example_args), cli_example_args 

575 write_yaml( 

576 cli_example_args, 

577 bioimageio_cli_path, 

578 ) 

579 

580 yaml_file_content = None 

581 

582 # escaped double quotes 

583 inputs_json = json.dumps(inputs) 

584 inputs_escaped = inputs_json.replace('"', r"\"") 

585 source_escaped = self.source.replace('"', r"\"") 

586 

587 def get_example_command(preview: bool, escape: bool = False): 

588 q: str = '"' if escape else "" 

589 

590 return [ 

591 "bioimageio", 

592 "predict", 

593 # --no-preview not supported for py=3.8 

594 *(["--preview"] if preview else []), 

595 "--overwrite", 

596 f"--blockwise={self.blockwise}", 

597 f"--stats={q}{stats}{q}", 

598 f"--inputs={q}{inputs_escaped if escape else inputs_json}{q}", 

599 f"--outputs={q}{output_pattern}{q}", 

600 f"{q}{source_escaped if escape else self.source}{q}", 

601 ] 

602 

603 if Path(YAML_FILE).exists(): 

604 logger.info( 

605 "temporarily removing '{}' to execute example prediction", YAML_FILE 

606 ) 

607 yaml_file_content = Path(YAML_FILE).read_bytes() 

608 Path(YAML_FILE).unlink() 

609 

610 try: 

611 _ = subprocess.run(get_example_command(True), check=True) 

612 _ = subprocess.run(get_example_command(False), check=True) 

613 finally: 

614 if yaml_file_content is not None: 

615 _ = Path(YAML_FILE).write_bytes(yaml_file_content) 

616 logger.debug("restored '{}'", YAML_FILE) 

617 

618 print( 

619 "🎉 Sucessfully ran example prediction!\n" 

620 + "To predict the example input using the CLI example config file" 

621 + f" {example_path / YAML_FILE}, execute `bioimageio predict` from {example_path}:\n" 

622 + f"$ cd {str(example_path)}\n" 

623 + f'$ bioimageio predict "{source_escaped}"\n\n' 

624 + "Alternatively run the following command" 

625 + " in the current workind directory, not the example folder:\n$ " 

626 + " ".join(get_example_command(False, escape=True)) 

627 + f"\n(note that a local '{JSON_FILE}' or '{YAML_FILE}' may interfere with this)" 

628 ) 

629 

630 def cli_cmd(self): 

631 try: 

632 for out_sample, out_path in self._yield_predictions(self.blockwise): 

633 save_sample(out_path, out_sample) 

634 except Exception as e: 

635 if not self.blockwise: 

636 raise RuntimeError( 

637 f"Prediction failed ({e}).\nConsider using blockwise processing, " 

638 + "e.g. with `--blockwise=10` to process inputs in blocks." 

639 ) from e 

640 raise e 

641 

642 def _yield_predictions(self, blockwise: Union[bool, int]): 

643 if self.example: 

644 return self._example() 

645 

646 model_descr = ensure_description_is_model(self.descr) 

647 

648 input_ids = get_member_ids(model_descr.inputs) 

649 output_ids = get_member_ids(model_descr.outputs) 

650 

651 minimum_input_ids = tuple( 

652 str(ipt.id) if isinstance(ipt, v0_5.InputTensorDescr) else str(ipt.name) 

653 for ipt in model_descr.inputs 

654 if not isinstance(ipt, v0_5.InputTensorDescr) or not ipt.optional 

655 ) 

656 maximum_input_ids = tuple( 

657 str(ipt.id) if isinstance(ipt, v0_5.InputTensorDescr) else str(ipt.name) 

658 for ipt in model_descr.inputs 

659 ) 

660 

661 def expand_inputs(i: int, ipt: Union[str, Sequence[str]]) -> Tuple[str, ...]: 

662 if isinstance(ipt, str): 

663 ipts = tuple( 

664 ipt.format(model_id=self.descr_id, input_id=t) for t in input_ids 

665 ) 

666 else: 

667 ipts = tuple( 

668 p.format(model_id=self.descr_id, input_id=t) 

669 for t, p in zip(input_ids, ipt) 

670 ) 

671 

672 if len(set(ipts)) < len(ipts): 

673 if len(minimum_input_ids) == len(maximum_input_ids): 

674 n = len(minimum_input_ids) 

675 else: 

676 n = f"{len(minimum_input_ids)}-{len(maximum_input_ids)}" 

677 

678 raise ValueError( 

679 f"[input sample #{i}] Include '{{input_id}}' in path pattern or explicitly specify {n} distinct input paths (got {ipt})" 

680 ) 

681 

682 if len(ipts) < len(minimum_input_ids): 

683 raise ValueError( 

684 f"[input sample #{i}] Expected at least {len(minimum_input_ids)} inputs {minimum_input_ids}, got {ipts}" 

685 ) 

686 

687 if len(ipts) > len(maximum_input_ids): 

688 raise ValueError( 

689 f"Expected at most {len(maximum_input_ids)} inputs {maximum_input_ids}, got {ipts}" 

690 ) 

691 

692 return ipts 

693 

694 inputs = [expand_inputs(i, ipt) for i, ipt in enumerate(self.inputs, start=1)] 

695 

696 sample_paths_in = [ 

697 {t: Path(p) for t, p in zip(input_ids, ipts)} for ipts in inputs 

698 ] 

699 

700 sample_ids = _get_sample_ids(sample_paths_in) 

701 

702 def expand_outputs(): 

703 if isinstance(self.outputs, str): 

704 outputs = [ 

705 tuple( 

706 Path( 

707 self.outputs.format( 

708 model_id=self.descr_id, output_id=t, sample_id=s 

709 ) 

710 ) 

711 for t in output_ids 

712 ) 

713 for s in sample_ids 

714 ] 

715 else: 

716 outputs = [ 

717 tuple( 

718 Path(p.format(model_id=self.descr_id, output_id=t, sample_id=s)) 

719 for t, p in zip(output_ids, self.outputs) 

720 ) 

721 for s in sample_ids 

722 ] 

723 # check for distinctness and correct number within each output sample 

724 for i, out in enumerate(outputs, start=1): 

725 if len(set(out)) < len(out): 

726 raise ValueError( 

727 f"[output sample #{i}] Include '{{output_id}}' in path pattern or explicitly specify {len(output_ids)} distinct output paths (got {out})" 

728 ) 

729 

730 if len(out) != len(output_ids): 

731 raise ValueError( 

732 f"[output sample #{i}] Expected {len(output_ids)} outputs {output_ids}, got {out}" 

733 ) 

734 

735 # check for distinctness across all output samples 

736 all_output_paths = [p for out in outputs for p in out] 

737 if len(set(all_output_paths)) < len(all_output_paths): 

738 raise ValueError( 

739 "Output paths are not distinct across samples. " 

740 + "Make sure to include '{{sample_id}}' in the output path pattern." 

741 ) 

742 

743 return outputs 

744 

745 outputs = expand_outputs() 

746 

747 sample_paths_out = [ 

748 {MemberId(t): Path(p) for t, p in zip(output_ids, out)} for out in outputs 

749 ] 

750 

751 if not self.overwrite: 

752 for sample_paths in sample_paths_out: 

753 for p in sample_paths.values(): 

754 if p.exists(): 

755 raise FileExistsError( 

756 f"{p} already exists. use --overwrite to (re-)write outputs anyway." 

757 ) 

758 if self.preview: 

759 print("🛈 bioimageio prediction preview structure:") 

760 pprint( 

761 { 

762 "{sample_id}": dict( 

763 inputs={"{input_id}": "<input path>"}, 

764 outputs={"{output_id}": "<output path>"}, 

765 ) 

766 } 

767 ) 

768 print("🔎 bioimageio prediction preview output:") 

769 pprint( 

770 { 

771 s: dict( 

772 inputs={t: p.as_posix() for t, p in sp_in.items()}, 

773 outputs={t: p.as_posix() for t, p in sp_out.items()}, 

774 ) 

775 for s, sp_in, sp_out in zip( 

776 sample_ids, sample_paths_in, sample_paths_out 

777 ) 

778 } 

779 ) 

780 return 

781 

782 def input_dataset(stat: Stat): 

783 for s, sp_in in zip(sample_ids, sample_paths_in): 

784 yield load_sample_for_model( 

785 model=model_descr, 

786 paths=sp_in, 

787 stat=stat, 

788 sample_id=s, 

789 ) 

790 

791 stat: Dict[Measure, MeasureValue] = dict( 

792 _get_stat( 

793 model_descr, input_dataset({}), len(sample_ids), self.stats 

794 ).items() 

795 ) 

796 

797 pp = create_prediction_pipeline( 

798 model_descr, 

799 weight_format=None if self.weight_format == "any" else self.weight_format, 

800 devices=self.devices, 

801 ) 

802 

803 if blockwise: 

804 predict_method = partial( 

805 pp.predict_sample_with_blocking, 

806 ns=None if isinstance(blockwise, bool) else blockwise, 

807 ) 

808 else: 

809 predict_method = pp.predict_sample_without_blocking 

810 

811 for sample_in, sp_out in tqdm( 

812 zip(input_dataset(dict(stat)), sample_paths_out), 

813 total=len(inputs), 

814 desc=f"predict with {self.descr_id}", 

815 unit="sample", 

816 ): 

817 if self.blockwise is False and not isinstance( 

818 pp.model_description, v0_4.ModelDescr 

819 ): 

820 try: 

821 _ = pp.model_description.validate_input_tensors( 

822 sample_in.as_arrays() 

823 ) 

824 except Exception as e: 

825 logger.warning( 

826 "Input sample '{}' failed validation for whole-sample prediction: {}\n" 

827 + "Consider using blockwise processing, e.g. with `--blockwise=10` to process inputs in blocks.", 

828 sample_in.id, 

829 e, 

830 ) 

831 

832 yield (predict_method(sample_in), sp_out) 

833 

834 

835class PredictBlockArtifactsCmd(PredictCmd): 

836 """Command to inspect block artifacts by subtracting the combined, blockwise predictions from a whole sample prediction. 

837 

838 Note: 

839 - This command intentionally uses a small blocksize (default: 1) to create block artifacts for testing purposes. 

840 - Typical sources of block artifacts include: 

841 - Described halo is smaller than the model's receptive field 

842 - Normalization layers inside the network cannot aggregate statistics over the whole sample. 

843 """ 

844 

845 blockwise: Union[Literal[True], int] = 1 

846 """Process inputs blockwise 

847 

848 - If an integer is given, it is used as the blocksize parameter 'n' for blockwise processing. 

849 The blockize parameter determines the block size along axes with parameterized input size 

850 by adding n*step_size to the minimum valid input size. 

851 - If `True`, the blocksize parameter is set to 10. 

852 

853 Defaults to a small blocksize to intentionally create block artifacts for testing purposes. 

854  """ 

855 

856 def cli_cmd(self): 

857 for (out_sample, out_path), (out_sample_blockwise, _) in zip( 

858 self._yield_predictions(False), self._yield_predictions(self.blockwise) 

859 ): 

860 diff_sample = self._subtract_samples(out_sample, out_sample_blockwise) 

861 for k, v_a in out_sample.stat.items(): 

862 v_b = out_sample_blockwise.stat.get(k) 

863 if v_b is None: 

864 logger.error( 

865 "measure '{}' not found in blockwise prediction statistics", k 

866 ) 

867 elif not np.not_equal(v_a, v_b): 

868 logger.error( 

869 "measure '{}' has different values (whole sample!=blockwise): {}!={}", 

870 k, 

871 v_a, 

872 v_b, 

873 ) 

874 

875 save_sample(out_path, diff_sample) 

876 

877 @staticmethod 

878 def _subtract_samples(a: Sample, b: Sample) -> Sample: 

879 return Sample( 

880 members={t: a.members[t] - b.members[t] for t in a.members}, 

881 id=a.id, 

882 stat=a.stat, 

883 ) 

884 

885 

886class AddWeightsCmd(CmdBase, WithSource, WithSummaryLogging): 

887 """Add additional weights to a model description by converting from available formats.""" 

888 

889 output: CliPositionalArg[Path] 

890 """The path to write the updated model package to.""" 

891 

892 source_format: Optional[SupportedWeightsFormat] = Field(None, alias="source-format") 

893 """Exclusively use these weights to convert to other formats.""" 

894 

895 target_format: Optional[SupportedWeightsFormat] = Field(None, alias="target-format") 

896 """Exclusively add this weight format.""" 

897 

898 verbose: bool = False 

899 """Log more (error) output.""" 

900 

901 tracing: bool = True 

902 """Allow tracing when converting pytorch_state_dict to torchscript 

903 (still uses scripting if possible).""" 

904 

905 def cli_cmd(self): 

906 model_descr = ensure_description_is_model(self.descr) 

907 if isinstance(model_descr, v0_4.ModelDescr): 

908 raise TypeError( 

909 f"model format {model_descr.format_version} not supported." 

910 + " Please update the model first." 

911 ) 

912 updated_model_descr = add_weights( 

913 model_descr, 

914 output_path=self.output, 

915 source_format=self.source_format, 

916 target_format=self.target_format, 

917 verbose=self.verbose, 

918 allow_tracing=self.tracing, 

919 ) 

920 self.log(updated_model_descr) 

921 

922 

923class EmptyCache(CmdBase): 

924 """Empty the bioimageio cache directory.""" 

925 

926 def cli_cmd(self): 

927 empty_cache() 

928 

929 

930JSON_FILE = "bioimageio-cli.json" 

931YAML_FILE = "bioimageio-cli.yaml" 

932 

933 

934class Bioimageio( 

935 BaseSettings, 

936 cli_implicit_flags=True, 

937 cli_parse_args=True, 

938 cli_prog_name="bioimageio", 

939 cli_use_class_docs_for_groups=True, 

940 use_attribute_docstrings=True, 

941): 

942 """bioimageio - CLI for bioimage.io resources 🦒""" 

943 

944 model_config = SettingsConfigDict( 

945 json_file=JSON_FILE, 

946 yaml_file=YAML_FILE, 

947 ) 

948 

949 validate_format: CliSubCommand[ValidateFormatCmd] = Field(alias="validate-format") 

950 """Check a resource's metadata format""" 

951 

952 test: CliSubCommand[TestCmd] 

953 """Test a bioimageio resource (beyond meta data formatting)""" 

954 

955 package: CliSubCommand[PackageCmd] 

956 """Package a resource""" 

957 

958 predict: CliSubCommand[PredictCmd] 

959 """Predict with a model resource""" 

960 

961 predict_block_artifacts: CliSubCommand[PredictBlockArtifactsCmd] = Field( 

962 alias="predict-block-artifacts" 

963 ) 

964 """Save the difference between predicting blowise and whole sample to check for block artifacts.""" 

965 

966 update_format: CliSubCommand[UpdateFormatCmd] = Field(alias="update-format") 

967 """Update the metadata format""" 

968 

969 update_hashes: CliSubCommand[UpdateHashesCmd] = Field(alias="update-hashes") 

970 """Create a bioimageio.yaml description with updated file hashes.""" 

971 

972 add_weights: CliSubCommand[AddWeightsCmd] = Field(alias="add-weights") 

973 """Add additional weights to a model description by converting from available formats.""" 

974 

975 empty_cache: CliSubCommand[EmptyCache] = Field(alias="empty-cache") 

976 """Empty the bioimageio cache directory.""" 

977 

978 @classmethod 

979 def settings_customise_sources( 

980 cls, 

981 settings_cls: Type[BaseSettings], 

982 init_settings: PydanticBaseSettingsSource, 

983 env_settings: PydanticBaseSettingsSource, 

984 dotenv_settings: PydanticBaseSettingsSource, 

985 file_secret_settings: PydanticBaseSettingsSource, 

986 ) -> Tuple[PydanticBaseSettingsSource, ...]: 

987 cli: CliSettingsSource[BaseSettings] = CliSettingsSource( 

988 settings_cls, 

989 cli_parse_args=True, 

990 formatter_class=RawTextHelpFormatter, 

991 ) 

992 sys_args = pformat(sys.argv) 

993 logger.info("starting CLI with arguments:\n{}", sys_args) 

994 return ( 

995 cli, 

996 init_settings, 

997 YamlConfigSettingsSource(settings_cls), 

998 JsonConfigSettingsSource(settings_cls), 

999 ) 

1000 

1001 @model_validator(mode="before") 

1002 @classmethod 

1003 def _log(cls, data: Any): 

1004 logger.info( 

1005 "loaded CLI input:\n{}", 

1006 pformat({k: v for k, v in data.items() if v is not None}), 

1007 ) 

1008 return data 

1009 

1010 def cli_cmd(self) -> None: 

1011 logger.info( 

1012 "executing CLI command:\n{}", 

1013 pformat({k: v for k, v in self.model_dump().items() if v is not None}), 

1014 ) 

1015 _ = CliApp.run_subcommand(self) 

1016 

1017 

1018assert isinstance(Bioimageio.__doc__, str) 

1019Bioimageio.__doc__ += f""" 

1020 

1021library versions: 

1022 bioimageio.core {__version__} 

1023 bioimageio.spec {bioimageio.spec.__version__} 

1024 

1025spec format versions: 

1026 model RDF {ModelDescr.implemented_format_version} 

1027 dataset RDF {DatasetDescr.implemented_format_version} 

1028 notebook RDF {NotebookDescr.implemented_format_version} 

1029 

1030""" 

1031 

1032 

1033def _get_sample_ids( 

1034 input_paths: Sequence[Mapping[MemberId, Path]], 

1035) -> Sequence[SampleId]: 

1036 """Get sample ids for given input paths, based on the common path per sample. 

1037 

1038 Falls back to sample01, samle02, etc...""" 

1039 

1040 matcher = SequenceMatcher() 

1041 

1042 def get_common_seq(seqs: Sequence[Sequence[str]]) -> Sequence[str]: 

1043 """extract a common sequence from multiple sequences 

1044 (order sensitive; strips whitespace and slashes) 

1045 """ 

1046 common = seqs[0] 

1047 

1048 for seq in seqs[1:]: 

1049 if not seq: 

1050 continue 

1051 matcher.set_seqs(common, seq) 

1052 i, _, size = matcher.find_longest_match() 

1053 common = common[i : i + size] 

1054 

1055 if isinstance(common, str): 

1056 common = common.strip().strip("/") 

1057 else: 

1058 common = [cs for c in common if (cs := c.strip().strip("/"))] 

1059 

1060 if not common: 

1061 raise ValueError(f"failed to find common sequence for {seqs}") 

1062 

1063 return common 

1064 

1065 def get_shorter_diff(seqs: Sequence[Sequence[str]]) -> List[Sequence[str]]: 

1066 """get a shorter sequence whose entries are still unique 

1067 (order sensitive, not minimal sequence) 

1068 """ 

1069 min_seq_len = min(len(s) for s in seqs) 

1070 # cut from the start 

1071 for start in range(min_seq_len - 1, -1, -1): 

1072 shortened = [s[start:] for s in seqs] 

1073 if len(set(shortened)) == len(seqs): 

1074 min_seq_len -= start 

1075 break 

1076 else: 

1077 seen: Set[Sequence[str]] = set() 

1078 dupes = [s for s in seqs if s in seen or seen.add(s)] 

1079 raise ValueError(f"Found duplicate entries {dupes}") 

1080 

1081 # cut from the end 

1082 for end in range(min_seq_len - 1, 1, -1): 

1083 shortened = [s[:end] for s in shortened] 

1084 if len(set(shortened)) == len(seqs): 

1085 break 

1086 

1087 return shortened 

1088 

1089 full_tensor_ids = [ 

1090 sorted( 

1091 p.resolve().with_suffix("").as_posix() for p in input_sample_paths.values() 

1092 ) 

1093 for input_sample_paths in input_paths 

1094 ] 

1095 try: 

1096 long_sample_ids = [get_common_seq(t) for t in full_tensor_ids] 

1097 sample_ids = get_shorter_diff(long_sample_ids) 

1098 except ValueError as e: 

1099 raise ValueError(f"failed to extract sample ids: {e}") 

1100 

1101 return sample_ids