Coverage for bioimageio/core/_resource_tests.py: 61%

304 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-01 09:51 +0000

1import hashlib 

2import os 

3import platform 

4import subprocess 

5import warnings 

6from io import StringIO 

7from itertools import product 

8from pathlib import Path 

9from tempfile import TemporaryDirectory 

10from typing import ( 

11 Callable, 

12 Dict, 

13 Hashable, 

14 List, 

15 Literal, 

16 Optional, 

17 Sequence, 

18 Set, 

19 Tuple, 

20 Union, 

21 overload, 

22) 

23 

24from loguru import logger 

25from typing_extensions import NotRequired, TypedDict, Unpack, assert_never, get_args 

26 

27from bioimageio.spec import ( 

28 BioimageioCondaEnv, 

29 InvalidDescr, 

30 LatestResourceDescr, 

31 ResourceDescr, 

32 ValidationContext, 

33 build_description, 

34 dump_description, 

35 get_conda_env, 

36 load_description, 

37 save_bioimageio_package, 

38) 

39from bioimageio.spec._description_impl import DISCOVER 

40from bioimageio.spec._internal.common_nodes import ResourceDescrBase 

41from bioimageio.spec._internal.io import is_yaml_value 

42from bioimageio.spec._internal.io_utils import read_yaml, write_yaml 

43from bioimageio.spec._internal.types import ( 

44 AbsoluteTolerance, 

45 FormatVersionPlaceholder, 

46 MismatchedElementsPerMillion, 

47 RelativeTolerance, 

48) 

49from bioimageio.spec._internal.validation_context import get_validation_context 

50from bioimageio.spec.common import BioimageioYamlContent, PermissiveFileSource, Sha256 

51from bioimageio.spec.model import v0_4, v0_5 

52from bioimageio.spec.model.v0_5 import WeightsFormat 

53from bioimageio.spec.summary import ( 

54 ErrorEntry, 

55 InstalledPackage, 

56 ValidationDetail, 

57 ValidationSummary, 

58) 

59 

60from ._prediction_pipeline import create_prediction_pipeline 

61from .axis import AxisId, BatchSize 

62from .common import MemberId, SupportedWeightsFormat 

63from .digest_spec import get_test_inputs, get_test_outputs 

64from .sample import Sample 

65from .utils import VERSION 

66 

67 

68class DeprecatedKwargs(TypedDict): 

69 absolute_tolerance: NotRequired[AbsoluteTolerance] 

70 relative_tolerance: NotRequired[RelativeTolerance] 

71 decimal: NotRequired[Optional[int]] 

72 

73 

74def enable_determinism( 

75 mode: Literal["seed_only", "full"] = "full", 

76 weight_formats: Optional[Sequence[SupportedWeightsFormat]] = None, 

77): 

78 """Seed and configure ML frameworks for maximum reproducibility. 

79 May degrade performance. Only recommended for testing reproducibility! 

80 

81 Seed any random generators and (if **mode**=="full") request ML frameworks to use 

82 deterministic algorithms. 

83 

84 Args: 

85 mode: determinism mode 

86 - 'seed_only' -- only set seeds, or 

87 - 'full' determinsm features (might degrade performance or throw exceptions) 

88 weight_formats: Limit deep learning importing deep learning frameworks 

89 based on weight_formats. 

90 E.g. this allows to avoid importing tensorflow when testing with pytorch. 

91 

92 Notes: 

93 - **mode** == "full" might degrade performance or throw exceptions. 

94 - Subsequent inference calls might still differ. Call before each function 

95 (sequence) that is expected to be reproducible. 

96 - Degraded performance: Use for testing reproducibility only! 

97 - Recipes: 

98 - [PyTorch](https://pytorch.org/docs/stable/notes/randomness.html) 

99 - [Keras](https://keras.io/examples/keras_recipes/reproducibility_recipes/) 

100 - [NumPy](https://numpy.org/doc/2.0/reference/random/generated/numpy.random.seed.html) 

101 """ 

102 try: 

103 try: 

104 import numpy.random 

105 except ImportError: 

106 pass 

107 else: 

108 numpy.random.seed(0) 

109 except Exception as e: 

110 logger.debug(str(e)) 

111 

112 if ( 

113 weight_formats is None 

114 or "pytorch_state_dict" in weight_formats 

115 or "torchscript" in weight_formats 

116 ): 

117 try: 

118 try: 

119 import torch 

120 except ImportError: 

121 pass 

122 else: 

123 _ = torch.manual_seed(0) 

124 torch.use_deterministic_algorithms(mode == "full") 

125 except Exception as e: 

126 logger.debug(str(e)) 

127 

128 if ( 

129 weight_formats is None 

130 or "tensorflow_saved_model_bundle" in weight_formats 

131 or "keras_hdf5" in weight_formats 

132 ): 

133 try: 

134 os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" 

135 try: 

136 import tensorflow as tf # pyright: ignore[reportMissingTypeStubs] 

137 except ImportError: 

138 pass 

139 else: 

140 tf.random.set_seed(0) 

141 if mode == "full": 

142 tf.config.experimental.enable_op_determinism() 

143 # TODO: find possibility to switch it off again?? 

144 except Exception as e: 

145 logger.debug(str(e)) 

146 

147 if weight_formats is None or "keras_hdf5" in weight_formats: 

148 try: 

149 try: 

150 import keras # pyright: ignore[reportMissingTypeStubs] 

151 except ImportError: 

152 pass 

153 else: 

154 keras.utils.set_random_seed(0) 

155 except Exception as e: 

156 logger.debug(str(e)) 

157 

158 

159def test_model( 

160 source: Union[v0_4.ModelDescr, v0_5.ModelDescr, PermissiveFileSource], 

161 weight_format: Optional[SupportedWeightsFormat] = None, 

162 devices: Optional[List[str]] = None, 

163 *, 

164 determinism: Literal["seed_only", "full"] = "seed_only", 

165 sha256: Optional[Sha256] = None, 

166 stop_early: bool = False, 

167 **deprecated: Unpack[DeprecatedKwargs], 

168) -> ValidationSummary: 

169 """Test model inference""" 

170 return test_description( 

171 source, 

172 weight_format=weight_format, 

173 devices=devices, 

174 determinism=determinism, 

175 expected_type="model", 

176 sha256=sha256, 

177 stop_early=stop_early, 

178 **deprecated, 

179 ) 

180 

181 

182def default_run_command(args: Sequence[str]): 

183 logger.info("running '{}'...", " ".join(args)) 

184 _ = subprocess.run(args, shell=True, text=True, check=True) 

185 

186 

187def test_description( 

188 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

189 *, 

190 format_version: Union[FormatVersionPlaceholder, str] = "discover", 

191 weight_format: Optional[SupportedWeightsFormat] = None, 

192 devices: Optional[Sequence[str]] = None, 

193 determinism: Literal["seed_only", "full"] = "seed_only", 

194 expected_type: Optional[str] = None, 

195 sha256: Optional[Sha256] = None, 

196 stop_early: bool = False, 

197 runtime_env: Union[ 

198 Literal["currently-active", "as-described"], Path, BioimageioCondaEnv 

199 ] = ("currently-active"), 

200 run_command: Callable[[Sequence[str]], None] = default_run_command, 

201 **deprecated: Unpack[DeprecatedKwargs], 

202) -> ValidationSummary: 

203 """Test a bioimage.io resource dynamically, 

204 for example run prediction of test tensors for models. 

205 

206 Args: 

207 source: model description source. 

208 weight_format: Weight format to test. 

209 Default: All weight formats present in **source**. 

210 devices: Devices to test with, e.g. 'cpu', 'cuda'. 

211 Default (may be weight format dependent): ['cuda'] if available, ['cpu'] otherwise. 

212 determinism: Modes to improve reproducibility of test outputs. 

213 expected_type: Assert an expected resource description `type`. 

214 sha256: Expected SHA256 value of **source**. 

215 (Ignored if **source** already is a loaded `ResourceDescr` object.) 

216 stop_early: Do not run further subtests after a failed one. 

217 runtime_env: (Experimental feature!) The Python environment to run the tests in 

218 - `"currently-active"`: Use active Python interpreter. 

219 - `"as-described"`: Use `bioimageio.spec.get_conda_env` to generate a conda 

220 environment YAML file based on the model weights description. 

221 - A `BioimageioCondaEnv` or a path to a conda environment YAML file. 

222 Note: The `bioimageio.core` dependency will be added automatically if not present. 

223 run_command: (Experimental feature!) Function to execute (conda) terminal commands in a subprocess 

224 (ignored if **runtime_env** is `"currently-active"`). 

225 """ 

226 if runtime_env == "currently-active": 

227 rd = load_description_and_test( 

228 source, 

229 format_version=format_version, 

230 weight_format=weight_format, 

231 devices=devices, 

232 determinism=determinism, 

233 expected_type=expected_type, 

234 sha256=sha256, 

235 stop_early=stop_early, 

236 **deprecated, 

237 ) 

238 return rd.validation_summary 

239 

240 if runtime_env == "as-described": 

241 conda_env = None 

242 elif isinstance(runtime_env, (str, Path)): 

243 conda_env = BioimageioCondaEnv.model_validate(read_yaml(Path(runtime_env))) 

244 elif isinstance(runtime_env, BioimageioCondaEnv): 

245 conda_env = runtime_env 

246 else: 

247 assert_never(runtime_env) 

248 

249 with TemporaryDirectory(ignore_cleanup_errors=True) as _d: 

250 working_dir = Path(_d) 

251 if isinstance(source, (dict, ResourceDescrBase)): 

252 file_source = save_bioimageio_package( 

253 source, output_path=working_dir / "package.zip" 

254 ) 

255 else: 

256 file_source = source 

257 

258 return _test_in_env( 

259 file_source, 

260 working_dir=working_dir, 

261 weight_format=weight_format, 

262 conda_env=conda_env, 

263 devices=devices, 

264 determinism=determinism, 

265 expected_type=expected_type, 

266 sha256=sha256, 

267 stop_early=stop_early, 

268 run_command=run_command, 

269 **deprecated, 

270 ) 

271 

272 

273def _test_in_env( 

274 source: PermissiveFileSource, 

275 *, 

276 working_dir: Path, 

277 weight_format: Optional[SupportedWeightsFormat], 

278 conda_env: Optional[BioimageioCondaEnv], 

279 devices: Optional[Sequence[str]], 

280 determinism: Literal["seed_only", "full"], 

281 run_command: Callable[[Sequence[str]], None], 

282 stop_early: bool, 

283 expected_type: Optional[str], 

284 sha256: Optional[Sha256], 

285 **deprecated: Unpack[DeprecatedKwargs], 

286) -> ValidationSummary: 

287 descr = load_description(source) 

288 

289 if not isinstance(descr, (v0_4.ModelDescr, v0_5.ModelDescr)): 

290 raise NotImplementedError("Not yet implemented for non-model resources") 

291 

292 if weight_format is None: 

293 all_present_wfs = [ 

294 wf for wf in get_args(WeightsFormat) if getattr(descr.weights, wf) 

295 ] 

296 ignore_wfs = [wf for wf in all_present_wfs if wf in ["tensorflow_js"]] 

297 logger.info( 

298 "Found weight formats {}. Start testing all{}...", 

299 all_present_wfs, 

300 f" (except: {', '.join(ignore_wfs)}) " if ignore_wfs else "", 

301 ) 

302 summary = _test_in_env( 

303 source, 

304 working_dir=working_dir / all_present_wfs[0], 

305 weight_format=all_present_wfs[0], 

306 devices=devices, 

307 determinism=determinism, 

308 conda_env=conda_env, 

309 run_command=run_command, 

310 expected_type=expected_type, 

311 sha256=sha256, 

312 stop_early=stop_early, 

313 **deprecated, 

314 ) 

315 for wf in all_present_wfs[1:]: 

316 additional_summary = _test_in_env( 

317 source, 

318 working_dir=working_dir / wf, 

319 weight_format=wf, 

320 devices=devices, 

321 determinism=determinism, 

322 conda_env=conda_env, 

323 run_command=run_command, 

324 expected_type=expected_type, 

325 sha256=sha256, 

326 stop_early=stop_early, 

327 **deprecated, 

328 ) 

329 for d in additional_summary.details: 

330 # TODO: filter reduntant details; group details 

331 summary.add_detail(d) 

332 return summary 

333 

334 if weight_format == "pytorch_state_dict": 

335 wf = descr.weights.pytorch_state_dict 

336 elif weight_format == "torchscript": 

337 wf = descr.weights.torchscript 

338 elif weight_format == "keras_hdf5": 

339 wf = descr.weights.keras_hdf5 

340 elif weight_format == "onnx": 

341 wf = descr.weights.onnx 

342 elif weight_format == "tensorflow_saved_model_bundle": 

343 wf = descr.weights.tensorflow_saved_model_bundle 

344 elif weight_format == "tensorflow_js": 

345 raise RuntimeError( 

346 "testing 'tensorflow_js' is not supported by bioimageio.core" 

347 ) 

348 else: 

349 assert_never(weight_format) 

350 

351 assert wf is not None 

352 if conda_env is None: 

353 conda_env = get_conda_env(entry=wf) 

354 

355 # remove name as we crate a name based on the env description hash value 

356 conda_env.name = None 

357 

358 dumped_env = conda_env.model_dump(mode="json", exclude_none=True) 

359 if not is_yaml_value(dumped_env): 

360 raise ValueError(f"Failed to dump conda env to valid YAML {conda_env}") 

361 

362 env_io = StringIO() 

363 write_yaml(dumped_env, file=env_io) 

364 encoded_env = env_io.getvalue().encode() 

365 env_name = hashlib.sha256(encoded_env).hexdigest() 

366 

367 try: 

368 run_command(["where" if platform.system() == "Windows" else "which", "conda"]) 

369 except Exception as e: 

370 raise RuntimeError("Conda not available") from e 

371 

372 working_dir.mkdir(parents=True, exist_ok=True) 

373 try: 

374 run_command(["conda", "activate", env_name]) 

375 except Exception: 

376 path = working_dir / "env.yaml" 

377 _ = path.write_bytes(encoded_env) 

378 logger.debug("written conda env to {}", path) 

379 run_command(["conda", "env", "create", f"--file={path}", f"--name={env_name}"]) 

380 run_command(["conda", "activate", env_name]) 

381 

382 summary_path = working_dir / "summary.json" 

383 run_command( 

384 [ 

385 "conda", 

386 "run", 

387 "-n", 

388 env_name, 

389 "bioimageio", 

390 "test", 

391 str(source), 

392 f"--summary-path={summary_path}", 

393 f"--determinism={determinism}", 

394 ] 

395 + ([f"--expected-type={expected_type}"] if expected_type else []) 

396 + (["--stop-early"] if stop_early else []) 

397 ) 

398 return ValidationSummary.model_validate_json(summary_path.read_bytes()) 

399 

400 

401@overload 

402def load_description_and_test( 

403 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

404 *, 

405 format_version: Literal["latest"], 

406 weight_format: Optional[SupportedWeightsFormat] = None, 

407 devices: Optional[Sequence[str]] = None, 

408 determinism: Literal["seed_only", "full"] = "seed_only", 

409 expected_type: Optional[str] = None, 

410 sha256: Optional[Sha256] = None, 

411 stop_early: bool = False, 

412 **deprecated: Unpack[DeprecatedKwargs], 

413) -> Union[LatestResourceDescr, InvalidDescr]: ... 

414 

415 

416@overload 

417def load_description_and_test( 

418 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

419 *, 

420 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 

421 weight_format: Optional[SupportedWeightsFormat] = None, 

422 devices: Optional[Sequence[str]] = None, 

423 determinism: Literal["seed_only", "full"] = "seed_only", 

424 expected_type: Optional[str] = None, 

425 sha256: Optional[Sha256] = None, 

426 stop_early: bool = False, 

427 **deprecated: Unpack[DeprecatedKwargs], 

428) -> Union[ResourceDescr, InvalidDescr]: ... 

429 

430 

431def load_description_and_test( 

432 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

433 *, 

434 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 

435 weight_format: Optional[SupportedWeightsFormat] = None, 

436 devices: Optional[Sequence[str]] = None, 

437 determinism: Literal["seed_only", "full"] = "seed_only", 

438 expected_type: Optional[str] = None, 

439 sha256: Optional[Sha256] = None, 

440 stop_early: bool = False, 

441 **deprecated: Unpack[DeprecatedKwargs], 

442) -> Union[ResourceDescr, InvalidDescr]: 

443 """Test a bioimage.io resource dynamically, 

444 for example run prediction of test tensors for models. 

445 

446 See `test_description` for more details. 

447 

448 Returns: 

449 A (possibly invalid) resource description object 

450 with a populated `.validation_summary` attribute. 

451 """ 

452 if isinstance(source, ResourceDescrBase): 

453 root = source.root 

454 file_name = source.file_name 

455 if ( 

456 ( 

457 format_version 

458 not in ( 

459 DISCOVER, 

460 source.format_version, 

461 ".".join(source.format_version.split(".")[:2]), 

462 ) 

463 ) 

464 or (c := source.validation_summary.details[0].context) is None 

465 or not c.perform_io_checks 

466 ): 

467 logger.debug( 

468 "deserializing source to ensure we validate and test using format {} and perform io checks", 

469 format_version, 

470 ) 

471 source = dump_description(source) 

472 else: 

473 root = Path() 

474 file_name = None 

475 

476 if isinstance(source, ResourceDescrBase): 

477 rd = source 

478 elif isinstance(source, dict): 

479 # check context for a given root; default to root of source 

480 context = get_validation_context( 

481 ValidationContext(root=root, file_name=file_name) 

482 ).replace( 

483 perform_io_checks=True # make sure we perform io checks though 

484 ) 

485 

486 rd = build_description( 

487 source, 

488 format_version=format_version, 

489 context=context, 

490 ) 

491 else: 

492 rd = load_description( 

493 source, format_version=format_version, sha256=sha256, perform_io_checks=True 

494 ) 

495 

496 rd.validation_summary.env.add( 

497 InstalledPackage(name="bioimageio.core", version=VERSION) 

498 ) 

499 

500 if expected_type is not None: 

501 _test_expected_resource_type(rd, expected_type) 

502 

503 if isinstance(rd, (v0_4.ModelDescr, v0_5.ModelDescr)): 

504 if weight_format is None: 

505 weight_formats: List[SupportedWeightsFormat] = [ 

506 w for w, we in rd.weights if we is not None 

507 ] # pyright: ignore[reportAssignmentType] 

508 else: 

509 weight_formats = [weight_format] 

510 

511 enable_determinism(determinism, weight_formats=weight_formats) 

512 for w in weight_formats: 

513 _test_model_inference(rd, w, devices, **deprecated) 

514 if stop_early and rd.validation_summary.status == "failed": 

515 break 

516 

517 if not isinstance(rd, v0_4.ModelDescr): 

518 _test_model_inference_parametrized( 

519 rd, w, devices, stop_early=stop_early 

520 ) 

521 if stop_early and rd.validation_summary.status == "failed": 

522 break 

523 

524 # TODO: add execution of jupyter notebooks 

525 # TODO: add more tests 

526 

527 if rd.validation_summary.status == "valid-format": 

528 rd.validation_summary.status = "passed" 

529 

530 return rd 

531 

532 

533def _get_tolerance( 

534 model: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

535 wf: SupportedWeightsFormat, 

536 m: MemberId, 

537 **deprecated: Unpack[DeprecatedKwargs], 

538) -> Tuple[RelativeTolerance, AbsoluteTolerance, MismatchedElementsPerMillion]: 

539 if isinstance(model, v0_5.ModelDescr): 

540 applicable = v0_5.ReproducibilityTolerance() 

541 

542 # check legacy test kwargs for weight format specific tolerance 

543 if model.config.bioimageio.model_extra is not None: 

544 for weights_format, test_kwargs in model.config.bioimageio.model_extra.get( 

545 "test_kwargs", {} 

546 ).items(): 

547 if wf == weights_format: 

548 applicable = v0_5.ReproducibilityTolerance( 

549 relative_tolerance=test_kwargs.get("relative_tolerance", 1e-3), 

550 absolute_tolerance=test_kwargs.get("absolute_tolerance", 1e-4), 

551 ) 

552 break 

553 

554 # check for weights format and output tensor specific tolerance 

555 for a in model.config.bioimageio.reproducibility_tolerance: 

556 if (not a.weights_formats or wf in a.weights_formats) and ( 

557 not a.output_ids or m in a.output_ids 

558 ): 

559 applicable = a 

560 break 

561 

562 rtol = applicable.relative_tolerance 

563 atol = applicable.absolute_tolerance 

564 mismatched_tol = applicable.mismatched_elements_per_million 

565 elif (decimal := deprecated.get("decimal")) is not None: 

566 warnings.warn( 

567 "The argument `decimal` has been deprecated in favour of" 

568 + " `relative_tolerance` and `absolute_tolerance`, with different" 

569 + " validation logic, using `numpy.testing.assert_allclose, see" 

570 + " 'https://numpy.org/doc/stable/reference/generated/" 

571 + " numpy.testing.assert_allclose.html'. Passing a value for `decimal`" 

572 + " will cause validation to revert to the old behaviour." 

573 ) 

574 atol = 1.5 * 10 ** (-decimal) 

575 rtol = 0 

576 mismatched_tol = 0 

577 else: 

578 # use given (deprecated) test kwargs 

579 atol = deprecated.get("absolute_tolerance", 1e-5) 

580 rtol = deprecated.get("relative_tolerance", 1e-3) 

581 mismatched_tol = 0 

582 

583 return rtol, atol, mismatched_tol 

584 

585 

586def _test_model_inference( 

587 model: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

588 weight_format: SupportedWeightsFormat, 

589 devices: Optional[Sequence[str]], 

590 **deprecated: Unpack[DeprecatedKwargs], 

591) -> None: 

592 test_name = f"Reproduce test outputs from test inputs ({weight_format})" 

593 logger.debug("starting '{}'", test_name) 

594 errors: List[ErrorEntry] = [] 

595 

596 def add_error_entry(msg: str, with_traceback: bool = False): 

597 errors.append( 

598 ErrorEntry( 

599 loc=("weights", weight_format), 

600 msg=msg, 

601 type="bioimageio.core", 

602 with_traceback=with_traceback, 

603 ) 

604 ) 

605 

606 try: 

607 inputs = get_test_inputs(model) 

608 expected = get_test_outputs(model) 

609 

610 with create_prediction_pipeline( 

611 bioimageio_model=model, devices=devices, weight_format=weight_format 

612 ) as prediction_pipeline: 

613 results = prediction_pipeline.predict_sample_without_blocking(inputs) 

614 

615 if len(results.members) != len(expected.members): 

616 add_error_entry( 

617 f"Expected {len(expected.members)} outputs, but got {len(results.members)}" 

618 ) 

619 

620 else: 

621 for m, expected in expected.members.items(): 

622 actual = results.members.get(m) 

623 if actual is None: 

624 add_error_entry("Output tensors for test case may not be None") 

625 break 

626 

627 rtol, atol, mismatched_tol = _get_tolerance( 

628 model, wf=weight_format, m=m, **deprecated 

629 ) 

630 mismatched = (abs_diff := abs(actual - expected)) > atol + rtol * abs( 

631 expected 

632 ) 

633 mismatched_elements = mismatched.sum().item() 

634 if mismatched_elements / expected.size > mismatched_tol / 1e6: 

635 r_max_idx = (r_diff := (abs_diff / (abs(expected) + 1e-6))).argmax() 

636 r_max = r_diff[r_max_idx].item() 

637 r_actual = actual[r_max_idx].item() 

638 r_expected = expected[r_max_idx].item() 

639 a_max_idx = abs_diff.argmax() 

640 a_max = abs_diff[a_max_idx].item() 

641 a_actual = actual[a_max_idx].item() 

642 a_expected = expected[a_max_idx].item() 

643 add_error_entry( 

644 f"Output '{m}' disagrees with {mismatched_elements} of" 

645 + f" {expected.size} expected values." 

646 + f"\n Max relative difference: {r_max:.2e}" 

647 + rf" (= \|{r_actual:.2e} - {r_expected:.2e}\|/\|{r_expected:.2e} + 1e-6\|)" 

648 + f" at {r_max_idx}" 

649 + f"\n Max absolute difference: {a_max:.2e}" 

650 + rf" (= \|{a_actual:.7e} - {a_expected:.7e}\|) at {a_max_idx}" 

651 ) 

652 break 

653 except Exception as e: 

654 if get_validation_context().raise_errors: 

655 raise e 

656 

657 add_error_entry(str(e), with_traceback=True) 

658 

659 model.validation_summary.add_detail( 

660 ValidationDetail( 

661 name=test_name, 

662 loc=("weights", weight_format), 

663 status="failed" if errors else "passed", 

664 recommended_env=get_conda_env(entry=dict(model.weights)[weight_format]), 

665 errors=errors, 

666 ) 

667 ) 

668 

669 

670def _test_model_inference_parametrized( 

671 model: v0_5.ModelDescr, 

672 weight_format: SupportedWeightsFormat, 

673 devices: Optional[Sequence[str]], 

674 *, 

675 stop_early: bool, 

676) -> None: 

677 if not any( 

678 isinstance(a.size, v0_5.ParameterizedSize) 

679 for ipt in model.inputs 

680 for a in ipt.axes 

681 ): 

682 # no parameterized sizes => set n=0 

683 ns: Set[v0_5.ParameterizedSize_N] = {0} 

684 else: 

685 ns = {0, 1, 2} 

686 

687 given_batch_sizes = { 

688 a.size 

689 for ipt in model.inputs 

690 for a in ipt.axes 

691 if isinstance(a, v0_5.BatchAxis) 

692 } 

693 if given_batch_sizes: 

694 batch_sizes = {gbs for gbs in given_batch_sizes if gbs is not None} 

695 if not batch_sizes: 

696 # only arbitrary batch sizes 

697 batch_sizes = {1, 2} 

698 else: 

699 # no batch axis 

700 batch_sizes = {1} 

701 

702 test_cases: Set[Tuple[BatchSize, v0_5.ParameterizedSize_N]] = { 

703 (b, n) for b, n in product(sorted(batch_sizes), sorted(ns)) 

704 } 

705 logger.info( 

706 "Testing inference with {} different inputs (B, N): {}", 

707 len(test_cases), 

708 test_cases, 

709 ) 

710 

711 def generate_test_cases(): 

712 tested: Set[Hashable] = set() 

713 

714 def get_ns(n: int): 

715 return { 

716 (t.id, a.id): n 

717 for t in model.inputs 

718 for a in t.axes 

719 if isinstance(a.size, v0_5.ParameterizedSize) 

720 } 

721 

722 for batch_size, n in sorted(test_cases): 

723 input_target_sizes, expected_output_sizes = model.get_axis_sizes( 

724 get_ns(n), batch_size=batch_size 

725 ) 

726 hashable_target_size = tuple( 

727 (k, input_target_sizes[k]) for k in sorted(input_target_sizes) 

728 ) 

729 if hashable_target_size in tested: 

730 continue 

731 else: 

732 tested.add(hashable_target_size) 

733 

734 resized_test_inputs = Sample( 

735 members={ 

736 t.id: ( 

737 test_inputs.members[t.id].resize_to( 

738 { 

739 aid: s 

740 for (tid, aid), s in input_target_sizes.items() 

741 if tid == t.id 

742 }, 

743 ) 

744 ) 

745 for t in model.inputs 

746 }, 

747 stat=test_inputs.stat, 

748 id=test_inputs.id, 

749 ) 

750 expected_output_shapes = { 

751 t.id: { 

752 aid: s 

753 for (tid, aid), s in expected_output_sizes.items() 

754 if tid == t.id 

755 } 

756 for t in model.outputs 

757 } 

758 yield n, batch_size, resized_test_inputs, expected_output_shapes 

759 

760 try: 

761 test_inputs = get_test_inputs(model) 

762 

763 with create_prediction_pipeline( 

764 bioimageio_model=model, devices=devices, weight_format=weight_format 

765 ) as prediction_pipeline: 

766 for n, batch_size, inputs, exptected_output_shape in generate_test_cases(): 

767 error: Optional[str] = None 

768 result = prediction_pipeline.predict_sample_without_blocking(inputs) 

769 if len(result.members) != len(exptected_output_shape): 

770 error = ( 

771 f"Expected {len(exptected_output_shape)} outputs," 

772 + f" but got {len(result.members)}" 

773 ) 

774 

775 else: 

776 for m, exp in exptected_output_shape.items(): 

777 res = result.members.get(m) 

778 if res is None: 

779 error = "Output tensors may not be None for test case" 

780 break 

781 

782 diff: Dict[AxisId, int] = {} 

783 for a, s in res.sizes.items(): 

784 if isinstance((e_aid := exp[AxisId(a)]), int): 

785 if s != e_aid: 

786 diff[AxisId(a)] = s 

787 elif ( 

788 s < e_aid.min or e_aid.max is not None and s > e_aid.max 

789 ): 

790 diff[AxisId(a)] = s 

791 if diff: 

792 error = ( 

793 f"(n={n}) Expected output shape {exp}," 

794 + f" but got {res.sizes} (diff: {diff})" 

795 ) 

796 break 

797 

798 model.validation_summary.add_detail( 

799 ValidationDetail( 

800 name=f"Run {weight_format} inference for inputs with" 

801 + f" batch_size: {batch_size} and size parameter n: {n}", 

802 loc=("weights", weight_format), 

803 status="passed" if error is None else "failed", 

804 errors=( 

805 [] 

806 if error is None 

807 else [ 

808 ErrorEntry( 

809 loc=("weights", weight_format), 

810 msg=error, 

811 type="bioimageio.core", 

812 ) 

813 ] 

814 ), 

815 ) 

816 ) 

817 if stop_early and error is not None: 

818 break 

819 except Exception as e: 

820 if get_validation_context().raise_errors: 

821 raise e 

822 

823 model.validation_summary.add_detail( 

824 ValidationDetail( 

825 name=f"Run {weight_format} inference for parametrized inputs", 

826 status="failed", 

827 loc=("weights", weight_format), 

828 errors=[ 

829 ErrorEntry( 

830 loc=("weights", weight_format), 

831 msg=str(e), 

832 type="bioimageio.core", 

833 with_traceback=True, 

834 ) 

835 ], 

836 ) 

837 ) 

838 

839 

840def _test_expected_resource_type( 

841 rd: Union[InvalidDescr, ResourceDescr], expected_type: str 

842): 

843 has_expected_type = rd.type == expected_type 

844 rd.validation_summary.details.append( 

845 ValidationDetail( 

846 name="Has expected resource type", 

847 status="passed" if has_expected_type else "failed", 

848 loc=("type",), 

849 errors=( 

850 [] 

851 if has_expected_type 

852 else [ 

853 ErrorEntry( 

854 loc=("type",), 

855 type="type", 

856 msg=f"Expected type {expected_type}, found {rd.type}", 

857 ) 

858 ] 

859 ), 

860 ) 

861 ) 

862 

863 

864# TODO: Implement `debug_model()` 

865# def debug_model( 

866# model_rdf: Union[RawResourceDescr, ResourceDescr, URI, Path, str], 

867# *, 

868# weight_format: Optional[WeightsFormat] = None, 

869# devices: Optional[List[str]] = None, 

870# ): 

871# """Run the model test and return dict with inputs, results, expected results and intermediates. 

872 

873# Returns dict with tensors "inputs", "inputs_processed", "outputs_raw", "outputs", "expected" and "diff". 

874# """ 

875# inputs_raw: Optional = None 

876# inputs_processed: Optional = None 

877# outputs_raw: Optional = None 

878# outputs: Optional = None 

879# expected: Optional = None 

880# diff: Optional = None 

881 

882# model = load_description( 

883# model_rdf, weights_priority_order=None if weight_format is None else [weight_format] 

884# ) 

885# if not isinstance(model, Model): 

886# raise ValueError(f"Not a bioimageio.model: {model_rdf}") 

887 

888# prediction_pipeline = create_prediction_pipeline( 

889# bioimageio_model=model, devices=devices, weight_format=weight_format 

890# ) 

891# inputs = [ 

892# xr.DataArray(load_array(str(in_path)), dims=input_spec.axes) 

893# for in_path, input_spec in zip(model.test_inputs, model.inputs) 

894# ] 

895# input_dict = {input_spec.name: input for input_spec, input in zip(model.inputs, inputs)} 

896 

897# # keep track of the non-processed inputs 

898# inputs_raw = [deepcopy(input) for input in inputs] 

899 

900# computed_measures = {} 

901 

902# prediction_pipeline.apply_preprocessing(input_dict, computed_measures) 

903# inputs_processed = list(input_dict.values()) 

904# outputs_raw = prediction_pipeline.predict(*inputs_processed) 

905# output_dict = {output_spec.name: deepcopy(output) for output_spec, output in zip(model.outputs, outputs_raw)} 

906# prediction_pipeline.apply_postprocessing(output_dict, computed_measures) 

907# outputs = list(output_dict.values()) 

908 

909# if isinstance(outputs, (np.ndarray, xr.DataArray)): 

910# outputs = [outputs] 

911 

912# expected = [ 

913# xr.DataArray(load_array(str(out_path)), dims=output_spec.axes) 

914# for out_path, output_spec in zip(model.test_outputs, model.outputs) 

915# ] 

916# if len(outputs) != len(expected): 

917# error = f"Number of outputs and number of expected outputs disagree: {len(outputs)} != {len(expected)}" 

918# print(error) 

919# else: 

920# diff = [] 

921# for res, exp in zip(outputs, expected): 

922# diff.append(res - exp) 

923 

924# return { 

925# "inputs": inputs_raw, 

926# "inputs_processed": inputs_processed, 

927# "outputs_raw": outputs_raw, 

928# "outputs": outputs, 

929# "expected": expected, 

930# "diff": diff, 

931# }