Coverage for src / bioimageio / core / cli.py: 82%

400 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-30 13:23 +0000

1"""bioimageio CLI 

2 

3Note: Some docstrings use a hair space ' ' 

4 to place the added '(default: ...)' on a new line. 

5""" 

6 

7import json 

8import shutil 

9import subprocess 

10import sys 

11from abc import ABC 

12from argparse import RawTextHelpFormatter 

13from difflib import SequenceMatcher 

14from functools import cached_property, partial 

15from io import StringIO 

16from pathlib import Path 

17from pprint import pformat, pprint 

18from typing import ( 

19 Annotated, 

20 Any, 

21 Dict, 

22 Iterable, 

23 List, 

24 Literal, 

25 Mapping, 

26 Optional, 

27 Sequence, 

28 Set, 

29 Tuple, 

30 Type, 

31 Union, 

32) 

33 

34import numpy as np 

35import rich.markdown 

36from loguru import logger 

37from pydantic import ( 

38 AliasChoices, 

39 BaseModel, 

40 Field, 

41 PlainSerializer, 

42 WithJsonSchema, 

43 model_validator, 

44) 

45from pydantic_settings import ( 

46 BaseSettings, 

47 CliApp, 

48 CliPositionalArg, 

49 CliSettingsSource, 

50 CliSubCommand, 

51 JsonConfigSettingsSource, 

52 PydanticBaseSettingsSource, 

53 SettingsConfigDict, 

54 YamlConfigSettingsSource, 

55) 

56from tqdm import tqdm 

57from typing_extensions import assert_never 

58 

59import bioimageio.spec 

60from bioimageio.core import __version__ 

61from bioimageio.spec import ( 

62 AnyModelDescr, 

63 InvalidDescr, 

64 ResourceDescr, 

65 load_description, 

66 save_bioimageio_yaml_only, 

67 settings, 

68 update_format, 

69 update_hashes, 

70) 

71from bioimageio.spec._internal.io import is_yaml_value 

72from bioimageio.spec._internal.io_utils import open_bioimageio_yaml 

73from bioimageio.spec._internal.types import FormatVersionPlaceholder, NotEmpty 

74from bioimageio.spec.dataset import DatasetDescr 

75from bioimageio.spec.model import ModelDescr, v0_4, v0_5 

76from bioimageio.spec.notebook import NotebookDescr 

77from bioimageio.spec.utils import ( 

78 empty_cache, 

79 ensure_description_is_model, 

80 get_reader, 

81 write_yaml, 

82) 

83 

84from .commands import WeightFormatArgAll, WeightFormatArgAny, package, test 

85from .common import MemberId, SampleId, SupportedWeightsFormat 

86from .digest_spec import get_member_ids, load_sample_for_model 

87from .io import load_dataset_stat, save_dataset_stat, save_sample 

88from .prediction import create_prediction_pipeline 

89from .proc_setup import ( 

90 DatasetMeasure, 

91 Measure, 

92 MeasureValue, 

93 StatsCalculator, 

94 get_required_dataset_measures, 

95) 

96from .sample import Sample 

97from .stat_measures import Stat 

98from .utils import compare 

99from .weight_converters._add_weights import add_weights 

100 

101WEIGHT_FORMAT_ALIASES = AliasChoices( 

102 "weight-format", 

103 "weights-format", 

104 "weight_format", 

105 "weights_format", 

106) 

107 

108 

109class CmdBase(BaseModel, use_attribute_docstrings=True, cli_implicit_flags=True): 

110 pass 

111 

112 

113class ArgMixin(BaseModel, use_attribute_docstrings=True, cli_implicit_flags=True): 

114 pass 

115 

116 

117class WithSummaryLogging(ArgMixin): 

118 summary: List[Union[Literal["display"], Path]] = Field( 

119 default_factory=lambda: ["display"], 

120 examples=[ 

121 Path("summary.md"), 

122 Path("bioimageio_summaries/"), 

123 ["display", Path("summary.md")], 

124 ], 

125 ) 

126 """Display the validation summary or save it as JSON, Markdown or HTML. 

127 The format is chosen based on the suffix: `.json`, `.md`, `.html`. 

128 If a folder is given (path w/o suffix) the summary is saved in all formats. 

129 Choose/add `"display"` to render the validation summary to the terminal. 

130 """ 

131 

132 def log(self, descr: Union[ResourceDescr, InvalidDescr]): 

133 _ = descr.validation_summary.log(self.summary) 

134 

135 

136class WithSource(ArgMixin): 

137 source: CliPositionalArg[str] 

138 """Url/path to a (folder with a) bioimageio.yaml/rdf.yaml file 

139 or a bioimage.io resource identifier, e.g. 'affable-shark'""" 

140 

141 @cached_property 

142 def descr(self): 

143 return load_description(self.source) 

144 

145 @property 

146 def descr_id(self) -> str: 

147 """a more user-friendly description id 

148 (replacing legacy ids with their nicknames) 

149 """ 

150 if isinstance(self.descr, InvalidDescr): 

151 return str(getattr(self.descr, "id", getattr(self.descr, "name"))) 

152 

153 nickname = None 

154 if ( 

155 isinstance(self.descr.config, v0_5.Config) 

156 and (bio_config := self.descr.config.bioimageio) 

157 and bio_config.model_extra is not None 

158 ): 

159 nickname = bio_config.model_extra.get("nickname") 

160 

161 return str(nickname or self.descr.id or self.descr.name) 

162 

163 

164class ValidateFormatCmd(CmdBase, WithSource, WithSummaryLogging): 

165 """Validate the meta data format of a bioimageio resource.""" 

166 

167 perform_io_checks: bool = Field( 

168 settings.perform_io_checks, alias="perform-io-checks" 

169 ) 

170 """Wether or not to perform validations that requires downloading remote files. 

171 Note: Default value is set by `BIOIMAGEIO_PERFORM_IO_CHECKS` environment variable. 

172 """ 

173 

174 @cached_property 

175 def descr(self): 

176 return load_description(self.source, perform_io_checks=self.perform_io_checks) 

177 

178 def cli_cmd(self): 

179 self.log(self.descr) 

180 sys.exit( 

181 0 

182 if self.descr.validation_summary.status in ("valid-format", "passed") 

183 else 1 

184 ) 

185 

186 

187class TestCmd(CmdBase, WithSource, WithSummaryLogging): 

188 """Test a bioimageio resource (beyond meta data formatting).""" 

189 

190 weight_format: WeightFormatArgAll = Field( 

191 "all", 

192 alias="weight-format", 

193 validation_alias=WEIGHT_FORMAT_ALIASES, 

194 ) 

195 """The weight format to limit testing to. 

196 

197 (only relevant for model resources)""" 

198 

199 devices: Optional[List[str]] = None 

200 """Device(s) to use for testing""" 

201 

202 runtime_env: Union[Literal["currently-active", "as-described"], Path] = Field( 

203 "currently-active", alias="runtime-env" 

204 ) 

205 """The python environment to run the tests in 

206 - `"currently-active"`: use active Python interpreter 

207 - `"as-described"`: generate a conda environment YAML file based on the model 

208 weights description. 

209 - A path to a conda environment YAML. 

210 Note: The `bioimageio.core` dependency will be added automatically if not present. 

211 """ 

212 

213 working_dir: Optional[Path] = Field(None, alias="working-dir") 

214 """(for debugging) Directory to save any temporary files.""" 

215 

216 determinism: Literal["seed_only", "full"] = "seed_only" 

217 """Modes to improve reproducibility of test outputs.""" 

218 

219 stop_early: bool = Field( 

220 False, alias="stop-early", validation_alias=AliasChoices("stop-early", "x") 

221 ) 

222 """Do not run further subtests after a failed one.""" 

223 

224 format_version: Union[FormatVersionPlaceholder, str] = Field( 

225 "discover", alias="format-version" 

226 ) 

227 """The format version to use for testing. 

228 - 'latest': Use the latest implemented format version for the given resource type (may trigger auto updating) 

229 - 'discover': Use the format version as described in the resource description 

230 - '0.4', '0.5', ...: Use the specified format version (may trigger auto updating) 

231 """ 

232 

233 def cli_cmd(self): 

234 sys.exit( 

235 test( 

236 self.descr, 

237 weight_format=self.weight_format, 

238 devices=self.devices, 

239 summary=self.summary, 

240 runtime_env=self.runtime_env, 

241 determinism=self.determinism, 

242 format_version=self.format_version, 

243 working_dir=self.working_dir, 

244 ) 

245 ) 

246 

247 

248class PackageCmd(CmdBase, WithSource, WithSummaryLogging): 

249 """Save a resource's metadata with its associated files.""" 

250 

251 path: CliPositionalArg[Path] 

252 """The path to write the (zipped) package to. 

253 If it does not have a `.zip` suffix 

254 this command will save the package as an unzipped folder instead.""" 

255 

256 weight_format: WeightFormatArgAll = Field( 

257 "all", 

258 alias="weight-format", 

259 validation_alias=WEIGHT_FORMAT_ALIASES, 

260 ) 

261 """The weight format to include in the package (for model descriptions only).""" 

262 

263 def cli_cmd(self): 

264 if isinstance(self.descr, InvalidDescr): 

265 self.log(self.descr) 

266 raise ValueError(f"Invalid {self.descr.type} description.") 

267 

268 sys.exit( 

269 package( 

270 self.descr, 

271 self.path, 

272 weight_format=self.weight_format, 

273 ) 

274 ) 

275 

276 

277def _get_stat( 

278 model_descr: AnyModelDescr, 

279 dataset: Iterable[Sample], 

280 dataset_length: int, 

281 stats_path: Path, 

282) -> Mapping[DatasetMeasure, MeasureValue]: 

283 req_dataset_meas, _ = get_required_dataset_measures(model_descr) 

284 if not req_dataset_meas: 

285 return {} 

286 

287 req_dataset_meas, _ = get_required_dataset_measures(model_descr) 

288 

289 if stats_path.exists(): 

290 logger.info("loading precomputed dataset measures from {}", stats_path) 

291 stat = load_dataset_stat(stats_path) 

292 for m in req_dataset_meas: 

293 if m not in stat: 

294 raise ValueError(f"Missing {m} in {stats_path}") 

295 

296 return stat 

297 

298 stats_calc = StatsCalculator(req_dataset_meas) 

299 

300 for sample in tqdm( 

301 dataset, total=dataset_length, desc="precomputing dataset stats", unit="sample" 

302 ): 

303 stats_calc.update(sample) 

304 

305 stat = stats_calc.finalize() 

306 save_dataset_stat(stat, stats_path) 

307 

308 return stat 

309 

310 

311class UpdateCmdBase(CmdBase, WithSource, ABC): 

312 output: Union[Literal["display", "stdout"], Path] = "display" 

313 """Output updated bioimageio.yaml to the terminal or write to a file. 

314 Notes: 

315 - `"display"`: Render to the terminal with syntax highlighting. 

316 - `"stdout"`: Write to sys.stdout without syntax highligthing. 

317 (More convenient for copying the updated bioimageio.yaml from the terminal.) 

318 """ 

319 

320 diff: Union[bool, Path] = Field(True, alias="diff") 

321 """Output a diff of original and updated bioimageio.yaml. 

322 If a given path has an `.html` extension, a standalone HTML file is written, 

323 otherwise the diff is saved in unified diff format (pure text). 

324 """ 

325 

326 exclude_unset: bool = Field(True, alias="exclude-unset") 

327 """Exclude fields that have not explicitly be set.""" 

328 

329 exclude_defaults: bool = Field(False, alias="exclude-defaults") 

330 """Exclude fields that have the default value (even if set explicitly).""" 

331 

332 @cached_property 

333 def updated(self) -> Union[ResourceDescr, InvalidDescr]: 

334 raise NotImplementedError 

335 

336 def cli_cmd(self): 

337 original_yaml = open_bioimageio_yaml(self.source).unparsed_content 

338 assert isinstance(original_yaml, str) 

339 stream = StringIO() 

340 

341 save_bioimageio_yaml_only( 

342 self.updated, 

343 stream, 

344 exclude_unset=self.exclude_unset, 

345 exclude_defaults=self.exclude_defaults, 

346 ) 

347 updated_yaml = stream.getvalue() 

348 

349 diff = compare( 

350 original_yaml.split("\n"), 

351 updated_yaml.split("\n"), 

352 diff_format=( 

353 "html" 

354 if isinstance(self.diff, Path) and self.diff.suffix == ".html" 

355 else "unified" 

356 ), 

357 ) 

358 

359 if isinstance(self.diff, Path): 

360 _ = self.diff.write_text(diff, encoding="utf-8") 

361 elif self.diff: 

362 console = rich.console.Console() 

363 diff_md = f"## Diff\n\n````````diff\n{diff}\n````````" 

364 console.print(rich.markdown.Markdown(diff_md)) 

365 

366 if isinstance(self.output, Path): 

367 _ = self.output.write_text(updated_yaml, encoding="utf-8") 

368 logger.info(f"written updated description to {self.output}") 

369 elif self.output == "display": 

370 updated_md = f"## Updated bioimageio.yaml\n\n```yaml\n{updated_yaml}\n```" 

371 rich.console.Console().print(rich.markdown.Markdown(updated_md)) 

372 elif self.output == "stdout": 

373 print(updated_yaml) 

374 else: 

375 assert_never(self.output) 

376 

377 if isinstance(self.updated, InvalidDescr): 

378 logger.warning("Update resulted in invalid description") 

379 _ = self.updated.validation_summary.display() 

380 

381 

382class UpdateFormatCmd(UpdateCmdBase): 

383 """Update the metadata format to the latest format version.""" 

384 

385 exclude_defaults: bool = Field(True, alias="exclude-defaults") 

386 """Exclude fields that have the default value (even if set explicitly). 

387 

388 Note: 

389 The update process sets most unset fields explicitly with their default value. 

390 """ 

391 

392 perform_io_checks: bool = Field( 

393 settings.perform_io_checks, alias="perform-io-checks" 

394 ) 

395 """Wether or not to attempt validation that may require file download. 

396 If `True` file hash values are added if not present.""" 

397 

398 @cached_property 

399 def updated(self): 

400 return update_format( 

401 self.source, 

402 exclude_defaults=self.exclude_defaults, 

403 perform_io_checks=self.perform_io_checks, 

404 ) 

405 

406 

407class UpdateHashesCmd(UpdateCmdBase): 

408 """Create a bioimageio.yaml description with updated file hashes.""" 

409 

410 @cached_property 

411 def updated(self): 

412 return update_hashes(self.source) 

413 

414 

415class PredictCmd(CmdBase, WithSource): 

416 """Run inference on your data with a bioimage.io model.""" 

417 

418 inputs: NotEmpty[List[Union[str, NotEmpty[List[str]]]]] = Field( 

419 default_factory=lambda: ["{input_id}/001.tif"] 

420 ) 

421 """Model input sample paths (for each input tensor) 

422 

423 The input paths are expected to have shape... 

424 - (n_samples,) or (n_samples,1) for models expecting a single input tensor 

425 - (n_samples,) containing the substring '{input_id}', or 

426 - (n_samples, n_model_inputs) to provide each input tensor path explicitly. 

427 

428 All substrings that are replaced by metadata from the model description: 

429 - '{model_id}' 

430 - '{input_id}' 

431 

432 Example inputs to process sample 'a' and 'b' 

433 for a model expecting a 'raw' and a 'mask' input tensor: 

434 --inputs="[[\\"a_raw.tif\\",\\"a_mask.tif\\"],[\\"b_raw.tif\\",\\"b_mask.tif\\"]]" 

435 (Note that JSON double quotes need to be escaped.) 

436 

437 Alternatively a `bioimageio-cli.yaml` (or `bioimageio-cli.json`) file 

438 may provide the arguments, e.g.: 

439 ```yaml 

440 inputs: 

441 - [a_raw.tif, a_mask.tif] 

442 - [b_raw.tif, b_mask.tif] 

443 ``` 

444 

445 `.npy` and any file extension supported by imageio are supported. 

446 Aavailable formats are listed at 

447 https://imageio.readthedocs.io/en/stable/formats/index.html#all-formats. 

448 Some formats have additional dependencies. 

449 

450 

451 """ 

452 

453 outputs: Union[str, NotEmpty[Tuple[str, ...]]] = ( 

454 "outputs_{model_id}/{output_id}/{sample_id}.tif" 

455 ) 

456 """Model output path pattern (per output tensor) 

457 

458 All substrings that are replaced: 

459 - '{model_id}' (from model description) 

460 - '{output_id}' (from model description) 

461 - '{sample_id}' (extracted from input paths) 

462 

463  """ 

464 

465 overwrite: bool = False 

466 """allow overwriting existing output files""" 

467 

468 blockwise: Union[bool, int] = False 

469 """Process inputs blockwise 

470 

471 - If an integer is given, it is used as the blocksize parameter 'n' for blockwise processing. 

472 The blockize parameter determines the block size along axes with parameterized input size 

473 by adding n*step_size to the minimum valid input size. 

474 - If `True`, the blocksize parameter is set to 10. 

475 - If `False`, inputs are processed as a whole without blocking. 

476 

477  """ 

478 

479 stats: Annotated[ 

480 Path, 

481 WithJsonSchema({"type": "string"}), 

482 PlainSerializer(lambda p: p.as_posix(), return_type=str), 

483 ] = Path("dataset_statistics.json") 

484 """path to dataset statistics 

485 (will be written if it does not exist 

486 and the model requires statistical dataset measures) 

487  """ 

488 

489 preview: bool = False 

490 """preview which files would be processed 

491 and what outputs would be generated.""" 

492 

493 weight_format: WeightFormatArgAny = Field( 

494 "any", 

495 alias="weight-format", 

496 validation_alias=WEIGHT_FORMAT_ALIASES, 

497 ) 

498 """The weight format to use.""" 

499 

500 example: bool = False 

501 """generate and run an example 

502 

503 1. downloads example model inputs 

504 2. creates a `{model_id}_example` folder 

505 3. writes input arguments to `{model_id}_example/bioimageio-cli.yaml` 

506 4. executes a preview dry-run 

507 5. executes prediction with example input 

508 

509 

510 """ 

511 

512 def _example(self): 

513 model_descr = ensure_description_is_model(self.descr) 

514 input_ids = get_member_ids(model_descr.inputs) 

515 example_inputs = ( 

516 model_descr.sample_inputs 

517 if isinstance(model_descr, v0_4.ModelDescr) 

518 else [ 

519 t 

520 for ipt in model_descr.inputs 

521 if (t := ipt.sample_tensor or ipt.test_tensor) 

522 ] 

523 ) 

524 if not example_inputs: 

525 raise ValueError(f"{self.descr_id} does not specify any example inputs.") 

526 

527 inputs001: List[str] = [] 

528 example_path = Path(f"{self.descr_id}_example") 

529 example_path.mkdir(exist_ok=True) 

530 

531 for t, src in zip(input_ids, example_inputs): 

532 reader = get_reader(src) 

533 dst = Path(f"{example_path}/{t}/001{reader.suffix}") 

534 dst.parent.mkdir(parents=True, exist_ok=True) 

535 inputs001.append(dst.as_posix()) 

536 with dst.open("wb") as f: 

537 shutil.copyfileobj(reader, f) 

538 

539 inputs = [inputs001] 

540 output_pattern = f"{example_path}/outputs/{{output_id}}/{{sample_id}}.tif" 

541 

542 bioimageio_cli_path = example_path / YAML_FILE 

543 stats_file = "dataset_statistics.json" 

544 stats = (example_path / stats_file).as_posix() 

545 cli_example_args = dict( 

546 inputs=inputs, 

547 outputs=output_pattern, 

548 stats=stats_file, 

549 blockwise=self.blockwise, 

550 ) 

551 assert is_yaml_value(cli_example_args), cli_example_args 

552 write_yaml( 

553 cli_example_args, 

554 bioimageio_cli_path, 

555 ) 

556 

557 yaml_file_content = None 

558 

559 # escaped double quotes 

560 inputs_json = json.dumps(inputs) 

561 inputs_escaped = inputs_json.replace('"', r"\"") 

562 source_escaped = self.source.replace('"', r"\"") 

563 

564 def get_example_command(preview: bool, escape: bool = False): 

565 q: str = '"' if escape else "" 

566 

567 return [ 

568 "bioimageio", 

569 "predict", 

570 # --no-preview not supported for py=3.8 

571 *(["--preview"] if preview else []), 

572 "--overwrite", 

573 *(["--blockwise"] if self.blockwise else []), 

574 f"--stats={q}{stats}{q}", 

575 f"--inputs={q}{inputs_escaped if escape else inputs_json}{q}", 

576 f"--outputs={q}{output_pattern}{q}", 

577 f"{q}{source_escaped if escape else self.source}{q}", 

578 ] 

579 

580 if Path(YAML_FILE).exists(): 

581 logger.info( 

582 "temporarily removing '{}' to execute example prediction", YAML_FILE 

583 ) 

584 yaml_file_content = Path(YAML_FILE).read_bytes() 

585 Path(YAML_FILE).unlink() 

586 

587 try: 

588 _ = subprocess.run(get_example_command(True), check=True) 

589 _ = subprocess.run(get_example_command(False), check=True) 

590 finally: 

591 if yaml_file_content is not None: 

592 _ = Path(YAML_FILE).write_bytes(yaml_file_content) 

593 logger.debug("restored '{}'", YAML_FILE) 

594 

595 print( 

596 "🎉 Sucessfully ran example prediction!\n" 

597 + "To predict the example input using the CLI example config file" 

598 + f" {example_path / YAML_FILE}, execute `bioimageio predict` from {example_path}:\n" 

599 + f"$ cd {str(example_path)}\n" 

600 + f'$ bioimageio predict "{source_escaped}"\n\n' 

601 + "Alternatively run the following command" 

602 + " in the current workind directory, not the example folder:\n$ " 

603 + " ".join(get_example_command(False, escape=True)) 

604 + f"\n(note that a local '{JSON_FILE}' or '{YAML_FILE}' may interfere with this)" 

605 ) 

606 

607 def cli_cmd(self): 

608 for out_sample, out_path in self._yield_predictions(self.blockwise): 

609 save_sample(out_path, out_sample) 

610 

611 def _yield_predictions(self, blockwise: Union[bool, int]): 

612 if self.example: 

613 return self._example() 

614 

615 model_descr = ensure_description_is_model(self.descr) 

616 

617 input_ids = get_member_ids(model_descr.inputs) 

618 output_ids = get_member_ids(model_descr.outputs) 

619 

620 minimum_input_ids = tuple( 

621 str(ipt.id) if isinstance(ipt, v0_5.InputTensorDescr) else str(ipt.name) 

622 for ipt in model_descr.inputs 

623 if not isinstance(ipt, v0_5.InputTensorDescr) or not ipt.optional 

624 ) 

625 maximum_input_ids = tuple( 

626 str(ipt.id) if isinstance(ipt, v0_5.InputTensorDescr) else str(ipt.name) 

627 for ipt in model_descr.inputs 

628 ) 

629 

630 def expand_inputs(i: int, ipt: Union[str, Sequence[str]]) -> Tuple[str, ...]: 

631 if isinstance(ipt, str): 

632 ipts = tuple( 

633 ipt.format(model_id=self.descr_id, input_id=t) for t in input_ids 

634 ) 

635 else: 

636 ipts = tuple( 

637 p.format(model_id=self.descr_id, input_id=t) 

638 for t, p in zip(input_ids, ipt) 

639 ) 

640 

641 if len(set(ipts)) < len(ipts): 

642 if len(minimum_input_ids) == len(maximum_input_ids): 

643 n = len(minimum_input_ids) 

644 else: 

645 n = f"{len(minimum_input_ids)}-{len(maximum_input_ids)}" 

646 

647 raise ValueError( 

648 f"[input sample #{i}] Include '{{input_id}}' in path pattern or explicitly specify {n} distinct input paths (got {ipt})" 

649 ) 

650 

651 if len(ipts) < len(minimum_input_ids): 

652 raise ValueError( 

653 f"[input sample #{i}] Expected at least {len(minimum_input_ids)} inputs {minimum_input_ids}, got {ipts}" 

654 ) 

655 

656 if len(ipts) > len(maximum_input_ids): 

657 raise ValueError( 

658 f"Expected at most {len(maximum_input_ids)} inputs {maximum_input_ids}, got {ipts}" 

659 ) 

660 

661 return ipts 

662 

663 inputs = [expand_inputs(i, ipt) for i, ipt in enumerate(self.inputs, start=1)] 

664 

665 sample_paths_in = [ 

666 {t: Path(p) for t, p in zip(input_ids, ipts)} for ipts in inputs 

667 ] 

668 

669 sample_ids = _get_sample_ids(sample_paths_in) 

670 

671 def expand_outputs(): 

672 if isinstance(self.outputs, str): 

673 outputs = [ 

674 tuple( 

675 Path( 

676 self.outputs.format( 

677 model_id=self.descr_id, output_id=t, sample_id=s 

678 ) 

679 ) 

680 for t in output_ids 

681 ) 

682 for s in sample_ids 

683 ] 

684 else: 

685 outputs = [ 

686 tuple( 

687 Path(p.format(model_id=self.descr_id, output_id=t, sample_id=s)) 

688 for t, p in zip(output_ids, self.outputs) 

689 ) 

690 for s in sample_ids 

691 ] 

692 # check for distinctness and correct number within each output sample 

693 for i, out in enumerate(outputs, start=1): 

694 if len(set(out)) < len(out): 

695 raise ValueError( 

696 f"[output sample #{i}] Include '{{output_id}}' in path pattern or explicitly specify {len(output_ids)} distinct output paths (got {out})" 

697 ) 

698 

699 if len(out) != len(output_ids): 

700 raise ValueError( 

701 f"[output sample #{i}] Expected {len(output_ids)} outputs {output_ids}, got {out}" 

702 ) 

703 

704 # check for distinctness across all output samples 

705 all_output_paths = [p for out in outputs for p in out] 

706 if len(set(all_output_paths)) < len(all_output_paths): 

707 raise ValueError( 

708 "Output paths are not distinct across samples. " 

709 + "Make sure to include '{{sample_id}}' in the output path pattern." 

710 ) 

711 

712 return outputs 

713 

714 outputs = expand_outputs() 

715 

716 sample_paths_out = [ 

717 {MemberId(t): Path(p) for t, p in zip(output_ids, out)} for out in outputs 

718 ] 

719 

720 if not self.overwrite: 

721 for sample_paths in sample_paths_out: 

722 for p in sample_paths.values(): 

723 if p.exists(): 

724 raise FileExistsError( 

725 f"{p} already exists. use --overwrite to (re-)write outputs anyway." 

726 ) 

727 if self.preview: 

728 print("🛈 bioimageio prediction preview structure:") 

729 pprint( 

730 { 

731 "{sample_id}": dict( 

732 inputs={"{input_id}": "<input path>"}, 

733 outputs={"{output_id}": "<output path>"}, 

734 ) 

735 } 

736 ) 

737 print("🔎 bioimageio prediction preview output:") 

738 pprint( 

739 { 

740 s: dict( 

741 inputs={t: p.as_posix() for t, p in sp_in.items()}, 

742 outputs={t: p.as_posix() for t, p in sp_out.items()}, 

743 ) 

744 for s, sp_in, sp_out in zip( 

745 sample_ids, sample_paths_in, sample_paths_out 

746 ) 

747 } 

748 ) 

749 return 

750 

751 def input_dataset(stat: Stat): 

752 for s, sp_in in zip(sample_ids, sample_paths_in): 

753 yield load_sample_for_model( 

754 model=model_descr, 

755 paths=sp_in, 

756 stat=stat, 

757 sample_id=s, 

758 ) 

759 

760 stat: Dict[Measure, MeasureValue] = dict( 

761 _get_stat( 

762 model_descr, input_dataset({}), len(sample_ids), self.stats 

763 ).items() 

764 ) 

765 

766 pp = create_prediction_pipeline( 

767 model_descr, 

768 weight_format=None if self.weight_format == "any" else self.weight_format, 

769 ) 

770 

771 if blockwise: 

772 predict_method = partial( 

773 pp.predict_sample_with_blocking, 

774 ns=None if isinstance(blockwise, bool) else blockwise, 

775 ) 

776 else: 

777 predict_method = pp.predict_sample_without_blocking 

778 

779 for sample_in, sp_out in tqdm( 

780 zip(input_dataset(dict(stat)), sample_paths_out), 

781 total=len(inputs), 

782 desc=f"predict with {self.descr_id}", 

783 unit="sample", 

784 ): 

785 yield (predict_method(sample_in), sp_out) 

786 

787 

788class PredictBlockArtifactsCmd(PredictCmd): 

789 """Command to inspect block artifacts by subtracting the combined, blockwise predictions from a whole sample prediction. 

790 

791 Note: 

792 - This command intentionally uses a small blocksize (default: 1) to create block artifacts for testing purposes. 

793 - Typical sources of block artifacts include: 

794 - Described halo is smaller than the model's receptive field 

795 - Normalization layers inside the network cannot aggregate statistics over the whole sample. 

796 """ 

797 

798 blockwise: Union[Literal[True], int] = 1 

799 """Process inputs blockwise 

800 

801 - If an integer is given, it is used as the blocksize parameter 'n' for blockwise processing. 

802 The blockize parameter determines the block size along axes with parameterized input size 

803 by adding n*step_size to the minimum valid input size. 

804 - If `True`, the blocksize parameter is set to 10. 

805 

806 Defaults to a small blocksize to intentionally create block artifacts for testing purposes. 

807  """ 

808 

809 def cli_cmd(self): 

810 for (out_sample, out_path), (out_sample_blockwise, _) in zip( 

811 self._yield_predictions(False), self._yield_predictions(self.blockwise) 

812 ): 

813 diff_sample = self._subtract_samples(out_sample, out_sample_blockwise) 

814 for k, v_a in out_sample.stat.items(): 

815 v_b = out_sample_blockwise.stat.get(k) 

816 if v_b is None: 

817 logger.error( 

818 "measure '{}' not found in blockwise prediction statistics", k 

819 ) 

820 elif not np.not_equal(v_a, v_b): 

821 logger.error( 

822 "measure '{}' has different values (whole sample!=blockwise): {}!={}", 

823 k, 

824 v_a, 

825 v_b, 

826 ) 

827 

828 save_sample(out_path, diff_sample) 

829 

830 @staticmethod 

831 def _subtract_samples(a: Sample, b: Sample) -> Sample: 

832 return Sample( 

833 members={t: a.members[t] - b.members[t] for t in a.members}, 

834 id=a.id, 

835 stat=a.stat, 

836 ) 

837 

838 

839class AddWeightsCmd(CmdBase, WithSource, WithSummaryLogging): 

840 """Add additional weights to a model description by converting from available formats.""" 

841 

842 output: CliPositionalArg[Path] 

843 """The path to write the updated model package to.""" 

844 

845 source_format: Optional[SupportedWeightsFormat] = Field(None, alias="source-format") 

846 """Exclusively use these weights to convert to other formats.""" 

847 

848 target_format: Optional[SupportedWeightsFormat] = Field(None, alias="target-format") 

849 """Exclusively add this weight format.""" 

850 

851 verbose: bool = False 

852 """Log more (error) output.""" 

853 

854 tracing: bool = True 

855 """Allow tracing when converting pytorch_state_dict to torchscript 

856 (still uses scripting if possible).""" 

857 

858 def cli_cmd(self): 

859 model_descr = ensure_description_is_model(self.descr) 

860 if isinstance(model_descr, v0_4.ModelDescr): 

861 raise TypeError( 

862 f"model format {model_descr.format_version} not supported." 

863 + " Please update the model first." 

864 ) 

865 updated_model_descr = add_weights( 

866 model_descr, 

867 output_path=self.output, 

868 source_format=self.source_format, 

869 target_format=self.target_format, 

870 verbose=self.verbose, 

871 allow_tracing=self.tracing, 

872 ) 

873 self.log(updated_model_descr) 

874 

875 

876class EmptyCache(CmdBase): 

877 """Empty the bioimageio cache directory.""" 

878 

879 def cli_cmd(self): 

880 empty_cache() 

881 

882 

883JSON_FILE = "bioimageio-cli.json" 

884YAML_FILE = "bioimageio-cli.yaml" 

885 

886 

887class Bioimageio( 

888 BaseSettings, 

889 cli_implicit_flags=True, 

890 cli_parse_args=True, 

891 cli_prog_name="bioimageio", 

892 cli_use_class_docs_for_groups=True, 

893 use_attribute_docstrings=True, 

894): 

895 """bioimageio - CLI for bioimage.io resources 🦒""" 

896 

897 model_config = SettingsConfigDict( 

898 json_file=JSON_FILE, 

899 yaml_file=YAML_FILE, 

900 ) 

901 

902 validate_format: CliSubCommand[ValidateFormatCmd] = Field(alias="validate-format") 

903 """Check a resource's metadata format""" 

904 

905 test: CliSubCommand[TestCmd] 

906 """Test a bioimageio resource (beyond meta data formatting)""" 

907 

908 package: CliSubCommand[PackageCmd] 

909 """Package a resource""" 

910 

911 predict: CliSubCommand[PredictCmd] 

912 """Predict with a model resource""" 

913 

914 predict_block_artifacts: CliSubCommand[PredictBlockArtifactsCmd] = Field( 

915 alias="predict-block-artifacts" 

916 ) 

917 """Save the difference between predicting blowise and whole sample to check for block artifacts.""" 

918 

919 update_format: CliSubCommand[UpdateFormatCmd] = Field(alias="update-format") 

920 """Update the metadata format""" 

921 

922 update_hashes: CliSubCommand[UpdateHashesCmd] = Field(alias="update-hashes") 

923 """Create a bioimageio.yaml description with updated file hashes.""" 

924 

925 add_weights: CliSubCommand[AddWeightsCmd] = Field(alias="add-weights") 

926 """Add additional weights to a model description by converting from available formats.""" 

927 

928 empty_cache: CliSubCommand[EmptyCache] = Field(alias="empty-cache") 

929 """Empty the bioimageio cache directory.""" 

930 

931 @classmethod 

932 def settings_customise_sources( 

933 cls, 

934 settings_cls: Type[BaseSettings], 

935 init_settings: PydanticBaseSettingsSource, 

936 env_settings: PydanticBaseSettingsSource, 

937 dotenv_settings: PydanticBaseSettingsSource, 

938 file_secret_settings: PydanticBaseSettingsSource, 

939 ) -> Tuple[PydanticBaseSettingsSource, ...]: 

940 cli: CliSettingsSource[BaseSettings] = CliSettingsSource( 

941 settings_cls, 

942 cli_parse_args=True, 

943 formatter_class=RawTextHelpFormatter, 

944 ) 

945 sys_args = pformat(sys.argv) 

946 logger.info("starting CLI with arguments:\n{}", sys_args) 

947 return ( 

948 cli, 

949 init_settings, 

950 YamlConfigSettingsSource(settings_cls), 

951 JsonConfigSettingsSource(settings_cls), 

952 ) 

953 

954 @model_validator(mode="before") 

955 @classmethod 

956 def _log(cls, data: Any): 

957 logger.info( 

958 "loaded CLI input:\n{}", 

959 pformat({k: v for k, v in data.items() if v is not None}), 

960 ) 

961 return data 

962 

963 def cli_cmd(self) -> None: 

964 logger.info( 

965 "executing CLI command:\n{}", 

966 pformat({k: v for k, v in self.model_dump().items() if v is not None}), 

967 ) 

968 _ = CliApp.run_subcommand(self) 

969 

970 

971assert isinstance(Bioimageio.__doc__, str) 

972Bioimageio.__doc__ += f""" 

973 

974library versions: 

975 bioimageio.core {__version__} 

976 bioimageio.spec {bioimageio.spec.__version__} 

977 

978spec format versions: 

979 model RDF {ModelDescr.implemented_format_version} 

980 dataset RDF {DatasetDescr.implemented_format_version} 

981 notebook RDF {NotebookDescr.implemented_format_version} 

982 

983""" 

984 

985 

986def _get_sample_ids( 

987 input_paths: Sequence[Mapping[MemberId, Path]], 

988) -> Sequence[SampleId]: 

989 """Get sample ids for given input paths, based on the common path per sample. 

990 

991 Falls back to sample01, samle02, etc...""" 

992 

993 matcher = SequenceMatcher() 

994 

995 def get_common_seq(seqs: Sequence[Sequence[str]]) -> Sequence[str]: 

996 """extract a common sequence from multiple sequences 

997 (order sensitive; strips whitespace and slashes) 

998 """ 

999 common = seqs[0] 

1000 

1001 for seq in seqs[1:]: 

1002 if not seq: 

1003 continue 

1004 matcher.set_seqs(common, seq) 

1005 i, _, size = matcher.find_longest_match() 

1006 common = common[i : i + size] 

1007 

1008 if isinstance(common, str): 

1009 common = common.strip().strip("/") 

1010 else: 

1011 common = [cs for c in common if (cs := c.strip().strip("/"))] 

1012 

1013 if not common: 

1014 raise ValueError(f"failed to find common sequence for {seqs}") 

1015 

1016 return common 

1017 

1018 def get_shorter_diff(seqs: Sequence[Sequence[str]]) -> List[Sequence[str]]: 

1019 """get a shorter sequence whose entries are still unique 

1020 (order sensitive, not minimal sequence) 

1021 """ 

1022 min_seq_len = min(len(s) for s in seqs) 

1023 # cut from the start 

1024 for start in range(min_seq_len - 1, -1, -1): 

1025 shortened = [s[start:] for s in seqs] 

1026 if len(set(shortened)) == len(seqs): 

1027 min_seq_len -= start 

1028 break 

1029 else: 

1030 seen: Set[Sequence[str]] = set() 

1031 dupes = [s for s in seqs if s in seen or seen.add(s)] 

1032 raise ValueError(f"Found duplicate entries {dupes}") 

1033 

1034 # cut from the end 

1035 for end in range(min_seq_len - 1, 1, -1): 

1036 shortened = [s[:end] for s in shortened] 

1037 if len(set(shortened)) == len(seqs): 

1038 break 

1039 

1040 return shortened 

1041 

1042 full_tensor_ids = [ 

1043 sorted( 

1044 p.resolve().with_suffix("").as_posix() for p in input_sample_paths.values() 

1045 ) 

1046 for input_sample_paths in input_paths 

1047 ] 

1048 try: 

1049 long_sample_ids = [get_common_seq(t) for t in full_tensor_ids] 

1050 sample_ids = get_shorter_diff(long_sample_ids) 

1051 except ValueError as e: 

1052 raise ValueError(f"failed to extract sample ids: {e}") 

1053 

1054 return sample_ids