Coverage for src/bioimageio/core/_resource_tests.py: 54%

351 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-09-22 09:21 +0000

1import hashlib 

2import os 

3import platform 

4import subprocess 

5import sys 

6import warnings 

7from io import StringIO 

8from itertools import product 

9from pathlib import Path 

10from tempfile import TemporaryDirectory 

11from typing import ( 

12 Any, 

13 Callable, 

14 Dict, 

15 Hashable, 

16 List, 

17 Literal, 

18 Optional, 

19 Sequence, 

20 Set, 

21 Tuple, 

22 Union, 

23 overload, 

24) 

25 

26import numpy as np 

27from bioimageio.spec import ( 

28 AnyDatasetDescr, 

29 AnyModelDescr, 

30 BioimageioCondaEnv, 

31 DatasetDescr, 

32 InvalidDescr, 

33 LatestResourceDescr, 

34 ModelDescr, 

35 ResourceDescr, 

36 ValidationContext, 

37 build_description, 

38 dump_description, 

39 get_conda_env, 

40 load_description, 

41 save_bioimageio_package, 

42) 

43from bioimageio.spec._description_impl import DISCOVER 

44from bioimageio.spec._internal.common_nodes import ResourceDescrBase 

45from bioimageio.spec._internal.io import is_yaml_value 

46from bioimageio.spec._internal.io_utils import read_yaml, write_yaml 

47from bioimageio.spec._internal.types import ( 

48 AbsoluteTolerance, 

49 FormatVersionPlaceholder, 

50 MismatchedElementsPerMillion, 

51 RelativeTolerance, 

52) 

53from bioimageio.spec._internal.validation_context import get_validation_context 

54from bioimageio.spec.common import BioimageioYamlContent, PermissiveFileSource, Sha256 

55from bioimageio.spec.model import v0_4, v0_5 

56from bioimageio.spec.model.v0_5 import WeightsFormat 

57from bioimageio.spec.summary import ( 

58 ErrorEntry, 

59 InstalledPackage, 

60 ValidationDetail, 

61 ValidationSummary, 

62 WarningEntry, 

63) 

64from loguru import logger 

65from numpy.typing import NDArray 

66from typing_extensions import NotRequired, TypedDict, Unpack, assert_never, get_args 

67 

68from bioimageio.core import __version__ 

69 

70from ._prediction_pipeline import create_prediction_pipeline 

71from .axis import AxisId, BatchSize 

72from .common import MemberId, SupportedWeightsFormat 

73from .digest_spec import get_test_input_sample, get_test_output_sample 

74from .sample import Sample 

75 

76 

77class DeprecatedKwargs(TypedDict): 

78 absolute_tolerance: NotRequired[AbsoluteTolerance] 

79 relative_tolerance: NotRequired[RelativeTolerance] 

80 decimal: NotRequired[Optional[int]] 

81 

82 

83def enable_determinism( 

84 mode: Literal["seed_only", "full"] = "full", 

85 weight_formats: Optional[Sequence[SupportedWeightsFormat]] = None, 

86): 

87 """Seed and configure ML frameworks for maximum reproducibility. 

88 May degrade performance. Only recommended for testing reproducibility! 

89 

90 Seed any random generators and (if **mode**=="full") request ML frameworks to use 

91 deterministic algorithms. 

92 

93 Args: 

94 mode: determinism mode 

95 - 'seed_only' -- only set seeds, or 

96 - 'full' determinsm features (might degrade performance or throw exceptions) 

97 weight_formats: Limit deep learning importing deep learning frameworks 

98 based on weight_formats. 

99 E.g. this allows to avoid importing tensorflow when testing with pytorch. 

100 

101 Notes: 

102 - **mode** == "full" might degrade performance or throw exceptions. 

103 - Subsequent inference calls might still differ. Call before each function 

104 (sequence) that is expected to be reproducible. 

105 - Degraded performance: Use for testing reproducibility only! 

106 - Recipes: 

107 - [PyTorch](https://pytorch.org/docs/stable/notes/randomness.html) 

108 - [Keras](https://keras.io/examples/keras_recipes/reproducibility_recipes/) 

109 - [NumPy](https://numpy.org/doc/2.0/reference/random/generated/numpy.random.seed.html) 

110 """ 

111 try: 

112 try: 

113 import numpy.random 

114 except ImportError: 

115 pass 

116 else: 

117 numpy.random.seed(0) 

118 except Exception as e: 

119 logger.debug(str(e)) 

120 

121 if ( 

122 weight_formats is None 

123 or "pytorch_state_dict" in weight_formats 

124 or "torchscript" in weight_formats 

125 ): 

126 try: 

127 try: 

128 import torch 

129 except ImportError: 

130 pass 

131 else: 

132 _ = torch.manual_seed(0) 

133 torch.use_deterministic_algorithms(mode == "full") 

134 except Exception as e: 

135 logger.debug(str(e)) 

136 

137 if ( 

138 weight_formats is None 

139 or "tensorflow_saved_model_bundle" in weight_formats 

140 or "keras_hdf5" in weight_formats 

141 ): 

142 try: 

143 os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" 

144 try: 

145 import tensorflow as tf # pyright: ignore[reportMissingTypeStubs] 

146 except ImportError: 

147 pass 

148 else: 

149 tf.random.set_seed(0) 

150 if mode == "full": 

151 tf.config.experimental.enable_op_determinism() 

152 # TODO: find possibility to switch it off again?? 

153 except Exception as e: 

154 logger.debug(str(e)) 

155 

156 if weight_formats is None or "keras_hdf5" in weight_formats: 

157 try: 

158 try: 

159 import keras # pyright: ignore[reportMissingTypeStubs] 

160 except ImportError: 

161 pass 

162 else: 

163 keras.utils.set_random_seed(0) 

164 except Exception as e: 

165 logger.debug(str(e)) 

166 

167 

168def test_model( 

169 source: Union[v0_4.ModelDescr, v0_5.ModelDescr, PermissiveFileSource], 

170 weight_format: Optional[SupportedWeightsFormat] = None, 

171 devices: Optional[List[str]] = None, 

172 *, 

173 determinism: Literal["seed_only", "full"] = "seed_only", 

174 sha256: Optional[Sha256] = None, 

175 stop_early: bool = True, 

176 **deprecated: Unpack[DeprecatedKwargs], 

177) -> ValidationSummary: 

178 """Test model inference""" 

179 return test_description( 

180 source, 

181 weight_format=weight_format, 

182 devices=devices, 

183 determinism=determinism, 

184 expected_type="model", 

185 sha256=sha256, 

186 stop_early=stop_early, 

187 **deprecated, 

188 ) 

189 

190 

191def default_run_command(args: Sequence[str]): 

192 logger.info("running '{}'...", " ".join(args)) 

193 _ = subprocess.run(args, shell=True, text=True, check=True) 

194 

195 

196def test_description( 

197 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

198 *, 

199 format_version: Union[FormatVersionPlaceholder, str] = "discover", 

200 weight_format: Optional[SupportedWeightsFormat] = None, 

201 devices: Optional[Sequence[str]] = None, 

202 determinism: Literal["seed_only", "full"] = "seed_only", 

203 expected_type: Optional[str] = None, 

204 sha256: Optional[Sha256] = None, 

205 stop_early: bool = True, 

206 runtime_env: Union[ 

207 Literal["currently-active", "as-described"], Path, BioimageioCondaEnv 

208 ] = ("currently-active"), 

209 run_command: Callable[[Sequence[str]], None] = default_run_command, 

210 **deprecated: Unpack[DeprecatedKwargs], 

211) -> ValidationSummary: 

212 """Test a bioimage.io resource dynamically, 

213 for example run prediction of test tensors for models. 

214 

215 Args: 

216 source: model description source. 

217 weight_format: Weight format to test. 

218 Default: All weight formats present in **source**. 

219 devices: Devices to test with, e.g. 'cpu', 'cuda'. 

220 Default (may be weight format dependent): ['cuda'] if available, ['cpu'] otherwise. 

221 determinism: Modes to improve reproducibility of test outputs. 

222 expected_type: Assert an expected resource description `type`. 

223 sha256: Expected SHA256 value of **source**. 

224 (Ignored if **source** already is a loaded `ResourceDescr` object.) 

225 stop_early: Do not run further subtests after a failed one. 

226 runtime_env: (Experimental feature!) The Python environment to run the tests in 

227 - `"currently-active"`: Use active Python interpreter. 

228 - `"as-described"`: Use `bioimageio.spec.get_conda_env` to generate a conda 

229 environment YAML file based on the model weights description. 

230 - A `BioimageioCondaEnv` or a path to a conda environment YAML file. 

231 Note: The `bioimageio.core` dependency will be added automatically if not present. 

232 run_command: (Experimental feature!) Function to execute (conda) terminal commands in a subprocess. 

233 The function should raise an exception if the command fails. 

234 **run_command** is ignored if **runtime_env** is `"currently-active"`. 

235 """ 

236 if runtime_env == "currently-active": 

237 rd = load_description_and_test( 

238 source, 

239 format_version=format_version, 

240 weight_format=weight_format, 

241 devices=devices, 

242 determinism=determinism, 

243 expected_type=expected_type, 

244 sha256=sha256, 

245 stop_early=stop_early, 

246 **deprecated, 

247 ) 

248 return rd.validation_summary 

249 

250 if runtime_env == "as-described": 

251 conda_env = None 

252 elif isinstance(runtime_env, (str, Path)): 

253 conda_env = BioimageioCondaEnv.model_validate(read_yaml(Path(runtime_env))) 

254 elif isinstance(runtime_env, BioimageioCondaEnv): 

255 conda_env = runtime_env 

256 else: 

257 assert_never(runtime_env) 

258 

259 td_kwargs: Dict[str, Any] = ( 

260 dict(ignore_cleanup_errors=True) if sys.version_info >= (3, 10) else {} 

261 ) 

262 with TemporaryDirectory(**td_kwargs) as _d: 

263 working_dir = Path(_d) 

264 if isinstance(source, (dict, ResourceDescrBase)): 

265 file_source = save_bioimageio_package( 

266 source, output_path=working_dir / "package.zip" 

267 ) 

268 else: 

269 file_source = source 

270 

271 return _test_in_env( 

272 file_source, 

273 working_dir=working_dir, 

274 weight_format=weight_format, 

275 conda_env=conda_env, 

276 devices=devices, 

277 determinism=determinism, 

278 expected_type=expected_type, 

279 sha256=sha256, 

280 stop_early=stop_early, 

281 run_command=run_command, 

282 **deprecated, 

283 ) 

284 

285 

286def _test_in_env( 

287 source: PermissiveFileSource, 

288 *, 

289 working_dir: Path, 

290 weight_format: Optional[SupportedWeightsFormat], 

291 conda_env: Optional[BioimageioCondaEnv], 

292 devices: Optional[Sequence[str]], 

293 determinism: Literal["seed_only", "full"], 

294 run_command: Callable[[Sequence[str]], None], 

295 stop_early: bool, 

296 expected_type: Optional[str], 

297 sha256: Optional[Sha256], 

298 **deprecated: Unpack[DeprecatedKwargs], 

299) -> ValidationSummary: 

300 descr = load_description(source) 

301 

302 if not isinstance(descr, (v0_4.ModelDescr, v0_5.ModelDescr)): 

303 raise NotImplementedError("Not yet implemented for non-model resources") 

304 

305 if weight_format is None: 

306 all_present_wfs = [ 

307 wf for wf in get_args(WeightsFormat) if getattr(descr.weights, wf) 

308 ] 

309 ignore_wfs = [wf for wf in all_present_wfs if wf in ["tensorflow_js"]] 

310 logger.info( 

311 "Found weight formats {}. Start testing all{}...", 

312 all_present_wfs, 

313 f" (except: {', '.join(ignore_wfs)}) " if ignore_wfs else "", 

314 ) 

315 summary = _test_in_env( 

316 source, 

317 working_dir=working_dir / all_present_wfs[0], 

318 weight_format=all_present_wfs[0], 

319 devices=devices, 

320 determinism=determinism, 

321 conda_env=conda_env, 

322 run_command=run_command, 

323 expected_type=expected_type, 

324 sha256=sha256, 

325 stop_early=stop_early, 

326 **deprecated, 

327 ) 

328 for wf in all_present_wfs[1:]: 

329 additional_summary = _test_in_env( 

330 source, 

331 working_dir=working_dir / wf, 

332 weight_format=wf, 

333 devices=devices, 

334 determinism=determinism, 

335 conda_env=conda_env, 

336 run_command=run_command, 

337 expected_type=expected_type, 

338 sha256=sha256, 

339 stop_early=stop_early, 

340 **deprecated, 

341 ) 

342 for d in additional_summary.details: 

343 # TODO: filter reduntant details; group details 

344 summary.add_detail(d) 

345 return summary 

346 

347 if weight_format == "pytorch_state_dict": 

348 wf = descr.weights.pytorch_state_dict 

349 elif weight_format == "torchscript": 

350 wf = descr.weights.torchscript 

351 elif weight_format == "keras_hdf5": 

352 wf = descr.weights.keras_hdf5 

353 elif weight_format == "onnx": 

354 wf = descr.weights.onnx 

355 elif weight_format == "tensorflow_saved_model_bundle": 

356 wf = descr.weights.tensorflow_saved_model_bundle 

357 elif weight_format == "tensorflow_js": 

358 raise RuntimeError( 

359 "testing 'tensorflow_js' is not supported by bioimageio.core" 

360 ) 

361 else: 

362 assert_never(weight_format) 

363 

364 assert wf is not None 

365 if conda_env is None: 

366 conda_env = get_conda_env(entry=wf) 

367 

368 # remove name as we crate a name based on the env description hash value 

369 conda_env.name = None 

370 

371 dumped_env = conda_env.model_dump(mode="json", exclude_none=True) 

372 if not is_yaml_value(dumped_env): 

373 raise ValueError(f"Failed to dump conda env to valid YAML {conda_env}") 

374 

375 env_io = StringIO() 

376 write_yaml(dumped_env, file=env_io) 

377 encoded_env = env_io.getvalue().encode() 

378 env_name = hashlib.sha256(encoded_env).hexdigest() 

379 

380 try: 

381 run_command(["where" if platform.system() == "Windows" else "which", "conda"]) 

382 except Exception as e: 

383 raise RuntimeError("Conda not available") from e 

384 

385 try: 

386 run_command(["conda", "activate", env_name]) 

387 except Exception: 

388 path = working_dir / "env.yaml" 

389 try: 

390 _ = path.write_bytes(encoded_env) 

391 logger.debug("written conda env to {}", path) 

392 run_command( 

393 ["conda", "env", "create", f"--file={path}", f"--name={env_name}"] 

394 ) 

395 run_command(["conda", "activate", env_name]) 

396 except Exception as e: 

397 summary = descr.validation_summary 

398 summary.add_detail( 

399 ValidationDetail( 

400 name="Conda environment creation", 

401 status="failed", 

402 loc=("weights", weight_format), 

403 recommended_env=conda_env, 

404 errors=[ 

405 ErrorEntry( 

406 loc=("weights", weight_format), 

407 msg=str(e), 

408 type="conda", 

409 with_traceback=True, 

410 ) 

411 ], 

412 ) 

413 ) 

414 return summary 

415 

416 working_dir.mkdir(parents=True, exist_ok=True) 

417 summary_path = working_dir / "summary.json" 

418 assert not summary_path.exists(), "Summary file already exists" 

419 cmd = [] 

420 cmd_error = None 

421 for summary_path_arg_name in ("summary", "summary-path"): 

422 try: 

423 run_command( 

424 cmd := ( 

425 [ 

426 "conda", 

427 "run", 

428 "-n", 

429 env_name, 

430 "bioimageio", 

431 "test", 

432 str(source), 

433 f"--{summary_path_arg_name}={summary_path.as_posix()}", 

434 f"--determinism={determinism}", 

435 ] 

436 + ([f"--expected-type={expected_type}"] if expected_type else []) 

437 + (["--stop-early"] if stop_early else []) 

438 ) 

439 ) 

440 except Exception as e: 

441 cmd_error = f"Failed to run command '{' '.join(cmd)}': {e}." 

442 

443 if summary_path.exists(): 

444 break 

445 else: 

446 if cmd_error is not None: 

447 logger.warning(cmd_error) 

448 

449 return ValidationSummary( 

450 name="calling bioimageio test command", 

451 source_name=str(source), 

452 status="failed", 

453 type="unknown", 

454 format_version="unknown", 

455 details=[ 

456 ValidationDetail( 

457 name="run 'bioimageio test'", 

458 errors=[ 

459 ErrorEntry( 

460 loc=(), 

461 type="bioimageio cli", 

462 msg=f"test command '{' '.join(cmd)}' did not produce a summary file at {summary_path}", 

463 ) 

464 ], 

465 status="failed", 

466 ) 

467 ], 

468 env=set(), 

469 ) 

470 

471 return ValidationSummary.load_json(summary_path) 

472 

473 

474@overload 

475def load_description_and_test( 

476 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

477 *, 

478 format_version: Literal["latest"], 

479 weight_format: Optional[SupportedWeightsFormat] = None, 

480 devices: Optional[Sequence[str]] = None, 

481 determinism: Literal["seed_only", "full"] = "seed_only", 

482 expected_type: Literal["model"], 

483 sha256: Optional[Sha256] = None, 

484 stop_early: bool = True, 

485 **deprecated: Unpack[DeprecatedKwargs], 

486) -> Union[ModelDescr, InvalidDescr]: ... 

487 

488 

489@overload 

490def load_description_and_test( 

491 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

492 *, 

493 format_version: Literal["latest"], 

494 weight_format: Optional[SupportedWeightsFormat] = None, 

495 devices: Optional[Sequence[str]] = None, 

496 determinism: Literal["seed_only", "full"] = "seed_only", 

497 expected_type: Literal["dataset"], 

498 sha256: Optional[Sha256] = None, 

499 stop_early: bool = True, 

500 **deprecated: Unpack[DeprecatedKwargs], 

501) -> Union[DatasetDescr, InvalidDescr]: ... 

502 

503 

504@overload 

505def load_description_and_test( 

506 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

507 *, 

508 format_version: Literal["latest"], 

509 weight_format: Optional[SupportedWeightsFormat] = None, 

510 devices: Optional[Sequence[str]] = None, 

511 determinism: Literal["seed_only", "full"] = "seed_only", 

512 expected_type: Optional[str] = None, 

513 sha256: Optional[Sha256] = None, 

514 stop_early: bool = True, 

515 **deprecated: Unpack[DeprecatedKwargs], 

516) -> Union[LatestResourceDescr, InvalidDescr]: ... 

517 

518 

519@overload 

520def load_description_and_test( 

521 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

522 *, 

523 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 

524 weight_format: Optional[SupportedWeightsFormat] = None, 

525 devices: Optional[Sequence[str]] = None, 

526 determinism: Literal["seed_only", "full"] = "seed_only", 

527 expected_type: Literal["model"], 

528 sha256: Optional[Sha256] = None, 

529 stop_early: bool = True, 

530 **deprecated: Unpack[DeprecatedKwargs], 

531) -> Union[AnyModelDescr, InvalidDescr]: ... 

532 

533 

534@overload 

535def load_description_and_test( 

536 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

537 *, 

538 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 

539 weight_format: Optional[SupportedWeightsFormat] = None, 

540 devices: Optional[Sequence[str]] = None, 

541 determinism: Literal["seed_only", "full"] = "seed_only", 

542 expected_type: Literal["dataset"], 

543 sha256: Optional[Sha256] = None, 

544 stop_early: bool = True, 

545 **deprecated: Unpack[DeprecatedKwargs], 

546) -> Union[AnyDatasetDescr, InvalidDescr]: ... 

547 

548 

549@overload 

550def load_description_and_test( 

551 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

552 *, 

553 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 

554 weight_format: Optional[SupportedWeightsFormat] = None, 

555 devices: Optional[Sequence[str]] = None, 

556 determinism: Literal["seed_only", "full"] = "seed_only", 

557 expected_type: Optional[str] = None, 

558 sha256: Optional[Sha256] = None, 

559 stop_early: bool = True, 

560 **deprecated: Unpack[DeprecatedKwargs], 

561) -> Union[ResourceDescr, InvalidDescr]: ... 

562 

563 

564def load_description_and_test( 

565 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

566 *, 

567 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 

568 weight_format: Optional[SupportedWeightsFormat] = None, 

569 devices: Optional[Sequence[str]] = None, 

570 determinism: Literal["seed_only", "full"] = "seed_only", 

571 expected_type: Optional[str] = None, 

572 sha256: Optional[Sha256] = None, 

573 stop_early: bool = True, 

574 **deprecated: Unpack[DeprecatedKwargs], 

575) -> Union[ResourceDescr, InvalidDescr]: 

576 """Test a bioimage.io resource dynamically, 

577 for example run prediction of test tensors for models. 

578 

579 See `test_description` for more details. 

580 

581 Returns: 

582 A (possibly invalid) resource description object 

583 with a populated `.validation_summary` attribute. 

584 """ 

585 if isinstance(source, ResourceDescrBase): 

586 root = source.root 

587 file_name = source.file_name 

588 if ( 

589 ( 

590 format_version 

591 not in ( 

592 DISCOVER, 

593 source.format_version, 

594 ".".join(source.format_version.split(".")[:2]), 

595 ) 

596 ) 

597 or (c := source.validation_summary.details[0].context) is None 

598 or not c.perform_io_checks 

599 ): 

600 logger.debug( 

601 "deserializing source to ensure we validate and test using format {} and perform io checks", 

602 format_version, 

603 ) 

604 source = dump_description(source) 

605 else: 

606 root = Path() 

607 file_name = None 

608 

609 if isinstance(source, ResourceDescrBase): 

610 rd = source 

611 elif isinstance(source, dict): 

612 # check context for a given root; default to root of source 

613 context = get_validation_context( 

614 ValidationContext(root=root, file_name=file_name) 

615 ).replace( 

616 perform_io_checks=True # make sure we perform io checks though 

617 ) 

618 

619 rd = build_description( 

620 source, 

621 format_version=format_version, 

622 context=context, 

623 ) 

624 else: 

625 rd = load_description( 

626 source, format_version=format_version, sha256=sha256, perform_io_checks=True 

627 ) 

628 

629 rd.validation_summary.env.add( 

630 InstalledPackage(name="bioimageio.core", version=__version__) 

631 ) 

632 

633 if expected_type is not None: 

634 _test_expected_resource_type(rd, expected_type) 

635 

636 if isinstance(rd, (v0_4.ModelDescr, v0_5.ModelDescr)): 

637 if weight_format is None: 

638 weight_formats: List[SupportedWeightsFormat] = [ 

639 w for w, we in rd.weights if we is not None 

640 ] # pyright: ignore[reportAssignmentType] 

641 else: 

642 weight_formats = [weight_format] 

643 

644 enable_determinism(determinism, weight_formats=weight_formats) 

645 for w in weight_formats: 

646 _test_model_inference(rd, w, devices, stop_early=stop_early, **deprecated) 

647 if stop_early and rd.validation_summary.status == "failed": 

648 break 

649 

650 if not isinstance(rd, v0_4.ModelDescr): 

651 _test_model_inference_parametrized( 

652 rd, w, devices, stop_early=stop_early 

653 ) 

654 if stop_early and rd.validation_summary.status == "failed": 

655 break 

656 

657 # TODO: add execution of jupyter notebooks 

658 # TODO: add more tests 

659 

660 if rd.validation_summary.status == "valid-format": 

661 rd.validation_summary.status = "passed" 

662 

663 return rd 

664 

665 

666def _get_tolerance( 

667 model: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

668 wf: SupportedWeightsFormat, 

669 m: MemberId, 

670 **deprecated: Unpack[DeprecatedKwargs], 

671) -> Tuple[RelativeTolerance, AbsoluteTolerance, MismatchedElementsPerMillion]: 

672 if isinstance(model, v0_5.ModelDescr): 

673 applicable = v0_5.ReproducibilityTolerance() 

674 

675 # check legacy test kwargs for weight format specific tolerance 

676 if model.config.bioimageio.model_extra is not None: 

677 for weights_format, test_kwargs in model.config.bioimageio.model_extra.get( 

678 "test_kwargs", {} 

679 ).items(): 

680 if wf == weights_format: 

681 applicable = v0_5.ReproducibilityTolerance( 

682 relative_tolerance=test_kwargs.get("relative_tolerance", 1e-3), 

683 absolute_tolerance=test_kwargs.get("absolute_tolerance", 1e-4), 

684 ) 

685 break 

686 

687 # check for weights format and output tensor specific tolerance 

688 for a in model.config.bioimageio.reproducibility_tolerance: 

689 if (not a.weights_formats or wf in a.weights_formats) and ( 

690 not a.output_ids or m in a.output_ids 

691 ): 

692 applicable = a 

693 break 

694 

695 rtol = applicable.relative_tolerance 

696 atol = applicable.absolute_tolerance 

697 mismatched_tol = applicable.mismatched_elements_per_million 

698 elif (decimal := deprecated.get("decimal")) is not None: 

699 warnings.warn( 

700 "The argument `decimal` has been deprecated in favour of" 

701 + " `relative_tolerance` and `absolute_tolerance`, with different" 

702 + " validation logic, using `numpy.testing.assert_allclose, see" 

703 + " 'https://numpy.org/doc/stable/reference/generated/" 

704 + " numpy.testing.assert_allclose.html'. Passing a value for `decimal`" 

705 + " will cause validation to revert to the old behaviour." 

706 ) 

707 atol = 1.5 * 10 ** (-decimal) 

708 rtol = 0 

709 mismatched_tol = 0 

710 else: 

711 # use given (deprecated) test kwargs 

712 atol = deprecated.get("absolute_tolerance", 1e-5) 

713 rtol = deprecated.get("relative_tolerance", 1e-3) 

714 mismatched_tol = 0 

715 

716 return rtol, atol, mismatched_tol 

717 

718 

719def _test_model_inference( 

720 model: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

721 weight_format: SupportedWeightsFormat, 

722 devices: Optional[Sequence[str]], 

723 stop_early: bool, 

724 **deprecated: Unpack[DeprecatedKwargs], 

725) -> None: 

726 test_name = f"Reproduce test outputs from test inputs ({weight_format})" 

727 logger.debug("starting '{}'", test_name) 

728 error_entries: List[ErrorEntry] = [] 

729 warning_entries: List[WarningEntry] = [] 

730 

731 def add_error_entry(msg: str, with_traceback: bool = False): 

732 error_entries.append( 

733 ErrorEntry( 

734 loc=("weights", weight_format), 

735 msg=msg, 

736 type="bioimageio.core", 

737 with_traceback=with_traceback, 

738 ) 

739 ) 

740 

741 def add_warning_entry(msg: str): 

742 warning_entries.append( 

743 WarningEntry( 

744 loc=("weights", weight_format), 

745 msg=msg, 

746 type="bioimageio.core", 

747 ) 

748 ) 

749 

750 try: 

751 test_input = get_test_input_sample(model) 

752 expected = get_test_output_sample(model) 

753 

754 with create_prediction_pipeline( 

755 bioimageio_model=model, devices=devices, weight_format=weight_format 

756 ) as prediction_pipeline: 

757 results = prediction_pipeline.predict_sample_without_blocking(test_input) 

758 

759 if len(results.members) != len(expected.members): 

760 add_error_entry( 

761 f"Expected {len(expected.members)} outputs, but got {len(results.members)}" 

762 ) 

763 

764 else: 

765 for m, expected in expected.members.items(): 

766 actual = results.members.get(m) 

767 if actual is None: 

768 add_error_entry("Output tensors for test case may not be None") 

769 if stop_early: 

770 break 

771 else: 

772 continue 

773 

774 if actual.dims != (dims := expected.dims): 

775 add_error_entry( 

776 f"Output '{m}' has dims {actual.dims}, but expected {expected.dims}" 

777 ) 

778 if stop_early: 

779 break 

780 else: 

781 continue 

782 

783 if actual.tagged_shape != expected.tagged_shape: 

784 add_error_entry( 

785 f"Output '{m}' has shape {actual.tagged_shape}, but expected {expected.tagged_shape}" 

786 ) 

787 if stop_early: 

788 break 

789 else: 

790 continue 

791 

792 expected_np = expected.data.to_numpy().astype(np.float32) 

793 del expected 

794 actual_np: NDArray[Any] = actual.data.to_numpy().astype(np.float32) 

795 del actual 

796 

797 rtol, atol, mismatched_tol = _get_tolerance( 

798 model, wf=weight_format, m=m, **deprecated 

799 ) 

800 rtol_value = rtol * abs(expected_np) 

801 abs_diff = abs(actual_np - expected_np) 

802 mismatched = abs_diff > atol + rtol_value 

803 mismatched_elements = mismatched.sum().item() 

804 if not mismatched_elements: 

805 continue 

806 

807 mismatched_ppm = mismatched_elements / expected_np.size * 1e6 

808 abs_diff[~mismatched] = 0 # ignore non-mismatched elements 

809 

810 r_max_idx_flat = ( 

811 r_diff := (abs_diff / (abs(expected_np) + 1e-6)) 

812 ).argmax() 

813 r_max_idx = np.unravel_index(r_max_idx_flat, r_diff.shape) 

814 r_max = r_diff[r_max_idx].item() 

815 r_actual = actual_np[r_max_idx].item() 

816 r_expected = expected_np[r_max_idx].item() 

817 

818 # Calculate the max absolute difference with the relative tolerance subtracted 

819 abs_diff_wo_rtol: NDArray[np.float32] = abs_diff - rtol_value 

820 a_max_idx = np.unravel_index( 

821 abs_diff_wo_rtol.argmax(), abs_diff_wo_rtol.shape 

822 ) 

823 

824 a_max = abs_diff[a_max_idx].item() 

825 a_actual = actual_np[a_max_idx].item() 

826 a_expected = expected_np[a_max_idx].item() 

827 

828 msg = ( 

829 f"Output '{m}' disagrees with {mismatched_elements} of" 

830 + f" {expected_np.size} expected values" 

831 + f" ({mismatched_ppm:.1f} ppm)." 

832 + f"\n Max relative difference: {r_max:.2e}" 

833 + rf" (= \|{r_actual:.2e} - {r_expected:.2e}\|/\|{r_expected:.2e} + 1e-6\|)" 

834 + f" at {dict(zip(dims, r_max_idx))}" 

835 + f"\n Max absolute difference not accounted for by relative tolerance: {a_max:.2e}" 

836 + rf" (= \|{a_actual:.7e} - {a_expected:.7e}\|) at {dict(zip(dims, a_max_idx))}" 

837 ) 

838 if mismatched_ppm > mismatched_tol: 

839 add_error_entry(msg) 

840 if stop_early: 

841 break 

842 else: 

843 add_warning_entry(msg) 

844 

845 except Exception as e: 

846 if get_validation_context().raise_errors: 

847 raise e 

848 

849 add_error_entry(str(e), with_traceback=True) 

850 

851 model.validation_summary.add_detail( 

852 ValidationDetail( 

853 name=test_name, 

854 loc=("weights", weight_format), 

855 status="failed" if error_entries else "passed", 

856 recommended_env=get_conda_env(entry=dict(model.weights)[weight_format]), 

857 errors=error_entries, 

858 warnings=warning_entries, 

859 ) 

860 ) 

861 

862 

863def _test_model_inference_parametrized( 

864 model: v0_5.ModelDescr, 

865 weight_format: SupportedWeightsFormat, 

866 devices: Optional[Sequence[str]], 

867 *, 

868 stop_early: bool, 

869) -> None: 

870 if not any( 

871 isinstance(a.size, v0_5.ParameterizedSize) 

872 for ipt in model.inputs 

873 for a in ipt.axes 

874 ): 

875 # no parameterized sizes => set n=0 

876 ns: Set[v0_5.ParameterizedSize_N] = {0} 

877 else: 

878 ns = {0, 1, 2} 

879 

880 given_batch_sizes = { 

881 a.size 

882 for ipt in model.inputs 

883 for a in ipt.axes 

884 if isinstance(a, v0_5.BatchAxis) 

885 } 

886 if given_batch_sizes: 

887 batch_sizes = {gbs for gbs in given_batch_sizes if gbs is not None} 

888 if not batch_sizes: 

889 # only arbitrary batch sizes 

890 batch_sizes = {1, 2} 

891 else: 

892 # no batch axis 

893 batch_sizes = {1} 

894 

895 test_cases: Set[Tuple[BatchSize, v0_5.ParameterizedSize_N]] = { 

896 (b, n) for b, n in product(sorted(batch_sizes), sorted(ns)) 

897 } 

898 logger.info( 

899 "Testing inference with '{}' for {} different inputs (B, N): {}", 

900 weight_format, 

901 len(test_cases), 

902 test_cases, 

903 ) 

904 

905 def generate_test_cases(): 

906 tested: Set[Hashable] = set() 

907 

908 def get_ns(n: int): 

909 return { 

910 (t.id, a.id): n 

911 for t in model.inputs 

912 for a in t.axes 

913 if isinstance(a.size, v0_5.ParameterizedSize) 

914 } 

915 

916 for batch_size, n in sorted(test_cases): 

917 input_target_sizes, expected_output_sizes = model.get_axis_sizes( 

918 get_ns(n), batch_size=batch_size 

919 ) 

920 hashable_target_size = tuple( 

921 (k, input_target_sizes[k]) for k in sorted(input_target_sizes) 

922 ) 

923 if hashable_target_size in tested: 

924 continue 

925 else: 

926 tested.add(hashable_target_size) 

927 

928 resized_test_inputs = Sample( 

929 members={ 

930 t.id: ( 

931 test_input.members[t.id].resize_to( 

932 { 

933 aid: s 

934 for (tid, aid), s in input_target_sizes.items() 

935 if tid == t.id 

936 }, 

937 ) 

938 ) 

939 for t in model.inputs 

940 }, 

941 stat=test_input.stat, 

942 id=test_input.id, 

943 ) 

944 expected_output_shapes = { 

945 t.id: { 

946 aid: s 

947 for (tid, aid), s in expected_output_sizes.items() 

948 if tid == t.id 

949 } 

950 for t in model.outputs 

951 } 

952 yield n, batch_size, resized_test_inputs, expected_output_shapes 

953 

954 try: 

955 test_input = get_test_input_sample(model) 

956 

957 with create_prediction_pipeline( 

958 bioimageio_model=model, devices=devices, weight_format=weight_format 

959 ) as prediction_pipeline: 

960 for n, batch_size, inputs, exptected_output_shape in generate_test_cases(): 

961 error: Optional[str] = None 

962 result = prediction_pipeline.predict_sample_without_blocking(inputs) 

963 if len(result.members) != len(exptected_output_shape): 

964 error = ( 

965 f"Expected {len(exptected_output_shape)} outputs," 

966 + f" but got {len(result.members)}" 

967 ) 

968 

969 else: 

970 for m, exp in exptected_output_shape.items(): 

971 res = result.members.get(m) 

972 if res is None: 

973 error = "Output tensors may not be None for test case" 

974 break 

975 

976 diff: Dict[AxisId, int] = {} 

977 for a, s in res.sizes.items(): 

978 if isinstance((e_aid := exp[AxisId(a)]), int): 

979 if s != e_aid: 

980 diff[AxisId(a)] = s 

981 elif ( 

982 s < e_aid.min or e_aid.max is not None and s > e_aid.max 

983 ): 

984 diff[AxisId(a)] = s 

985 if diff: 

986 error = ( 

987 f"(n={n}) Expected output shape {exp}," 

988 + f" but got {res.sizes} (diff: {diff})" 

989 ) 

990 break 

991 

992 model.validation_summary.add_detail( 

993 ValidationDetail( 

994 name=f"Run {weight_format} inference for inputs with" 

995 + f" batch_size: {batch_size} and size parameter n: {n}", 

996 loc=("weights", weight_format), 

997 status="passed" if error is None else "failed", 

998 errors=( 

999 [] 

1000 if error is None 

1001 else [ 

1002 ErrorEntry( 

1003 loc=("weights", weight_format), 

1004 msg=error, 

1005 type="bioimageio.core", 

1006 ) 

1007 ] 

1008 ), 

1009 ) 

1010 ) 

1011 if stop_early and error is not None: 

1012 break 

1013 except Exception as e: 

1014 if get_validation_context().raise_errors: 

1015 raise e 

1016 

1017 model.validation_summary.add_detail( 

1018 ValidationDetail( 

1019 name=f"Run {weight_format} inference for parametrized inputs", 

1020 status="failed", 

1021 loc=("weights", weight_format), 

1022 errors=[ 

1023 ErrorEntry( 

1024 loc=("weights", weight_format), 

1025 msg=str(e), 

1026 type="bioimageio.core", 

1027 with_traceback=True, 

1028 ) 

1029 ], 

1030 ) 

1031 ) 

1032 

1033 

1034def _test_expected_resource_type( 

1035 rd: Union[InvalidDescr, ResourceDescr], expected_type: str 

1036): 

1037 has_expected_type = rd.type == expected_type 

1038 rd.validation_summary.details.append( 

1039 ValidationDetail( 

1040 name="Has expected resource type", 

1041 status="passed" if has_expected_type else "failed", 

1042 loc=("type",), 

1043 errors=( 

1044 [] 

1045 if has_expected_type 

1046 else [ 

1047 ErrorEntry( 

1048 loc=("type",), 

1049 type="type", 

1050 msg=f"Expected type {expected_type}, found {rd.type}", 

1051 ) 

1052 ] 

1053 ), 

1054 ) 

1055 ) 

1056 

1057 

1058# TODO: Implement `debug_model()` 

1059# def debug_model( 

1060# model_rdf: Union[RawResourceDescr, ResourceDescr, URI, Path, str], 

1061# *, 

1062# weight_format: Optional[WeightsFormat] = None, 

1063# devices: Optional[List[str]] = None, 

1064# ): 

1065# """Run the model test and return dict with inputs, results, expected results and intermediates. 

1066 

1067# Returns dict with tensors "inputs", "inputs_processed", "outputs_raw", "outputs", "expected" and "diff". 

1068# """ 

1069# inputs_raw: Optional = None 

1070# inputs_processed: Optional = None 

1071# outputs_raw: Optional = None 

1072# outputs: Optional = None 

1073# expected: Optional = None 

1074# diff: Optional = None 

1075 

1076# model = load_description( 

1077# model_rdf, weights_priority_order=None if weight_format is None else [weight_format] 

1078# ) 

1079# if not isinstance(model, Model): 

1080# raise ValueError(f"Not a bioimageio.model: {model_rdf}") 

1081 

1082# prediction_pipeline = create_prediction_pipeline( 

1083# bioimageio_model=model, devices=devices, weight_format=weight_format 

1084# ) 

1085# inputs = [ 

1086# xr.DataArray(load_array(str(in_path)), dims=input_spec.axes) 

1087# for in_path, input_spec in zip(model.test_inputs, model.inputs) 

1088# ] 

1089# input_dict = {input_spec.name: input for input_spec, input in zip(model.inputs, inputs)} 

1090 

1091# # keep track of the non-processed inputs 

1092# inputs_raw = [deepcopy(input) for input in inputs] 

1093 

1094# computed_measures = {} 

1095 

1096# prediction_pipeline.apply_preprocessing(input_dict, computed_measures) 

1097# inputs_processed = list(input_dict.values()) 

1098# outputs_raw = prediction_pipeline.predict(*inputs_processed) 

1099# output_dict = {output_spec.name: deepcopy(output) for output_spec, output in zip(model.outputs, outputs_raw)} 

1100# prediction_pipeline.apply_postprocessing(output_dict, computed_measures) 

1101# outputs = list(output_dict.values()) 

1102 

1103# if isinstance(outputs, (np.ndarray, xr.DataArray)): 

1104# outputs = [outputs] 

1105 

1106# expected = [ 

1107# xr.DataArray(load_array(str(out_path)), dims=output_spec.axes) 

1108# for out_path, output_spec in zip(model.test_outputs, model.outputs) 

1109# ] 

1110# if len(outputs) != len(expected): 

1111# error = f"Number of outputs and number of expected outputs disagree: {len(outputs)} != {len(expected)}" 

1112# print(error) 

1113# else: 

1114# diff = [] 

1115# for res, exp in zip(outputs, expected): 

1116# diff.append(res - exp) 

1117 

1118# return { 

1119# "inputs": inputs_raw, 

1120# "inputs_processed": inputs_processed, 

1121# "outputs_raw": outputs_raw, 

1122# "outputs": outputs, 

1123# "expected": expected, 

1124# "diff": diff, 

1125# }