Coverage for src/bioimageio/core/_resource_tests.py: 52%

370 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-14 08:35 +0000

1import hashlib 

2import os 

3import platform 

4import subprocess 

5import sys 

6import warnings 

7from io import StringIO 

8from itertools import product 

9from pathlib import Path 

10from tempfile import TemporaryDirectory 

11from typing import ( 

12 Any, 

13 Callable, 

14 Dict, 

15 Hashable, 

16 List, 

17 Literal, 

18 Optional, 

19 Sequence, 

20 Set, 

21 Tuple, 

22 Union, 

23 overload, 

24) 

25 

26import numpy as np 

27from bioimageio.spec import ( 

28 AnyDatasetDescr, 

29 AnyModelDescr, 

30 BioimageioCondaEnv, 

31 DatasetDescr, 

32 InvalidDescr, 

33 LatestResourceDescr, 

34 ModelDescr, 

35 ResourceDescr, 

36 ValidationContext, 

37 build_description, 

38 dump_description, 

39 get_conda_env, 

40 load_description, 

41 save_bioimageio_package, 

42) 

43from bioimageio.spec._description_impl import DISCOVER 

44from bioimageio.spec._internal.common_nodes import ResourceDescrBase 

45from bioimageio.spec._internal.io import is_yaml_value 

46from bioimageio.spec._internal.io_utils import read_yaml, write_yaml 

47from bioimageio.spec._internal.types import ( 

48 AbsoluteTolerance, 

49 FormatVersionPlaceholder, 

50 MismatchedElementsPerMillion, 

51 RelativeTolerance, 

52) 

53from bioimageio.spec._internal.validation_context import get_validation_context 

54from bioimageio.spec.common import BioimageioYamlContent, PermissiveFileSource, Sha256 

55from bioimageio.spec.model import v0_4, v0_5 

56from bioimageio.spec.model.v0_5 import WeightsFormat 

57from bioimageio.spec.summary import ( 

58 ErrorEntry, 

59 InstalledPackage, 

60 ValidationDetail, 

61 ValidationSummary, 

62 WarningEntry, 

63) 

64from loguru import logger 

65from numpy.typing import NDArray 

66from typing_extensions import NotRequired, TypedDict, Unpack, assert_never, get_args 

67 

68from bioimageio.core import __version__ 

69from bioimageio.core.io import save_tensor 

70 

71from ._prediction_pipeline import create_prediction_pipeline 

72from ._settings import settings 

73from .axis import AxisId, BatchSize 

74from .common import MemberId, SupportedWeightsFormat 

75from .digest_spec import get_test_input_sample, get_test_output_sample 

76from .sample import Sample 

77 

78CONDA_CMD = "conda.bat" if platform.system() == "Windows" else "conda" 

79 

80 

81class DeprecatedKwargs(TypedDict): 

82 absolute_tolerance: NotRequired[AbsoluteTolerance] 

83 relative_tolerance: NotRequired[RelativeTolerance] 

84 decimal: NotRequired[Optional[int]] 

85 

86 

87def enable_determinism( 

88 mode: Literal["seed_only", "full"] = "full", 

89 weight_formats: Optional[Sequence[SupportedWeightsFormat]] = None, 

90): 

91 """Seed and configure ML frameworks for maximum reproducibility. 

92 May degrade performance. Only recommended for testing reproducibility! 

93 

94 Seed any random generators and (if **mode**=="full") request ML frameworks to use 

95 deterministic algorithms. 

96 

97 Args: 

98 mode: determinism mode 

99 - 'seed_only' -- only set seeds, or 

100 - 'full' determinsm features (might degrade performance or throw exceptions) 

101 weight_formats: Limit deep learning importing deep learning frameworks 

102 based on weight_formats. 

103 E.g. this allows to avoid importing tensorflow when testing with pytorch. 

104 

105 Notes: 

106 - **mode** == "full" might degrade performance or throw exceptions. 

107 - Subsequent inference calls might still differ. Call before each function 

108 (sequence) that is expected to be reproducible. 

109 - Degraded performance: Use for testing reproducibility only! 

110 - Recipes: 

111 - [PyTorch](https://pytorch.org/docs/stable/notes/randomness.html) 

112 - [Keras](https://keras.io/examples/keras_recipes/reproducibility_recipes/) 

113 - [NumPy](https://numpy.org/doc/2.0/reference/random/generated/numpy.random.seed.html) 

114 """ 

115 try: 

116 try: 

117 import numpy.random 

118 except ImportError: 

119 pass 

120 else: 

121 numpy.random.seed(0) 

122 except Exception as e: 

123 logger.debug(str(e)) 

124 

125 if ( 

126 weight_formats is None 

127 or "pytorch_state_dict" in weight_formats 

128 or "torchscript" in weight_formats 

129 ): 

130 try: 

131 try: 

132 import torch 

133 except ImportError: 

134 pass 

135 else: 

136 _ = torch.manual_seed(0) 

137 torch.use_deterministic_algorithms(mode == "full") 

138 except Exception as e: 

139 logger.debug(str(e)) 

140 

141 if ( 

142 weight_formats is None 

143 or "tensorflow_saved_model_bundle" in weight_formats 

144 or "keras_hdf5" in weight_formats 

145 ): 

146 try: 

147 os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" 

148 try: 

149 import tensorflow as tf # pyright: ignore[reportMissingTypeStubs] 

150 except ImportError: 

151 pass 

152 else: 

153 tf.random.set_seed(0) 

154 if mode == "full": 

155 tf.config.experimental.enable_op_determinism() 

156 # TODO: find possibility to switch it off again?? 

157 except Exception as e: 

158 logger.debug(str(e)) 

159 

160 if weight_formats is None or "keras_hdf5" in weight_formats: 

161 try: 

162 try: 

163 import keras # pyright: ignore[reportMissingTypeStubs] 

164 except ImportError: 

165 pass 

166 else: 

167 keras.utils.set_random_seed(0) 

168 except Exception as e: 

169 logger.debug(str(e)) 

170 

171 

172def test_model( 

173 source: Union[v0_4.ModelDescr, v0_5.ModelDescr, PermissiveFileSource], 

174 weight_format: Optional[SupportedWeightsFormat] = None, 

175 devices: Optional[List[str]] = None, 

176 *, 

177 determinism: Literal["seed_only", "full"] = "seed_only", 

178 sha256: Optional[Sha256] = None, 

179 stop_early: bool = True, 

180 **deprecated: Unpack[DeprecatedKwargs], 

181) -> ValidationSummary: 

182 """Test model inference""" 

183 return test_description( 

184 source, 

185 weight_format=weight_format, 

186 devices=devices, 

187 determinism=determinism, 

188 expected_type="model", 

189 sha256=sha256, 

190 stop_early=stop_early, 

191 **deprecated, 

192 ) 

193 

194 

195def default_run_command(args: Sequence[str]): 

196 logger.info("running '{}'...", " ".join(args)) 

197 _ = subprocess.check_call(args) 

198 

199 

200def test_description( 

201 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

202 *, 

203 format_version: Union[FormatVersionPlaceholder, str] = "discover", 

204 weight_format: Optional[SupportedWeightsFormat] = None, 

205 devices: Optional[Sequence[str]] = None, 

206 determinism: Literal["seed_only", "full"] = "seed_only", 

207 expected_type: Optional[str] = None, 

208 sha256: Optional[Sha256] = None, 

209 stop_early: bool = True, 

210 runtime_env: Union[ 

211 Literal["currently-active", "as-described"], Path, BioimageioCondaEnv 

212 ] = ("currently-active"), 

213 run_command: Callable[[Sequence[str]], None] = default_run_command, 

214 **deprecated: Unpack[DeprecatedKwargs], 

215) -> ValidationSummary: 

216 """Test a bioimage.io resource dynamically, 

217 for example run prediction of test tensors for models. 

218 

219 Args: 

220 source: model description source. 

221 weight_format: Weight format to test. 

222 Default: All weight formats present in **source**. 

223 devices: Devices to test with, e.g. 'cpu', 'cuda'. 

224 Default (may be weight format dependent): ['cuda'] if available, ['cpu'] otherwise. 

225 determinism: Modes to improve reproducibility of test outputs. 

226 expected_type: Assert an expected resource description `type`. 

227 sha256: Expected SHA256 value of **source**. 

228 (Ignored if **source** already is a loaded `ResourceDescr` object.) 

229 stop_early: Do not run further subtests after a failed one. 

230 runtime_env: (Experimental feature!) The Python environment to run the tests in 

231 - `"currently-active"`: Use active Python interpreter. 

232 - `"as-described"`: Use `bioimageio.spec.get_conda_env` to generate a conda 

233 environment YAML file based on the model weights description. 

234 - A `BioimageioCondaEnv` or a path to a conda environment YAML file. 

235 Note: The `bioimageio.core` dependency will be added automatically if not present. 

236 run_command: (Experimental feature!) Function to execute (conda) terminal commands in a subprocess. 

237 The function should raise an exception if the command fails. 

238 **run_command** is ignored if **runtime_env** is `"currently-active"`. 

239 """ 

240 if runtime_env == "currently-active": 

241 rd = load_description_and_test( 

242 source, 

243 format_version=format_version, 

244 weight_format=weight_format, 

245 devices=devices, 

246 determinism=determinism, 

247 expected_type=expected_type, 

248 sha256=sha256, 

249 stop_early=stop_early, 

250 **deprecated, 

251 ) 

252 return rd.validation_summary 

253 

254 if runtime_env == "as-described": 

255 conda_env = None 

256 elif isinstance(runtime_env, (str, Path)): 

257 conda_env = BioimageioCondaEnv.model_validate(read_yaml(Path(runtime_env))) 

258 elif isinstance(runtime_env, BioimageioCondaEnv): 

259 conda_env = runtime_env 

260 else: 

261 assert_never(runtime_env) 

262 

263 try: 

264 run_command(["thiscommandshouldalwaysfail", "please"]) 

265 except Exception: 

266 pass 

267 else: 

268 raise RuntimeError( 

269 "given run_command does not raise an exception for a failing command" 

270 ) 

271 

272 td_kwargs: Dict[str, Any] = ( 

273 dict(ignore_cleanup_errors=True) if sys.version_info >= (3, 10) else {} 

274 ) 

275 with TemporaryDirectory(**td_kwargs) as _d: 

276 working_dir = Path(_d) 

277 if isinstance(source, (dict, ResourceDescrBase)): 

278 file_source = save_bioimageio_package( 

279 source, output_path=working_dir / "package.zip" 

280 ) 

281 else: 

282 file_source = source 

283 

284 return _test_in_env( 

285 file_source, 

286 working_dir=working_dir, 

287 weight_format=weight_format, 

288 conda_env=conda_env, 

289 devices=devices, 

290 determinism=determinism, 

291 expected_type=expected_type, 

292 sha256=sha256, 

293 stop_early=stop_early, 

294 run_command=run_command, 

295 **deprecated, 

296 ) 

297 

298 

299def _test_in_env( 

300 source: PermissiveFileSource, 

301 *, 

302 working_dir: Path, 

303 weight_format: Optional[SupportedWeightsFormat], 

304 conda_env: Optional[BioimageioCondaEnv], 

305 devices: Optional[Sequence[str]], 

306 determinism: Literal["seed_only", "full"], 

307 run_command: Callable[[Sequence[str]], None], 

308 stop_early: bool, 

309 expected_type: Optional[str], 

310 sha256: Optional[Sha256], 

311 **deprecated: Unpack[DeprecatedKwargs], 

312) -> ValidationSummary: 

313 descr = load_description(source) 

314 

315 if not isinstance(descr, (v0_4.ModelDescr, v0_5.ModelDescr)): 

316 raise NotImplementedError("Not yet implemented for non-model resources") 

317 

318 if weight_format is None: 

319 all_present_wfs = [ 

320 wf for wf in get_args(WeightsFormat) if getattr(descr.weights, wf) 

321 ] 

322 ignore_wfs = [wf for wf in all_present_wfs if wf in ["tensorflow_js"]] 

323 logger.info( 

324 "Found weight formats {}. Start testing all{}...", 

325 all_present_wfs, 

326 f" (except: {', '.join(ignore_wfs)}) " if ignore_wfs else "", 

327 ) 

328 summary = _test_in_env( 

329 source, 

330 working_dir=working_dir / all_present_wfs[0], 

331 weight_format=all_present_wfs[0], 

332 devices=devices, 

333 determinism=determinism, 

334 conda_env=conda_env, 

335 run_command=run_command, 

336 expected_type=expected_type, 

337 sha256=sha256, 

338 stop_early=stop_early, 

339 **deprecated, 

340 ) 

341 for wf in all_present_wfs[1:]: 

342 additional_summary = _test_in_env( 

343 source, 

344 working_dir=working_dir / wf, 

345 weight_format=wf, 

346 devices=devices, 

347 determinism=determinism, 

348 conda_env=conda_env, 

349 run_command=run_command, 

350 expected_type=expected_type, 

351 sha256=sha256, 

352 stop_early=stop_early, 

353 **deprecated, 

354 ) 

355 for d in additional_summary.details: 

356 # TODO: filter reduntant details; group details 

357 summary.add_detail(d) 

358 return summary 

359 

360 if weight_format == "pytorch_state_dict": 

361 wf = descr.weights.pytorch_state_dict 

362 elif weight_format == "torchscript": 

363 wf = descr.weights.torchscript 

364 elif weight_format == "keras_hdf5": 

365 wf = descr.weights.keras_hdf5 

366 elif weight_format == "onnx": 

367 wf = descr.weights.onnx 

368 elif weight_format == "tensorflow_saved_model_bundle": 

369 wf = descr.weights.tensorflow_saved_model_bundle 

370 elif weight_format == "tensorflow_js": 

371 raise RuntimeError( 

372 "testing 'tensorflow_js' is not supported by bioimageio.core" 

373 ) 

374 else: 

375 assert_never(weight_format) 

376 

377 assert wf is not None 

378 if conda_env is None: 

379 conda_env = get_conda_env(entry=wf) 

380 

381 # remove name as we crate a name based on the env description hash value 

382 conda_env.name = None 

383 

384 dumped_env = conda_env.model_dump(mode="json", exclude_none=True) 

385 if not is_yaml_value(dumped_env): 

386 raise ValueError(f"Failed to dump conda env to valid YAML {conda_env}") 

387 

388 env_io = StringIO() 

389 write_yaml(dumped_env, file=env_io) 

390 encoded_env = env_io.getvalue().encode() 

391 env_name = hashlib.sha256(encoded_env).hexdigest() 

392 

393 try: 

394 run_command(["where" if platform.system() == "Windows" else "which", CONDA_CMD]) 

395 except Exception as e: 

396 raise RuntimeError("Conda not available") from e 

397 

398 try: 

399 run_command([CONDA_CMD, "run", "-n", env_name, "python", "--version"]) 

400 except Exception as e: 

401 working_dir.mkdir(parents=True, exist_ok=True) 

402 path = working_dir / "env.yaml" 

403 try: 

404 _ = path.write_bytes(encoded_env) 

405 logger.debug("written conda env to {}", path) 

406 run_command( 

407 [ 

408 CONDA_CMD, 

409 "env", 

410 "create", 

411 "--yes", 

412 f"--file={path}", 

413 f"--name={env_name}", 

414 ] 

415 + (["--quiet"] if settings.CI else []) 

416 ) 

417 # double check that environment was created successfully 

418 run_command([CONDA_CMD, "run", "-n", env_name, "python", "--version"]) 

419 except Exception as e: 

420 summary = descr.validation_summary 

421 summary.add_detail( 

422 ValidationDetail( 

423 name="Conda environment creation", 

424 status="failed", 

425 loc=("weights", weight_format), 

426 recommended_env=conda_env, 

427 errors=[ 

428 ErrorEntry( 

429 loc=("weights", weight_format), 

430 msg=str(e), 

431 type="conda", 

432 with_traceback=True, 

433 ) 

434 ], 

435 ) 

436 ) 

437 return summary 

438 

439 working_dir.mkdir(parents=True, exist_ok=True) 

440 summary_path = working_dir / "summary.json" 

441 assert not summary_path.exists(), "Summary file already exists" 

442 cmd = [] 

443 cmd_error = None 

444 for summary_path_arg_name in ("summary", "summary-path"): 

445 try: 

446 run_command( 

447 cmd := ( 

448 [ 

449 CONDA_CMD, 

450 "run", 

451 "-n", 

452 env_name, 

453 "bioimageio", 

454 "test", 

455 str(source), 

456 f"--{summary_path_arg_name}={summary_path.as_posix()}", 

457 f"--determinism={determinism}", 

458 ] 

459 + ([f"--expected-type={expected_type}"] if expected_type else []) 

460 + (["--stop-early"] if stop_early else []) 

461 ) 

462 ) 

463 except Exception as e: 

464 cmd_error = f"Failed to run command '{' '.join(cmd)}': {e}." 

465 

466 if summary_path.exists(): 

467 break 

468 else: 

469 if cmd_error is not None: 

470 logger.warning(cmd_error) 

471 

472 return ValidationSummary( 

473 name="calling bioimageio test command", 

474 source_name=str(source), 

475 status="failed", 

476 type="unknown", 

477 format_version="unknown", 

478 details=[ 

479 ValidationDetail( 

480 name="run 'bioimageio test'", 

481 errors=[ 

482 ErrorEntry( 

483 loc=(), 

484 type="bioimageio cli", 

485 msg=f"test command '{' '.join(cmd)}' did not produce a summary file at {summary_path}", 

486 ) 

487 ], 

488 status="failed", 

489 ) 

490 ], 

491 env=set(), 

492 ) 

493 

494 return ValidationSummary.load_json(summary_path) 

495 

496 

497@overload 

498def load_description_and_test( 

499 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

500 *, 

501 format_version: Literal["latest"], 

502 weight_format: Optional[SupportedWeightsFormat] = None, 

503 devices: Optional[Sequence[str]] = None, 

504 determinism: Literal["seed_only", "full"] = "seed_only", 

505 expected_type: Literal["model"], 

506 sha256: Optional[Sha256] = None, 

507 stop_early: bool = True, 

508 **deprecated: Unpack[DeprecatedKwargs], 

509) -> Union[ModelDescr, InvalidDescr]: ... 

510 

511 

512@overload 

513def load_description_and_test( 

514 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

515 *, 

516 format_version: Literal["latest"], 

517 weight_format: Optional[SupportedWeightsFormat] = None, 

518 devices: Optional[Sequence[str]] = None, 

519 determinism: Literal["seed_only", "full"] = "seed_only", 

520 expected_type: Literal["dataset"], 

521 sha256: Optional[Sha256] = None, 

522 stop_early: bool = True, 

523 **deprecated: Unpack[DeprecatedKwargs], 

524) -> Union[DatasetDescr, InvalidDescr]: ... 

525 

526 

527@overload 

528def load_description_and_test( 

529 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

530 *, 

531 format_version: Literal["latest"], 

532 weight_format: Optional[SupportedWeightsFormat] = None, 

533 devices: Optional[Sequence[str]] = None, 

534 determinism: Literal["seed_only", "full"] = "seed_only", 

535 expected_type: Optional[str] = None, 

536 sha256: Optional[Sha256] = None, 

537 stop_early: bool = True, 

538 **deprecated: Unpack[DeprecatedKwargs], 

539) -> Union[LatestResourceDescr, InvalidDescr]: ... 

540 

541 

542@overload 

543def load_description_and_test( 

544 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

545 *, 

546 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 

547 weight_format: Optional[SupportedWeightsFormat] = None, 

548 devices: Optional[Sequence[str]] = None, 

549 determinism: Literal["seed_only", "full"] = "seed_only", 

550 expected_type: Literal["model"], 

551 sha256: Optional[Sha256] = None, 

552 stop_early: bool = True, 

553 **deprecated: Unpack[DeprecatedKwargs], 

554) -> Union[AnyModelDescr, InvalidDescr]: ... 

555 

556 

557@overload 

558def load_description_and_test( 

559 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

560 *, 

561 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 

562 weight_format: Optional[SupportedWeightsFormat] = None, 

563 devices: Optional[Sequence[str]] = None, 

564 determinism: Literal["seed_only", "full"] = "seed_only", 

565 expected_type: Literal["dataset"], 

566 sha256: Optional[Sha256] = None, 

567 stop_early: bool = True, 

568 **deprecated: Unpack[DeprecatedKwargs], 

569) -> Union[AnyDatasetDescr, InvalidDescr]: ... 

570 

571 

572@overload 

573def load_description_and_test( 

574 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

575 *, 

576 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 

577 weight_format: Optional[SupportedWeightsFormat] = None, 

578 devices: Optional[Sequence[str]] = None, 

579 determinism: Literal["seed_only", "full"] = "seed_only", 

580 expected_type: Optional[str] = None, 

581 sha256: Optional[Sha256] = None, 

582 stop_early: bool = True, 

583 **deprecated: Unpack[DeprecatedKwargs], 

584) -> Union[ResourceDescr, InvalidDescr]: ... 

585 

586 

587def load_description_and_test( 

588 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

589 *, 

590 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 

591 weight_format: Optional[SupportedWeightsFormat] = None, 

592 devices: Optional[Sequence[str]] = None, 

593 determinism: Literal["seed_only", "full"] = "seed_only", 

594 expected_type: Optional[str] = None, 

595 sha256: Optional[Sha256] = None, 

596 stop_early: bool = True, 

597 **deprecated: Unpack[DeprecatedKwargs], 

598) -> Union[ResourceDescr, InvalidDescr]: 

599 """Test a bioimage.io resource dynamically, 

600 for example run prediction of test tensors for models. 

601 

602 See `test_description` for more details. 

603 

604 Returns: 

605 A (possibly invalid) resource description object 

606 with a populated `.validation_summary` attribute. 

607 """ 

608 if isinstance(source, ResourceDescrBase): 

609 root = source.root 

610 file_name = source.file_name 

611 if ( 

612 ( 

613 format_version 

614 not in ( 

615 DISCOVER, 

616 source.format_version, 

617 ".".join(source.format_version.split(".")[:2]), 

618 ) 

619 ) 

620 or (c := source.validation_summary.details[0].context) is None 

621 or not c.perform_io_checks 

622 ): 

623 logger.debug( 

624 "deserializing source to ensure we validate and test using format {} and perform io checks", 

625 format_version, 

626 ) 

627 source = dump_description(source) 

628 else: 

629 root = Path() 

630 file_name = None 

631 

632 if isinstance(source, ResourceDescrBase): 

633 rd = source 

634 elif isinstance(source, dict): 

635 # check context for a given root; default to root of source 

636 context = get_validation_context( 

637 ValidationContext(root=root, file_name=file_name) 

638 ).replace( 

639 perform_io_checks=True # make sure we perform io checks though 

640 ) 

641 

642 rd = build_description( 

643 source, 

644 format_version=format_version, 

645 context=context, 

646 ) 

647 else: 

648 rd = load_description( 

649 source, format_version=format_version, sha256=sha256, perform_io_checks=True 

650 ) 

651 

652 rd.validation_summary.env.add( 

653 InstalledPackage(name="bioimageio.core", version=__version__) 

654 ) 

655 

656 if expected_type is not None: 

657 _test_expected_resource_type(rd, expected_type) 

658 

659 if isinstance(rd, (v0_4.ModelDescr, v0_5.ModelDescr)): 

660 if weight_format is None: 

661 weight_formats: List[SupportedWeightsFormat] = [ 

662 w for w, we in rd.weights if we is not None 

663 ] # pyright: ignore[reportAssignmentType] 

664 else: 

665 weight_formats = [weight_format] 

666 

667 enable_determinism(determinism, weight_formats=weight_formats) 

668 for w in weight_formats: 

669 _test_model_inference(rd, w, devices, stop_early=stop_early, **deprecated) 

670 if stop_early and rd.validation_summary.status == "failed": 

671 break 

672 

673 if not isinstance(rd, v0_4.ModelDescr): 

674 _test_model_inference_parametrized( 

675 rd, w, devices, stop_early=stop_early 

676 ) 

677 if stop_early and rd.validation_summary.status == "failed": 

678 break 

679 

680 # TODO: add execution of jupyter notebooks 

681 # TODO: add more tests 

682 

683 if rd.validation_summary.status == "valid-format": 

684 rd.validation_summary.status = "passed" 

685 

686 return rd 

687 

688 

689def _get_tolerance( 

690 model: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

691 wf: SupportedWeightsFormat, 

692 m: MemberId, 

693 **deprecated: Unpack[DeprecatedKwargs], 

694) -> Tuple[RelativeTolerance, AbsoluteTolerance, MismatchedElementsPerMillion]: 

695 if isinstance(model, v0_5.ModelDescr): 

696 applicable = v0_5.ReproducibilityTolerance() 

697 

698 # check legacy test kwargs for weight format specific tolerance 

699 if model.config.bioimageio.model_extra is not None: 

700 for weights_format, test_kwargs in model.config.bioimageio.model_extra.get( 

701 "test_kwargs", {} 

702 ).items(): 

703 if wf == weights_format: 

704 applicable = v0_5.ReproducibilityTolerance( 

705 relative_tolerance=test_kwargs.get("relative_tolerance", 1e-3), 

706 absolute_tolerance=test_kwargs.get("absolute_tolerance", 1e-4), 

707 ) 

708 break 

709 

710 # check for weights format and output tensor specific tolerance 

711 for a in model.config.bioimageio.reproducibility_tolerance: 

712 if (not a.weights_formats or wf in a.weights_formats) and ( 

713 not a.output_ids or m in a.output_ids 

714 ): 

715 applicable = a 

716 break 

717 

718 rtol = applicable.relative_tolerance 

719 atol = applicable.absolute_tolerance 

720 mismatched_tol = applicable.mismatched_elements_per_million 

721 elif (decimal := deprecated.get("decimal")) is not None: 

722 warnings.warn( 

723 "The argument `decimal` has been deprecated in favour of" 

724 + " `relative_tolerance` and `absolute_tolerance`, with different" 

725 + " validation logic, using `numpy.testing.assert_allclose, see" 

726 + " 'https://numpy.org/doc/stable/reference/generated/" 

727 + " numpy.testing.assert_allclose.html'. Passing a value for `decimal`" 

728 + " will cause validation to revert to the old behaviour." 

729 ) 

730 atol = 1.5 * 10 ** (-decimal) 

731 rtol = 0 

732 mismatched_tol = 0 

733 else: 

734 # use given (deprecated) test kwargs 

735 atol = deprecated.get("absolute_tolerance", 1e-5) 

736 rtol = deprecated.get("relative_tolerance", 1e-3) 

737 mismatched_tol = 0 

738 

739 return rtol, atol, mismatched_tol 

740 

741 

742def _test_model_inference( 

743 model: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

744 weight_format: SupportedWeightsFormat, 

745 devices: Optional[Sequence[str]], 

746 stop_early: bool, 

747 **deprecated: Unpack[DeprecatedKwargs], 

748) -> None: 

749 test_name = f"Reproduce test outputs from test inputs ({weight_format})" 

750 logger.debug("starting '{}'", test_name) 

751 error_entries: List[ErrorEntry] = [] 

752 warning_entries: List[WarningEntry] = [] 

753 

754 def add_error_entry(msg: str, with_traceback: bool = False): 

755 error_entries.append( 

756 ErrorEntry( 

757 loc=("weights", weight_format), 

758 msg=msg, 

759 type="bioimageio.core", 

760 with_traceback=with_traceback, 

761 ) 

762 ) 

763 

764 def add_warning_entry(msg: str): 

765 warning_entries.append( 

766 WarningEntry( 

767 loc=("weights", weight_format), 

768 msg=msg, 

769 type="bioimageio.core", 

770 ) 

771 ) 

772 

773 try: 

774 test_input = get_test_input_sample(model) 

775 expected = get_test_output_sample(model) 

776 

777 with create_prediction_pipeline( 

778 bioimageio_model=model, devices=devices, weight_format=weight_format 

779 ) as prediction_pipeline: 

780 results = prediction_pipeline.predict_sample_without_blocking(test_input) 

781 

782 if len(results.members) != len(expected.members): 

783 add_error_entry( 

784 f"Expected {len(expected.members)} outputs, but got {len(results.members)}" 

785 ) 

786 

787 else: 

788 for m, expected in expected.members.items(): 

789 actual = results.members.get(m) 

790 if actual is None: 

791 add_error_entry("Output tensors for test case may not be None") 

792 if stop_early: 

793 break 

794 else: 

795 continue 

796 

797 if actual.dims != (dims := expected.dims): 

798 add_error_entry( 

799 f"Output '{m}' has dims {actual.dims}, but expected {expected.dims}" 

800 ) 

801 if stop_early: 

802 break 

803 else: 

804 continue 

805 

806 if actual.tagged_shape != expected.tagged_shape: 

807 add_error_entry( 

808 f"Output '{m}' has shape {actual.tagged_shape}, but expected {expected.tagged_shape}" 

809 ) 

810 if stop_early: 

811 break 

812 else: 

813 continue 

814 

815 try: 

816 expected_np = expected.data.to_numpy().astype(np.float32) 

817 del expected 

818 actual_np: NDArray[Any] = actual.data.to_numpy().astype(np.float32) 

819 

820 rtol, atol, mismatched_tol = _get_tolerance( 

821 model, wf=weight_format, m=m, **deprecated 

822 ) 

823 rtol_value = rtol * abs(expected_np) 

824 abs_diff = abs(actual_np - expected_np) 

825 mismatched = abs_diff > atol + rtol_value 

826 mismatched_elements = mismatched.sum().item() 

827 if not mismatched_elements: 

828 continue 

829 

830 actual_output_path = Path(f"actual_output_{m}_{weight_format}.npy") 

831 try: 

832 save_tensor(actual_output_path, actual) 

833 except Exception as e: 

834 logger.error( 

835 "Failed to save actual output tensor to {}: {}", 

836 actual_output_path, 

837 e, 

838 ) 

839 

840 mismatched_ppm = mismatched_elements / expected_np.size * 1e6 

841 abs_diff[~mismatched] = 0 # ignore non-mismatched elements 

842 

843 r_max_idx_flat = ( 

844 r_diff := (abs_diff / (abs(expected_np) + 1e-6)) 

845 ).argmax() 

846 r_max_idx = np.unravel_index(r_max_idx_flat, r_diff.shape) 

847 r_max = r_diff[r_max_idx].item() 

848 r_actual = actual_np[r_max_idx].item() 

849 r_expected = expected_np[r_max_idx].item() 

850 

851 # Calculate the max absolute difference with the relative tolerance subtracted 

852 abs_diff_wo_rtol: NDArray[np.float32] = abs_diff - rtol_value 

853 a_max_idx = np.unravel_index( 

854 abs_diff_wo_rtol.argmax(), abs_diff_wo_rtol.shape 

855 ) 

856 

857 a_max = abs_diff[a_max_idx].item() 

858 a_actual = actual_np[a_max_idx].item() 

859 a_expected = expected_np[a_max_idx].item() 

860 except Exception as e: 

861 msg = f"Output '{m}' disagrees with expected values." 

862 add_error_entry(msg) 

863 if stop_early: 

864 break 

865 else: 

866 msg = ( 

867 f"Output '{m}' disagrees with {mismatched_elements} of" 

868 + f" {expected_np.size} expected values" 

869 + f" ({mismatched_ppm:.1f} ppm)." 

870 + f"\n Max relative difference: {r_max:.2e}" 

871 + rf" (= \|{r_actual:.2e} - {r_expected:.2e}\|/\|{r_expected:.2e} + 1e-6\|)" 

872 + f" at {dict(zip(dims, r_max_idx))}" 

873 + f"\n Max absolute difference not accounted for by relative tolerance: {a_max:.2e}" 

874 + rf" (= \|{a_actual:.7e} - {a_expected:.7e}\|) at {dict(zip(dims, a_max_idx))}" 

875 + f"\n Saved actual output to {actual_output_path}." 

876 ) 

877 if mismatched_ppm > mismatched_tol: 

878 add_error_entry(msg) 

879 if stop_early: 

880 break 

881 else: 

882 add_warning_entry(msg) 

883 

884 except Exception as e: 

885 if get_validation_context().raise_errors: 

886 raise e 

887 

888 add_error_entry(str(e), with_traceback=True) 

889 

890 model.validation_summary.add_detail( 

891 ValidationDetail( 

892 name=test_name, 

893 loc=("weights", weight_format), 

894 status="failed" if error_entries else "passed", 

895 recommended_env=get_conda_env(entry=dict(model.weights)[weight_format]), 

896 errors=error_entries, 

897 warnings=warning_entries, 

898 ) 

899 ) 

900 

901 

902def _test_model_inference_parametrized( 

903 model: v0_5.ModelDescr, 

904 weight_format: SupportedWeightsFormat, 

905 devices: Optional[Sequence[str]], 

906 *, 

907 stop_early: bool, 

908) -> None: 

909 if not any( 

910 isinstance(a.size, v0_5.ParameterizedSize) 

911 for ipt in model.inputs 

912 for a in ipt.axes 

913 ): 

914 # no parameterized sizes => set n=0 

915 ns: Set[v0_5.ParameterizedSize_N] = {0} 

916 else: 

917 ns = {0, 1, 2} 

918 

919 given_batch_sizes = { 

920 a.size 

921 for ipt in model.inputs 

922 for a in ipt.axes 

923 if isinstance(a, v0_5.BatchAxis) 

924 } 

925 if given_batch_sizes: 

926 batch_sizes = {gbs for gbs in given_batch_sizes if gbs is not None} 

927 if not batch_sizes: 

928 # only arbitrary batch sizes 

929 batch_sizes = {1, 2} 

930 else: 

931 # no batch axis 

932 batch_sizes = {1} 

933 

934 test_cases: Set[Tuple[BatchSize, v0_5.ParameterizedSize_N]] = { 

935 (b, n) for b, n in product(sorted(batch_sizes), sorted(ns)) 

936 } 

937 logger.info( 

938 "Testing inference with '{}' for {} different inputs (B, N): {}", 

939 weight_format, 

940 len(test_cases), 

941 test_cases, 

942 ) 

943 

944 def generate_test_cases(): 

945 tested: Set[Hashable] = set() 

946 

947 def get_ns(n: int): 

948 return { 

949 (t.id, a.id): n 

950 for t in model.inputs 

951 for a in t.axes 

952 if isinstance(a.size, v0_5.ParameterizedSize) 

953 } 

954 

955 for batch_size, n in sorted(test_cases): 

956 input_target_sizes, expected_output_sizes = model.get_axis_sizes( 

957 get_ns(n), batch_size=batch_size 

958 ) 

959 hashable_target_size = tuple( 

960 (k, input_target_sizes[k]) for k in sorted(input_target_sizes) 

961 ) 

962 if hashable_target_size in tested: 

963 continue 

964 else: 

965 tested.add(hashable_target_size) 

966 

967 resized_test_inputs = Sample( 

968 members={ 

969 t.id: ( 

970 test_input.members[t.id].resize_to( 

971 { 

972 aid: s 

973 for (tid, aid), s in input_target_sizes.items() 

974 if tid == t.id 

975 }, 

976 ) 

977 ) 

978 for t in model.inputs 

979 }, 

980 stat=test_input.stat, 

981 id=test_input.id, 

982 ) 

983 expected_output_shapes = { 

984 t.id: { 

985 aid: s 

986 for (tid, aid), s in expected_output_sizes.items() 

987 if tid == t.id 

988 } 

989 for t in model.outputs 

990 } 

991 yield n, batch_size, resized_test_inputs, expected_output_shapes 

992 

993 try: 

994 test_input = get_test_input_sample(model) 

995 

996 with create_prediction_pipeline( 

997 bioimageio_model=model, devices=devices, weight_format=weight_format 

998 ) as prediction_pipeline: 

999 for n, batch_size, inputs, exptected_output_shape in generate_test_cases(): 

1000 error: Optional[str] = None 

1001 result = prediction_pipeline.predict_sample_without_blocking(inputs) 

1002 if len(result.members) != len(exptected_output_shape): 

1003 error = ( 

1004 f"Expected {len(exptected_output_shape)} outputs," 

1005 + f" but got {len(result.members)}" 

1006 ) 

1007 

1008 else: 

1009 for m, exp in exptected_output_shape.items(): 

1010 res = result.members.get(m) 

1011 if res is None: 

1012 error = "Output tensors may not be None for test case" 

1013 break 

1014 

1015 diff: Dict[AxisId, int] = {} 

1016 for a, s in res.sizes.items(): 

1017 if isinstance((e_aid := exp[AxisId(a)]), int): 

1018 if s != e_aid: 

1019 diff[AxisId(a)] = s 

1020 elif ( 

1021 s < e_aid.min or e_aid.max is not None and s > e_aid.max 

1022 ): 

1023 diff[AxisId(a)] = s 

1024 if diff: 

1025 error = ( 

1026 f"(n={n}) Expected output shape {exp}," 

1027 + f" but got {res.sizes} (diff: {diff})" 

1028 ) 

1029 break 

1030 

1031 model.validation_summary.add_detail( 

1032 ValidationDetail( 

1033 name=f"Run {weight_format} inference for inputs with" 

1034 + f" batch_size: {batch_size} and size parameter n: {n}", 

1035 loc=("weights", weight_format), 

1036 status="passed" if error is None else "failed", 

1037 errors=( 

1038 [] 

1039 if error is None 

1040 else [ 

1041 ErrorEntry( 

1042 loc=("weights", weight_format), 

1043 msg=error, 

1044 type="bioimageio.core", 

1045 ) 

1046 ] 

1047 ), 

1048 ) 

1049 ) 

1050 if stop_early and error is not None: 

1051 break 

1052 except Exception as e: 

1053 if get_validation_context().raise_errors: 

1054 raise e 

1055 

1056 model.validation_summary.add_detail( 

1057 ValidationDetail( 

1058 name=f"Run {weight_format} inference for parametrized inputs", 

1059 status="failed", 

1060 loc=("weights", weight_format), 

1061 errors=[ 

1062 ErrorEntry( 

1063 loc=("weights", weight_format), 

1064 msg=str(e), 

1065 type="bioimageio.core", 

1066 with_traceback=True, 

1067 ) 

1068 ], 

1069 ) 

1070 ) 

1071 

1072 

1073def _test_expected_resource_type( 

1074 rd: Union[InvalidDescr, ResourceDescr], expected_type: str 

1075): 

1076 has_expected_type = rd.type == expected_type 

1077 rd.validation_summary.details.append( 

1078 ValidationDetail( 

1079 name="Has expected resource type", 

1080 status="passed" if has_expected_type else "failed", 

1081 loc=("type",), 

1082 errors=( 

1083 [] 

1084 if has_expected_type 

1085 else [ 

1086 ErrorEntry( 

1087 loc=("type",), 

1088 type="type", 

1089 msg=f"Expected type {expected_type}, found {rd.type}", 

1090 ) 

1091 ] 

1092 ), 

1093 ) 

1094 ) 

1095 

1096 

1097# TODO: Implement `debug_model()` 

1098# def debug_model( 

1099# model_rdf: Union[RawResourceDescr, ResourceDescr, URI, Path, str], 

1100# *, 

1101# weight_format: Optional[WeightsFormat] = None, 

1102# devices: Optional[List[str]] = None, 

1103# ): 

1104# """Run the model test and return dict with inputs, results, expected results and intermediates. 

1105 

1106# Returns dict with tensors "inputs", "inputs_processed", "outputs_raw", "outputs", "expected" and "diff". 

1107# """ 

1108# inputs_raw: Optional = None 

1109# inputs_processed: Optional = None 

1110# outputs_raw: Optional = None 

1111# outputs: Optional = None 

1112# expected: Optional = None 

1113# diff: Optional = None 

1114 

1115# model = load_description( 

1116# model_rdf, weights_priority_order=None if weight_format is None else [weight_format] 

1117# ) 

1118# if not isinstance(model, Model): 

1119# raise ValueError(f"Not a bioimageio.model: {model_rdf}") 

1120 

1121# prediction_pipeline = create_prediction_pipeline( 

1122# bioimageio_model=model, devices=devices, weight_format=weight_format 

1123# ) 

1124# inputs = [ 

1125# xr.DataArray(load_array(str(in_path)), dims=input_spec.axes) 

1126# for in_path, input_spec in zip(model.test_inputs, model.inputs) 

1127# ] 

1128# input_dict = {input_spec.name: input for input_spec, input in zip(model.inputs, inputs)} 

1129 

1130# # keep track of the non-processed inputs 

1131# inputs_raw = [deepcopy(input) for input in inputs] 

1132 

1133# computed_measures = {} 

1134 

1135# prediction_pipeline.apply_preprocessing(input_dict, computed_measures) 

1136# inputs_processed = list(input_dict.values()) 

1137# outputs_raw = prediction_pipeline.predict(*inputs_processed) 

1138# output_dict = {output_spec.name: deepcopy(output) for output_spec, output in zip(model.outputs, outputs_raw)} 

1139# prediction_pipeline.apply_postprocessing(output_dict, computed_measures) 

1140# outputs = list(output_dict.values()) 

1141 

1142# if isinstance(outputs, (np.ndarray, xr.DataArray)): 

1143# outputs = [outputs] 

1144 

1145# expected = [ 

1146# xr.DataArray(load_array(str(out_path)), dims=output_spec.axes) 

1147# for out_path, output_spec in zip(model.test_outputs, model.outputs) 

1148# ] 

1149# if len(outputs) != len(expected): 

1150# error = f"Number of outputs and number of expected outputs disagree: {len(outputs)} != {len(expected)}" 

1151# print(error) 

1152# else: 

1153# diff = [] 

1154# for res, exp in zip(outputs, expected): 

1155# diff.append(res - exp) 

1156 

1157# return { 

1158# "inputs": inputs_raw, 

1159# "inputs_processed": inputs_processed, 

1160# "outputs_raw": outputs_raw, 

1161# "outputs": outputs, 

1162# "expected": expected, 

1163# "diff": diff, 

1164# }