Coverage for src/bioimageio/core/cli.py: 44%

369 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-09-22 09:21 +0000

1"""bioimageio CLI 

2 

3Note: Some docstrings use a hair space ' ' 

4 to place the added '(default: ...)' on a new line. 

5""" 

6 

7import json 

8import shutil 

9import subprocess 

10import sys 

11from abc import ABC 

12from argparse import RawTextHelpFormatter 

13from difflib import SequenceMatcher 

14from functools import cached_property 

15from io import StringIO 

16from pathlib import Path 

17from pprint import pformat, pprint 

18from typing import ( 

19 Any, 

20 Dict, 

21 Iterable, 

22 List, 

23 Literal, 

24 Mapping, 

25 Optional, 

26 Sequence, 

27 Set, 

28 Tuple, 

29 Type, 

30 Union, 

31) 

32 

33import rich.markdown 

34from loguru import logger 

35from pydantic import AliasChoices, BaseModel, Field, model_validator 

36from pydantic_settings import ( 

37 BaseSettings, 

38 CliPositionalArg, 

39 CliSettingsSource, 

40 CliSubCommand, 

41 JsonConfigSettingsSource, 

42 PydanticBaseSettingsSource, 

43 SettingsConfigDict, 

44 YamlConfigSettingsSource, 

45) 

46from tqdm import tqdm 

47from typing_extensions import assert_never 

48 

49import bioimageio.spec 

50from bioimageio.core import __version__ 

51from bioimageio.spec import ( 

52 AnyModelDescr, 

53 InvalidDescr, 

54 ResourceDescr, 

55 load_description, 

56 save_bioimageio_yaml_only, 

57 settings, 

58 update_format, 

59 update_hashes, 

60) 

61from bioimageio.spec._internal.io import is_yaml_value 

62from bioimageio.spec._internal.io_utils import open_bioimageio_yaml 

63from bioimageio.spec._internal.types import FormatVersionPlaceholder, NotEmpty 

64from bioimageio.spec.dataset import DatasetDescr 

65from bioimageio.spec.model import ModelDescr, v0_4, v0_5 

66from bioimageio.spec.notebook import NotebookDescr 

67from bioimageio.spec.utils import ensure_description_is_model, get_reader, write_yaml 

68 

69from .commands import WeightFormatArgAll, WeightFormatArgAny, package, test 

70from .common import MemberId, SampleId, SupportedWeightsFormat 

71from .digest_spec import get_member_ids, load_sample_for_model 

72from .io import load_dataset_stat, save_dataset_stat, save_sample 

73from .prediction import create_prediction_pipeline 

74from .proc_setup import ( 

75 DatasetMeasure, 

76 Measure, 

77 MeasureValue, 

78 StatsCalculator, 

79 get_required_dataset_measures, 

80) 

81from .sample import Sample 

82from .stat_measures import Stat 

83from .utils import compare 

84from .weight_converters._add_weights import add_weights 

85 

86WEIGHT_FORMAT_ALIASES = AliasChoices( 

87 "weight-format", 

88 "weights-format", 

89 "weight_format", 

90 "weights_format", 

91) 

92 

93 

94class CmdBase(BaseModel, use_attribute_docstrings=True, cli_implicit_flags=True): 

95 pass 

96 

97 

98class ArgMixin(BaseModel, use_attribute_docstrings=True, cli_implicit_flags=True): 

99 pass 

100 

101 

102class WithSummaryLogging(ArgMixin): 

103 summary: List[Union[Literal["display"], Path]] = Field( 

104 default_factory=lambda: ["display"], 

105 examples=[ 

106 Path("summary.md"), 

107 Path("bioimageio_summaries/"), 

108 ["display", Path("summary.md")], 

109 ], 

110 ) 

111 """Display the validation summary or save it as JSON, Markdown or HTML. 

112 The format is chosen based on the suffix: `.json`, `.md`, `.html`. 

113 If a folder is given (path w/o suffix) the summary is saved in all formats. 

114 Choose/add `"display"` to render the validation summary to the terminal. 

115 """ 

116 

117 def log(self, descr: Union[ResourceDescr, InvalidDescr]): 

118 _ = descr.validation_summary.log(self.summary) 

119 

120 

121class WithSource(ArgMixin): 

122 source: CliPositionalArg[str] 

123 """Url/path to a (folder with a) bioimageio.yaml/rdf.yaml file 

124 or a bioimage.io resource identifier, e.g. 'affable-shark'""" 

125 

126 @cached_property 

127 def descr(self): 

128 return load_description(self.source) 

129 

130 @property 

131 def descr_id(self) -> str: 

132 """a more user-friendly description id 

133 (replacing legacy ids with their nicknames) 

134 """ 

135 if isinstance(self.descr, InvalidDescr): 

136 return str(getattr(self.descr, "id", getattr(self.descr, "name"))) 

137 

138 nickname = None 

139 if ( 

140 isinstance(self.descr.config, v0_5.Config) 

141 and (bio_config := self.descr.config.bioimageio) 

142 and bio_config.model_extra is not None 

143 ): 

144 nickname = bio_config.model_extra.get("nickname") 

145 

146 return str(nickname or self.descr.id or self.descr.name) 

147 

148 

149class ValidateFormatCmd(CmdBase, WithSource, WithSummaryLogging): 

150 """Validate the meta data format of a bioimageio resource.""" 

151 

152 perform_io_checks: bool = Field( 

153 settings.perform_io_checks, alias="perform-io-checks" 

154 ) 

155 """Wether or not to perform validations that requires downloading remote files. 

156 Note: Default value is set by `BIOIMAGEIO_PERFORM_IO_CHECKS` environment variable. 

157 """ 

158 

159 @cached_property 

160 def descr(self): 

161 return load_description(self.source, perform_io_checks=self.perform_io_checks) 

162 

163 def run(self): 

164 self.log(self.descr) 

165 sys.exit( 

166 0 

167 if self.descr.validation_summary.status in ("valid-format", "passed") 

168 else 1 

169 ) 

170 

171 

172class TestCmd(CmdBase, WithSource, WithSummaryLogging): 

173 """Test a bioimageio resource (beyond meta data formatting).""" 

174 

175 weight_format: WeightFormatArgAll = Field( 

176 "all", 

177 alias="weight-format", 

178 validation_alias=WEIGHT_FORMAT_ALIASES, 

179 ) 

180 """The weight format to limit testing to. 

181 

182 (only relevant for model resources)""" 

183 

184 devices: Optional[List[str]] = None 

185 """Device(s) to use for testing""" 

186 

187 runtime_env: Union[Literal["currently-active", "as-described"], Path] = Field( 

188 "currently-active", alias="runtime-env" 

189 ) 

190 """The python environment to run the tests in 

191 - `"currently-active"`: use active Python interpreter 

192 - `"as-described"`: generate a conda environment YAML file based on the model 

193 weights description. 

194 - A path to a conda environment YAML. 

195 Note: The `bioimageio.core` dependency will be added automatically if not present. 

196 """ 

197 

198 determinism: Literal["seed_only", "full"] = "seed_only" 

199 """Modes to improve reproducibility of test outputs.""" 

200 

201 stop_early: bool = Field( 

202 False, alias="stop-early", validation_alias=AliasChoices("stop-early", "x") 

203 ) 

204 """Do not run further subtests after a failed one.""" 

205 

206 format_version: Union[FormatVersionPlaceholder, str] = Field( 

207 "discover", alias="format-version" 

208 ) 

209 """The format version to use for testing. 

210 - 'latest': Use the latest implemented format version for the given resource type (may trigger auto updating) 

211 - 'discover': Use the format version as described in the resource description 

212 - '0.4', '0.5', ...: Use the specified format version (may trigger auto updating) 

213 """ 

214 

215 def run(self): 

216 sys.exit( 

217 test( 

218 self.descr, 

219 weight_format=self.weight_format, 

220 devices=self.devices, 

221 summary=self.summary, 

222 runtime_env=self.runtime_env, 

223 determinism=self.determinism, 

224 format_version=self.format_version, 

225 ) 

226 ) 

227 

228 

229class PackageCmd(CmdBase, WithSource, WithSummaryLogging): 

230 """Save a resource's metadata with its associated files.""" 

231 

232 path: CliPositionalArg[Path] 

233 """The path to write the (zipped) package to. 

234 If it does not have a `.zip` suffix 

235 this command will save the package as an unzipped folder instead.""" 

236 

237 weight_format: WeightFormatArgAll = Field( 

238 "all", 

239 alias="weight-format", 

240 validation_alias=WEIGHT_FORMAT_ALIASES, 

241 ) 

242 """The weight format to include in the package (for model descriptions only).""" 

243 

244 def run(self): 

245 if isinstance(self.descr, InvalidDescr): 

246 self.log(self.descr) 

247 raise ValueError(f"Invalid {self.descr.type} description.") 

248 

249 sys.exit( 

250 package( 

251 self.descr, 

252 self.path, 

253 weight_format=self.weight_format, 

254 ) 

255 ) 

256 

257 

258def _get_stat( 

259 model_descr: AnyModelDescr, 

260 dataset: Iterable[Sample], 

261 dataset_length: int, 

262 stats_path: Path, 

263) -> Mapping[DatasetMeasure, MeasureValue]: 

264 req_dataset_meas, _ = get_required_dataset_measures(model_descr) 

265 if not req_dataset_meas: 

266 return {} 

267 

268 req_dataset_meas, _ = get_required_dataset_measures(model_descr) 

269 

270 if stats_path.exists(): 

271 logger.info("loading precomputed dataset measures from {}", stats_path) 

272 stat = load_dataset_stat(stats_path) 

273 for m in req_dataset_meas: 

274 if m not in stat: 

275 raise ValueError(f"Missing {m} in {stats_path}") 

276 

277 return stat 

278 

279 stats_calc = StatsCalculator(req_dataset_meas) 

280 

281 for sample in tqdm( 

282 dataset, total=dataset_length, desc="precomputing dataset stats", unit="sample" 

283 ): 

284 stats_calc.update(sample) 

285 

286 stat = stats_calc.finalize() 

287 save_dataset_stat(stat, stats_path) 

288 

289 return stat 

290 

291 

292class UpdateCmdBase(CmdBase, WithSource, ABC): 

293 output: Union[Literal["display", "stdout"], Path] = "display" 

294 """Output updated bioimageio.yaml to the terminal or write to a file. 

295 Notes: 

296 - `"display"`: Render to the terminal with syntax highlighting. 

297 - `"stdout"`: Write to sys.stdout without syntax highligthing. 

298 (More convenient for copying the updated bioimageio.yaml from the terminal.) 

299 """ 

300 

301 diff: Union[bool, Path] = Field(True, alias="diff") 

302 """Output a diff of original and updated bioimageio.yaml. 

303 If a given path has an `.html` extension, a standalone HTML file is written, 

304 otherwise the diff is saved in unified diff format (pure text). 

305 """ 

306 

307 exclude_unset: bool = Field(True, alias="exclude-unset") 

308 """Exclude fields that have not explicitly be set.""" 

309 

310 exclude_defaults: bool = Field(False, alias="exclude-defaults") 

311 """Exclude fields that have the default value (even if set explicitly).""" 

312 

313 @cached_property 

314 def updated(self) -> Union[ResourceDescr, InvalidDescr]: 

315 raise NotImplementedError 

316 

317 def run(self): 

318 original_yaml = open_bioimageio_yaml(self.source).unparsed_content 

319 assert isinstance(original_yaml, str) 

320 stream = StringIO() 

321 

322 save_bioimageio_yaml_only( 

323 self.updated, 

324 stream, 

325 exclude_unset=self.exclude_unset, 

326 exclude_defaults=self.exclude_defaults, 

327 ) 

328 updated_yaml = stream.getvalue() 

329 

330 diff = compare( 

331 original_yaml.split("\n"), 

332 updated_yaml.split("\n"), 

333 diff_format=( 

334 "html" 

335 if isinstance(self.diff, Path) and self.diff.suffix == ".html" 

336 else "unified" 

337 ), 

338 ) 

339 

340 if isinstance(self.diff, Path): 

341 _ = self.diff.write_text(diff, encoding="utf-8") 

342 elif self.diff: 

343 console = rich.console.Console() 

344 diff_md = f"## Diff\n\n````````diff\n{diff}\n````````" 

345 console.print(rich.markdown.Markdown(diff_md)) 

346 

347 if isinstance(self.output, Path): 

348 _ = self.output.write_text(updated_yaml, encoding="utf-8") 

349 logger.info(f"written updated description to {self.output}") 

350 elif self.output == "display": 

351 updated_md = f"## Updated bioimageio.yaml\n\n```yaml\n{updated_yaml}\n```" 

352 rich.console.Console().print(rich.markdown.Markdown(updated_md)) 

353 elif self.output == "stdout": 

354 print(updated_yaml) 

355 else: 

356 assert_never(self.output) 

357 

358 if isinstance(self.updated, InvalidDescr): 

359 logger.warning("Update resulted in invalid description") 

360 _ = self.updated.validation_summary.display() 

361 

362 

363class UpdateFormatCmd(UpdateCmdBase): 

364 """Update the metadata format to the latest format version.""" 

365 

366 exclude_defaults: bool = Field(True, alias="exclude-defaults") 

367 """Exclude fields that have the default value (even if set explicitly). 

368 

369 Note: 

370 The update process sets most unset fields explicitly with their default value. 

371 """ 

372 

373 perform_io_checks: bool = Field( 

374 settings.perform_io_checks, alias="perform-io-checks" 

375 ) 

376 """Wether or not to attempt validation that may require file download. 

377 If `True` file hash values are added if not present.""" 

378 

379 @cached_property 

380 def updated(self): 

381 return update_format( 

382 self.source, 

383 exclude_defaults=self.exclude_defaults, 

384 perform_io_checks=self.perform_io_checks, 

385 ) 

386 

387 

388class UpdateHashesCmd(UpdateCmdBase): 

389 """Create a bioimageio.yaml description with updated file hashes.""" 

390 

391 @cached_property 

392 def updated(self): 

393 return update_hashes(self.source) 

394 

395 

396class PredictCmd(CmdBase, WithSource): 

397 """Run inference on your data with a bioimage.io model.""" 

398 

399 inputs: NotEmpty[List[Union[str, NotEmpty[List[str]]]]] = Field( 

400 default_factory=lambda: ["{input_id}/001.tif"] 

401 ) 

402 """Model input sample paths (for each input tensor) 

403 

404 The input paths are expected to have shape... 

405 - (n_samples,) or (n_samples,1) for models expecting a single input tensor 

406 - (n_samples,) containing the substring '{input_id}', or 

407 - (n_samples, n_model_inputs) to provide each input tensor path explicitly. 

408 

409 All substrings that are replaced by metadata from the model description: 

410 - '{model_id}' 

411 - '{input_id}' 

412 

413 Example inputs to process sample 'a' and 'b' 

414 for a model expecting a 'raw' and a 'mask' input tensor: 

415 --inputs="[[\\"a_raw.tif\\",\\"a_mask.tif\\"],[\\"b_raw.tif\\",\\"b_mask.tif\\"]]" 

416 (Note that JSON double quotes need to be escaped.) 

417 

418 Alternatively a `bioimageio-cli.yaml` (or `bioimageio-cli.json`) file 

419 may provide the arguments, e.g.: 

420 ```yaml 

421 inputs: 

422 - [a_raw.tif, a_mask.tif] 

423 - [b_raw.tif, b_mask.tif] 

424 ``` 

425 

426 `.npy` and any file extension supported by imageio are supported. 

427 Aavailable formats are listed at 

428 https://imageio.readthedocs.io/en/stable/formats/index.html#all-formats. 

429 Some formats have additional dependencies. 

430 

431 

432 """ 

433 

434 outputs: Union[str, NotEmpty[Tuple[str, ...]]] = ( 

435 "outputs_{model_id}/{output_id}/{sample_id}.tif" 

436 ) 

437 """Model output path pattern (per output tensor) 

438 

439 All substrings that are replaced: 

440 - '{model_id}' (from model description) 

441 - '{output_id}' (from model description) 

442 - '{sample_id}' (extracted from input paths) 

443 

444 

445 """ 

446 

447 overwrite: bool = False 

448 """allow overwriting existing output files""" 

449 

450 blockwise: bool = False 

451 """process inputs blockwise""" 

452 

453 stats: Path = Path("dataset_statistics.json") 

454 """path to dataset statistics 

455 (will be written if it does not exist, 

456 but the model requires statistical dataset measures) 

457  """ 

458 

459 preview: bool = False 

460 """preview which files would be processed 

461 and what outputs would be generated.""" 

462 

463 weight_format: WeightFormatArgAny = Field( 

464 "any", 

465 alias="weight-format", 

466 validation_alias=WEIGHT_FORMAT_ALIASES, 

467 ) 

468 """The weight format to use.""" 

469 

470 example: bool = False 

471 """generate and run an example 

472 

473 1. downloads example model inputs 

474 2. creates a `{model_id}_example` folder 

475 3. writes input arguments to `{model_id}_example/bioimageio-cli.yaml` 

476 4. executes a preview dry-run 

477 5. executes prediction with example input 

478 

479 

480 """ 

481 

482 def _example(self): 

483 model_descr = ensure_description_is_model(self.descr) 

484 input_ids = get_member_ids(model_descr.inputs) 

485 example_inputs = ( 

486 model_descr.sample_inputs 

487 if isinstance(model_descr, v0_4.ModelDescr) 

488 else [ 

489 t 

490 for ipt in model_descr.inputs 

491 if (t := ipt.sample_tensor or ipt.test_tensor) 

492 ] 

493 ) 

494 if not example_inputs: 

495 raise ValueError(f"{self.descr_id} does not specify any example inputs.") 

496 

497 inputs001: List[str] = [] 

498 example_path = Path(f"{self.descr_id}_example") 

499 example_path.mkdir(exist_ok=True) 

500 

501 for t, src in zip(input_ids, example_inputs): 

502 reader = get_reader(src) 

503 dst = Path(f"{example_path}/{t}/001{reader.suffix}") 

504 dst.parent.mkdir(parents=True, exist_ok=True) 

505 inputs001.append(dst.as_posix()) 

506 with dst.open("wb") as f: 

507 shutil.copyfileobj(reader, f) 

508 

509 inputs = [inputs001] 

510 output_pattern = f"{example_path}/outputs/{{output_id}}/{{sample_id}}.tif" 

511 

512 bioimageio_cli_path = example_path / YAML_FILE 

513 stats_file = "dataset_statistics.json" 

514 stats = (example_path / stats_file).as_posix() 

515 cli_example_args = dict( 

516 inputs=inputs, 

517 outputs=output_pattern, 

518 stats=stats_file, 

519 blockwise=self.blockwise, 

520 ) 

521 assert is_yaml_value(cli_example_args), cli_example_args 

522 write_yaml( 

523 cli_example_args, 

524 bioimageio_cli_path, 

525 ) 

526 

527 yaml_file_content = None 

528 

529 # escaped double quotes 

530 inputs_json = json.dumps(inputs) 

531 inputs_escaped = inputs_json.replace('"', r"\"") 

532 source_escaped = self.source.replace('"', r"\"") 

533 

534 def get_example_command(preview: bool, escape: bool = False): 

535 q: str = '"' if escape else "" 

536 

537 return [ 

538 "bioimageio", 

539 "predict", 

540 # --no-preview not supported for py=3.8 

541 *(["--preview"] if preview else []), 

542 "--overwrite", 

543 *(["--blockwise"] if self.blockwise else []), 

544 f"--stats={q}{stats}{q}", 

545 f"--inputs={q}{inputs_escaped if escape else inputs_json}{q}", 

546 f"--outputs={q}{output_pattern}{q}", 

547 f"{q}{source_escaped if escape else self.source}{q}", 

548 ] 

549 

550 if Path(YAML_FILE).exists(): 

551 logger.info( 

552 "temporarily removing '{}' to execute example prediction", YAML_FILE 

553 ) 

554 yaml_file_content = Path(YAML_FILE).read_bytes() 

555 Path(YAML_FILE).unlink() 

556 

557 try: 

558 _ = subprocess.run(get_example_command(True), check=True) 

559 _ = subprocess.run(get_example_command(False), check=True) 

560 finally: 

561 if yaml_file_content is not None: 

562 _ = Path(YAML_FILE).write_bytes(yaml_file_content) 

563 logger.debug("restored '{}'", YAML_FILE) 

564 

565 print( 

566 "🎉 Sucessfully ran example prediction!\n" 

567 + "To predict the example input using the CLI example config file" 

568 + f" {example_path / YAML_FILE}, execute `bioimageio predict` from {example_path}:\n" 

569 + f"$ cd {str(example_path)}\n" 

570 + f'$ bioimageio predict "{source_escaped}"\n\n' 

571 + "Alternatively run the following command" 

572 + " in the current workind directory, not the example folder:\n$ " 

573 + " ".join(get_example_command(False, escape=True)) 

574 + f"\n(note that a local '{JSON_FILE}' or '{YAML_FILE}' may interfere with this)" 

575 ) 

576 

577 def run(self): 

578 if self.example: 

579 return self._example() 

580 

581 model_descr = ensure_description_is_model(self.descr) 

582 

583 input_ids = get_member_ids(model_descr.inputs) 

584 output_ids = get_member_ids(model_descr.outputs) 

585 

586 minimum_input_ids = tuple( 

587 str(ipt.id) if isinstance(ipt, v0_5.InputTensorDescr) else str(ipt.name) 

588 for ipt in model_descr.inputs 

589 if not isinstance(ipt, v0_5.InputTensorDescr) or not ipt.optional 

590 ) 

591 maximum_input_ids = tuple( 

592 str(ipt.id) if isinstance(ipt, v0_5.InputTensorDescr) else str(ipt.name) 

593 for ipt in model_descr.inputs 

594 ) 

595 

596 def expand_inputs(i: int, ipt: Union[str, Sequence[str]]) -> Tuple[str, ...]: 

597 if isinstance(ipt, str): 

598 ipts = tuple( 

599 ipt.format(model_id=self.descr_id, input_id=t) for t in input_ids 

600 ) 

601 else: 

602 ipts = tuple( 

603 p.format(model_id=self.descr_id, input_id=t) 

604 for t, p in zip(input_ids, ipt) 

605 ) 

606 

607 if len(set(ipts)) < len(ipts): 

608 if len(minimum_input_ids) == len(maximum_input_ids): 

609 n = len(minimum_input_ids) 

610 else: 

611 n = f"{len(minimum_input_ids)}-{len(maximum_input_ids)}" 

612 

613 raise ValueError( 

614 f"[input sample #{i}] Include '{{input_id}}' in path pattern or explicitly specify {n} distinct input paths (got {ipt})" 

615 ) 

616 

617 if len(ipts) < len(minimum_input_ids): 

618 raise ValueError( 

619 f"[input sample #{i}] Expected at least {len(minimum_input_ids)} inputs {minimum_input_ids}, got {ipts}" 

620 ) 

621 

622 if len(ipts) > len(maximum_input_ids): 

623 raise ValueError( 

624 f"Expected at most {len(maximum_input_ids)} inputs {maximum_input_ids}, got {ipts}" 

625 ) 

626 

627 return ipts 

628 

629 inputs = [expand_inputs(i, ipt) for i, ipt in enumerate(self.inputs, start=1)] 

630 

631 sample_paths_in = [ 

632 {t: Path(p) for t, p in zip(input_ids, ipts)} for ipts in inputs 

633 ] 

634 

635 sample_ids = _get_sample_ids(sample_paths_in) 

636 

637 def expand_outputs(): 

638 if isinstance(self.outputs, str): 

639 outputs = [ 

640 tuple( 

641 Path( 

642 self.outputs.format( 

643 model_id=self.descr_id, output_id=t, sample_id=s 

644 ) 

645 ) 

646 for t in output_ids 

647 ) 

648 for s in sample_ids 

649 ] 

650 else: 

651 outputs = [ 

652 tuple( 

653 Path(p.format(model_id=self.descr_id, output_id=t, sample_id=s)) 

654 for t, p in zip(output_ids, self.outputs) 

655 ) 

656 for s in sample_ids 

657 ] 

658 

659 for i, out in enumerate(outputs, start=1): 

660 if len(set(out)) < len(out): 

661 raise ValueError( 

662 f"[output sample #{i}] Include '{{output_id}}' in path pattern or explicitly specify {len(output_ids)} distinct output paths (got {out})" 

663 ) 

664 

665 if len(out) != len(output_ids): 

666 raise ValueError( 

667 f"[output sample #{i}] Expected {len(output_ids)} outputs {output_ids}, got {out}" 

668 ) 

669 

670 return outputs 

671 

672 outputs = expand_outputs() 

673 

674 sample_paths_out = [ 

675 {MemberId(t): Path(p) for t, p in zip(output_ids, out)} for out in outputs 

676 ] 

677 

678 if not self.overwrite: 

679 for sample_paths in sample_paths_out: 

680 for p in sample_paths.values(): 

681 if p.exists(): 

682 raise FileExistsError( 

683 f"{p} already exists. use --overwrite to (re-)write outputs anyway." 

684 ) 

685 if self.preview: 

686 print("🛈 bioimageio prediction preview structure:") 

687 pprint( 

688 { 

689 "{sample_id}": dict( 

690 inputs={"{input_id}": "<input path>"}, 

691 outputs={"{output_id}": "<output path>"}, 

692 ) 

693 } 

694 ) 

695 print("🔎 bioimageio prediction preview output:") 

696 pprint( 

697 { 

698 s: dict( 

699 inputs={t: p.as_posix() for t, p in sp_in.items()}, 

700 outputs={t: p.as_posix() for t, p in sp_out.items()}, 

701 ) 

702 for s, sp_in, sp_out in zip( 

703 sample_ids, sample_paths_in, sample_paths_out 

704 ) 

705 } 

706 ) 

707 return 

708 

709 def input_dataset(stat: Stat): 

710 for s, sp_in in zip(sample_ids, sample_paths_in): 

711 yield load_sample_for_model( 

712 model=model_descr, 

713 paths=sp_in, 

714 stat=stat, 

715 sample_id=s, 

716 ) 

717 

718 stat: Dict[Measure, MeasureValue] = dict( 

719 _get_stat( 

720 model_descr, input_dataset({}), len(sample_ids), self.stats 

721 ).items() 

722 ) 

723 

724 pp = create_prediction_pipeline( 

725 model_descr, 

726 weight_format=None if self.weight_format == "any" else self.weight_format, 

727 ) 

728 predict_method = ( 

729 pp.predict_sample_with_blocking 

730 if self.blockwise 

731 else pp.predict_sample_without_blocking 

732 ) 

733 

734 for sample_in, sp_out in tqdm( 

735 zip(input_dataset(dict(stat)), sample_paths_out), 

736 total=len(inputs), 

737 desc=f"predict with {self.descr_id}", 

738 unit="sample", 

739 ): 

740 sample_out = predict_method(sample_in) 

741 save_sample(sp_out, sample_out) 

742 

743 

744class AddWeightsCmd(CmdBase, WithSource, WithSummaryLogging): 

745 output: CliPositionalArg[Path] 

746 """The path to write the updated model package to.""" 

747 

748 source_format: Optional[SupportedWeightsFormat] = Field(None, alias="source-format") 

749 """Exclusively use these weights to convert to other formats.""" 

750 

751 target_format: Optional[SupportedWeightsFormat] = Field(None, alias="target-format") 

752 """Exclusively add this weight format.""" 

753 

754 verbose: bool = False 

755 """Log more (error) output.""" 

756 

757 tracing: bool = True 

758 """Allow tracing when converting pytorch_state_dict to torchscript 

759 (still uses scripting if possible).""" 

760 

761 def run(self): 

762 model_descr = ensure_description_is_model(self.descr) 

763 if isinstance(model_descr, v0_4.ModelDescr): 

764 raise TypeError( 

765 f"model format {model_descr.format_version} not supported." 

766 + " Please update the model first." 

767 ) 

768 updated_model_descr = add_weights( 

769 model_descr, 

770 output_path=self.output, 

771 source_format=self.source_format, 

772 target_format=self.target_format, 

773 verbose=self.verbose, 

774 allow_tracing=self.tracing, 

775 ) 

776 self.log(updated_model_descr) 

777 

778 

779JSON_FILE = "bioimageio-cli.json" 

780YAML_FILE = "bioimageio-cli.yaml" 

781 

782 

783class Bioimageio( 

784 BaseSettings, 

785 cli_implicit_flags=True, 

786 cli_parse_args=True, 

787 cli_prog_name="bioimageio", 

788 cli_use_class_docs_for_groups=True, 

789 use_attribute_docstrings=True, 

790): 

791 """bioimageio - CLI for bioimage.io resources 🦒""" 

792 

793 model_config = SettingsConfigDict( 

794 json_file=JSON_FILE, 

795 yaml_file=YAML_FILE, 

796 ) 

797 

798 validate_format: CliSubCommand[ValidateFormatCmd] = Field(alias="validate-format") 

799 "Check a resource's metadata format" 

800 

801 test: CliSubCommand[TestCmd] 

802 "Test a bioimageio resource (beyond meta data formatting)" 

803 

804 package: CliSubCommand[PackageCmd] 

805 "Package a resource" 

806 

807 predict: CliSubCommand[PredictCmd] 

808 "Predict with a model resource" 

809 

810 update_format: CliSubCommand[UpdateFormatCmd] = Field(alias="update-format") 

811 """Update the metadata format""" 

812 

813 update_hashes: CliSubCommand[UpdateHashesCmd] = Field(alias="update-hashes") 

814 """Create a bioimageio.yaml description with updated file hashes.""" 

815 

816 add_weights: CliSubCommand[AddWeightsCmd] = Field(alias="add-weights") 

817 """Add additional weights to the model descriptions converted from available 

818 formats to improve deployability.""" 

819 

820 @classmethod 

821 def settings_customise_sources( 

822 cls, 

823 settings_cls: Type[BaseSettings], 

824 init_settings: PydanticBaseSettingsSource, 

825 env_settings: PydanticBaseSettingsSource, 

826 dotenv_settings: PydanticBaseSettingsSource, 

827 file_secret_settings: PydanticBaseSettingsSource, 

828 ) -> Tuple[PydanticBaseSettingsSource, ...]: 

829 cli: CliSettingsSource[BaseSettings] = CliSettingsSource( 

830 settings_cls, 

831 cli_parse_args=True, 

832 formatter_class=RawTextHelpFormatter, 

833 ) 

834 sys_args = pformat(sys.argv) 

835 logger.info("starting CLI with arguments:\n{}", sys_args) 

836 return ( 

837 cli, 

838 init_settings, 

839 YamlConfigSettingsSource(settings_cls), 

840 JsonConfigSettingsSource(settings_cls), 

841 ) 

842 

843 @model_validator(mode="before") 

844 @classmethod 

845 def _log(cls, data: Any): 

846 logger.info( 

847 "loaded CLI input:\n{}", 

848 pformat({k: v for k, v in data.items() if v is not None}), 

849 ) 

850 return data 

851 

852 def run(self): 

853 logger.info( 

854 "executing CLI command:\n{}", 

855 pformat({k: v for k, v in self.model_dump().items() if v is not None}), 

856 ) 

857 cmd = ( 

858 self.add_weights 

859 or self.package 

860 or self.predict 

861 or self.test 

862 or self.update_format 

863 or self.update_hashes 

864 or self.validate_format 

865 ) 

866 assert cmd is not None 

867 cmd.run() 

868 

869 

870assert isinstance(Bioimageio.__doc__, str) 

871Bioimageio.__doc__ += f""" 

872 

873library versions: 

874 bioimageio.core {__version__} 

875 bioimageio.spec {bioimageio.spec.__version__} 

876 

877spec format versions: 

878 model RDF {ModelDescr.implemented_format_version} 

879 dataset RDF {DatasetDescr.implemented_format_version} 

880 notebook RDF {NotebookDescr.implemented_format_version} 

881 

882""" 

883 

884 

885def _get_sample_ids( 

886 input_paths: Sequence[Mapping[MemberId, Path]], 

887) -> Sequence[SampleId]: 

888 """Get sample ids for given input paths, based on the common path per sample. 

889 

890 Falls back to sample01, samle02, etc...""" 

891 

892 matcher = SequenceMatcher() 

893 

894 def get_common_seq(seqs: Sequence[Sequence[str]]) -> Sequence[str]: 

895 """extract a common sequence from multiple sequences 

896 (order sensitive; strips whitespace and slashes) 

897 """ 

898 common = seqs[0] 

899 

900 for seq in seqs[1:]: 

901 if not seq: 

902 continue 

903 matcher.set_seqs(common, seq) 

904 i, _, size = matcher.find_longest_match() 

905 common = common[i : i + size] 

906 

907 if isinstance(common, str): 

908 common = common.strip().strip("/") 

909 else: 

910 common = [cs for c in common if (cs := c.strip().strip("/"))] 

911 

912 if not common: 

913 raise ValueError(f"failed to find common sequence for {seqs}") 

914 

915 return common 

916 

917 def get_shorter_diff(seqs: Sequence[Sequence[str]]) -> List[Sequence[str]]: 

918 """get a shorter sequence whose entries are still unique 

919 (order sensitive, not minimal sequence) 

920 """ 

921 min_seq_len = min(len(s) for s in seqs) 

922 # cut from the start 

923 for start in range(min_seq_len - 1, -1, -1): 

924 shortened = [s[start:] for s in seqs] 

925 if len(set(shortened)) == len(seqs): 

926 min_seq_len -= start 

927 break 

928 else: 

929 seen: Set[Sequence[str]] = set() 

930 dupes = [s for s in seqs if s in seen or seen.add(s)] 

931 raise ValueError(f"Found duplicate entries {dupes}") 

932 

933 # cut from the end 

934 for end in range(min_seq_len - 1, 1, -1): 

935 shortened = [s[:end] for s in shortened] 

936 if len(set(shortened)) == len(seqs): 

937 break 

938 

939 return shortened 

940 

941 full_tensor_ids = [ 

942 sorted( 

943 p.resolve().with_suffix("").as_posix() for p in input_sample_paths.values() 

944 ) 

945 for input_sample_paths in input_paths 

946 ] 

947 try: 

948 long_sample_ids = [get_common_seq(t) for t in full_tensor_ids] 

949 sample_ids = get_shorter_diff(long_sample_ids) 

950 except ValueError as e: 

951 raise ValueError(f"failed to extract sample ids: {e}") 

952 

953 return sample_ids