Coverage for src / bioimageio / core / _resource_tests.py: 73%

395 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-18 11:12 +0000

1import hashlib 

2import os 

3import platform 

4import subprocess 

5import sys 

6import warnings 

7from contextlib import nullcontext 

8from io import StringIO 

9from itertools import product 

10from pathlib import Path 

11from tempfile import TemporaryDirectory 

12from typing import ( 

13 Any, 

14 Callable, 

15 Dict, 

16 Hashable, 

17 List, 

18 Literal, 

19 Optional, 

20 Sequence, 

21 Set, 

22 Tuple, 

23 Union, 

24 overload, 

25) 

26 

27import numpy as np 

28from loguru import logger 

29from numpy.typing import NDArray 

30from typing_extensions import NotRequired, TypedDict, Unpack, assert_never, get_args 

31 

32from bioimageio.spec import ( 

33 AnyDatasetDescr, 

34 AnyModelDescr, 

35 BioimageioCondaEnv, 

36 DatasetDescr, 

37 InvalidDescr, 

38 LatestResourceDescr, 

39 ModelDescr, 

40 ResourceDescr, 

41 ValidationContext, 

42 build_description, 

43 dump_description, 

44 get_conda_env, 

45 load_description, 

46 save_bioimageio_package, 

47) 

48from bioimageio.spec._description_impl import DISCOVER 

49from bioimageio.spec._internal.common_nodes import ResourceDescrBase 

50from bioimageio.spec._internal.io import is_yaml_value 

51from bioimageio.spec._internal.io_utils import read_yaml, write_yaml 

52from bioimageio.spec._internal.types import ( 

53 AbsoluteTolerance, 

54 FormatVersionPlaceholder, 

55 MismatchedElementsPerMillion, 

56 RelativeTolerance, 

57) 

58from bioimageio.spec._internal.validation_context import get_validation_context 

59from bioimageio.spec.common import BioimageioYamlContent, PermissiveFileSource, Sha256 

60from bioimageio.spec.model import v0_4, v0_5 

61from bioimageio.spec.model.v0_5 import WeightsFormat 

62from bioimageio.spec.summary import ( 

63 ErrorEntry, 

64 InstalledPackage, 

65 ValidationDetail, 

66 ValidationSummary, 

67 WarningEntry, 

68) 

69 

70from . import __version__ 

71from ._prediction_pipeline import create_prediction_pipeline 

72from ._settings import settings 

73from .axis import AxisId, BatchSize 

74from .common import MemberId, SupportedWeightsFormat 

75from .digest_spec import get_test_input_sample, get_test_output_sample 

76from .io import save_tensor 

77from .sample import Sample 

78 

79CONDA_CMD = "conda.bat" if platform.system() == "Windows" else "conda" 

80 

81 

82class DeprecatedKwargs(TypedDict): 

83 absolute_tolerance: NotRequired[AbsoluteTolerance] 

84 relative_tolerance: NotRequired[RelativeTolerance] 

85 decimal: NotRequired[Optional[int]] 

86 

87 

88def enable_determinism( 

89 mode: Literal["seed_only", "full"] = "full", 

90 weight_formats: Optional[Sequence[SupportedWeightsFormat]] = None, 

91): 

92 """Seed and configure ML frameworks for maximum reproducibility. 

93 May degrade performance. Only recommended for testing reproducibility! 

94 

95 Seed any random generators and (if **mode**=="full") request ML frameworks to use 

96 deterministic algorithms. 

97 

98 Args: 

99 mode: determinism mode 

100 - 'seed_only' -- only set seeds, or 

101 - 'full' determinsm features (might degrade performance or throw exceptions) 

102 weight_formats: Limit deep learning importing deep learning frameworks 

103 based on weight_formats. 

104 E.g. this allows to avoid importing tensorflow when testing with pytorch. 

105 

106 Notes: 

107 - **mode** == "full" might degrade performance or throw exceptions. 

108 - Subsequent inference calls might still differ. Call before each function 

109 (sequence) that is expected to be reproducible. 

110 - Degraded performance: Use for testing reproducibility only! 

111 - Recipes: 

112 - [PyTorch](https://pytorch.org/docs/stable/notes/randomness.html) 

113 - [Keras](https://keras.io/examples/keras_recipes/reproducibility_recipes/) 

114 - [NumPy](https://numpy.org/doc/2.0/reference/random/generated/numpy.random.seed.html) 

115 """ 

116 try: 

117 try: 

118 import numpy.random 

119 except ImportError: 

120 pass 

121 else: 

122 numpy.random.seed(0) 

123 except Exception as e: 

124 logger.debug(str(e)) 

125 

126 if ( 

127 weight_formats is None 

128 or "pytorch_state_dict" in weight_formats 

129 or "torchscript" in weight_formats 

130 ): 

131 try: 

132 try: 

133 import torch 

134 except ImportError: 

135 pass 

136 else: 

137 _ = torch.manual_seed(0) 

138 torch.use_deterministic_algorithms(mode == "full") 

139 except Exception as e: 

140 logger.debug(str(e)) 

141 

142 if ( 

143 weight_formats is None 

144 or "tensorflow_saved_model_bundle" in weight_formats 

145 or "keras_hdf5" in weight_formats 

146 ): 

147 try: 

148 os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" 

149 try: 

150 import tensorflow as tf # pyright: ignore[reportMissingTypeStubs] 

151 except ImportError: 

152 pass 

153 else: 

154 tf.random.set_seed(0) 

155 if mode == "full": 

156 tf.config.experimental.enable_op_determinism() 

157 # TODO: find possibility to switch it off again?? 

158 except Exception as e: 

159 logger.debug(str(e)) 

160 

161 if weight_formats is None or "keras_hdf5" in weight_formats: 

162 try: 

163 try: 

164 import keras # pyright: ignore[reportMissingTypeStubs] 

165 except ImportError: 

166 pass 

167 else: 

168 keras.utils.set_random_seed(0) 

169 except Exception as e: 

170 logger.debug(str(e)) 

171 

172 

173def test_model( 

174 source: Union[v0_4.ModelDescr, v0_5.ModelDescr, PermissiveFileSource], 

175 weight_format: Optional[SupportedWeightsFormat] = None, 

176 devices: Optional[List[str]] = None, 

177 *, 

178 determinism: Literal["seed_only", "full"] = "seed_only", 

179 sha256: Optional[Sha256] = None, 

180 stop_early: bool = True, 

181 **deprecated: Unpack[DeprecatedKwargs], 

182) -> ValidationSummary: 

183 """Test model inference""" 

184 return test_description( 

185 source, 

186 weight_format=weight_format, 

187 devices=devices, 

188 determinism=determinism, 

189 expected_type="model", 

190 sha256=sha256, 

191 stop_early=stop_early, 

192 **deprecated, 

193 ) 

194 

195 

196def default_run_command(args: Sequence[str]): 

197 logger.info("running '{}'...", " ".join(args)) 

198 _ = subprocess.check_call(args) 

199 

200 

201def test_description( 

202 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

203 *, 

204 format_version: Union[FormatVersionPlaceholder, str] = "discover", 

205 weight_format: Optional[SupportedWeightsFormat] = None, 

206 devices: Optional[Sequence[str]] = None, 

207 determinism: Literal["seed_only", "full"] = "seed_only", 

208 expected_type: Optional[str] = None, 

209 sha256: Optional[Sha256] = None, 

210 stop_early: bool = True, 

211 runtime_env: Union[ 

212 Literal["currently-active", "as-described"], Path, BioimageioCondaEnv 

213 ] = ("currently-active"), 

214 run_command: Callable[[Sequence[str]], None] = default_run_command, 

215 working_dir: Optional[Union[os.PathLike[str], str]] = None, 

216 **deprecated: Unpack[DeprecatedKwargs], 

217) -> ValidationSummary: 

218 """Test a bioimage.io resource dynamically, 

219 for example run prediction of test tensors for models. 

220 

221 Args: 

222 source: model description source. 

223 weight_format: Weight format to test. 

224 Default: All weight formats present in **source**. 

225 devices: Devices to test with, e.g. 'cpu', 'cuda'. 

226 Default (may be weight format dependent): ['cuda'] if available, ['cpu'] otherwise. 

227 determinism: Modes to improve reproducibility of test outputs. 

228 expected_type: Assert an expected resource description `type`. 

229 sha256: Expected SHA256 value of **source**. 

230 (Ignored if **source** already is a loaded `ResourceDescr` object.) 

231 stop_early: Do not run further subtests after a failed one. 

232 runtime_env: (Experimental feature!) The Python environment to run the tests in 

233 - `"currently-active"`: Use active Python interpreter. 

234 - `"as-described"`: Use `bioimageio.spec.get_conda_env` to generate a conda 

235 environment YAML file based on the model weights description. 

236 - A `BioimageioCondaEnv` or a path to a conda environment YAML file. 

237 Note: The `bioimageio.core` dependency will be added automatically if not present. 

238 run_command: (Experimental feature!) Function to execute (conda) terminal commands in a subprocess. 

239 The function should raise an exception if the command fails. 

240 **run_command** is ignored if **runtime_env** is `"currently-active"`. 

241 working_dir: (for debugging) directory to save any temporary files 

242 (model packages, conda environments, test summaries). 

243 Defaults to a temporary directory. 

244 """ 

245 if runtime_env == "currently-active": 

246 rd = load_description_and_test( 

247 source, 

248 format_version=format_version, 

249 weight_format=weight_format, 

250 devices=devices, 

251 determinism=determinism, 

252 expected_type=expected_type, 

253 sha256=sha256, 

254 stop_early=stop_early, 

255 working_dir=working_dir, 

256 **deprecated, 

257 ) 

258 return rd.validation_summary 

259 

260 if runtime_env == "as-described": 

261 conda_env = None 

262 elif isinstance(runtime_env, (str, Path)): 

263 conda_env = BioimageioCondaEnv.model_validate(read_yaml(Path(runtime_env))) 

264 elif isinstance(runtime_env, BioimageioCondaEnv): 

265 conda_env = runtime_env 

266 else: 

267 assert_never(runtime_env) 

268 

269 if run_command is not default_run_command: 

270 try: 

271 run_command(["thiscommandshouldalwaysfail", "please"]) 

272 except Exception: 

273 pass 

274 else: 

275 raise RuntimeError( 

276 "given run_command does not raise an exception for a failing command" 

277 ) 

278 

279 verbose = working_dir is not None 

280 if working_dir is None: 

281 td_kwargs: Dict[str, Any] = ( 

282 dict(ignore_cleanup_errors=True) if sys.version_info >= (3, 10) else {} 

283 ) 

284 working_dir_ctxt = TemporaryDirectory(**td_kwargs) 

285 else: 

286 working_dir_ctxt = nullcontext(working_dir) 

287 

288 with working_dir_ctxt as _d: 

289 working_dir = Path(_d) 

290 

291 if isinstance(source, ResourceDescrBase): 

292 descr = source 

293 elif isinstance(source, dict): 

294 context = get_validation_context().replace( 

295 perform_io_checks=True # make sure we perform io checks though 

296 ) 

297 

298 descr = build_description(source, context=context) 

299 else: 

300 descr = load_description(source, perform_io_checks=True) 

301 

302 if isinstance(descr, InvalidDescr): 

303 return descr.validation_summary 

304 elif isinstance(source, (dict, ResourceDescrBase)): 

305 file_source = save_bioimageio_package( 

306 descr, output_path=working_dir / "package.zip" 

307 ) 

308 else: 

309 file_source = source 

310 

311 _test_in_env( 

312 file_source, 

313 descr=descr, 

314 working_dir=working_dir, 

315 weight_format=weight_format, 

316 conda_env=conda_env, 

317 devices=devices, 

318 determinism=determinism, 

319 expected_type=expected_type, 

320 sha256=sha256, 

321 stop_early=stop_early, 

322 run_command=run_command, 

323 verbose=verbose, 

324 **deprecated, 

325 ) 

326 

327 return descr.validation_summary 

328 

329 

330def _test_in_env( 

331 source: PermissiveFileSource, 

332 *, 

333 descr: ResourceDescr, 

334 working_dir: Path, 

335 weight_format: Optional[SupportedWeightsFormat], 

336 conda_env: Optional[BioimageioCondaEnv], 

337 devices: Optional[Sequence[str]], 

338 determinism: Literal["seed_only", "full"], 

339 run_command: Callable[[Sequence[str]], None], 

340 stop_early: bool, 

341 expected_type: Optional[str], 

342 sha256: Optional[Sha256], 

343 verbose: bool, 

344 **deprecated: Unpack[DeprecatedKwargs], 

345): 

346 """Test a bioimage.io resource in a given conda environment. 

347 Adds details to the existing validation summary of **descr**. 

348 """ 

349 if isinstance(descr, (v0_4.ModelDescr, v0_5.ModelDescr)): 

350 if weight_format is None: 

351 # run tests for all present weight formats 

352 all_present_wfs = [ 

353 wf for wf in get_args(WeightsFormat) if getattr(descr.weights, wf) 

354 ] 

355 ignore_wfs = [wf for wf in all_present_wfs if wf in ["tensorflow_js"]] 

356 logger.info( 

357 "Found weight formats {}. Start testing all{}...", 

358 all_present_wfs, 

359 f" (except: {', '.join(ignore_wfs)}) " if ignore_wfs else "", 

360 ) 

361 for wf in all_present_wfs: 

362 _test_in_env( 

363 source, 

364 descr=descr, 

365 working_dir=working_dir / wf, 

366 weight_format=wf, 

367 devices=devices, 

368 determinism=determinism, 

369 conda_env=conda_env, 

370 run_command=run_command, 

371 expected_type=expected_type, 

372 sha256=sha256, 

373 stop_early=stop_early, 

374 verbose=verbose, 

375 **deprecated, 

376 ) 

377 

378 return 

379 

380 if weight_format == "pytorch_state_dict": 

381 wf = descr.weights.pytorch_state_dict 

382 elif weight_format == "torchscript": 

383 wf = descr.weights.torchscript 

384 elif weight_format == "keras_hdf5": 

385 wf = descr.weights.keras_hdf5 

386 elif weight_format == "onnx": 

387 wf = descr.weights.onnx 

388 elif weight_format == "tensorflow_saved_model_bundle": 

389 wf = descr.weights.tensorflow_saved_model_bundle 

390 elif weight_format == "tensorflow_js": 

391 raise RuntimeError( 

392 "testing 'tensorflow_js' is not supported by bioimageio.core" 

393 ) 

394 else: 

395 assert_never(weight_format) 

396 assert wf is not None 

397 if conda_env is None: 

398 conda_env = get_conda_env(entry=wf) 

399 

400 test_loc = ("weights", weight_format) 

401 else: 

402 if conda_env is None: 

403 warnings.warn( 

404 "No conda environment description given for testing (And no default conda envs available for non-model descriptions)." 

405 ) 

406 return 

407 

408 test_loc = () 

409 

410 # remove name as we create a name based on the env description hash value 

411 conda_env.name = None 

412 

413 dumped_env = conda_env.model_dump(mode="json", exclude_none=True) 

414 if not is_yaml_value(dumped_env): 

415 raise ValueError(f"Failed to dump conda env to valid YAML {conda_env}") 

416 

417 env_io = StringIO() 

418 write_yaml(dumped_env, file=env_io) 

419 encoded_env = env_io.getvalue().encode() 

420 env_name = hashlib.sha256(encoded_env).hexdigest() 

421 

422 try: 

423 run_command(["where" if platform.system() == "Windows" else "which", CONDA_CMD]) 

424 except Exception as e: 

425 raise RuntimeError("Conda not available") from e 

426 

427 try: 

428 run_command([CONDA_CMD, "run", "-n", env_name, "python", "--version"]) 

429 except Exception as e: 

430 working_dir.mkdir(parents=True, exist_ok=True) 

431 path = working_dir / "env.yaml" 

432 try: 

433 _ = path.write_bytes(encoded_env) 

434 logger.debug("written conda env to {}", path) 

435 run_command( 

436 [ 

437 CONDA_CMD, 

438 "env", 

439 "create", 

440 "--yes", 

441 f"--file={path}", 

442 f"--name={env_name}", 

443 ] 

444 + (["--quiet"] if settings.CI else []) 

445 ) 

446 # double check that environment was created successfully 

447 run_command([CONDA_CMD, "run", "-n", env_name, "python", "--version"]) 

448 except Exception as e: 

449 descr.validation_summary.add_detail( 

450 ValidationDetail( 

451 name="Conda environment creation", 

452 status="failed", 

453 loc=test_loc, 

454 recommended_env=conda_env, 

455 errors=[ 

456 ErrorEntry( 

457 loc=test_loc, 

458 msg=str(e), 

459 type="conda", 

460 with_traceback=True, 

461 ) 

462 ], 

463 ) 

464 ) 

465 return 

466 else: 

467 descr.validation_summary.add_detail( 

468 ValidationDetail( 

469 name=f"Created conda environment '{env_name}'", 

470 status="passed", 

471 loc=test_loc, 

472 ) 

473 ) 

474 else: 

475 descr.validation_summary.add_detail( 

476 ValidationDetail( 

477 name=f"Found existing conda environment '{env_name}'", 

478 status="passed", 

479 loc=test_loc, 

480 ) 

481 ) 

482 

483 working_dir.mkdir(parents=True, exist_ok=True) 

484 summary_path = working_dir / "summary.json" 

485 assert not summary_path.exists(), "Summary file already exists" 

486 cmd = [] 

487 cmd_error = None 

488 for summary_path_arg_name in ("summary", "summary-path"): 

489 try: 

490 run_command( 

491 cmd := ( 

492 [ 

493 CONDA_CMD, 

494 "run", 

495 "-n", 

496 env_name, 

497 "bioimageio", 

498 "test", 

499 str(source), 

500 f"--{summary_path_arg_name}={summary_path.as_posix()}", 

501 f"--determinism={determinism}", 

502 ] 

503 + ([f"--weight-format={weight_format}"] if weight_format else []) 

504 + ([f"--expected-type={expected_type}"] if expected_type else []) 

505 + (["--stop-early"] if stop_early else []) 

506 ) 

507 ) 

508 except Exception as e: 

509 cmd_error = f"Command '{' '.join(cmd)}' returned with error: {e}." 

510 

511 if summary_path.exists(): 

512 break 

513 else: 

514 if cmd_error is not None: 

515 logger.warning(cmd_error) 

516 

517 descr.validation_summary.add_detail( 

518 ValidationDetail( 

519 name="run 'bioimageio test' command", 

520 recommended_env=conda_env, 

521 errors=[ 

522 ErrorEntry( 

523 loc=(), 

524 type="bioimageio cli", 

525 msg=f"test command '{' '.join(cmd)}' did not produce a summary file at {summary_path}", 

526 ) 

527 ], 

528 status="failed", 

529 ) 

530 ) 

531 return 

532 

533 # add relevant details from command summary 

534 command_summary = ValidationSummary.load_json(summary_path) 

535 for detail in command_summary.details: 

536 if detail.loc[: len(test_loc)] == test_loc or detail.status == "failed": 

537 descr.validation_summary.add_detail(detail) 

538 

539 

540@overload 

541def load_description_and_test( 

542 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

543 *, 

544 format_version: Literal["latest"], 

545 weight_format: Optional[SupportedWeightsFormat] = None, 

546 devices: Optional[Sequence[str]] = None, 

547 determinism: Literal["seed_only", "full"] = "seed_only", 

548 expected_type: Literal["model"], 

549 sha256: Optional[Sha256] = None, 

550 stop_early: bool = True, 

551 working_dir: Optional[Union[os.PathLike[str], str]] = None, 

552 **deprecated: Unpack[DeprecatedKwargs], 

553) -> Union[ModelDescr, InvalidDescr]: ... 

554 

555 

556@overload 

557def load_description_and_test( 

558 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

559 *, 

560 format_version: Literal["latest"], 

561 weight_format: Optional[SupportedWeightsFormat] = None, 

562 devices: Optional[Sequence[str]] = None, 

563 determinism: Literal["seed_only", "full"] = "seed_only", 

564 expected_type: Literal["dataset"], 

565 sha256: Optional[Sha256] = None, 

566 stop_early: bool = True, 

567 working_dir: Optional[Union[os.PathLike[str], str]] = None, 

568 **deprecated: Unpack[DeprecatedKwargs], 

569) -> Union[DatasetDescr, InvalidDescr]: ... 

570 

571 

572@overload 

573def load_description_and_test( 

574 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

575 *, 

576 format_version: Literal["latest"], 

577 weight_format: Optional[SupportedWeightsFormat] = None, 

578 devices: Optional[Sequence[str]] = None, 

579 determinism: Literal["seed_only", "full"] = "seed_only", 

580 expected_type: Optional[str] = None, 

581 sha256: Optional[Sha256] = None, 

582 stop_early: bool = True, 

583 working_dir: Optional[Union[os.PathLike[str], str]] = None, 

584 **deprecated: Unpack[DeprecatedKwargs], 

585) -> Union[LatestResourceDescr, InvalidDescr]: ... 

586 

587 

588@overload 

589def load_description_and_test( 

590 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

591 *, 

592 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 

593 weight_format: Optional[SupportedWeightsFormat] = None, 

594 devices: Optional[Sequence[str]] = None, 

595 determinism: Literal["seed_only", "full"] = "seed_only", 

596 expected_type: Literal["model"], 

597 sha256: Optional[Sha256] = None, 

598 stop_early: bool = True, 

599 working_dir: Optional[Union[os.PathLike[str], str]] = None, 

600 **deprecated: Unpack[DeprecatedKwargs], 

601) -> Union[AnyModelDescr, InvalidDescr]: ... 

602 

603 

604@overload 

605def load_description_and_test( 

606 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

607 *, 

608 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 

609 weight_format: Optional[SupportedWeightsFormat] = None, 

610 devices: Optional[Sequence[str]] = None, 

611 determinism: Literal["seed_only", "full"] = "seed_only", 

612 expected_type: Literal["dataset"], 

613 sha256: Optional[Sha256] = None, 

614 stop_early: bool = True, 

615 working_dir: Optional[Union[os.PathLike[str], str]] = None, 

616 **deprecated: Unpack[DeprecatedKwargs], 

617) -> Union[AnyDatasetDescr, InvalidDescr]: ... 

618 

619 

620@overload 

621def load_description_and_test( 

622 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

623 *, 

624 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 

625 weight_format: Optional[SupportedWeightsFormat] = None, 

626 devices: Optional[Sequence[str]] = None, 

627 determinism: Literal["seed_only", "full"] = "seed_only", 

628 expected_type: Optional[str] = None, 

629 sha256: Optional[Sha256] = None, 

630 stop_early: bool = True, 

631 working_dir: Optional[Union[os.PathLike[str], str]] = None, 

632 **deprecated: Unpack[DeprecatedKwargs], 

633) -> Union[ResourceDescr, InvalidDescr]: ... 

634 

635 

636def load_description_and_test( 

637 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

638 *, 

639 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 

640 weight_format: Optional[SupportedWeightsFormat] = None, 

641 devices: Optional[Sequence[str]] = None, 

642 determinism: Literal["seed_only", "full"] = "seed_only", 

643 expected_type: Optional[str] = None, 

644 sha256: Optional[Sha256] = None, 

645 stop_early: bool = True, 

646 working_dir: Optional[Union[os.PathLike[str], str]] = None, 

647 **deprecated: Unpack[DeprecatedKwargs], 

648) -> Union[ResourceDescr, InvalidDescr]: 

649 """Test a bioimage.io resource dynamically, 

650 for example run prediction of test tensors for models. 

651 

652 See `test_description` for more details. 

653 

654 Returns: 

655 A (possibly invalid) resource description object 

656 with a populated `.validation_summary` attribute. 

657 """ 

658 if isinstance(source, ResourceDescrBase): 

659 root = source.root 

660 file_name = source.file_name 

661 if ( 

662 ( 

663 format_version 

664 not in ( 

665 DISCOVER, 

666 source.format_version, 

667 ".".join(source.format_version.split(".")[:2]), 

668 ) 

669 ) 

670 or (c := source.validation_summary.details[0].context) is None 

671 or not c.perform_io_checks 

672 ): 

673 logger.debug( 

674 "deserializing source to ensure we validate and test using format {} and perform io checks", 

675 format_version, 

676 ) 

677 source = dump_description(source) 

678 else: 

679 root = Path() 

680 file_name = None 

681 

682 if isinstance(source, ResourceDescrBase): 

683 rd = source 

684 elif isinstance(source, dict): 

685 # check context for a given root; default to root of source 

686 context = get_validation_context( 

687 ValidationContext(root=root, file_name=file_name) 

688 ).replace( 

689 perform_io_checks=True # make sure we perform io checks though 

690 ) 

691 

692 rd = build_description( 

693 source, 

694 format_version=format_version, 

695 context=context, 

696 ) 

697 else: 

698 rd = load_description( 

699 source, format_version=format_version, sha256=sha256, perform_io_checks=True 

700 ) 

701 

702 rd.validation_summary.env.add( 

703 InstalledPackage(name="bioimageio.core", version=__version__) 

704 ) 

705 

706 if expected_type is not None: 

707 _test_expected_resource_type(rd, expected_type) 

708 

709 if isinstance(rd, (v0_4.ModelDescr, v0_5.ModelDescr)): 

710 if weight_format is None: 

711 weight_formats: List[SupportedWeightsFormat] = [ 

712 w for w, we in rd.weights if we is not None 

713 ] # pyright: ignore[reportAssignmentType] 

714 else: 

715 weight_formats = [weight_format] 

716 

717 enable_determinism(determinism, weight_formats=weight_formats) 

718 for w in weight_formats: 

719 _test_model_inference( 

720 rd, 

721 w, 

722 devices, 

723 stop_early=stop_early, 

724 working_dir=working_dir, 

725 verbose=working_dir is not None, 

726 **deprecated, 

727 ) 

728 if stop_early and rd.validation_summary.status != "passed": 

729 break 

730 

731 if not isinstance(rd, v0_4.ModelDescr): 

732 _test_model_inference_parametrized( 

733 rd, w, devices, stop_early=stop_early 

734 ) 

735 if stop_early and rd.validation_summary.status != "passed": 

736 break 

737 

738 # TODO: add execution of jupyter notebooks 

739 # TODO: add more tests 

740 

741 return rd 

742 

743 

744def _get_tolerance( 

745 model: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

746 wf: SupportedWeightsFormat, 

747 m: MemberId, 

748 **deprecated: Unpack[DeprecatedKwargs], 

749) -> Tuple[RelativeTolerance, AbsoluteTolerance, MismatchedElementsPerMillion]: 

750 if isinstance(model, v0_5.ModelDescr): 

751 applicable = v0_5.ReproducibilityTolerance() 

752 

753 # check legacy test kwargs for weight format specific tolerance 

754 if model.config.bioimageio.model_extra is not None: 

755 for weights_format, test_kwargs in model.config.bioimageio.model_extra.get( 

756 "test_kwargs", {} 

757 ).items(): 

758 if wf == weights_format: 

759 applicable = v0_5.ReproducibilityTolerance( 

760 relative_tolerance=test_kwargs.get("relative_tolerance", 1e-3), 

761 absolute_tolerance=test_kwargs.get("absolute_tolerance", 1e-3), 

762 ) 

763 break 

764 

765 # check for weights format and output tensor specific tolerance 

766 for a in model.config.bioimageio.reproducibility_tolerance: 

767 if (not a.weights_formats or wf in a.weights_formats) and ( 

768 not a.output_ids or m in a.output_ids 

769 ): 

770 applicable = a 

771 break 

772 

773 rtol = applicable.relative_tolerance 

774 atol = applicable.absolute_tolerance 

775 mismatched_tol = applicable.mismatched_elements_per_million 

776 elif (decimal := deprecated.get("decimal")) is not None: 

777 warnings.warn( 

778 "The argument `decimal` has been deprecated in favour of" 

779 + " `relative_tolerance` and `absolute_tolerance`, with different" 

780 + " validation logic, using `numpy.testing.assert_allclose, see" 

781 + " 'https://numpy.org/doc/stable/reference/generated/" 

782 + " numpy.testing.assert_allclose.html'. Passing a value for `decimal`" 

783 + " will cause validation to revert to the old behaviour." 

784 ) 

785 atol = 1.5 * 10 ** (-decimal) 

786 rtol = 0 

787 mismatched_tol = 0 

788 else: 

789 # use given (deprecated) test kwargs 

790 atol = deprecated.get("absolute_tolerance", 1e-3) 

791 rtol = deprecated.get("relative_tolerance", 1e-3) 

792 mismatched_tol = 0 

793 

794 return rtol, atol, mismatched_tol 

795 

796 

797def _test_model_inference( 

798 model: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

799 weight_format: SupportedWeightsFormat, 

800 devices: Optional[Sequence[str]], 

801 stop_early: bool, 

802 *, 

803 working_dir: Optional[Union[os.PathLike[str], str]], 

804 verbose: bool, 

805 **deprecated: Unpack[DeprecatedKwargs], 

806) -> None: 

807 test_name = f"Reproduce test outputs from test inputs ({weight_format})" 

808 logger.debug("starting '{}'", test_name) 

809 error_entries: List[ErrorEntry] = [] 

810 warning_entries: List[WarningEntry] = [] 

811 

812 def add_error_entry(msg: str, with_traceback: bool = False): 

813 error_entries.append( 

814 ErrorEntry( 

815 loc=("weights", weight_format), 

816 msg=msg, 

817 type="bioimageio.core", 

818 with_traceback=with_traceback, 

819 ) 

820 ) 

821 

822 def add_warning_entry(msg: str): 

823 warning_entries.append( 

824 WarningEntry( 

825 loc=("weights", weight_format), 

826 msg=msg, 

827 type="bioimageio.core", 

828 ) 

829 ) 

830 

831 try: 

832 test_input = get_test_input_sample(model) 

833 expected = get_test_output_sample(model) 

834 

835 with create_prediction_pipeline( 

836 bioimageio_model=model, devices=devices, weight_format=weight_format 

837 ) as prediction_pipeline: 

838 results = prediction_pipeline.predict_sample_without_blocking(test_input) 

839 

840 if len(results.members) != len(expected.members): 

841 add_error_entry( 

842 f"Expected {len(expected.members)} outputs, but got {len(results.members)}" 

843 ) 

844 

845 else: 

846 for m, expected in expected.members.items(): 

847 actual = results.members.get(m) 

848 if actual is None: 

849 add_error_entry("Output tensors for test case may not be None") 

850 if stop_early: 

851 break 

852 else: 

853 continue 

854 

855 if actual.dims != (dims := expected.dims): 

856 add_error_entry( 

857 f"Output '{m}' has dims {actual.dims}, but expected {expected.dims}" 

858 ) 

859 if stop_early: 

860 break 

861 else: 

862 continue 

863 

864 if actual.tagged_shape != expected.tagged_shape: 

865 add_error_entry( 

866 f"Output '{m}' has shape {actual.tagged_shape}, but expected {expected.tagged_shape}" 

867 ) 

868 if stop_early: 

869 break 

870 else: 

871 continue 

872 

873 try: 

874 expected_np = expected.data.to_numpy().astype(np.float32) 

875 del expected 

876 actual_np: NDArray[Any] = actual.data.to_numpy().astype(np.float32) 

877 

878 rtol, atol, mismatched_tol = _get_tolerance( 

879 model, wf=weight_format, m=m, **deprecated 

880 ) 

881 rtol_value = rtol * abs(expected_np) 

882 abs_diff = abs(actual_np - expected_np) 

883 mismatched = abs_diff > atol + rtol_value 

884 mismatched_elements = mismatched.sum().item() 

885 if not mismatched_elements: 

886 continue 

887 

888 if working_dir is not None and verbose: 

889 actual_output_path = ( 

890 Path(working_dir) / f"actual_output_{m}_{weight_format}.npy" 

891 ) 

892 try: 

893 save_tensor(actual_output_path, actual) 

894 except Exception as e: 

895 logger.error( 

896 "Failed to save actual output tensor to {}: {}", 

897 actual_output_path, 

898 e, 

899 ) 

900 else: 

901 actual_output_path = None 

902 

903 mismatched_ppm = mismatched_elements / expected_np.size * 1e6 

904 abs_diff[~mismatched] = 0 # ignore non-mismatched elements 

905 

906 r_max_idx_flat = ( 

907 r_diff := (abs_diff / (abs(expected_np) + 1e-6)) 

908 ).argmax() 

909 r_max_idx = np.unravel_index(r_max_idx_flat, r_diff.shape) 

910 r_max = r_diff[r_max_idx].item() 

911 r_actual = actual_np[r_max_idx].item() 

912 r_expected = expected_np[r_max_idx].item() 

913 

914 # Calculate the max absolute difference with the relative tolerance subtracted 

915 abs_diff_wo_rtol: NDArray[np.float32] = abs_diff - rtol_value 

916 a_max_idx = np.unravel_index( 

917 abs_diff_wo_rtol.argmax(), abs_diff_wo_rtol.shape 

918 ) 

919 

920 a_max = abs_diff[a_max_idx].item() 

921 a_actual = actual_np[a_max_idx].item() 

922 a_expected = expected_np[a_max_idx].item() 

923 except Exception as e: 

924 msg = f"Output '{m}' disagrees with expected values." 

925 add_error_entry(msg) 

926 if stop_early: 

927 break 

928 else: 

929 msg = ( 

930 f"Output '{m}' disagrees with {mismatched_elements} of" 

931 + f" {expected_np.size} expected values" 

932 + f" ({mismatched_ppm:.1f} ppm)." 

933 + f"\n Max relative difference not accounted for by absolute tolerance ({atol:.2e}): {r_max:.2e}" 

934 + rf" (= \|{r_actual:.2e} - {r_expected:.2e}\|/\|{r_expected:.2e} + 1e-6\|)" 

935 + f" at {dict(zip(dims, r_max_idx))}" 

936 + f"\n Max absolute difference not accounted for by relative tolerance ({rtol:.2e}): {a_max:.2e}" 

937 + rf" (= \|{a_actual:.7e} - {a_expected:.7e}\|) at {dict(zip(dims, a_max_idx))}" 

938 ) 

939 if actual_output_path is not None: 

940 msg += f"\n Saved actual output to {actual_output_path}." 

941 

942 if mismatched_ppm > mismatched_tol: 

943 add_error_entry(msg) 

944 if stop_early: 

945 break 

946 else: 

947 add_warning_entry(msg) 

948 

949 except Exception as e: 

950 if get_validation_context().raise_errors: 

951 raise e 

952 

953 add_error_entry(str(e), with_traceback=True) 

954 

955 model.validation_summary.add_detail( 

956 ValidationDetail( 

957 name=test_name, 

958 loc=("weights", weight_format), 

959 status="failed" if error_entries else "passed", 

960 recommended_env=get_conda_env(entry=dict(model.weights)[weight_format]), 

961 errors=error_entries, 

962 warnings=warning_entries, 

963 ) 

964 ) 

965 

966 

967def _test_model_inference_parametrized( 

968 model: v0_5.ModelDescr, 

969 weight_format: SupportedWeightsFormat, 

970 devices: Optional[Sequence[str]], 

971 *, 

972 stop_early: bool, 

973) -> None: 

974 if not any( 

975 isinstance(a.size, v0_5.ParameterizedSize) 

976 for ipt in model.inputs 

977 for a in ipt.axes 

978 ): 

979 # no parameterized sizes => set n=0 

980 ns: Set[v0_5.ParameterizedSize_N] = {0} 

981 else: 

982 ns = {0, 1, 2} 

983 

984 given_batch_sizes = { 

985 a.size 

986 for ipt in model.inputs 

987 for a in ipt.axes 

988 if isinstance(a, v0_5.BatchAxis) 

989 } 

990 if given_batch_sizes: 

991 batch_sizes = {gbs for gbs in given_batch_sizes if gbs is not None} 

992 if not batch_sizes: 

993 # only arbitrary batch sizes 

994 batch_sizes = {1, 2} 

995 else: 

996 # no batch axis 

997 batch_sizes = {1} 

998 

999 test_cases: Set[Tuple[BatchSize, v0_5.ParameterizedSize_N]] = { 

1000 (b, n) for b, n in product(sorted(batch_sizes), sorted(ns)) 

1001 } 

1002 logger.info( 

1003 "Testing inference with '{}' for {} different inputs (B, N): {}", 

1004 weight_format, 

1005 len(test_cases), 

1006 test_cases, 

1007 ) 

1008 

1009 def generate_test_cases(): 

1010 tested: Set[Hashable] = set() 

1011 

1012 def get_ns(n: int): 

1013 return { 

1014 (t.id, a.id): n 

1015 for t in model.inputs 

1016 for a in t.axes 

1017 if isinstance(a.size, v0_5.ParameterizedSize) 

1018 } 

1019 

1020 for batch_size, n in sorted(test_cases): 

1021 input_target_sizes, expected_output_sizes = model.get_axis_sizes( 

1022 get_ns(n), batch_size=batch_size 

1023 ) 

1024 hashable_target_size = tuple( 

1025 (k, input_target_sizes[k]) for k in sorted(input_target_sizes) 

1026 ) 

1027 if hashable_target_size in tested: 

1028 continue 

1029 else: 

1030 tested.add(hashable_target_size) 

1031 

1032 resized_test_inputs = Sample( 

1033 members={ 

1034 t.id: ( 

1035 test_input.members[t.id].resize_to( 

1036 { 

1037 aid: s 

1038 for (tid, aid), s in input_target_sizes.items() 

1039 if tid == t.id 

1040 }, 

1041 ) 

1042 ) 

1043 for t in model.inputs 

1044 }, 

1045 stat=test_input.stat, 

1046 id=test_input.id, 

1047 ) 

1048 expected_output_shapes = { 

1049 t.id: { 

1050 aid: s 

1051 for (tid, aid), s in expected_output_sizes.items() 

1052 if tid == t.id 

1053 } 

1054 for t in model.outputs 

1055 } 

1056 yield n, batch_size, resized_test_inputs, expected_output_shapes 

1057 

1058 try: 

1059 test_input = get_test_input_sample(model) 

1060 

1061 with create_prediction_pipeline( 

1062 bioimageio_model=model, devices=devices, weight_format=weight_format 

1063 ) as prediction_pipeline: 

1064 for n, batch_size, inputs, exptected_output_shape in generate_test_cases(): 

1065 error: Optional[str] = None 

1066 try: 

1067 result = prediction_pipeline.predict_sample_without_blocking(inputs) 

1068 except Exception as e: 

1069 error = str(e) 

1070 else: 

1071 if len(result.members) != len(exptected_output_shape): 

1072 error = ( 

1073 f"Expected {len(exptected_output_shape)} outputs," 

1074 + f" but got {len(result.members)}" 

1075 ) 

1076 

1077 else: 

1078 for m, exp in exptected_output_shape.items(): 

1079 res = result.members.get(m) 

1080 if res is None: 

1081 error = "Output tensors may not be None for test case" 

1082 break 

1083 

1084 diff: Dict[AxisId, int] = {} 

1085 for a, s in res.sizes.items(): 

1086 if isinstance((e_aid := exp[AxisId(a)]), int): 

1087 if s != e_aid: 

1088 diff[AxisId(a)] = s 

1089 elif ( 

1090 s < e_aid.min 

1091 or e_aid.max is not None 

1092 and s > e_aid.max 

1093 ): 

1094 diff[AxisId(a)] = s 

1095 if diff: 

1096 error = ( 

1097 f"(n={n}) Expected output shape {exp}," 

1098 + f" but got {res.sizes} (diff: {diff})" 

1099 ) 

1100 break 

1101 

1102 model.validation_summary.add_detail( 

1103 ValidationDetail( 

1104 name=f"Run {weight_format} inference for inputs with" 

1105 + f" batch_size: {batch_size} and size parameter n: {n}", 

1106 loc=("weights", weight_format), 

1107 status="passed" if error is None else "failed", 

1108 errors=( 

1109 [] 

1110 if error is None 

1111 else [ 

1112 ErrorEntry( 

1113 loc=("weights", weight_format), 

1114 msg=error, 

1115 type="bioimageio.core", 

1116 ) 

1117 ] 

1118 ), 

1119 ) 

1120 ) 

1121 if stop_early and error is not None: 

1122 break 

1123 except Exception as e: 

1124 if get_validation_context().raise_errors: 

1125 raise e 

1126 

1127 model.validation_summary.add_detail( 

1128 ValidationDetail( 

1129 name=f"Run {weight_format} inference for parametrized inputs", 

1130 status="failed", 

1131 loc=("weights", weight_format), 

1132 errors=[ 

1133 ErrorEntry( 

1134 loc=("weights", weight_format), 

1135 msg=str(e), 

1136 type="bioimageio.core", 

1137 with_traceback=True, 

1138 ) 

1139 ], 

1140 ) 

1141 ) 

1142 

1143 

1144def _test_expected_resource_type( 

1145 rd: Union[InvalidDescr, ResourceDescr], expected_type: str 

1146): 

1147 has_expected_type = rd.type == expected_type 

1148 rd.validation_summary.details.append( 

1149 ValidationDetail( 

1150 name="Has expected resource type", 

1151 status="passed" if has_expected_type else "failed", 

1152 loc=("type",), 

1153 errors=( 

1154 [] 

1155 if has_expected_type 

1156 else [ 

1157 ErrorEntry( 

1158 loc=("type",), 

1159 type="type", 

1160 msg=f"Expected type {expected_type}, found {rd.type}", 

1161 ) 

1162 ] 

1163 ), 

1164 ) 

1165 ) 

1166 

1167 

1168# TODO: Implement `debug_model()` 

1169# def debug_model( 

1170# model_rdf: Union[RawResourceDescr, ResourceDescr, URI, Path, str], 

1171# *, 

1172# weight_format: Optional[WeightsFormat] = None, 

1173# devices: Optional[List[str]] = None, 

1174# ): 

1175# """Run the model test and return dict with inputs, results, expected results and intermediates. 

1176 

1177# Returns dict with tensors "inputs", "inputs_processed", "outputs_raw", "outputs", "expected" and "diff". 

1178# """ 

1179# inputs_raw: Optional = None 

1180# inputs_processed: Optional = None 

1181# outputs_raw: Optional = None 

1182# outputs: Optional = None 

1183# expected: Optional = None 

1184# diff: Optional = None 

1185 

1186# model = load_description( 

1187# model_rdf, weights_priority_order=None if weight_format is None else [weight_format] 

1188# ) 

1189# if not isinstance(model, Model): 

1190# raise ValueError(f"Not a bioimageio.model: {model_rdf}") 

1191 

1192# prediction_pipeline = create_prediction_pipeline( 

1193# bioimageio_model=model, devices=devices, weight_format=weight_format 

1194# ) 

1195# inputs = [ 

1196# xr.DataArray(load_array(str(in_path)), dims=input_spec.axes) 

1197# for in_path, input_spec in zip(model.test_inputs, model.inputs) 

1198# ] 

1199# input_dict = {input_spec.name: input for input_spec, input in zip(model.inputs, inputs)} 

1200 

1201# # keep track of the non-processed inputs 

1202# inputs_raw = [deepcopy(input) for input in inputs] 

1203 

1204# computed_measures = {} 

1205 

1206# prediction_pipeline.apply_preprocessing(input_dict, computed_measures) 

1207# inputs_processed = list(input_dict.values()) 

1208# outputs_raw = prediction_pipeline.predict(*inputs_processed) 

1209# output_dict = {output_spec.name: deepcopy(output) for output_spec, output in zip(model.outputs, outputs_raw)} 

1210# prediction_pipeline.apply_postprocessing(output_dict, computed_measures) 

1211# outputs = list(output_dict.values()) 

1212 

1213# if isinstance(outputs, (np.ndarray, xr.DataArray)): 

1214# outputs = [outputs] 

1215 

1216# expected = [ 

1217# xr.DataArray(load_array(str(out_path)), dims=output_spec.axes) 

1218# for out_path, output_spec in zip(model.test_outputs, model.outputs) 

1219# ] 

1220# if len(outputs) != len(expected): 

1221# error = f"Number of outputs and number of expected outputs disagree: {len(outputs)} != {len(expected)}" 

1222# print(error) 

1223# else: 

1224# diff = [] 

1225# for res, exp in zip(outputs, expected): 

1226# diff.append(res - exp) 

1227 

1228# return { 

1229# "inputs": inputs_raw, 

1230# "inputs_processed": inputs_processed, 

1231# "outputs_raw": outputs_raw, 

1232# "outputs": outputs, 

1233# "expected": expected, 

1234# "diff": diff, 

1235# }