Coverage for src/bioimageio/core/_resource_tests.py: 57%

439 statements  

« prev     ^ index     » next       coverage.py v7.14.2, created at 2026-06-22 16:54 +0000

1import hashlib 

2import os 

3import platform 

4import subprocess 

5import sys 

6import warnings 

7from contextlib import nullcontext 

8from copy import deepcopy 

9from io import StringIO 

10from itertools import product 

11from pathlib import Path 

12from tempfile import TemporaryDirectory 

13from typing import ( 

14 Any, 

15 Callable, 

16 Dict, 

17 Hashable, 

18 List, 

19 Literal, 

20 Optional, 

21 Sequence, 

22 Set, 

23 Tuple, 

24 Union, 

25 overload, 

26) 

27 

28import numpy as np 

29from loguru import logger 

30from numpy.typing import NDArray 

31from typing_extensions import NotRequired, TypedDict, Unpack, assert_never, get_args 

32 

33from bioimageio.spec import ( 

34 AnyDatasetDescr, 

35 AnyModelDescr, 

36 BioimageioCondaEnv, 

37 DatasetDescr, 

38 InvalidDescr, 

39 LatestResourceDescr, 

40 ModelDescr, 

41 ResourceDescr, 

42 ValidationContext, 

43 build_description, 

44 dump_description, 

45 get_conda_env, 

46 load_description, 

47 save_bioimageio_package, 

48) 

49from bioimageio.spec._description_impl import DISCOVER 

50from bioimageio.spec._internal.common_nodes import ResourceDescrBase 

51from bioimageio.spec._internal.io import is_yaml_value 

52from bioimageio.spec._internal.io_utils import read_yaml, write_yaml 

53from bioimageio.spec._internal.types import ( 

54 AbsoluteTolerance, 

55 FormatVersionPlaceholder, 

56 MismatchedElementsPerMillion, 

57 RelativeTolerance, 

58) 

59from bioimageio.spec._internal.validation_context import get_validation_context 

60from bioimageio.spec._internal.warning_levels import INFO, WARNING, WarningSeverity 

61from bioimageio.spec.common import BioimageioYamlContent, PermissiveFileSource, Sha256 

62from bioimageio.spec.model import v0_4, v0_5 

63from bioimageio.spec.model.v0_5 import WeightsFormat 

64from bioimageio.spec.summary import ( 

65 ErrorEntry, 

66 InstalledPackage, 

67 ValidationDetail, 

68 ValidationSummary, 

69 WarningEntry, 

70) 

71 

72from . import __version__ 

73from ._prediction_pipeline import create_prediction_pipeline 

74from ._settings import settings 

75from .axis import AxisId, BatchSize 

76from .common import MemberId, SupportedWeightsFormat 

77from .digest_spec import get_test_input_sample, get_test_output_sample 

78from .io import save_tensor 

79from .sample import Sample 

80from .tensor import Tensor 

81 

82CONDA_CMD = "conda.bat" if platform.system() == "Windows" else "conda" 

83 

84 

85class DeprecatedKwargs(TypedDict): 

86 absolute_tolerance: NotRequired[AbsoluteTolerance] 

87 relative_tolerance: NotRequired[RelativeTolerance] 

88 decimal: NotRequired[Optional[int]] 

89 

90 

91def enable_determinism( 

92 mode: Literal["seed_only", "full"] = "full", 

93 weight_formats: Optional[Sequence[SupportedWeightsFormat]] = None, 

94): 

95 """Seed and configure ML frameworks for maximum reproducibility. 

96 May degrade performance. Only recommended for testing reproducibility! 

97 

98 Seed any random generators and (if **mode**=="full") request ML frameworks to use 

99 deterministic algorithms. 

100 

101 Args: 

102 mode: determinism mode 

103 - 'seed_only' -- only set seeds, or 

104 - 'full' determinsm features (might degrade performance or throw exceptions) 

105 weight_formats: Limit deep learning importing deep learning frameworks 

106 based on weight_formats. 

107 E.g. this allows to avoid importing tensorflow when testing with pytorch. 

108 

109 Notes: 

110 - **mode** == "full" might degrade performance or throw exceptions. 

111 - Subsequent inference calls might still differ. Call before each function 

112 (sequence) that is expected to be reproducible. 

113 - Degraded performance: Use for testing reproducibility only! 

114 - Recipes: 

115 - [PyTorch](https://pytorch.org/docs/stable/notes/randomness.html) 

116 - [Keras](https://keras.io/examples/keras_recipes/reproducibility_recipes/) 

117 - [NumPy](https://numpy.org/doc/2.0/reference/random/generated/numpy.random.seed.html) 

118 """ 

119 try: 

120 try: 

121 import numpy.random 

122 except ImportError: 

123 pass 

124 else: 

125 numpy.random.seed(0) 

126 except Exception as e: 

127 logger.debug(str(e)) 

128 

129 if ( 

130 weight_formats is None 

131 or "pytorch_state_dict" in weight_formats 

132 or "torchscript" in weight_formats 

133 ): 

134 try: 

135 try: 

136 import torch 

137 except ImportError: 

138 pass 

139 else: 

140 _ = torch.manual_seed(0) 

141 torch.use_deterministic_algorithms(mode == "full") 

142 except Exception as e: 

143 logger.debug(str(e)) 

144 

145 if ( 

146 weight_formats is None 

147 or "tensorflow_saved_model_bundle" in weight_formats 

148 or "keras_hdf5" in weight_formats 

149 ): 

150 try: 

151 os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" 

152 try: 

153 import tensorflow as tf 

154 except ImportError: 

155 pass 

156 else: 

157 tf.random.set_seed(0) 

158 if mode == "full": 

159 tf.config.experimental.enable_op_determinism() 

160 # TODO: find possibility to switch it off again?? 

161 except Exception as e: 

162 logger.debug(str(e)) 

163 

164 if weight_formats is None or "keras_hdf5" in weight_formats: 

165 try: 

166 try: 

167 import keras # pyright: ignore[reportMissingTypeStubs] 

168 except ImportError: 

169 pass 

170 else: 

171 keras.utils.set_random_seed(0) 

172 except Exception as e: 

173 logger.debug(str(e)) 

174 

175 

176def test_model( 

177 source: Union[v0_4.ModelDescr, v0_5.ModelDescr, PermissiveFileSource], 

178 weight_format: Optional[SupportedWeightsFormat] = None, 

179 devices: Optional[List[str]] = None, 

180 *, 

181 determinism: Literal["seed_only", "full"] = "seed_only", 

182 sha256: Optional[Sha256] = None, 

183 stop_early: bool = False, 

184 working_dir: Optional[Union[os.PathLike[str], str]] = None, 

185 **deprecated: Unpack[DeprecatedKwargs], 

186) -> ValidationSummary: 

187 """Test model inference""" 

188 return test_description( 

189 source, 

190 weight_format=weight_format, 

191 devices=devices, 

192 determinism=determinism, 

193 expected_type="model", 

194 sha256=sha256, 

195 stop_early=stop_early, 

196 working_dir=working_dir, 

197 **deprecated, 

198 ) 

199 

200 

201def default_run_command(args: Sequence[str]): 

202 logger.info("running '{}'...", " ".join(args)) 

203 _ = subprocess.check_call(args) 

204 

205 

206def test_description( 

207 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

208 *, 

209 format_version: Union[FormatVersionPlaceholder, str] = "discover", 

210 weight_format: Optional[SupportedWeightsFormat] = None, 

211 devices: Optional[Sequence[str]] = None, 

212 determinism: Literal["seed_only", "full"] = "seed_only", 

213 expected_type: Optional[str] = None, 

214 sha256: Optional[Sha256] = None, 

215 stop_early: bool = False, 

216 runtime_env: Union[ 

217 Literal["currently-active", "as-described"], Path, BioimageioCondaEnv 

218 ] = ("currently-active"), 

219 run_command: Callable[[Sequence[str]], None] = default_run_command, 

220 working_dir: Optional[Union[os.PathLike[str], str]] = None, 

221 **deprecated: Unpack[DeprecatedKwargs], 

222) -> ValidationSummary: 

223 """Test a bioimage.io resource dynamically, 

224 for example run prediction of test tensors for models. 

225 

226 Args: 

227 source: model description source. 

228 weight_format: Weight format to test. 

229 Default: All weight formats present in **source**. 

230 devices: Devices to test with, e.g. 'cpu', 'cuda'. 

231 Default (may be weight format dependent): ['cuda'] if available, ['cpu'] otherwise. 

232 determinism: Modes to improve reproducibility of test outputs. 

233 expected_type: Assert an expected resource description `type`. 

234 sha256: Expected SHA256 value of **source**. 

235 (Ignored if **source** already is a loaded `ResourceDescr` object.) 

236 stop_early: Do not run further subtests after a failed one. 

237 runtime_env: (Experimental feature!) The Python environment to run the tests in 

238 - `"currently-active"`: Use active Python interpreter. 

239 - `"as-described"`: Use `bioimageio.spec.get_conda_env` to generate a conda 

240 environment YAML file based on the model weights description. 

241 - A `BioimageioCondaEnv` or a path to a conda environment YAML file. 

242 Note: The `bioimageio.core` dependency will be added automatically if not present. 

243 run_command: (Experimental feature!) Function to execute (conda) terminal commands in a subprocess. 

244 The function should raise an exception if the command fails. 

245 **run_command** is ignored if **runtime_env** is `"currently-active"`. 

246 working_dir: (for debugging) directory to save any temporary files 

247 (model packages, conda environments, test summaries). 

248 Defaults to a temporary directory. 

249 """ 

250 if runtime_env == "currently-active": 

251 rd = load_description_and_test( 

252 source, 

253 format_version=format_version, 

254 weight_format=weight_format, 

255 devices=devices, 

256 determinism=determinism, 

257 expected_type=expected_type, 

258 sha256=sha256, 

259 stop_early=stop_early, 

260 working_dir=working_dir, 

261 **deprecated, 

262 ) 

263 return rd.validation_summary 

264 

265 if runtime_env == "as-described": 

266 conda_env = None 

267 elif isinstance(runtime_env, (str, Path)): 

268 conda_env = BioimageioCondaEnv.model_validate(read_yaml(Path(runtime_env))) 

269 elif isinstance(runtime_env, BioimageioCondaEnv): 

270 conda_env = runtime_env 

271 else: 

272 assert_never(runtime_env) 

273 

274 if run_command is not default_run_command: 

275 try: 

276 run_command(["thiscommandshouldalwaysfail", "please"]) 

277 except Exception: 

278 pass 

279 else: 

280 raise RuntimeError( 

281 "given run_command does not raise an exception for a failing command" 

282 ) 

283 

284 verbose = working_dir is not None 

285 if working_dir is None: 

286 td_kwargs: Dict[str, Any] = ( 

287 dict(ignore_cleanup_errors=True) if sys.version_info >= (3, 10) else {} 

288 ) 

289 working_dir_ctxt = TemporaryDirectory(**td_kwargs) 

290 else: 

291 working_dir_ctxt = nullcontext(working_dir) 

292 

293 with working_dir_ctxt as _d: 

294 working_dir = Path(_d) 

295 

296 if isinstance(source, ResourceDescrBase): 

297 descr = source 

298 elif isinstance(source, dict): 

299 context = get_validation_context().replace( 

300 perform_io_checks=True # make sure we perform io checks though 

301 ) 

302 

303 descr = build_description(source, context=context) 

304 else: 

305 descr = load_description(source, perform_io_checks=True) 

306 

307 if isinstance(descr, InvalidDescr): 

308 return descr.validation_summary 

309 elif isinstance(source, (dict, ResourceDescrBase)): 

310 file_source = save_bioimageio_package( 

311 descr, output_path=working_dir / "package.zip" 

312 ) 

313 else: 

314 file_source = source 

315 

316 # elevate status valid-format to passed and start testing 

317 descr.validation_summary.status = "passed" 

318 _test_in_env( 

319 file_source, 

320 descr=descr, 

321 working_dir=working_dir, 

322 weight_format=weight_format, 

323 conda_env=conda_env, 

324 devices=devices, 

325 determinism=determinism, 

326 expected_type=expected_type, 

327 sha256=sha256, 

328 stop_early=stop_early, 

329 run_command=run_command, 

330 verbose=verbose, 

331 **deprecated, 

332 ) 

333 

334 return descr.validation_summary 

335 

336 

337def _test_in_env( 

338 source: PermissiveFileSource, 

339 *, 

340 descr: ResourceDescr, 

341 working_dir: Path, 

342 weight_format: Optional[SupportedWeightsFormat], 

343 conda_env: Optional[BioimageioCondaEnv], 

344 devices: Optional[Sequence[str]], 

345 determinism: Literal["seed_only", "full"], 

346 run_command: Callable[[Sequence[str]], None], 

347 stop_early: bool, 

348 expected_type: Optional[str], 

349 sha256: Optional[Sha256], 

350 verbose: bool, 

351 **deprecated: Unpack[DeprecatedKwargs], 

352): 

353 """Test a bioimage.io resource in a given conda environment. 

354 Adds details to the existing validation summary of **descr**. 

355 """ 

356 if isinstance(descr, (v0_4.ModelDescr, v0_5.ModelDescr)): 

357 if weight_format is None: 

358 # run tests for all present weight formats 

359 all_present_wfs = [ 

360 wf for wf in get_args(WeightsFormat) if getattr(descr.weights, wf) 

361 ] 

362 ignore_wfs = [wf for wf in all_present_wfs if wf in ["tensorflow_js"]] 

363 logger.info( 

364 "Found weight formats {}. Start testing all{}...", 

365 all_present_wfs, 

366 f" (except: {', '.join(ignore_wfs)}) " if ignore_wfs else "", 

367 ) 

368 for wf in all_present_wfs: 

369 _test_in_env( 

370 source, 

371 descr=descr, 

372 working_dir=working_dir / wf, 

373 weight_format=wf, 

374 devices=devices, 

375 determinism=determinism, 

376 conda_env=conda_env, 

377 run_command=run_command, 

378 expected_type=expected_type, 

379 sha256=sha256, 

380 stop_early=stop_early, 

381 verbose=verbose, 

382 **deprecated, 

383 ) 

384 

385 return 

386 

387 if weight_format == "pytorch_state_dict": 

388 wf = descr.weights.pytorch_state_dict 

389 elif weight_format == "torchscript": 

390 wf = descr.weights.torchscript 

391 elif weight_format == "keras_hdf5": 

392 wf = descr.weights.keras_hdf5 

393 elif weight_format == "onnx": 

394 wf = descr.weights.onnx 

395 elif weight_format == "tensorflow_saved_model_bundle": 

396 wf = descr.weights.tensorflow_saved_model_bundle 

397 elif weight_format == "keras_v3": 

398 if isinstance(descr, v0_4.ModelDescr): 

399 raise ValueError( 

400 "Weight format 'keras_v3' is not supported in v0.4 model descriptions. use format version >= 0.5" 

401 ) 

402 

403 wf = descr.weights.keras_v3 

404 elif weight_format == "tensorflow_js": 

405 raise RuntimeError( 

406 "testing 'tensorflow_js' is not supported by bioimageio.core" 

407 ) 

408 else: 

409 assert_never(weight_format) 

410 

411 assert wf is not None 

412 if conda_env is None: 

413 conda_env = get_conda_env(entry=wf) 

414 

415 test_loc = ("weights", weight_format) 

416 else: 

417 if conda_env is None: 

418 warnings.warn( 

419 "No conda environment description given for testing (And no default conda envs available for non-model descriptions)." 

420 ) 

421 return 

422 

423 test_loc = () 

424 

425 # remove name as we create a name based on the env description hash value 

426 conda_env.name = None 

427 

428 dumped_env = conda_env.model_dump(mode="json", exclude_none=True) 

429 if not is_yaml_value(dumped_env): 

430 raise ValueError(f"Failed to dump conda env to valid YAML {conda_env}") 

431 

432 env_io = StringIO() 

433 write_yaml(dumped_env, file=env_io) 

434 encoded_env = env_io.getvalue().encode() 

435 env_name = hashlib.sha256(encoded_env).hexdigest() 

436 

437 try: 

438 run_command(["where" if platform.system() == "Windows" else "which", CONDA_CMD]) 

439 except Exception as e: 

440 raise RuntimeError("Conda not available") from e 

441 

442 try: 

443 run_command([CONDA_CMD, "run", "-n", env_name, "python", "--version"]) 

444 except Exception: 

445 working_dir.mkdir(parents=True, exist_ok=True) 

446 path = working_dir / "env.yaml" 

447 try: 

448 _ = path.write_bytes(encoded_env) 

449 logger.debug("written conda env to {}", path) 

450 run_command( 

451 [ 

452 CONDA_CMD, 

453 "env", 

454 "create", 

455 "--yes", 

456 f"--file={path}", 

457 f"--name={env_name}", 

458 ] 

459 + (["--quiet"] if settings.CI else []) 

460 ) 

461 # double check that environment was created successfully 

462 run_command([CONDA_CMD, "run", "-n", env_name, "python", "--version"]) 

463 except Exception as e: 

464 descr.validation_summary.add_detail( 

465 ValidationDetail( 

466 name="Conda environment creation", 

467 status="failed", 

468 loc=test_loc, 

469 recommended_env=conda_env, 

470 errors=[ 

471 ErrorEntry( 

472 loc=test_loc, 

473 msg=str(e), 

474 type="conda", 

475 with_traceback=True, 

476 ) 

477 ], 

478 ) 

479 ) 

480 return 

481 else: 

482 descr.validation_summary.add_detail( 

483 ValidationDetail( 

484 name=f"Created conda environment '{env_name}'", 

485 status="passed", 

486 loc=test_loc, 

487 ) 

488 ) 

489 else: 

490 descr.validation_summary.add_detail( 

491 ValidationDetail( 

492 name=f"Found existing conda environment '{env_name}'", 

493 status="passed", 

494 loc=test_loc, 

495 ) 

496 ) 

497 

498 working_dir.mkdir(parents=True, exist_ok=True) 

499 summary_path = working_dir / "summary.json" 

500 assert not summary_path.exists(), "Summary file already exists" 

501 cmd = [] 

502 cmd_error = None 

503 for summary_path_arg_name in ("summary", "summary-path"): 

504 try: 

505 run_command( 

506 cmd := ( 

507 [ 

508 CONDA_CMD, 

509 "run", 

510 "-n", 

511 env_name, 

512 "bioimageio", 

513 "test", 

514 str(source), 

515 f"--{summary_path_arg_name}={summary_path.as_posix()}", 

516 f"--determinism={determinism}", 

517 ] 

518 + ([f"--weight-format={weight_format}"] if weight_format else []) 

519 + ([f"--expected-type={expected_type}"] if expected_type else []) 

520 + (["--stop-early"] if stop_early else []) 

521 ) 

522 ) 

523 except Exception as e: 

524 cmd_error = f"Command '{' '.join(cmd)}' returned with error: {e}." 

525 

526 if summary_path.exists(): 

527 break 

528 else: 

529 if cmd_error is not None: 

530 logger.warning(cmd_error) 

531 

532 descr.validation_summary.add_detail( 

533 ValidationDetail( 

534 name="run 'bioimageio test' command", 

535 recommended_env=conda_env, 

536 errors=[ 

537 ErrorEntry( 

538 loc=(), 

539 type="bioimageio cli", 

540 msg=f"test command '{' '.join(cmd)}' did not produce a summary file at {summary_path}", 

541 ) 

542 ], 

543 status="failed", 

544 ) 

545 ) 

546 return 

547 

548 # add relevant details from command summary 

549 command_summary = ValidationSummary.load_json(summary_path) 

550 for detail in command_summary.details: 

551 if detail.loc[: len(test_loc)] == test_loc or detail.status == "failed": 

552 descr.validation_summary.add_detail(detail) 

553 

554 

555@overload 

556def load_description_and_test( 

557 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

558 *, 

559 format_version: Literal["latest"], 

560 weight_format: Optional[SupportedWeightsFormat] = None, 

561 devices: Optional[Sequence[str]] = None, 

562 determinism: Literal["seed_only", "full"] = "seed_only", 

563 expected_type: Literal["model"], 

564 sha256: Optional[Sha256] = None, 

565 stop_early: bool = False, 

566 working_dir: Optional[Union[os.PathLike[str], str]] = None, 

567 **deprecated: Unpack[DeprecatedKwargs], 

568) -> Union[ModelDescr, InvalidDescr]: ... 

569 

570 

571@overload 

572def load_description_and_test( 

573 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

574 *, 

575 format_version: Literal["latest"], 

576 weight_format: Optional[SupportedWeightsFormat] = None, 

577 devices: Optional[Sequence[str]] = None, 

578 determinism: Literal["seed_only", "full"] = "seed_only", 

579 expected_type: Literal["dataset"], 

580 sha256: Optional[Sha256] = None, 

581 stop_early: bool = False, 

582 working_dir: Optional[Union[os.PathLike[str], str]] = None, 

583 **deprecated: Unpack[DeprecatedKwargs], 

584) -> Union[DatasetDescr, InvalidDescr]: ... 

585 

586 

587@overload 

588def load_description_and_test( 

589 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

590 *, 

591 format_version: Literal["latest"], 

592 weight_format: Optional[SupportedWeightsFormat] = None, 

593 devices: Optional[Sequence[str]] = None, 

594 determinism: Literal["seed_only", "full"] = "seed_only", 

595 expected_type: Optional[str] = None, 

596 sha256: Optional[Sha256] = None, 

597 stop_early: bool = False, 

598 working_dir: Optional[Union[os.PathLike[str], str]] = None, 

599 **deprecated: Unpack[DeprecatedKwargs], 

600) -> Union[LatestResourceDescr, InvalidDescr]: ... 

601 

602 

603@overload 

604def load_description_and_test( 

605 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

606 *, 

607 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 

608 weight_format: Optional[SupportedWeightsFormat] = None, 

609 devices: Optional[Sequence[str]] = None, 

610 determinism: Literal["seed_only", "full"] = "seed_only", 

611 expected_type: Literal["model"], 

612 sha256: Optional[Sha256] = None, 

613 stop_early: bool = False, 

614 working_dir: Optional[Union[os.PathLike[str], str]] = None, 

615 **deprecated: Unpack[DeprecatedKwargs], 

616) -> Union[AnyModelDescr, InvalidDescr]: ... 

617 

618 

619@overload 

620def load_description_and_test( 

621 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

622 *, 

623 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 

624 weight_format: Optional[SupportedWeightsFormat] = None, 

625 devices: Optional[Sequence[str]] = None, 

626 determinism: Literal["seed_only", "full"] = "seed_only", 

627 expected_type: Literal["dataset"], 

628 sha256: Optional[Sha256] = None, 

629 stop_early: bool = False, 

630 working_dir: Optional[Union[os.PathLike[str], str]] = None, 

631 **deprecated: Unpack[DeprecatedKwargs], 

632) -> Union[AnyDatasetDescr, InvalidDescr]: ... 

633 

634 

635@overload 

636def load_description_and_test( 

637 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

638 *, 

639 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 

640 weight_format: Optional[SupportedWeightsFormat] = None, 

641 devices: Optional[Sequence[str]] = None, 

642 determinism: Literal["seed_only", "full"] = "seed_only", 

643 expected_type: Optional[str] = None, 

644 sha256: Optional[Sha256] = None, 

645 stop_early: bool = False, 

646 working_dir: Optional[Union[os.PathLike[str], str]] = None, 

647 **deprecated: Unpack[DeprecatedKwargs], 

648) -> Union[ResourceDescr, InvalidDescr]: ... 

649 

650 

651def load_description_and_test( 

652 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

653 *, 

654 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 

655 weight_format: Optional[SupportedWeightsFormat] = None, 

656 devices: Optional[Sequence[str]] = None, 

657 determinism: Literal["seed_only", "full"] = "seed_only", 

658 expected_type: Optional[str] = None, 

659 sha256: Optional[Sha256] = None, 

660 stop_early: bool = False, 

661 working_dir: Optional[Union[os.PathLike[str], str]] = None, 

662 **deprecated: Unpack[DeprecatedKwargs], 

663) -> Union[ResourceDescr, InvalidDescr]: 

664 """Test a bioimage.io resource dynamically, 

665 for example run prediction of test tensors for models. 

666 

667 See `test_description` for more details. 

668 

669 Returns: 

670 A (possibly invalid) resource description object 

671 with a populated `.validation_summary` attribute. 

672 """ 

673 if isinstance(source, ResourceDescrBase): 

674 root = source.root 

675 file_name = source.file_name 

676 if ( 

677 ( 

678 format_version 

679 not in ( 

680 DISCOVER, 

681 source.format_version, 

682 ".".join(source.format_version.split(".")[:2]), 

683 ) 

684 ) 

685 or (c := source.validation_summary.details[0].context) is None 

686 or not c.perform_io_checks 

687 ): 

688 logger.debug( 

689 "deserializing source to ensure we validate and test using format {} and perform io checks", 

690 format_version, 

691 ) 

692 source = dump_description(source) 

693 else: 

694 root = Path() 

695 file_name = None 

696 

697 if isinstance(source, ResourceDescrBase): 

698 rd = source 

699 elif isinstance(source, dict): 

700 # check context for a given root; default to root of source 

701 context = get_validation_context( 

702 ValidationContext(root=root, file_name=file_name) 

703 ).replace( 

704 perform_io_checks=True # make sure we perform io checks though 

705 ) 

706 

707 rd = build_description( 

708 source, 

709 format_version=format_version, 

710 context=context, 

711 ) 

712 else: 

713 rd = load_description( 

714 source, format_version=format_version, sha256=sha256, perform_io_checks=True 

715 ) 

716 

717 rd.validation_summary.env.add( 

718 InstalledPackage(name="bioimageio.core", version=__version__) 

719 ) 

720 

721 if expected_type is not None: 

722 has_expected_type = _test_expected_resource_type(rd, expected_type) 

723 if not has_expected_type: 

724 # unexpected type -> invalid format 

725 rd.validation_summary.status = "failed" 

726 return rd 

727 

728 # elevate status valid-format to passed and start testing 

729 if rd.validation_summary.status == "valid-format": 

730 rd.validation_summary.status = "passed" 

731 

732 if isinstance(rd, (v0_4.ModelDescr, v0_5.ModelDescr)): 

733 if weight_format is None: 

734 weight_formats: List[SupportedWeightsFormat] = [ 

735 w for w, we in rd.weights if we is not None 

736 ] # pyright: ignore[reportAssignmentType] 

737 else: 

738 weight_formats = [weight_format] 

739 

740 enable_determinism(determinism, weight_formats=weight_formats) 

741 for w in weight_formats: 

742 passed_recreate_test_outputs = _test_recreate_test_outputs( 

743 rd, 

744 w, 

745 devices, 

746 stop_early=stop_early, 

747 working_dir=working_dir, 

748 verbose=working_dir is not None, 

749 **deprecated, 

750 ) 

751 

752 if stop_early and not passed_recreate_test_outputs: 

753 break 

754 

755 if not isinstance(rd, v0_4.ModelDescr): 

756 passed_parametrized_inference = _test_parametrized_inference( 

757 rd, w, devices, stop_early=stop_early 

758 ) 

759 if stop_early and not passed_parametrized_inference: 

760 break 

761 

762 # TODO: add execution of jupyter notebooks 

763 # TODO: add more tests 

764 

765 return rd 

766 

767 

768def _get_tolerance( 

769 model: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

770 wf: SupportedWeightsFormat, 

771 m: MemberId, 

772 **deprecated: Unpack[DeprecatedKwargs], 

773) -> Tuple[RelativeTolerance, AbsoluteTolerance, MismatchedElementsPerMillion]: 

774 if isinstance(model, v0_5.ModelDescr): 

775 applicable = v0_5.ReproducibilityTolerance() 

776 

777 # check legacy test kwargs for weight format specific tolerance 

778 if model.config.bioimageio.model_extra is not None: 

779 for weights_format, test_kwargs in model.config.bioimageio.model_extra.get( 

780 "test_kwargs", {} 

781 ).items(): 

782 if wf == weights_format: 

783 applicable = v0_5.ReproducibilityTolerance( 

784 relative_tolerance=test_kwargs.get("relative_tolerance", 1e-3), 

785 absolute_tolerance=test_kwargs.get("absolute_tolerance", 1e-3), 

786 ) 

787 break 

788 

789 # check for weights format and output tensor specific tolerance 

790 for a in model.config.bioimageio.reproducibility_tolerance: 

791 if (not a.weights_formats or wf in a.weights_formats) and ( 

792 not a.output_ids or m in a.output_ids 

793 ): 

794 applicable = a 

795 break 

796 

797 rtol = applicable.relative_tolerance 

798 atol = applicable.absolute_tolerance 

799 mismatched_tol = applicable.mismatched_elements_per_million 

800 elif (decimal := deprecated.get("decimal")) is not None: 

801 warnings.warn( 

802 "The argument `decimal` has been deprecated in favour of" 

803 + " `relative_tolerance` and `absolute_tolerance`, with different" 

804 + " validation logic, using `numpy.testing.assert_allclose, see" 

805 + " 'https://numpy.org/doc/stable/reference/generated/" 

806 + " numpy.testing.assert_allclose.html'. Passing a value for `decimal`" 

807 + " will cause validation to revert to the old behaviour." 

808 ) 

809 atol = 1.5 * 10 ** (-decimal) 

810 rtol = 0 

811 mismatched_tol = 0 

812 else: 

813 # use given (deprecated) test kwargs 

814 atol = deprecated.get("absolute_tolerance", 1e-3) 

815 rtol = deprecated.get("relative_tolerance", 1e-3) 

816 mismatched_tol = 0 

817 

818 return rtol, atol, mismatched_tol 

819 

820 

821def evaluate_mismatched_elements( 

822 actual: Tensor, expected: Tensor, rtol: float, atol: float, name: str 

823) -> Tuple[float, str, Optional[str]]: 

824 try: 

825 expected_np = expected.data.to_numpy().astype(np.float32) 

826 dims = expected.dims 

827 del expected 

828 actual_np: NDArray[Any] = actual.data.to_numpy().astype(np.float32) 

829 del actual 

830 

831 rtol_value = rtol * abs(expected_np) 

832 abs_diff = abs(actual_np - expected_np) 

833 mismatched = abs_diff > atol + rtol_value 

834 mismatched_elements = mismatched.sum().item() 

835 

836 mismatched_ppm = mismatched_elements / expected_np.size * 1e6 

837 abs_diff[~mismatched] = 0 # ignore non-mismatched elements 

838 

839 r_max_idx_flat = (r_diff := (abs_diff / (abs(expected_np) + 1e-6))).argmax() 

840 r_max_idx = np.unravel_index(r_max_idx_flat, r_diff.shape) 

841 r_max = r_diff[r_max_idx].item() 

842 r_actual = actual_np[r_max_idx].item() 

843 r_expected = expected_np[r_max_idx].item() 

844 

845 # Calculate the max absolute difference with the relative tolerance subtracted 

846 abs_diff_wo_rtol: NDArray[np.float32] = abs_diff - rtol_value 

847 a_max_idx = np.unravel_index(abs_diff_wo_rtol.argmax(), abs_diff_wo_rtol.shape) 

848 

849 a_max = abs_diff[a_max_idx].item() 

850 a_actual = actual_np[a_max_idx].item() 

851 a_expected = expected_np[a_max_idx].item() 

852 except Exception as e: 

853 mismatched_ppm = -1 

854 msg = "" 

855 error_msg = ( 

856 f"Error while checking if '{name}' disagrees with expected values: {e}" 

857 ) 

858 else: 

859 error_msg = None 

860 if mismatched_elements: 

861 msg = ( 

862 f"Output '{name}': {mismatched_elements} of " 

863 + f"{expected_np.size} elements disagree with expected values (" 

864 + ( 

865 f"{mismatched_ppm * 10_000:.1f}%" 

866 if mismatched_ppm >= 1_000 

867 else f"{mismatched_ppm:.1f} ppm" 

868 ) 

869 + "). " 

870 ) 

871 else: 

872 msg = f"Output `{name}`: all elements agree with expected values. " 

873 

874 msg += ( 

875 f"\nMax relative difference not accounted for by absolute tolerance ({atol:.2e}):\n{r_max:.2e}" 

876 + rf" (= \|{r_actual:.2e} - {r_expected:.2e}\|/\|{r_expected:.2e} + 1e-6\|)" 

877 + f" at {dict(zip(dims, r_max_idx))} " 

878 + f"\nMax absolute difference not accounted for by relative tolerance ({rtol:.2e}):\n{a_max:.2e}" 

879 + rf" (= \|{a_actual:.7e} - {a_expected:.7e}\|) at {dict(zip(dims, a_max_idx))}" 

880 ) 

881 

882 return mismatched_ppm, msg, error_msg 

883 

884 

885def _test_recreate_test_outputs( 

886 model: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

887 weight_format: SupportedWeightsFormat, 

888 devices: Optional[Sequence[str]], 

889 stop_early: bool, 

890 *, 

891 working_dir: Optional[Union[os.PathLike[str], str]], 

892 verbose: bool, 

893 **deprecated: Unpack[DeprecatedKwargs], 

894) -> bool: 

895 test_name = f"Reproduce test outputs from test inputs ({weight_format})" 

896 logger.debug("starting '{}'", test_name) 

897 error_entries: List[ErrorEntry] = [] 

898 warning_entries: List[WarningEntry] = [] 

899 

900 def add_error_entry(msg: str, with_traceback: bool = False): 

901 error_entries.append( 

902 ErrorEntry( 

903 loc=("weights", weight_format), 

904 msg=msg, 

905 type="bioimageio.core", 

906 with_traceback=with_traceback, 

907 ) 

908 ) 

909 

910 def add_warning_entry(msg: str, severity: WarningSeverity): 

911 warning_entries.append( 

912 WarningEntry( 

913 loc=("weights", weight_format), 

914 msg=msg, 

915 type="bioimageio.core", 

916 severity=severity, 

917 ) 

918 ) 

919 

920 def save_to_working_dir(name: str, tensor: Tensor) -> List[Path]: 

921 saved_paths: List[Path] = [] 

922 if working_dir is not None and verbose: 

923 for p in [ 

924 Path(working_dir) / f"{name}_{weight_format}{suffix}" 

925 for suffix in (".npy", ".tiff") 

926 ]: 

927 try: 

928 save_tensor(p, tensor) 

929 except Exception as e: 

930 logger.error( 

931 "Failed to save tensor {}: {}", 

932 p, 

933 e, 

934 ) 

935 else: 

936 saved_paths.append(p) 

937 

938 return saved_paths 

939 

940 try: 

941 test_input = get_test_input_sample(model) 

942 expected = get_test_output_sample(model) 

943 

944 with create_prediction_pipeline( 

945 bioimageio_model=model, devices=devices, weight_format=weight_format 

946 ) as prediction_pipeline: 

947 prediction_pipeline.apply_preprocessing(test_input) 

948 test_input_preprocessed = deepcopy(test_input) 

949 results_not_postprocessed = ( 

950 prediction_pipeline.predict_sample_without_blocking( 

951 test_input, 

952 skip_postprocessing=True, 

953 skip_preprocessing=True, 

954 skip_input_padding=True, 

955 skip_output_cropping=True, 

956 ) 

957 ) 

958 results = deepcopy(results_not_postprocessed) 

959 prediction_pipeline.apply_postprocessing(results) 

960 

961 if len(results.members) != len(expected.members): 

962 add_error_entry( 

963 f"Expected {len(expected.members)} outputs, but got {len(results.members)}" 

964 ) 

965 

966 else: 

967 intermediate_paths: List[Path] = [] 

968 for m, t in test_input_preprocessed.members.items(): 

969 intermediate_paths.extend( 

970 save_to_working_dir(f"test_input_preprocessed_{m}", t) 

971 ) 

972 if intermediate_paths: 

973 logger.debug("Saved preprocessed test inputs to {}", intermediate_paths) 

974 

975 for m, expected in expected.members.items(): 

976 actual = results.members.get(m) 

977 if actual is None: 

978 add_error_entry("Output tensors for test case may not be None") 

979 if stop_early: 

980 break 

981 else: 

982 continue 

983 

984 if actual.dims != expected.dims: 

985 add_error_entry( 

986 f"Output '{m}' has dims {actual.dims}, but expected {expected.dims}" 

987 ) 

988 if stop_early: 

989 break 

990 else: 

991 continue 

992 

993 if actual.tagged_shape != expected.tagged_shape: 

994 add_error_entry( 

995 f"Output '{m}' has shape {actual.tagged_shape}, but expected {expected.tagged_shape}" 

996 ) 

997 if stop_early: 

998 break 

999 else: 

1000 continue 

1001 

1002 try: 

1003 output_paths = save_to_working_dir(f"actual_output_{m}", actual) 

1004 if m in results_not_postprocessed.members: 

1005 output_paths.extend( 

1006 save_to_working_dir( 

1007 f"actual_output_{m}_not_postprocessed", 

1008 results_not_postprocessed.members[m], 

1009 ) 

1010 ) 

1011 except Exception as e: 

1012 logger.error(f"Failed to save actual output tensor for '{m}': {e}") 

1013 output_paths = None 

1014 

1015 rtol, atol, mismatched_tol = _get_tolerance( 

1016 model, wf=weight_format, m=m, **deprecated 

1017 ) 

1018 mismatched_ppm, msg, error_msg = evaluate_mismatched_elements( 

1019 actual, expected, rtol, atol, m 

1020 ) 

1021 if error_msg is not None: 

1022 add_error_entry(error_msg) 

1023 if stop_early: 

1024 break 

1025 

1026 if output_paths: 

1027 msg += f"\n Saved (intermediate) outputs to {output_paths}." 

1028 

1029 if mismatched_ppm > mismatched_tol: 

1030 add_error_entry(msg) 

1031 if stop_early: 

1032 break 

1033 else: 

1034 add_warning_entry( 

1035 msg, severity=WARNING if mismatched_ppm != 0 else INFO 

1036 ) 

1037 

1038 except Exception as e: 

1039 if get_validation_context().raise_errors: 

1040 raise e 

1041 

1042 add_error_entry(str(e), with_traceback=True) 

1043 

1044 model.validation_summary.add_detail( 

1045 ValidationDetail( 

1046 name=test_name, 

1047 loc=("weights", weight_format), 

1048 status="failed" if error_entries else "passed", 

1049 recommended_env=get_conda_env(entry=dict(model.weights)[weight_format]), 

1050 errors=error_entries, 

1051 warnings=warning_entries, 

1052 ) 

1053 ) 

1054 return bool(error_entries) 

1055 

1056 

1057def _test_parametrized_inference( 

1058 model: v0_5.ModelDescr, 

1059 weight_format: SupportedWeightsFormat, 

1060 devices: Optional[Sequence[str]], 

1061 *, 

1062 stop_early: bool, 

1063) -> None: 

1064 if not any( 

1065 isinstance(a.size, v0_5.ParameterizedSize) 

1066 for ipt in model.inputs 

1067 for a in ipt.axes 

1068 ): 

1069 # no parameterized sizes => set n=0 

1070 ns: Set[v0_5.ParameterizedSize_N] = {0} 

1071 else: 

1072 ns = {0, 1, 2} 

1073 

1074 given_batch_sizes = { 

1075 a.size 

1076 for ipt in model.inputs 

1077 for a in ipt.axes 

1078 if isinstance(a, v0_5.BatchAxis) 

1079 } 

1080 if given_batch_sizes: 

1081 batch_sizes = {gbs for gbs in given_batch_sizes if gbs is not None} 

1082 if not batch_sizes: 

1083 # only arbitrary batch sizes 

1084 batch_sizes = {1, 2} 

1085 else: 

1086 # no batch axis 

1087 batch_sizes = {1} 

1088 

1089 test_cases: Set[Tuple[BatchSize, v0_5.ParameterizedSize_N]] = { 

1090 (b, n) for b, n in product(sorted(batch_sizes), sorted(ns)) 

1091 } 

1092 logger.info( 

1093 "Testing inference with '{}' for {} different inputs (B, N): {}", 

1094 weight_format, 

1095 len(test_cases), 

1096 test_cases, 

1097 ) 

1098 

1099 def generate_test_cases(): 

1100 tested: Set[Hashable] = set() 

1101 

1102 def get_ns(n: int): 

1103 return { 

1104 (t.id, a.id): n 

1105 for t in model.inputs 

1106 for a in t.axes 

1107 if isinstance(a.size, v0_5.ParameterizedSize) 

1108 } 

1109 

1110 for batch_size, n in sorted(test_cases): 

1111 input_target_sizes, expected_output_sizes = model.get_axis_sizes( 

1112 get_ns(n), batch_size=batch_size 

1113 ) 

1114 hashable_target_size = tuple( 

1115 (k, input_target_sizes[k]) for k in sorted(input_target_sizes) 

1116 ) 

1117 if hashable_target_size in tested: 

1118 continue 

1119 else: 

1120 tested.add(hashable_target_size) 

1121 

1122 resized_test_inputs = Sample( 

1123 members={ 

1124 t.id: ( 

1125 test_input.members[t.id].resize_to( 

1126 { 

1127 aid: s 

1128 for (tid, aid), s in input_target_sizes.items() 

1129 if tid == t.id 

1130 }, 

1131 ) 

1132 ) 

1133 for t in model.inputs 

1134 }, 

1135 stat=test_input.stat, 

1136 id=test_input.id, 

1137 ) 

1138 expected_output_shapes = { 

1139 t.id: { 

1140 aid: s 

1141 for (tid, aid), s in expected_output_sizes.items() 

1142 if tid == t.id 

1143 } 

1144 for t in model.outputs 

1145 } 

1146 yield n, batch_size, resized_test_inputs, expected_output_shapes 

1147 

1148 try: 

1149 test_input = get_test_input_sample(model) 

1150 

1151 with create_prediction_pipeline( 

1152 bioimageio_model=model, devices=devices, weight_format=weight_format 

1153 ) as prediction_pipeline: 

1154 for n, batch_size, inputs, exptected_output_shape in generate_test_cases(): 

1155 error: Optional[str] = None 

1156 try: 

1157 result = prediction_pipeline.predict_sample_without_blocking( 

1158 inputs, skip_input_padding=True, skip_output_cropping=True 

1159 ) 

1160 except Exception as e: 

1161 error = str(e) 

1162 else: 

1163 if len(result.members) != len(exptected_output_shape): 

1164 error = ( 

1165 f"Expected {len(exptected_output_shape)} outputs," 

1166 + f" but got {len(result.members)}" 

1167 ) 

1168 

1169 else: 

1170 for m, exp in exptected_output_shape.items(): 

1171 res = result.members.get(m) 

1172 if res is None: 

1173 error = "Output tensors may not be None for test case" 

1174 break 

1175 

1176 diff: Dict[AxisId, int] = {} 

1177 for a, s in res.sizes.items(): 

1178 if isinstance((e_aid := exp[AxisId(a)]), int): 

1179 if s != e_aid: 

1180 diff[AxisId(a)] = s 

1181 elif ( 

1182 s < e_aid.min 

1183 or e_aid.max is not None 

1184 and s > e_aid.max 

1185 ): 

1186 diff[AxisId(a)] = s 

1187 if diff: 

1188 error = ( 

1189 f"(n={n}) Expected output shape {exp}," 

1190 + f" but got {res.sizes} (diff: {diff})" 

1191 ) 

1192 break 

1193 

1194 model.validation_summary.add_detail( 

1195 ValidationDetail( 

1196 name=f"Run {weight_format} inference for inputs with" 

1197 + f" batch_size: {batch_size} and size parameter n: {n}", 

1198 loc=("weights", weight_format), 

1199 status="passed" if error is None else "failed", 

1200 errors=( 

1201 [] 

1202 if error is None 

1203 else [ 

1204 ErrorEntry( 

1205 loc=("weights", weight_format), 

1206 msg=error, 

1207 type="bioimageio.core", 

1208 ) 

1209 ] 

1210 ), 

1211 ) 

1212 ) 

1213 if stop_early and error is not None: 

1214 break 

1215 except Exception as e: 

1216 if get_validation_context().raise_errors: 

1217 raise e 

1218 

1219 model.validation_summary.add_detail( 

1220 ValidationDetail( 

1221 name=f"Run {weight_format} inference for parametrized inputs", 

1222 status="failed", 

1223 loc=("weights", weight_format), 

1224 errors=[ 

1225 ErrorEntry( 

1226 loc=("weights", weight_format), 

1227 msg=str(e), 

1228 type="bioimageio.core", 

1229 with_traceback=True, 

1230 ) 

1231 ], 

1232 ) 

1233 ) 

1234 

1235 

1236def _test_expected_resource_type( 

1237 rd: Union[InvalidDescr, ResourceDescr], expected_type: str 

1238): 

1239 has_expected_type = rd.type is expected_type 

1240 rd.validation_summary.details.append( 

1241 ValidationDetail( 

1242 name="Has expected resource type", 

1243 status="passed" if has_expected_type else "failed", 

1244 loc=("type",), 

1245 errors=( 

1246 [] 

1247 if has_expected_type 

1248 else [ 

1249 ErrorEntry( 

1250 loc=("type",), 

1251 type="type", 

1252 msg=f"Expected type {expected_type}, found {rd.type}", 

1253 ) 

1254 ] 

1255 ), 

1256 ) 

1257 ) 

1258 return has_expected_type 

1259 

1260 

1261# TODO: Implement `debug_model()` 

1262# def debug_model( 

1263# model_rdf: Union[RawResourceDescr, ResourceDescr, URI, Path, str], 

1264# *, 

1265# weight_format: Optional[WeightsFormat] = None, 

1266# devices: Optional[List[str]] = None, 

1267# ): 

1268# """Run the model test and return dict with inputs, results, expected results and intermediates. 

1269 

1270# Returns dict with tensors "inputs", "inputs_processed", "outputs_raw", "outputs", "expected" and "diff". 

1271# """ 

1272# inputs_raw: Optional = None 

1273# inputs_processed: Optional = None 

1274# outputs_raw: Optional = None 

1275# outputs: Optional = None 

1276# expected: Optional = None 

1277# diff: Optional = None 

1278 

1279# model = load_description( 

1280# model_rdf, weights_priority_order=None if weight_format is None else [weight_format] 

1281# ) 

1282# if not isinstance(model, Model): 

1283# raise ValueError(f"Not a bioimageio.model: {model_rdf}") 

1284 

1285# prediction_pipeline = create_prediction_pipeline( 

1286# bioimageio_model=model, devices=devices, weight_format=weight_format 

1287# ) 

1288# inputs = [ 

1289# xr.DataArray(load_array(str(in_path)), dims=input_spec.axes) 

1290# for in_path, input_spec in zip(model.test_inputs, model.inputs) 

1291# ] 

1292# input_dict = {input_spec.name: input for input_spec, input in zip(model.inputs, inputs)} 

1293 

1294# # keep track of the non-processed inputs 

1295# inputs_raw = [deepcopy(input) for input in inputs] 

1296 

1297# computed_measures = {} 

1298 

1299# prediction_pipeline.apply_preprocessing(input_dict, computed_measures) 

1300# inputs_processed = list(input_dict.values()) 

1301# outputs_raw = prediction_pipeline.predict(*inputs_processed) 

1302# output_dict = {output_spec.name: deepcopy(output) for output_spec, output in zip(model.outputs, outputs_raw)} 

1303# prediction_pipeline.apply_postprocessing(output_dict, computed_measures) 

1304# outputs = list(output_dict.values()) 

1305 

1306# if isinstance(outputs, (np.ndarray, xr.DataArray)): 

1307# outputs = [outputs] 

1308 

1309# expected = [ 

1310# xr.DataArray(load_array(str(out_path)), dims=output_spec.axes) 

1311# for out_path, output_spec in zip(model.test_outputs, model.outputs) 

1312# ] 

1313# if len(outputs) != len(expected): 

1314# error = f"Number of outputs and number of expected outputs disagree: {len(outputs)} != {len(expected)}" 

1315# print(error) 

1316# else: 

1317# diff = [] 

1318# for res, exp in zip(outputs, expected): 

1319# diff.append(res - exp) 

1320 

1321# return { 

1322# "inputs": inputs_raw, 

1323# "inputs_processed": inputs_processed, 

1324# "outputs_raw": outputs_raw, 

1325# "outputs": outputs, 

1326# "expected": expected, 

1327# "diff": diff, 

1328# }