Coverage for src / bioimageio / core / _resource_tests.py: 74%

426 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-18 12:35 +0000

1import hashlib 

2import os 

3import platform 

4import subprocess 

5import sys 

6import warnings 

7from contextlib import nullcontext 

8from copy import deepcopy 

9from io import StringIO 

10from itertools import product 

11from pathlib import Path 

12from tempfile import TemporaryDirectory 

13from typing import ( 

14 Any, 

15 Callable, 

16 Dict, 

17 Hashable, 

18 List, 

19 Literal, 

20 Optional, 

21 Sequence, 

22 Set, 

23 Tuple, 

24 Union, 

25 overload, 

26) 

27 

28import numpy as np 

29from loguru import logger 

30from numpy.typing import NDArray 

31from typing_extensions import NotRequired, TypedDict, Unpack, assert_never, get_args 

32 

33from bioimageio.spec import ( 

34 AnyDatasetDescr, 

35 AnyModelDescr, 

36 BioimageioCondaEnv, 

37 DatasetDescr, 

38 InvalidDescr, 

39 LatestResourceDescr, 

40 ModelDescr, 

41 ResourceDescr, 

42 ValidationContext, 

43 build_description, 

44 dump_description, 

45 get_conda_env, 

46 load_description, 

47 save_bioimageio_package, 

48) 

49from bioimageio.spec._description_impl import DISCOVER 

50from bioimageio.spec._internal.common_nodes import ResourceDescrBase 

51from bioimageio.spec._internal.io import is_yaml_value 

52from bioimageio.spec._internal.io_utils import read_yaml, write_yaml 

53from bioimageio.spec._internal.types import ( 

54 AbsoluteTolerance, 

55 FormatVersionPlaceholder, 

56 MismatchedElementsPerMillion, 

57 RelativeTolerance, 

58) 

59from bioimageio.spec._internal.validation_context import get_validation_context 

60from bioimageio.spec._internal.warning_levels import INFO, WARNING, WarningSeverity 

61from bioimageio.spec.common import BioimageioYamlContent, PermissiveFileSource, Sha256 

62from bioimageio.spec.model import v0_4, v0_5 

63from bioimageio.spec.model.v0_5 import WeightsFormat 

64from bioimageio.spec.summary import ( 

65 ErrorEntry, 

66 InstalledPackage, 

67 ValidationDetail, 

68 ValidationSummary, 

69 WarningEntry, 

70) 

71 

72from . import __version__ 

73from ._prediction_pipeline import create_prediction_pipeline 

74from ._settings import settings 

75from .axis import AxisId, BatchSize 

76from .common import MemberId, SupportedWeightsFormat 

77from .digest_spec import get_test_input_sample, get_test_output_sample 

78from .io import save_tensor 

79from .sample import Sample 

80from .tensor import Tensor 

81 

82CONDA_CMD = "conda.bat" if platform.system() == "Windows" else "conda" 

83 

84 

85class DeprecatedKwargs(TypedDict): 

86 absolute_tolerance: NotRequired[AbsoluteTolerance] 

87 relative_tolerance: NotRequired[RelativeTolerance] 

88 decimal: NotRequired[Optional[int]] 

89 

90 

91def enable_determinism( 

92 mode: Literal["seed_only", "full"] = "full", 

93 weight_formats: Optional[Sequence[SupportedWeightsFormat]] = None, 

94): 

95 """Seed and configure ML frameworks for maximum reproducibility. 

96 May degrade performance. Only recommended for testing reproducibility! 

97 

98 Seed any random generators and (if **mode**=="full") request ML frameworks to use 

99 deterministic algorithms. 

100 

101 Args: 

102 mode: determinism mode 

103 - 'seed_only' -- only set seeds, or 

104 - 'full' determinsm features (might degrade performance or throw exceptions) 

105 weight_formats: Limit deep learning importing deep learning frameworks 

106 based on weight_formats. 

107 E.g. this allows to avoid importing tensorflow when testing with pytorch. 

108 

109 Notes: 

110 - **mode** == "full" might degrade performance or throw exceptions. 

111 - Subsequent inference calls might still differ. Call before each function 

112 (sequence) that is expected to be reproducible. 

113 - Degraded performance: Use for testing reproducibility only! 

114 - Recipes: 

115 - [PyTorch](https://pytorch.org/docs/stable/notes/randomness.html) 

116 - [Keras](https://keras.io/examples/keras_recipes/reproducibility_recipes/) 

117 - [NumPy](https://numpy.org/doc/2.0/reference/random/generated/numpy.random.seed.html) 

118 """ 

119 try: 

120 try: 

121 import numpy.random 

122 except ImportError: 

123 pass 

124 else: 

125 numpy.random.seed(0) 

126 except Exception as e: 

127 logger.debug(str(e)) 

128 

129 if ( 

130 weight_formats is None 

131 or "pytorch_state_dict" in weight_formats 

132 or "torchscript" in weight_formats 

133 ): 

134 try: 

135 try: 

136 import torch 

137 except ImportError: 

138 pass 

139 else: 

140 _ = torch.manual_seed(0) 

141 torch.use_deterministic_algorithms(mode == "full") 

142 except Exception as e: 

143 logger.debug(str(e)) 

144 

145 if ( 

146 weight_formats is None 

147 or "tensorflow_saved_model_bundle" in weight_formats 

148 or "keras_hdf5" in weight_formats 

149 ): 

150 try: 

151 os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" 

152 try: 

153 import tensorflow as tf 

154 except ImportError: 

155 pass 

156 else: 

157 tf.random.set_seed(0) 

158 if mode == "full": 

159 tf.config.experimental.enable_op_determinism() 

160 # TODO: find possibility to switch it off again?? 

161 except Exception as e: 

162 logger.debug(str(e)) 

163 

164 if weight_formats is None or "keras_hdf5" in weight_formats: 

165 try: 

166 try: 

167 import keras # pyright: ignore[reportMissingTypeStubs] 

168 except ImportError: 

169 pass 

170 else: 

171 keras.utils.set_random_seed(0) 

172 except Exception as e: 

173 logger.debug(str(e)) 

174 

175 

176def test_model( 

177 source: Union[v0_4.ModelDescr, v0_5.ModelDescr, PermissiveFileSource], 

178 weight_format: Optional[SupportedWeightsFormat] = None, 

179 devices: Optional[List[str]] = None, 

180 *, 

181 determinism: Literal["seed_only", "full"] = "seed_only", 

182 sha256: Optional[Sha256] = None, 

183 stop_early: bool = False, 

184 working_dir: Optional[Union[os.PathLike[str], str]] = None, 

185 **deprecated: Unpack[DeprecatedKwargs], 

186) -> ValidationSummary: 

187 """Test model inference""" 

188 return test_description( 

189 source, 

190 weight_format=weight_format, 

191 devices=devices, 

192 determinism=determinism, 

193 expected_type="model", 

194 sha256=sha256, 

195 stop_early=stop_early, 

196 working_dir=working_dir, 

197 **deprecated, 

198 ) 

199 

200 

201def default_run_command(args: Sequence[str]): 

202 logger.info("running '{}'...", " ".join(args)) 

203 _ = subprocess.check_call(args) 

204 

205 

206def test_description( 

207 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

208 *, 

209 format_version: Union[FormatVersionPlaceholder, str] = "discover", 

210 weight_format: Optional[SupportedWeightsFormat] = None, 

211 devices: Optional[Sequence[str]] = None, 

212 determinism: Literal["seed_only", "full"] = "seed_only", 

213 expected_type: Optional[str] = None, 

214 sha256: Optional[Sha256] = None, 

215 stop_early: bool = False, 

216 runtime_env: Union[ 

217 Literal["currently-active", "as-described"], Path, BioimageioCondaEnv 

218 ] = ("currently-active"), 

219 run_command: Callable[[Sequence[str]], None] = default_run_command, 

220 working_dir: Optional[Union[os.PathLike[str], str]] = None, 

221 **deprecated: Unpack[DeprecatedKwargs], 

222) -> ValidationSummary: 

223 """Test a bioimage.io resource dynamically, 

224 for example run prediction of test tensors for models. 

225 

226 Args: 

227 source: model description source. 

228 weight_format: Weight format to test. 

229 Default: All weight formats present in **source**. 

230 devices: Devices to test with, e.g. 'cpu', 'cuda'. 

231 Default (may be weight format dependent): ['cuda'] if available, ['cpu'] otherwise. 

232 determinism: Modes to improve reproducibility of test outputs. 

233 expected_type: Assert an expected resource description `type`. 

234 sha256: Expected SHA256 value of **source**. 

235 (Ignored if **source** already is a loaded `ResourceDescr` object.) 

236 stop_early: Do not run further subtests after a failed one. 

237 runtime_env: (Experimental feature!) The Python environment to run the tests in 

238 - `"currently-active"`: Use active Python interpreter. 

239 - `"as-described"`: Use `bioimageio.spec.get_conda_env` to generate a conda 

240 environment YAML file based on the model weights description. 

241 - A `BioimageioCondaEnv` or a path to a conda environment YAML file. 

242 Note: The `bioimageio.core` dependency will be added automatically if not present. 

243 run_command: (Experimental feature!) Function to execute (conda) terminal commands in a subprocess. 

244 The function should raise an exception if the command fails. 

245 **run_command** is ignored if **runtime_env** is `"currently-active"`. 

246 working_dir: (for debugging) directory to save any temporary files 

247 (model packages, conda environments, test summaries). 

248 Defaults to a temporary directory. 

249 """ 

250 if runtime_env == "currently-active": 

251 rd = load_description_and_test( 

252 source, 

253 format_version=format_version, 

254 weight_format=weight_format, 

255 devices=devices, 

256 determinism=determinism, 

257 expected_type=expected_type, 

258 sha256=sha256, 

259 stop_early=stop_early, 

260 working_dir=working_dir, 

261 **deprecated, 

262 ) 

263 return rd.validation_summary 

264 

265 if runtime_env == "as-described": 

266 conda_env = None 

267 elif isinstance(runtime_env, (str, Path)): 

268 conda_env = BioimageioCondaEnv.model_validate(read_yaml(Path(runtime_env))) 

269 elif isinstance(runtime_env, BioimageioCondaEnv): 

270 conda_env = runtime_env 

271 else: 

272 assert_never(runtime_env) 

273 

274 if run_command is not default_run_command: 

275 try: 

276 run_command(["thiscommandshouldalwaysfail", "please"]) 

277 except Exception: 

278 pass 

279 else: 

280 raise RuntimeError( 

281 "given run_command does not raise an exception for a failing command" 

282 ) 

283 

284 verbose = working_dir is not None 

285 if working_dir is None: 

286 td_kwargs: Dict[str, Any] = ( 

287 dict(ignore_cleanup_errors=True) if sys.version_info >= (3, 10) else {} 

288 ) 

289 working_dir_ctxt = TemporaryDirectory(**td_kwargs) 

290 else: 

291 working_dir_ctxt = nullcontext(working_dir) 

292 

293 with working_dir_ctxt as _d: 

294 working_dir = Path(_d) 

295 

296 if isinstance(source, ResourceDescrBase): 

297 descr = source 

298 elif isinstance(source, dict): 

299 context = get_validation_context().replace( 

300 perform_io_checks=True # make sure we perform io checks though 

301 ) 

302 

303 descr = build_description(source, context=context) 

304 else: 

305 descr = load_description(source, perform_io_checks=True) 

306 

307 if isinstance(descr, InvalidDescr): 

308 return descr.validation_summary 

309 elif isinstance(source, (dict, ResourceDescrBase)): 

310 file_source = save_bioimageio_package( 

311 descr, output_path=working_dir / "package.zip" 

312 ) 

313 else: 

314 file_source = source 

315 

316 # elevate status valid-format to passed and start testing 

317 descr.validation_summary.status = "passed" 

318 _test_in_env( 

319 file_source, 

320 descr=descr, 

321 working_dir=working_dir, 

322 weight_format=weight_format, 

323 conda_env=conda_env, 

324 devices=devices, 

325 determinism=determinism, 

326 expected_type=expected_type, 

327 sha256=sha256, 

328 stop_early=stop_early, 

329 run_command=run_command, 

330 verbose=verbose, 

331 **deprecated, 

332 ) 

333 

334 return descr.validation_summary 

335 

336 

337def _test_in_env( 

338 source: PermissiveFileSource, 

339 *, 

340 descr: ResourceDescr, 

341 working_dir: Path, 

342 weight_format: Optional[SupportedWeightsFormat], 

343 conda_env: Optional[BioimageioCondaEnv], 

344 devices: Optional[Sequence[str]], 

345 determinism: Literal["seed_only", "full"], 

346 run_command: Callable[[Sequence[str]], None], 

347 stop_early: bool, 

348 expected_type: Optional[str], 

349 sha256: Optional[Sha256], 

350 verbose: bool, 

351 **deprecated: Unpack[DeprecatedKwargs], 

352): 

353 """Test a bioimage.io resource in a given conda environment. 

354 Adds details to the existing validation summary of **descr**. 

355 """ 

356 if isinstance(descr, (v0_4.ModelDescr, v0_5.ModelDescr)): 

357 if weight_format is None: 

358 # run tests for all present weight formats 

359 all_present_wfs = [ 

360 wf for wf in get_args(WeightsFormat) if getattr(descr.weights, wf) 

361 ] 

362 ignore_wfs = [wf for wf in all_present_wfs if wf in ["tensorflow_js"]] 

363 logger.info( 

364 "Found weight formats {}. Start testing all{}...", 

365 all_present_wfs, 

366 f" (except: {', '.join(ignore_wfs)}) " if ignore_wfs else "", 

367 ) 

368 for wf in all_present_wfs: 

369 _test_in_env( 

370 source, 

371 descr=descr, 

372 working_dir=working_dir / wf, 

373 weight_format=wf, 

374 devices=devices, 

375 determinism=determinism, 

376 conda_env=conda_env, 

377 run_command=run_command, 

378 expected_type=expected_type, 

379 sha256=sha256, 

380 stop_early=stop_early, 

381 verbose=verbose, 

382 **deprecated, 

383 ) 

384 

385 return 

386 

387 if weight_format == "pytorch_state_dict": 

388 wf = descr.weights.pytorch_state_dict 

389 elif weight_format == "torchscript": 

390 wf = descr.weights.torchscript 

391 elif weight_format == "keras_hdf5": 

392 wf = descr.weights.keras_hdf5 

393 elif weight_format == "onnx": 

394 wf = descr.weights.onnx 

395 elif weight_format == "tensorflow_saved_model_bundle": 

396 wf = descr.weights.tensorflow_saved_model_bundle 

397 elif weight_format == "keras_v3": 

398 if isinstance(descr, v0_4.ModelDescr): 

399 raise ValueError( 

400 "Weight format 'keras_v3' is not supported in v0.4 model descriptions. use format version >= 0.5" 

401 ) 

402 

403 wf = descr.weights.keras_v3 

404 elif weight_format == "tensorflow_js": 

405 raise RuntimeError( 

406 "testing 'tensorflow_js' is not supported by bioimageio.core" 

407 ) 

408 else: 

409 assert_never(weight_format) 

410 

411 assert wf is not None 

412 if conda_env is None: 

413 conda_env = get_conda_env(entry=wf) 

414 

415 test_loc = ("weights", weight_format) 

416 else: 

417 if conda_env is None: 

418 warnings.warn( 

419 "No conda environment description given for testing (And no default conda envs available for non-model descriptions)." 

420 ) 

421 return 

422 

423 test_loc = () 

424 

425 # remove name as we create a name based on the env description hash value 

426 conda_env.name = None 

427 

428 dumped_env = conda_env.model_dump(mode="json", exclude_none=True) 

429 if not is_yaml_value(dumped_env): 

430 raise ValueError(f"Failed to dump conda env to valid YAML {conda_env}") 

431 

432 env_io = StringIO() 

433 write_yaml(dumped_env, file=env_io) 

434 encoded_env = env_io.getvalue().encode() 

435 env_name = hashlib.sha256(encoded_env).hexdigest() 

436 

437 try: 

438 run_command(["where" if platform.system() == "Windows" else "which", CONDA_CMD]) 

439 except Exception as e: 

440 raise RuntimeError("Conda not available") from e 

441 

442 try: 

443 run_command([CONDA_CMD, "run", "-n", env_name, "python", "--version"]) 

444 except Exception: 

445 working_dir.mkdir(parents=True, exist_ok=True) 

446 path = working_dir / "env.yaml" 

447 try: 

448 _ = path.write_bytes(encoded_env) 

449 logger.debug("written conda env to {}", path) 

450 run_command( 

451 [ 

452 CONDA_CMD, 

453 "env", 

454 "create", 

455 "--yes", 

456 f"--file={path}", 

457 f"--name={env_name}", 

458 ] 

459 + (["--quiet"] if settings.CI else []) 

460 ) 

461 # double check that environment was created successfully 

462 run_command([CONDA_CMD, "run", "-n", env_name, "python", "--version"]) 

463 except Exception as e: 

464 descr.validation_summary.add_detail( 

465 ValidationDetail( 

466 name="Conda environment creation", 

467 status="failed", 

468 loc=test_loc, 

469 recommended_env=conda_env, 

470 errors=[ 

471 ErrorEntry( 

472 loc=test_loc, 

473 msg=str(e), 

474 type="conda", 

475 with_traceback=True, 

476 ) 

477 ], 

478 ) 

479 ) 

480 return 

481 else: 

482 descr.validation_summary.add_detail( 

483 ValidationDetail( 

484 name=f"Created conda environment '{env_name}'", 

485 status="passed", 

486 loc=test_loc, 

487 ) 

488 ) 

489 else: 

490 descr.validation_summary.add_detail( 

491 ValidationDetail( 

492 name=f"Found existing conda environment '{env_name}'", 

493 status="passed", 

494 loc=test_loc, 

495 ) 

496 ) 

497 

498 working_dir.mkdir(parents=True, exist_ok=True) 

499 summary_path = working_dir / "summary.json" 

500 assert not summary_path.exists(), "Summary file already exists" 

501 cmd = [] 

502 cmd_error = None 

503 for summary_path_arg_name in ("summary", "summary-path"): 

504 try: 

505 run_command( 

506 cmd := ( 

507 [ 

508 CONDA_CMD, 

509 "run", 

510 "-n", 

511 env_name, 

512 "bioimageio", 

513 "test", 

514 str(source), 

515 f"--{summary_path_arg_name}={summary_path.as_posix()}", 

516 f"--determinism={determinism}", 

517 ] 

518 + ([f"--weight-format={weight_format}"] if weight_format else []) 

519 + ([f"--expected-type={expected_type}"] if expected_type else []) 

520 + (["--stop-early"] if stop_early else []) 

521 ) 

522 ) 

523 except Exception as e: 

524 cmd_error = f"Command '{' '.join(cmd)}' returned with error: {e}." 

525 

526 if summary_path.exists(): 

527 break 

528 else: 

529 if cmd_error is not None: 

530 logger.warning(cmd_error) 

531 

532 descr.validation_summary.add_detail( 

533 ValidationDetail( 

534 name="run 'bioimageio test' command", 

535 recommended_env=conda_env, 

536 errors=[ 

537 ErrorEntry( 

538 loc=(), 

539 type="bioimageio cli", 

540 msg=f"test command '{' '.join(cmd)}' did not produce a summary file at {summary_path}", 

541 ) 

542 ], 

543 status="failed", 

544 ) 

545 ) 

546 return 

547 

548 # add relevant details from command summary 

549 command_summary = ValidationSummary.load_json(summary_path) 

550 for detail in command_summary.details: 

551 if detail.loc[: len(test_loc)] == test_loc or detail.status == "failed": 

552 descr.validation_summary.add_detail(detail) 

553 

554 

555@overload 

556def load_description_and_test( 

557 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

558 *, 

559 format_version: Literal["latest"], 

560 weight_format: Optional[SupportedWeightsFormat] = None, 

561 devices: Optional[Sequence[str]] = None, 

562 determinism: Literal["seed_only", "full"] = "seed_only", 

563 expected_type: Literal["model"], 

564 sha256: Optional[Sha256] = None, 

565 stop_early: bool = False, 

566 working_dir: Optional[Union[os.PathLike[str], str]] = None, 

567 **deprecated: Unpack[DeprecatedKwargs], 

568) -> Union[ModelDescr, InvalidDescr]: ... 

569 

570 

571@overload 

572def load_description_and_test( 

573 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

574 *, 

575 format_version: Literal["latest"], 

576 weight_format: Optional[SupportedWeightsFormat] = None, 

577 devices: Optional[Sequence[str]] = None, 

578 determinism: Literal["seed_only", "full"] = "seed_only", 

579 expected_type: Literal["dataset"], 

580 sha256: Optional[Sha256] = None, 

581 stop_early: bool = False, 

582 working_dir: Optional[Union[os.PathLike[str], str]] = None, 

583 **deprecated: Unpack[DeprecatedKwargs], 

584) -> Union[DatasetDescr, InvalidDescr]: ... 

585 

586 

587@overload 

588def load_description_and_test( 

589 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

590 *, 

591 format_version: Literal["latest"], 

592 weight_format: Optional[SupportedWeightsFormat] = None, 

593 devices: Optional[Sequence[str]] = None, 

594 determinism: Literal["seed_only", "full"] = "seed_only", 

595 expected_type: Optional[str] = None, 

596 sha256: Optional[Sha256] = None, 

597 stop_early: bool = False, 

598 working_dir: Optional[Union[os.PathLike[str], str]] = None, 

599 **deprecated: Unpack[DeprecatedKwargs], 

600) -> Union[LatestResourceDescr, InvalidDescr]: ... 

601 

602 

603@overload 

604def load_description_and_test( 

605 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

606 *, 

607 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 

608 weight_format: Optional[SupportedWeightsFormat] = None, 

609 devices: Optional[Sequence[str]] = None, 

610 determinism: Literal["seed_only", "full"] = "seed_only", 

611 expected_type: Literal["model"], 

612 sha256: Optional[Sha256] = None, 

613 stop_early: bool = False, 

614 working_dir: Optional[Union[os.PathLike[str], str]] = None, 

615 **deprecated: Unpack[DeprecatedKwargs], 

616) -> Union[AnyModelDescr, InvalidDescr]: ... 

617 

618 

619@overload 

620def load_description_and_test( 

621 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

622 *, 

623 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 

624 weight_format: Optional[SupportedWeightsFormat] = None, 

625 devices: Optional[Sequence[str]] = None, 

626 determinism: Literal["seed_only", "full"] = "seed_only", 

627 expected_type: Literal["dataset"], 

628 sha256: Optional[Sha256] = None, 

629 stop_early: bool = False, 

630 working_dir: Optional[Union[os.PathLike[str], str]] = None, 

631 **deprecated: Unpack[DeprecatedKwargs], 

632) -> Union[AnyDatasetDescr, InvalidDescr]: ... 

633 

634 

635@overload 

636def load_description_and_test( 

637 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

638 *, 

639 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 

640 weight_format: Optional[SupportedWeightsFormat] = None, 

641 devices: Optional[Sequence[str]] = None, 

642 determinism: Literal["seed_only", "full"] = "seed_only", 

643 expected_type: Optional[str] = None, 

644 sha256: Optional[Sha256] = None, 

645 stop_early: bool = False, 

646 working_dir: Optional[Union[os.PathLike[str], str]] = None, 

647 **deprecated: Unpack[DeprecatedKwargs], 

648) -> Union[ResourceDescr, InvalidDescr]: ... 

649 

650 

651def load_description_and_test( 

652 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

653 *, 

654 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 

655 weight_format: Optional[SupportedWeightsFormat] = None, 

656 devices: Optional[Sequence[str]] = None, 

657 determinism: Literal["seed_only", "full"] = "seed_only", 

658 expected_type: Optional[str] = None, 

659 sha256: Optional[Sha256] = None, 

660 stop_early: bool = False, 

661 working_dir: Optional[Union[os.PathLike[str], str]] = None, 

662 **deprecated: Unpack[DeprecatedKwargs], 

663) -> Union[ResourceDescr, InvalidDescr]: 

664 """Test a bioimage.io resource dynamically, 

665 for example run prediction of test tensors for models. 

666 

667 See `test_description` for more details. 

668 

669 Returns: 

670 A (possibly invalid) resource description object 

671 with a populated `.validation_summary` attribute. 

672 """ 

673 if isinstance(source, ResourceDescrBase): 

674 root = source.root 

675 file_name = source.file_name 

676 if ( 

677 ( 

678 format_version 

679 not in ( 

680 DISCOVER, 

681 source.format_version, 

682 ".".join(source.format_version.split(".")[:2]), 

683 ) 

684 ) 

685 or (c := source.validation_summary.details[0].context) is None 

686 or not c.perform_io_checks 

687 ): 

688 logger.debug( 

689 "deserializing source to ensure we validate and test using format {} and perform io checks", 

690 format_version, 

691 ) 

692 source = dump_description(source) 

693 else: 

694 root = Path() 

695 file_name = None 

696 

697 if isinstance(source, ResourceDescrBase): 

698 rd = source 

699 elif isinstance(source, dict): 

700 # check context for a given root; default to root of source 

701 context = get_validation_context( 

702 ValidationContext(root=root, file_name=file_name) 

703 ).replace( 

704 perform_io_checks=True # make sure we perform io checks though 

705 ) 

706 

707 rd = build_description( 

708 source, 

709 format_version=format_version, 

710 context=context, 

711 ) 

712 else: 

713 rd = load_description( 

714 source, format_version=format_version, sha256=sha256, perform_io_checks=True 

715 ) 

716 

717 rd.validation_summary.env.add( 

718 InstalledPackage(name="bioimageio.core", version=__version__) 

719 ) 

720 

721 if expected_type is not None: 

722 has_expected_type = _test_expected_resource_type(rd, expected_type) 

723 if not has_expected_type: 

724 # unexpected type -> invalid format 

725 rd.validation_summary.status = "failed" 

726 return rd 

727 

728 # elevate status valid-format to passed and start testing 

729 if rd.validation_summary.status == "valid-format": 

730 rd.validation_summary.status = "passed" 

731 

732 if isinstance(rd, (v0_4.ModelDescr, v0_5.ModelDescr)): 

733 if weight_format is None: 

734 weight_formats: List[SupportedWeightsFormat] = [ 

735 w for w, we in rd.weights if we is not None 

736 ] # pyright: ignore[reportAssignmentType] 

737 else: 

738 weight_formats = [weight_format] 

739 

740 enable_determinism(determinism, weight_formats=weight_formats) 

741 for w in weight_formats: 

742 passed_recreate_test_outputs = _test_recreate_test_outputs( 

743 rd, 

744 w, 

745 devices, 

746 stop_early=stop_early, 

747 working_dir=working_dir, 

748 verbose=working_dir is not None, 

749 **deprecated, 

750 ) 

751 

752 if stop_early and not passed_recreate_test_outputs: 

753 break 

754 

755 if not isinstance(rd, v0_4.ModelDescr): 

756 passed_parametrized_inference = _test_parametrized_inference( 

757 rd, w, devices, stop_early=stop_early 

758 ) 

759 if stop_early and not passed_parametrized_inference: 

760 break 

761 

762 # TODO: add execution of jupyter notebooks 

763 # TODO: add more tests 

764 

765 return rd 

766 

767 

768def _get_tolerance( 

769 model: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

770 wf: SupportedWeightsFormat, 

771 m: MemberId, 

772 **deprecated: Unpack[DeprecatedKwargs], 

773) -> Tuple[RelativeTolerance, AbsoluteTolerance, MismatchedElementsPerMillion]: 

774 if isinstance(model, v0_5.ModelDescr): 

775 applicable = v0_5.ReproducibilityTolerance() 

776 

777 # check legacy test kwargs for weight format specific tolerance 

778 if model.config.bioimageio.model_extra is not None: 

779 for weights_format, test_kwargs in model.config.bioimageio.model_extra.get( 

780 "test_kwargs", {} 

781 ).items(): 

782 if wf == weights_format: 

783 applicable = v0_5.ReproducibilityTolerance( 

784 relative_tolerance=test_kwargs.get("relative_tolerance", 1e-3), 

785 absolute_tolerance=test_kwargs.get("absolute_tolerance", 1e-3), 

786 ) 

787 break 

788 

789 # check for weights format and output tensor specific tolerance 

790 for a in model.config.bioimageio.reproducibility_tolerance: 

791 if (not a.weights_formats or wf in a.weights_formats) and ( 

792 not a.output_ids or m in a.output_ids 

793 ): 

794 applicable = a 

795 break 

796 

797 rtol = applicable.relative_tolerance 

798 atol = applicable.absolute_tolerance 

799 mismatched_tol = applicable.mismatched_elements_per_million 

800 elif (decimal := deprecated.get("decimal")) is not None: 

801 warnings.warn( 

802 "The argument `decimal` has been deprecated in favour of" 

803 + " `relative_tolerance` and `absolute_tolerance`, with different" 

804 + " validation logic, using `numpy.testing.assert_allclose, see" 

805 + " 'https://numpy.org/doc/stable/reference/generated/" 

806 + " numpy.testing.assert_allclose.html'. Passing a value for `decimal`" 

807 + " will cause validation to revert to the old behaviour." 

808 ) 

809 atol = 1.5 * 10 ** (-decimal) 

810 rtol = 0 

811 mismatched_tol = 0 

812 else: 

813 # use given (deprecated) test kwargs 

814 atol = deprecated.get("absolute_tolerance", 1e-3) 

815 rtol = deprecated.get("relative_tolerance", 1e-3) 

816 mismatched_tol = 0 

817 

818 return rtol, atol, mismatched_tol 

819 

820 

821def _test_recreate_test_outputs( 

822 model: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

823 weight_format: SupportedWeightsFormat, 

824 devices: Optional[Sequence[str]], 

825 stop_early: bool, 

826 *, 

827 working_dir: Optional[Union[os.PathLike[str], str]], 

828 verbose: bool, 

829 **deprecated: Unpack[DeprecatedKwargs], 

830) -> bool: 

831 test_name = f"Reproduce test outputs from test inputs ({weight_format})" 

832 logger.debug("starting '{}'", test_name) 

833 error_entries: List[ErrorEntry] = [] 

834 warning_entries: List[WarningEntry] = [] 

835 

836 def add_error_entry(msg: str, with_traceback: bool = False): 

837 error_entries.append( 

838 ErrorEntry( 

839 loc=("weights", weight_format), 

840 msg=msg, 

841 type="bioimageio.core", 

842 with_traceback=with_traceback, 

843 ) 

844 ) 

845 

846 def add_warning_entry(msg: str, severity: WarningSeverity): 

847 warning_entries.append( 

848 WarningEntry( 

849 loc=("weights", weight_format), 

850 msg=msg, 

851 type="bioimageio.core", 

852 severity=severity, 

853 ) 

854 ) 

855 

856 def save_to_working_dir(name: str, tensor: Tensor) -> List[Path]: 

857 saved_paths: List[Path] = [] 

858 if working_dir is not None and verbose: 

859 for p in [ 

860 Path(working_dir) / f"{name}_{weight_format}{suffix}" 

861 for suffix in (".npy", ".tiff") 

862 ]: 

863 try: 

864 save_tensor(p, tensor) 

865 except Exception as e: 

866 logger.error( 

867 "Failed to save tensor {}: {}", 

868 p, 

869 e, 

870 ) 

871 else: 

872 saved_paths.append(p) 

873 

874 return saved_paths 

875 

876 try: 

877 test_input = get_test_input_sample(model) 

878 expected = get_test_output_sample(model) 

879 

880 with create_prediction_pipeline( 

881 bioimageio_model=model, devices=devices, weight_format=weight_format 

882 ) as prediction_pipeline: 

883 prediction_pipeline.apply_preprocessing(test_input) 

884 test_input_preprocessed = deepcopy(test_input) 

885 results_not_postprocessed = ( 

886 prediction_pipeline.predict_sample_without_blocking( 

887 test_input, 

888 skip_postprocessing=True, 

889 skip_preprocessing=True, 

890 skip_input_padding=True, 

891 skip_output_cropping=True, 

892 ) 

893 ) 

894 results = deepcopy(results_not_postprocessed) 

895 prediction_pipeline.apply_postprocessing(results) 

896 

897 if len(results.members) != len(expected.members): 

898 add_error_entry( 

899 f"Expected {len(expected.members)} outputs, but got {len(results.members)}" 

900 ) 

901 

902 else: 

903 intermediate_paths: List[Path] = [] 

904 for m, t in test_input_preprocessed.members.items(): 

905 intermediate_paths.extend( 

906 save_to_working_dir(f"test_input_preprocessed_{m}", t) 

907 ) 

908 if intermediate_paths: 

909 logger.debug("Saved preprocessed test inputs to {}", intermediate_paths) 

910 

911 for m, expected in expected.members.items(): 

912 actual = results.members.get(m) 

913 if actual is None: 

914 add_error_entry("Output tensors for test case may not be None") 

915 if stop_early: 

916 break 

917 else: 

918 continue 

919 

920 if actual.dims != (dims := expected.dims): 

921 add_error_entry( 

922 f"Output '{m}' has dims {actual.dims}, but expected {expected.dims}" 

923 ) 

924 if stop_early: 

925 break 

926 else: 

927 continue 

928 

929 if actual.tagged_shape != expected.tagged_shape: 

930 add_error_entry( 

931 f"Output '{m}' has shape {actual.tagged_shape}, but expected {expected.tagged_shape}" 

932 ) 

933 if stop_early: 

934 break 

935 else: 

936 continue 

937 

938 try: 

939 output_paths = save_to_working_dir(f"actual_output_{m}", actual) 

940 if m in results_not_postprocessed.members: 

941 output_paths.extend( 

942 save_to_working_dir( 

943 f"actual_output_{m}_not_postprocessed", 

944 results_not_postprocessed.members[m], 

945 ) 

946 ) 

947 

948 expected_np = expected.data.to_numpy().astype(np.float32) 

949 del expected 

950 actual_np: NDArray[Any] = actual.data.to_numpy().astype(np.float32) 

951 

952 rtol, atol, mismatched_tol = _get_tolerance( 

953 model, wf=weight_format, m=m, **deprecated 

954 ) 

955 rtol_value = rtol * abs(expected_np) 

956 abs_diff = abs(actual_np - expected_np) 

957 mismatched = abs_diff > atol + rtol_value 

958 mismatched_elements = mismatched.sum().item() 

959 

960 mismatched_ppm = mismatched_elements / expected_np.size * 1e6 

961 abs_diff[~mismatched] = 0 # ignore non-mismatched elements 

962 

963 r_max_idx_flat = ( 

964 r_diff := (abs_diff / (abs(expected_np) + 1e-6)) 

965 ).argmax() 

966 r_max_idx = np.unravel_index(r_max_idx_flat, r_diff.shape) 

967 r_max = r_diff[r_max_idx].item() 

968 r_actual = actual_np[r_max_idx].item() 

969 r_expected = expected_np[r_max_idx].item() 

970 

971 # Calculate the max absolute difference with the relative tolerance subtracted 

972 abs_diff_wo_rtol: NDArray[np.float32] = abs_diff - rtol_value 

973 a_max_idx = np.unravel_index( 

974 abs_diff_wo_rtol.argmax(), abs_diff_wo_rtol.shape 

975 ) 

976 

977 a_max = abs_diff[a_max_idx].item() 

978 a_actual = actual_np[a_max_idx].item() 

979 a_expected = expected_np[a_max_idx].item() 

980 except Exception as e: 

981 msg = f"Error while checking if '{m}' disagrees with expected values: {e}" 

982 add_error_entry(msg) 

983 if stop_early: 

984 break 

985 else: 

986 if mismatched_elements: 

987 msg = ( 

988 f"Output '{m}': {mismatched_elements} of " 

989 + f"{expected_np.size} elements disagree with expected values." 

990 + f" ({mismatched_ppm:.1f} ppm). " 

991 ) 

992 else: 

993 msg = f"Output `{m}`: all elements agree with expected values. " 

994 

995 msg += ( 

996 f"\nMax relative difference not accounted for by absolute tolerance ({atol:.2e}):\n{r_max:.2e}" 

997 + rf" (= \|{r_actual:.2e} - {r_expected:.2e}\|/\|{r_expected:.2e} + 1e-6\|)" 

998 + f" at {dict(zip(dims, r_max_idx))} " 

999 + f"\nMax absolute difference not accounted for by relative tolerance ({rtol:.2e}):\n{a_max:.2e}" 

1000 + rf" (= \|{a_actual:.7e} - {a_expected:.7e}\|) at {dict(zip(dims, a_max_idx))}" 

1001 ) 

1002 if output_paths: 

1003 msg += f"\n Saved (intermediate) outputs to {output_paths}." 

1004 

1005 if mismatched_ppm > mismatched_tol: 

1006 add_error_entry(msg) 

1007 if stop_early: 

1008 break 

1009 else: 

1010 add_warning_entry( 

1011 msg, severity=WARNING if mismatched_elements else INFO 

1012 ) 

1013 

1014 except Exception as e: 

1015 if get_validation_context().raise_errors: 

1016 raise e 

1017 

1018 add_error_entry(str(e), with_traceback=True) 

1019 

1020 model.validation_summary.add_detail( 

1021 ValidationDetail( 

1022 name=test_name, 

1023 loc=("weights", weight_format), 

1024 status="failed" if error_entries else "passed", 

1025 recommended_env=get_conda_env(entry=dict(model.weights)[weight_format]), 

1026 errors=error_entries, 

1027 warnings=warning_entries, 

1028 ) 

1029 ) 

1030 return bool(error_entries) 

1031 

1032 

1033def _test_parametrized_inference( 

1034 model: v0_5.ModelDescr, 

1035 weight_format: SupportedWeightsFormat, 

1036 devices: Optional[Sequence[str]], 

1037 *, 

1038 stop_early: bool, 

1039) -> None: 

1040 if not any( 

1041 isinstance(a.size, v0_5.ParameterizedSize) 

1042 for ipt in model.inputs 

1043 for a in ipt.axes 

1044 ): 

1045 # no parameterized sizes => set n=0 

1046 ns: Set[v0_5.ParameterizedSize_N] = {0} 

1047 else: 

1048 ns = {0, 1, 2} 

1049 

1050 given_batch_sizes = { 

1051 a.size 

1052 for ipt in model.inputs 

1053 for a in ipt.axes 

1054 if isinstance(a, v0_5.BatchAxis) 

1055 } 

1056 if given_batch_sizes: 

1057 batch_sizes = {gbs for gbs in given_batch_sizes if gbs is not None} 

1058 if not batch_sizes: 

1059 # only arbitrary batch sizes 

1060 batch_sizes = {1, 2} 

1061 else: 

1062 # no batch axis 

1063 batch_sizes = {1} 

1064 

1065 test_cases: Set[Tuple[BatchSize, v0_5.ParameterizedSize_N]] = { 

1066 (b, n) for b, n in product(sorted(batch_sizes), sorted(ns)) 

1067 } 

1068 logger.info( 

1069 "Testing inference with '{}' for {} different inputs (B, N): {}", 

1070 weight_format, 

1071 len(test_cases), 

1072 test_cases, 

1073 ) 

1074 

1075 def generate_test_cases(): 

1076 tested: Set[Hashable] = set() 

1077 

1078 def get_ns(n: int): 

1079 return { 

1080 (t.id, a.id): n 

1081 for t in model.inputs 

1082 for a in t.axes 

1083 if isinstance(a.size, v0_5.ParameterizedSize) 

1084 } 

1085 

1086 for batch_size, n in sorted(test_cases): 

1087 input_target_sizes, expected_output_sizes = model.get_axis_sizes( 

1088 get_ns(n), batch_size=batch_size 

1089 ) 

1090 hashable_target_size = tuple( 

1091 (k, input_target_sizes[k]) for k in sorted(input_target_sizes) 

1092 ) 

1093 if hashable_target_size in tested: 

1094 continue 

1095 else: 

1096 tested.add(hashable_target_size) 

1097 

1098 resized_test_inputs = Sample( 

1099 members={ 

1100 t.id: ( 

1101 test_input.members[t.id].resize_to( 

1102 { 

1103 aid: s 

1104 for (tid, aid), s in input_target_sizes.items() 

1105 if tid == t.id 

1106 }, 

1107 ) 

1108 ) 

1109 for t in model.inputs 

1110 }, 

1111 stat=test_input.stat, 

1112 id=test_input.id, 

1113 ) 

1114 expected_output_shapes = { 

1115 t.id: { 

1116 aid: s 

1117 for (tid, aid), s in expected_output_sizes.items() 

1118 if tid == t.id 

1119 } 

1120 for t in model.outputs 

1121 } 

1122 yield n, batch_size, resized_test_inputs, expected_output_shapes 

1123 

1124 try: 

1125 test_input = get_test_input_sample(model) 

1126 

1127 with create_prediction_pipeline( 

1128 bioimageio_model=model, devices=devices, weight_format=weight_format 

1129 ) as prediction_pipeline: 

1130 for n, batch_size, inputs, exptected_output_shape in generate_test_cases(): 

1131 error: Optional[str] = None 

1132 try: 

1133 result = prediction_pipeline.predict_sample_without_blocking( 

1134 inputs, skip_input_padding=True, skip_output_cropping=True 

1135 ) 

1136 except Exception as e: 

1137 error = str(e) 

1138 else: 

1139 if len(result.members) != len(exptected_output_shape): 

1140 error = ( 

1141 f"Expected {len(exptected_output_shape)} outputs," 

1142 + f" but got {len(result.members)}" 

1143 ) 

1144 

1145 else: 

1146 for m, exp in exptected_output_shape.items(): 

1147 res = result.members.get(m) 

1148 if res is None: 

1149 error = "Output tensors may not be None for test case" 

1150 break 

1151 

1152 diff: Dict[AxisId, int] = {} 

1153 for a, s in res.sizes.items(): 

1154 if isinstance((e_aid := exp[AxisId(a)]), int): 

1155 if s != e_aid: 

1156 diff[AxisId(a)] = s 

1157 elif ( 

1158 s < e_aid.min 

1159 or e_aid.max is not None 

1160 and s > e_aid.max 

1161 ): 

1162 diff[AxisId(a)] = s 

1163 if diff: 

1164 error = ( 

1165 f"(n={n}) Expected output shape {exp}," 

1166 + f" but got {res.sizes} (diff: {diff})" 

1167 ) 

1168 break 

1169 

1170 model.validation_summary.add_detail( 

1171 ValidationDetail( 

1172 name=f"Run {weight_format} inference for inputs with" 

1173 + f" batch_size: {batch_size} and size parameter n: {n}", 

1174 loc=("weights", weight_format), 

1175 status="passed" if error is None else "failed", 

1176 errors=( 

1177 [] 

1178 if error is None 

1179 else [ 

1180 ErrorEntry( 

1181 loc=("weights", weight_format), 

1182 msg=error, 

1183 type="bioimageio.core", 

1184 ) 

1185 ] 

1186 ), 

1187 ) 

1188 ) 

1189 if stop_early and error is not None: 

1190 break 

1191 except Exception as e: 

1192 if get_validation_context().raise_errors: 

1193 raise e 

1194 

1195 model.validation_summary.add_detail( 

1196 ValidationDetail( 

1197 name=f"Run {weight_format} inference for parametrized inputs", 

1198 status="failed", 

1199 loc=("weights", weight_format), 

1200 errors=[ 

1201 ErrorEntry( 

1202 loc=("weights", weight_format), 

1203 msg=str(e), 

1204 type="bioimageio.core", 

1205 with_traceback=True, 

1206 ) 

1207 ], 

1208 ) 

1209 ) 

1210 

1211 

1212def _test_expected_resource_type( 

1213 rd: Union[InvalidDescr, ResourceDescr], expected_type: str 

1214): 

1215 has_expected_type = rd.type is expected_type 

1216 rd.validation_summary.details.append( 

1217 ValidationDetail( 

1218 name="Has expected resource type", 

1219 status="passed" if has_expected_type else "failed", 

1220 loc=("type",), 

1221 errors=( 

1222 [] 

1223 if has_expected_type 

1224 else [ 

1225 ErrorEntry( 

1226 loc=("type",), 

1227 type="type", 

1228 msg=f"Expected type {expected_type}, found {rd.type}", 

1229 ) 

1230 ] 

1231 ), 

1232 ) 

1233 ) 

1234 return has_expected_type 

1235 

1236 

1237# TODO: Implement `debug_model()` 

1238# def debug_model( 

1239# model_rdf: Union[RawResourceDescr, ResourceDescr, URI, Path, str], 

1240# *, 

1241# weight_format: Optional[WeightsFormat] = None, 

1242# devices: Optional[List[str]] = None, 

1243# ): 

1244# """Run the model test and return dict with inputs, results, expected results and intermediates. 

1245 

1246# Returns dict with tensors "inputs", "inputs_processed", "outputs_raw", "outputs", "expected" and "diff". 

1247# """ 

1248# inputs_raw: Optional = None 

1249# inputs_processed: Optional = None 

1250# outputs_raw: Optional = None 

1251# outputs: Optional = None 

1252# expected: Optional = None 

1253# diff: Optional = None 

1254 

1255# model = load_description( 

1256# model_rdf, weights_priority_order=None if weight_format is None else [weight_format] 

1257# ) 

1258# if not isinstance(model, Model): 

1259# raise ValueError(f"Not a bioimageio.model: {model_rdf}") 

1260 

1261# prediction_pipeline = create_prediction_pipeline( 

1262# bioimageio_model=model, devices=devices, weight_format=weight_format 

1263# ) 

1264# inputs = [ 

1265# xr.DataArray(load_array(str(in_path)), dims=input_spec.axes) 

1266# for in_path, input_spec in zip(model.test_inputs, model.inputs) 

1267# ] 

1268# input_dict = {input_spec.name: input for input_spec, input in zip(model.inputs, inputs)} 

1269 

1270# # keep track of the non-processed inputs 

1271# inputs_raw = [deepcopy(input) for input in inputs] 

1272 

1273# computed_measures = {} 

1274 

1275# prediction_pipeline.apply_preprocessing(input_dict, computed_measures) 

1276# inputs_processed = list(input_dict.values()) 

1277# outputs_raw = prediction_pipeline.predict(*inputs_processed) 

1278# output_dict = {output_spec.name: deepcopy(output) for output_spec, output in zip(model.outputs, outputs_raw)} 

1279# prediction_pipeline.apply_postprocessing(output_dict, computed_measures) 

1280# outputs = list(output_dict.values()) 

1281 

1282# if isinstance(outputs, (np.ndarray, xr.DataArray)): 

1283# outputs = [outputs] 

1284 

1285# expected = [ 

1286# xr.DataArray(load_array(str(out_path)), dims=output_spec.axes) 

1287# for out_path, output_spec in zip(model.test_outputs, model.outputs) 

1288# ] 

1289# if len(outputs) != len(expected): 

1290# error = f"Number of outputs and number of expected outputs disagree: {len(outputs)} != {len(expected)}" 

1291# print(error) 

1292# else: 

1293# diff = [] 

1294# for res, exp in zip(outputs, expected): 

1295# diff.append(res - exp) 

1296 

1297# return { 

1298# "inputs": inputs_raw, 

1299# "inputs_processed": inputs_processed, 

1300# "outputs_raw": outputs_raw, 

1301# "outputs": outputs, 

1302# "expected": expected, 

1303# "diff": diff, 

1304# }