Coverage for src/bioimageio/core/cli.py: 80%

445 statements  

« prev     ^ index     » next       coverage.py v7.14.2, created at 2026-06-22 16:54 +0000

1"""bioimageio CLI 

2 

3Note: Some docstrings use a hair space ' ' 

4 to place the added '(default: ...)' on a new line. 

5""" 

6 

7import json 

8import shutil 

9import subprocess 

10import sys 

11from abc import ABC 

12from argparse import RawTextHelpFormatter 

13from difflib import SequenceMatcher 

14from functools import cached_property, partial 

15from io import StringIO 

16from pathlib import Path 

17from pprint import pformat, pprint 

18from typing import ( 

19 Annotated, 

20 Any, 

21 Dict, 

22 Iterable, 

23 List, 

24 Literal, 

25 Mapping, 

26 Optional, 

27 Sequence, 

28 Set, 

29 Tuple, 

30 Type, 

31 Union, 

32) 

33 

34import numpy as np 

35import rich.markdown 

36from loguru import logger 

37from pydantic import ( 

38 AliasChoices, 

39 BaseModel, 

40 Field, 

41 PlainSerializer, 

42 WithJsonSchema, 

43 model_validator, 

44) 

45from pydantic_settings import ( 

46 BaseSettings, 

47 CliApp, 

48 CliPositionalArg, 

49 CliSettingsSource, 

50 CliSubCommand, 

51 JsonConfigSettingsSource, 

52 PydanticBaseSettingsSource, 

53 SettingsConfigDict, 

54 YamlConfigSettingsSource, 

55) 

56from tqdm import tqdm 

57from typing_extensions import assert_never 

58 

59import bioimageio.spec 

60from bioimageio.core import __version__ 

61from bioimageio.spec import ( 

62 AnyModelDescr, 

63 InvalidDescr, 

64 ResourceDescr, 

65 load_description, 

66 save_bioimageio_package, 

67 save_bioimageio_package_as_folder, 

68 save_bioimageio_yaml_only, 

69 settings, 

70 update_format, 

71 update_hashes, 

72) 

73from bioimageio.spec._internal.io import is_yaml_value 

74from bioimageio.spec._internal.io_utils import open_bioimageio_yaml 

75from bioimageio.spec._internal.types import FormatVersionPlaceholder, NotEmpty 

76from bioimageio.spec.dataset import DatasetDescr 

77from bioimageio.spec.model import ModelDescr, v0_4, v0_5 

78from bioimageio.spec.notebook import NotebookDescr 

79from bioimageio.spec.utils import ( 

80 empty_cache, 

81 ensure_description_is_model, 

82 get_reader, 

83 write_yaml, 

84) 

85 

86from ._prediction_pipeline import ( 

87 create_prediction_pipeline, 

88 create_remote_prediction_pipeline, 

89) 

90from .commands import WeightFormatArgAll, WeightFormatArgAny, package, test 

91from .common import MemberId, SampleId, SupportedWeightsFormat 

92from .digest_spec import get_member_ids, load_sample_for_model 

93from .io import load_stat, save_sample, save_stat 

94from .proc_setup import get_required_dataset_measures 

95from .remote_backends import create_remote_model_adapter 

96from .sample import Sample 

97from .stat_calculators import StatsCalculator 

98from .stat_measures import Measure, MeasureValue, Stat 

99from .utils import compare 

100from .weight_converters._add_weights import add_weights 

101 

102WEIGHT_FORMAT_ALIASES = AliasChoices( 

103 "weight-format", 

104 "weights-format", 

105 "weight_format", 

106 "weights_format", 

107) 

108 

109 

110class CmdBase(BaseModel, use_attribute_docstrings=True, cli_implicit_flags=True): 

111 pass 

112 

113 

114class ArgMixin(BaseModel, use_attribute_docstrings=True, cli_implicit_flags=True): 

115 pass 

116 

117 

118class WithSummaryLogging(ArgMixin): 

119 summary: List[Union[Literal["display"], Path]] = Field( 

120 default_factory=lambda: ["display"], 

121 examples=[ 

122 Path("summary.md"), 

123 Path("bioimageio_summaries/"), 

124 ["display", Path("summary.md")], 

125 ], 

126 ) 

127 """Display the validation summary or save it as JSON, Markdown or HTML. 

128 The format is chosen based on the suffix: `.json`, `.md`, `.html`. 

129 If a folder is given (path w/o suffix) the summary is saved in all formats. 

130 Choose/add `"display"` to render the validation summary to the terminal. 

131 """ 

132 

133 def log(self, descr: Union[ResourceDescr, InvalidDescr]): 

134 _ = descr.validation_summary.log(self.summary) 

135 

136 

137class WithSource(ArgMixin): 

138 source: CliPositionalArg[str] 

139 """Url/path to a (folder with a) bioimageio.yaml/rdf.yaml file 

140 or a bioimage.io resource identifier, e.g. 'affable-shark'""" 

141 

142 @cached_property 

143 def descr(self): 

144 return load_description(self.source) 

145 

146 @property 

147 def descr_id(self) -> str: 

148 """a more user-friendly description id 

149 (replacing legacy ids with their nicknames) 

150 """ 

151 if isinstance(self.descr, InvalidDescr): 

152 return str(getattr(self.descr, "id", getattr(self.descr, "name"))) 

153 

154 nickname = None 

155 if ( 

156 isinstance(self.descr.config, v0_5.Config) 

157 and (bio_config := self.descr.config.bioimageio) 

158 and bio_config.model_extra is not None 

159 ): 

160 nickname = bio_config.model_extra.get("nickname") 

161 

162 return str(nickname or self.descr.id or self.descr.name) 

163 

164 

165class ValidateFormatCmd(CmdBase, WithSource, WithSummaryLogging): 

166 """Validate the meta data format of a bioimageio resource.""" 

167 

168 perform_io_checks: bool = Field( 

169 settings.perform_io_checks, alias="perform-io-checks" 

170 ) 

171 """Wether or not to perform validations that requires downloading remote files. 

172 Note: Default value is set by `BIOIMAGEIO_PERFORM_IO_CHECKS` environment variable. 

173 """ 

174 

175 @cached_property 

176 def descr(self): 

177 return load_description(self.source, perform_io_checks=self.perform_io_checks) 

178 

179 def cli_cmd(self): 

180 self.log(self.descr) 

181 sys.exit( 

182 0 

183 if self.descr.validation_summary.status in ("valid-format", "passed") 

184 else 1 

185 ) 

186 

187 

188class TestCmd(CmdBase, WithSource, WithSummaryLogging): 

189 """Test a bioimageio resource (beyond meta data formatting).""" 

190 

191 weight_format: WeightFormatArgAll = Field( 

192 "all", 

193 alias="weight-format", 

194 validation_alias=WEIGHT_FORMAT_ALIASES, 

195 ) 

196 """The weight format to limit testing to. 

197 

198 (only relevant for model resources)""" 

199 

200 devices: Optional[List[str]] = Field( 

201 None, validation_alias=AliasChoices("devices", "device") 

202 ) 

203 """Device(s) to use""" 

204 

205 runtime_env: Union[Literal["currently-active", "as-described"], Path] = Field( 

206 "currently-active", alias="runtime-env" 

207 ) 

208 """The python environment to run the tests in 

209 - `"currently-active"`: use active Python interpreter 

210 - `"as-described"`: generate a conda environment YAML file based on the model 

211 weights description. 

212 - A path to a conda environment YAML. 

213 Note: The `bioimageio.core` dependency will be added automatically if not present. 

214 """ 

215 

216 working_dir: Optional[Path] = Field(None, alias="working-dir") 

217 """(for debugging) Directory to save any temporary files.""" 

218 

219 determinism: Literal["seed_only", "full"] = "seed_only" 

220 """Modes to improve reproducibility of test outputs.""" 

221 

222 stop_early: bool = Field( 

223 False, alias="stop-early", validation_alias=AliasChoices("stop-early", "x") 

224 ) 

225 """Do not run further subtests after a failed one.""" 

226 

227 format_version: Union[FormatVersionPlaceholder, str] = Field( 

228 "discover", alias="format-version" 

229 ) 

230 """The format version to use for testing. 

231 - 'latest': Use the latest implemented format version for the given resource type (may trigger auto updating) 

232 - 'discover': Use the format version as described in the resource description 

233 - '0.4', '0.5', ...: Use the specified format version (may trigger auto updating) 

234 """ 

235 

236 def cli_cmd(self): 

237 sys.exit( 

238 test( 

239 self.descr, 

240 weight_format=self.weight_format, 

241 devices=self.devices, 

242 summary=self.summary, 

243 runtime_env=self.runtime_env, 

244 determinism=self.determinism, 

245 format_version=self.format_version, 

246 working_dir=self.working_dir, 

247 ) 

248 ) 

249 

250 

251class PackageCmd(CmdBase, WithSource, WithSummaryLogging): 

252 """Save a resource's metadata with its associated files.""" 

253 

254 path: CliPositionalArg[Path] 

255 """The path to write the (zipped) package to. 

256 If it does not have a `.zip` suffix 

257 this command will save the package as an unzipped folder instead.""" 

258 

259 weight_format: WeightFormatArgAll = Field( 

260 "all", 

261 alias="weight-format", 

262 validation_alias=WEIGHT_FORMAT_ALIASES, 

263 ) 

264 """The weight format to include in the package (for model descriptions only).""" 

265 

266 def cli_cmd(self): 

267 if isinstance(self.descr, InvalidDescr): 

268 self.log(self.descr) 

269 raise ValueError(f"Invalid {self.descr.type} description.") 

270 

271 sys.exit( 

272 package( 

273 self.descr, 

274 self.path, 

275 weight_format=self.weight_format, 

276 ) 

277 ) 

278 

279 

280def _get_stat( 

281 model_descr: AnyModelDescr, 

282 dataset: Iterable[Sample], 

283 dataset_length: int, 

284 stats_path: Path, 

285) -> Stat: 

286 req_dataset_meas, _ = get_required_dataset_measures(model_descr) 

287 if not req_dataset_meas: 

288 return {} 

289 

290 req_dataset_meas, _ = get_required_dataset_measures(model_descr) 

291 

292 if stats_path.exists(): 

293 logger.info("loading precomputed dataset measures from {}", stats_path) 

294 stat = load_stat(stats_path) 

295 for m in req_dataset_meas: 

296 if m not in stat: 

297 raise ValueError(f"Missing {m} in {stats_path}") 

298 

299 return stat 

300 

301 stats_calc = StatsCalculator(req_dataset_meas) 

302 

303 for sample in tqdm( 

304 dataset, total=dataset_length, desc="precomputing dataset stats", unit="sample" 

305 ): 

306 stats_calc.update(sample) 

307 

308 stat: Dict[Measure, MeasureValue] = {k: v for k, v in stats_calc.finalize().items()} 

309 save_stat(stat, stats_path) 

310 return stat 

311 

312 

313class UpdateCmdBase(CmdBase, WithSource, ABC): 

314 output: Union[Literal["display", "stdout"], Path] = "display" 

315 """Output updated bioimageio.yaml to the terminal or write to a file. 

316 Notes: 

317 - `"display"`: Render to the terminal with syntax highlighting. 

318 - `"stdout"`: Write to sys.stdout without syntax highligthing. 

319 (More convenient for copying the updated bioimageio.yaml from the terminal.) 

320 """ 

321 

322 diff: Union[bool, Path] = Field(True, alias="diff") 

323 """Output a diff of original and updated bioimageio.yaml. 

324 If a given path has an `.html` extension, a standalone HTML file is written, 

325 otherwise the diff is saved in unified diff format (pure text). 

326 """ 

327 

328 exclude_unset: bool = Field(True, alias="exclude-unset") 

329 """Exclude fields that have not explicitly be set.""" 

330 

331 exclude_defaults: bool = Field(False, alias="exclude-defaults") 

332 """Exclude fields that have the default value (even if set explicitly).""" 

333 

334 @cached_property 

335 def updated(self) -> Union[ResourceDescr, InvalidDescr]: 

336 raise NotImplementedError 

337 

338 def cli_cmd(self): 

339 original_yaml = open_bioimageio_yaml(self.source).unparsed_content 

340 assert isinstance(original_yaml, str) 

341 stream = StringIO() 

342 

343 save_bioimageio_yaml_only( 

344 self.updated, 

345 stream, 

346 exclude_unset=self.exclude_unset, 

347 exclude_defaults=self.exclude_defaults, 

348 ) 

349 updated_yaml = stream.getvalue() 

350 

351 diff = compare( 

352 original_yaml.split("\n"), 

353 updated_yaml.split("\n"), 

354 diff_format=( 

355 "html" 

356 if isinstance(self.diff, Path) and self.diff.suffix == ".html" 

357 else "unified" 

358 ), 

359 ) 

360 

361 if isinstance(self.diff, Path): 

362 _ = self.diff.write_text(diff, encoding="utf-8") 

363 elif self.diff: 

364 console = rich.console.Console() 

365 diff_md = f"## Diff\n\n````````diff\n{diff}\n````````" 

366 console.print(rich.markdown.Markdown(diff_md)) 

367 

368 if isinstance(self.output, Path): 

369 if self.output.suffix in (".yaml", ".yml"): 

370 _ = self.output.write_text(updated_yaml, encoding="utf-8") 

371 logger.info(f"written updated description to {self.output}") 

372 elif isinstance(self.updated, InvalidDescr): 

373 raise ValueError( 

374 f"Cannot save invalid description package to {self.output}." 

375 + " To save the metadata only, choose an output with a '.yaml' extension." 

376 ) 

377 

378 elif not self.output.suffix: 

379 _ = save_bioimageio_package_as_folder( 

380 self.updated, output_path=self.output 

381 ) 

382 else: 

383 _ = save_bioimageio_package(self.updated, output_path=self.output) 

384 

385 elif self.output == "display": 

386 updated_md = f"## Updated bioimageio.yaml\n\n```yaml\n{updated_yaml}\n```" 

387 rich.console.Console().print(rich.markdown.Markdown(updated_md)) 

388 elif self.output == "stdout": 

389 print(updated_yaml) 

390 else: 

391 assert_never(self.output) 

392 

393 if isinstance(self.updated, InvalidDescr): 

394 logger.warning("Update resulted in invalid description") 

395 _ = self.updated.validation_summary.display() 

396 

397 

398class UpdateFormatCmd(UpdateCmdBase): 

399 """Update the metadata format to the latest format version.""" 

400 

401 exclude_defaults: bool = Field(True, alias="exclude-defaults") 

402 """Exclude fields that have the default value (even if set explicitly). 

403 

404 Note: 

405 The update process sets most unset fields explicitly with their default value. 

406 """ 

407 

408 perform_io_checks: bool = Field( 

409 settings.perform_io_checks, alias="perform-io-checks" 

410 ) 

411 """Wether or not to attempt validation that may require file download. 

412 If `True` file hash values are added if not present.""" 

413 

414 @cached_property 

415 def updated(self): 

416 return update_format( 

417 self.source, 

418 exclude_defaults=self.exclude_defaults, 

419 perform_io_checks=self.perform_io_checks, 

420 ) 

421 

422 

423class UpdateHashesCmd(UpdateCmdBase): 

424 """Create a bioimageio.yaml description with updated file hashes.""" 

425 

426 @cached_property 

427 def updated(self): 

428 return update_hashes(self.source) 

429 

430 

431class PredictCmd(CmdBase, WithSource): 

432 """Run inference on your data with a bioimage.io model.""" 

433 

434 inputs: NotEmpty[List[Union[str, NotEmpty[List[str]]]]] = Field( 

435 default_factory=lambda: ["{input_id}/001.tif"], 

436 validation_alias=AliasChoices("inputs", "input"), 

437 ) 

438 """Model input sample paths (for each input tensor) 

439 

440 The input paths are expected to have shape... 

441 - (n_samples,) or (n_samples,1) for models expecting a single input tensor 

442 - (n_samples,) containing the substring '{input_id}', or 

443 - (n_samples, n_model_inputs) to provide each input tensor path explicitly. 

444 

445 All substrings that are replaced by metadata from the model description: 

446 - '{model_id}' 

447 - '{input_id}' 

448 

449 Example inputs to process sample 'a' and 'b' 

450 for a model expecting a 'raw' and a 'mask' input tensor: 

451 --inputs="[[\\"a_raw.tif\\",\\"a_mask.tif\\"],[\\"b_raw.tif\\",\\"b_mask.tif\\"]]" 

452 (Note that JSON double quotes need to be escaped.) 

453 

454 Alternatively a `bioimageio-cli.yaml` (or `bioimageio-cli.json`) file 

455 may provide the arguments, e.g.: 

456 ```yaml 

457 inputs: 

458 - [a_raw.tif, a_mask.tif] 

459 - [b_raw.tif, b_mask.tif] 

460 ``` 

461 

462 `.npy` and any file extension supported by imageio are supported. 

463 Aavailable formats are listed at 

464 https://imageio.readthedocs.io/en/stable/formats/index.html#all-formats. 

465 Some formats have additional dependencies. 

466 

467 

468 """ 

469 

470 outputs: Union[str, NotEmpty[Tuple[str, ...]]] = Field( 

471 "outputs_{model_id}/{output_id}/{sample_id}.tif", 

472 validation_alias=AliasChoices("outputs", "output"), 

473 ) 

474 """Model output path pattern (per output tensor) 

475 

476 All substrings that are replaced: 

477 - '{model_id}' (from model description) 

478 - '{output_id}' (from model description) 

479 - '{sample_id}' (extracted from input paths) 

480 

481  """ 

482 

483 overwrite: bool = False 

484 """allow overwriting existing output files""" 

485 

486 blockwise: Union[bool, int] = False 

487 """Process inputs blockwise 

488 

489 - If an integer is given, it is used as the blocksize parameter 'n' for blockwise processing. 

490 The blockize parameter determines the block size along axes with parameterized input size 

491 by adding n*step_size to the minimum valid input size. 

492 - If `True`, the blocksize parameter is set to 10. 

493 - If `False`, inputs are processed as a whole without blocking. 

494 

495  """ 

496 

497 stats: Annotated[ 

498 Path, 

499 WithJsonSchema({"type": "string"}), 

500 PlainSerializer(lambda p: p.as_posix(), return_type=str), 

501 ] = Path("precomputed_statistics.json") 

502 """path to dataset statistics 

503 (will be written if it does not exist 

504 and the model requires statistical dataset measures) 

505  """ 

506 

507 preview: bool = False 

508 """preview which files would be processed 

509 and what outputs would be generated.""" 

510 

511 weight_format: WeightFormatArgAny = Field( 

512 "any", 

513 alias="weight-format", 

514 validation_alias=WEIGHT_FORMAT_ALIASES, 

515 ) 

516 """The weight format to use.""" 

517 

518 devices: Optional[List[str]] = Field( 

519 None, validation_alias=AliasChoices("devices", "device") 

520 ) 

521 """Device(s) to use""" 

522 

523 server: Optional[str] = None 

524 """The URL or Hugging Face space name of a running bioimageio (gradio) server instance to use as a remote backend for prediction.""" 

525 

526 pre_post_processing_location: Literal["local", "remote"] = Field( 

527 "local", alias="pre-post-processing-location" 

528 ) 

529 """Where to run preprocessing/postprocessing operations when using `--server`. 

530 

531 - `local`: Run preprocessing/postprocessing locally and only model inference on the server. 

532 - `remote`: Run preprocessing/postprocessing on the server as well. 

533 

534 

535 """ 

536 

537 example: bool = False 

538 """generate and run an example 

539 

540 1. downloads example model inputs 

541 2. creates a `{model_id}_example` folder 

542 3. writes input arguments to `{model_id}_example/bioimageio-cli.yaml` 

543 4. executes a preview dry-run 

544 5. executes prediction with example input 

545 

546 

547 """ 

548 

549 def _example(self): 

550 model_descr = ensure_description_is_model(self.descr) 

551 input_ids = get_member_ids(model_descr.inputs) 

552 example_inputs = ( 

553 model_descr.sample_inputs 

554 if isinstance(model_descr, v0_4.ModelDescr) 

555 else [ 

556 t 

557 for ipt in model_descr.inputs 

558 if (t := ipt.sample_tensor or ipt.test_tensor) 

559 ] 

560 ) 

561 if not example_inputs: 

562 raise ValueError(f"{self.descr_id} does not specify any example inputs.") 

563 

564 inputs001: List[str] = [] 

565 example_path = Path(f"{self.descr_id}_example") 

566 example_path.mkdir(exist_ok=True) 

567 

568 for t, src in zip(input_ids, example_inputs): 

569 reader = get_reader(src) 

570 dst = Path(f"{example_path}/{t}/001{reader.suffix}") 

571 dst.parent.mkdir(parents=True, exist_ok=True) 

572 inputs001.append(dst.as_posix()) 

573 with dst.open("wb") as f: 

574 shutil.copyfileobj(reader, f) 

575 

576 inputs = [inputs001] 

577 output_pattern = f"{example_path}/outputs/{{output_id}}/{{sample_id}}.tif" 

578 

579 bioimageio_cli_path = example_path / YAML_FILE 

580 stats_file = "precomputed_statistics.json" 

581 stats = (example_path / stats_file).as_posix() 

582 cli_example_args = dict( 

583 inputs=inputs, 

584 outputs=output_pattern, 

585 stats=stats_file, 

586 blockwise=self.blockwise, 

587 ) 

588 assert is_yaml_value(cli_example_args), cli_example_args 

589 write_yaml( 

590 cli_example_args, 

591 bioimageio_cli_path, 

592 ) 

593 

594 yaml_file_content = None 

595 

596 # escaped double quotes 

597 inputs_json = json.dumps(inputs) 

598 inputs_escaped = inputs_json.replace('"', r"\"") 

599 source_escaped = self.source.replace('"', r"\"") 

600 

601 def get_example_command(preview: bool, escape: bool = False): 

602 q: str = '"' if escape else "" 

603 

604 return [ 

605 "bioimageio", 

606 "predict", 

607 # --no-preview not supported for py=3.8 

608 *(["--preview"] if preview else []), 

609 "--overwrite", 

610 f"--blockwise={self.blockwise}", 

611 f"--stats={q}{stats}{q}", 

612 f"--inputs={q}{inputs_escaped if escape else inputs_json}{q}", 

613 f"--outputs={q}{output_pattern}{q}", 

614 f"{q}{source_escaped if escape else self.source}{q}", 

615 ] 

616 

617 if Path(YAML_FILE).exists(): 

618 logger.info( 

619 "temporarily removing '{}' to execute example prediction", YAML_FILE 

620 ) 

621 yaml_file_content = Path(YAML_FILE).read_bytes() 

622 Path(YAML_FILE).unlink() 

623 

624 try: 

625 _ = subprocess.run(get_example_command(True), check=True) 

626 _ = subprocess.run(get_example_command(False), check=True) 

627 finally: 

628 if yaml_file_content is not None: 

629 _ = Path(YAML_FILE).write_bytes(yaml_file_content) 

630 logger.debug("restored '{}'", YAML_FILE) 

631 

632 print( 

633 "🎉 Sucessfully ran example prediction!\n" 

634 + "To predict the example input using the CLI example config file" 

635 + f" {example_path / YAML_FILE}, execute `bioimageio predict` from {example_path}:\n" 

636 + f"$ cd {str(example_path)}\n" 

637 + f'$ bioimageio predict "{source_escaped}"\n\n' 

638 + "Alternatively run the following command" 

639 + " in the current workind directory, not the example folder:\n$ " 

640 + " ".join(get_example_command(False, escape=True)) 

641 + f"\n(note that a local '{JSON_FILE}' or '{YAML_FILE}' may interfere with this)" 

642 ) 

643 

644 def cli_cmd(self): 

645 try: 

646 for out_sample, out_path in self._yield_predictions(self.blockwise): 

647 save_sample(out_path, out_sample) 

648 except Exception as e: 

649 if not self.blockwise: 

650 raise RuntimeError( 

651 f"Prediction failed ({e}).\nConsider using blockwise processing, " 

652 + "e.g. with `--blockwise=10` to process inputs in blocks." 

653 ) from e 

654 raise e 

655 

656 def _yield_predictions(self, blockwise: Union[bool, int]): 

657 if self.example: 

658 return self._example() 

659 

660 model_descr = ensure_description_is_model(self.descr) 

661 

662 input_ids = get_member_ids(model_descr.inputs) 

663 output_ids = get_member_ids(model_descr.outputs) 

664 

665 minimum_input_ids = tuple( 

666 str(ipt.id) if isinstance(ipt, v0_5.InputTensorDescr) else str(ipt.name) 

667 for ipt in model_descr.inputs 

668 if not isinstance(ipt, v0_5.InputTensorDescr) or not ipt.optional 

669 ) 

670 maximum_input_ids = tuple( 

671 str(ipt.id) if isinstance(ipt, v0_5.InputTensorDescr) else str(ipt.name) 

672 for ipt in model_descr.inputs 

673 ) 

674 

675 def expand_inputs(i: int, ipt: Union[str, Sequence[str]]) -> Tuple[str, ...]: 

676 if isinstance(ipt, str): 

677 ipts = tuple( 

678 ipt.format(model_id=self.descr_id, input_id=t) for t in input_ids 

679 ) 

680 else: 

681 ipts = tuple( 

682 p.format(model_id=self.descr_id, input_id=t) 

683 for t, p in zip(input_ids, ipt) 

684 ) 

685 

686 if len(set(ipts)) < len(ipts): 

687 if len(minimum_input_ids) == len(maximum_input_ids): 

688 n = len(minimum_input_ids) 

689 else: 

690 n = f"{len(minimum_input_ids)}-{len(maximum_input_ids)}" 

691 

692 raise ValueError( 

693 f"[input sample #{i}] Include '{{input_id}}' in path pattern or explicitly specify {n} distinct input paths (got {ipt})" 

694 ) 

695 

696 if len(ipts) < len(minimum_input_ids): 

697 raise ValueError( 

698 f"[input sample #{i}] Expected at least {len(minimum_input_ids)} inputs {minimum_input_ids}, got {ipts}" 

699 ) 

700 

701 if len(ipts) > len(maximum_input_ids): 

702 raise ValueError( 

703 f"Expected at most {len(maximum_input_ids)} inputs {maximum_input_ids}, got {ipts}" 

704 ) 

705 

706 return ipts 

707 

708 inputs = [expand_inputs(i, ipt) for i, ipt in enumerate(self.inputs, start=1)] 

709 

710 sample_paths_in = [ 

711 {t: Path(p) for t, p in zip(input_ids, ipts)} for ipts in inputs 

712 ] 

713 

714 sample_ids = _get_sample_ids(sample_paths_in) 

715 

716 def expand_outputs(): 

717 if isinstance(self.outputs, str): 

718 outputs = [ 

719 tuple( 

720 Path( 

721 self.outputs.format( 

722 model_id=self.descr_id, output_id=t, sample_id=s 

723 ) 

724 ) 

725 for t in output_ids 

726 ) 

727 for s in sample_ids 

728 ] 

729 else: 

730 outputs = [ 

731 tuple( 

732 Path(p.format(model_id=self.descr_id, output_id=t, sample_id=s)) 

733 for t, p in zip(output_ids, self.outputs) 

734 ) 

735 for s in sample_ids 

736 ] 

737 # check for distinctness and correct number within each output sample 

738 for i, out in enumerate(outputs, start=1): 

739 if len(set(out)) < len(out): 

740 raise ValueError( 

741 f"[output sample #{i}] Include '{{output_id}}' in path pattern or explicitly specify {len(output_ids)} distinct output paths (got {out})" 

742 ) 

743 

744 if len(out) != len(output_ids): 

745 raise ValueError( 

746 f"[output sample #{i}] Expected {len(output_ids)} outputs {output_ids}, got {out}" 

747 ) 

748 

749 # check for distinctness across all output samples 

750 all_output_paths = [p for out in outputs for p in out] 

751 if len(set(all_output_paths)) < len(all_output_paths): 

752 raise ValueError( 

753 "Output paths are not distinct across samples. " 

754 + "Make sure to include '{{sample_id}}' in the output path pattern." 

755 ) 

756 

757 return outputs 

758 

759 outputs = expand_outputs() 

760 

761 sample_paths_out = [ 

762 {MemberId(t): Path(p) for t, p in zip(output_ids, out)} for out in outputs 

763 ] 

764 

765 if not self.overwrite: 

766 for sample_paths in sample_paths_out: 

767 for p in sample_paths.values(): 

768 if p.exists(): 

769 raise FileExistsError( 

770 f"{p} already exists. use --overwrite to (re-)write outputs anyway." 

771 ) 

772 if self.preview: 

773 print("🛈 bioimageio prediction preview structure:") 

774 pprint( 

775 { 

776 "{sample_id}": dict( 

777 inputs={"{input_id}": "<input path>"}, 

778 outputs={"{output_id}": "<output path>"}, 

779 ) 

780 } 

781 ) 

782 print("🔎 bioimageio prediction preview output:") 

783 pprint( 

784 { 

785 s: dict( 

786 inputs={t: p.as_posix() for t, p in sp_in.items()}, 

787 outputs={t: p.as_posix() for t, p in sp_out.items()}, 

788 ) 

789 for s, sp_in, sp_out in zip( 

790 sample_ids, sample_paths_in, sample_paths_out 

791 ) 

792 } 

793 ) 

794 return 

795 

796 def input_dataset(stat: Stat): 

797 for s, sp_in in zip(sample_ids, sample_paths_in): 

798 yield load_sample_for_model( 

799 model=model_descr, 

800 paths=sp_in, 

801 stat=stat, 

802 sample_id=s, 

803 ) 

804 

805 stat: Dict[Measure, MeasureValue] = dict( 

806 _get_stat( 

807 model_descr, input_dataset({}), len(sample_ids), self.stats 

808 ).items() 

809 ) 

810 

811 if self.server is not None and self.pre_post_processing_location == "remote": 

812 pp = create_remote_prediction_pipeline(model_descr, server=self.server) 

813 else: 

814 if self.server is None: 

815 model_adapter = None 

816 else: 

817 assert self.pre_post_processing_location == "local" 

818 model_adapter = create_remote_model_adapter( 

819 model_descr, server=self.server 

820 ) 

821 

822 pp = create_prediction_pipeline( 

823 model_descr, 

824 weight_format=None 

825 if self.weight_format == "any" 

826 else self.weight_format, 

827 devices=self.devices, 

828 model_adapter=model_adapter, 

829 ) 

830 

831 if blockwise: 

832 predict_method = partial( 

833 pp.predict_sample_with_blocking, 

834 ns=None if isinstance(blockwise, bool) else blockwise, 

835 ) 

836 else: 

837 predict_method = pp.predict_sample_without_blocking 

838 

839 for sample_in, sp_out in tqdm( 

840 zip(input_dataset(dict(stat)), sample_paths_out), 

841 total=len(inputs), 

842 desc=f"predict with {self.descr_id}", 

843 unit="sample", 

844 ): 

845 if self.blockwise is False and not isinstance( 

846 pp.model_description, v0_4.ModelDescr 

847 ): 

848 try: 

849 _ = pp.model_description.validate_input_tensors( 

850 sample_in.as_arrays() 

851 ) 

852 except Exception as e: 

853 logger.warning( 

854 "Input sample '{}' failed validation for whole-sample prediction: {}\n" 

855 + "Consider using blockwise processing, e.g. with `--blockwise=10` to process inputs in blocks.", 

856 sample_in.id, 

857 e, 

858 ) 

859 

860 yield (predict_method(sample_in), sp_out) 

861 

862 

863class PredictBlockArtifactsCmd(PredictCmd): 

864 """Command to inspect block artifacts by subtracting the combined, blockwise predictions from a whole sample prediction. 

865 

866 Note: 

867 - This command intentionally uses a small blocksize (default: 1) to create block artifacts for testing purposes. 

868 - Typical sources of block artifacts include: 

869 - Described halo is smaller than the model's receptive field 

870 - Normalization layers inside the network cannot aggregate statistics over the whole sample. 

871 """ 

872 

873 blockwise: Union[Literal[True], int] = 1 

874 """Process inputs blockwise 

875 

876 - If an integer is given, it is used as the blocksize parameter 'n' for blockwise processing. 

877 The blockize parameter determines the block size along axes with parameterized input size 

878 by adding n*step_size to the minimum valid input size. 

879 - If `True`, the blocksize parameter is set to 10. 

880 

881 Defaults to a small blocksize to intentionally create block artifacts for testing purposes. 

882  """ 

883 

884 def cli_cmd(self): 

885 for (out_sample, out_path), (out_sample_blockwise, _) in zip( 

886 self._yield_predictions(False), self._yield_predictions(self.blockwise) 

887 ): 

888 diff_sample = self._subtract_samples(out_sample, out_sample_blockwise) 

889 for k, v_a in out_sample.stat.items(): 

890 v_b = out_sample_blockwise.stat.get(k) 

891 if v_b is None: 

892 logger.error( 

893 "measure '{}' not found in blockwise prediction statistics", k 

894 ) 

895 elif not np.not_equal(v_a, v_b): 

896 logger.error( 

897 "measure '{}' has different values (whole sample!=blockwise): {}!={}", 

898 k, 

899 v_a, 

900 v_b, 

901 ) 

902 

903 save_sample(out_path, diff_sample) 

904 

905 @staticmethod 

906 def _subtract_samples(a: Sample, b: Sample) -> Sample: 

907 return Sample( 

908 members={t: a.members[t] - b.members[t] for t in a.members}, 

909 id=a.id, 

910 stat=a.stat, 

911 ) 

912 

913 

914class AddWeightsCmd(CmdBase, WithSource, WithSummaryLogging): 

915 """Add additional weights to a model description by converting from available formats.""" 

916 

917 output: CliPositionalArg[Path] 

918 """The path to write the updated model package to.""" 

919 

920 source_format: Optional[SupportedWeightsFormat] = Field(None, alias="source-format") 

921 """Exclusively use these weights to convert to other formats.""" 

922 

923 target_format: Optional[SupportedWeightsFormat] = Field(None, alias="target-format") 

924 """Exclusively add this weight format.""" 

925 

926 verbose: bool = False 

927 """Log more (error) output.""" 

928 

929 tracing: bool = True 

930 """Allow tracing when converting pytorch_state_dict to torchscript 

931 (still uses scripting if possible).""" 

932 

933 def cli_cmd(self): 

934 model_descr = ensure_description_is_model(self.descr) 

935 if isinstance(model_descr, v0_4.ModelDescr): 

936 raise TypeError( 

937 f"model format {model_descr.format_version} not supported." 

938 + " Please update the model first." 

939 ) 

940 updated_model_descr = add_weights( 

941 model_descr, 

942 output_path=self.output, 

943 source_format=self.source_format, 

944 target_format=self.target_format, 

945 verbose=self.verbose, 

946 allow_tracing=self.tracing, 

947 ) 

948 self.log(updated_model_descr) 

949 

950 

951class EmptyCacheCmd(CmdBase): 

952 """Empty the bioimageio cache directory.""" 

953 

954 def cli_cmd(self): 

955 empty_cache() 

956 

957 

958class ServerCmd(CmdBase): 

959 """Start a server to connect to with remote model adapters or remote prediction pipelines.""" 

960 

961 backend: Literal["gradio"] = "gradio" 

962 """The remote backend to use.""" 

963 

964 port: Optional[int] = None 

965 """The port to start the server on. If not given, a free port will be used.""" 

966 

967 def cli_cmd(self) -> None: 

968 try: 

969 if self.backend == "gradio": 

970 from .remote_backends.gradio.server import main 

971 else: 

972 assert_never(self.backend) 

973 except ImportError as e: 

974 raise ImportError( 

975 f"{self.backend.capitalize()} is not installed. Please install the '{self.backend}-server' extra to use this command," 

976 + f" e.g. with `pip install bioimageio.core[{self.backend}-server]`." 

977 ) from e 

978 

979 local_server_url = main(port=self.port) 

980 logger.info( 

981 "{} server shutdown at {}", self.backend.capitalize(), local_server_url 

982 ) 

983 

984 

985JSON_FILE = "bioimageio-cli.json" 

986YAML_FILE = "bioimageio-cli.yaml" 

987 

988 

989class Bioimageio( 

990 BaseSettings, 

991 cli_implicit_flags=True, 

992 cli_parse_args=True, 

993 cli_prog_name="bioimageio", 

994 cli_use_class_docs_for_groups=True, 

995 use_attribute_docstrings=True, 

996): 

997 """bioimageio - CLI for bioimage.io resources 🦒""" 

998 

999 model_config = SettingsConfigDict( 

1000 json_file=JSON_FILE, 

1001 yaml_file=YAML_FILE, 

1002 ) 

1003 

1004 validate_format: CliSubCommand[ValidateFormatCmd] = Field(alias="validate-format") 

1005 """Check a resource's metadata format""" 

1006 

1007 test: CliSubCommand[TestCmd] 

1008 """Test a bioimageio resource (beyond meta data formatting)""" 

1009 

1010 package: CliSubCommand[PackageCmd] 

1011 """Package a resource""" 

1012 

1013 predict: CliSubCommand[PredictCmd] 

1014 """Predict with a model resource""" 

1015 

1016 predict_block_artifacts: CliSubCommand[PredictBlockArtifactsCmd] = Field( 

1017 alias="predict-block-artifacts" 

1018 ) 

1019 """Save the difference between predicting blowise and whole sample to check for block artifacts.""" 

1020 

1021 update_format: CliSubCommand[UpdateFormatCmd] = Field(alias="update-format") 

1022 """Update the metadata format""" 

1023 

1024 update_hashes: CliSubCommand[UpdateHashesCmd] = Field(alias="update-hashes") 

1025 """Create a bioimageio.yaml description with updated file hashes.""" 

1026 

1027 add_weights: CliSubCommand[AddWeightsCmd] = Field(alias="add-weights") 

1028 """Add additional weights to a model description by converting from available formats.""" 

1029 

1030 empty_cache: CliSubCommand[EmptyCacheCmd] = Field(alias="empty-cache") 

1031 """Empty the bioimageio cache directory.""" 

1032 

1033 server: CliSubCommand[ServerCmd] 

1034 """Start a server to connect to with remote model adapters or remote prediction pipelines.""" 

1035 

1036 @classmethod 

1037 def settings_customise_sources( 

1038 cls, 

1039 settings_cls: Type[BaseSettings], 

1040 init_settings: PydanticBaseSettingsSource, 

1041 env_settings: PydanticBaseSettingsSource, 

1042 dotenv_settings: PydanticBaseSettingsSource, 

1043 file_secret_settings: PydanticBaseSettingsSource, 

1044 ) -> Tuple[PydanticBaseSettingsSource, ...]: 

1045 cli: CliSettingsSource[BaseSettings] = CliSettingsSource( 

1046 settings_cls, 

1047 cli_parse_args=True, 

1048 formatter_class=RawTextHelpFormatter, 

1049 ) 

1050 sys_args = pformat(sys.argv) 

1051 logger.info("starting CLI with arguments:\n{}", sys_args) 

1052 return ( 

1053 cli, 

1054 init_settings, 

1055 YamlConfigSettingsSource(settings_cls), 

1056 JsonConfigSettingsSource(settings_cls), 

1057 ) 

1058 

1059 @model_validator(mode="before") 

1060 @classmethod 

1061 def _log(cls, data: Any): 

1062 logger.info( 

1063 "loaded CLI input:\n{}", 

1064 pformat({k: v for k, v in data.items() if v is not None}), 

1065 ) 

1066 return data 

1067 

1068 def cli_cmd(self) -> None: 

1069 logger.info( 

1070 "executing CLI command:\n{}", 

1071 pformat({k: v for k, v in self.model_dump().items() if v is not None}), 

1072 ) 

1073 _ = CliApp.run_subcommand(self) 

1074 

1075 

1076assert isinstance(Bioimageio.__doc__, str) 

1077Bioimageio.__doc__ += f""" 

1078 

1079library versions: 

1080 bioimageio.core {__version__} 

1081 bioimageio.spec {bioimageio.spec.__version__} 

1082 

1083spec format versions: 

1084 model RDF {ModelDescr.implemented_format_version} 

1085 dataset RDF {DatasetDescr.implemented_format_version} 

1086 notebook RDF {NotebookDescr.implemented_format_version} 

1087 

1088""" 

1089 

1090 

1091def _get_sample_ids( 

1092 input_paths: Sequence[Mapping[MemberId, Path]], 

1093) -> Sequence[SampleId]: 

1094 """Get sample ids for given input paths, based on the common path per sample. 

1095 

1096 Falls back to sample01, samle02, etc...""" 

1097 

1098 matcher = SequenceMatcher() 

1099 

1100 def get_common_seq(seqs: Sequence[Sequence[str]]) -> Sequence[str]: 

1101 """extract a common sequence from multiple sequences 

1102 (order sensitive; strips whitespace and slashes) 

1103 """ 

1104 common = seqs[0] 

1105 

1106 for seq in seqs[1:]: 

1107 if not seq: 

1108 continue 

1109 matcher.set_seqs(common, seq) 

1110 i, _, size = matcher.find_longest_match() 

1111 common = common[i : i + size] 

1112 

1113 if isinstance(common, str): 

1114 common = common.strip().strip("/") 

1115 else: 

1116 common = [cs for c in common if (cs := c.strip().strip("/"))] 

1117 

1118 if not common: 

1119 raise ValueError(f"failed to find common sequence for {seqs}") 

1120 

1121 return common 

1122 

1123 def get_shorter_diff(seqs: Sequence[Sequence[str]]) -> List[Sequence[str]]: 

1124 """get a shorter sequence whose entries are still unique 

1125 (order sensitive, not minimal sequence) 

1126 """ 

1127 min_seq_len = min(len(s) for s in seqs) 

1128 # cut from the start 

1129 for start in range(min_seq_len - 1, -1, -1): 

1130 shortened = [s[start:] for s in seqs] 

1131 if len(set(shortened)) == len(seqs): 

1132 min_seq_len -= start 

1133 break 

1134 else: 

1135 seen: Set[Sequence[str]] = set() 

1136 dupes = [s for s in seqs if s in seen or seen.add(s)] 

1137 raise ValueError(f"Found duplicate entries {dupes}") 

1138 

1139 # cut from the end 

1140 for end in range(min_seq_len - 1, 1, -1): 

1141 shortened = [s[:end] for s in shortened] 

1142 if len(set(shortened)) == len(seqs): 

1143 break 

1144 

1145 return shortened 

1146 

1147 full_tensor_ids = [ 

1148 sorted( 

1149 p.resolve().with_suffix("").as_posix() for p in input_sample_paths.values() 

1150 ) 

1151 for input_sample_paths in input_paths 

1152 ] 

1153 try: 

1154 long_sample_ids = [get_common_seq(t) for t in full_tensor_ids] 

1155 sample_ids = get_shorter_diff(long_sample_ids) 

1156 except ValueError as e: 

1157 raise ValueError(f"failed to extract sample ids: {e}") 

1158 

1159 return sample_ids