Coverage for bioimageio/core/_resource_tests.py: 58%
330 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-16 15:20 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-16 15:20 +0000
1import hashlib
2import os
3import platform
4import subprocess
5import warnings
6from io import StringIO
7from itertools import product
8from pathlib import Path
9from tempfile import TemporaryDirectory
10from typing import (
11 Callable,
12 Dict,
13 Hashable,
14 List,
15 Literal,
16 Optional,
17 Sequence,
18 Set,
19 Tuple,
20 Union,
21 overload,
22)
24import xarray as xr
25from loguru import logger
26from typing_extensions import NotRequired, TypedDict, Unpack, assert_never, get_args
28from bioimageio.spec import (
29 BioimageioCondaEnv,
30 InvalidDescr,
31 LatestResourceDescr,
32 ResourceDescr,
33 ValidationContext,
34 build_description,
35 dump_description,
36 get_conda_env,
37 load_description,
38 save_bioimageio_package,
39)
40from bioimageio.spec._description_impl import DISCOVER
41from bioimageio.spec._internal.common_nodes import ResourceDescrBase
42from bioimageio.spec._internal.io import is_yaml_value
43from bioimageio.spec._internal.io_utils import read_yaml, write_yaml
44from bioimageio.spec._internal.types import (
45 AbsoluteTolerance,
46 FormatVersionPlaceholder,
47 MismatchedElementsPerMillion,
48 RelativeTolerance,
49)
50from bioimageio.spec._internal.validation_context import get_validation_context
51from bioimageio.spec.common import BioimageioYamlContent, PermissiveFileSource, Sha256
52from bioimageio.spec.model import v0_4, v0_5
53from bioimageio.spec.model.v0_5 import WeightsFormat
54from bioimageio.spec.summary import (
55 ErrorEntry,
56 InstalledPackage,
57 ValidationDetail,
58 ValidationSummary,
59 WarningEntry,
60)
62from ._prediction_pipeline import create_prediction_pipeline
63from .axis import AxisId, BatchSize
64from .common import MemberId, SupportedWeightsFormat
65from .digest_spec import get_test_inputs, get_test_outputs
66from .sample import Sample
67from .utils import VERSION
70class DeprecatedKwargs(TypedDict):
71 absolute_tolerance: NotRequired[AbsoluteTolerance]
72 relative_tolerance: NotRequired[RelativeTolerance]
73 decimal: NotRequired[Optional[int]]
76def enable_determinism(
77 mode: Literal["seed_only", "full"] = "full",
78 weight_formats: Optional[Sequence[SupportedWeightsFormat]] = None,
79):
80 """Seed and configure ML frameworks for maximum reproducibility.
81 May degrade performance. Only recommended for testing reproducibility!
83 Seed any random generators and (if **mode**=="full") request ML frameworks to use
84 deterministic algorithms.
86 Args:
87 mode: determinism mode
88 - 'seed_only' -- only set seeds, or
89 - 'full' determinsm features (might degrade performance or throw exceptions)
90 weight_formats: Limit deep learning importing deep learning frameworks
91 based on weight_formats.
92 E.g. this allows to avoid importing tensorflow when testing with pytorch.
94 Notes:
95 - **mode** == "full" might degrade performance or throw exceptions.
96 - Subsequent inference calls might still differ. Call before each function
97 (sequence) that is expected to be reproducible.
98 - Degraded performance: Use for testing reproducibility only!
99 - Recipes:
100 - [PyTorch](https://pytorch.org/docs/stable/notes/randomness.html)
101 - [Keras](https://keras.io/examples/keras_recipes/reproducibility_recipes/)
102 - [NumPy](https://numpy.org/doc/2.0/reference/random/generated/numpy.random.seed.html)
103 """
104 try:
105 try:
106 import numpy.random
107 except ImportError:
108 pass
109 else:
110 numpy.random.seed(0)
111 except Exception as e:
112 logger.debug(str(e))
114 if (
115 weight_formats is None
116 or "pytorch_state_dict" in weight_formats
117 or "torchscript" in weight_formats
118 ):
119 try:
120 try:
121 import torch
122 except ImportError:
123 pass
124 else:
125 _ = torch.manual_seed(0)
126 torch.use_deterministic_algorithms(mode == "full")
127 except Exception as e:
128 logger.debug(str(e))
130 if (
131 weight_formats is None
132 or "tensorflow_saved_model_bundle" in weight_formats
133 or "keras_hdf5" in weight_formats
134 ):
135 try:
136 os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
137 try:
138 import tensorflow as tf # pyright: ignore[reportMissingTypeStubs]
139 except ImportError:
140 pass
141 else:
142 tf.random.set_seed(0)
143 if mode == "full":
144 tf.config.experimental.enable_op_determinism()
145 # TODO: find possibility to switch it off again??
146 except Exception as e:
147 logger.debug(str(e))
149 if weight_formats is None or "keras_hdf5" in weight_formats:
150 try:
151 try:
152 import keras # pyright: ignore[reportMissingTypeStubs]
153 except ImportError:
154 pass
155 else:
156 keras.utils.set_random_seed(0)
157 except Exception as e:
158 logger.debug(str(e))
161def test_model(
162 source: Union[v0_4.ModelDescr, v0_5.ModelDescr, PermissiveFileSource],
163 weight_format: Optional[SupportedWeightsFormat] = None,
164 devices: Optional[List[str]] = None,
165 *,
166 determinism: Literal["seed_only", "full"] = "seed_only",
167 sha256: Optional[Sha256] = None,
168 stop_early: bool = False,
169 **deprecated: Unpack[DeprecatedKwargs],
170) -> ValidationSummary:
171 """Test model inference"""
172 return test_description(
173 source,
174 weight_format=weight_format,
175 devices=devices,
176 determinism=determinism,
177 expected_type="model",
178 sha256=sha256,
179 stop_early=stop_early,
180 **deprecated,
181 )
184def default_run_command(args: Sequence[str]):
185 logger.info("running '{}'...", " ".join(args))
186 _ = subprocess.run(args, shell=True, text=True, check=True)
189def test_description(
190 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
191 *,
192 format_version: Union[FormatVersionPlaceholder, str] = "discover",
193 weight_format: Optional[SupportedWeightsFormat] = None,
194 devices: Optional[Sequence[str]] = None,
195 determinism: Literal["seed_only", "full"] = "seed_only",
196 expected_type: Optional[str] = None,
197 sha256: Optional[Sha256] = None,
198 stop_early: bool = False,
199 runtime_env: Union[
200 Literal["currently-active", "as-described"], Path, BioimageioCondaEnv
201 ] = ("currently-active"),
202 run_command: Callable[[Sequence[str]], None] = default_run_command,
203 **deprecated: Unpack[DeprecatedKwargs],
204) -> ValidationSummary:
205 """Test a bioimage.io resource dynamically,
206 for example run prediction of test tensors for models.
208 Args:
209 source: model description source.
210 weight_format: Weight format to test.
211 Default: All weight formats present in **source**.
212 devices: Devices to test with, e.g. 'cpu', 'cuda'.
213 Default (may be weight format dependent): ['cuda'] if available, ['cpu'] otherwise.
214 determinism: Modes to improve reproducibility of test outputs.
215 expected_type: Assert an expected resource description `type`.
216 sha256: Expected SHA256 value of **source**.
217 (Ignored if **source** already is a loaded `ResourceDescr` object.)
218 stop_early: Do not run further subtests after a failed one.
219 runtime_env: (Experimental feature!) The Python environment to run the tests in
220 - `"currently-active"`: Use active Python interpreter.
221 - `"as-described"`: Use `bioimageio.spec.get_conda_env` to generate a conda
222 environment YAML file based on the model weights description.
223 - A `BioimageioCondaEnv` or a path to a conda environment YAML file.
224 Note: The `bioimageio.core` dependency will be added automatically if not present.
225 run_command: (Experimental feature!) Function to execute (conda) terminal commands in a subprocess
226 (ignored if **runtime_env** is `"currently-active"`).
227 """
228 if runtime_env == "currently-active":
229 rd = load_description_and_test(
230 source,
231 format_version=format_version,
232 weight_format=weight_format,
233 devices=devices,
234 determinism=determinism,
235 expected_type=expected_type,
236 sha256=sha256,
237 stop_early=stop_early,
238 **deprecated,
239 )
240 return rd.validation_summary
242 if runtime_env == "as-described":
243 conda_env = None
244 elif isinstance(runtime_env, (str, Path)):
245 conda_env = BioimageioCondaEnv.model_validate(read_yaml(Path(runtime_env)))
246 elif isinstance(runtime_env, BioimageioCondaEnv):
247 conda_env = runtime_env
248 else:
249 assert_never(runtime_env)
251 with TemporaryDirectory(ignore_cleanup_errors=True) as _d:
252 working_dir = Path(_d)
253 if isinstance(source, (dict, ResourceDescrBase)):
254 file_source = save_bioimageio_package(
255 source, output_path=working_dir / "package.zip"
256 )
257 else:
258 file_source = source
260 return _test_in_env(
261 file_source,
262 working_dir=working_dir,
263 weight_format=weight_format,
264 conda_env=conda_env,
265 devices=devices,
266 determinism=determinism,
267 expected_type=expected_type,
268 sha256=sha256,
269 stop_early=stop_early,
270 run_command=run_command,
271 **deprecated,
272 )
275def _test_in_env(
276 source: PermissiveFileSource,
277 *,
278 working_dir: Path,
279 weight_format: Optional[SupportedWeightsFormat],
280 conda_env: Optional[BioimageioCondaEnv],
281 devices: Optional[Sequence[str]],
282 determinism: Literal["seed_only", "full"],
283 run_command: Callable[[Sequence[str]], None],
284 stop_early: bool,
285 expected_type: Optional[str],
286 sha256: Optional[Sha256],
287 **deprecated: Unpack[DeprecatedKwargs],
288) -> ValidationSummary:
289 descr = load_description(source)
291 if not isinstance(descr, (v0_4.ModelDescr, v0_5.ModelDescr)):
292 raise NotImplementedError("Not yet implemented for non-model resources")
294 if weight_format is None:
295 all_present_wfs = [
296 wf for wf in get_args(WeightsFormat) if getattr(descr.weights, wf)
297 ]
298 ignore_wfs = [wf for wf in all_present_wfs if wf in ["tensorflow_js"]]
299 logger.info(
300 "Found weight formats {}. Start testing all{}...",
301 all_present_wfs,
302 f" (except: {', '.join(ignore_wfs)}) " if ignore_wfs else "",
303 )
304 summary = _test_in_env(
305 source,
306 working_dir=working_dir / all_present_wfs[0],
307 weight_format=all_present_wfs[0],
308 devices=devices,
309 determinism=determinism,
310 conda_env=conda_env,
311 run_command=run_command,
312 expected_type=expected_type,
313 sha256=sha256,
314 stop_early=stop_early,
315 **deprecated,
316 )
317 for wf in all_present_wfs[1:]:
318 additional_summary = _test_in_env(
319 source,
320 working_dir=working_dir / wf,
321 weight_format=wf,
322 devices=devices,
323 determinism=determinism,
324 conda_env=conda_env,
325 run_command=run_command,
326 expected_type=expected_type,
327 sha256=sha256,
328 stop_early=stop_early,
329 **deprecated,
330 )
331 for d in additional_summary.details:
332 # TODO: filter reduntant details; group details
333 summary.add_detail(d)
334 return summary
336 if weight_format == "pytorch_state_dict":
337 wf = descr.weights.pytorch_state_dict
338 elif weight_format == "torchscript":
339 wf = descr.weights.torchscript
340 elif weight_format == "keras_hdf5":
341 wf = descr.weights.keras_hdf5
342 elif weight_format == "onnx":
343 wf = descr.weights.onnx
344 elif weight_format == "tensorflow_saved_model_bundle":
345 wf = descr.weights.tensorflow_saved_model_bundle
346 elif weight_format == "tensorflow_js":
347 raise RuntimeError(
348 "testing 'tensorflow_js' is not supported by bioimageio.core"
349 )
350 else:
351 assert_never(weight_format)
353 assert wf is not None
354 if conda_env is None:
355 conda_env = get_conda_env(entry=wf)
357 # remove name as we crate a name based on the env description hash value
358 conda_env.name = None
360 dumped_env = conda_env.model_dump(mode="json", exclude_none=True)
361 if not is_yaml_value(dumped_env):
362 raise ValueError(f"Failed to dump conda env to valid YAML {conda_env}")
364 env_io = StringIO()
365 write_yaml(dumped_env, file=env_io)
366 encoded_env = env_io.getvalue().encode()
367 env_name = hashlib.sha256(encoded_env).hexdigest()
369 try:
370 run_command(["where" if platform.system() == "Windows" else "which", "conda"])
371 except Exception as e:
372 raise RuntimeError("Conda not available") from e
374 working_dir.mkdir(parents=True, exist_ok=True)
375 summary_path = working_dir / "summary.json"
376 try:
377 run_command(["conda", "activate", env_name])
378 except Exception:
379 path = working_dir / "env.yaml"
380 try:
381 _ = path.write_bytes(encoded_env)
382 logger.debug("written conda env to {}", path)
383 run_command(
384 ["conda", "env", "create", f"--file={path}", f"--name={env_name}"]
385 )
386 run_command(["conda", "activate", env_name])
387 except Exception as e:
388 summary = descr.validation_summary
389 summary.add_detail(
390 ValidationDetail(
391 name="Conda environment creation",
392 status="failed",
393 loc=("weights", weight_format),
394 recommended_env=conda_env,
395 errors=[
396 ErrorEntry(
397 loc=("weights", weight_format),
398 msg=str(e),
399 type="conda",
400 with_traceback=True,
401 )
402 ],
403 )
404 )
405 return summary
407 cmd = []
408 for summary_path_arg_name in ("summary", "summary-path"):
409 run_command(
410 cmd := (
411 [
412 "conda",
413 "run",
414 "-n",
415 env_name,
416 "bioimageio",
417 "test",
418 str(source),
419 f"--{summary_path_arg_name}={summary_path.as_posix()}",
420 f"--determinism={determinism}",
421 ]
422 + ([f"--expected-type={expected_type}"] if expected_type else [])
423 + (["--stop-early"] if stop_early else [])
424 )
425 )
426 if summary_path.exists():
427 break
428 else:
429 return ValidationSummary(
430 name="calling bioimageio test command",
431 source_name=str(source),
432 status="failed",
433 type="unknown",
434 format_version="unknown",
435 details=[
436 ValidationDetail(
437 name="run 'bioimageio test'",
438 errors=[
439 ErrorEntry(
440 loc=(),
441 type="bioimageio cli",
442 msg=f"test command '{' '.join(cmd)}' did not produce a summary file at {summary_path}",
443 )
444 ],
445 status="failed",
446 )
447 ],
448 env=set(),
449 )
451 return ValidationSummary.model_validate_json(summary_path.read_bytes())
454@overload
455def load_description_and_test(
456 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
457 *,
458 format_version: Literal["latest"],
459 weight_format: Optional[SupportedWeightsFormat] = None,
460 devices: Optional[Sequence[str]] = None,
461 determinism: Literal["seed_only", "full"] = "seed_only",
462 expected_type: Optional[str] = None,
463 sha256: Optional[Sha256] = None,
464 stop_early: bool = False,
465 **deprecated: Unpack[DeprecatedKwargs],
466) -> Union[LatestResourceDescr, InvalidDescr]: ...
469@overload
470def load_description_and_test(
471 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
472 *,
473 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER,
474 weight_format: Optional[SupportedWeightsFormat] = None,
475 devices: Optional[Sequence[str]] = None,
476 determinism: Literal["seed_only", "full"] = "seed_only",
477 expected_type: Optional[str] = None,
478 sha256: Optional[Sha256] = None,
479 stop_early: bool = False,
480 **deprecated: Unpack[DeprecatedKwargs],
481) -> Union[ResourceDescr, InvalidDescr]: ...
484def load_description_and_test(
485 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
486 *,
487 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER,
488 weight_format: Optional[SupportedWeightsFormat] = None,
489 devices: Optional[Sequence[str]] = None,
490 determinism: Literal["seed_only", "full"] = "seed_only",
491 expected_type: Optional[str] = None,
492 sha256: Optional[Sha256] = None,
493 stop_early: bool = False,
494 **deprecated: Unpack[DeprecatedKwargs],
495) -> Union[ResourceDescr, InvalidDescr]:
496 """Test a bioimage.io resource dynamically,
497 for example run prediction of test tensors for models.
499 See `test_description` for more details.
501 Returns:
502 A (possibly invalid) resource description object
503 with a populated `.validation_summary` attribute.
504 """
505 if isinstance(source, ResourceDescrBase):
506 root = source.root
507 file_name = source.file_name
508 if (
509 (
510 format_version
511 not in (
512 DISCOVER,
513 source.format_version,
514 ".".join(source.format_version.split(".")[:2]),
515 )
516 )
517 or (c := source.validation_summary.details[0].context) is None
518 or not c.perform_io_checks
519 ):
520 logger.debug(
521 "deserializing source to ensure we validate and test using format {} and perform io checks",
522 format_version,
523 )
524 source = dump_description(source)
525 else:
526 root = Path()
527 file_name = None
529 if isinstance(source, ResourceDescrBase):
530 rd = source
531 elif isinstance(source, dict):
532 # check context for a given root; default to root of source
533 context = get_validation_context(
534 ValidationContext(root=root, file_name=file_name)
535 ).replace(
536 perform_io_checks=True # make sure we perform io checks though
537 )
539 rd = build_description(
540 source,
541 format_version=format_version,
542 context=context,
543 )
544 else:
545 rd = load_description(
546 source, format_version=format_version, sha256=sha256, perform_io_checks=True
547 )
549 rd.validation_summary.env.add(
550 InstalledPackage(name="bioimageio.core", version=VERSION)
551 )
553 if expected_type is not None:
554 _test_expected_resource_type(rd, expected_type)
556 if isinstance(rd, (v0_4.ModelDescr, v0_5.ModelDescr)):
557 if weight_format is None:
558 weight_formats: List[SupportedWeightsFormat] = [
559 w for w, we in rd.weights if we is not None
560 ] # pyright: ignore[reportAssignmentType]
561 else:
562 weight_formats = [weight_format]
564 enable_determinism(determinism, weight_formats=weight_formats)
565 for w in weight_formats:
566 _test_model_inference(rd, w, devices, stop_early=stop_early, **deprecated)
567 if stop_early and rd.validation_summary.status == "failed":
568 break
570 if not isinstance(rd, v0_4.ModelDescr):
571 _test_model_inference_parametrized(
572 rd, w, devices, stop_early=stop_early
573 )
574 if stop_early and rd.validation_summary.status == "failed":
575 break
577 # TODO: add execution of jupyter notebooks
578 # TODO: add more tests
580 if rd.validation_summary.status == "valid-format":
581 rd.validation_summary.status = "passed"
583 return rd
586def _get_tolerance(
587 model: Union[v0_4.ModelDescr, v0_5.ModelDescr],
588 wf: SupportedWeightsFormat,
589 m: MemberId,
590 **deprecated: Unpack[DeprecatedKwargs],
591) -> Tuple[RelativeTolerance, AbsoluteTolerance, MismatchedElementsPerMillion]:
592 if isinstance(model, v0_5.ModelDescr):
593 applicable = v0_5.ReproducibilityTolerance()
595 # check legacy test kwargs for weight format specific tolerance
596 if model.config.bioimageio.model_extra is not None:
597 for weights_format, test_kwargs in model.config.bioimageio.model_extra.get(
598 "test_kwargs", {}
599 ).items():
600 if wf == weights_format:
601 applicable = v0_5.ReproducibilityTolerance(
602 relative_tolerance=test_kwargs.get("relative_tolerance", 1e-3),
603 absolute_tolerance=test_kwargs.get("absolute_tolerance", 1e-4),
604 )
605 break
607 # check for weights format and output tensor specific tolerance
608 for a in model.config.bioimageio.reproducibility_tolerance:
609 if (not a.weights_formats or wf in a.weights_formats) and (
610 not a.output_ids or m in a.output_ids
611 ):
612 applicable = a
613 break
615 rtol = applicable.relative_tolerance
616 atol = applicable.absolute_tolerance
617 mismatched_tol = applicable.mismatched_elements_per_million
618 elif (decimal := deprecated.get("decimal")) is not None:
619 warnings.warn(
620 "The argument `decimal` has been deprecated in favour of"
621 + " `relative_tolerance` and `absolute_tolerance`, with different"
622 + " validation logic, using `numpy.testing.assert_allclose, see"
623 + " 'https://numpy.org/doc/stable/reference/generated/"
624 + " numpy.testing.assert_allclose.html'. Passing a value for `decimal`"
625 + " will cause validation to revert to the old behaviour."
626 )
627 atol = 1.5 * 10 ** (-decimal)
628 rtol = 0
629 mismatched_tol = 0
630 else:
631 # use given (deprecated) test kwargs
632 atol = deprecated.get("absolute_tolerance", 1e-5)
633 rtol = deprecated.get("relative_tolerance", 1e-3)
634 mismatched_tol = 0
636 return rtol, atol, mismatched_tol
639def _test_model_inference(
640 model: Union[v0_4.ModelDescr, v0_5.ModelDescr],
641 weight_format: SupportedWeightsFormat,
642 devices: Optional[Sequence[str]],
643 stop_early: bool,
644 **deprecated: Unpack[DeprecatedKwargs],
645) -> None:
646 test_name = f"Reproduce test outputs from test inputs ({weight_format})"
647 logger.debug("starting '{}'", test_name)
648 error_entries: List[ErrorEntry] = []
649 warning_entries: List[WarningEntry] = []
651 def add_error_entry(msg: str, with_traceback: bool = False):
652 error_entries.append(
653 ErrorEntry(
654 loc=("weights", weight_format),
655 msg=msg,
656 type="bioimageio.core",
657 with_traceback=with_traceback,
658 )
659 )
661 def add_warning_entry(msg: str):
662 warning_entries.append(
663 WarningEntry(
664 loc=("weights", weight_format),
665 msg=msg,
666 type="bioimageio.core",
667 )
668 )
670 try:
671 inputs = get_test_inputs(model)
672 expected = get_test_outputs(model)
674 with create_prediction_pipeline(
675 bioimageio_model=model, devices=devices, weight_format=weight_format
676 ) as prediction_pipeline:
677 results = prediction_pipeline.predict_sample_without_blocking(inputs)
679 if len(results.members) != len(expected.members):
680 add_error_entry(
681 f"Expected {len(expected.members)} outputs, but got {len(results.members)}"
682 )
684 else:
685 for m, expected in expected.members.items():
686 actual = results.members.get(m)
687 if actual is None:
688 add_error_entry("Output tensors for test case may not be None")
689 if stop_early:
690 break
691 else:
692 continue
694 rtol, atol, mismatched_tol = _get_tolerance(
695 model, wf=weight_format, m=m, **deprecated
696 )
697 rtol_value = rtol * abs(expected)
698 abs_diff = abs(actual - expected)
699 mismatched = abs_diff > atol + rtol_value
700 mismatched_elements = mismatched.sum().item()
701 if not mismatched_elements:
702 continue
704 mismatched_ppm = mismatched_elements / expected.size * 1e6
705 abs_diff[~mismatched] = 0 # ignore non-mismatched elements
707 r_max_idx = (r_diff := (abs_diff / (abs(expected) + 1e-6))).argmax()
708 r_max = r_diff[r_max_idx].item()
709 r_actual = actual[r_max_idx].item()
710 r_expected = expected[r_max_idx].item()
712 # Calculate the max absolute difference with the relative tolerance subtracted
713 abs_diff_wo_rtol: xr.DataArray = xr.ufuncs.maximum(
714 (abs_diff - rtol_value).data, 0
715 )
716 a_max_idx = {
717 AxisId(k): int(v) for k, v in abs_diff_wo_rtol.argmax().items()
718 }
720 a_max = abs_diff[a_max_idx].item()
721 a_actual = actual[a_max_idx].item()
722 a_expected = expected[a_max_idx].item()
724 msg = (
725 f"Output '{m}' disagrees with {mismatched_elements} of"
726 + f" {expected.size} expected values"
727 + f" ({mismatched_ppm:.1f} ppm)."
728 + f"\n Max relative difference: {r_max:.2e}"
729 + rf" (= \|{r_actual:.2e} - {r_expected:.2e}\|/\|{r_expected:.2e} + 1e-6\|)"
730 + f" at {r_max_idx}"
731 + f"\n Max absolute difference not accounted for by relative tolerance: {a_max:.2e}"
732 + rf" (= \|{a_actual:.7e} - {a_expected:.7e}\|) at {a_max_idx}"
733 )
734 if mismatched_ppm > mismatched_tol:
735 add_error_entry(msg)
736 if stop_early:
737 break
738 else:
739 add_warning_entry(msg)
741 except Exception as e:
742 if get_validation_context().raise_errors:
743 raise e
745 add_error_entry(str(e), with_traceback=True)
747 model.validation_summary.add_detail(
748 ValidationDetail(
749 name=test_name,
750 loc=("weights", weight_format),
751 status="failed" if error_entries else "passed",
752 recommended_env=get_conda_env(entry=dict(model.weights)[weight_format]),
753 errors=error_entries,
754 warnings=warning_entries,
755 )
756 )
759def _test_model_inference_parametrized(
760 model: v0_5.ModelDescr,
761 weight_format: SupportedWeightsFormat,
762 devices: Optional[Sequence[str]],
763 *,
764 stop_early: bool,
765) -> None:
766 if not any(
767 isinstance(a.size, v0_5.ParameterizedSize)
768 for ipt in model.inputs
769 for a in ipt.axes
770 ):
771 # no parameterized sizes => set n=0
772 ns: Set[v0_5.ParameterizedSize_N] = {0}
773 else:
774 ns = {0, 1, 2}
776 given_batch_sizes = {
777 a.size
778 for ipt in model.inputs
779 for a in ipt.axes
780 if isinstance(a, v0_5.BatchAxis)
781 }
782 if given_batch_sizes:
783 batch_sizes = {gbs for gbs in given_batch_sizes if gbs is not None}
784 if not batch_sizes:
785 # only arbitrary batch sizes
786 batch_sizes = {1, 2}
787 else:
788 # no batch axis
789 batch_sizes = {1}
791 test_cases: Set[Tuple[BatchSize, v0_5.ParameterizedSize_N]] = {
792 (b, n) for b, n in product(sorted(batch_sizes), sorted(ns))
793 }
794 logger.info(
795 "Testing inference with {} different inputs (B, N): {}",
796 len(test_cases),
797 test_cases,
798 )
800 def generate_test_cases():
801 tested: Set[Hashable] = set()
803 def get_ns(n: int):
804 return {
805 (t.id, a.id): n
806 for t in model.inputs
807 for a in t.axes
808 if isinstance(a.size, v0_5.ParameterizedSize)
809 }
811 for batch_size, n in sorted(test_cases):
812 input_target_sizes, expected_output_sizes = model.get_axis_sizes(
813 get_ns(n), batch_size=batch_size
814 )
815 hashable_target_size = tuple(
816 (k, input_target_sizes[k]) for k in sorted(input_target_sizes)
817 )
818 if hashable_target_size in tested:
819 continue
820 else:
821 tested.add(hashable_target_size)
823 resized_test_inputs = Sample(
824 members={
825 t.id: (
826 test_inputs.members[t.id].resize_to(
827 {
828 aid: s
829 for (tid, aid), s in input_target_sizes.items()
830 if tid == t.id
831 },
832 )
833 )
834 for t in model.inputs
835 },
836 stat=test_inputs.stat,
837 id=test_inputs.id,
838 )
839 expected_output_shapes = {
840 t.id: {
841 aid: s
842 for (tid, aid), s in expected_output_sizes.items()
843 if tid == t.id
844 }
845 for t in model.outputs
846 }
847 yield n, batch_size, resized_test_inputs, expected_output_shapes
849 try:
850 test_inputs = get_test_inputs(model)
852 with create_prediction_pipeline(
853 bioimageio_model=model, devices=devices, weight_format=weight_format
854 ) as prediction_pipeline:
855 for n, batch_size, inputs, exptected_output_shape in generate_test_cases():
856 error: Optional[str] = None
857 result = prediction_pipeline.predict_sample_without_blocking(inputs)
858 if len(result.members) != len(exptected_output_shape):
859 error = (
860 f"Expected {len(exptected_output_shape)} outputs,"
861 + f" but got {len(result.members)}"
862 )
864 else:
865 for m, exp in exptected_output_shape.items():
866 res = result.members.get(m)
867 if res is None:
868 error = "Output tensors may not be None for test case"
869 break
871 diff: Dict[AxisId, int] = {}
872 for a, s in res.sizes.items():
873 if isinstance((e_aid := exp[AxisId(a)]), int):
874 if s != e_aid:
875 diff[AxisId(a)] = s
876 elif (
877 s < e_aid.min or e_aid.max is not None and s > e_aid.max
878 ):
879 diff[AxisId(a)] = s
880 if diff:
881 error = (
882 f"(n={n}) Expected output shape {exp},"
883 + f" but got {res.sizes} (diff: {diff})"
884 )
885 break
887 model.validation_summary.add_detail(
888 ValidationDetail(
889 name=f"Run {weight_format} inference for inputs with"
890 + f" batch_size: {batch_size} and size parameter n: {n}",
891 loc=("weights", weight_format),
892 status="passed" if error is None else "failed",
893 errors=(
894 []
895 if error is None
896 else [
897 ErrorEntry(
898 loc=("weights", weight_format),
899 msg=error,
900 type="bioimageio.core",
901 )
902 ]
903 ),
904 )
905 )
906 if stop_early and error is not None:
907 break
908 except Exception as e:
909 if get_validation_context().raise_errors:
910 raise e
912 model.validation_summary.add_detail(
913 ValidationDetail(
914 name=f"Run {weight_format} inference for parametrized inputs",
915 status="failed",
916 loc=("weights", weight_format),
917 errors=[
918 ErrorEntry(
919 loc=("weights", weight_format),
920 msg=str(e),
921 type="bioimageio.core",
922 with_traceback=True,
923 )
924 ],
925 )
926 )
929def _test_expected_resource_type(
930 rd: Union[InvalidDescr, ResourceDescr], expected_type: str
931):
932 has_expected_type = rd.type == expected_type
933 rd.validation_summary.details.append(
934 ValidationDetail(
935 name="Has expected resource type",
936 status="passed" if has_expected_type else "failed",
937 loc=("type",),
938 errors=(
939 []
940 if has_expected_type
941 else [
942 ErrorEntry(
943 loc=("type",),
944 type="type",
945 msg=f"Expected type {expected_type}, found {rd.type}",
946 )
947 ]
948 ),
949 )
950 )
953# TODO: Implement `debug_model()`
954# def debug_model(
955# model_rdf: Union[RawResourceDescr, ResourceDescr, URI, Path, str],
956# *,
957# weight_format: Optional[WeightsFormat] = None,
958# devices: Optional[List[str]] = None,
959# ):
960# """Run the model test and return dict with inputs, results, expected results and intermediates.
962# Returns dict with tensors "inputs", "inputs_processed", "outputs_raw", "outputs", "expected" and "diff".
963# """
964# inputs_raw: Optional = None
965# inputs_processed: Optional = None
966# outputs_raw: Optional = None
967# outputs: Optional = None
968# expected: Optional = None
969# diff: Optional = None
971# model = load_description(
972# model_rdf, weights_priority_order=None if weight_format is None else [weight_format]
973# )
974# if not isinstance(model, Model):
975# raise ValueError(f"Not a bioimageio.model: {model_rdf}")
977# prediction_pipeline = create_prediction_pipeline(
978# bioimageio_model=model, devices=devices, weight_format=weight_format
979# )
980# inputs = [
981# xr.DataArray(load_array(str(in_path)), dims=input_spec.axes)
982# for in_path, input_spec in zip(model.test_inputs, model.inputs)
983# ]
984# input_dict = {input_spec.name: input for input_spec, input in zip(model.inputs, inputs)}
986# # keep track of the non-processed inputs
987# inputs_raw = [deepcopy(input) for input in inputs]
989# computed_measures = {}
991# prediction_pipeline.apply_preprocessing(input_dict, computed_measures)
992# inputs_processed = list(input_dict.values())
993# outputs_raw = prediction_pipeline.predict(*inputs_processed)
994# output_dict = {output_spec.name: deepcopy(output) for output_spec, output in zip(model.outputs, outputs_raw)}
995# prediction_pipeline.apply_postprocessing(output_dict, computed_measures)
996# outputs = list(output_dict.values())
998# if isinstance(outputs, (np.ndarray, xr.DataArray)):
999# outputs = [outputs]
1001# expected = [
1002# xr.DataArray(load_array(str(out_path)), dims=output_spec.axes)
1003# for out_path, output_spec in zip(model.test_outputs, model.outputs)
1004# ]
1005# if len(outputs) != len(expected):
1006# error = f"Number of outputs and number of expected outputs disagree: {len(outputs)} != {len(expected)}"
1007# print(error)
1008# else:
1009# diff = []
1010# for res, exp in zip(outputs, expected):
1011# diff.append(res - exp)
1013# return {
1014# "inputs": inputs_raw,
1015# "inputs_processed": inputs_processed,
1016# "outputs_raw": outputs_raw,
1017# "outputs": outputs,
1018# "expected": expected,
1019# "diff": diff,
1020# }