Coverage for src / bioimageio / core / _resource_tests.py: 73%

426 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-15 23:26 +0000

1import hashlib 

2import os 

3import platform 

4import subprocess 

5import sys 

6import warnings 

7from contextlib import nullcontext 

8from copy import deepcopy 

9from io import StringIO 

10from itertools import product 

11from pathlib import Path 

12from tempfile import TemporaryDirectory 

13from typing import ( 

14 Any, 

15 Callable, 

16 Dict, 

17 Hashable, 

18 List, 

19 Literal, 

20 Optional, 

21 Sequence, 

22 Set, 

23 Tuple, 

24 Union, 

25 overload, 

26) 

27 

28import numpy as np 

29from loguru import logger 

30from numpy.typing import NDArray 

31from typing_extensions import NotRequired, TypedDict, Unpack, assert_never, get_args 

32 

33from bioimageio.core.tensor import Tensor 

34from bioimageio.spec import ( 

35 AnyDatasetDescr, 

36 AnyModelDescr, 

37 BioimageioCondaEnv, 

38 DatasetDescr, 

39 InvalidDescr, 

40 LatestResourceDescr, 

41 ModelDescr, 

42 ResourceDescr, 

43 ValidationContext, 

44 build_description, 

45 dump_description, 

46 get_conda_env, 

47 load_description, 

48 save_bioimageio_package, 

49) 

50from bioimageio.spec._description_impl import DISCOVER 

51from bioimageio.spec._internal.common_nodes import ResourceDescrBase 

52from bioimageio.spec._internal.io import is_yaml_value 

53from bioimageio.spec._internal.io_utils import read_yaml, write_yaml 

54from bioimageio.spec._internal.types import ( 

55 AbsoluteTolerance, 

56 FormatVersionPlaceholder, 

57 MismatchedElementsPerMillion, 

58 RelativeTolerance, 

59) 

60from bioimageio.spec._internal.validation_context import get_validation_context 

61from bioimageio.spec._internal.warning_levels import INFO, WARNING, WarningSeverity 

62from bioimageio.spec.common import BioimageioYamlContent, PermissiveFileSource, Sha256 

63from bioimageio.spec.model import v0_4, v0_5 

64from bioimageio.spec.model.v0_5 import WeightsFormat 

65from bioimageio.spec.summary import ( 

66 ErrorEntry, 

67 InstalledPackage, 

68 ValidationDetail, 

69 ValidationSummary, 

70 WarningEntry, 

71) 

72 

73from . import __version__ 

74from ._prediction_pipeline import create_prediction_pipeline 

75from ._settings import settings 

76from .axis import AxisId, BatchSize 

77from .common import MemberId, SupportedWeightsFormat 

78from .digest_spec import get_test_input_sample, get_test_output_sample 

79from .io import save_tensor 

80from .sample import Sample 

81 

82CONDA_CMD = "conda.bat" if platform.system() == "Windows" else "conda" 

83 

84 

85class DeprecatedKwargs(TypedDict): 

86 absolute_tolerance: NotRequired[AbsoluteTolerance] 

87 relative_tolerance: NotRequired[RelativeTolerance] 

88 decimal: NotRequired[Optional[int]] 

89 

90 

91def enable_determinism( 

92 mode: Literal["seed_only", "full"] = "full", 

93 weight_formats: Optional[Sequence[SupportedWeightsFormat]] = None, 

94): 

95 """Seed and configure ML frameworks for maximum reproducibility. 

96 May degrade performance. Only recommended for testing reproducibility! 

97 

98 Seed any random generators and (if **mode**=="full") request ML frameworks to use 

99 deterministic algorithms. 

100 

101 Args: 

102 mode: determinism mode 

103 - 'seed_only' -- only set seeds, or 

104 - 'full' determinsm features (might degrade performance or throw exceptions) 

105 weight_formats: Limit deep learning importing deep learning frameworks 

106 based on weight_formats. 

107 E.g. this allows to avoid importing tensorflow when testing with pytorch. 

108 

109 Notes: 

110 - **mode** == "full" might degrade performance or throw exceptions. 

111 - Subsequent inference calls might still differ. Call before each function 

112 (sequence) that is expected to be reproducible. 

113 - Degraded performance: Use for testing reproducibility only! 

114 - Recipes: 

115 - [PyTorch](https://pytorch.org/docs/stable/notes/randomness.html) 

116 - [Keras](https://keras.io/examples/keras_recipes/reproducibility_recipes/) 

117 - [NumPy](https://numpy.org/doc/2.0/reference/random/generated/numpy.random.seed.html) 

118 """ 

119 try: 

120 try: 

121 import numpy.random 

122 except ImportError: 

123 pass 

124 else: 

125 numpy.random.seed(0) 

126 except Exception as e: 

127 logger.debug(str(e)) 

128 

129 if ( 

130 weight_formats is None 

131 or "pytorch_state_dict" in weight_formats 

132 or "torchscript" in weight_formats 

133 ): 

134 try: 

135 try: 

136 import torch 

137 except ImportError: 

138 pass 

139 else: 

140 _ = torch.manual_seed(0) 

141 torch.use_deterministic_algorithms(mode == "full") 

142 except Exception as e: 

143 logger.debug(str(e)) 

144 

145 if ( 

146 weight_formats is None 

147 or "tensorflow_saved_model_bundle" in weight_formats 

148 or "keras_hdf5" in weight_formats 

149 ): 

150 try: 

151 os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" 

152 try: 

153 import tensorflow as tf 

154 except ImportError: 

155 pass 

156 else: 

157 tf.random.set_seed(0) 

158 if mode == "full": 

159 tf.config.experimental.enable_op_determinism() 

160 # TODO: find possibility to switch it off again?? 

161 except Exception as e: 

162 logger.debug(str(e)) 

163 

164 if weight_formats is None or "keras_hdf5" in weight_formats: 

165 try: 

166 try: 

167 import keras # pyright: ignore[reportMissingTypeStubs] 

168 except ImportError: 

169 pass 

170 else: 

171 keras.utils.set_random_seed(0) 

172 except Exception as e: 

173 logger.debug(str(e)) 

174 

175 

176def test_model( 

177 source: Union[v0_4.ModelDescr, v0_5.ModelDescr, PermissiveFileSource], 

178 weight_format: Optional[SupportedWeightsFormat] = None, 

179 devices: Optional[List[str]] = None, 

180 *, 

181 determinism: Literal["seed_only", "full"] = "seed_only", 

182 sha256: Optional[Sha256] = None, 

183 stop_early: bool = False, 

184 working_dir: Optional[Union[os.PathLike[str], str]] = None, 

185 **deprecated: Unpack[DeprecatedKwargs], 

186) -> ValidationSummary: 

187 """Test model inference""" 

188 return test_description( 

189 source, 

190 weight_format=weight_format, 

191 devices=devices, 

192 determinism=determinism, 

193 expected_type="model", 

194 sha256=sha256, 

195 stop_early=stop_early, 

196 working_dir=working_dir, 

197 **deprecated, 

198 ) 

199 

200 

201def default_run_command(args: Sequence[str]): 

202 logger.info("running '{}'...", " ".join(args)) 

203 _ = subprocess.check_call(args) 

204 

205 

206def test_description( 

207 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

208 *, 

209 format_version: Union[FormatVersionPlaceholder, str] = "discover", 

210 weight_format: Optional[SupportedWeightsFormat] = None, 

211 devices: Optional[Sequence[str]] = None, 

212 determinism: Literal["seed_only", "full"] = "seed_only", 

213 expected_type: Optional[str] = None, 

214 sha256: Optional[Sha256] = None, 

215 stop_early: bool = False, 

216 runtime_env: Union[ 

217 Literal["currently-active", "as-described"], Path, BioimageioCondaEnv 

218 ] = ("currently-active"), 

219 run_command: Callable[[Sequence[str]], None] = default_run_command, 

220 working_dir: Optional[Union[os.PathLike[str], str]] = None, 

221 **deprecated: Unpack[DeprecatedKwargs], 

222) -> ValidationSummary: 

223 """Test a bioimage.io resource dynamically, 

224 for example run prediction of test tensors for models. 

225 

226 Args: 

227 source: model description source. 

228 weight_format: Weight format to test. 

229 Default: All weight formats present in **source**. 

230 devices: Devices to test with, e.g. 'cpu', 'cuda'. 

231 Default (may be weight format dependent): ['cuda'] if available, ['cpu'] otherwise. 

232 determinism: Modes to improve reproducibility of test outputs. 

233 expected_type: Assert an expected resource description `type`. 

234 sha256: Expected SHA256 value of **source**. 

235 (Ignored if **source** already is a loaded `ResourceDescr` object.) 

236 stop_early: Do not run further subtests after a failed one. 

237 runtime_env: (Experimental feature!) The Python environment to run the tests in 

238 - `"currently-active"`: Use active Python interpreter. 

239 - `"as-described"`: Use `bioimageio.spec.get_conda_env` to generate a conda 

240 environment YAML file based on the model weights description. 

241 - A `BioimageioCondaEnv` or a path to a conda environment YAML file. 

242 Note: The `bioimageio.core` dependency will be added automatically if not present. 

243 run_command: (Experimental feature!) Function to execute (conda) terminal commands in a subprocess. 

244 The function should raise an exception if the command fails. 

245 **run_command** is ignored if **runtime_env** is `"currently-active"`. 

246 working_dir: (for debugging) directory to save any temporary files 

247 (model packages, conda environments, test summaries). 

248 Defaults to a temporary directory. 

249 """ 

250 if runtime_env == "currently-active": 

251 rd = load_description_and_test( 

252 source, 

253 format_version=format_version, 

254 weight_format=weight_format, 

255 devices=devices, 

256 determinism=determinism, 

257 expected_type=expected_type, 

258 sha256=sha256, 

259 stop_early=stop_early, 

260 working_dir=working_dir, 

261 **deprecated, 

262 ) 

263 return rd.validation_summary 

264 

265 if runtime_env == "as-described": 

266 conda_env = None 

267 elif isinstance(runtime_env, (str, Path)): 

268 conda_env = BioimageioCondaEnv.model_validate(read_yaml(Path(runtime_env))) 

269 elif isinstance(runtime_env, BioimageioCondaEnv): 

270 conda_env = runtime_env 

271 else: 

272 assert_never(runtime_env) 

273 

274 if run_command is not default_run_command: 

275 try: 

276 run_command(["thiscommandshouldalwaysfail", "please"]) 

277 except Exception: 

278 pass 

279 else: 

280 raise RuntimeError( 

281 "given run_command does not raise an exception for a failing command" 

282 ) 

283 

284 verbose = working_dir is not None 

285 if working_dir is None: 

286 td_kwargs: Dict[str, Any] = ( 

287 dict(ignore_cleanup_errors=True) if sys.version_info >= (3, 10) else {} 

288 ) 

289 working_dir_ctxt = TemporaryDirectory(**td_kwargs) 

290 else: 

291 working_dir_ctxt = nullcontext(working_dir) 

292 

293 with working_dir_ctxt as _d: 

294 working_dir = Path(_d) 

295 

296 if isinstance(source, ResourceDescrBase): 

297 descr = source 

298 elif isinstance(source, dict): 

299 context = get_validation_context().replace( 

300 perform_io_checks=True # make sure we perform io checks though 

301 ) 

302 

303 descr = build_description(source, context=context) 

304 else: 

305 descr = load_description(source, perform_io_checks=True) 

306 

307 if isinstance(descr, InvalidDescr): 

308 return descr.validation_summary 

309 elif isinstance(source, (dict, ResourceDescrBase)): 

310 file_source = save_bioimageio_package( 

311 descr, output_path=working_dir / "package.zip" 

312 ) 

313 else: 

314 file_source = source 

315 

316 # elevate status valid-format to passed and start testing 

317 descr.validation_summary.status = "passed" 

318 _test_in_env( 

319 file_source, 

320 descr=descr, 

321 working_dir=working_dir, 

322 weight_format=weight_format, 

323 conda_env=conda_env, 

324 devices=devices, 

325 determinism=determinism, 

326 expected_type=expected_type, 

327 sha256=sha256, 

328 stop_early=stop_early, 

329 run_command=run_command, 

330 verbose=verbose, 

331 **deprecated, 

332 ) 

333 

334 return descr.validation_summary 

335 

336 

337def _test_in_env( 

338 source: PermissiveFileSource, 

339 *, 

340 descr: ResourceDescr, 

341 working_dir: Path, 

342 weight_format: Optional[SupportedWeightsFormat], 

343 conda_env: Optional[BioimageioCondaEnv], 

344 devices: Optional[Sequence[str]], 

345 determinism: Literal["seed_only", "full"], 

346 run_command: Callable[[Sequence[str]], None], 

347 stop_early: bool, 

348 expected_type: Optional[str], 

349 sha256: Optional[Sha256], 

350 verbose: bool, 

351 **deprecated: Unpack[DeprecatedKwargs], 

352): 

353 """Test a bioimage.io resource in a given conda environment. 

354 Adds details to the existing validation summary of **descr**. 

355 """ 

356 if isinstance(descr, (v0_4.ModelDescr, v0_5.ModelDescr)): 

357 if weight_format is None: 

358 # run tests for all present weight formats 

359 all_present_wfs = [ 

360 wf for wf in get_args(WeightsFormat) if getattr(descr.weights, wf) 

361 ] 

362 ignore_wfs = [wf for wf in all_present_wfs if wf in ["tensorflow_js"]] 

363 logger.info( 

364 "Found weight formats {}. Start testing all{}...", 

365 all_present_wfs, 

366 f" (except: {', '.join(ignore_wfs)}) " if ignore_wfs else "", 

367 ) 

368 for wf in all_present_wfs: 

369 _test_in_env( 

370 source, 

371 descr=descr, 

372 working_dir=working_dir / wf, 

373 weight_format=wf, 

374 devices=devices, 

375 determinism=determinism, 

376 conda_env=conda_env, 

377 run_command=run_command, 

378 expected_type=expected_type, 

379 sha256=sha256, 

380 stop_early=stop_early, 

381 verbose=verbose, 

382 **deprecated, 

383 ) 

384 

385 return 

386 

387 if weight_format == "pytorch_state_dict": 

388 wf = descr.weights.pytorch_state_dict 

389 elif weight_format == "torchscript": 

390 wf = descr.weights.torchscript 

391 elif weight_format == "keras_hdf5": 

392 wf = descr.weights.keras_hdf5 

393 elif weight_format == "onnx": 

394 wf = descr.weights.onnx 

395 elif weight_format == "tensorflow_saved_model_bundle": 

396 wf = descr.weights.tensorflow_saved_model_bundle 

397 elif weight_format == "keras_v3": 

398 if isinstance(descr, v0_4.ModelDescr): 

399 raise ValueError( 

400 "Weight format 'keras_v3' is not supported in v0.4 model descriptions. use format version >= 0.5" 

401 ) 

402 

403 wf = descr.weights.keras_v3 

404 elif weight_format == "tensorflow_js": 

405 raise RuntimeError( 

406 "testing 'tensorflow_js' is not supported by bioimageio.core" 

407 ) 

408 else: 

409 assert_never(weight_format) 

410 

411 assert wf is not None 

412 if conda_env is None: 

413 conda_env = get_conda_env(entry=wf) 

414 

415 test_loc = ("weights", weight_format) 

416 else: 

417 if conda_env is None: 

418 warnings.warn( 

419 "No conda environment description given for testing (And no default conda envs available for non-model descriptions)." 

420 ) 

421 return 

422 

423 test_loc = () 

424 

425 # remove name as we create a name based on the env description hash value 

426 conda_env.name = None 

427 

428 dumped_env = conda_env.model_dump(mode="json", exclude_none=True) 

429 if not is_yaml_value(dumped_env): 

430 raise ValueError(f"Failed to dump conda env to valid YAML {conda_env}") 

431 

432 env_io = StringIO() 

433 write_yaml(dumped_env, file=env_io) 

434 encoded_env = env_io.getvalue().encode() 

435 env_name = hashlib.sha256(encoded_env).hexdigest() 

436 

437 try: 

438 run_command(["where" if platform.system() == "Windows" else "which", CONDA_CMD]) 

439 except Exception as e: 

440 raise RuntimeError("Conda not available") from e 

441 

442 try: 

443 run_command([CONDA_CMD, "run", "-n", env_name, "python", "--version"]) 

444 except Exception: 

445 working_dir.mkdir(parents=True, exist_ok=True) 

446 path = working_dir / "env.yaml" 

447 try: 

448 _ = path.write_bytes(encoded_env) 

449 logger.debug("written conda env to {}", path) 

450 run_command( 

451 [ 

452 CONDA_CMD, 

453 "env", 

454 "create", 

455 "--yes", 

456 f"--file={path}", 

457 f"--name={env_name}", 

458 ] 

459 + (["--quiet"] if settings.CI else []) 

460 ) 

461 # double check that environment was created successfully 

462 run_command([CONDA_CMD, "run", "-n", env_name, "python", "--version"]) 

463 except Exception as e: 

464 descr.validation_summary.add_detail( 

465 ValidationDetail( 

466 name="Conda environment creation", 

467 status="failed", 

468 loc=test_loc, 

469 recommended_env=conda_env, 

470 errors=[ 

471 ErrorEntry( 

472 loc=test_loc, 

473 msg=str(e), 

474 type="conda", 

475 with_traceback=True, 

476 ) 

477 ], 

478 ) 

479 ) 

480 return 

481 else: 

482 descr.validation_summary.add_detail( 

483 ValidationDetail( 

484 name=f"Created conda environment '{env_name}'", 

485 status="passed", 

486 loc=test_loc, 

487 ) 

488 ) 

489 else: 

490 descr.validation_summary.add_detail( 

491 ValidationDetail( 

492 name=f"Found existing conda environment '{env_name}'", 

493 status="passed", 

494 loc=test_loc, 

495 ) 

496 ) 

497 

498 working_dir.mkdir(parents=True, exist_ok=True) 

499 summary_path = working_dir / "summary.json" 

500 assert not summary_path.exists(), "Summary file already exists" 

501 cmd = [] 

502 cmd_error = None 

503 for summary_path_arg_name in ("summary", "summary-path"): 

504 try: 

505 run_command( 

506 cmd := ( 

507 [ 

508 CONDA_CMD, 

509 "run", 

510 "-n", 

511 env_name, 

512 "bioimageio", 

513 "test", 

514 str(source), 

515 f"--{summary_path_arg_name}={summary_path.as_posix()}", 

516 f"--determinism={determinism}", 

517 ] 

518 + ([f"--weight-format={weight_format}"] if weight_format else []) 

519 + ([f"--expected-type={expected_type}"] if expected_type else []) 

520 + (["--stop-early"] if stop_early else []) 

521 ) 

522 ) 

523 except Exception as e: 

524 cmd_error = f"Command '{' '.join(cmd)}' returned with error: {e}." 

525 

526 if summary_path.exists(): 

527 break 

528 else: 

529 if cmd_error is not None: 

530 logger.warning(cmd_error) 

531 

532 descr.validation_summary.add_detail( 

533 ValidationDetail( 

534 name="run 'bioimageio test' command", 

535 recommended_env=conda_env, 

536 errors=[ 

537 ErrorEntry( 

538 loc=(), 

539 type="bioimageio cli", 

540 msg=f"test command '{' '.join(cmd)}' did not produce a summary file at {summary_path}", 

541 ) 

542 ], 

543 status="failed", 

544 ) 

545 ) 

546 return 

547 

548 # add relevant details from command summary 

549 command_summary = ValidationSummary.load_json(summary_path) 

550 for detail in command_summary.details: 

551 if detail.loc[: len(test_loc)] == test_loc or detail.status == "failed": 

552 descr.validation_summary.add_detail(detail) 

553 

554 

555@overload 

556def load_description_and_test( 

557 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

558 *, 

559 format_version: Literal["latest"], 

560 weight_format: Optional[SupportedWeightsFormat] = None, 

561 devices: Optional[Sequence[str]] = None, 

562 determinism: Literal["seed_only", "full"] = "seed_only", 

563 expected_type: Literal["model"], 

564 sha256: Optional[Sha256] = None, 

565 stop_early: bool = False, 

566 working_dir: Optional[Union[os.PathLike[str], str]] = None, 

567 **deprecated: Unpack[DeprecatedKwargs], 

568) -> Union[ModelDescr, InvalidDescr]: ... 

569 

570 

571@overload 

572def load_description_and_test( 

573 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

574 *, 

575 format_version: Literal["latest"], 

576 weight_format: Optional[SupportedWeightsFormat] = None, 

577 devices: Optional[Sequence[str]] = None, 

578 determinism: Literal["seed_only", "full"] = "seed_only", 

579 expected_type: Literal["dataset"], 

580 sha256: Optional[Sha256] = None, 

581 stop_early: bool = False, 

582 working_dir: Optional[Union[os.PathLike[str], str]] = None, 

583 **deprecated: Unpack[DeprecatedKwargs], 

584) -> Union[DatasetDescr, InvalidDescr]: ... 

585 

586 

587@overload 

588def load_description_and_test( 

589 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

590 *, 

591 format_version: Literal["latest"], 

592 weight_format: Optional[SupportedWeightsFormat] = None, 

593 devices: Optional[Sequence[str]] = None, 

594 determinism: Literal["seed_only", "full"] = "seed_only", 

595 expected_type: Optional[str] = None, 

596 sha256: Optional[Sha256] = None, 

597 stop_early: bool = False, 

598 working_dir: Optional[Union[os.PathLike[str], str]] = None, 

599 **deprecated: Unpack[DeprecatedKwargs], 

600) -> Union[LatestResourceDescr, InvalidDescr]: ... 

601 

602 

603@overload 

604def load_description_and_test( 

605 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

606 *, 

607 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 

608 weight_format: Optional[SupportedWeightsFormat] = None, 

609 devices: Optional[Sequence[str]] = None, 

610 determinism: Literal["seed_only", "full"] = "seed_only", 

611 expected_type: Literal["model"], 

612 sha256: Optional[Sha256] = None, 

613 stop_early: bool = False, 

614 working_dir: Optional[Union[os.PathLike[str], str]] = None, 

615 **deprecated: Unpack[DeprecatedKwargs], 

616) -> Union[AnyModelDescr, InvalidDescr]: ... 

617 

618 

619@overload 

620def load_description_and_test( 

621 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

622 *, 

623 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 

624 weight_format: Optional[SupportedWeightsFormat] = None, 

625 devices: Optional[Sequence[str]] = None, 

626 determinism: Literal["seed_only", "full"] = "seed_only", 

627 expected_type: Literal["dataset"], 

628 sha256: Optional[Sha256] = None, 

629 stop_early: bool = False, 

630 working_dir: Optional[Union[os.PathLike[str], str]] = None, 

631 **deprecated: Unpack[DeprecatedKwargs], 

632) -> Union[AnyDatasetDescr, InvalidDescr]: ... 

633 

634 

635@overload 

636def load_description_and_test( 

637 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

638 *, 

639 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 

640 weight_format: Optional[SupportedWeightsFormat] = None, 

641 devices: Optional[Sequence[str]] = None, 

642 determinism: Literal["seed_only", "full"] = "seed_only", 

643 expected_type: Optional[str] = None, 

644 sha256: Optional[Sha256] = None, 

645 stop_early: bool = False, 

646 working_dir: Optional[Union[os.PathLike[str], str]] = None, 

647 **deprecated: Unpack[DeprecatedKwargs], 

648) -> Union[ResourceDescr, InvalidDescr]: ... 

649 

650 

651def load_description_and_test( 

652 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

653 *, 

654 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 

655 weight_format: Optional[SupportedWeightsFormat] = None, 

656 devices: Optional[Sequence[str]] = None, 

657 determinism: Literal["seed_only", "full"] = "seed_only", 

658 expected_type: Optional[str] = None, 

659 sha256: Optional[Sha256] = None, 

660 stop_early: bool = False, 

661 working_dir: Optional[Union[os.PathLike[str], str]] = None, 

662 **deprecated: Unpack[DeprecatedKwargs], 

663) -> Union[ResourceDescr, InvalidDescr]: 

664 """Test a bioimage.io resource dynamically, 

665 for example run prediction of test tensors for models. 

666 

667 See `test_description` for more details. 

668 

669 Returns: 

670 A (possibly invalid) resource description object 

671 with a populated `.validation_summary` attribute. 

672 """ 

673 if isinstance(source, ResourceDescrBase): 

674 root = source.root 

675 file_name = source.file_name 

676 if ( 

677 ( 

678 format_version 

679 not in ( 

680 DISCOVER, 

681 source.format_version, 

682 ".".join(source.format_version.split(".")[:2]), 

683 ) 

684 ) 

685 or (c := source.validation_summary.details[0].context) is None 

686 or not c.perform_io_checks 

687 ): 

688 logger.debug( 

689 "deserializing source to ensure we validate and test using format {} and perform io checks", 

690 format_version, 

691 ) 

692 source = dump_description(source) 

693 else: 

694 root = Path() 

695 file_name = None 

696 

697 if isinstance(source, ResourceDescrBase): 

698 rd = source 

699 elif isinstance(source, dict): 

700 # check context for a given root; default to root of source 

701 context = get_validation_context( 

702 ValidationContext(root=root, file_name=file_name) 

703 ).replace( 

704 perform_io_checks=True # make sure we perform io checks though 

705 ) 

706 

707 rd = build_description( 

708 source, 

709 format_version=format_version, 

710 context=context, 

711 ) 

712 else: 

713 rd = load_description( 

714 source, format_version=format_version, sha256=sha256, perform_io_checks=True 

715 ) 

716 

717 rd.validation_summary.env.add( 

718 InstalledPackage(name="bioimageio.core", version=__version__) 

719 ) 

720 

721 if expected_type is not None: 

722 has_expected_type = _test_expected_resource_type(rd, expected_type) 

723 if not has_expected_type: 

724 # unexpected type -> invalid format 

725 rd.validation_summary.status = "failed" 

726 return rd 

727 

728 # elevate status valid-format to passed and start testing 

729 if rd.validation_summary.status == "valid-format": 

730 rd.validation_summary.status = "passed" 

731 

732 if isinstance(rd, (v0_4.ModelDescr, v0_5.ModelDescr)): 

733 if weight_format is None: 

734 weight_formats: List[SupportedWeightsFormat] = [ 

735 w for w, we in rd.weights if we is not None 

736 ] # pyright: ignore[reportAssignmentType] 

737 else: 

738 weight_formats = [weight_format] 

739 

740 enable_determinism(determinism, weight_formats=weight_formats) 

741 for w in weight_formats: 

742 passed_recreate_test_outputs = _test_recreate_test_outputs( 

743 rd, 

744 w, 

745 devices, 

746 stop_early=stop_early, 

747 working_dir=working_dir, 

748 verbose=working_dir is not None, 

749 **deprecated, 

750 ) 

751 

752 if stop_early and not passed_recreate_test_outputs: 

753 break 

754 

755 if not isinstance(rd, v0_4.ModelDescr): 

756 passed_parametrized_inference = _test_parametrized_inference( 

757 rd, w, devices, stop_early=stop_early 

758 ) 

759 if stop_early and not passed_parametrized_inference: 

760 break 

761 

762 # TODO: add execution of jupyter notebooks 

763 # TODO: add more tests 

764 

765 return rd 

766 

767 

768def _get_tolerance( 

769 model: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

770 wf: SupportedWeightsFormat, 

771 m: MemberId, 

772 **deprecated: Unpack[DeprecatedKwargs], 

773) -> Tuple[RelativeTolerance, AbsoluteTolerance, MismatchedElementsPerMillion]: 

774 if isinstance(model, v0_5.ModelDescr): 

775 applicable = v0_5.ReproducibilityTolerance() 

776 

777 # check legacy test kwargs for weight format specific tolerance 

778 if model.config.bioimageio.model_extra is not None: 

779 for weights_format, test_kwargs in model.config.bioimageio.model_extra.get( 

780 "test_kwargs", {} 

781 ).items(): 

782 if wf == weights_format: 

783 applicable = v0_5.ReproducibilityTolerance( 

784 relative_tolerance=test_kwargs.get("relative_tolerance", 1e-3), 

785 absolute_tolerance=test_kwargs.get("absolute_tolerance", 1e-3), 

786 ) 

787 break 

788 

789 # check for weights format and output tensor specific tolerance 

790 for a in model.config.bioimageio.reproducibility_tolerance: 

791 if (not a.weights_formats or wf in a.weights_formats) and ( 

792 not a.output_ids or m in a.output_ids 

793 ): 

794 applicable = a 

795 break 

796 

797 rtol = applicable.relative_tolerance 

798 atol = applicable.absolute_tolerance 

799 mismatched_tol = applicable.mismatched_elements_per_million 

800 elif (decimal := deprecated.get("decimal")) is not None: 

801 warnings.warn( 

802 "The argument `decimal` has been deprecated in favour of" 

803 + " `relative_tolerance` and `absolute_tolerance`, with different" 

804 + " validation logic, using `numpy.testing.assert_allclose, see" 

805 + " 'https://numpy.org/doc/stable/reference/generated/" 

806 + " numpy.testing.assert_allclose.html'. Passing a value for `decimal`" 

807 + " will cause validation to revert to the old behaviour." 

808 ) 

809 atol = 1.5 * 10 ** (-decimal) 

810 rtol = 0 

811 mismatched_tol = 0 

812 else: 

813 # use given (deprecated) test kwargs 

814 atol = deprecated.get("absolute_tolerance", 1e-3) 

815 rtol = deprecated.get("relative_tolerance", 1e-3) 

816 mismatched_tol = 0 

817 

818 return rtol, atol, mismatched_tol 

819 

820 

821def _test_recreate_test_outputs( 

822 model: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

823 weight_format: SupportedWeightsFormat, 

824 devices: Optional[Sequence[str]], 

825 stop_early: bool, 

826 *, 

827 working_dir: Optional[Union[os.PathLike[str], str]], 

828 verbose: bool, 

829 **deprecated: Unpack[DeprecatedKwargs], 

830) -> bool: 

831 test_name = f"Reproduce test outputs from test inputs ({weight_format})" 

832 logger.debug("starting '{}'", test_name) 

833 error_entries: List[ErrorEntry] = [] 

834 warning_entries: List[WarningEntry] = [] 

835 

836 def add_error_entry(msg: str, with_traceback: bool = False): 

837 error_entries.append( 

838 ErrorEntry( 

839 loc=("weights", weight_format), 

840 msg=msg, 

841 type="bioimageio.core", 

842 with_traceback=with_traceback, 

843 ) 

844 ) 

845 

846 def add_warning_entry(msg: str, severity: WarningSeverity): 

847 warning_entries.append( 

848 WarningEntry( 

849 loc=("weights", weight_format), 

850 msg=msg, 

851 type="bioimageio.core", 

852 severity=severity, 

853 ) 

854 ) 

855 

856 def save_to_working_dir(name: str, tensor: Tensor) -> List[Path]: 

857 saved_paths: List[Path] = [] 

858 if working_dir is not None and verbose: 

859 for p in [ 

860 Path(working_dir) / f"{name}_{weight_format}{suffix}" 

861 for suffix in (".npy", ".tiff") 

862 ]: 

863 try: 

864 save_tensor(p, tensor) 

865 except Exception as e: 

866 logger.error( 

867 "Failed to save tensor {}: {}", 

868 p, 

869 e, 

870 ) 

871 else: 

872 saved_paths.append(p) 

873 

874 return saved_paths 

875 

876 try: 

877 test_input = get_test_input_sample(model) 

878 expected = get_test_output_sample(model) 

879 

880 with create_prediction_pipeline( 

881 bioimageio_model=model, devices=devices, weight_format=weight_format 

882 ) as prediction_pipeline: 

883 prediction_pipeline.apply_preprocessing(test_input) 

884 test_input_preprocessed = deepcopy(test_input) 

885 results_not_postprocessed = ( 

886 prediction_pipeline.predict_sample_without_blocking( 

887 test_input, skip_postprocessing=True, skip_preprocessing=True 

888 ) 

889 ) 

890 results = deepcopy(results_not_postprocessed) 

891 prediction_pipeline.apply_postprocessing(results) 

892 

893 if len(results.members) != len(expected.members): 

894 add_error_entry( 

895 f"Expected {len(expected.members)} outputs, but got {len(results.members)}" 

896 ) 

897 

898 else: 

899 intermediate_paths: List[Path] = [] 

900 for m, t in test_input_preprocessed.members.items(): 

901 intermediate_paths.extend( 

902 save_to_working_dir(f"test_input_preprocessed_{m}", t) 

903 ) 

904 if intermediate_paths: 

905 logger.debug("Saved preprocessed test inputs to {}", intermediate_paths) 

906 

907 for m, expected in expected.members.items(): 

908 actual = results.members.get(m) 

909 if actual is None: 

910 add_error_entry("Output tensors for test case may not be None") 

911 if stop_early: 

912 break 

913 else: 

914 continue 

915 

916 if actual.dims != (dims := expected.dims): 

917 add_error_entry( 

918 f"Output '{m}' has dims {actual.dims}, but expected {expected.dims}" 

919 ) 

920 if stop_early: 

921 break 

922 else: 

923 continue 

924 

925 if actual.tagged_shape != expected.tagged_shape: 

926 add_error_entry( 

927 f"Output '{m}' has shape {actual.tagged_shape}, but expected {expected.tagged_shape}" 

928 ) 

929 if stop_early: 

930 break 

931 else: 

932 continue 

933 

934 try: 

935 output_paths = save_to_working_dir(f"actual_output_{m}", actual) 

936 if m in results_not_postprocessed.members: 

937 output_paths.extend( 

938 save_to_working_dir( 

939 f"actual_output_{m}_not_postprocessed", 

940 results_not_postprocessed.members[m], 

941 ) 

942 ) 

943 

944 expected_np = expected.data.to_numpy().astype(np.float32) 

945 del expected 

946 actual_np: NDArray[Any] = actual.data.to_numpy().astype(np.float32) 

947 

948 rtol, atol, mismatched_tol = _get_tolerance( 

949 model, wf=weight_format, m=m, **deprecated 

950 ) 

951 rtol_value = rtol * abs(expected_np) 

952 abs_diff = abs(actual_np - expected_np) 

953 mismatched = abs_diff > atol + rtol_value 

954 mismatched_elements = mismatched.sum().item() 

955 

956 mismatched_ppm = mismatched_elements / expected_np.size * 1e6 

957 abs_diff[~mismatched] = 0 # ignore non-mismatched elements 

958 

959 r_max_idx_flat = ( 

960 r_diff := (abs_diff / (abs(expected_np) + 1e-6)) 

961 ).argmax() 

962 r_max_idx = np.unravel_index(r_max_idx_flat, r_diff.shape) 

963 r_max = r_diff[r_max_idx].item() 

964 r_actual = actual_np[r_max_idx].item() 

965 r_expected = expected_np[r_max_idx].item() 

966 

967 # Calculate the max absolute difference with the relative tolerance subtracted 

968 abs_diff_wo_rtol: NDArray[np.float32] = abs_diff - rtol_value 

969 a_max_idx = np.unravel_index( 

970 abs_diff_wo_rtol.argmax(), abs_diff_wo_rtol.shape 

971 ) 

972 

973 a_max = abs_diff[a_max_idx].item() 

974 a_actual = actual_np[a_max_idx].item() 

975 a_expected = expected_np[a_max_idx].item() 

976 except Exception as e: 

977 msg = f"Error while checking if '{m}' disagrees with expected values: {e}" 

978 add_error_entry(msg) 

979 if stop_early: 

980 break 

981 else: 

982 if mismatched_elements: 

983 msg = ( 

984 f"Output '{m}': {mismatched_elements} of " 

985 + f"{expected_np.size} elements disagree with expected values." 

986 + f" ({mismatched_ppm:.1f} ppm). " 

987 ) 

988 else: 

989 msg = f"Output `{m}`: all elements agree with expected values. " 

990 

991 msg += ( 

992 f"\nMax relative difference not accounted for by absolute tolerance ({atol:.2e}):\n{r_max:.2e}" 

993 + rf" (= \|{r_actual:.2e} - {r_expected:.2e}\|/\|{r_expected:.2e} + 1e-6\|)" 

994 + f" at {dict(zip(dims, r_max_idx))} " 

995 + f"\nMax absolute difference not accounted for by relative tolerance ({rtol:.2e}):\n{a_max:.2e}" 

996 + rf" (= \|{a_actual:.7e} - {a_expected:.7e}\|) at {dict(zip(dims, a_max_idx))}" 

997 ) 

998 if output_paths: 

999 msg += f"\n Saved (intermediate) outputs to {output_paths}." 

1000 

1001 if mismatched_ppm > mismatched_tol: 

1002 add_error_entry(msg) 

1003 if stop_early: 

1004 break 

1005 else: 

1006 add_warning_entry( 

1007 msg, severity=WARNING if mismatched_elements else INFO 

1008 ) 

1009 

1010 except Exception as e: 

1011 if get_validation_context().raise_errors: 

1012 raise e 

1013 

1014 add_error_entry(str(e), with_traceback=True) 

1015 

1016 model.validation_summary.add_detail( 

1017 ValidationDetail( 

1018 name=test_name, 

1019 loc=("weights", weight_format), 

1020 status="failed" if error_entries else "passed", 

1021 recommended_env=get_conda_env(entry=dict(model.weights)[weight_format]), 

1022 errors=error_entries, 

1023 warnings=warning_entries, 

1024 ) 

1025 ) 

1026 return bool(error_entries) 

1027 

1028 

1029def _test_parametrized_inference( 

1030 model: v0_5.ModelDescr, 

1031 weight_format: SupportedWeightsFormat, 

1032 devices: Optional[Sequence[str]], 

1033 *, 

1034 stop_early: bool, 

1035) -> None: 

1036 if not any( 

1037 isinstance(a.size, v0_5.ParameterizedSize) 

1038 for ipt in model.inputs 

1039 for a in ipt.axes 

1040 ): 

1041 # no parameterized sizes => set n=0 

1042 ns: Set[v0_5.ParameterizedSize_N] = {0} 

1043 else: 

1044 ns = {0, 1, 2} 

1045 

1046 given_batch_sizes = { 

1047 a.size 

1048 for ipt in model.inputs 

1049 for a in ipt.axes 

1050 if isinstance(a, v0_5.BatchAxis) 

1051 } 

1052 if given_batch_sizes: 

1053 batch_sizes = {gbs for gbs in given_batch_sizes if gbs is not None} 

1054 if not batch_sizes: 

1055 # only arbitrary batch sizes 

1056 batch_sizes = {1, 2} 

1057 else: 

1058 # no batch axis 

1059 batch_sizes = {1} 

1060 

1061 test_cases: Set[Tuple[BatchSize, v0_5.ParameterizedSize_N]] = { 

1062 (b, n) for b, n in product(sorted(batch_sizes), sorted(ns)) 

1063 } 

1064 logger.info( 

1065 "Testing inference with '{}' for {} different inputs (B, N): {}", 

1066 weight_format, 

1067 len(test_cases), 

1068 test_cases, 

1069 ) 

1070 

1071 def generate_test_cases(): 

1072 tested: Set[Hashable] = set() 

1073 

1074 def get_ns(n: int): 

1075 return { 

1076 (t.id, a.id): n 

1077 for t in model.inputs 

1078 for a in t.axes 

1079 if isinstance(a.size, v0_5.ParameterizedSize) 

1080 } 

1081 

1082 for batch_size, n in sorted(test_cases): 

1083 input_target_sizes, expected_output_sizes = model.get_axis_sizes( 

1084 get_ns(n), batch_size=batch_size 

1085 ) 

1086 hashable_target_size = tuple( 

1087 (k, input_target_sizes[k]) for k in sorted(input_target_sizes) 

1088 ) 

1089 if hashable_target_size in tested: 

1090 continue 

1091 else: 

1092 tested.add(hashable_target_size) 

1093 

1094 resized_test_inputs = Sample( 

1095 members={ 

1096 t.id: ( 

1097 test_input.members[t.id].resize_to( 

1098 { 

1099 aid: s 

1100 for (tid, aid), s in input_target_sizes.items() 

1101 if tid == t.id 

1102 }, 

1103 ) 

1104 ) 

1105 for t in model.inputs 

1106 }, 

1107 stat=test_input.stat, 

1108 id=test_input.id, 

1109 ) 

1110 expected_output_shapes = { 

1111 t.id: { 

1112 aid: s 

1113 for (tid, aid), s in expected_output_sizes.items() 

1114 if tid == t.id 

1115 } 

1116 for t in model.outputs 

1117 } 

1118 yield n, batch_size, resized_test_inputs, expected_output_shapes 

1119 

1120 try: 

1121 test_input = get_test_input_sample(model) 

1122 

1123 with create_prediction_pipeline( 

1124 bioimageio_model=model, devices=devices, weight_format=weight_format 

1125 ) as prediction_pipeline: 

1126 for n, batch_size, inputs, exptected_output_shape in generate_test_cases(): 

1127 error: Optional[str] = None 

1128 try: 

1129 result = prediction_pipeline.predict_sample_without_blocking(inputs) 

1130 except Exception as e: 

1131 error = str(e) 

1132 else: 

1133 if len(result.members) != len(exptected_output_shape): 

1134 error = ( 

1135 f"Expected {len(exptected_output_shape)} outputs," 

1136 + f" but got {len(result.members)}" 

1137 ) 

1138 

1139 else: 

1140 for m, exp in exptected_output_shape.items(): 

1141 res = result.members.get(m) 

1142 if res is None: 

1143 error = "Output tensors may not be None for test case" 

1144 break 

1145 

1146 diff: Dict[AxisId, int] = {} 

1147 for a, s in res.sizes.items(): 

1148 if isinstance((e_aid := exp[AxisId(a)]), int): 

1149 if s != e_aid: 

1150 diff[AxisId(a)] = s 

1151 elif ( 

1152 s < e_aid.min 

1153 or e_aid.max is not None 

1154 and s > e_aid.max 

1155 ): 

1156 diff[AxisId(a)] = s 

1157 if diff: 

1158 error = ( 

1159 f"(n={n}) Expected output shape {exp}," 

1160 + f" but got {res.sizes} (diff: {diff})" 

1161 ) 

1162 break 

1163 

1164 model.validation_summary.add_detail( 

1165 ValidationDetail( 

1166 name=f"Run {weight_format} inference for inputs with" 

1167 + f" batch_size: {batch_size} and size parameter n: {n}", 

1168 loc=("weights", weight_format), 

1169 status="passed" if error is None else "failed", 

1170 errors=( 

1171 [] 

1172 if error is None 

1173 else [ 

1174 ErrorEntry( 

1175 loc=("weights", weight_format), 

1176 msg=error, 

1177 type="bioimageio.core", 

1178 ) 

1179 ] 

1180 ), 

1181 ) 

1182 ) 

1183 if stop_early and error is not None: 

1184 break 

1185 except Exception as e: 

1186 if get_validation_context().raise_errors: 

1187 raise e 

1188 

1189 model.validation_summary.add_detail( 

1190 ValidationDetail( 

1191 name=f"Run {weight_format} inference for parametrized inputs", 

1192 status="failed", 

1193 loc=("weights", weight_format), 

1194 errors=[ 

1195 ErrorEntry( 

1196 loc=("weights", weight_format), 

1197 msg=str(e), 

1198 type="bioimageio.core", 

1199 with_traceback=True, 

1200 ) 

1201 ], 

1202 ) 

1203 ) 

1204 

1205 

1206def _test_expected_resource_type( 

1207 rd: Union[InvalidDescr, ResourceDescr], expected_type: str 

1208): 

1209 has_expected_type = rd.type is expected_type 

1210 rd.validation_summary.details.append( 

1211 ValidationDetail( 

1212 name="Has expected resource type", 

1213 status="passed" if has_expected_type else "failed", 

1214 loc=("type",), 

1215 errors=( 

1216 [] 

1217 if has_expected_type 

1218 else [ 

1219 ErrorEntry( 

1220 loc=("type",), 

1221 type="type", 

1222 msg=f"Expected type {expected_type}, found {rd.type}", 

1223 ) 

1224 ] 

1225 ), 

1226 ) 

1227 ) 

1228 return has_expected_type 

1229 

1230 

1231# TODO: Implement `debug_model()` 

1232# def debug_model( 

1233# model_rdf: Union[RawResourceDescr, ResourceDescr, URI, Path, str], 

1234# *, 

1235# weight_format: Optional[WeightsFormat] = None, 

1236# devices: Optional[List[str]] = None, 

1237# ): 

1238# """Run the model test and return dict with inputs, results, expected results and intermediates. 

1239 

1240# Returns dict with tensors "inputs", "inputs_processed", "outputs_raw", "outputs", "expected" and "diff". 

1241# """ 

1242# inputs_raw: Optional = None 

1243# inputs_processed: Optional = None 

1244# outputs_raw: Optional = None 

1245# outputs: Optional = None 

1246# expected: Optional = None 

1247# diff: Optional = None 

1248 

1249# model = load_description( 

1250# model_rdf, weights_priority_order=None if weight_format is None else [weight_format] 

1251# ) 

1252# if not isinstance(model, Model): 

1253# raise ValueError(f"Not a bioimageio.model: {model_rdf}") 

1254 

1255# prediction_pipeline = create_prediction_pipeline( 

1256# bioimageio_model=model, devices=devices, weight_format=weight_format 

1257# ) 

1258# inputs = [ 

1259# xr.DataArray(load_array(str(in_path)), dims=input_spec.axes) 

1260# for in_path, input_spec in zip(model.test_inputs, model.inputs) 

1261# ] 

1262# input_dict = {input_spec.name: input for input_spec, input in zip(model.inputs, inputs)} 

1263 

1264# # keep track of the non-processed inputs 

1265# inputs_raw = [deepcopy(input) for input in inputs] 

1266 

1267# computed_measures = {} 

1268 

1269# prediction_pipeline.apply_preprocessing(input_dict, computed_measures) 

1270# inputs_processed = list(input_dict.values()) 

1271# outputs_raw = prediction_pipeline.predict(*inputs_processed) 

1272# output_dict = {output_spec.name: deepcopy(output) for output_spec, output in zip(model.outputs, outputs_raw)} 

1273# prediction_pipeline.apply_postprocessing(output_dict, computed_measures) 

1274# outputs = list(output_dict.values()) 

1275 

1276# if isinstance(outputs, (np.ndarray, xr.DataArray)): 

1277# outputs = [outputs] 

1278 

1279# expected = [ 

1280# xr.DataArray(load_array(str(out_path)), dims=output_spec.axes) 

1281# for out_path, output_spec in zip(model.test_outputs, model.outputs) 

1282# ] 

1283# if len(outputs) != len(expected): 

1284# error = f"Number of outputs and number of expected outputs disagree: {len(outputs)} != {len(expected)}" 

1285# print(error) 

1286# else: 

1287# diff = [] 

1288# for res, exp in zip(outputs, expected): 

1289# diff.append(res - exp) 

1290 

1291# return { 

1292# "inputs": inputs_raw, 

1293# "inputs_processed": inputs_processed, 

1294# "outputs_raw": outputs_raw, 

1295# "outputs": outputs, 

1296# "expected": expected, 

1297# "diff": diff, 

1298# }