Coverage for src / bioimageio / core / _resource_tests.py: 73%

390 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-13 09:46 +0000

1import hashlib 

2import os 

3import platform 

4import subprocess 

5import sys 

6import warnings 

7from contextlib import nullcontext 

8from io import StringIO 

9from itertools import product 

10from pathlib import Path 

11from tempfile import TemporaryDirectory 

12from typing import ( 

13 Any, 

14 Callable, 

15 Dict, 

16 Hashable, 

17 List, 

18 Literal, 

19 Optional, 

20 Sequence, 

21 Set, 

22 Tuple, 

23 Union, 

24 overload, 

25) 

26 

27import numpy as np 

28from loguru import logger 

29from numpy.typing import NDArray 

30from typing_extensions import NotRequired, TypedDict, Unpack, assert_never, get_args 

31 

32from bioimageio.spec import ( 

33 AnyDatasetDescr, 

34 AnyModelDescr, 

35 BioimageioCondaEnv, 

36 DatasetDescr, 

37 InvalidDescr, 

38 LatestResourceDescr, 

39 ModelDescr, 

40 ResourceDescr, 

41 ValidationContext, 

42 build_description, 

43 dump_description, 

44 get_conda_env, 

45 load_description, 

46 save_bioimageio_package, 

47) 

48from bioimageio.spec._description_impl import DISCOVER 

49from bioimageio.spec._internal.common_nodes import ResourceDescrBase 

50from bioimageio.spec._internal.io import is_yaml_value 

51from bioimageio.spec._internal.io_utils import read_yaml, write_yaml 

52from bioimageio.spec._internal.types import ( 

53 AbsoluteTolerance, 

54 FormatVersionPlaceholder, 

55 MismatchedElementsPerMillion, 

56 RelativeTolerance, 

57) 

58from bioimageio.spec._internal.validation_context import get_validation_context 

59from bioimageio.spec.common import BioimageioYamlContent, PermissiveFileSource, Sha256 

60from bioimageio.spec.model import v0_4, v0_5 

61from bioimageio.spec.model.v0_5 import WeightsFormat 

62from bioimageio.spec.summary import ( 

63 ErrorEntry, 

64 InstalledPackage, 

65 ValidationDetail, 

66 ValidationSummary, 

67 WarningEntry, 

68) 

69 

70from . import __version__ 

71from ._prediction_pipeline import create_prediction_pipeline 

72from ._settings import settings 

73from .axis import AxisId, BatchSize 

74from .common import MemberId, SupportedWeightsFormat 

75from .digest_spec import get_test_input_sample, get_test_output_sample 

76from .io import save_tensor 

77from .sample import Sample 

78 

79CONDA_CMD = "conda.bat" if platform.system() == "Windows" else "conda" 

80 

81 

82class DeprecatedKwargs(TypedDict): 

83 absolute_tolerance: NotRequired[AbsoluteTolerance] 

84 relative_tolerance: NotRequired[RelativeTolerance] 

85 decimal: NotRequired[Optional[int]] 

86 

87 

88def enable_determinism( 

89 mode: Literal["seed_only", "full"] = "full", 

90 weight_formats: Optional[Sequence[SupportedWeightsFormat]] = None, 

91): 

92 """Seed and configure ML frameworks for maximum reproducibility. 

93 May degrade performance. Only recommended for testing reproducibility! 

94 

95 Seed any random generators and (if **mode**=="full") request ML frameworks to use 

96 deterministic algorithms. 

97 

98 Args: 

99 mode: determinism mode 

100 - 'seed_only' -- only set seeds, or 

101 - 'full' determinsm features (might degrade performance or throw exceptions) 

102 weight_formats: Limit deep learning importing deep learning frameworks 

103 based on weight_formats. 

104 E.g. this allows to avoid importing tensorflow when testing with pytorch. 

105 

106 Notes: 

107 - **mode** == "full" might degrade performance or throw exceptions. 

108 - Subsequent inference calls might still differ. Call before each function 

109 (sequence) that is expected to be reproducible. 

110 - Degraded performance: Use for testing reproducibility only! 

111 - Recipes: 

112 - [PyTorch](https://pytorch.org/docs/stable/notes/randomness.html) 

113 - [Keras](https://keras.io/examples/keras_recipes/reproducibility_recipes/) 

114 - [NumPy](https://numpy.org/doc/2.0/reference/random/generated/numpy.random.seed.html) 

115 """ 

116 try: 

117 try: 

118 import numpy.random 

119 except ImportError: 

120 pass 

121 else: 

122 numpy.random.seed(0) 

123 except Exception as e: 

124 logger.debug(str(e)) 

125 

126 if ( 

127 weight_formats is None 

128 or "pytorch_state_dict" in weight_formats 

129 or "torchscript" in weight_formats 

130 ): 

131 try: 

132 try: 

133 import torch 

134 except ImportError: 

135 pass 

136 else: 

137 _ = torch.manual_seed(0) 

138 torch.use_deterministic_algorithms(mode == "full") 

139 except Exception as e: 

140 logger.debug(str(e)) 

141 

142 if ( 

143 weight_formats is None 

144 or "tensorflow_saved_model_bundle" in weight_formats 

145 or "keras_hdf5" in weight_formats 

146 ): 

147 try: 

148 os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" 

149 try: 

150 import tensorflow as tf # pyright: ignore[reportMissingTypeStubs] 

151 except ImportError: 

152 pass 

153 else: 

154 tf.random.set_seed(0) 

155 if mode == "full": 

156 tf.config.experimental.enable_op_determinism() 

157 # TODO: find possibility to switch it off again?? 

158 except Exception as e: 

159 logger.debug(str(e)) 

160 

161 if weight_formats is None or "keras_hdf5" in weight_formats: 

162 try: 

163 try: 

164 import keras # pyright: ignore[reportMissingTypeStubs] 

165 except ImportError: 

166 pass 

167 else: 

168 keras.utils.set_random_seed(0) 

169 except Exception as e: 

170 logger.debug(str(e)) 

171 

172 

173def test_model( 

174 source: Union[v0_4.ModelDescr, v0_5.ModelDescr, PermissiveFileSource], 

175 weight_format: Optional[SupportedWeightsFormat] = None, 

176 devices: Optional[List[str]] = None, 

177 *, 

178 determinism: Literal["seed_only", "full"] = "seed_only", 

179 sha256: Optional[Sha256] = None, 

180 stop_early: bool = True, 

181 **deprecated: Unpack[DeprecatedKwargs], 

182) -> ValidationSummary: 

183 """Test model inference""" 

184 return test_description( 

185 source, 

186 weight_format=weight_format, 

187 devices=devices, 

188 determinism=determinism, 

189 expected_type="model", 

190 sha256=sha256, 

191 stop_early=stop_early, 

192 **deprecated, 

193 ) 

194 

195 

196def default_run_command(args: Sequence[str]): 

197 logger.info("running '{}'...", " ".join(args)) 

198 _ = subprocess.check_call(args) 

199 

200 

201def test_description( 

202 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

203 *, 

204 format_version: Union[FormatVersionPlaceholder, str] = "discover", 

205 weight_format: Optional[SupportedWeightsFormat] = None, 

206 devices: Optional[Sequence[str]] = None, 

207 determinism: Literal["seed_only", "full"] = "seed_only", 

208 expected_type: Optional[str] = None, 

209 sha256: Optional[Sha256] = None, 

210 stop_early: bool = True, 

211 runtime_env: Union[ 

212 Literal["currently-active", "as-described"], Path, BioimageioCondaEnv 

213 ] = ("currently-active"), 

214 run_command: Callable[[Sequence[str]], None] = default_run_command, 

215 working_dir: Optional[Union[os.PathLike[str], str]] = None, 

216 **deprecated: Unpack[DeprecatedKwargs], 

217) -> ValidationSummary: 

218 """Test a bioimage.io resource dynamically, 

219 for example run prediction of test tensors for models. 

220 

221 Args: 

222 source: model description source. 

223 weight_format: Weight format to test. 

224 Default: All weight formats present in **source**. 

225 devices: Devices to test with, e.g. 'cpu', 'cuda'. 

226 Default (may be weight format dependent): ['cuda'] if available, ['cpu'] otherwise. 

227 determinism: Modes to improve reproducibility of test outputs. 

228 expected_type: Assert an expected resource description `type`. 

229 sha256: Expected SHA256 value of **source**. 

230 (Ignored if **source** already is a loaded `ResourceDescr` object.) 

231 stop_early: Do not run further subtests after a failed one. 

232 runtime_env: (Experimental feature!) The Python environment to run the tests in 

233 - `"currently-active"`: Use active Python interpreter. 

234 - `"as-described"`: Use `bioimageio.spec.get_conda_env` to generate a conda 

235 environment YAML file based on the model weights description. 

236 - A `BioimageioCondaEnv` or a path to a conda environment YAML file. 

237 Note: The `bioimageio.core` dependency will be added automatically if not present. 

238 run_command: (Experimental feature!) Function to execute (conda) terminal commands in a subprocess. 

239 The function should raise an exception if the command fails. 

240 **run_command** is ignored if **runtime_env** is `"currently-active"`. 

241 working_dir: (for debugging) directory to save any temporary files 

242 (model packages, conda environments, test summaries). 

243 Defaults to a temporary directory. 

244 """ 

245 if runtime_env == "currently-active": 

246 rd = load_description_and_test( 

247 source, 

248 format_version=format_version, 

249 weight_format=weight_format, 

250 devices=devices, 

251 determinism=determinism, 

252 expected_type=expected_type, 

253 sha256=sha256, 

254 stop_early=stop_early, 

255 **deprecated, 

256 ) 

257 return rd.validation_summary 

258 

259 if runtime_env == "as-described": 

260 conda_env = None 

261 elif isinstance(runtime_env, (str, Path)): 

262 conda_env = BioimageioCondaEnv.model_validate(read_yaml(Path(runtime_env))) 

263 elif isinstance(runtime_env, BioimageioCondaEnv): 

264 conda_env = runtime_env 

265 else: 

266 assert_never(runtime_env) 

267 

268 if run_command is not default_run_command: 

269 try: 

270 run_command(["thiscommandshouldalwaysfail", "please"]) 

271 except Exception: 

272 pass 

273 else: 

274 raise RuntimeError( 

275 "given run_command does not raise an exception for a failing command" 

276 ) 

277 

278 if working_dir is None: 

279 td_kwargs: Dict[str, Any] = ( 

280 dict(ignore_cleanup_errors=True) if sys.version_info >= (3, 10) else {} 

281 ) 

282 working_dir_ctxt = TemporaryDirectory(**td_kwargs) 

283 else: 

284 working_dir_ctxt = nullcontext(working_dir) 

285 

286 with working_dir_ctxt as _d: 

287 working_dir = Path(_d) 

288 

289 if isinstance(source, ResourceDescrBase): 

290 descr = source 

291 elif isinstance(source, dict): 

292 context = get_validation_context().replace( 

293 perform_io_checks=True # make sure we perform io checks though 

294 ) 

295 

296 descr = build_description(source, context=context) 

297 else: 

298 descr = load_description(source, perform_io_checks=True) 

299 

300 if isinstance(descr, InvalidDescr): 

301 return descr.validation_summary 

302 elif isinstance(source, (dict, ResourceDescrBase)): 

303 file_source = save_bioimageio_package( 

304 descr, output_path=working_dir / "package.zip" 

305 ) 

306 else: 

307 file_source = source 

308 

309 _test_in_env( 

310 file_source, 

311 descr=descr, 

312 working_dir=working_dir, 

313 weight_format=weight_format, 

314 conda_env=conda_env, 

315 devices=devices, 

316 determinism=determinism, 

317 expected_type=expected_type, 

318 sha256=sha256, 

319 stop_early=stop_early, 

320 run_command=run_command, 

321 **deprecated, 

322 ) 

323 

324 return descr.validation_summary 

325 

326 

327def _test_in_env( 

328 source: PermissiveFileSource, 

329 *, 

330 descr: ResourceDescr, 

331 working_dir: Path, 

332 weight_format: Optional[SupportedWeightsFormat], 

333 conda_env: Optional[BioimageioCondaEnv], 

334 devices: Optional[Sequence[str]], 

335 determinism: Literal["seed_only", "full"], 

336 run_command: Callable[[Sequence[str]], None], 

337 stop_early: bool, 

338 expected_type: Optional[str], 

339 sha256: Optional[Sha256], 

340 **deprecated: Unpack[DeprecatedKwargs], 

341): 

342 """Test a bioimage.io resource in a given conda environment. 

343 Adds details to the existing validation summary of **descr**. 

344 """ 

345 if isinstance(descr, (v0_4.ModelDescr, v0_5.ModelDescr)): 

346 if weight_format is None: 

347 # run tests for all present weight formats 

348 all_present_wfs = [ 

349 wf for wf in get_args(WeightsFormat) if getattr(descr.weights, wf) 

350 ] 

351 ignore_wfs = [wf for wf in all_present_wfs if wf in ["tensorflow_js"]] 

352 logger.info( 

353 "Found weight formats {}. Start testing all{}...", 

354 all_present_wfs, 

355 f" (except: {', '.join(ignore_wfs)}) " if ignore_wfs else "", 

356 ) 

357 for wf in all_present_wfs: 

358 _test_in_env( 

359 source, 

360 descr=descr, 

361 working_dir=working_dir / wf, 

362 weight_format=wf, 

363 devices=devices, 

364 determinism=determinism, 

365 conda_env=conda_env, 

366 run_command=run_command, 

367 expected_type=expected_type, 

368 sha256=sha256, 

369 stop_early=stop_early, 

370 **deprecated, 

371 ) 

372 

373 return 

374 

375 if weight_format == "pytorch_state_dict": 

376 wf = descr.weights.pytorch_state_dict 

377 elif weight_format == "torchscript": 

378 wf = descr.weights.torchscript 

379 elif weight_format == "keras_hdf5": 

380 wf = descr.weights.keras_hdf5 

381 elif weight_format == "onnx": 

382 wf = descr.weights.onnx 

383 elif weight_format == "tensorflow_saved_model_bundle": 

384 wf = descr.weights.tensorflow_saved_model_bundle 

385 elif weight_format == "tensorflow_js": 

386 raise RuntimeError( 

387 "testing 'tensorflow_js' is not supported by bioimageio.core" 

388 ) 

389 else: 

390 assert_never(weight_format) 

391 assert wf is not None 

392 if conda_env is None: 

393 conda_env = get_conda_env(entry=wf) 

394 

395 test_loc = ("weights", weight_format) 

396 else: 

397 if conda_env is None: 

398 warnings.warn( 

399 "No conda environment description given for testing (And no default conda envs available for non-model descriptions)." 

400 ) 

401 return 

402 

403 test_loc = () 

404 

405 # remove name as we create a name based on the env description hash value 

406 conda_env.name = None 

407 

408 dumped_env = conda_env.model_dump(mode="json", exclude_none=True) 

409 if not is_yaml_value(dumped_env): 

410 raise ValueError(f"Failed to dump conda env to valid YAML {conda_env}") 

411 

412 env_io = StringIO() 

413 write_yaml(dumped_env, file=env_io) 

414 encoded_env = env_io.getvalue().encode() 

415 env_name = hashlib.sha256(encoded_env).hexdigest() 

416 

417 try: 

418 run_command(["where" if platform.system() == "Windows" else "which", CONDA_CMD]) 

419 except Exception as e: 

420 raise RuntimeError("Conda not available") from e 

421 

422 try: 

423 run_command([CONDA_CMD, "run", "-n", env_name, "python", "--version"]) 

424 except Exception as e: 

425 working_dir.mkdir(parents=True, exist_ok=True) 

426 path = working_dir / "env.yaml" 

427 try: 

428 _ = path.write_bytes(encoded_env) 

429 logger.debug("written conda env to {}", path) 

430 run_command( 

431 [ 

432 CONDA_CMD, 

433 "env", 

434 "create", 

435 "--yes", 

436 f"--file={path}", 

437 f"--name={env_name}", 

438 ] 

439 + (["--quiet"] if settings.CI else []) 

440 ) 

441 # double check that environment was created successfully 

442 run_command([CONDA_CMD, "run", "-n", env_name, "python", "--version"]) 

443 except Exception as e: 

444 descr.validation_summary.add_detail( 

445 ValidationDetail( 

446 name="Conda environment creation", 

447 status="failed", 

448 loc=test_loc, 

449 recommended_env=conda_env, 

450 errors=[ 

451 ErrorEntry( 

452 loc=test_loc, 

453 msg=str(e), 

454 type="conda", 

455 with_traceback=True, 

456 ) 

457 ], 

458 ) 

459 ) 

460 return 

461 else: 

462 descr.validation_summary.add_detail( 

463 ValidationDetail( 

464 name=f"Created conda environment '{env_name}'", 

465 status="passed", 

466 loc=test_loc, 

467 ) 

468 ) 

469 else: 

470 descr.validation_summary.add_detail( 

471 ValidationDetail( 

472 name=f"Found existing conda environment '{env_name}'", 

473 status="passed", 

474 loc=test_loc, 

475 ) 

476 ) 

477 

478 working_dir.mkdir(parents=True, exist_ok=True) 

479 summary_path = working_dir / "summary.json" 

480 assert not summary_path.exists(), "Summary file already exists" 

481 cmd = [] 

482 cmd_error = None 

483 for summary_path_arg_name in ("summary", "summary-path"): 

484 try: 

485 run_command( 

486 cmd := ( 

487 [ 

488 CONDA_CMD, 

489 "run", 

490 "-n", 

491 env_name, 

492 "bioimageio", 

493 "test", 

494 str(source), 

495 f"--{summary_path_arg_name}={summary_path.as_posix()}", 

496 f"--determinism={determinism}", 

497 ] 

498 + ([f"--weight-format={weight_format}"] if weight_format else []) 

499 + ([f"--expected-type={expected_type}"] if expected_type else []) 

500 + (["--stop-early"] if stop_early else []) 

501 ) 

502 ) 

503 except Exception as e: 

504 cmd_error = f"Command '{' '.join(cmd)}' returned with error: {e}." 

505 

506 if summary_path.exists(): 

507 break 

508 else: 

509 if cmd_error is not None: 

510 logger.warning(cmd_error) 

511 

512 descr.validation_summary.add_detail( 

513 ValidationDetail( 

514 name="run 'bioimageio test' command", 

515 recommended_env=conda_env, 

516 errors=[ 

517 ErrorEntry( 

518 loc=(), 

519 type="bioimageio cli", 

520 msg=f"test command '{' '.join(cmd)}' did not produce a summary file at {summary_path}", 

521 ) 

522 ], 

523 status="failed", 

524 ) 

525 ) 

526 return 

527 

528 # add relevant details from command summary 

529 command_summary = ValidationSummary.load_json(summary_path) 

530 for detail in command_summary.details: 

531 if detail.loc[: len(test_loc)] == test_loc or detail.status == "failed": 

532 descr.validation_summary.add_detail(detail) 

533 

534 

535@overload 

536def load_description_and_test( 

537 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

538 *, 

539 format_version: Literal["latest"], 

540 weight_format: Optional[SupportedWeightsFormat] = None, 

541 devices: Optional[Sequence[str]] = None, 

542 determinism: Literal["seed_only", "full"] = "seed_only", 

543 expected_type: Literal["model"], 

544 sha256: Optional[Sha256] = None, 

545 stop_early: bool = True, 

546 **deprecated: Unpack[DeprecatedKwargs], 

547) -> Union[ModelDescr, InvalidDescr]: ... 

548 

549 

550@overload 

551def load_description_and_test( 

552 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

553 *, 

554 format_version: Literal["latest"], 

555 weight_format: Optional[SupportedWeightsFormat] = None, 

556 devices: Optional[Sequence[str]] = None, 

557 determinism: Literal["seed_only", "full"] = "seed_only", 

558 expected_type: Literal["dataset"], 

559 sha256: Optional[Sha256] = None, 

560 stop_early: bool = True, 

561 **deprecated: Unpack[DeprecatedKwargs], 

562) -> Union[DatasetDescr, InvalidDescr]: ... 

563 

564 

565@overload 

566def load_description_and_test( 

567 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

568 *, 

569 format_version: Literal["latest"], 

570 weight_format: Optional[SupportedWeightsFormat] = None, 

571 devices: Optional[Sequence[str]] = None, 

572 determinism: Literal["seed_only", "full"] = "seed_only", 

573 expected_type: Optional[str] = None, 

574 sha256: Optional[Sha256] = None, 

575 stop_early: bool = True, 

576 **deprecated: Unpack[DeprecatedKwargs], 

577) -> Union[LatestResourceDescr, InvalidDescr]: ... 

578 

579 

580@overload 

581def load_description_and_test( 

582 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

583 *, 

584 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 

585 weight_format: Optional[SupportedWeightsFormat] = None, 

586 devices: Optional[Sequence[str]] = None, 

587 determinism: Literal["seed_only", "full"] = "seed_only", 

588 expected_type: Literal["model"], 

589 sha256: Optional[Sha256] = None, 

590 stop_early: bool = True, 

591 **deprecated: Unpack[DeprecatedKwargs], 

592) -> Union[AnyModelDescr, InvalidDescr]: ... 

593 

594 

595@overload 

596def load_description_and_test( 

597 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

598 *, 

599 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 

600 weight_format: Optional[SupportedWeightsFormat] = None, 

601 devices: Optional[Sequence[str]] = None, 

602 determinism: Literal["seed_only", "full"] = "seed_only", 

603 expected_type: Literal["dataset"], 

604 sha256: Optional[Sha256] = None, 

605 stop_early: bool = True, 

606 **deprecated: Unpack[DeprecatedKwargs], 

607) -> Union[AnyDatasetDescr, InvalidDescr]: ... 

608 

609 

610@overload 

611def load_description_and_test( 

612 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

613 *, 

614 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 

615 weight_format: Optional[SupportedWeightsFormat] = None, 

616 devices: Optional[Sequence[str]] = None, 

617 determinism: Literal["seed_only", "full"] = "seed_only", 

618 expected_type: Optional[str] = None, 

619 sha256: Optional[Sha256] = None, 

620 stop_early: bool = True, 

621 **deprecated: Unpack[DeprecatedKwargs], 

622) -> Union[ResourceDescr, InvalidDescr]: ... 

623 

624 

625def load_description_and_test( 

626 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

627 *, 

628 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 

629 weight_format: Optional[SupportedWeightsFormat] = None, 

630 devices: Optional[Sequence[str]] = None, 

631 determinism: Literal["seed_only", "full"] = "seed_only", 

632 expected_type: Optional[str] = None, 

633 sha256: Optional[Sha256] = None, 

634 stop_early: bool = True, 

635 **deprecated: Unpack[DeprecatedKwargs], 

636) -> Union[ResourceDescr, InvalidDescr]: 

637 """Test a bioimage.io resource dynamically, 

638 for example run prediction of test tensors for models. 

639 

640 See `test_description` for more details. 

641 

642 Returns: 

643 A (possibly invalid) resource description object 

644 with a populated `.validation_summary` attribute. 

645 """ 

646 if isinstance(source, ResourceDescrBase): 

647 root = source.root 

648 file_name = source.file_name 

649 if ( 

650 ( 

651 format_version 

652 not in ( 

653 DISCOVER, 

654 source.format_version, 

655 ".".join(source.format_version.split(".")[:2]), 

656 ) 

657 ) 

658 or (c := source.validation_summary.details[0].context) is None 

659 or not c.perform_io_checks 

660 ): 

661 logger.debug( 

662 "deserializing source to ensure we validate and test using format {} and perform io checks", 

663 format_version, 

664 ) 

665 source = dump_description(source) 

666 else: 

667 root = Path() 

668 file_name = None 

669 

670 if isinstance(source, ResourceDescrBase): 

671 rd = source 

672 elif isinstance(source, dict): 

673 # check context for a given root; default to root of source 

674 context = get_validation_context( 

675 ValidationContext(root=root, file_name=file_name) 

676 ).replace( 

677 perform_io_checks=True # make sure we perform io checks though 

678 ) 

679 

680 rd = build_description( 

681 source, 

682 format_version=format_version, 

683 context=context, 

684 ) 

685 else: 

686 rd = load_description( 

687 source, format_version=format_version, sha256=sha256, perform_io_checks=True 

688 ) 

689 

690 rd.validation_summary.env.add( 

691 InstalledPackage(name="bioimageio.core", version=__version__) 

692 ) 

693 

694 if expected_type is not None: 

695 _test_expected_resource_type(rd, expected_type) 

696 

697 if isinstance(rd, (v0_4.ModelDescr, v0_5.ModelDescr)): 

698 if weight_format is None: 

699 weight_formats: List[SupportedWeightsFormat] = [ 

700 w for w, we in rd.weights if we is not None 

701 ] # pyright: ignore[reportAssignmentType] 

702 else: 

703 weight_formats = [weight_format] 

704 

705 enable_determinism(determinism, weight_formats=weight_formats) 

706 for w in weight_formats: 

707 _test_model_inference(rd, w, devices, stop_early=stop_early, **deprecated) 

708 if stop_early and rd.validation_summary.status != "passed": 

709 break 

710 

711 if not isinstance(rd, v0_4.ModelDescr): 

712 _test_model_inference_parametrized( 

713 rd, w, devices, stop_early=stop_early 

714 ) 

715 if stop_early and rd.validation_summary.status != "passed": 

716 break 

717 

718 # TODO: add execution of jupyter notebooks 

719 # TODO: add more tests 

720 

721 return rd 

722 

723 

724def _get_tolerance( 

725 model: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

726 wf: SupportedWeightsFormat, 

727 m: MemberId, 

728 **deprecated: Unpack[DeprecatedKwargs], 

729) -> Tuple[RelativeTolerance, AbsoluteTolerance, MismatchedElementsPerMillion]: 

730 if isinstance(model, v0_5.ModelDescr): 

731 applicable = v0_5.ReproducibilityTolerance() 

732 

733 # check legacy test kwargs for weight format specific tolerance 

734 if model.config.bioimageio.model_extra is not None: 

735 for weights_format, test_kwargs in model.config.bioimageio.model_extra.get( 

736 "test_kwargs", {} 

737 ).items(): 

738 if wf == weights_format: 

739 applicable = v0_5.ReproducibilityTolerance( 

740 relative_tolerance=test_kwargs.get("relative_tolerance", 1e-3), 

741 absolute_tolerance=test_kwargs.get("absolute_tolerance", 1e-3), 

742 ) 

743 break 

744 

745 # check for weights format and output tensor specific tolerance 

746 for a in model.config.bioimageio.reproducibility_tolerance: 

747 if (not a.weights_formats or wf in a.weights_formats) and ( 

748 not a.output_ids or m in a.output_ids 

749 ): 

750 applicable = a 

751 break 

752 

753 rtol = applicable.relative_tolerance 

754 atol = applicable.absolute_tolerance 

755 mismatched_tol = applicable.mismatched_elements_per_million 

756 elif (decimal := deprecated.get("decimal")) is not None: 

757 warnings.warn( 

758 "The argument `decimal` has been deprecated in favour of" 

759 + " `relative_tolerance` and `absolute_tolerance`, with different" 

760 + " validation logic, using `numpy.testing.assert_allclose, see" 

761 + " 'https://numpy.org/doc/stable/reference/generated/" 

762 + " numpy.testing.assert_allclose.html'. Passing a value for `decimal`" 

763 + " will cause validation to revert to the old behaviour." 

764 ) 

765 atol = 1.5 * 10 ** (-decimal) 

766 rtol = 0 

767 mismatched_tol = 0 

768 else: 

769 # use given (deprecated) test kwargs 

770 atol = deprecated.get("absolute_tolerance", 1e-3) 

771 rtol = deprecated.get("relative_tolerance", 1e-3) 

772 mismatched_tol = 0 

773 

774 return rtol, atol, mismatched_tol 

775 

776 

777def _test_model_inference( 

778 model: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

779 weight_format: SupportedWeightsFormat, 

780 devices: Optional[Sequence[str]], 

781 stop_early: bool, 

782 **deprecated: Unpack[DeprecatedKwargs], 

783) -> None: 

784 test_name = f"Reproduce test outputs from test inputs ({weight_format})" 

785 logger.debug("starting '{}'", test_name) 

786 error_entries: List[ErrorEntry] = [] 

787 warning_entries: List[WarningEntry] = [] 

788 

789 def add_error_entry(msg: str, with_traceback: bool = False): 

790 error_entries.append( 

791 ErrorEntry( 

792 loc=("weights", weight_format), 

793 msg=msg, 

794 type="bioimageio.core", 

795 with_traceback=with_traceback, 

796 ) 

797 ) 

798 

799 def add_warning_entry(msg: str): 

800 warning_entries.append( 

801 WarningEntry( 

802 loc=("weights", weight_format), 

803 msg=msg, 

804 type="bioimageio.core", 

805 ) 

806 ) 

807 

808 try: 

809 test_input = get_test_input_sample(model) 

810 expected = get_test_output_sample(model) 

811 

812 with create_prediction_pipeline( 

813 bioimageio_model=model, devices=devices, weight_format=weight_format 

814 ) as prediction_pipeline: 

815 results = prediction_pipeline.predict_sample_without_blocking(test_input) 

816 

817 if len(results.members) != len(expected.members): 

818 add_error_entry( 

819 f"Expected {len(expected.members)} outputs, but got {len(results.members)}" 

820 ) 

821 

822 else: 

823 for m, expected in expected.members.items(): 

824 actual = results.members.get(m) 

825 if actual is None: 

826 add_error_entry("Output tensors for test case may not be None") 

827 if stop_early: 

828 break 

829 else: 

830 continue 

831 

832 if actual.dims != (dims := expected.dims): 

833 add_error_entry( 

834 f"Output '{m}' has dims {actual.dims}, but expected {expected.dims}" 

835 ) 

836 if stop_early: 

837 break 

838 else: 

839 continue 

840 

841 if actual.tagged_shape != expected.tagged_shape: 

842 add_error_entry( 

843 f"Output '{m}' has shape {actual.tagged_shape}, but expected {expected.tagged_shape}" 

844 ) 

845 if stop_early: 

846 break 

847 else: 

848 continue 

849 

850 try: 

851 expected_np = expected.data.to_numpy().astype(np.float32) 

852 del expected 

853 actual_np: NDArray[Any] = actual.data.to_numpy().astype(np.float32) 

854 

855 rtol, atol, mismatched_tol = _get_tolerance( 

856 model, wf=weight_format, m=m, **deprecated 

857 ) 

858 rtol_value = rtol * abs(expected_np) 

859 abs_diff = abs(actual_np - expected_np) 

860 mismatched = abs_diff > atol + rtol_value 

861 mismatched_elements = mismatched.sum().item() 

862 if not mismatched_elements: 

863 continue 

864 

865 actual_output_path = Path(f"actual_output_{m}_{weight_format}.npy") 

866 try: 

867 save_tensor(actual_output_path, actual) 

868 except Exception as e: 

869 logger.error( 

870 "Failed to save actual output tensor to {}: {}", 

871 actual_output_path, 

872 e, 

873 ) 

874 

875 mismatched_ppm = mismatched_elements / expected_np.size * 1e6 

876 abs_diff[~mismatched] = 0 # ignore non-mismatched elements 

877 

878 r_max_idx_flat = ( 

879 r_diff := (abs_diff / (abs(expected_np) + 1e-6)) 

880 ).argmax() 

881 r_max_idx = np.unravel_index(r_max_idx_flat, r_diff.shape) 

882 r_max = r_diff[r_max_idx].item() 

883 r_actual = actual_np[r_max_idx].item() 

884 r_expected = expected_np[r_max_idx].item() 

885 

886 # Calculate the max absolute difference with the relative tolerance subtracted 

887 abs_diff_wo_rtol: NDArray[np.float32] = abs_diff - rtol_value 

888 a_max_idx = np.unravel_index( 

889 abs_diff_wo_rtol.argmax(), abs_diff_wo_rtol.shape 

890 ) 

891 

892 a_max = abs_diff[a_max_idx].item() 

893 a_actual = actual_np[a_max_idx].item() 

894 a_expected = expected_np[a_max_idx].item() 

895 except Exception as e: 

896 msg = f"Output '{m}' disagrees with expected values." 

897 add_error_entry(msg) 

898 if stop_early: 

899 break 

900 else: 

901 msg = ( 

902 f"Output '{m}' disagrees with {mismatched_elements} of" 

903 + f" {expected_np.size} expected values" 

904 + f" ({mismatched_ppm:.1f} ppm)." 

905 + f"\n Max relative difference not accounted for by absolute tolerance ({atol:.2e}): {r_max:.2e}" 

906 + rf" (= \|{r_actual:.2e} - {r_expected:.2e}\|/\|{r_expected:.2e} + 1e-6\|)" 

907 + f" at {dict(zip(dims, r_max_idx))}" 

908 + f"\n Max absolute difference not accounted for by relative tolerance ({rtol:.2e}): {a_max:.2e}" 

909 + rf" (= \|{a_actual:.7e} - {a_expected:.7e}\|) at {dict(zip(dims, a_max_idx))}" 

910 + f"\n Saved actual output to {actual_output_path}." 

911 ) 

912 if mismatched_ppm > mismatched_tol: 

913 add_error_entry(msg) 

914 if stop_early: 

915 break 

916 else: 

917 add_warning_entry(msg) 

918 

919 except Exception as e: 

920 if get_validation_context().raise_errors: 

921 raise e 

922 

923 add_error_entry(str(e), with_traceback=True) 

924 

925 model.validation_summary.add_detail( 

926 ValidationDetail( 

927 name=test_name, 

928 loc=("weights", weight_format), 

929 status="failed" if error_entries else "passed", 

930 recommended_env=get_conda_env(entry=dict(model.weights)[weight_format]), 

931 errors=error_entries, 

932 warnings=warning_entries, 

933 ) 

934 ) 

935 

936 

937def _test_model_inference_parametrized( 

938 model: v0_5.ModelDescr, 

939 weight_format: SupportedWeightsFormat, 

940 devices: Optional[Sequence[str]], 

941 *, 

942 stop_early: bool, 

943) -> None: 

944 if not any( 

945 isinstance(a.size, v0_5.ParameterizedSize) 

946 for ipt in model.inputs 

947 for a in ipt.axes 

948 ): 

949 # no parameterized sizes => set n=0 

950 ns: Set[v0_5.ParameterizedSize_N] = {0} 

951 else: 

952 ns = {0, 1, 2} 

953 

954 given_batch_sizes = { 

955 a.size 

956 for ipt in model.inputs 

957 for a in ipt.axes 

958 if isinstance(a, v0_5.BatchAxis) 

959 } 

960 if given_batch_sizes: 

961 batch_sizes = {gbs for gbs in given_batch_sizes if gbs is not None} 

962 if not batch_sizes: 

963 # only arbitrary batch sizes 

964 batch_sizes = {1, 2} 

965 else: 

966 # no batch axis 

967 batch_sizes = {1} 

968 

969 test_cases: Set[Tuple[BatchSize, v0_5.ParameterizedSize_N]] = { 

970 (b, n) for b, n in product(sorted(batch_sizes), sorted(ns)) 

971 } 

972 logger.info( 

973 "Testing inference with '{}' for {} different inputs (B, N): {}", 

974 weight_format, 

975 len(test_cases), 

976 test_cases, 

977 ) 

978 

979 def generate_test_cases(): 

980 tested: Set[Hashable] = set() 

981 

982 def get_ns(n: int): 

983 return { 

984 (t.id, a.id): n 

985 for t in model.inputs 

986 for a in t.axes 

987 if isinstance(a.size, v0_5.ParameterizedSize) 

988 } 

989 

990 for batch_size, n in sorted(test_cases): 

991 input_target_sizes, expected_output_sizes = model.get_axis_sizes( 

992 get_ns(n), batch_size=batch_size 

993 ) 

994 hashable_target_size = tuple( 

995 (k, input_target_sizes[k]) for k in sorted(input_target_sizes) 

996 ) 

997 if hashable_target_size in tested: 

998 continue 

999 else: 

1000 tested.add(hashable_target_size) 

1001 

1002 resized_test_inputs = Sample( 

1003 members={ 

1004 t.id: ( 

1005 test_input.members[t.id].resize_to( 

1006 { 

1007 aid: s 

1008 for (tid, aid), s in input_target_sizes.items() 

1009 if tid == t.id 

1010 }, 

1011 ) 

1012 ) 

1013 for t in model.inputs 

1014 }, 

1015 stat=test_input.stat, 

1016 id=test_input.id, 

1017 ) 

1018 expected_output_shapes = { 

1019 t.id: { 

1020 aid: s 

1021 for (tid, aid), s in expected_output_sizes.items() 

1022 if tid == t.id 

1023 } 

1024 for t in model.outputs 

1025 } 

1026 yield n, batch_size, resized_test_inputs, expected_output_shapes 

1027 

1028 try: 

1029 test_input = get_test_input_sample(model) 

1030 

1031 with create_prediction_pipeline( 

1032 bioimageio_model=model, devices=devices, weight_format=weight_format 

1033 ) as prediction_pipeline: 

1034 for n, batch_size, inputs, exptected_output_shape in generate_test_cases(): 

1035 error: Optional[str] = None 

1036 try: 

1037 result = prediction_pipeline.predict_sample_without_blocking(inputs) 

1038 except Exception as e: 

1039 error = str(e) 

1040 else: 

1041 if len(result.members) != len(exptected_output_shape): 

1042 error = ( 

1043 f"Expected {len(exptected_output_shape)} outputs," 

1044 + f" but got {len(result.members)}" 

1045 ) 

1046 

1047 else: 

1048 for m, exp in exptected_output_shape.items(): 

1049 res = result.members.get(m) 

1050 if res is None: 

1051 error = "Output tensors may not be None for test case" 

1052 break 

1053 

1054 diff: Dict[AxisId, int] = {} 

1055 for a, s in res.sizes.items(): 

1056 if isinstance((e_aid := exp[AxisId(a)]), int): 

1057 if s != e_aid: 

1058 diff[AxisId(a)] = s 

1059 elif ( 

1060 s < e_aid.min 

1061 or e_aid.max is not None 

1062 and s > e_aid.max 

1063 ): 

1064 diff[AxisId(a)] = s 

1065 if diff: 

1066 error = ( 

1067 f"(n={n}) Expected output shape {exp}," 

1068 + f" but got {res.sizes} (diff: {diff})" 

1069 ) 

1070 break 

1071 

1072 model.validation_summary.add_detail( 

1073 ValidationDetail( 

1074 name=f"Run {weight_format} inference for inputs with" 

1075 + f" batch_size: {batch_size} and size parameter n: {n}", 

1076 loc=("weights", weight_format), 

1077 status="passed" if error is None else "failed", 

1078 errors=( 

1079 [] 

1080 if error is None 

1081 else [ 

1082 ErrorEntry( 

1083 loc=("weights", weight_format), 

1084 msg=error, 

1085 type="bioimageio.core", 

1086 ) 

1087 ] 

1088 ), 

1089 ) 

1090 ) 

1091 if stop_early and error is not None: 

1092 break 

1093 except Exception as e: 

1094 if get_validation_context().raise_errors: 

1095 raise e 

1096 

1097 model.validation_summary.add_detail( 

1098 ValidationDetail( 

1099 name=f"Run {weight_format} inference for parametrized inputs", 

1100 status="failed", 

1101 loc=("weights", weight_format), 

1102 errors=[ 

1103 ErrorEntry( 

1104 loc=("weights", weight_format), 

1105 msg=str(e), 

1106 type="bioimageio.core", 

1107 with_traceback=True, 

1108 ) 

1109 ], 

1110 ) 

1111 ) 

1112 

1113 

1114def _test_expected_resource_type( 

1115 rd: Union[InvalidDescr, ResourceDescr], expected_type: str 

1116): 

1117 has_expected_type = rd.type == expected_type 

1118 rd.validation_summary.details.append( 

1119 ValidationDetail( 

1120 name="Has expected resource type", 

1121 status="passed" if has_expected_type else "failed", 

1122 loc=("type",), 

1123 errors=( 

1124 [] 

1125 if has_expected_type 

1126 else [ 

1127 ErrorEntry( 

1128 loc=("type",), 

1129 type="type", 

1130 msg=f"Expected type {expected_type}, found {rd.type}", 

1131 ) 

1132 ] 

1133 ), 

1134 ) 

1135 ) 

1136 

1137 

1138# TODO: Implement `debug_model()` 

1139# def debug_model( 

1140# model_rdf: Union[RawResourceDescr, ResourceDescr, URI, Path, str], 

1141# *, 

1142# weight_format: Optional[WeightsFormat] = None, 

1143# devices: Optional[List[str]] = None, 

1144# ): 

1145# """Run the model test and return dict with inputs, results, expected results and intermediates. 

1146 

1147# Returns dict with tensors "inputs", "inputs_processed", "outputs_raw", "outputs", "expected" and "diff". 

1148# """ 

1149# inputs_raw: Optional = None 

1150# inputs_processed: Optional = None 

1151# outputs_raw: Optional = None 

1152# outputs: Optional = None 

1153# expected: Optional = None 

1154# diff: Optional = None 

1155 

1156# model = load_description( 

1157# model_rdf, weights_priority_order=None if weight_format is None else [weight_format] 

1158# ) 

1159# if not isinstance(model, Model): 

1160# raise ValueError(f"Not a bioimageio.model: {model_rdf}") 

1161 

1162# prediction_pipeline = create_prediction_pipeline( 

1163# bioimageio_model=model, devices=devices, weight_format=weight_format 

1164# ) 

1165# inputs = [ 

1166# xr.DataArray(load_array(str(in_path)), dims=input_spec.axes) 

1167# for in_path, input_spec in zip(model.test_inputs, model.inputs) 

1168# ] 

1169# input_dict = {input_spec.name: input for input_spec, input in zip(model.inputs, inputs)} 

1170 

1171# # keep track of the non-processed inputs 

1172# inputs_raw = [deepcopy(input) for input in inputs] 

1173 

1174# computed_measures = {} 

1175 

1176# prediction_pipeline.apply_preprocessing(input_dict, computed_measures) 

1177# inputs_processed = list(input_dict.values()) 

1178# outputs_raw = prediction_pipeline.predict(*inputs_processed) 

1179# output_dict = {output_spec.name: deepcopy(output) for output_spec, output in zip(model.outputs, outputs_raw)} 

1180# prediction_pipeline.apply_postprocessing(output_dict, computed_measures) 

1181# outputs = list(output_dict.values()) 

1182 

1183# if isinstance(outputs, (np.ndarray, xr.DataArray)): 

1184# outputs = [outputs] 

1185 

1186# expected = [ 

1187# xr.DataArray(load_array(str(out_path)), dims=output_spec.axes) 

1188# for out_path, output_spec in zip(model.test_outputs, model.outputs) 

1189# ] 

1190# if len(outputs) != len(expected): 

1191# error = f"Number of outputs and number of expected outputs disagree: {len(outputs)} != {len(expected)}" 

1192# print(error) 

1193# else: 

1194# diff = [] 

1195# for res, exp in zip(outputs, expected): 

1196# diff.append(res - exp) 

1197 

1198# return { 

1199# "inputs": inputs_raw, 

1200# "inputs_processed": inputs_processed, 

1201# "outputs_raw": outputs_raw, 

1202# "outputs": outputs, 

1203# "expected": expected, 

1204# "diff": diff, 

1205# }