Coverage for bioimageio/core/_resource_tests.py: 58%

330 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-16 15:20 +0000

1import hashlib 

2import os 

3import platform 

4import subprocess 

5import warnings 

6from io import StringIO 

7from itertools import product 

8from pathlib import Path 

9from tempfile import TemporaryDirectory 

10from typing import ( 

11 Callable, 

12 Dict, 

13 Hashable, 

14 List, 

15 Literal, 

16 Optional, 

17 Sequence, 

18 Set, 

19 Tuple, 

20 Union, 

21 overload, 

22) 

23 

24import xarray as xr 

25from loguru import logger 

26from typing_extensions import NotRequired, TypedDict, Unpack, assert_never, get_args 

27 

28from bioimageio.spec import ( 

29 BioimageioCondaEnv, 

30 InvalidDescr, 

31 LatestResourceDescr, 

32 ResourceDescr, 

33 ValidationContext, 

34 build_description, 

35 dump_description, 

36 get_conda_env, 

37 load_description, 

38 save_bioimageio_package, 

39) 

40from bioimageio.spec._description_impl import DISCOVER 

41from bioimageio.spec._internal.common_nodes import ResourceDescrBase 

42from bioimageio.spec._internal.io import is_yaml_value 

43from bioimageio.spec._internal.io_utils import read_yaml, write_yaml 

44from bioimageio.spec._internal.types import ( 

45 AbsoluteTolerance, 

46 FormatVersionPlaceholder, 

47 MismatchedElementsPerMillion, 

48 RelativeTolerance, 

49) 

50from bioimageio.spec._internal.validation_context import get_validation_context 

51from bioimageio.spec.common import BioimageioYamlContent, PermissiveFileSource, Sha256 

52from bioimageio.spec.model import v0_4, v0_5 

53from bioimageio.spec.model.v0_5 import WeightsFormat 

54from bioimageio.spec.summary import ( 

55 ErrorEntry, 

56 InstalledPackage, 

57 ValidationDetail, 

58 ValidationSummary, 

59 WarningEntry, 

60) 

61 

62from ._prediction_pipeline import create_prediction_pipeline 

63from .axis import AxisId, BatchSize 

64from .common import MemberId, SupportedWeightsFormat 

65from .digest_spec import get_test_inputs, get_test_outputs 

66from .sample import Sample 

67from .utils import VERSION 

68 

69 

70class DeprecatedKwargs(TypedDict): 

71 absolute_tolerance: NotRequired[AbsoluteTolerance] 

72 relative_tolerance: NotRequired[RelativeTolerance] 

73 decimal: NotRequired[Optional[int]] 

74 

75 

76def enable_determinism( 

77 mode: Literal["seed_only", "full"] = "full", 

78 weight_formats: Optional[Sequence[SupportedWeightsFormat]] = None, 

79): 

80 """Seed and configure ML frameworks for maximum reproducibility. 

81 May degrade performance. Only recommended for testing reproducibility! 

82 

83 Seed any random generators and (if **mode**=="full") request ML frameworks to use 

84 deterministic algorithms. 

85 

86 Args: 

87 mode: determinism mode 

88 - 'seed_only' -- only set seeds, or 

89 - 'full' determinsm features (might degrade performance or throw exceptions) 

90 weight_formats: Limit deep learning importing deep learning frameworks 

91 based on weight_formats. 

92 E.g. this allows to avoid importing tensorflow when testing with pytorch. 

93 

94 Notes: 

95 - **mode** == "full" might degrade performance or throw exceptions. 

96 - Subsequent inference calls might still differ. Call before each function 

97 (sequence) that is expected to be reproducible. 

98 - Degraded performance: Use for testing reproducibility only! 

99 - Recipes: 

100 - [PyTorch](https://pytorch.org/docs/stable/notes/randomness.html) 

101 - [Keras](https://keras.io/examples/keras_recipes/reproducibility_recipes/) 

102 - [NumPy](https://numpy.org/doc/2.0/reference/random/generated/numpy.random.seed.html) 

103 """ 

104 try: 

105 try: 

106 import numpy.random 

107 except ImportError: 

108 pass 

109 else: 

110 numpy.random.seed(0) 

111 except Exception as e: 

112 logger.debug(str(e)) 

113 

114 if ( 

115 weight_formats is None 

116 or "pytorch_state_dict" in weight_formats 

117 or "torchscript" in weight_formats 

118 ): 

119 try: 

120 try: 

121 import torch 

122 except ImportError: 

123 pass 

124 else: 

125 _ = torch.manual_seed(0) 

126 torch.use_deterministic_algorithms(mode == "full") 

127 except Exception as e: 

128 logger.debug(str(e)) 

129 

130 if ( 

131 weight_formats is None 

132 or "tensorflow_saved_model_bundle" in weight_formats 

133 or "keras_hdf5" in weight_formats 

134 ): 

135 try: 

136 os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" 

137 try: 

138 import tensorflow as tf # pyright: ignore[reportMissingTypeStubs] 

139 except ImportError: 

140 pass 

141 else: 

142 tf.random.set_seed(0) 

143 if mode == "full": 

144 tf.config.experimental.enable_op_determinism() 

145 # TODO: find possibility to switch it off again?? 

146 except Exception as e: 

147 logger.debug(str(e)) 

148 

149 if weight_formats is None or "keras_hdf5" in weight_formats: 

150 try: 

151 try: 

152 import keras # pyright: ignore[reportMissingTypeStubs] 

153 except ImportError: 

154 pass 

155 else: 

156 keras.utils.set_random_seed(0) 

157 except Exception as e: 

158 logger.debug(str(e)) 

159 

160 

161def test_model( 

162 source: Union[v0_4.ModelDescr, v0_5.ModelDescr, PermissiveFileSource], 

163 weight_format: Optional[SupportedWeightsFormat] = None, 

164 devices: Optional[List[str]] = None, 

165 *, 

166 determinism: Literal["seed_only", "full"] = "seed_only", 

167 sha256: Optional[Sha256] = None, 

168 stop_early: bool = False, 

169 **deprecated: Unpack[DeprecatedKwargs], 

170) -> ValidationSummary: 

171 """Test model inference""" 

172 return test_description( 

173 source, 

174 weight_format=weight_format, 

175 devices=devices, 

176 determinism=determinism, 

177 expected_type="model", 

178 sha256=sha256, 

179 stop_early=stop_early, 

180 **deprecated, 

181 ) 

182 

183 

184def default_run_command(args: Sequence[str]): 

185 logger.info("running '{}'...", " ".join(args)) 

186 _ = subprocess.run(args, shell=True, text=True, check=True) 

187 

188 

189def test_description( 

190 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

191 *, 

192 format_version: Union[FormatVersionPlaceholder, str] = "discover", 

193 weight_format: Optional[SupportedWeightsFormat] = None, 

194 devices: Optional[Sequence[str]] = None, 

195 determinism: Literal["seed_only", "full"] = "seed_only", 

196 expected_type: Optional[str] = None, 

197 sha256: Optional[Sha256] = None, 

198 stop_early: bool = False, 

199 runtime_env: Union[ 

200 Literal["currently-active", "as-described"], Path, BioimageioCondaEnv 

201 ] = ("currently-active"), 

202 run_command: Callable[[Sequence[str]], None] = default_run_command, 

203 **deprecated: Unpack[DeprecatedKwargs], 

204) -> ValidationSummary: 

205 """Test a bioimage.io resource dynamically, 

206 for example run prediction of test tensors for models. 

207 

208 Args: 

209 source: model description source. 

210 weight_format: Weight format to test. 

211 Default: All weight formats present in **source**. 

212 devices: Devices to test with, e.g. 'cpu', 'cuda'. 

213 Default (may be weight format dependent): ['cuda'] if available, ['cpu'] otherwise. 

214 determinism: Modes to improve reproducibility of test outputs. 

215 expected_type: Assert an expected resource description `type`. 

216 sha256: Expected SHA256 value of **source**. 

217 (Ignored if **source** already is a loaded `ResourceDescr` object.) 

218 stop_early: Do not run further subtests after a failed one. 

219 runtime_env: (Experimental feature!) The Python environment to run the tests in 

220 - `"currently-active"`: Use active Python interpreter. 

221 - `"as-described"`: Use `bioimageio.spec.get_conda_env` to generate a conda 

222 environment YAML file based on the model weights description. 

223 - A `BioimageioCondaEnv` or a path to a conda environment YAML file. 

224 Note: The `bioimageio.core` dependency will be added automatically if not present. 

225 run_command: (Experimental feature!) Function to execute (conda) terminal commands in a subprocess 

226 (ignored if **runtime_env** is `"currently-active"`). 

227 """ 

228 if runtime_env == "currently-active": 

229 rd = load_description_and_test( 

230 source, 

231 format_version=format_version, 

232 weight_format=weight_format, 

233 devices=devices, 

234 determinism=determinism, 

235 expected_type=expected_type, 

236 sha256=sha256, 

237 stop_early=stop_early, 

238 **deprecated, 

239 ) 

240 return rd.validation_summary 

241 

242 if runtime_env == "as-described": 

243 conda_env = None 

244 elif isinstance(runtime_env, (str, Path)): 

245 conda_env = BioimageioCondaEnv.model_validate(read_yaml(Path(runtime_env))) 

246 elif isinstance(runtime_env, BioimageioCondaEnv): 

247 conda_env = runtime_env 

248 else: 

249 assert_never(runtime_env) 

250 

251 with TemporaryDirectory(ignore_cleanup_errors=True) as _d: 

252 working_dir = Path(_d) 

253 if isinstance(source, (dict, ResourceDescrBase)): 

254 file_source = save_bioimageio_package( 

255 source, output_path=working_dir / "package.zip" 

256 ) 

257 else: 

258 file_source = source 

259 

260 return _test_in_env( 

261 file_source, 

262 working_dir=working_dir, 

263 weight_format=weight_format, 

264 conda_env=conda_env, 

265 devices=devices, 

266 determinism=determinism, 

267 expected_type=expected_type, 

268 sha256=sha256, 

269 stop_early=stop_early, 

270 run_command=run_command, 

271 **deprecated, 

272 ) 

273 

274 

275def _test_in_env( 

276 source: PermissiveFileSource, 

277 *, 

278 working_dir: Path, 

279 weight_format: Optional[SupportedWeightsFormat], 

280 conda_env: Optional[BioimageioCondaEnv], 

281 devices: Optional[Sequence[str]], 

282 determinism: Literal["seed_only", "full"], 

283 run_command: Callable[[Sequence[str]], None], 

284 stop_early: bool, 

285 expected_type: Optional[str], 

286 sha256: Optional[Sha256], 

287 **deprecated: Unpack[DeprecatedKwargs], 

288) -> ValidationSummary: 

289 descr = load_description(source) 

290 

291 if not isinstance(descr, (v0_4.ModelDescr, v0_5.ModelDescr)): 

292 raise NotImplementedError("Not yet implemented for non-model resources") 

293 

294 if weight_format is None: 

295 all_present_wfs = [ 

296 wf for wf in get_args(WeightsFormat) if getattr(descr.weights, wf) 

297 ] 

298 ignore_wfs = [wf for wf in all_present_wfs if wf in ["tensorflow_js"]] 

299 logger.info( 

300 "Found weight formats {}. Start testing all{}...", 

301 all_present_wfs, 

302 f" (except: {', '.join(ignore_wfs)}) " if ignore_wfs else "", 

303 ) 

304 summary = _test_in_env( 

305 source, 

306 working_dir=working_dir / all_present_wfs[0], 

307 weight_format=all_present_wfs[0], 

308 devices=devices, 

309 determinism=determinism, 

310 conda_env=conda_env, 

311 run_command=run_command, 

312 expected_type=expected_type, 

313 sha256=sha256, 

314 stop_early=stop_early, 

315 **deprecated, 

316 ) 

317 for wf in all_present_wfs[1:]: 

318 additional_summary = _test_in_env( 

319 source, 

320 working_dir=working_dir / wf, 

321 weight_format=wf, 

322 devices=devices, 

323 determinism=determinism, 

324 conda_env=conda_env, 

325 run_command=run_command, 

326 expected_type=expected_type, 

327 sha256=sha256, 

328 stop_early=stop_early, 

329 **deprecated, 

330 ) 

331 for d in additional_summary.details: 

332 # TODO: filter reduntant details; group details 

333 summary.add_detail(d) 

334 return summary 

335 

336 if weight_format == "pytorch_state_dict": 

337 wf = descr.weights.pytorch_state_dict 

338 elif weight_format == "torchscript": 

339 wf = descr.weights.torchscript 

340 elif weight_format == "keras_hdf5": 

341 wf = descr.weights.keras_hdf5 

342 elif weight_format == "onnx": 

343 wf = descr.weights.onnx 

344 elif weight_format == "tensorflow_saved_model_bundle": 

345 wf = descr.weights.tensorflow_saved_model_bundle 

346 elif weight_format == "tensorflow_js": 

347 raise RuntimeError( 

348 "testing 'tensorflow_js' is not supported by bioimageio.core" 

349 ) 

350 else: 

351 assert_never(weight_format) 

352 

353 assert wf is not None 

354 if conda_env is None: 

355 conda_env = get_conda_env(entry=wf) 

356 

357 # remove name as we crate a name based on the env description hash value 

358 conda_env.name = None 

359 

360 dumped_env = conda_env.model_dump(mode="json", exclude_none=True) 

361 if not is_yaml_value(dumped_env): 

362 raise ValueError(f"Failed to dump conda env to valid YAML {conda_env}") 

363 

364 env_io = StringIO() 

365 write_yaml(dumped_env, file=env_io) 

366 encoded_env = env_io.getvalue().encode() 

367 env_name = hashlib.sha256(encoded_env).hexdigest() 

368 

369 try: 

370 run_command(["where" if platform.system() == "Windows" else "which", "conda"]) 

371 except Exception as e: 

372 raise RuntimeError("Conda not available") from e 

373 

374 working_dir.mkdir(parents=True, exist_ok=True) 

375 summary_path = working_dir / "summary.json" 

376 try: 

377 run_command(["conda", "activate", env_name]) 

378 except Exception: 

379 path = working_dir / "env.yaml" 

380 try: 

381 _ = path.write_bytes(encoded_env) 

382 logger.debug("written conda env to {}", path) 

383 run_command( 

384 ["conda", "env", "create", f"--file={path}", f"--name={env_name}"] 

385 ) 

386 run_command(["conda", "activate", env_name]) 

387 except Exception as e: 

388 summary = descr.validation_summary 

389 summary.add_detail( 

390 ValidationDetail( 

391 name="Conda environment creation", 

392 status="failed", 

393 loc=("weights", weight_format), 

394 recommended_env=conda_env, 

395 errors=[ 

396 ErrorEntry( 

397 loc=("weights", weight_format), 

398 msg=str(e), 

399 type="conda", 

400 with_traceback=True, 

401 ) 

402 ], 

403 ) 

404 ) 

405 return summary 

406 

407 cmd = [] 

408 for summary_path_arg_name in ("summary", "summary-path"): 

409 run_command( 

410 cmd := ( 

411 [ 

412 "conda", 

413 "run", 

414 "-n", 

415 env_name, 

416 "bioimageio", 

417 "test", 

418 str(source), 

419 f"--{summary_path_arg_name}={summary_path.as_posix()}", 

420 f"--determinism={determinism}", 

421 ] 

422 + ([f"--expected-type={expected_type}"] if expected_type else []) 

423 + (["--stop-early"] if stop_early else []) 

424 ) 

425 ) 

426 if summary_path.exists(): 

427 break 

428 else: 

429 return ValidationSummary( 

430 name="calling bioimageio test command", 

431 source_name=str(source), 

432 status="failed", 

433 type="unknown", 

434 format_version="unknown", 

435 details=[ 

436 ValidationDetail( 

437 name="run 'bioimageio test'", 

438 errors=[ 

439 ErrorEntry( 

440 loc=(), 

441 type="bioimageio cli", 

442 msg=f"test command '{' '.join(cmd)}' did not produce a summary file at {summary_path}", 

443 ) 

444 ], 

445 status="failed", 

446 ) 

447 ], 

448 env=set(), 

449 ) 

450 

451 return ValidationSummary.model_validate_json(summary_path.read_bytes()) 

452 

453 

454@overload 

455def load_description_and_test( 

456 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

457 *, 

458 format_version: Literal["latest"], 

459 weight_format: Optional[SupportedWeightsFormat] = None, 

460 devices: Optional[Sequence[str]] = None, 

461 determinism: Literal["seed_only", "full"] = "seed_only", 

462 expected_type: Optional[str] = None, 

463 sha256: Optional[Sha256] = None, 

464 stop_early: bool = False, 

465 **deprecated: Unpack[DeprecatedKwargs], 

466) -> Union[LatestResourceDescr, InvalidDescr]: ... 

467 

468 

469@overload 

470def load_description_and_test( 

471 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

472 *, 

473 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 

474 weight_format: Optional[SupportedWeightsFormat] = None, 

475 devices: Optional[Sequence[str]] = None, 

476 determinism: Literal["seed_only", "full"] = "seed_only", 

477 expected_type: Optional[str] = None, 

478 sha256: Optional[Sha256] = None, 

479 stop_early: bool = False, 

480 **deprecated: Unpack[DeprecatedKwargs], 

481) -> Union[ResourceDescr, InvalidDescr]: ... 

482 

483 

484def load_description_and_test( 

485 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

486 *, 

487 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 

488 weight_format: Optional[SupportedWeightsFormat] = None, 

489 devices: Optional[Sequence[str]] = None, 

490 determinism: Literal["seed_only", "full"] = "seed_only", 

491 expected_type: Optional[str] = None, 

492 sha256: Optional[Sha256] = None, 

493 stop_early: bool = False, 

494 **deprecated: Unpack[DeprecatedKwargs], 

495) -> Union[ResourceDescr, InvalidDescr]: 

496 """Test a bioimage.io resource dynamically, 

497 for example run prediction of test tensors for models. 

498 

499 See `test_description` for more details. 

500 

501 Returns: 

502 A (possibly invalid) resource description object 

503 with a populated `.validation_summary` attribute. 

504 """ 

505 if isinstance(source, ResourceDescrBase): 

506 root = source.root 

507 file_name = source.file_name 

508 if ( 

509 ( 

510 format_version 

511 not in ( 

512 DISCOVER, 

513 source.format_version, 

514 ".".join(source.format_version.split(".")[:2]), 

515 ) 

516 ) 

517 or (c := source.validation_summary.details[0].context) is None 

518 or not c.perform_io_checks 

519 ): 

520 logger.debug( 

521 "deserializing source to ensure we validate and test using format {} and perform io checks", 

522 format_version, 

523 ) 

524 source = dump_description(source) 

525 else: 

526 root = Path() 

527 file_name = None 

528 

529 if isinstance(source, ResourceDescrBase): 

530 rd = source 

531 elif isinstance(source, dict): 

532 # check context for a given root; default to root of source 

533 context = get_validation_context( 

534 ValidationContext(root=root, file_name=file_name) 

535 ).replace( 

536 perform_io_checks=True # make sure we perform io checks though 

537 ) 

538 

539 rd = build_description( 

540 source, 

541 format_version=format_version, 

542 context=context, 

543 ) 

544 else: 

545 rd = load_description( 

546 source, format_version=format_version, sha256=sha256, perform_io_checks=True 

547 ) 

548 

549 rd.validation_summary.env.add( 

550 InstalledPackage(name="bioimageio.core", version=VERSION) 

551 ) 

552 

553 if expected_type is not None: 

554 _test_expected_resource_type(rd, expected_type) 

555 

556 if isinstance(rd, (v0_4.ModelDescr, v0_5.ModelDescr)): 

557 if weight_format is None: 

558 weight_formats: List[SupportedWeightsFormat] = [ 

559 w for w, we in rd.weights if we is not None 

560 ] # pyright: ignore[reportAssignmentType] 

561 else: 

562 weight_formats = [weight_format] 

563 

564 enable_determinism(determinism, weight_formats=weight_formats) 

565 for w in weight_formats: 

566 _test_model_inference(rd, w, devices, stop_early=stop_early, **deprecated) 

567 if stop_early and rd.validation_summary.status == "failed": 

568 break 

569 

570 if not isinstance(rd, v0_4.ModelDescr): 

571 _test_model_inference_parametrized( 

572 rd, w, devices, stop_early=stop_early 

573 ) 

574 if stop_early and rd.validation_summary.status == "failed": 

575 break 

576 

577 # TODO: add execution of jupyter notebooks 

578 # TODO: add more tests 

579 

580 if rd.validation_summary.status == "valid-format": 

581 rd.validation_summary.status = "passed" 

582 

583 return rd 

584 

585 

586def _get_tolerance( 

587 model: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

588 wf: SupportedWeightsFormat, 

589 m: MemberId, 

590 **deprecated: Unpack[DeprecatedKwargs], 

591) -> Tuple[RelativeTolerance, AbsoluteTolerance, MismatchedElementsPerMillion]: 

592 if isinstance(model, v0_5.ModelDescr): 

593 applicable = v0_5.ReproducibilityTolerance() 

594 

595 # check legacy test kwargs for weight format specific tolerance 

596 if model.config.bioimageio.model_extra is not None: 

597 for weights_format, test_kwargs in model.config.bioimageio.model_extra.get( 

598 "test_kwargs", {} 

599 ).items(): 

600 if wf == weights_format: 

601 applicable = v0_5.ReproducibilityTolerance( 

602 relative_tolerance=test_kwargs.get("relative_tolerance", 1e-3), 

603 absolute_tolerance=test_kwargs.get("absolute_tolerance", 1e-4), 

604 ) 

605 break 

606 

607 # check for weights format and output tensor specific tolerance 

608 for a in model.config.bioimageio.reproducibility_tolerance: 

609 if (not a.weights_formats or wf in a.weights_formats) and ( 

610 not a.output_ids or m in a.output_ids 

611 ): 

612 applicable = a 

613 break 

614 

615 rtol = applicable.relative_tolerance 

616 atol = applicable.absolute_tolerance 

617 mismatched_tol = applicable.mismatched_elements_per_million 

618 elif (decimal := deprecated.get("decimal")) is not None: 

619 warnings.warn( 

620 "The argument `decimal` has been deprecated in favour of" 

621 + " `relative_tolerance` and `absolute_tolerance`, with different" 

622 + " validation logic, using `numpy.testing.assert_allclose, see" 

623 + " 'https://numpy.org/doc/stable/reference/generated/" 

624 + " numpy.testing.assert_allclose.html'. Passing a value for `decimal`" 

625 + " will cause validation to revert to the old behaviour." 

626 ) 

627 atol = 1.5 * 10 ** (-decimal) 

628 rtol = 0 

629 mismatched_tol = 0 

630 else: 

631 # use given (deprecated) test kwargs 

632 atol = deprecated.get("absolute_tolerance", 1e-5) 

633 rtol = deprecated.get("relative_tolerance", 1e-3) 

634 mismatched_tol = 0 

635 

636 return rtol, atol, mismatched_tol 

637 

638 

639def _test_model_inference( 

640 model: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

641 weight_format: SupportedWeightsFormat, 

642 devices: Optional[Sequence[str]], 

643 stop_early: bool, 

644 **deprecated: Unpack[DeprecatedKwargs], 

645) -> None: 

646 test_name = f"Reproduce test outputs from test inputs ({weight_format})" 

647 logger.debug("starting '{}'", test_name) 

648 error_entries: List[ErrorEntry] = [] 

649 warning_entries: List[WarningEntry] = [] 

650 

651 def add_error_entry(msg: str, with_traceback: bool = False): 

652 error_entries.append( 

653 ErrorEntry( 

654 loc=("weights", weight_format), 

655 msg=msg, 

656 type="bioimageio.core", 

657 with_traceback=with_traceback, 

658 ) 

659 ) 

660 

661 def add_warning_entry(msg: str): 

662 warning_entries.append( 

663 WarningEntry( 

664 loc=("weights", weight_format), 

665 msg=msg, 

666 type="bioimageio.core", 

667 ) 

668 ) 

669 

670 try: 

671 inputs = get_test_inputs(model) 

672 expected = get_test_outputs(model) 

673 

674 with create_prediction_pipeline( 

675 bioimageio_model=model, devices=devices, weight_format=weight_format 

676 ) as prediction_pipeline: 

677 results = prediction_pipeline.predict_sample_without_blocking(inputs) 

678 

679 if len(results.members) != len(expected.members): 

680 add_error_entry( 

681 f"Expected {len(expected.members)} outputs, but got {len(results.members)}" 

682 ) 

683 

684 else: 

685 for m, expected in expected.members.items(): 

686 actual = results.members.get(m) 

687 if actual is None: 

688 add_error_entry("Output tensors for test case may not be None") 

689 if stop_early: 

690 break 

691 else: 

692 continue 

693 

694 rtol, atol, mismatched_tol = _get_tolerance( 

695 model, wf=weight_format, m=m, **deprecated 

696 ) 

697 rtol_value = rtol * abs(expected) 

698 abs_diff = abs(actual - expected) 

699 mismatched = abs_diff > atol + rtol_value 

700 mismatched_elements = mismatched.sum().item() 

701 if not mismatched_elements: 

702 continue 

703 

704 mismatched_ppm = mismatched_elements / expected.size * 1e6 

705 abs_diff[~mismatched] = 0 # ignore non-mismatched elements 

706 

707 r_max_idx = (r_diff := (abs_diff / (abs(expected) + 1e-6))).argmax() 

708 r_max = r_diff[r_max_idx].item() 

709 r_actual = actual[r_max_idx].item() 

710 r_expected = expected[r_max_idx].item() 

711 

712 # Calculate the max absolute difference with the relative tolerance subtracted 

713 abs_diff_wo_rtol: xr.DataArray = xr.ufuncs.maximum( 

714 (abs_diff - rtol_value).data, 0 

715 ) 

716 a_max_idx = { 

717 AxisId(k): int(v) for k, v in abs_diff_wo_rtol.argmax().items() 

718 } 

719 

720 a_max = abs_diff[a_max_idx].item() 

721 a_actual = actual[a_max_idx].item() 

722 a_expected = expected[a_max_idx].item() 

723 

724 msg = ( 

725 f"Output '{m}' disagrees with {mismatched_elements} of" 

726 + f" {expected.size} expected values" 

727 + f" ({mismatched_ppm:.1f} ppm)." 

728 + f"\n Max relative difference: {r_max:.2e}" 

729 + rf" (= \|{r_actual:.2e} - {r_expected:.2e}\|/\|{r_expected:.2e} + 1e-6\|)" 

730 + f" at {r_max_idx}" 

731 + f"\n Max absolute difference not accounted for by relative tolerance: {a_max:.2e}" 

732 + rf" (= \|{a_actual:.7e} - {a_expected:.7e}\|) at {a_max_idx}" 

733 ) 

734 if mismatched_ppm > mismatched_tol: 

735 add_error_entry(msg) 

736 if stop_early: 

737 break 

738 else: 

739 add_warning_entry(msg) 

740 

741 except Exception as e: 

742 if get_validation_context().raise_errors: 

743 raise e 

744 

745 add_error_entry(str(e), with_traceback=True) 

746 

747 model.validation_summary.add_detail( 

748 ValidationDetail( 

749 name=test_name, 

750 loc=("weights", weight_format), 

751 status="failed" if error_entries else "passed", 

752 recommended_env=get_conda_env(entry=dict(model.weights)[weight_format]), 

753 errors=error_entries, 

754 warnings=warning_entries, 

755 ) 

756 ) 

757 

758 

759def _test_model_inference_parametrized( 

760 model: v0_5.ModelDescr, 

761 weight_format: SupportedWeightsFormat, 

762 devices: Optional[Sequence[str]], 

763 *, 

764 stop_early: bool, 

765) -> None: 

766 if not any( 

767 isinstance(a.size, v0_5.ParameterizedSize) 

768 for ipt in model.inputs 

769 for a in ipt.axes 

770 ): 

771 # no parameterized sizes => set n=0 

772 ns: Set[v0_5.ParameterizedSize_N] = {0} 

773 else: 

774 ns = {0, 1, 2} 

775 

776 given_batch_sizes = { 

777 a.size 

778 for ipt in model.inputs 

779 for a in ipt.axes 

780 if isinstance(a, v0_5.BatchAxis) 

781 } 

782 if given_batch_sizes: 

783 batch_sizes = {gbs for gbs in given_batch_sizes if gbs is not None} 

784 if not batch_sizes: 

785 # only arbitrary batch sizes 

786 batch_sizes = {1, 2} 

787 else: 

788 # no batch axis 

789 batch_sizes = {1} 

790 

791 test_cases: Set[Tuple[BatchSize, v0_5.ParameterizedSize_N]] = { 

792 (b, n) for b, n in product(sorted(batch_sizes), sorted(ns)) 

793 } 

794 logger.info( 

795 "Testing inference with {} different inputs (B, N): {}", 

796 len(test_cases), 

797 test_cases, 

798 ) 

799 

800 def generate_test_cases(): 

801 tested: Set[Hashable] = set() 

802 

803 def get_ns(n: int): 

804 return { 

805 (t.id, a.id): n 

806 for t in model.inputs 

807 for a in t.axes 

808 if isinstance(a.size, v0_5.ParameterizedSize) 

809 } 

810 

811 for batch_size, n in sorted(test_cases): 

812 input_target_sizes, expected_output_sizes = model.get_axis_sizes( 

813 get_ns(n), batch_size=batch_size 

814 ) 

815 hashable_target_size = tuple( 

816 (k, input_target_sizes[k]) for k in sorted(input_target_sizes) 

817 ) 

818 if hashable_target_size in tested: 

819 continue 

820 else: 

821 tested.add(hashable_target_size) 

822 

823 resized_test_inputs = Sample( 

824 members={ 

825 t.id: ( 

826 test_inputs.members[t.id].resize_to( 

827 { 

828 aid: s 

829 for (tid, aid), s in input_target_sizes.items() 

830 if tid == t.id 

831 }, 

832 ) 

833 ) 

834 for t in model.inputs 

835 }, 

836 stat=test_inputs.stat, 

837 id=test_inputs.id, 

838 ) 

839 expected_output_shapes = { 

840 t.id: { 

841 aid: s 

842 for (tid, aid), s in expected_output_sizes.items() 

843 if tid == t.id 

844 } 

845 for t in model.outputs 

846 } 

847 yield n, batch_size, resized_test_inputs, expected_output_shapes 

848 

849 try: 

850 test_inputs = get_test_inputs(model) 

851 

852 with create_prediction_pipeline( 

853 bioimageio_model=model, devices=devices, weight_format=weight_format 

854 ) as prediction_pipeline: 

855 for n, batch_size, inputs, exptected_output_shape in generate_test_cases(): 

856 error: Optional[str] = None 

857 result = prediction_pipeline.predict_sample_without_blocking(inputs) 

858 if len(result.members) != len(exptected_output_shape): 

859 error = ( 

860 f"Expected {len(exptected_output_shape)} outputs," 

861 + f" but got {len(result.members)}" 

862 ) 

863 

864 else: 

865 for m, exp in exptected_output_shape.items(): 

866 res = result.members.get(m) 

867 if res is None: 

868 error = "Output tensors may not be None for test case" 

869 break 

870 

871 diff: Dict[AxisId, int] = {} 

872 for a, s in res.sizes.items(): 

873 if isinstance((e_aid := exp[AxisId(a)]), int): 

874 if s != e_aid: 

875 diff[AxisId(a)] = s 

876 elif ( 

877 s < e_aid.min or e_aid.max is not None and s > e_aid.max 

878 ): 

879 diff[AxisId(a)] = s 

880 if diff: 

881 error = ( 

882 f"(n={n}) Expected output shape {exp}," 

883 + f" but got {res.sizes} (diff: {diff})" 

884 ) 

885 break 

886 

887 model.validation_summary.add_detail( 

888 ValidationDetail( 

889 name=f"Run {weight_format} inference for inputs with" 

890 + f" batch_size: {batch_size} and size parameter n: {n}", 

891 loc=("weights", weight_format), 

892 status="passed" if error is None else "failed", 

893 errors=( 

894 [] 

895 if error is None 

896 else [ 

897 ErrorEntry( 

898 loc=("weights", weight_format), 

899 msg=error, 

900 type="bioimageio.core", 

901 ) 

902 ] 

903 ), 

904 ) 

905 ) 

906 if stop_early and error is not None: 

907 break 

908 except Exception as e: 

909 if get_validation_context().raise_errors: 

910 raise e 

911 

912 model.validation_summary.add_detail( 

913 ValidationDetail( 

914 name=f"Run {weight_format} inference for parametrized inputs", 

915 status="failed", 

916 loc=("weights", weight_format), 

917 errors=[ 

918 ErrorEntry( 

919 loc=("weights", weight_format), 

920 msg=str(e), 

921 type="bioimageio.core", 

922 with_traceback=True, 

923 ) 

924 ], 

925 ) 

926 ) 

927 

928 

929def _test_expected_resource_type( 

930 rd: Union[InvalidDescr, ResourceDescr], expected_type: str 

931): 

932 has_expected_type = rd.type == expected_type 

933 rd.validation_summary.details.append( 

934 ValidationDetail( 

935 name="Has expected resource type", 

936 status="passed" if has_expected_type else "failed", 

937 loc=("type",), 

938 errors=( 

939 [] 

940 if has_expected_type 

941 else [ 

942 ErrorEntry( 

943 loc=("type",), 

944 type="type", 

945 msg=f"Expected type {expected_type}, found {rd.type}", 

946 ) 

947 ] 

948 ), 

949 ) 

950 ) 

951 

952 

953# TODO: Implement `debug_model()` 

954# def debug_model( 

955# model_rdf: Union[RawResourceDescr, ResourceDescr, URI, Path, str], 

956# *, 

957# weight_format: Optional[WeightsFormat] = None, 

958# devices: Optional[List[str]] = None, 

959# ): 

960# """Run the model test and return dict with inputs, results, expected results and intermediates. 

961 

962# Returns dict with tensors "inputs", "inputs_processed", "outputs_raw", "outputs", "expected" and "diff". 

963# """ 

964# inputs_raw: Optional = None 

965# inputs_processed: Optional = None 

966# outputs_raw: Optional = None 

967# outputs: Optional = None 

968# expected: Optional = None 

969# diff: Optional = None 

970 

971# model = load_description( 

972# model_rdf, weights_priority_order=None if weight_format is None else [weight_format] 

973# ) 

974# if not isinstance(model, Model): 

975# raise ValueError(f"Not a bioimageio.model: {model_rdf}") 

976 

977# prediction_pipeline = create_prediction_pipeline( 

978# bioimageio_model=model, devices=devices, weight_format=weight_format 

979# ) 

980# inputs = [ 

981# xr.DataArray(load_array(str(in_path)), dims=input_spec.axes) 

982# for in_path, input_spec in zip(model.test_inputs, model.inputs) 

983# ] 

984# input_dict = {input_spec.name: input for input_spec, input in zip(model.inputs, inputs)} 

985 

986# # keep track of the non-processed inputs 

987# inputs_raw = [deepcopy(input) for input in inputs] 

988 

989# computed_measures = {} 

990 

991# prediction_pipeline.apply_preprocessing(input_dict, computed_measures) 

992# inputs_processed = list(input_dict.values()) 

993# outputs_raw = prediction_pipeline.predict(*inputs_processed) 

994# output_dict = {output_spec.name: deepcopy(output) for output_spec, output in zip(model.outputs, outputs_raw)} 

995# prediction_pipeline.apply_postprocessing(output_dict, computed_measures) 

996# outputs = list(output_dict.values()) 

997 

998# if isinstance(outputs, (np.ndarray, xr.DataArray)): 

999# outputs = [outputs] 

1000 

1001# expected = [ 

1002# xr.DataArray(load_array(str(out_path)), dims=output_spec.axes) 

1003# for out_path, output_spec in zip(model.test_outputs, model.outputs) 

1004# ] 

1005# if len(outputs) != len(expected): 

1006# error = f"Number of outputs and number of expected outputs disagree: {len(outputs)} != {len(expected)}" 

1007# print(error) 

1008# else: 

1009# diff = [] 

1010# for res, exp in zip(outputs, expected): 

1011# diff.append(res - exp) 

1012 

1013# return { 

1014# "inputs": inputs_raw, 

1015# "inputs_processed": inputs_processed, 

1016# "outputs_raw": outputs_raw, 

1017# "outputs": outputs, 

1018# "expected": expected, 

1019# "diff": diff, 

1020# }