Coverage for src/bioimageio/core/_resource_tests.py: 50%

383 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-13 11:02 +0000

1import hashlib 

2import os 

3import platform 

4import subprocess 

5import sys 

6import warnings 

7from io import StringIO 

8from itertools import product 

9from pathlib import Path 

10from tempfile import TemporaryDirectory 

11from typing import ( 

12 Any, 

13 Callable, 

14 Dict, 

15 Hashable, 

16 List, 

17 Literal, 

18 Optional, 

19 Sequence, 

20 Set, 

21 Tuple, 

22 Union, 

23 overload, 

24) 

25 

26import numpy as np 

27from bioimageio.spec import ( 

28 AnyDatasetDescr, 

29 AnyModelDescr, 

30 BioimageioCondaEnv, 

31 DatasetDescr, 

32 InvalidDescr, 

33 LatestResourceDescr, 

34 ModelDescr, 

35 ResourceDescr, 

36 ValidationContext, 

37 build_description, 

38 dump_description, 

39 get_conda_env, 

40 load_description, 

41 save_bioimageio_package, 

42) 

43from bioimageio.spec._description_impl import DISCOVER 

44from bioimageio.spec._internal.common_nodes import ResourceDescrBase 

45from bioimageio.spec._internal.io import is_yaml_value 

46from bioimageio.spec._internal.io_utils import read_yaml, write_yaml 

47from bioimageio.spec._internal.types import ( 

48 AbsoluteTolerance, 

49 FormatVersionPlaceholder, 

50 MismatchedElementsPerMillion, 

51 RelativeTolerance, 

52) 

53from bioimageio.spec._internal.validation_context import get_validation_context 

54from bioimageio.spec.common import BioimageioYamlContent, PermissiveFileSource, Sha256 

55from bioimageio.spec.model import v0_4, v0_5 

56from bioimageio.spec.model.v0_5 import WeightsFormat 

57from bioimageio.spec.summary import ( 

58 ErrorEntry, 

59 InstalledPackage, 

60 ValidationDetail, 

61 ValidationSummary, 

62 WarningEntry, 

63) 

64from loguru import logger 

65from numpy.typing import NDArray 

66from typing_extensions import NotRequired, TypedDict, Unpack, assert_never, get_args 

67 

68from bioimageio.core import __version__ 

69from bioimageio.core.io import save_tensor 

70 

71from ._prediction_pipeline import create_prediction_pipeline 

72from ._settings import settings 

73from .axis import AxisId, BatchSize 

74from .common import MemberId, SupportedWeightsFormat 

75from .digest_spec import get_test_input_sample, get_test_output_sample 

76from .sample import Sample 

77 

78CONDA_CMD = "conda.bat" if platform.system() == "Windows" else "conda" 

79 

80 

81class DeprecatedKwargs(TypedDict): 

82 absolute_tolerance: NotRequired[AbsoluteTolerance] 

83 relative_tolerance: NotRequired[RelativeTolerance] 

84 decimal: NotRequired[Optional[int]] 

85 

86 

87def enable_determinism( 

88 mode: Literal["seed_only", "full"] = "full", 

89 weight_formats: Optional[Sequence[SupportedWeightsFormat]] = None, 

90): 

91 """Seed and configure ML frameworks for maximum reproducibility. 

92 May degrade performance. Only recommended for testing reproducibility! 

93 

94 Seed any random generators and (if **mode**=="full") request ML frameworks to use 

95 deterministic algorithms. 

96 

97 Args: 

98 mode: determinism mode 

99 - 'seed_only' -- only set seeds, or 

100 - 'full' determinsm features (might degrade performance or throw exceptions) 

101 weight_formats: Limit deep learning importing deep learning frameworks 

102 based on weight_formats. 

103 E.g. this allows to avoid importing tensorflow when testing with pytorch. 

104 

105 Notes: 

106 - **mode** == "full" might degrade performance or throw exceptions. 

107 - Subsequent inference calls might still differ. Call before each function 

108 (sequence) that is expected to be reproducible. 

109 - Degraded performance: Use for testing reproducibility only! 

110 - Recipes: 

111 - [PyTorch](https://pytorch.org/docs/stable/notes/randomness.html) 

112 - [Keras](https://keras.io/examples/keras_recipes/reproducibility_recipes/) 

113 - [NumPy](https://numpy.org/doc/2.0/reference/random/generated/numpy.random.seed.html) 

114 """ 

115 try: 

116 try: 

117 import numpy.random 

118 except ImportError: 

119 pass 

120 else: 

121 numpy.random.seed(0) 

122 except Exception as e: 

123 logger.debug(str(e)) 

124 

125 if ( 

126 weight_formats is None 

127 or "pytorch_state_dict" in weight_formats 

128 or "torchscript" in weight_formats 

129 ): 

130 try: 

131 try: 

132 import torch 

133 except ImportError: 

134 pass 

135 else: 

136 _ = torch.manual_seed(0) 

137 torch.use_deterministic_algorithms(mode == "full") 

138 except Exception as e: 

139 logger.debug(str(e)) 

140 

141 if ( 

142 weight_formats is None 

143 or "tensorflow_saved_model_bundle" in weight_formats 

144 or "keras_hdf5" in weight_formats 

145 ): 

146 try: 

147 os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" 

148 try: 

149 import tensorflow as tf # pyright: ignore[reportMissingTypeStubs] 

150 except ImportError: 

151 pass 

152 else: 

153 tf.random.set_seed(0) 

154 if mode == "full": 

155 tf.config.experimental.enable_op_determinism() 

156 # TODO: find possibility to switch it off again?? 

157 except Exception as e: 

158 logger.debug(str(e)) 

159 

160 if weight_formats is None or "keras_hdf5" in weight_formats: 

161 try: 

162 try: 

163 import keras # pyright: ignore[reportMissingTypeStubs] 

164 except ImportError: 

165 pass 

166 else: 

167 keras.utils.set_random_seed(0) 

168 except Exception as e: 

169 logger.debug(str(e)) 

170 

171 

172def test_model( 

173 source: Union[v0_4.ModelDescr, v0_5.ModelDescr, PermissiveFileSource], 

174 weight_format: Optional[SupportedWeightsFormat] = None, 

175 devices: Optional[List[str]] = None, 

176 *, 

177 determinism: Literal["seed_only", "full"] = "seed_only", 

178 sha256: Optional[Sha256] = None, 

179 stop_early: bool = True, 

180 **deprecated: Unpack[DeprecatedKwargs], 

181) -> ValidationSummary: 

182 """Test model inference""" 

183 return test_description( 

184 source, 

185 weight_format=weight_format, 

186 devices=devices, 

187 determinism=determinism, 

188 expected_type="model", 

189 sha256=sha256, 

190 stop_early=stop_early, 

191 **deprecated, 

192 ) 

193 

194 

195def default_run_command(args: Sequence[str]): 

196 logger.info("running '{}'...", " ".join(args)) 

197 _ = subprocess.check_call(args) 

198 

199 

200def test_description( 

201 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

202 *, 

203 format_version: Union[FormatVersionPlaceholder, str] = "discover", 

204 weight_format: Optional[SupportedWeightsFormat] = None, 

205 devices: Optional[Sequence[str]] = None, 

206 determinism: Literal["seed_only", "full"] = "seed_only", 

207 expected_type: Optional[str] = None, 

208 sha256: Optional[Sha256] = None, 

209 stop_early: bool = True, 

210 runtime_env: Union[ 

211 Literal["currently-active", "as-described"], Path, BioimageioCondaEnv 

212 ] = ("currently-active"), 

213 run_command: Callable[[Sequence[str]], None] = default_run_command, 

214 **deprecated: Unpack[DeprecatedKwargs], 

215) -> ValidationSummary: 

216 """Test a bioimage.io resource dynamically, 

217 for example run prediction of test tensors for models. 

218 

219 Args: 

220 source: model description source. 

221 weight_format: Weight format to test. 

222 Default: All weight formats present in **source**. 

223 devices: Devices to test with, e.g. 'cpu', 'cuda'. 

224 Default (may be weight format dependent): ['cuda'] if available, ['cpu'] otherwise. 

225 determinism: Modes to improve reproducibility of test outputs. 

226 expected_type: Assert an expected resource description `type`. 

227 sha256: Expected SHA256 value of **source**. 

228 (Ignored if **source** already is a loaded `ResourceDescr` object.) 

229 stop_early: Do not run further subtests after a failed one. 

230 runtime_env: (Experimental feature!) The Python environment to run the tests in 

231 - `"currently-active"`: Use active Python interpreter. 

232 - `"as-described"`: Use `bioimageio.spec.get_conda_env` to generate a conda 

233 environment YAML file based on the model weights description. 

234 - A `BioimageioCondaEnv` or a path to a conda environment YAML file. 

235 Note: The `bioimageio.core` dependency will be added automatically if not present. 

236 run_command: (Experimental feature!) Function to execute (conda) terminal commands in a subprocess. 

237 The function should raise an exception if the command fails. 

238 **run_command** is ignored if **runtime_env** is `"currently-active"`. 

239 """ 

240 if runtime_env == "currently-active": 

241 rd = load_description_and_test( 

242 source, 

243 format_version=format_version, 

244 weight_format=weight_format, 

245 devices=devices, 

246 determinism=determinism, 

247 expected_type=expected_type, 

248 sha256=sha256, 

249 stop_early=stop_early, 

250 **deprecated, 

251 ) 

252 return rd.validation_summary 

253 

254 if runtime_env == "as-described": 

255 conda_env = None 

256 elif isinstance(runtime_env, (str, Path)): 

257 conda_env = BioimageioCondaEnv.model_validate(read_yaml(Path(runtime_env))) 

258 elif isinstance(runtime_env, BioimageioCondaEnv): 

259 conda_env = runtime_env 

260 else: 

261 assert_never(runtime_env) 

262 

263 try: 

264 run_command(["thiscommandshouldalwaysfail", "please"]) 

265 except Exception: 

266 pass 

267 else: 

268 raise RuntimeError( 

269 "given run_command does not raise an exception for a failing command" 

270 ) 

271 

272 td_kwargs: Dict[str, Any] = ( 

273 dict(ignore_cleanup_errors=True) if sys.version_info >= (3, 10) else {} 

274 ) 

275 with TemporaryDirectory(**td_kwargs) as _d: 

276 working_dir = Path(_d) 

277 

278 if isinstance(source, ResourceDescrBase): 

279 descr = source 

280 elif isinstance(source, dict): 

281 context = get_validation_context().replace( 

282 perform_io_checks=True # make sure we perform io checks though 

283 ) 

284 

285 descr = build_description(source, context=context) 

286 else: 

287 descr = load_description(source, perform_io_checks=True) 

288 

289 if isinstance(descr, InvalidDescr): 

290 return descr.validation_summary 

291 elif isinstance(source, (dict, ResourceDescrBase)): 

292 file_source = save_bioimageio_package( 

293 descr, output_path=working_dir / "package.zip" 

294 ) 

295 else: 

296 file_source = source 

297 

298 _test_in_env( 

299 file_source, 

300 descr=descr, 

301 working_dir=working_dir, 

302 weight_format=weight_format, 

303 conda_env=conda_env, 

304 devices=devices, 

305 determinism=determinism, 

306 expected_type=expected_type, 

307 sha256=sha256, 

308 stop_early=stop_early, 

309 run_command=run_command, 

310 **deprecated, 

311 ) 

312 

313 return descr.validation_summary 

314 

315 

316def _test_in_env( 

317 source: PermissiveFileSource, 

318 *, 

319 descr: ResourceDescr, 

320 working_dir: Path, 

321 weight_format: Optional[SupportedWeightsFormat], 

322 conda_env: Optional[BioimageioCondaEnv], 

323 devices: Optional[Sequence[str]], 

324 determinism: Literal["seed_only", "full"], 

325 run_command: Callable[[Sequence[str]], None], 

326 stop_early: bool, 

327 expected_type: Optional[str], 

328 sha256: Optional[Sha256], 

329 **deprecated: Unpack[DeprecatedKwargs], 

330): 

331 """Test a bioimage.io resource in a given conda environment. 

332 Adds details to the existing validation summary of **descr**. 

333 """ 

334 if isinstance(descr, (v0_4.ModelDescr, v0_5.ModelDescr)): 

335 if weight_format is None: 

336 # run tests for all present weight formats 

337 all_present_wfs = [ 

338 wf for wf in get_args(WeightsFormat) if getattr(descr.weights, wf) 

339 ] 

340 ignore_wfs = [wf for wf in all_present_wfs if wf in ["tensorflow_js"]] 

341 logger.info( 

342 "Found weight formats {}. Start testing all{}...", 

343 all_present_wfs, 

344 f" (except: {', '.join(ignore_wfs)}) " if ignore_wfs else "", 

345 ) 

346 for wf in all_present_wfs: 

347 _test_in_env( 

348 source, 

349 descr=descr, 

350 working_dir=working_dir / wf, 

351 weight_format=wf, 

352 devices=devices, 

353 determinism=determinism, 

354 conda_env=conda_env, 

355 run_command=run_command, 

356 expected_type=expected_type, 

357 sha256=sha256, 

358 stop_early=stop_early, 

359 **deprecated, 

360 ) 

361 

362 return 

363 

364 if weight_format == "pytorch_state_dict": 

365 wf = descr.weights.pytorch_state_dict 

366 elif weight_format == "torchscript": 

367 wf = descr.weights.torchscript 

368 elif weight_format == "keras_hdf5": 

369 wf = descr.weights.keras_hdf5 

370 elif weight_format == "onnx": 

371 wf = descr.weights.onnx 

372 elif weight_format == "tensorflow_saved_model_bundle": 

373 wf = descr.weights.tensorflow_saved_model_bundle 

374 elif weight_format == "tensorflow_js": 

375 raise RuntimeError( 

376 "testing 'tensorflow_js' is not supported by bioimageio.core" 

377 ) 

378 else: 

379 assert_never(weight_format) 

380 assert wf is not None 

381 if conda_env is None: 

382 conda_env = get_conda_env(entry=wf) 

383 

384 test_loc = ("weights", weight_format) 

385 else: 

386 if conda_env is None: 

387 warnings.warn( 

388 "No conda environment description given for testing (And no default conda envs available for non-model descriptions)." 

389 ) 

390 return 

391 

392 test_loc = () 

393 

394 # remove name as we crate a name based on the env description hash value 

395 conda_env.name = None 

396 

397 dumped_env = conda_env.model_dump(mode="json", exclude_none=True) 

398 if not is_yaml_value(dumped_env): 

399 raise ValueError(f"Failed to dump conda env to valid YAML {conda_env}") 

400 

401 env_io = StringIO() 

402 write_yaml(dumped_env, file=env_io) 

403 encoded_env = env_io.getvalue().encode() 

404 env_name = hashlib.sha256(encoded_env).hexdigest() 

405 

406 try: 

407 run_command(["where" if platform.system() == "Windows" else "which", CONDA_CMD]) 

408 except Exception as e: 

409 raise RuntimeError("Conda not available") from e 

410 

411 try: 

412 run_command([CONDA_CMD, "run", "-n", env_name, "python", "--version"]) 

413 except Exception as e: 

414 working_dir.mkdir(parents=True, exist_ok=True) 

415 path = working_dir / "env.yaml" 

416 try: 

417 _ = path.write_bytes(encoded_env) 

418 logger.debug("written conda env to {}", path) 

419 run_command( 

420 [ 

421 CONDA_CMD, 

422 "env", 

423 "create", 

424 "--yes", 

425 f"--file={path}", 

426 f"--name={env_name}", 

427 ] 

428 + (["--quiet"] if settings.CI else []) 

429 ) 

430 # double check that environment was created successfully 

431 run_command([CONDA_CMD, "run", "-n", env_name, "python", "--version"]) 

432 except Exception as e: 

433 descr.validation_summary.add_detail( 

434 ValidationDetail( 

435 name="Conda environment creation", 

436 status="failed", 

437 loc=test_loc, 

438 recommended_env=conda_env, 

439 errors=[ 

440 ErrorEntry( 

441 loc=test_loc, 

442 msg=str(e), 

443 type="conda", 

444 with_traceback=True, 

445 ) 

446 ], 

447 ) 

448 ) 

449 return 

450 

451 working_dir.mkdir(parents=True, exist_ok=True) 

452 summary_path = working_dir / "summary.json" 

453 assert not summary_path.exists(), "Summary file already exists" 

454 cmd = [] 

455 cmd_error = None 

456 for summary_path_arg_name in ("summary", "summary-path"): 

457 try: 

458 run_command( 

459 cmd := ( 

460 [ 

461 CONDA_CMD, 

462 "run", 

463 "-n", 

464 env_name, 

465 "bioimageio", 

466 "test", 

467 str(source), 

468 f"--{summary_path_arg_name}={summary_path.as_posix()}", 

469 f"--determinism={determinism}", 

470 ] 

471 + ([f"--expected-type={expected_type}"] if expected_type else []) 

472 + (["--stop-early"] if stop_early else []) 

473 ) 

474 ) 

475 except Exception as e: 

476 cmd_error = f"Command '{' '.join(cmd)}' returned with error: {e}." 

477 

478 if summary_path.exists(): 

479 break 

480 else: 

481 if cmd_error is not None: 

482 logger.warning(cmd_error) 

483 

484 descr.validation_summary.add_detail( 

485 ValidationDetail( 

486 name="run 'bioimageio test' command", 

487 recommended_env=conda_env, 

488 errors=[ 

489 ErrorEntry( 

490 loc=(), 

491 type="bioimageio cli", 

492 msg=f"test command '{' '.join(cmd)}' did not produce a summary file at {summary_path}", 

493 ) 

494 ], 

495 status="failed", 

496 ) 

497 ) 

498 return 

499 

500 # add relevant details from command summary 

501 command_summary = ValidationSummary.load_json(summary_path) 

502 for detail in command_summary.details: 

503 if detail.loc[: len(test_loc)] == test_loc: 

504 descr.validation_summary.add_detail(detail) 

505 

506 

507@overload 

508def load_description_and_test( 

509 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

510 *, 

511 format_version: Literal["latest"], 

512 weight_format: Optional[SupportedWeightsFormat] = None, 

513 devices: Optional[Sequence[str]] = None, 

514 determinism: Literal["seed_only", "full"] = "seed_only", 

515 expected_type: Literal["model"], 

516 sha256: Optional[Sha256] = None, 

517 stop_early: bool = True, 

518 **deprecated: Unpack[DeprecatedKwargs], 

519) -> Union[ModelDescr, InvalidDescr]: ... 

520 

521 

522@overload 

523def load_description_and_test( 

524 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

525 *, 

526 format_version: Literal["latest"], 

527 weight_format: Optional[SupportedWeightsFormat] = None, 

528 devices: Optional[Sequence[str]] = None, 

529 determinism: Literal["seed_only", "full"] = "seed_only", 

530 expected_type: Literal["dataset"], 

531 sha256: Optional[Sha256] = None, 

532 stop_early: bool = True, 

533 **deprecated: Unpack[DeprecatedKwargs], 

534) -> Union[DatasetDescr, InvalidDescr]: ... 

535 

536 

537@overload 

538def load_description_and_test( 

539 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

540 *, 

541 format_version: Literal["latest"], 

542 weight_format: Optional[SupportedWeightsFormat] = None, 

543 devices: Optional[Sequence[str]] = None, 

544 determinism: Literal["seed_only", "full"] = "seed_only", 

545 expected_type: Optional[str] = None, 

546 sha256: Optional[Sha256] = None, 

547 stop_early: bool = True, 

548 **deprecated: Unpack[DeprecatedKwargs], 

549) -> Union[LatestResourceDescr, InvalidDescr]: ... 

550 

551 

552@overload 

553def load_description_and_test( 

554 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

555 *, 

556 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 

557 weight_format: Optional[SupportedWeightsFormat] = None, 

558 devices: Optional[Sequence[str]] = None, 

559 determinism: Literal["seed_only", "full"] = "seed_only", 

560 expected_type: Literal["model"], 

561 sha256: Optional[Sha256] = None, 

562 stop_early: bool = True, 

563 **deprecated: Unpack[DeprecatedKwargs], 

564) -> Union[AnyModelDescr, InvalidDescr]: ... 

565 

566 

567@overload 

568def load_description_and_test( 

569 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

570 *, 

571 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 

572 weight_format: Optional[SupportedWeightsFormat] = None, 

573 devices: Optional[Sequence[str]] = None, 

574 determinism: Literal["seed_only", "full"] = "seed_only", 

575 expected_type: Literal["dataset"], 

576 sha256: Optional[Sha256] = None, 

577 stop_early: bool = True, 

578 **deprecated: Unpack[DeprecatedKwargs], 

579) -> Union[AnyDatasetDescr, InvalidDescr]: ... 

580 

581 

582@overload 

583def load_description_and_test( 

584 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

585 *, 

586 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 

587 weight_format: Optional[SupportedWeightsFormat] = None, 

588 devices: Optional[Sequence[str]] = None, 

589 determinism: Literal["seed_only", "full"] = "seed_only", 

590 expected_type: Optional[str] = None, 

591 sha256: Optional[Sha256] = None, 

592 stop_early: bool = True, 

593 **deprecated: Unpack[DeprecatedKwargs], 

594) -> Union[ResourceDescr, InvalidDescr]: ... 

595 

596 

597def load_description_and_test( 

598 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

599 *, 

600 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 

601 weight_format: Optional[SupportedWeightsFormat] = None, 

602 devices: Optional[Sequence[str]] = None, 

603 determinism: Literal["seed_only", "full"] = "seed_only", 

604 expected_type: Optional[str] = None, 

605 sha256: Optional[Sha256] = None, 

606 stop_early: bool = True, 

607 **deprecated: Unpack[DeprecatedKwargs], 

608) -> Union[ResourceDescr, InvalidDescr]: 

609 """Test a bioimage.io resource dynamically, 

610 for example run prediction of test tensors for models. 

611 

612 See `test_description` for more details. 

613 

614 Returns: 

615 A (possibly invalid) resource description object 

616 with a populated `.validation_summary` attribute. 

617 """ 

618 if isinstance(source, ResourceDescrBase): 

619 root = source.root 

620 file_name = source.file_name 

621 if ( 

622 ( 

623 format_version 

624 not in ( 

625 DISCOVER, 

626 source.format_version, 

627 ".".join(source.format_version.split(".")[:2]), 

628 ) 

629 ) 

630 or (c := source.validation_summary.details[0].context) is None 

631 or not c.perform_io_checks 

632 ): 

633 logger.debug( 

634 "deserializing source to ensure we validate and test using format {} and perform io checks", 

635 format_version, 

636 ) 

637 source = dump_description(source) 

638 else: 

639 root = Path() 

640 file_name = None 

641 

642 if isinstance(source, ResourceDescrBase): 

643 rd = source 

644 elif isinstance(source, dict): 

645 # check context for a given root; default to root of source 

646 context = get_validation_context( 

647 ValidationContext(root=root, file_name=file_name) 

648 ).replace( 

649 perform_io_checks=True # make sure we perform io checks though 

650 ) 

651 

652 rd = build_description( 

653 source, 

654 format_version=format_version, 

655 context=context, 

656 ) 

657 else: 

658 rd = load_description( 

659 source, format_version=format_version, sha256=sha256, perform_io_checks=True 

660 ) 

661 

662 rd.validation_summary.env.add( 

663 InstalledPackage(name="bioimageio.core", version=__version__) 

664 ) 

665 

666 if expected_type is not None: 

667 _test_expected_resource_type(rd, expected_type) 

668 

669 if isinstance(rd, (v0_4.ModelDescr, v0_5.ModelDescr)): 

670 if weight_format is None: 

671 weight_formats: List[SupportedWeightsFormat] = [ 

672 w for w, we in rd.weights if we is not None 

673 ] # pyright: ignore[reportAssignmentType] 

674 else: 

675 weight_formats = [weight_format] 

676 

677 enable_determinism(determinism, weight_formats=weight_formats) 

678 for w in weight_formats: 

679 _test_model_inference(rd, w, devices, stop_early=stop_early, **deprecated) 

680 if stop_early and rd.validation_summary.status != "passed": 

681 break 

682 

683 if not isinstance(rd, v0_4.ModelDescr): 

684 _test_model_inference_parametrized( 

685 rd, w, devices, stop_early=stop_early 

686 ) 

687 if stop_early and rd.validation_summary.status != "passed": 

688 break 

689 

690 # TODO: add execution of jupyter notebooks 

691 # TODO: add more tests 

692 

693 return rd 

694 

695 

696def _get_tolerance( 

697 model: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

698 wf: SupportedWeightsFormat, 

699 m: MemberId, 

700 **deprecated: Unpack[DeprecatedKwargs], 

701) -> Tuple[RelativeTolerance, AbsoluteTolerance, MismatchedElementsPerMillion]: 

702 if isinstance(model, v0_5.ModelDescr): 

703 applicable = v0_5.ReproducibilityTolerance() 

704 

705 # check legacy test kwargs for weight format specific tolerance 

706 if model.config.bioimageio.model_extra is not None: 

707 for weights_format, test_kwargs in model.config.bioimageio.model_extra.get( 

708 "test_kwargs", {} 

709 ).items(): 

710 if wf == weights_format: 

711 applicable = v0_5.ReproducibilityTolerance( 

712 relative_tolerance=test_kwargs.get("relative_tolerance", 1e-3), 

713 absolute_tolerance=test_kwargs.get("absolute_tolerance", 1e-4), 

714 ) 

715 break 

716 

717 # check for weights format and output tensor specific tolerance 

718 for a in model.config.bioimageio.reproducibility_tolerance: 

719 if (not a.weights_formats or wf in a.weights_formats) and ( 

720 not a.output_ids or m in a.output_ids 

721 ): 

722 applicable = a 

723 break 

724 

725 rtol = applicable.relative_tolerance 

726 atol = applicable.absolute_tolerance 

727 mismatched_tol = applicable.mismatched_elements_per_million 

728 elif (decimal := deprecated.get("decimal")) is not None: 

729 warnings.warn( 

730 "The argument `decimal` has been deprecated in favour of" 

731 + " `relative_tolerance` and `absolute_tolerance`, with different" 

732 + " validation logic, using `numpy.testing.assert_allclose, see" 

733 + " 'https://numpy.org/doc/stable/reference/generated/" 

734 + " numpy.testing.assert_allclose.html'. Passing a value for `decimal`" 

735 + " will cause validation to revert to the old behaviour." 

736 ) 

737 atol = 1.5 * 10 ** (-decimal) 

738 rtol = 0 

739 mismatched_tol = 0 

740 else: 

741 # use given (deprecated) test kwargs 

742 atol = deprecated.get("absolute_tolerance", 1e-5) 

743 rtol = deprecated.get("relative_tolerance", 1e-3) 

744 mismatched_tol = 0 

745 

746 return rtol, atol, mismatched_tol 

747 

748 

749def _test_model_inference( 

750 model: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

751 weight_format: SupportedWeightsFormat, 

752 devices: Optional[Sequence[str]], 

753 stop_early: bool, 

754 **deprecated: Unpack[DeprecatedKwargs], 

755) -> None: 

756 test_name = f"Reproduce test outputs from test inputs ({weight_format})" 

757 logger.debug("starting '{}'", test_name) 

758 error_entries: List[ErrorEntry] = [] 

759 warning_entries: List[WarningEntry] = [] 

760 

761 def add_error_entry(msg: str, with_traceback: bool = False): 

762 error_entries.append( 

763 ErrorEntry( 

764 loc=("weights", weight_format), 

765 msg=msg, 

766 type="bioimageio.core", 

767 with_traceback=with_traceback, 

768 ) 

769 ) 

770 

771 def add_warning_entry(msg: str): 

772 warning_entries.append( 

773 WarningEntry( 

774 loc=("weights", weight_format), 

775 msg=msg, 

776 type="bioimageio.core", 

777 ) 

778 ) 

779 

780 try: 

781 test_input = get_test_input_sample(model) 

782 expected = get_test_output_sample(model) 

783 

784 with create_prediction_pipeline( 

785 bioimageio_model=model, devices=devices, weight_format=weight_format 

786 ) as prediction_pipeline: 

787 results = prediction_pipeline.predict_sample_without_blocking(test_input) 

788 

789 if len(results.members) != len(expected.members): 

790 add_error_entry( 

791 f"Expected {len(expected.members)} outputs, but got {len(results.members)}" 

792 ) 

793 

794 else: 

795 for m, expected in expected.members.items(): 

796 actual = results.members.get(m) 

797 if actual is None: 

798 add_error_entry("Output tensors for test case may not be None") 

799 if stop_early: 

800 break 

801 else: 

802 continue 

803 

804 if actual.dims != (dims := expected.dims): 

805 add_error_entry( 

806 f"Output '{m}' has dims {actual.dims}, but expected {expected.dims}" 

807 ) 

808 if stop_early: 

809 break 

810 else: 

811 continue 

812 

813 if actual.tagged_shape != expected.tagged_shape: 

814 add_error_entry( 

815 f"Output '{m}' has shape {actual.tagged_shape}, but expected {expected.tagged_shape}" 

816 ) 

817 if stop_early: 

818 break 

819 else: 

820 continue 

821 

822 try: 

823 expected_np = expected.data.to_numpy().astype(np.float32) 

824 del expected 

825 actual_np: NDArray[Any] = actual.data.to_numpy().astype(np.float32) 

826 

827 rtol, atol, mismatched_tol = _get_tolerance( 

828 model, wf=weight_format, m=m, **deprecated 

829 ) 

830 rtol_value = rtol * abs(expected_np) 

831 abs_diff = abs(actual_np - expected_np) 

832 mismatched = abs_diff > atol + rtol_value 

833 mismatched_elements = mismatched.sum().item() 

834 if not mismatched_elements: 

835 continue 

836 

837 actual_output_path = Path(f"actual_output_{m}_{weight_format}.npy") 

838 try: 

839 save_tensor(actual_output_path, actual) 

840 except Exception as e: 

841 logger.error( 

842 "Failed to save actual output tensor to {}: {}", 

843 actual_output_path, 

844 e, 

845 ) 

846 

847 mismatched_ppm = mismatched_elements / expected_np.size * 1e6 

848 abs_diff[~mismatched] = 0 # ignore non-mismatched elements 

849 

850 r_max_idx_flat = ( 

851 r_diff := (abs_diff / (abs(expected_np) + 1e-6)) 

852 ).argmax() 

853 r_max_idx = np.unravel_index(r_max_idx_flat, r_diff.shape) 

854 r_max = r_diff[r_max_idx].item() 

855 r_actual = actual_np[r_max_idx].item() 

856 r_expected = expected_np[r_max_idx].item() 

857 

858 # Calculate the max absolute difference with the relative tolerance subtracted 

859 abs_diff_wo_rtol: NDArray[np.float32] = abs_diff - rtol_value 

860 a_max_idx = np.unravel_index( 

861 abs_diff_wo_rtol.argmax(), abs_diff_wo_rtol.shape 

862 ) 

863 

864 a_max = abs_diff[a_max_idx].item() 

865 a_actual = actual_np[a_max_idx].item() 

866 a_expected = expected_np[a_max_idx].item() 

867 except Exception as e: 

868 msg = f"Output '{m}' disagrees with expected values." 

869 add_error_entry(msg) 

870 if stop_early: 

871 break 

872 else: 

873 msg = ( 

874 f"Output '{m}' disagrees with {mismatched_elements} of" 

875 + f" {expected_np.size} expected values" 

876 + f" ({mismatched_ppm:.1f} ppm)." 

877 + f"\n Max relative difference: {r_max:.2e}" 

878 + rf" (= \|{r_actual:.2e} - {r_expected:.2e}\|/\|{r_expected:.2e} + 1e-6\|)" 

879 + f" at {dict(zip(dims, r_max_idx))}" 

880 + f"\n Max absolute difference not accounted for by relative tolerance: {a_max:.2e}" 

881 + rf" (= \|{a_actual:.7e} - {a_expected:.7e}\|) at {dict(zip(dims, a_max_idx))}" 

882 + f"\n Saved actual output to {actual_output_path}." 

883 ) 

884 if mismatched_ppm > mismatched_tol: 

885 add_error_entry(msg) 

886 if stop_early: 

887 break 

888 else: 

889 add_warning_entry(msg) 

890 

891 except Exception as e: 

892 if get_validation_context().raise_errors: 

893 raise e 

894 

895 add_error_entry(str(e), with_traceback=True) 

896 

897 model.validation_summary.add_detail( 

898 ValidationDetail( 

899 name=test_name, 

900 loc=("weights", weight_format), 

901 status="failed" if error_entries else "passed", 

902 recommended_env=get_conda_env(entry=dict(model.weights)[weight_format]), 

903 errors=error_entries, 

904 warnings=warning_entries, 

905 ) 

906 ) 

907 

908 

909def _test_model_inference_parametrized( 

910 model: v0_5.ModelDescr, 

911 weight_format: SupportedWeightsFormat, 

912 devices: Optional[Sequence[str]], 

913 *, 

914 stop_early: bool, 

915) -> None: 

916 if not any( 

917 isinstance(a.size, v0_5.ParameterizedSize) 

918 for ipt in model.inputs 

919 for a in ipt.axes 

920 ): 

921 # no parameterized sizes => set n=0 

922 ns: Set[v0_5.ParameterizedSize_N] = {0} 

923 else: 

924 ns = {0, 1, 2} 

925 

926 given_batch_sizes = { 

927 a.size 

928 for ipt in model.inputs 

929 for a in ipt.axes 

930 if isinstance(a, v0_5.BatchAxis) 

931 } 

932 if given_batch_sizes: 

933 batch_sizes = {gbs for gbs in given_batch_sizes if gbs is not None} 

934 if not batch_sizes: 

935 # only arbitrary batch sizes 

936 batch_sizes = {1, 2} 

937 else: 

938 # no batch axis 

939 batch_sizes = {1} 

940 

941 test_cases: Set[Tuple[BatchSize, v0_5.ParameterizedSize_N]] = { 

942 (b, n) for b, n in product(sorted(batch_sizes), sorted(ns)) 

943 } 

944 logger.info( 

945 "Testing inference with '{}' for {} different inputs (B, N): {}", 

946 weight_format, 

947 len(test_cases), 

948 test_cases, 

949 ) 

950 

951 def generate_test_cases(): 

952 tested: Set[Hashable] = set() 

953 

954 def get_ns(n: int): 

955 return { 

956 (t.id, a.id): n 

957 for t in model.inputs 

958 for a in t.axes 

959 if isinstance(a.size, v0_5.ParameterizedSize) 

960 } 

961 

962 for batch_size, n in sorted(test_cases): 

963 input_target_sizes, expected_output_sizes = model.get_axis_sizes( 

964 get_ns(n), batch_size=batch_size 

965 ) 

966 hashable_target_size = tuple( 

967 (k, input_target_sizes[k]) for k in sorted(input_target_sizes) 

968 ) 

969 if hashable_target_size in tested: 

970 continue 

971 else: 

972 tested.add(hashable_target_size) 

973 

974 resized_test_inputs = Sample( 

975 members={ 

976 t.id: ( 

977 test_input.members[t.id].resize_to( 

978 { 

979 aid: s 

980 for (tid, aid), s in input_target_sizes.items() 

981 if tid == t.id 

982 }, 

983 ) 

984 ) 

985 for t in model.inputs 

986 }, 

987 stat=test_input.stat, 

988 id=test_input.id, 

989 ) 

990 expected_output_shapes = { 

991 t.id: { 

992 aid: s 

993 for (tid, aid), s in expected_output_sizes.items() 

994 if tid == t.id 

995 } 

996 for t in model.outputs 

997 } 

998 yield n, batch_size, resized_test_inputs, expected_output_shapes 

999 

1000 try: 

1001 test_input = get_test_input_sample(model) 

1002 

1003 with create_prediction_pipeline( 

1004 bioimageio_model=model, devices=devices, weight_format=weight_format 

1005 ) as prediction_pipeline: 

1006 for n, batch_size, inputs, exptected_output_shape in generate_test_cases(): 

1007 error: Optional[str] = None 

1008 try: 

1009 result = prediction_pipeline.predict_sample_without_blocking(inputs) 

1010 except Exception as e: 

1011 error = str(e) 

1012 else: 

1013 if len(result.members) != len(exptected_output_shape): 

1014 error = ( 

1015 f"Expected {len(exptected_output_shape)} outputs," 

1016 + f" but got {len(result.members)}" 

1017 ) 

1018 

1019 else: 

1020 for m, exp in exptected_output_shape.items(): 

1021 res = result.members.get(m) 

1022 if res is None: 

1023 error = "Output tensors may not be None for test case" 

1024 break 

1025 

1026 diff: Dict[AxisId, int] = {} 

1027 for a, s in res.sizes.items(): 

1028 if isinstance((e_aid := exp[AxisId(a)]), int): 

1029 if s != e_aid: 

1030 diff[AxisId(a)] = s 

1031 elif ( 

1032 s < e_aid.min 

1033 or e_aid.max is not None 

1034 and s > e_aid.max 

1035 ): 

1036 diff[AxisId(a)] = s 

1037 if diff: 

1038 error = ( 

1039 f"(n={n}) Expected output shape {exp}," 

1040 + f" but got {res.sizes} (diff: {diff})" 

1041 ) 

1042 break 

1043 

1044 model.validation_summary.add_detail( 

1045 ValidationDetail( 

1046 name=f"Run {weight_format} inference for inputs with" 

1047 + f" batch_size: {batch_size} and size parameter n: {n}", 

1048 loc=("weights", weight_format), 

1049 status="passed" if error is None else "failed", 

1050 errors=( 

1051 [] 

1052 if error is None 

1053 else [ 

1054 ErrorEntry( 

1055 loc=("weights", weight_format), 

1056 msg=error, 

1057 type="bioimageio.core", 

1058 ) 

1059 ] 

1060 ), 

1061 ) 

1062 ) 

1063 if stop_early and error is not None: 

1064 break 

1065 except Exception as e: 

1066 if get_validation_context().raise_errors: 

1067 raise e 

1068 

1069 model.validation_summary.add_detail( 

1070 ValidationDetail( 

1071 name=f"Run {weight_format} inference for parametrized inputs", 

1072 status="failed", 

1073 loc=("weights", weight_format), 

1074 errors=[ 

1075 ErrorEntry( 

1076 loc=("weights", weight_format), 

1077 msg=str(e), 

1078 type="bioimageio.core", 

1079 with_traceback=True, 

1080 ) 

1081 ], 

1082 ) 

1083 ) 

1084 

1085 

1086def _test_expected_resource_type( 

1087 rd: Union[InvalidDescr, ResourceDescr], expected_type: str 

1088): 

1089 has_expected_type = rd.type == expected_type 

1090 rd.validation_summary.details.append( 

1091 ValidationDetail( 

1092 name="Has expected resource type", 

1093 status="passed" if has_expected_type else "failed", 

1094 loc=("type",), 

1095 errors=( 

1096 [] 

1097 if has_expected_type 

1098 else [ 

1099 ErrorEntry( 

1100 loc=("type",), 

1101 type="type", 

1102 msg=f"Expected type {expected_type}, found {rd.type}", 

1103 ) 

1104 ] 

1105 ), 

1106 ) 

1107 ) 

1108 

1109 

1110# TODO: Implement `debug_model()` 

1111# def debug_model( 

1112# model_rdf: Union[RawResourceDescr, ResourceDescr, URI, Path, str], 

1113# *, 

1114# weight_format: Optional[WeightsFormat] = None, 

1115# devices: Optional[List[str]] = None, 

1116# ): 

1117# """Run the model test and return dict with inputs, results, expected results and intermediates. 

1118 

1119# Returns dict with tensors "inputs", "inputs_processed", "outputs_raw", "outputs", "expected" and "diff". 

1120# """ 

1121# inputs_raw: Optional = None 

1122# inputs_processed: Optional = None 

1123# outputs_raw: Optional = None 

1124# outputs: Optional = None 

1125# expected: Optional = None 

1126# diff: Optional = None 

1127 

1128# model = load_description( 

1129# model_rdf, weights_priority_order=None if weight_format is None else [weight_format] 

1130# ) 

1131# if not isinstance(model, Model): 

1132# raise ValueError(f"Not a bioimageio.model: {model_rdf}") 

1133 

1134# prediction_pipeline = create_prediction_pipeline( 

1135# bioimageio_model=model, devices=devices, weight_format=weight_format 

1136# ) 

1137# inputs = [ 

1138# xr.DataArray(load_array(str(in_path)), dims=input_spec.axes) 

1139# for in_path, input_spec in zip(model.test_inputs, model.inputs) 

1140# ] 

1141# input_dict = {input_spec.name: input for input_spec, input in zip(model.inputs, inputs)} 

1142 

1143# # keep track of the non-processed inputs 

1144# inputs_raw = [deepcopy(input) for input in inputs] 

1145 

1146# computed_measures = {} 

1147 

1148# prediction_pipeline.apply_preprocessing(input_dict, computed_measures) 

1149# inputs_processed = list(input_dict.values()) 

1150# outputs_raw = prediction_pipeline.predict(*inputs_processed) 

1151# output_dict = {output_spec.name: deepcopy(output) for output_spec, output in zip(model.outputs, outputs_raw)} 

1152# prediction_pipeline.apply_postprocessing(output_dict, computed_measures) 

1153# outputs = list(output_dict.values()) 

1154 

1155# if isinstance(outputs, (np.ndarray, xr.DataArray)): 

1156# outputs = [outputs] 

1157 

1158# expected = [ 

1159# xr.DataArray(load_array(str(out_path)), dims=output_spec.axes) 

1160# for out_path, output_spec in zip(model.test_outputs, model.outputs) 

1161# ] 

1162# if len(outputs) != len(expected): 

1163# error = f"Number of outputs and number of expected outputs disagree: {len(outputs)} != {len(expected)}" 

1164# print(error) 

1165# else: 

1166# diff = [] 

1167# for res, exp in zip(outputs, expected): 

1168# diff.append(res - exp) 

1169 

1170# return { 

1171# "inputs": inputs_raw, 

1172# "inputs_processed": inputs_processed, 

1173# "outputs_raw": outputs_raw, 

1174# "outputs": outputs, 

1175# "expected": expected, 

1176# "diff": diff, 

1177# }