Coverage for src / bioimageio / core / _resource_tests.py: 73%

418 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-30 13:23 +0000

1import hashlib 

2import os 

3import platform 

4import subprocess 

5import sys 

6import warnings 

7from contextlib import nullcontext 

8from copy import deepcopy 

9from io import StringIO 

10from itertools import product 

11from pathlib import Path 

12from tempfile import TemporaryDirectory 

13from typing import ( 

14 Any, 

15 Callable, 

16 Dict, 

17 Hashable, 

18 List, 

19 Literal, 

20 Optional, 

21 Sequence, 

22 Set, 

23 Tuple, 

24 Union, 

25 overload, 

26) 

27 

28import numpy as np 

29from loguru import logger 

30from numpy.typing import NDArray 

31from typing_extensions import NotRequired, TypedDict, Unpack, assert_never, get_args 

32 

33from bioimageio.core.tensor import Tensor 

34from bioimageio.spec import ( 

35 AnyDatasetDescr, 

36 AnyModelDescr, 

37 BioimageioCondaEnv, 

38 DatasetDescr, 

39 InvalidDescr, 

40 LatestResourceDescr, 

41 ModelDescr, 

42 ResourceDescr, 

43 ValidationContext, 

44 build_description, 

45 dump_description, 

46 get_conda_env, 

47 load_description, 

48 save_bioimageio_package, 

49) 

50from bioimageio.spec._description_impl import DISCOVER 

51from bioimageio.spec._internal.common_nodes import ResourceDescrBase 

52from bioimageio.spec._internal.io import is_yaml_value 

53from bioimageio.spec._internal.io_utils import read_yaml, write_yaml 

54from bioimageio.spec._internal.types import ( 

55 AbsoluteTolerance, 

56 FormatVersionPlaceholder, 

57 MismatchedElementsPerMillion, 

58 RelativeTolerance, 

59) 

60from bioimageio.spec._internal.validation_context import get_validation_context 

61from bioimageio.spec._internal.warning_levels import INFO, WARNING, WarningSeverity 

62from bioimageio.spec.common import BioimageioYamlContent, PermissiveFileSource, Sha256 

63from bioimageio.spec.model import v0_4, v0_5 

64from bioimageio.spec.model.v0_5 import WeightsFormat 

65from bioimageio.spec.summary import ( 

66 ErrorEntry, 

67 InstalledPackage, 

68 ValidationDetail, 

69 ValidationSummary, 

70 WarningEntry, 

71) 

72 

73from . import __version__ 

74from ._prediction_pipeline import create_prediction_pipeline 

75from ._settings import settings 

76from .axis import AxisId, BatchSize 

77from .common import MemberId, SupportedWeightsFormat 

78from .digest_spec import get_test_input_sample, get_test_output_sample 

79from .io import save_tensor 

80from .sample import Sample 

81 

82CONDA_CMD = "conda.bat" if platform.system() == "Windows" else "conda" 

83 

84 

85class DeprecatedKwargs(TypedDict): 

86 absolute_tolerance: NotRequired[AbsoluteTolerance] 

87 relative_tolerance: NotRequired[RelativeTolerance] 

88 decimal: NotRequired[Optional[int]] 

89 

90 

91def enable_determinism( 

92 mode: Literal["seed_only", "full"] = "full", 

93 weight_formats: Optional[Sequence[SupportedWeightsFormat]] = None, 

94): 

95 """Seed and configure ML frameworks for maximum reproducibility. 

96 May degrade performance. Only recommended for testing reproducibility! 

97 

98 Seed any random generators and (if **mode**=="full") request ML frameworks to use 

99 deterministic algorithms. 

100 

101 Args: 

102 mode: determinism mode 

103 - 'seed_only' -- only set seeds, or 

104 - 'full' determinsm features (might degrade performance or throw exceptions) 

105 weight_formats: Limit deep learning importing deep learning frameworks 

106 based on weight_formats. 

107 E.g. this allows to avoid importing tensorflow when testing with pytorch. 

108 

109 Notes: 

110 - **mode** == "full" might degrade performance or throw exceptions. 

111 - Subsequent inference calls might still differ. Call before each function 

112 (sequence) that is expected to be reproducible. 

113 - Degraded performance: Use for testing reproducibility only! 

114 - Recipes: 

115 - [PyTorch](https://pytorch.org/docs/stable/notes/randomness.html) 

116 - [Keras](https://keras.io/examples/keras_recipes/reproducibility_recipes/) 

117 - [NumPy](https://numpy.org/doc/2.0/reference/random/generated/numpy.random.seed.html) 

118 """ 

119 try: 

120 try: 

121 import numpy.random 

122 except ImportError: 

123 pass 

124 else: 

125 numpy.random.seed(0) 

126 except Exception as e: 

127 logger.debug(str(e)) 

128 

129 if ( 

130 weight_formats is None 

131 or "pytorch_state_dict" in weight_formats 

132 or "torchscript" in weight_formats 

133 ): 

134 try: 

135 try: 

136 import torch 

137 except ImportError: 

138 pass 

139 else: 

140 _ = torch.manual_seed(0) 

141 torch.use_deterministic_algorithms(mode == "full") 

142 except Exception as e: 

143 logger.debug(str(e)) 

144 

145 if ( 

146 weight_formats is None 

147 or "tensorflow_saved_model_bundle" in weight_formats 

148 or "keras_hdf5" in weight_formats 

149 ): 

150 try: 

151 os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" 

152 try: 

153 import tensorflow as tf 

154 except ImportError: 

155 pass 

156 else: 

157 tf.random.set_seed(0) 

158 if mode == "full": 

159 tf.config.experimental.enable_op_determinism() 

160 # TODO: find possibility to switch it off again?? 

161 except Exception as e: 

162 logger.debug(str(e)) 

163 

164 if weight_formats is None or "keras_hdf5" in weight_formats: 

165 try: 

166 try: 

167 import keras # pyright: ignore[reportMissingTypeStubs] 

168 except ImportError: 

169 pass 

170 else: 

171 keras.utils.set_random_seed(0) 

172 except Exception as e: 

173 logger.debug(str(e)) 

174 

175 

176def test_model( 

177 source: Union[v0_4.ModelDescr, v0_5.ModelDescr, PermissiveFileSource], 

178 weight_format: Optional[SupportedWeightsFormat] = None, 

179 devices: Optional[List[str]] = None, 

180 *, 

181 determinism: Literal["seed_only", "full"] = "seed_only", 

182 sha256: Optional[Sha256] = None, 

183 stop_early: bool = True, 

184 working_dir: Optional[Union[os.PathLike[str], str]] = None, 

185 **deprecated: Unpack[DeprecatedKwargs], 

186) -> ValidationSummary: 

187 """Test model inference""" 

188 return test_description( 

189 source, 

190 weight_format=weight_format, 

191 devices=devices, 

192 determinism=determinism, 

193 expected_type="model", 

194 sha256=sha256, 

195 stop_early=stop_early, 

196 working_dir=working_dir, 

197 **deprecated, 

198 ) 

199 

200 

201def default_run_command(args: Sequence[str]): 

202 logger.info("running '{}'...", " ".join(args)) 

203 _ = subprocess.check_call(args) 

204 

205 

206def test_description( 

207 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

208 *, 

209 format_version: Union[FormatVersionPlaceholder, str] = "discover", 

210 weight_format: Optional[SupportedWeightsFormat] = None, 

211 devices: Optional[Sequence[str]] = None, 

212 determinism: Literal["seed_only", "full"] = "seed_only", 

213 expected_type: Optional[str] = None, 

214 sha256: Optional[Sha256] = None, 

215 stop_early: bool = True, 

216 runtime_env: Union[ 

217 Literal["currently-active", "as-described"], Path, BioimageioCondaEnv 

218 ] = ("currently-active"), 

219 run_command: Callable[[Sequence[str]], None] = default_run_command, 

220 working_dir: Optional[Union[os.PathLike[str], str]] = None, 

221 **deprecated: Unpack[DeprecatedKwargs], 

222) -> ValidationSummary: 

223 """Test a bioimage.io resource dynamically, 

224 for example run prediction of test tensors for models. 

225 

226 Args: 

227 source: model description source. 

228 weight_format: Weight format to test. 

229 Default: All weight formats present in **source**. 

230 devices: Devices to test with, e.g. 'cpu', 'cuda'. 

231 Default (may be weight format dependent): ['cuda'] if available, ['cpu'] otherwise. 

232 determinism: Modes to improve reproducibility of test outputs. 

233 expected_type: Assert an expected resource description `type`. 

234 sha256: Expected SHA256 value of **source**. 

235 (Ignored if **source** already is a loaded `ResourceDescr` object.) 

236 stop_early: Do not run further subtests after a failed one. 

237 runtime_env: (Experimental feature!) The Python environment to run the tests in 

238 - `"currently-active"`: Use active Python interpreter. 

239 - `"as-described"`: Use `bioimageio.spec.get_conda_env` to generate a conda 

240 environment YAML file based on the model weights description. 

241 - A `BioimageioCondaEnv` or a path to a conda environment YAML file. 

242 Note: The `bioimageio.core` dependency will be added automatically if not present. 

243 run_command: (Experimental feature!) Function to execute (conda) terminal commands in a subprocess. 

244 The function should raise an exception if the command fails. 

245 **run_command** is ignored if **runtime_env** is `"currently-active"`. 

246 working_dir: (for debugging) directory to save any temporary files 

247 (model packages, conda environments, test summaries). 

248 Defaults to a temporary directory. 

249 """ 

250 if runtime_env == "currently-active": 

251 rd = load_description_and_test( 

252 source, 

253 format_version=format_version, 

254 weight_format=weight_format, 

255 devices=devices, 

256 determinism=determinism, 

257 expected_type=expected_type, 

258 sha256=sha256, 

259 stop_early=stop_early, 

260 working_dir=working_dir, 

261 **deprecated, 

262 ) 

263 return rd.validation_summary 

264 

265 if runtime_env == "as-described": 

266 conda_env = None 

267 elif isinstance(runtime_env, (str, Path)): 

268 conda_env = BioimageioCondaEnv.model_validate(read_yaml(Path(runtime_env))) 

269 elif isinstance(runtime_env, BioimageioCondaEnv): 

270 conda_env = runtime_env 

271 else: 

272 assert_never(runtime_env) 

273 

274 if run_command is not default_run_command: 

275 try: 

276 run_command(["thiscommandshouldalwaysfail", "please"]) 

277 except Exception: 

278 pass 

279 else: 

280 raise RuntimeError( 

281 "given run_command does not raise an exception for a failing command" 

282 ) 

283 

284 verbose = working_dir is not None 

285 if working_dir is None: 

286 td_kwargs: Dict[str, Any] = ( 

287 dict(ignore_cleanup_errors=True) if sys.version_info >= (3, 10) else {} 

288 ) 

289 working_dir_ctxt = TemporaryDirectory(**td_kwargs) 

290 else: 

291 working_dir_ctxt = nullcontext(working_dir) 

292 

293 with working_dir_ctxt as _d: 

294 working_dir = Path(_d) 

295 

296 if isinstance(source, ResourceDescrBase): 

297 descr = source 

298 elif isinstance(source, dict): 

299 context = get_validation_context().replace( 

300 perform_io_checks=True # make sure we perform io checks though 

301 ) 

302 

303 descr = build_description(source, context=context) 

304 else: 

305 descr = load_description(source, perform_io_checks=True) 

306 

307 if isinstance(descr, InvalidDescr): 

308 return descr.validation_summary 

309 elif isinstance(source, (dict, ResourceDescrBase)): 

310 file_source = save_bioimageio_package( 

311 descr, output_path=working_dir / "package.zip" 

312 ) 

313 else: 

314 file_source = source 

315 

316 _test_in_env( 

317 file_source, 

318 descr=descr, 

319 working_dir=working_dir, 

320 weight_format=weight_format, 

321 conda_env=conda_env, 

322 devices=devices, 

323 determinism=determinism, 

324 expected_type=expected_type, 

325 sha256=sha256, 

326 stop_early=stop_early, 

327 run_command=run_command, 

328 verbose=verbose, 

329 **deprecated, 

330 ) 

331 

332 return descr.validation_summary 

333 

334 

335def _test_in_env( 

336 source: PermissiveFileSource, 

337 *, 

338 descr: ResourceDescr, 

339 working_dir: Path, 

340 weight_format: Optional[SupportedWeightsFormat], 

341 conda_env: Optional[BioimageioCondaEnv], 

342 devices: Optional[Sequence[str]], 

343 determinism: Literal["seed_only", "full"], 

344 run_command: Callable[[Sequence[str]], None], 

345 stop_early: bool, 

346 expected_type: Optional[str], 

347 sha256: Optional[Sha256], 

348 verbose: bool, 

349 **deprecated: Unpack[DeprecatedKwargs], 

350): 

351 """Test a bioimage.io resource in a given conda environment. 

352 Adds details to the existing validation summary of **descr**. 

353 """ 

354 if isinstance(descr, (v0_4.ModelDescr, v0_5.ModelDescr)): 

355 if weight_format is None: 

356 # run tests for all present weight formats 

357 all_present_wfs = [ 

358 wf for wf in get_args(WeightsFormat) if getattr(descr.weights, wf) 

359 ] 

360 ignore_wfs = [wf for wf in all_present_wfs if wf in ["tensorflow_js"]] 

361 logger.info( 

362 "Found weight formats {}. Start testing all{}...", 

363 all_present_wfs, 

364 f" (except: {', '.join(ignore_wfs)}) " if ignore_wfs else "", 

365 ) 

366 for wf in all_present_wfs: 

367 _test_in_env( 

368 source, 

369 descr=descr, 

370 working_dir=working_dir / wf, 

371 weight_format=wf, 

372 devices=devices, 

373 determinism=determinism, 

374 conda_env=conda_env, 

375 run_command=run_command, 

376 expected_type=expected_type, 

377 sha256=sha256, 

378 stop_early=stop_early, 

379 verbose=verbose, 

380 **deprecated, 

381 ) 

382 

383 return 

384 

385 if weight_format == "pytorch_state_dict": 

386 wf = descr.weights.pytorch_state_dict 

387 elif weight_format == "torchscript": 

388 wf = descr.weights.torchscript 

389 elif weight_format == "keras_hdf5": 

390 wf = descr.weights.keras_hdf5 

391 elif weight_format == "onnx": 

392 wf = descr.weights.onnx 

393 elif weight_format == "tensorflow_saved_model_bundle": 

394 wf = descr.weights.tensorflow_saved_model_bundle 

395 elif weight_format == "keras_v3": 

396 if isinstance(descr, v0_4.ModelDescr): 

397 raise ValueError( 

398 "Weight format 'keras_v3' is not supported in v0.4 model descriptions. use format version >= 0.5" 

399 ) 

400 

401 wf = descr.weights.keras_v3 

402 elif weight_format == "tensorflow_js": 

403 raise RuntimeError( 

404 "testing 'tensorflow_js' is not supported by bioimageio.core" 

405 ) 

406 else: 

407 assert_never(weight_format) 

408 

409 assert wf is not None 

410 if conda_env is None: 

411 conda_env = get_conda_env(entry=wf) 

412 

413 test_loc = ("weights", weight_format) 

414 else: 

415 if conda_env is None: 

416 warnings.warn( 

417 "No conda environment description given for testing (And no default conda envs available for non-model descriptions)." 

418 ) 

419 return 

420 

421 test_loc = () 

422 

423 # remove name as we create a name based on the env description hash value 

424 conda_env.name = None 

425 

426 dumped_env = conda_env.model_dump(mode="json", exclude_none=True) 

427 if not is_yaml_value(dumped_env): 

428 raise ValueError(f"Failed to dump conda env to valid YAML {conda_env}") 

429 

430 env_io = StringIO() 

431 write_yaml(dumped_env, file=env_io) 

432 encoded_env = env_io.getvalue().encode() 

433 env_name = hashlib.sha256(encoded_env).hexdigest() 

434 

435 try: 

436 run_command(["where" if platform.system() == "Windows" else "which", CONDA_CMD]) 

437 except Exception as e: 

438 raise RuntimeError("Conda not available") from e 

439 

440 try: 

441 run_command([CONDA_CMD, "run", "-n", env_name, "python", "--version"]) 

442 except Exception: 

443 working_dir.mkdir(parents=True, exist_ok=True) 

444 path = working_dir / "env.yaml" 

445 try: 

446 _ = path.write_bytes(encoded_env) 

447 logger.debug("written conda env to {}", path) 

448 run_command( 

449 [ 

450 CONDA_CMD, 

451 "env", 

452 "create", 

453 "--yes", 

454 f"--file={path}", 

455 f"--name={env_name}", 

456 ] 

457 + (["--quiet"] if settings.CI else []) 

458 ) 

459 # double check that environment was created successfully 

460 run_command([CONDA_CMD, "run", "-n", env_name, "python", "--version"]) 

461 except Exception as e: 

462 descr.validation_summary.add_detail( 

463 ValidationDetail( 

464 name="Conda environment creation", 

465 status="failed", 

466 loc=test_loc, 

467 recommended_env=conda_env, 

468 errors=[ 

469 ErrorEntry( 

470 loc=test_loc, 

471 msg=str(e), 

472 type="conda", 

473 with_traceback=True, 

474 ) 

475 ], 

476 ) 

477 ) 

478 return 

479 else: 

480 descr.validation_summary.add_detail( 

481 ValidationDetail( 

482 name=f"Created conda environment '{env_name}'", 

483 status="passed", 

484 loc=test_loc, 

485 ) 

486 ) 

487 else: 

488 descr.validation_summary.add_detail( 

489 ValidationDetail( 

490 name=f"Found existing conda environment '{env_name}'", 

491 status="passed", 

492 loc=test_loc, 

493 ) 

494 ) 

495 

496 working_dir.mkdir(parents=True, exist_ok=True) 

497 summary_path = working_dir / "summary.json" 

498 assert not summary_path.exists(), "Summary file already exists" 

499 cmd = [] 

500 cmd_error = None 

501 for summary_path_arg_name in ("summary", "summary-path"): 

502 try: 

503 run_command( 

504 cmd := ( 

505 [ 

506 CONDA_CMD, 

507 "run", 

508 "-n", 

509 env_name, 

510 "bioimageio", 

511 "test", 

512 str(source), 

513 f"--{summary_path_arg_name}={summary_path.as_posix()}", 

514 f"--determinism={determinism}", 

515 ] 

516 + ([f"--weight-format={weight_format}"] if weight_format else []) 

517 + ([f"--expected-type={expected_type}"] if expected_type else []) 

518 + (["--stop-early"] if stop_early else []) 

519 ) 

520 ) 

521 except Exception as e: 

522 cmd_error = f"Command '{' '.join(cmd)}' returned with error: {e}." 

523 

524 if summary_path.exists(): 

525 break 

526 else: 

527 if cmd_error is not None: 

528 logger.warning(cmd_error) 

529 

530 descr.validation_summary.add_detail( 

531 ValidationDetail( 

532 name="run 'bioimageio test' command", 

533 recommended_env=conda_env, 

534 errors=[ 

535 ErrorEntry( 

536 loc=(), 

537 type="bioimageio cli", 

538 msg=f"test command '{' '.join(cmd)}' did not produce a summary file at {summary_path}", 

539 ) 

540 ], 

541 status="failed", 

542 ) 

543 ) 

544 return 

545 

546 # add relevant details from command summary 

547 command_summary = ValidationSummary.load_json(summary_path) 

548 for detail in command_summary.details: 

549 if detail.loc[: len(test_loc)] == test_loc or detail.status == "failed": 

550 descr.validation_summary.add_detail(detail) 

551 

552 

553@overload 

554def load_description_and_test( 

555 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

556 *, 

557 format_version: Literal["latest"], 

558 weight_format: Optional[SupportedWeightsFormat] = None, 

559 devices: Optional[Sequence[str]] = None, 

560 determinism: Literal["seed_only", "full"] = "seed_only", 

561 expected_type: Literal["model"], 

562 sha256: Optional[Sha256] = None, 

563 stop_early: bool = True, 

564 working_dir: Optional[Union[os.PathLike[str], str]] = None, 

565 **deprecated: Unpack[DeprecatedKwargs], 

566) -> Union[ModelDescr, InvalidDescr]: ... 

567 

568 

569@overload 

570def load_description_and_test( 

571 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

572 *, 

573 format_version: Literal["latest"], 

574 weight_format: Optional[SupportedWeightsFormat] = None, 

575 devices: Optional[Sequence[str]] = None, 

576 determinism: Literal["seed_only", "full"] = "seed_only", 

577 expected_type: Literal["dataset"], 

578 sha256: Optional[Sha256] = None, 

579 stop_early: bool = True, 

580 working_dir: Optional[Union[os.PathLike[str], str]] = None, 

581 **deprecated: Unpack[DeprecatedKwargs], 

582) -> Union[DatasetDescr, InvalidDescr]: ... 

583 

584 

585@overload 

586def load_description_and_test( 

587 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

588 *, 

589 format_version: Literal["latest"], 

590 weight_format: Optional[SupportedWeightsFormat] = None, 

591 devices: Optional[Sequence[str]] = None, 

592 determinism: Literal["seed_only", "full"] = "seed_only", 

593 expected_type: Optional[str] = None, 

594 sha256: Optional[Sha256] = None, 

595 stop_early: bool = True, 

596 working_dir: Optional[Union[os.PathLike[str], str]] = None, 

597 **deprecated: Unpack[DeprecatedKwargs], 

598) -> Union[LatestResourceDescr, InvalidDescr]: ... 

599 

600 

601@overload 

602def load_description_and_test( 

603 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

604 *, 

605 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 

606 weight_format: Optional[SupportedWeightsFormat] = None, 

607 devices: Optional[Sequence[str]] = None, 

608 determinism: Literal["seed_only", "full"] = "seed_only", 

609 expected_type: Literal["model"], 

610 sha256: Optional[Sha256] = None, 

611 stop_early: bool = True, 

612 working_dir: Optional[Union[os.PathLike[str], str]] = None, 

613 **deprecated: Unpack[DeprecatedKwargs], 

614) -> Union[AnyModelDescr, InvalidDescr]: ... 

615 

616 

617@overload 

618def load_description_and_test( 

619 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

620 *, 

621 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 

622 weight_format: Optional[SupportedWeightsFormat] = None, 

623 devices: Optional[Sequence[str]] = None, 

624 determinism: Literal["seed_only", "full"] = "seed_only", 

625 expected_type: Literal["dataset"], 

626 sha256: Optional[Sha256] = None, 

627 stop_early: bool = True, 

628 working_dir: Optional[Union[os.PathLike[str], str]] = None, 

629 **deprecated: Unpack[DeprecatedKwargs], 

630) -> Union[AnyDatasetDescr, InvalidDescr]: ... 

631 

632 

633@overload 

634def load_description_and_test( 

635 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

636 *, 

637 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 

638 weight_format: Optional[SupportedWeightsFormat] = None, 

639 devices: Optional[Sequence[str]] = None, 

640 determinism: Literal["seed_only", "full"] = "seed_only", 

641 expected_type: Optional[str] = None, 

642 sha256: Optional[Sha256] = None, 

643 stop_early: bool = True, 

644 working_dir: Optional[Union[os.PathLike[str], str]] = None, 

645 **deprecated: Unpack[DeprecatedKwargs], 

646) -> Union[ResourceDescr, InvalidDescr]: ... 

647 

648 

649def load_description_and_test( 

650 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

651 *, 

652 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 

653 weight_format: Optional[SupportedWeightsFormat] = None, 

654 devices: Optional[Sequence[str]] = None, 

655 determinism: Literal["seed_only", "full"] = "seed_only", 

656 expected_type: Optional[str] = None, 

657 sha256: Optional[Sha256] = None, 

658 stop_early: bool = True, 

659 working_dir: Optional[Union[os.PathLike[str], str]] = None, 

660 **deprecated: Unpack[DeprecatedKwargs], 

661) -> Union[ResourceDescr, InvalidDescr]: 

662 """Test a bioimage.io resource dynamically, 

663 for example run prediction of test tensors for models. 

664 

665 See `test_description` for more details. 

666 

667 Returns: 

668 A (possibly invalid) resource description object 

669 with a populated `.validation_summary` attribute. 

670 """ 

671 if isinstance(source, ResourceDescrBase): 

672 root = source.root 

673 file_name = source.file_name 

674 if ( 

675 ( 

676 format_version 

677 not in ( 

678 DISCOVER, 

679 source.format_version, 

680 ".".join(source.format_version.split(".")[:2]), 

681 ) 

682 ) 

683 or (c := source.validation_summary.details[0].context) is None 

684 or not c.perform_io_checks 

685 ): 

686 logger.debug( 

687 "deserializing source to ensure we validate and test using format {} and perform io checks", 

688 format_version, 

689 ) 

690 source = dump_description(source) 

691 else: 

692 root = Path() 

693 file_name = None 

694 

695 if isinstance(source, ResourceDescrBase): 

696 rd = source 

697 elif isinstance(source, dict): 

698 # check context for a given root; default to root of source 

699 context = get_validation_context( 

700 ValidationContext(root=root, file_name=file_name) 

701 ).replace( 

702 perform_io_checks=True # make sure we perform io checks though 

703 ) 

704 

705 rd = build_description( 

706 source, 

707 format_version=format_version, 

708 context=context, 

709 ) 

710 else: 

711 rd = load_description( 

712 source, format_version=format_version, sha256=sha256, perform_io_checks=True 

713 ) 

714 

715 rd.validation_summary.env.add( 

716 InstalledPackage(name="bioimageio.core", version=__version__) 

717 ) 

718 

719 if expected_type is not None: 

720 _test_expected_resource_type(rd, expected_type) 

721 

722 if isinstance(rd, (v0_4.ModelDescr, v0_5.ModelDescr)): 

723 if weight_format is None: 

724 weight_formats: List[SupportedWeightsFormat] = [ 

725 w for w, we in rd.weights if we is not None 

726 ] # pyright: ignore[reportAssignmentType] 

727 else: 

728 weight_formats = [weight_format] 

729 

730 enable_determinism(determinism, weight_formats=weight_formats) 

731 for w in weight_formats: 

732 _test_model_inference( 

733 rd, 

734 w, 

735 devices, 

736 stop_early=stop_early, 

737 working_dir=working_dir, 

738 verbose=working_dir is not None, 

739 **deprecated, 

740 ) 

741 if stop_early and rd.validation_summary.status != "passed": 

742 break 

743 

744 if not isinstance(rd, v0_4.ModelDescr): 

745 _test_model_inference_parametrized( 

746 rd, w, devices, stop_early=stop_early 

747 ) 

748 if stop_early and rd.validation_summary.status != "passed": 

749 break 

750 

751 # TODO: add execution of jupyter notebooks 

752 # TODO: add more tests 

753 

754 return rd 

755 

756 

757def _get_tolerance( 

758 model: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

759 wf: SupportedWeightsFormat, 

760 m: MemberId, 

761 **deprecated: Unpack[DeprecatedKwargs], 

762) -> Tuple[RelativeTolerance, AbsoluteTolerance, MismatchedElementsPerMillion]: 

763 if isinstance(model, v0_5.ModelDescr): 

764 applicable = v0_5.ReproducibilityTolerance() 

765 

766 # check legacy test kwargs for weight format specific tolerance 

767 if model.config.bioimageio.model_extra is not None: 

768 for weights_format, test_kwargs in model.config.bioimageio.model_extra.get( 

769 "test_kwargs", {} 

770 ).items(): 

771 if wf == weights_format: 

772 applicable = v0_5.ReproducibilityTolerance( 

773 relative_tolerance=test_kwargs.get("relative_tolerance", 1e-3), 

774 absolute_tolerance=test_kwargs.get("absolute_tolerance", 1e-3), 

775 ) 

776 break 

777 

778 # check for weights format and output tensor specific tolerance 

779 for a in model.config.bioimageio.reproducibility_tolerance: 

780 if (not a.weights_formats or wf in a.weights_formats) and ( 

781 not a.output_ids or m in a.output_ids 

782 ): 

783 applicable = a 

784 break 

785 

786 rtol = applicable.relative_tolerance 

787 atol = applicable.absolute_tolerance 

788 mismatched_tol = applicable.mismatched_elements_per_million 

789 elif (decimal := deprecated.get("decimal")) is not None: 

790 warnings.warn( 

791 "The argument `decimal` has been deprecated in favour of" 

792 + " `relative_tolerance` and `absolute_tolerance`, with different" 

793 + " validation logic, using `numpy.testing.assert_allclose, see" 

794 + " 'https://numpy.org/doc/stable/reference/generated/" 

795 + " numpy.testing.assert_allclose.html'. Passing a value for `decimal`" 

796 + " will cause validation to revert to the old behaviour." 

797 ) 

798 atol = 1.5 * 10 ** (-decimal) 

799 rtol = 0 

800 mismatched_tol = 0 

801 else: 

802 # use given (deprecated) test kwargs 

803 atol = deprecated.get("absolute_tolerance", 1e-3) 

804 rtol = deprecated.get("relative_tolerance", 1e-3) 

805 mismatched_tol = 0 

806 

807 return rtol, atol, mismatched_tol 

808 

809 

810def _test_model_inference( 

811 model: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

812 weight_format: SupportedWeightsFormat, 

813 devices: Optional[Sequence[str]], 

814 stop_early: bool, 

815 *, 

816 working_dir: Optional[Union[os.PathLike[str], str]], 

817 verbose: bool, 

818 **deprecated: Unpack[DeprecatedKwargs], 

819) -> None: 

820 test_name = f"Reproduce test outputs from test inputs ({weight_format})" 

821 logger.debug("starting '{}'", test_name) 

822 error_entries: List[ErrorEntry] = [] 

823 warning_entries: List[WarningEntry] = [] 

824 

825 def add_error_entry(msg: str, with_traceback: bool = False): 

826 error_entries.append( 

827 ErrorEntry( 

828 loc=("weights", weight_format), 

829 msg=msg, 

830 type="bioimageio.core", 

831 with_traceback=with_traceback, 

832 ) 

833 ) 

834 

835 def add_warning_entry(msg: str, severity: WarningSeverity): 

836 warning_entries.append( 

837 WarningEntry( 

838 loc=("weights", weight_format), 

839 msg=msg, 

840 type="bioimageio.core", 

841 severity=severity, 

842 ) 

843 ) 

844 

845 def save_to_working_dir(name: str, tensor: Tensor) -> List[Path]: 

846 saved_paths: List[Path] = [] 

847 if working_dir is not None and verbose: 

848 for p in [ 

849 Path(working_dir) / f"{name}_{weight_format}{suffix}" 

850 for suffix in (".npy", ".tiff") 

851 ]: 

852 try: 

853 save_tensor(p, tensor) 

854 except Exception as e: 

855 logger.error( 

856 "Failed to save tensor {}: {}", 

857 p, 

858 e, 

859 ) 

860 else: 

861 saved_paths.append(p) 

862 

863 return saved_paths 

864 

865 try: 

866 test_input = get_test_input_sample(model) 

867 expected = get_test_output_sample(model) 

868 

869 with create_prediction_pipeline( 

870 bioimageio_model=model, devices=devices, weight_format=weight_format 

871 ) as prediction_pipeline: 

872 prediction_pipeline.apply_preprocessing(test_input) 

873 test_input_preprocessed = deepcopy(test_input) 

874 results_not_postprocessed = ( 

875 prediction_pipeline.predict_sample_without_blocking( 

876 test_input, skip_postprocessing=True, skip_preprocessing=True 

877 ) 

878 ) 

879 results = deepcopy(results_not_postprocessed) 

880 prediction_pipeline.apply_postprocessing(results) 

881 

882 if len(results.members) != len(expected.members): 

883 add_error_entry( 

884 f"Expected {len(expected.members)} outputs, but got {len(results.members)}" 

885 ) 

886 

887 else: 

888 intermediate_paths: List[Path] = [] 

889 for m, t in test_input_preprocessed.members.items(): 

890 intermediate_paths.extend( 

891 save_to_working_dir(f"test_input_preprocessed_{m}", t) 

892 ) 

893 if intermediate_paths: 

894 logger.debug("Saved preprocessed test inputs to {}", intermediate_paths) 

895 

896 for m, expected in expected.members.items(): 

897 actual = results.members.get(m) 

898 if actual is None: 

899 add_error_entry("Output tensors for test case may not be None") 

900 if stop_early: 

901 break 

902 else: 

903 continue 

904 

905 if actual.dims != (dims := expected.dims): 

906 add_error_entry( 

907 f"Output '{m}' has dims {actual.dims}, but expected {expected.dims}" 

908 ) 

909 if stop_early: 

910 break 

911 else: 

912 continue 

913 

914 if actual.tagged_shape != expected.tagged_shape: 

915 add_error_entry( 

916 f"Output '{m}' has shape {actual.tagged_shape}, but expected {expected.tagged_shape}" 

917 ) 

918 if stop_early: 

919 break 

920 else: 

921 continue 

922 

923 try: 

924 output_paths = save_to_working_dir(f"actual_output_{m}", actual) 

925 if m in results_not_postprocessed.members: 

926 output_paths.extend( 

927 save_to_working_dir( 

928 f"actual_output_{m}_not_postprocessed", 

929 results_not_postprocessed.members[m], 

930 ) 

931 ) 

932 

933 expected_np = expected.data.to_numpy().astype(np.float32) 

934 del expected 

935 actual_np: NDArray[Any] = actual.data.to_numpy().astype(np.float32) 

936 

937 rtol, atol, mismatched_tol = _get_tolerance( 

938 model, wf=weight_format, m=m, **deprecated 

939 ) 

940 rtol_value = rtol * abs(expected_np) 

941 abs_diff = abs(actual_np - expected_np) 

942 mismatched = abs_diff > atol + rtol_value 

943 mismatched_elements = mismatched.sum().item() 

944 

945 mismatched_ppm = mismatched_elements / expected_np.size * 1e6 

946 abs_diff[~mismatched] = 0 # ignore non-mismatched elements 

947 

948 r_max_idx_flat = ( 

949 r_diff := (abs_diff / (abs(expected_np) + 1e-6)) 

950 ).argmax() 

951 r_max_idx = np.unravel_index(r_max_idx_flat, r_diff.shape) 

952 r_max = r_diff[r_max_idx].item() 

953 r_actual = actual_np[r_max_idx].item() 

954 r_expected = expected_np[r_max_idx].item() 

955 

956 # Calculate the max absolute difference with the relative tolerance subtracted 

957 abs_diff_wo_rtol: NDArray[np.float32] = abs_diff - rtol_value 

958 a_max_idx = np.unravel_index( 

959 abs_diff_wo_rtol.argmax(), abs_diff_wo_rtol.shape 

960 ) 

961 

962 a_max = abs_diff[a_max_idx].item() 

963 a_actual = actual_np[a_max_idx].item() 

964 a_expected = expected_np[a_max_idx].item() 

965 except Exception as e: 

966 msg = f"Error while checking if '{m}' disagrees with expected values: {e}" 

967 add_error_entry(msg) 

968 if stop_early: 

969 break 

970 else: 

971 if mismatched_elements: 

972 msg = ( 

973 f"Output '{m}': {mismatched_elements} of " 

974 + f"{expected_np.size} elements disagree with expected values." 

975 + f" ({mismatched_ppm:.1f} ppm)." 

976 ) 

977 else: 

978 msg = f"Output `{m}`: all elements agree with expected values." 

979 

980 msg += ( 

981 f"\n Max relative difference not accounted for by absolute tolerance ({atol:.2e}): {r_max:.2e}" 

982 + rf" (= \|{r_actual:.2e} - {r_expected:.2e}\|/\|{r_expected:.2e} + 1e-6\|)" 

983 + f" at {dict(zip(dims, r_max_idx))}" 

984 + f"\n Max absolute difference not accounted for by relative tolerance ({rtol:.2e}): {a_max:.2e}" 

985 + rf" (= \|{a_actual:.7e} - {a_expected:.7e}\|) at {dict(zip(dims, a_max_idx))}" 

986 ) 

987 if output_paths: 

988 msg += f"\n Saved (intermediate) outputs to {output_paths}." 

989 

990 if mismatched_ppm > mismatched_tol: 

991 add_error_entry(msg) 

992 if stop_early: 

993 break 

994 else: 

995 add_warning_entry( 

996 msg, severity=WARNING if mismatched_elements else INFO 

997 ) 

998 

999 except Exception as e: 

1000 if get_validation_context().raise_errors: 

1001 raise e 

1002 

1003 add_error_entry(str(e), with_traceback=True) 

1004 

1005 model.validation_summary.add_detail( 

1006 ValidationDetail( 

1007 name=test_name, 

1008 loc=("weights", weight_format), 

1009 status="failed" if error_entries else "passed", 

1010 recommended_env=get_conda_env(entry=dict(model.weights)[weight_format]), 

1011 errors=error_entries, 

1012 warnings=warning_entries, 

1013 ) 

1014 ) 

1015 

1016 

1017def _test_model_inference_parametrized( 

1018 model: v0_5.ModelDescr, 

1019 weight_format: SupportedWeightsFormat, 

1020 devices: Optional[Sequence[str]], 

1021 *, 

1022 stop_early: bool, 

1023) -> None: 

1024 if not any( 

1025 isinstance(a.size, v0_5.ParameterizedSize) 

1026 for ipt in model.inputs 

1027 for a in ipt.axes 

1028 ): 

1029 # no parameterized sizes => set n=0 

1030 ns: Set[v0_5.ParameterizedSize_N] = {0} 

1031 else: 

1032 ns = {0, 1, 2} 

1033 

1034 given_batch_sizes = { 

1035 a.size 

1036 for ipt in model.inputs 

1037 for a in ipt.axes 

1038 if isinstance(a, v0_5.BatchAxis) 

1039 } 

1040 if given_batch_sizes: 

1041 batch_sizes = {gbs for gbs in given_batch_sizes if gbs is not None} 

1042 if not batch_sizes: 

1043 # only arbitrary batch sizes 

1044 batch_sizes = {1, 2} 

1045 else: 

1046 # no batch axis 

1047 batch_sizes = {1} 

1048 

1049 test_cases: Set[Tuple[BatchSize, v0_5.ParameterizedSize_N]] = { 

1050 (b, n) for b, n in product(sorted(batch_sizes), sorted(ns)) 

1051 } 

1052 logger.info( 

1053 "Testing inference with '{}' for {} different inputs (B, N): {}", 

1054 weight_format, 

1055 len(test_cases), 

1056 test_cases, 

1057 ) 

1058 

1059 def generate_test_cases(): 

1060 tested: Set[Hashable] = set() 

1061 

1062 def get_ns(n: int): 

1063 return { 

1064 (t.id, a.id): n 

1065 for t in model.inputs 

1066 for a in t.axes 

1067 if isinstance(a.size, v0_5.ParameterizedSize) 

1068 } 

1069 

1070 for batch_size, n in sorted(test_cases): 

1071 input_target_sizes, expected_output_sizes = model.get_axis_sizes( 

1072 get_ns(n), batch_size=batch_size 

1073 ) 

1074 hashable_target_size = tuple( 

1075 (k, input_target_sizes[k]) for k in sorted(input_target_sizes) 

1076 ) 

1077 if hashable_target_size in tested: 

1078 continue 

1079 else: 

1080 tested.add(hashable_target_size) 

1081 

1082 resized_test_inputs = Sample( 

1083 members={ 

1084 t.id: ( 

1085 test_input.members[t.id].resize_to( 

1086 { 

1087 aid: s 

1088 for (tid, aid), s in input_target_sizes.items() 

1089 if tid == t.id 

1090 }, 

1091 ) 

1092 ) 

1093 for t in model.inputs 

1094 }, 

1095 stat=test_input.stat, 

1096 id=test_input.id, 

1097 ) 

1098 expected_output_shapes = { 

1099 t.id: { 

1100 aid: s 

1101 for (tid, aid), s in expected_output_sizes.items() 

1102 if tid == t.id 

1103 } 

1104 for t in model.outputs 

1105 } 

1106 yield n, batch_size, resized_test_inputs, expected_output_shapes 

1107 

1108 try: 

1109 test_input = get_test_input_sample(model) 

1110 

1111 with create_prediction_pipeline( 

1112 bioimageio_model=model, devices=devices, weight_format=weight_format 

1113 ) as prediction_pipeline: 

1114 for n, batch_size, inputs, exptected_output_shape in generate_test_cases(): 

1115 error: Optional[str] = None 

1116 try: 

1117 result = prediction_pipeline.predict_sample_without_blocking(inputs) 

1118 except Exception as e: 

1119 error = str(e) 

1120 else: 

1121 if len(result.members) != len(exptected_output_shape): 

1122 error = ( 

1123 f"Expected {len(exptected_output_shape)} outputs," 

1124 + f" but got {len(result.members)}" 

1125 ) 

1126 

1127 else: 

1128 for m, exp in exptected_output_shape.items(): 

1129 res = result.members.get(m) 

1130 if res is None: 

1131 error = "Output tensors may not be None for test case" 

1132 break 

1133 

1134 diff: Dict[AxisId, int] = {} 

1135 for a, s in res.sizes.items(): 

1136 if isinstance((e_aid := exp[AxisId(a)]), int): 

1137 if s != e_aid: 

1138 diff[AxisId(a)] = s 

1139 elif ( 

1140 s < e_aid.min 

1141 or e_aid.max is not None 

1142 and s > e_aid.max 

1143 ): 

1144 diff[AxisId(a)] = s 

1145 if diff: 

1146 error = ( 

1147 f"(n={n}) Expected output shape {exp}," 

1148 + f" but got {res.sizes} (diff: {diff})" 

1149 ) 

1150 break 

1151 

1152 model.validation_summary.add_detail( 

1153 ValidationDetail( 

1154 name=f"Run {weight_format} inference for inputs with" 

1155 + f" batch_size: {batch_size} and size parameter n: {n}", 

1156 loc=("weights", weight_format), 

1157 status="passed" if error is None else "failed", 

1158 errors=( 

1159 [] 

1160 if error is None 

1161 else [ 

1162 ErrorEntry( 

1163 loc=("weights", weight_format), 

1164 msg=error, 

1165 type="bioimageio.core", 

1166 ) 

1167 ] 

1168 ), 

1169 ) 

1170 ) 

1171 if stop_early and error is not None: 

1172 break 

1173 except Exception as e: 

1174 if get_validation_context().raise_errors: 

1175 raise e 

1176 

1177 model.validation_summary.add_detail( 

1178 ValidationDetail( 

1179 name=f"Run {weight_format} inference for parametrized inputs", 

1180 status="failed", 

1181 loc=("weights", weight_format), 

1182 errors=[ 

1183 ErrorEntry( 

1184 loc=("weights", weight_format), 

1185 msg=str(e), 

1186 type="bioimageio.core", 

1187 with_traceback=True, 

1188 ) 

1189 ], 

1190 ) 

1191 ) 

1192 

1193 

1194def _test_expected_resource_type( 

1195 rd: Union[InvalidDescr, ResourceDescr], expected_type: str 

1196): 

1197 has_expected_type = rd.type == expected_type 

1198 rd.validation_summary.details.append( 

1199 ValidationDetail( 

1200 name="Has expected resource type", 

1201 status="passed" if has_expected_type else "failed", 

1202 loc=("type",), 

1203 errors=( 

1204 [] 

1205 if has_expected_type 

1206 else [ 

1207 ErrorEntry( 

1208 loc=("type",), 

1209 type="type", 

1210 msg=f"Expected type {expected_type}, found {rd.type}", 

1211 ) 

1212 ] 

1213 ), 

1214 ) 

1215 ) 

1216 

1217 

1218# TODO: Implement `debug_model()` 

1219# def debug_model( 

1220# model_rdf: Union[RawResourceDescr, ResourceDescr, URI, Path, str], 

1221# *, 

1222# weight_format: Optional[WeightsFormat] = None, 

1223# devices: Optional[List[str]] = None, 

1224# ): 

1225# """Run the model test and return dict with inputs, results, expected results and intermediates. 

1226 

1227# Returns dict with tensors "inputs", "inputs_processed", "outputs_raw", "outputs", "expected" and "diff". 

1228# """ 

1229# inputs_raw: Optional = None 

1230# inputs_processed: Optional = None 

1231# outputs_raw: Optional = None 

1232# outputs: Optional = None 

1233# expected: Optional = None 

1234# diff: Optional = None 

1235 

1236# model = load_description( 

1237# model_rdf, weights_priority_order=None if weight_format is None else [weight_format] 

1238# ) 

1239# if not isinstance(model, Model): 

1240# raise ValueError(f"Not a bioimageio.model: {model_rdf}") 

1241 

1242# prediction_pipeline = create_prediction_pipeline( 

1243# bioimageio_model=model, devices=devices, weight_format=weight_format 

1244# ) 

1245# inputs = [ 

1246# xr.DataArray(load_array(str(in_path)), dims=input_spec.axes) 

1247# for in_path, input_spec in zip(model.test_inputs, model.inputs) 

1248# ] 

1249# input_dict = {input_spec.name: input for input_spec, input in zip(model.inputs, inputs)} 

1250 

1251# # keep track of the non-processed inputs 

1252# inputs_raw = [deepcopy(input) for input in inputs] 

1253 

1254# computed_measures = {} 

1255 

1256# prediction_pipeline.apply_preprocessing(input_dict, computed_measures) 

1257# inputs_processed = list(input_dict.values()) 

1258# outputs_raw = prediction_pipeline.predict(*inputs_processed) 

1259# output_dict = {output_spec.name: deepcopy(output) for output_spec, output in zip(model.outputs, outputs_raw)} 

1260# prediction_pipeline.apply_postprocessing(output_dict, computed_measures) 

1261# outputs = list(output_dict.values()) 

1262 

1263# if isinstance(outputs, (np.ndarray, xr.DataArray)): 

1264# outputs = [outputs] 

1265 

1266# expected = [ 

1267# xr.DataArray(load_array(str(out_path)), dims=output_spec.axes) 

1268# for out_path, output_spec in zip(model.test_outputs, model.outputs) 

1269# ] 

1270# if len(outputs) != len(expected): 

1271# error = f"Number of outputs and number of expected outputs disagree: {len(outputs)} != {len(expected)}" 

1272# print(error) 

1273# else: 

1274# diff = [] 

1275# for res, exp in zip(outputs, expected): 

1276# diff.append(res - exp) 

1277 

1278# return { 

1279# "inputs": inputs_raw, 

1280# "inputs_processed": inputs_processed, 

1281# "outputs_raw": outputs_raw, 

1282# "outputs": outputs, 

1283# "expected": expected, 

1284# "diff": diff, 

1285# }