Coverage for bioimageio/core/cli.py: 80%

272 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-19 09:02 +0000

1"""bioimageio CLI 

2 

3Note: Some docstrings use a hair space ' ' 

4 to place the added '(default: ...)' on a new line. 

5""" 

6 

7import json 

8import shutil 

9import subprocess 

10import sys 

11from argparse import RawTextHelpFormatter 

12from difflib import SequenceMatcher 

13from functools import cached_property 

14from pathlib import Path 

15from pprint import pformat, pprint 

16from typing import ( 

17 Any, 

18 Dict, 

19 Iterable, 

20 List, 

21 Mapping, 

22 Optional, 

23 Sequence, 

24 Set, 

25 Tuple, 

26 Type, 

27 Union, 

28) 

29 

30from loguru import logger 

31from pydantic import BaseModel, Field, model_validator 

32from pydantic_settings import ( 

33 BaseSettings, 

34 CliPositionalArg, 

35 CliSettingsSource, 

36 CliSubCommand, 

37 JsonConfigSettingsSource, 

38 PydanticBaseSettingsSource, 

39 SettingsConfigDict, 

40 YamlConfigSettingsSource, 

41) 

42from ruyaml import YAML 

43from tqdm import tqdm 

44from typing_extensions import assert_never 

45 

46from bioimageio.spec import AnyModelDescr, InvalidDescr, load_description 

47from bioimageio.spec._internal.io_basics import ZipPath 

48from bioimageio.spec._internal.types import NotEmpty 

49from bioimageio.spec.dataset import DatasetDescr 

50from bioimageio.spec.model import ModelDescr, v0_4, v0_5 

51from bioimageio.spec.notebook import NotebookDescr 

52from bioimageio.spec.utils import download, ensure_description_is_model 

53 

54from .commands import ( 

55 WeightFormatArgAll, 

56 WeightFormatArgAny, 

57 package, 

58 test, 

59 validate_format, 

60) 

61from .common import MemberId, SampleId 

62from .digest_spec import get_member_ids, load_sample_for_model 

63from .io import load_dataset_stat, save_dataset_stat, save_sample 

64from .prediction import create_prediction_pipeline 

65from .proc_setup import ( 

66 DatasetMeasure, 

67 Measure, 

68 MeasureValue, 

69 StatsCalculator, 

70 get_required_dataset_measures, 

71) 

72from .sample import Sample 

73from .stat_measures import Stat 

74from .utils import VERSION 

75 

76yaml = YAML(typ="safe") 

77 

78 

79class CmdBase(BaseModel, use_attribute_docstrings=True, cli_implicit_flags=True): 

80 pass 

81 

82 

83class ArgMixin(BaseModel, use_attribute_docstrings=True, cli_implicit_flags=True): 

84 pass 

85 

86 

87class WithSource(ArgMixin): 

88 source: CliPositionalArg[str] 

89 """Url/path to a bioimageio.yaml/rdf.yaml file 

90 or a bioimage.io resource identifier, e.g. 'affable-shark'""" 

91 

92 @cached_property 

93 def descr(self): 

94 return load_description(self.source) 

95 

96 @property 

97 def descr_id(self) -> str: 

98 """a more user-friendly description id 

99 (replacing legacy ids with their nicknames) 

100 """ 

101 if isinstance(self.descr, InvalidDescr): 

102 return str(getattr(self.descr, "id", getattr(self.descr, "name"))) 

103 else: 

104 return str( 

105 ( 

106 (bio_config := self.descr.config.get("bioimageio", {})) 

107 and isinstance(bio_config, dict) 

108 and bio_config.get("nickname") 

109 ) 

110 or self.descr.id 

111 or self.descr.name 

112 ) 

113 

114 

115class ValidateFormatCmd(CmdBase, WithSource): 

116 """validate the meta data format of a bioimageio resource.""" 

117 

118 def run(self): 

119 sys.exit(validate_format(self.descr)) 

120 

121 

122class TestCmd(CmdBase, WithSource): 

123 """Test a bioimageio resource (beyond meta data formatting)""" 

124 

125 weight_format: WeightFormatArgAll = "all" 

126 """The weight format to limit testing to. 

127 

128 (only relevant for model resources)""" 

129 

130 devices: Optional[Union[str, Sequence[str]]] = None 

131 """Device(s) to use for testing""" 

132 

133 decimal: int = 4 

134 """Precision for numerical comparisons""" 

135 

136 def run(self): 

137 sys.exit( 

138 test( 

139 self.descr, 

140 weight_format=self.weight_format, 

141 devices=self.devices, 

142 decimal=self.decimal, 

143 ) 

144 ) 

145 

146 

147class PackageCmd(CmdBase, WithSource): 

148 """save a resource's metadata with its associated files.""" 

149 

150 path: CliPositionalArg[Path] 

151 """The path to write the (zipped) package to. 

152 If it does not have a `.zip` suffix 

153 this command will save the package as an unzipped folder instead.""" 

154 

155 weight_format: WeightFormatArgAll = "all" 

156 """The weight format to include in the package (for model descriptions only).""" 

157 

158 def run(self): 

159 if isinstance(self.descr, InvalidDescr): 

160 self.descr.validation_summary.display() 

161 raise ValueError("resource description is invalid") 

162 

163 sys.exit( 

164 package( 

165 self.descr, 

166 self.path, 

167 weight_format=self.weight_format, 

168 ) 

169 ) 

170 

171 

172def _get_stat( 

173 model_descr: AnyModelDescr, 

174 dataset: Iterable[Sample], 

175 dataset_length: int, 

176 stats_path: Path, 

177) -> Mapping[DatasetMeasure, MeasureValue]: 

178 req_dataset_meas, _ = get_required_dataset_measures(model_descr) 

179 if not req_dataset_meas: 

180 return {} 

181 

182 req_dataset_meas, _ = get_required_dataset_measures(model_descr) 

183 

184 if stats_path.exists(): 

185 logger.info(f"loading precomputed dataset measures from {stats_path}") 

186 stat = load_dataset_stat(stats_path) 

187 for m in req_dataset_meas: 

188 if m not in stat: 

189 raise ValueError(f"Missing {m} in {stats_path}") 

190 

191 return stat 

192 

193 stats_calc = StatsCalculator(req_dataset_meas) 

194 

195 for sample in tqdm( 

196 dataset, total=dataset_length, desc="precomputing dataset stats", unit="sample" 

197 ): 

198 stats_calc.update(sample) 

199 

200 stat = stats_calc.finalize() 

201 save_dataset_stat(stat, stats_path) 

202 

203 return stat 

204 

205 

206class PredictCmd(CmdBase, WithSource): 

207 """Run inference on your data with a bioimage.io model.""" 

208 

209 inputs: NotEmpty[Sequence[Union[str, NotEmpty[Tuple[str, ...]]]]] = ( 

210 "{input_id}/001.tif", 

211 ) 

212 """Model input sample paths (for each input tensor) 

213 

214 The input paths are expected to have shape... 

215 - (n_samples,) or (n_samples,1) for models expecting a single input tensor 

216 - (n_samples,) containing the substring '{input_id}', or 

217 - (n_samples, n_model_inputs) to provide each input tensor path explicitly. 

218 

219 All substrings that are replaced by metadata from the model description: 

220 - '{model_id}' 

221 - '{input_id}' 

222 

223 Example inputs to process sample 'a' and 'b' 

224 for a model expecting a 'raw' and a 'mask' input tensor: 

225 --inputs="[[\"a_raw.tif\",\"a_mask.tif\"],[\"b_raw.tif\",\"b_mask.tif\"]]" 

226 (Note that JSON double quotes need to be escaped.) 

227 

228 Alternatively a `bioimageio-cli.yaml` (or `bioimageio-cli.json`) file 

229 may provide the arguments, e.g.: 

230 ```yaml 

231 inputs: 

232 - [a_raw.tif, a_mask.tif] 

233 - [b_raw.tif, b_mask.tif] 

234 ``` 

235 

236 `.npy` and any file extension supported by imageio are supported. 

237 Aavailable formats are listed at 

238 https://imageio.readthedocs.io/en/stable/formats/index.html#all-formats. 

239 Some formats have additional dependencies. 

240 

241 

242 """ 

243 

244 outputs: Union[str, NotEmpty[Tuple[str, ...]]] = ( 

245 "outputs_{model_id}/{output_id}/{sample_id}.tif" 

246 ) 

247 """Model output path pattern (per output tensor) 

248 

249 All substrings that are replaced: 

250 - '{model_id}' (from model description) 

251 - '{output_id}' (from model description) 

252 - '{sample_id}' (extracted from input paths) 

253 

254 

255 """ 

256 

257 overwrite: bool = False 

258 """allow overwriting existing output files""" 

259 

260 blockwise: bool = False 

261 """process inputs blockwise""" 

262 

263 stats: Path = Path("dataset_statistics.json") 

264 """path to dataset statistics 

265 (will be written if it does not exist, 

266 but the model requires statistical dataset measures) 

267  """ 

268 

269 preview: bool = False 

270 """preview which files would be processed 

271 and what outputs would be generated.""" 

272 

273 weight_format: WeightFormatArgAny = "any" 

274 """The weight format to use.""" 

275 

276 example: bool = False 

277 """generate and run an example 

278 

279 1. downloads example model inputs 

280 2. creates a `{model_id}_example` folder 

281 3. writes input arguments to `{model_id}_example/bioimageio-cli.yaml` 

282 4. executes a preview dry-run 

283 5. executes prediction with example input 

284 

285 

286 """ 

287 

288 def _example(self): 

289 model_descr = ensure_description_is_model(self.descr) 

290 input_ids = get_member_ids(model_descr.inputs) 

291 example_inputs = ( 

292 model_descr.sample_inputs 

293 if isinstance(model_descr, v0_4.ModelDescr) 

294 else [ipt.sample_tensor or ipt.test_tensor for ipt in model_descr.inputs] 

295 ) 

296 if not example_inputs: 

297 raise ValueError(f"{self.descr_id} does not specify any example inputs.") 

298 

299 inputs001: List[str] = [] 

300 example_path = Path(f"{self.descr_id}_example") 

301 example_path.mkdir(exist_ok=True) 

302 

303 for t, src in zip(input_ids, example_inputs): 

304 local = download(src).path 

305 dst = Path(f"{example_path}/{t}/001{''.join(local.suffixes)}") 

306 dst.parent.mkdir(parents=True, exist_ok=True) 

307 inputs001.append(dst.as_posix()) 

308 if isinstance(local, Path): 

309 shutil.copy(local, dst) 

310 elif isinstance(local, ZipPath): 

311 _ = local.root.extract(local.at, path=dst) 

312 else: 

313 assert_never(local) 

314 

315 inputs = [tuple(inputs001)] 

316 output_pattern = f"{example_path}/outputs/{ output_id} /{ sample_id} .tif" 

317 

318 bioimageio_cli_path = example_path / YAML_FILE 

319 stats_file = "dataset_statistics.json" 

320 stats = (example_path / stats_file).as_posix() 

321 yaml.dump( 

322 dict( 

323 inputs=inputs, 

324 outputs=output_pattern, 

325 stats=stats_file, 

326 blockwise=self.blockwise, 

327 ), 

328 bioimageio_cli_path, 

329 ) 

330 

331 yaml_file_content = None 

332 

333 # escaped double quotes 

334 inputs_json = json.dumps(inputs) 

335 inputs_escaped = inputs_json.replace('"', r"\"") 

336 source_escaped = self.source.replace('"', r"\"") 

337 

338 def get_example_command(preview: bool, escape: bool = False): 

339 q: str = '"' if escape else "" 

340 

341 return [ 

342 "bioimageio", 

343 "predict", 

344 # --no-preview not supported for py=3.8 

345 *(["--preview"] if preview else []), 

346 "--overwrite", 

347 *(["--blockwise"] if self.blockwise else []), 

348 f"--stats={q}{stats}{q}", 

349 f"--inputs={q}{inputs_escaped if escape else inputs_json}{q}", 

350 f"--outputs={q}{output_pattern}{q}", 

351 f"{q}{source_escaped if escape else self.source}{q}", 

352 ] 

353 

354 if Path(YAML_FILE).exists(): 

355 logger.info( 

356 "temporarily removing '{}' to execute example prediction", YAML_FILE 

357 ) 

358 yaml_file_content = Path(YAML_FILE).read_bytes() 

359 Path(YAML_FILE).unlink() 

360 

361 try: 

362 _ = subprocess.run(get_example_command(True), check=True) 

363 _ = subprocess.run(get_example_command(False), check=True) 

364 finally: 

365 if yaml_file_content is not None: 

366 _ = Path(YAML_FILE).write_bytes(yaml_file_content) 

367 logger.debug("restored '{}'", YAML_FILE) 

368 

369 print( 

370 "🎉 Sucessfully ran example prediction!\n" 

371 + "To predict the example input using the CLI example config file" 

372 + f" {example_path/YAML_FILE}, execute `bioimageio predict` from {example_path}:\n" 

373 + f"$ cd {str(example_path)}\n" 

374 + f'$ bioimageio predict "{source_escaped}"\n\n' 

375 + "Alternatively run the following command" 

376 + " in the current workind directory, not the example folder:\n$ " 

377 + " ".join(get_example_command(False, escape=True)) 

378 + f"\n(note that a local '{JSON_FILE}' or '{YAML_FILE}' may interfere with this)" 

379 ) 

380 

381 def run(self): 

382 if self.example: 

383 return self._example() 

384 

385 model_descr = ensure_description_is_model(self.descr) 

386 

387 input_ids = get_member_ids(model_descr.inputs) 

388 output_ids = get_member_ids(model_descr.outputs) 

389 

390 minimum_input_ids = tuple( 

391 str(ipt.id) if isinstance(ipt, v0_5.InputTensorDescr) else str(ipt.name) 

392 for ipt in model_descr.inputs 

393 if not isinstance(ipt, v0_5.InputTensorDescr) or not ipt.optional 

394 ) 

395 maximum_input_ids = tuple( 

396 str(ipt.id) if isinstance(ipt, v0_5.InputTensorDescr) else str(ipt.name) 

397 for ipt in model_descr.inputs 

398 ) 

399 

400 def expand_inputs(i: int, ipt: Union[str, Tuple[str, ...]]) -> Tuple[str, ...]: 

401 if isinstance(ipt, str): 

402 ipts = tuple( 

403 ipt.format(model_id=self.descr_id, input_id=t) for t in input_ids 

404 ) 

405 else: 

406 ipts = tuple( 

407 p.format(model_id=self.descr_id, input_id=t) 

408 for t, p in zip(input_ids, ipt) 

409 ) 

410 

411 if len(set(ipts)) < len(ipts): 

412 if len(minimum_input_ids) == len(maximum_input_ids): 

413 n = len(minimum_input_ids) 

414 else: 

415 n = f"{len(minimum_input_ids)}-{len(maximum_input_ids)}" 

416 

417 raise ValueError( 

418 f"[input sample #{i}] Include '{ input_id} ' in path pattern or explicitly specify {n} distinct input paths (got {ipt})" 

419 ) 

420 

421 if len(ipts) < len(minimum_input_ids): 

422 raise ValueError( 

423 f"[input sample #{i}] Expected at least {len(minimum_input_ids)} inputs {minimum_input_ids}, got {ipts}" 

424 ) 

425 

426 if len(ipts) > len(maximum_input_ids): 

427 raise ValueError( 

428 f"Expected at most {len(maximum_input_ids)} inputs {maximum_input_ids}, got {ipts}" 

429 ) 

430 

431 return ipts 

432 

433 inputs = [expand_inputs(i, ipt) for i, ipt in enumerate(self.inputs, start=1)] 

434 

435 sample_paths_in = [ 

436 {t: Path(p) for t, p in zip(input_ids, ipts)} for ipts in inputs 

437 ] 

438 

439 sample_ids = _get_sample_ids(sample_paths_in) 

440 

441 def expand_outputs(): 

442 if isinstance(self.outputs, str): 

443 outputs = [ 

444 tuple( 

445 Path( 

446 self.outputs.format( 

447 model_id=self.descr_id, output_id=t, sample_id=s 

448 ) 

449 ) 

450 for t in output_ids 

451 ) 

452 for s in sample_ids 

453 ] 

454 else: 

455 outputs = [ 

456 tuple( 

457 Path(p.format(model_id=self.descr_id, output_id=t, sample_id=s)) 

458 for t, p in zip(output_ids, self.outputs) 

459 ) 

460 for s in sample_ids 

461 ] 

462 

463 for i, out in enumerate(outputs, start=1): 

464 if len(set(out)) < len(out): 

465 raise ValueError( 

466 f"[output sample #{i}] Include '{ output_id} ' in path pattern or explicitly specify {len(output_ids)} distinct output paths (got {out})" 

467 ) 

468 

469 if len(out) != len(output_ids): 

470 raise ValueError( 

471 f"[output sample #{i}] Expected {len(output_ids)} outputs {output_ids}, got {out}" 

472 ) 

473 

474 return outputs 

475 

476 outputs = expand_outputs() 

477 

478 sample_paths_out = [ 

479 {MemberId(t): Path(p) for t, p in zip(output_ids, out)} for out in outputs 

480 ] 

481 

482 if not self.overwrite: 

483 for sample_paths in sample_paths_out: 

484 for p in sample_paths.values(): 

485 if p.exists(): 

486 raise FileExistsError( 

487 f"{p} already exists. use --overwrite to (re-)write outputs anyway." 

488 ) 

489 if self.preview: 

490 print("🛈 bioimageio prediction preview structure:") 

491 pprint( 

492 { 

493 "{sample_id}": dict( 

494 inputs={"{input_id}": "<input path>"}, 

495 outputs={"{output_id}": "<output path>"}, 

496 ) 

497 } 

498 ) 

499 print("🔎 bioimageio prediction preview output:") 

500 pprint( 

501 { 

502 s: dict( 

503 inputs={t: p.as_posix() for t, p in sp_in.items()}, 

504 outputs={t: p.as_posix() for t, p in sp_out.items()}, 

505 ) 

506 for s, sp_in, sp_out in zip( 

507 sample_ids, sample_paths_in, sample_paths_out 

508 ) 

509 } 

510 ) 

511 return 

512 

513 def input_dataset(stat: Stat): 

514 for s, sp_in in zip(sample_ids, sample_paths_in): 

515 yield load_sample_for_model( 

516 model=model_descr, 

517 paths=sp_in, 

518 stat=stat, 

519 sample_id=s, 

520 ) 

521 

522 stat: Dict[Measure, MeasureValue] = dict( 

523 _get_stat( 

524 model_descr, input_dataset({}), len(sample_ids), self.stats 

525 ).items() 

526 ) 

527 

528 pp = create_prediction_pipeline( 

529 model_descr, 

530 weight_format=None if self.weight_format == "any" else self.weight_format, 

531 ) 

532 predict_method = ( 

533 pp.predict_sample_with_blocking 

534 if self.blockwise 

535 else pp.predict_sample_without_blocking 

536 ) 

537 

538 for sample_in, sp_out in tqdm( 

539 zip(input_dataset(dict(stat)), sample_paths_out), 

540 total=len(inputs), 

541 desc=f"predict with {self.descr_id}", 

542 unit="sample", 

543 ): 

544 sample_out = predict_method(sample_in) 

545 save_sample(sp_out, sample_out) 

546 

547 

548JSON_FILE = "bioimageio-cli.json" 

549YAML_FILE = "bioimageio-cli.yaml" 

550 

551 

552class Bioimageio( 

553 BaseSettings, 

554 cli_parse_args=True, 

555 cli_prog_name="bioimageio", 

556 cli_use_class_docs_for_groups=True, 

557 cli_implicit_flags=True, 

558 use_attribute_docstrings=True, 

559): 

560 """bioimageio - CLI for bioimage.io resources 🦒""" 

561 

562 model_config = SettingsConfigDict( 

563 json_file=JSON_FILE, 

564 yaml_file=YAML_FILE, 

565 ) 

566 

567 validate_format: CliSubCommand[ValidateFormatCmd] = Field(alias="validate-format") 

568 "Check a resource's metadata format" 

569 

570 test: CliSubCommand[TestCmd] 

571 "Test a bioimageio resource (beyond meta data formatting)" 

572 

573 package: CliSubCommand[PackageCmd] 

574 "Package a resource" 

575 

576 predict: CliSubCommand[PredictCmd] 

577 "Predict with a model resource" 

578 

579 @classmethod 

580 def settings_customise_sources( 

581 cls, 

582 settings_cls: Type[BaseSettings], 

583 init_settings: PydanticBaseSettingsSource, 

584 env_settings: PydanticBaseSettingsSource, 

585 dotenv_settings: PydanticBaseSettingsSource, 

586 file_secret_settings: PydanticBaseSettingsSource, 

587 ) -> Tuple[PydanticBaseSettingsSource, ...]: 

588 cli: CliSettingsSource[BaseSettings] = CliSettingsSource( 

589 settings_cls, 

590 cli_parse_args=True, 

591 formatter_class=RawTextHelpFormatter, 

592 ) 

593 sys_args = pformat(sys.argv) 

594 logger.info("starting CLI with arguments:\n{}", sys_args) 

595 return ( 

596 cli, 

597 init_settings, 

598 YamlConfigSettingsSource(settings_cls), 

599 JsonConfigSettingsSource(settings_cls), 

600 ) 

601 

602 @model_validator(mode="before") 

603 @classmethod 

604 def _log(cls, data: Any): 

605 logger.info( 

606 "loaded CLI input:\n{}", 

607 pformat({k: v for k, v in data.items() if v is not None}), 

608 ) 

609 return data 

610 

611 def run(self): 

612 logger.info( 

613 "executing CLI command:\n{}", 

614 pformat({k: v for k, v in self.model_dump().items() if v is not None}), 

615 ) 

616 cmd = self.validate_format or self.test or self.package or self.predict 

617 assert cmd is not None 

618 cmd.run() 

619 

620 

621assert isinstance(Bioimageio.__doc__, str) 

622Bioimageio.__doc__ += f""" 

623 

624library versions: 

625 bioimageio.core {VERSION} 

626 bioimageio.spec {VERSION} 

627 

628spec format versions: 

629 model RDF {ModelDescr.implemented_format_version} 

630 dataset RDF {DatasetDescr.implemented_format_version} 

631 notebook RDF {NotebookDescr.implemented_format_version} 

632 

633""" 

634 

635 

636def _get_sample_ids( 

637 input_paths: Sequence[Mapping[MemberId, Path]], 

638) -> Sequence[SampleId]: 

639 """Get sample ids for given input paths, based on the common path per sample. 

640 

641 Falls back to sample01, samle02, etc...""" 

642 

643 matcher = SequenceMatcher() 

644 

645 def get_common_seq(seqs: Sequence[Sequence[str]]) -> Sequence[str]: 

646 """extract a common sequence from multiple sequences 

647 (order sensitive; strips whitespace and slashes) 

648 """ 

649 common = seqs[0] 

650 

651 for seq in seqs[1:]: 

652 if not seq: 

653 continue 

654 matcher.set_seqs(common, seq) 

655 i, _, size = matcher.find_longest_match() 

656 common = common[i : i + size] 

657 

658 if isinstance(common, str): 

659 common = common.strip().strip("/") 

660 else: 

661 common = [cs for c in common if (cs := c.strip().strip("/"))] 

662 

663 if not common: 

664 raise ValueError(f"failed to find common sequence for {seqs}") 

665 

666 return common 

667 

668 def get_shorter_diff(seqs: Sequence[Sequence[str]]) -> List[Sequence[str]]: 

669 """get a shorter sequence whose entries are still unique 

670 (order sensitive, not minimal sequence) 

671 """ 

672 min_seq_len = min(len(s) for s in seqs) 

673 # cut from the start 

674 for start in range(min_seq_len - 1, -1, -1): 

675 shortened = [s[start:] for s in seqs] 

676 if len(set(shortened)) == len(seqs): 

677 min_seq_len -= start 

678 break 

679 else: 

680 seen: Set[Sequence[str]] = set() 

681 dupes = [s for s in seqs if s in seen or seen.add(s)] 

682 raise ValueError(f"Found duplicate entries {dupes}") 

683 

684 # cut from the end 

685 for end in range(min_seq_len - 1, 1, -1): 

686 shortened = [s[:end] for s in shortened] 

687 if len(set(shortened)) == len(seqs): 

688 break 

689 

690 return shortened 

691 

692 full_tensor_ids = [ 

693 sorted( 

694 p.resolve().with_suffix("").as_posix() for p in input_sample_paths.values() 

695 ) 

696 for input_sample_paths in input_paths 

697 ] 

698 try: 

699 long_sample_ids = [get_common_seq(t) for t in full_tensor_ids] 

700 sample_ids = get_shorter_diff(long_sample_ids) 

701 except ValueError as e: 

702 raise ValueError(f"failed to extract sample ids: {e}") 

703 

704 return sample_ids