Coverage for src / bioimageio / core / _resource_tests.py: 73%
390 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-13 09:46 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-13 09:46 +0000
1import hashlib
2import os
3import platform
4import subprocess
5import sys
6import warnings
7from contextlib import nullcontext
8from io import StringIO
9from itertools import product
10from pathlib import Path
11from tempfile import TemporaryDirectory
12from typing import (
13 Any,
14 Callable,
15 Dict,
16 Hashable,
17 List,
18 Literal,
19 Optional,
20 Sequence,
21 Set,
22 Tuple,
23 Union,
24 overload,
25)
27import numpy as np
28from loguru import logger
29from numpy.typing import NDArray
30from typing_extensions import NotRequired, TypedDict, Unpack, assert_never, get_args
32from bioimageio.spec import (
33 AnyDatasetDescr,
34 AnyModelDescr,
35 BioimageioCondaEnv,
36 DatasetDescr,
37 InvalidDescr,
38 LatestResourceDescr,
39 ModelDescr,
40 ResourceDescr,
41 ValidationContext,
42 build_description,
43 dump_description,
44 get_conda_env,
45 load_description,
46 save_bioimageio_package,
47)
48from bioimageio.spec._description_impl import DISCOVER
49from bioimageio.spec._internal.common_nodes import ResourceDescrBase
50from bioimageio.spec._internal.io import is_yaml_value
51from bioimageio.spec._internal.io_utils import read_yaml, write_yaml
52from bioimageio.spec._internal.types import (
53 AbsoluteTolerance,
54 FormatVersionPlaceholder,
55 MismatchedElementsPerMillion,
56 RelativeTolerance,
57)
58from bioimageio.spec._internal.validation_context import get_validation_context
59from bioimageio.spec.common import BioimageioYamlContent, PermissiveFileSource, Sha256
60from bioimageio.spec.model import v0_4, v0_5
61from bioimageio.spec.model.v0_5 import WeightsFormat
62from bioimageio.spec.summary import (
63 ErrorEntry,
64 InstalledPackage,
65 ValidationDetail,
66 ValidationSummary,
67 WarningEntry,
68)
70from . import __version__
71from ._prediction_pipeline import create_prediction_pipeline
72from ._settings import settings
73from .axis import AxisId, BatchSize
74from .common import MemberId, SupportedWeightsFormat
75from .digest_spec import get_test_input_sample, get_test_output_sample
76from .io import save_tensor
77from .sample import Sample
79CONDA_CMD = "conda.bat" if platform.system() == "Windows" else "conda"
82class DeprecatedKwargs(TypedDict):
83 absolute_tolerance: NotRequired[AbsoluteTolerance]
84 relative_tolerance: NotRequired[RelativeTolerance]
85 decimal: NotRequired[Optional[int]]
88def enable_determinism(
89 mode: Literal["seed_only", "full"] = "full",
90 weight_formats: Optional[Sequence[SupportedWeightsFormat]] = None,
91):
92 """Seed and configure ML frameworks for maximum reproducibility.
93 May degrade performance. Only recommended for testing reproducibility!
95 Seed any random generators and (if **mode**=="full") request ML frameworks to use
96 deterministic algorithms.
98 Args:
99 mode: determinism mode
100 - 'seed_only' -- only set seeds, or
101 - 'full' determinsm features (might degrade performance or throw exceptions)
102 weight_formats: Limit deep learning importing deep learning frameworks
103 based on weight_formats.
104 E.g. this allows to avoid importing tensorflow when testing with pytorch.
106 Notes:
107 - **mode** == "full" might degrade performance or throw exceptions.
108 - Subsequent inference calls might still differ. Call before each function
109 (sequence) that is expected to be reproducible.
110 - Degraded performance: Use for testing reproducibility only!
111 - Recipes:
112 - [PyTorch](https://pytorch.org/docs/stable/notes/randomness.html)
113 - [Keras](https://keras.io/examples/keras_recipes/reproducibility_recipes/)
114 - [NumPy](https://numpy.org/doc/2.0/reference/random/generated/numpy.random.seed.html)
115 """
116 try:
117 try:
118 import numpy.random
119 except ImportError:
120 pass
121 else:
122 numpy.random.seed(0)
123 except Exception as e:
124 logger.debug(str(e))
126 if (
127 weight_formats is None
128 or "pytorch_state_dict" in weight_formats
129 or "torchscript" in weight_formats
130 ):
131 try:
132 try:
133 import torch
134 except ImportError:
135 pass
136 else:
137 _ = torch.manual_seed(0)
138 torch.use_deterministic_algorithms(mode == "full")
139 except Exception as e:
140 logger.debug(str(e))
142 if (
143 weight_formats is None
144 or "tensorflow_saved_model_bundle" in weight_formats
145 or "keras_hdf5" in weight_formats
146 ):
147 try:
148 os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
149 try:
150 import tensorflow as tf # pyright: ignore[reportMissingTypeStubs]
151 except ImportError:
152 pass
153 else:
154 tf.random.set_seed(0)
155 if mode == "full":
156 tf.config.experimental.enable_op_determinism()
157 # TODO: find possibility to switch it off again??
158 except Exception as e:
159 logger.debug(str(e))
161 if weight_formats is None or "keras_hdf5" in weight_formats:
162 try:
163 try:
164 import keras # pyright: ignore[reportMissingTypeStubs]
165 except ImportError:
166 pass
167 else:
168 keras.utils.set_random_seed(0)
169 except Exception as e:
170 logger.debug(str(e))
173def test_model(
174 source: Union[v0_4.ModelDescr, v0_5.ModelDescr, PermissiveFileSource],
175 weight_format: Optional[SupportedWeightsFormat] = None,
176 devices: Optional[List[str]] = None,
177 *,
178 determinism: Literal["seed_only", "full"] = "seed_only",
179 sha256: Optional[Sha256] = None,
180 stop_early: bool = True,
181 **deprecated: Unpack[DeprecatedKwargs],
182) -> ValidationSummary:
183 """Test model inference"""
184 return test_description(
185 source,
186 weight_format=weight_format,
187 devices=devices,
188 determinism=determinism,
189 expected_type="model",
190 sha256=sha256,
191 stop_early=stop_early,
192 **deprecated,
193 )
196def default_run_command(args: Sequence[str]):
197 logger.info("running '{}'...", " ".join(args))
198 _ = subprocess.check_call(args)
201def test_description(
202 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
203 *,
204 format_version: Union[FormatVersionPlaceholder, str] = "discover",
205 weight_format: Optional[SupportedWeightsFormat] = None,
206 devices: Optional[Sequence[str]] = None,
207 determinism: Literal["seed_only", "full"] = "seed_only",
208 expected_type: Optional[str] = None,
209 sha256: Optional[Sha256] = None,
210 stop_early: bool = True,
211 runtime_env: Union[
212 Literal["currently-active", "as-described"], Path, BioimageioCondaEnv
213 ] = ("currently-active"),
214 run_command: Callable[[Sequence[str]], None] = default_run_command,
215 working_dir: Optional[Union[os.PathLike[str], str]] = None,
216 **deprecated: Unpack[DeprecatedKwargs],
217) -> ValidationSummary:
218 """Test a bioimage.io resource dynamically,
219 for example run prediction of test tensors for models.
221 Args:
222 source: model description source.
223 weight_format: Weight format to test.
224 Default: All weight formats present in **source**.
225 devices: Devices to test with, e.g. 'cpu', 'cuda'.
226 Default (may be weight format dependent): ['cuda'] if available, ['cpu'] otherwise.
227 determinism: Modes to improve reproducibility of test outputs.
228 expected_type: Assert an expected resource description `type`.
229 sha256: Expected SHA256 value of **source**.
230 (Ignored if **source** already is a loaded `ResourceDescr` object.)
231 stop_early: Do not run further subtests after a failed one.
232 runtime_env: (Experimental feature!) The Python environment to run the tests in
233 - `"currently-active"`: Use active Python interpreter.
234 - `"as-described"`: Use `bioimageio.spec.get_conda_env` to generate a conda
235 environment YAML file based on the model weights description.
236 - A `BioimageioCondaEnv` or a path to a conda environment YAML file.
237 Note: The `bioimageio.core` dependency will be added automatically if not present.
238 run_command: (Experimental feature!) Function to execute (conda) terminal commands in a subprocess.
239 The function should raise an exception if the command fails.
240 **run_command** is ignored if **runtime_env** is `"currently-active"`.
241 working_dir: (for debugging) directory to save any temporary files
242 (model packages, conda environments, test summaries).
243 Defaults to a temporary directory.
244 """
245 if runtime_env == "currently-active":
246 rd = load_description_and_test(
247 source,
248 format_version=format_version,
249 weight_format=weight_format,
250 devices=devices,
251 determinism=determinism,
252 expected_type=expected_type,
253 sha256=sha256,
254 stop_early=stop_early,
255 **deprecated,
256 )
257 return rd.validation_summary
259 if runtime_env == "as-described":
260 conda_env = None
261 elif isinstance(runtime_env, (str, Path)):
262 conda_env = BioimageioCondaEnv.model_validate(read_yaml(Path(runtime_env)))
263 elif isinstance(runtime_env, BioimageioCondaEnv):
264 conda_env = runtime_env
265 else:
266 assert_never(runtime_env)
268 if run_command is not default_run_command:
269 try:
270 run_command(["thiscommandshouldalwaysfail", "please"])
271 except Exception:
272 pass
273 else:
274 raise RuntimeError(
275 "given run_command does not raise an exception for a failing command"
276 )
278 if working_dir is None:
279 td_kwargs: Dict[str, Any] = (
280 dict(ignore_cleanup_errors=True) if sys.version_info >= (3, 10) else {}
281 )
282 working_dir_ctxt = TemporaryDirectory(**td_kwargs)
283 else:
284 working_dir_ctxt = nullcontext(working_dir)
286 with working_dir_ctxt as _d:
287 working_dir = Path(_d)
289 if isinstance(source, ResourceDescrBase):
290 descr = source
291 elif isinstance(source, dict):
292 context = get_validation_context().replace(
293 perform_io_checks=True # make sure we perform io checks though
294 )
296 descr = build_description(source, context=context)
297 else:
298 descr = load_description(source, perform_io_checks=True)
300 if isinstance(descr, InvalidDescr):
301 return descr.validation_summary
302 elif isinstance(source, (dict, ResourceDescrBase)):
303 file_source = save_bioimageio_package(
304 descr, output_path=working_dir / "package.zip"
305 )
306 else:
307 file_source = source
309 _test_in_env(
310 file_source,
311 descr=descr,
312 working_dir=working_dir,
313 weight_format=weight_format,
314 conda_env=conda_env,
315 devices=devices,
316 determinism=determinism,
317 expected_type=expected_type,
318 sha256=sha256,
319 stop_early=stop_early,
320 run_command=run_command,
321 **deprecated,
322 )
324 return descr.validation_summary
327def _test_in_env(
328 source: PermissiveFileSource,
329 *,
330 descr: ResourceDescr,
331 working_dir: Path,
332 weight_format: Optional[SupportedWeightsFormat],
333 conda_env: Optional[BioimageioCondaEnv],
334 devices: Optional[Sequence[str]],
335 determinism: Literal["seed_only", "full"],
336 run_command: Callable[[Sequence[str]], None],
337 stop_early: bool,
338 expected_type: Optional[str],
339 sha256: Optional[Sha256],
340 **deprecated: Unpack[DeprecatedKwargs],
341):
342 """Test a bioimage.io resource in a given conda environment.
343 Adds details to the existing validation summary of **descr**.
344 """
345 if isinstance(descr, (v0_4.ModelDescr, v0_5.ModelDescr)):
346 if weight_format is None:
347 # run tests for all present weight formats
348 all_present_wfs = [
349 wf for wf in get_args(WeightsFormat) if getattr(descr.weights, wf)
350 ]
351 ignore_wfs = [wf for wf in all_present_wfs if wf in ["tensorflow_js"]]
352 logger.info(
353 "Found weight formats {}. Start testing all{}...",
354 all_present_wfs,
355 f" (except: {', '.join(ignore_wfs)}) " if ignore_wfs else "",
356 )
357 for wf in all_present_wfs:
358 _test_in_env(
359 source,
360 descr=descr,
361 working_dir=working_dir / wf,
362 weight_format=wf,
363 devices=devices,
364 determinism=determinism,
365 conda_env=conda_env,
366 run_command=run_command,
367 expected_type=expected_type,
368 sha256=sha256,
369 stop_early=stop_early,
370 **deprecated,
371 )
373 return
375 if weight_format == "pytorch_state_dict":
376 wf = descr.weights.pytorch_state_dict
377 elif weight_format == "torchscript":
378 wf = descr.weights.torchscript
379 elif weight_format == "keras_hdf5":
380 wf = descr.weights.keras_hdf5
381 elif weight_format == "onnx":
382 wf = descr.weights.onnx
383 elif weight_format == "tensorflow_saved_model_bundle":
384 wf = descr.weights.tensorflow_saved_model_bundle
385 elif weight_format == "tensorflow_js":
386 raise RuntimeError(
387 "testing 'tensorflow_js' is not supported by bioimageio.core"
388 )
389 else:
390 assert_never(weight_format)
391 assert wf is not None
392 if conda_env is None:
393 conda_env = get_conda_env(entry=wf)
395 test_loc = ("weights", weight_format)
396 else:
397 if conda_env is None:
398 warnings.warn(
399 "No conda environment description given for testing (And no default conda envs available for non-model descriptions)."
400 )
401 return
403 test_loc = ()
405 # remove name as we create a name based on the env description hash value
406 conda_env.name = None
408 dumped_env = conda_env.model_dump(mode="json", exclude_none=True)
409 if not is_yaml_value(dumped_env):
410 raise ValueError(f"Failed to dump conda env to valid YAML {conda_env}")
412 env_io = StringIO()
413 write_yaml(dumped_env, file=env_io)
414 encoded_env = env_io.getvalue().encode()
415 env_name = hashlib.sha256(encoded_env).hexdigest()
417 try:
418 run_command(["where" if platform.system() == "Windows" else "which", CONDA_CMD])
419 except Exception as e:
420 raise RuntimeError("Conda not available") from e
422 try:
423 run_command([CONDA_CMD, "run", "-n", env_name, "python", "--version"])
424 except Exception as e:
425 working_dir.mkdir(parents=True, exist_ok=True)
426 path = working_dir / "env.yaml"
427 try:
428 _ = path.write_bytes(encoded_env)
429 logger.debug("written conda env to {}", path)
430 run_command(
431 [
432 CONDA_CMD,
433 "env",
434 "create",
435 "--yes",
436 f"--file={path}",
437 f"--name={env_name}",
438 ]
439 + (["--quiet"] if settings.CI else [])
440 )
441 # double check that environment was created successfully
442 run_command([CONDA_CMD, "run", "-n", env_name, "python", "--version"])
443 except Exception as e:
444 descr.validation_summary.add_detail(
445 ValidationDetail(
446 name="Conda environment creation",
447 status="failed",
448 loc=test_loc,
449 recommended_env=conda_env,
450 errors=[
451 ErrorEntry(
452 loc=test_loc,
453 msg=str(e),
454 type="conda",
455 with_traceback=True,
456 )
457 ],
458 )
459 )
460 return
461 else:
462 descr.validation_summary.add_detail(
463 ValidationDetail(
464 name=f"Created conda environment '{env_name}'",
465 status="passed",
466 loc=test_loc,
467 )
468 )
469 else:
470 descr.validation_summary.add_detail(
471 ValidationDetail(
472 name=f"Found existing conda environment '{env_name}'",
473 status="passed",
474 loc=test_loc,
475 )
476 )
478 working_dir.mkdir(parents=True, exist_ok=True)
479 summary_path = working_dir / "summary.json"
480 assert not summary_path.exists(), "Summary file already exists"
481 cmd = []
482 cmd_error = None
483 for summary_path_arg_name in ("summary", "summary-path"):
484 try:
485 run_command(
486 cmd := (
487 [
488 CONDA_CMD,
489 "run",
490 "-n",
491 env_name,
492 "bioimageio",
493 "test",
494 str(source),
495 f"--{summary_path_arg_name}={summary_path.as_posix()}",
496 f"--determinism={determinism}",
497 ]
498 + ([f"--weight-format={weight_format}"] if weight_format else [])
499 + ([f"--expected-type={expected_type}"] if expected_type else [])
500 + (["--stop-early"] if stop_early else [])
501 )
502 )
503 except Exception as e:
504 cmd_error = f"Command '{' '.join(cmd)}' returned with error: {e}."
506 if summary_path.exists():
507 break
508 else:
509 if cmd_error is not None:
510 logger.warning(cmd_error)
512 descr.validation_summary.add_detail(
513 ValidationDetail(
514 name="run 'bioimageio test' command",
515 recommended_env=conda_env,
516 errors=[
517 ErrorEntry(
518 loc=(),
519 type="bioimageio cli",
520 msg=f"test command '{' '.join(cmd)}' did not produce a summary file at {summary_path}",
521 )
522 ],
523 status="failed",
524 )
525 )
526 return
528 # add relevant details from command summary
529 command_summary = ValidationSummary.load_json(summary_path)
530 for detail in command_summary.details:
531 if detail.loc[: len(test_loc)] == test_loc or detail.status == "failed":
532 descr.validation_summary.add_detail(detail)
535@overload
536def load_description_and_test(
537 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
538 *,
539 format_version: Literal["latest"],
540 weight_format: Optional[SupportedWeightsFormat] = None,
541 devices: Optional[Sequence[str]] = None,
542 determinism: Literal["seed_only", "full"] = "seed_only",
543 expected_type: Literal["model"],
544 sha256: Optional[Sha256] = None,
545 stop_early: bool = True,
546 **deprecated: Unpack[DeprecatedKwargs],
547) -> Union[ModelDescr, InvalidDescr]: ...
550@overload
551def load_description_and_test(
552 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
553 *,
554 format_version: Literal["latest"],
555 weight_format: Optional[SupportedWeightsFormat] = None,
556 devices: Optional[Sequence[str]] = None,
557 determinism: Literal["seed_only", "full"] = "seed_only",
558 expected_type: Literal["dataset"],
559 sha256: Optional[Sha256] = None,
560 stop_early: bool = True,
561 **deprecated: Unpack[DeprecatedKwargs],
562) -> Union[DatasetDescr, InvalidDescr]: ...
565@overload
566def load_description_and_test(
567 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
568 *,
569 format_version: Literal["latest"],
570 weight_format: Optional[SupportedWeightsFormat] = None,
571 devices: Optional[Sequence[str]] = None,
572 determinism: Literal["seed_only", "full"] = "seed_only",
573 expected_type: Optional[str] = None,
574 sha256: Optional[Sha256] = None,
575 stop_early: bool = True,
576 **deprecated: Unpack[DeprecatedKwargs],
577) -> Union[LatestResourceDescr, InvalidDescr]: ...
580@overload
581def load_description_and_test(
582 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
583 *,
584 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER,
585 weight_format: Optional[SupportedWeightsFormat] = None,
586 devices: Optional[Sequence[str]] = None,
587 determinism: Literal["seed_only", "full"] = "seed_only",
588 expected_type: Literal["model"],
589 sha256: Optional[Sha256] = None,
590 stop_early: bool = True,
591 **deprecated: Unpack[DeprecatedKwargs],
592) -> Union[AnyModelDescr, InvalidDescr]: ...
595@overload
596def load_description_and_test(
597 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
598 *,
599 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER,
600 weight_format: Optional[SupportedWeightsFormat] = None,
601 devices: Optional[Sequence[str]] = None,
602 determinism: Literal["seed_only", "full"] = "seed_only",
603 expected_type: Literal["dataset"],
604 sha256: Optional[Sha256] = None,
605 stop_early: bool = True,
606 **deprecated: Unpack[DeprecatedKwargs],
607) -> Union[AnyDatasetDescr, InvalidDescr]: ...
610@overload
611def load_description_and_test(
612 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
613 *,
614 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER,
615 weight_format: Optional[SupportedWeightsFormat] = None,
616 devices: Optional[Sequence[str]] = None,
617 determinism: Literal["seed_only", "full"] = "seed_only",
618 expected_type: Optional[str] = None,
619 sha256: Optional[Sha256] = None,
620 stop_early: bool = True,
621 **deprecated: Unpack[DeprecatedKwargs],
622) -> Union[ResourceDescr, InvalidDescr]: ...
625def load_description_and_test(
626 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
627 *,
628 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER,
629 weight_format: Optional[SupportedWeightsFormat] = None,
630 devices: Optional[Sequence[str]] = None,
631 determinism: Literal["seed_only", "full"] = "seed_only",
632 expected_type: Optional[str] = None,
633 sha256: Optional[Sha256] = None,
634 stop_early: bool = True,
635 **deprecated: Unpack[DeprecatedKwargs],
636) -> Union[ResourceDescr, InvalidDescr]:
637 """Test a bioimage.io resource dynamically,
638 for example run prediction of test tensors for models.
640 See `test_description` for more details.
642 Returns:
643 A (possibly invalid) resource description object
644 with a populated `.validation_summary` attribute.
645 """
646 if isinstance(source, ResourceDescrBase):
647 root = source.root
648 file_name = source.file_name
649 if (
650 (
651 format_version
652 not in (
653 DISCOVER,
654 source.format_version,
655 ".".join(source.format_version.split(".")[:2]),
656 )
657 )
658 or (c := source.validation_summary.details[0].context) is None
659 or not c.perform_io_checks
660 ):
661 logger.debug(
662 "deserializing source to ensure we validate and test using format {} and perform io checks",
663 format_version,
664 )
665 source = dump_description(source)
666 else:
667 root = Path()
668 file_name = None
670 if isinstance(source, ResourceDescrBase):
671 rd = source
672 elif isinstance(source, dict):
673 # check context for a given root; default to root of source
674 context = get_validation_context(
675 ValidationContext(root=root, file_name=file_name)
676 ).replace(
677 perform_io_checks=True # make sure we perform io checks though
678 )
680 rd = build_description(
681 source,
682 format_version=format_version,
683 context=context,
684 )
685 else:
686 rd = load_description(
687 source, format_version=format_version, sha256=sha256, perform_io_checks=True
688 )
690 rd.validation_summary.env.add(
691 InstalledPackage(name="bioimageio.core", version=__version__)
692 )
694 if expected_type is not None:
695 _test_expected_resource_type(rd, expected_type)
697 if isinstance(rd, (v0_4.ModelDescr, v0_5.ModelDescr)):
698 if weight_format is None:
699 weight_formats: List[SupportedWeightsFormat] = [
700 w for w, we in rd.weights if we is not None
701 ] # pyright: ignore[reportAssignmentType]
702 else:
703 weight_formats = [weight_format]
705 enable_determinism(determinism, weight_formats=weight_formats)
706 for w in weight_formats:
707 _test_model_inference(rd, w, devices, stop_early=stop_early, **deprecated)
708 if stop_early and rd.validation_summary.status != "passed":
709 break
711 if not isinstance(rd, v0_4.ModelDescr):
712 _test_model_inference_parametrized(
713 rd, w, devices, stop_early=stop_early
714 )
715 if stop_early and rd.validation_summary.status != "passed":
716 break
718 # TODO: add execution of jupyter notebooks
719 # TODO: add more tests
721 return rd
724def _get_tolerance(
725 model: Union[v0_4.ModelDescr, v0_5.ModelDescr],
726 wf: SupportedWeightsFormat,
727 m: MemberId,
728 **deprecated: Unpack[DeprecatedKwargs],
729) -> Tuple[RelativeTolerance, AbsoluteTolerance, MismatchedElementsPerMillion]:
730 if isinstance(model, v0_5.ModelDescr):
731 applicable = v0_5.ReproducibilityTolerance()
733 # check legacy test kwargs for weight format specific tolerance
734 if model.config.bioimageio.model_extra is not None:
735 for weights_format, test_kwargs in model.config.bioimageio.model_extra.get(
736 "test_kwargs", {}
737 ).items():
738 if wf == weights_format:
739 applicable = v0_5.ReproducibilityTolerance(
740 relative_tolerance=test_kwargs.get("relative_tolerance", 1e-3),
741 absolute_tolerance=test_kwargs.get("absolute_tolerance", 1e-3),
742 )
743 break
745 # check for weights format and output tensor specific tolerance
746 for a in model.config.bioimageio.reproducibility_tolerance:
747 if (not a.weights_formats or wf in a.weights_formats) and (
748 not a.output_ids or m in a.output_ids
749 ):
750 applicable = a
751 break
753 rtol = applicable.relative_tolerance
754 atol = applicable.absolute_tolerance
755 mismatched_tol = applicable.mismatched_elements_per_million
756 elif (decimal := deprecated.get("decimal")) is not None:
757 warnings.warn(
758 "The argument `decimal` has been deprecated in favour of"
759 + " `relative_tolerance` and `absolute_tolerance`, with different"
760 + " validation logic, using `numpy.testing.assert_allclose, see"
761 + " 'https://numpy.org/doc/stable/reference/generated/"
762 + " numpy.testing.assert_allclose.html'. Passing a value for `decimal`"
763 + " will cause validation to revert to the old behaviour."
764 )
765 atol = 1.5 * 10 ** (-decimal)
766 rtol = 0
767 mismatched_tol = 0
768 else:
769 # use given (deprecated) test kwargs
770 atol = deprecated.get("absolute_tolerance", 1e-3)
771 rtol = deprecated.get("relative_tolerance", 1e-3)
772 mismatched_tol = 0
774 return rtol, atol, mismatched_tol
777def _test_model_inference(
778 model: Union[v0_4.ModelDescr, v0_5.ModelDescr],
779 weight_format: SupportedWeightsFormat,
780 devices: Optional[Sequence[str]],
781 stop_early: bool,
782 **deprecated: Unpack[DeprecatedKwargs],
783) -> None:
784 test_name = f"Reproduce test outputs from test inputs ({weight_format})"
785 logger.debug("starting '{}'", test_name)
786 error_entries: List[ErrorEntry] = []
787 warning_entries: List[WarningEntry] = []
789 def add_error_entry(msg: str, with_traceback: bool = False):
790 error_entries.append(
791 ErrorEntry(
792 loc=("weights", weight_format),
793 msg=msg,
794 type="bioimageio.core",
795 with_traceback=with_traceback,
796 )
797 )
799 def add_warning_entry(msg: str):
800 warning_entries.append(
801 WarningEntry(
802 loc=("weights", weight_format),
803 msg=msg,
804 type="bioimageio.core",
805 )
806 )
808 try:
809 test_input = get_test_input_sample(model)
810 expected = get_test_output_sample(model)
812 with create_prediction_pipeline(
813 bioimageio_model=model, devices=devices, weight_format=weight_format
814 ) as prediction_pipeline:
815 results = prediction_pipeline.predict_sample_without_blocking(test_input)
817 if len(results.members) != len(expected.members):
818 add_error_entry(
819 f"Expected {len(expected.members)} outputs, but got {len(results.members)}"
820 )
822 else:
823 for m, expected in expected.members.items():
824 actual = results.members.get(m)
825 if actual is None:
826 add_error_entry("Output tensors for test case may not be None")
827 if stop_early:
828 break
829 else:
830 continue
832 if actual.dims != (dims := expected.dims):
833 add_error_entry(
834 f"Output '{m}' has dims {actual.dims}, but expected {expected.dims}"
835 )
836 if stop_early:
837 break
838 else:
839 continue
841 if actual.tagged_shape != expected.tagged_shape:
842 add_error_entry(
843 f"Output '{m}' has shape {actual.tagged_shape}, but expected {expected.tagged_shape}"
844 )
845 if stop_early:
846 break
847 else:
848 continue
850 try:
851 expected_np = expected.data.to_numpy().astype(np.float32)
852 del expected
853 actual_np: NDArray[Any] = actual.data.to_numpy().astype(np.float32)
855 rtol, atol, mismatched_tol = _get_tolerance(
856 model, wf=weight_format, m=m, **deprecated
857 )
858 rtol_value = rtol * abs(expected_np)
859 abs_diff = abs(actual_np - expected_np)
860 mismatched = abs_diff > atol + rtol_value
861 mismatched_elements = mismatched.sum().item()
862 if not mismatched_elements:
863 continue
865 actual_output_path = Path(f"actual_output_{m}_{weight_format}.npy")
866 try:
867 save_tensor(actual_output_path, actual)
868 except Exception as e:
869 logger.error(
870 "Failed to save actual output tensor to {}: {}",
871 actual_output_path,
872 e,
873 )
875 mismatched_ppm = mismatched_elements / expected_np.size * 1e6
876 abs_diff[~mismatched] = 0 # ignore non-mismatched elements
878 r_max_idx_flat = (
879 r_diff := (abs_diff / (abs(expected_np) + 1e-6))
880 ).argmax()
881 r_max_idx = np.unravel_index(r_max_idx_flat, r_diff.shape)
882 r_max = r_diff[r_max_idx].item()
883 r_actual = actual_np[r_max_idx].item()
884 r_expected = expected_np[r_max_idx].item()
886 # Calculate the max absolute difference with the relative tolerance subtracted
887 abs_diff_wo_rtol: NDArray[np.float32] = abs_diff - rtol_value
888 a_max_idx = np.unravel_index(
889 abs_diff_wo_rtol.argmax(), abs_diff_wo_rtol.shape
890 )
892 a_max = abs_diff[a_max_idx].item()
893 a_actual = actual_np[a_max_idx].item()
894 a_expected = expected_np[a_max_idx].item()
895 except Exception as e:
896 msg = f"Output '{m}' disagrees with expected values."
897 add_error_entry(msg)
898 if stop_early:
899 break
900 else:
901 msg = (
902 f"Output '{m}' disagrees with {mismatched_elements} of"
903 + f" {expected_np.size} expected values"
904 + f" ({mismatched_ppm:.1f} ppm)."
905 + f"\n Max relative difference not accounted for by absolute tolerance ({atol:.2e}): {r_max:.2e}"
906 + rf" (= \|{r_actual:.2e} - {r_expected:.2e}\|/\|{r_expected:.2e} + 1e-6\|)"
907 + f" at {dict(zip(dims, r_max_idx))}"
908 + f"\n Max absolute difference not accounted for by relative tolerance ({rtol:.2e}): {a_max:.2e}"
909 + rf" (= \|{a_actual:.7e} - {a_expected:.7e}\|) at {dict(zip(dims, a_max_idx))}"
910 + f"\n Saved actual output to {actual_output_path}."
911 )
912 if mismatched_ppm > mismatched_tol:
913 add_error_entry(msg)
914 if stop_early:
915 break
916 else:
917 add_warning_entry(msg)
919 except Exception as e:
920 if get_validation_context().raise_errors:
921 raise e
923 add_error_entry(str(e), with_traceback=True)
925 model.validation_summary.add_detail(
926 ValidationDetail(
927 name=test_name,
928 loc=("weights", weight_format),
929 status="failed" if error_entries else "passed",
930 recommended_env=get_conda_env(entry=dict(model.weights)[weight_format]),
931 errors=error_entries,
932 warnings=warning_entries,
933 )
934 )
937def _test_model_inference_parametrized(
938 model: v0_5.ModelDescr,
939 weight_format: SupportedWeightsFormat,
940 devices: Optional[Sequence[str]],
941 *,
942 stop_early: bool,
943) -> None:
944 if not any(
945 isinstance(a.size, v0_5.ParameterizedSize)
946 for ipt in model.inputs
947 for a in ipt.axes
948 ):
949 # no parameterized sizes => set n=0
950 ns: Set[v0_5.ParameterizedSize_N] = {0}
951 else:
952 ns = {0, 1, 2}
954 given_batch_sizes = {
955 a.size
956 for ipt in model.inputs
957 for a in ipt.axes
958 if isinstance(a, v0_5.BatchAxis)
959 }
960 if given_batch_sizes:
961 batch_sizes = {gbs for gbs in given_batch_sizes if gbs is not None}
962 if not batch_sizes:
963 # only arbitrary batch sizes
964 batch_sizes = {1, 2}
965 else:
966 # no batch axis
967 batch_sizes = {1}
969 test_cases: Set[Tuple[BatchSize, v0_5.ParameterizedSize_N]] = {
970 (b, n) for b, n in product(sorted(batch_sizes), sorted(ns))
971 }
972 logger.info(
973 "Testing inference with '{}' for {} different inputs (B, N): {}",
974 weight_format,
975 len(test_cases),
976 test_cases,
977 )
979 def generate_test_cases():
980 tested: Set[Hashable] = set()
982 def get_ns(n: int):
983 return {
984 (t.id, a.id): n
985 for t in model.inputs
986 for a in t.axes
987 if isinstance(a.size, v0_5.ParameterizedSize)
988 }
990 for batch_size, n in sorted(test_cases):
991 input_target_sizes, expected_output_sizes = model.get_axis_sizes(
992 get_ns(n), batch_size=batch_size
993 )
994 hashable_target_size = tuple(
995 (k, input_target_sizes[k]) for k in sorted(input_target_sizes)
996 )
997 if hashable_target_size in tested:
998 continue
999 else:
1000 tested.add(hashable_target_size)
1002 resized_test_inputs = Sample(
1003 members={
1004 t.id: (
1005 test_input.members[t.id].resize_to(
1006 {
1007 aid: s
1008 for (tid, aid), s in input_target_sizes.items()
1009 if tid == t.id
1010 },
1011 )
1012 )
1013 for t in model.inputs
1014 },
1015 stat=test_input.stat,
1016 id=test_input.id,
1017 )
1018 expected_output_shapes = {
1019 t.id: {
1020 aid: s
1021 for (tid, aid), s in expected_output_sizes.items()
1022 if tid == t.id
1023 }
1024 for t in model.outputs
1025 }
1026 yield n, batch_size, resized_test_inputs, expected_output_shapes
1028 try:
1029 test_input = get_test_input_sample(model)
1031 with create_prediction_pipeline(
1032 bioimageio_model=model, devices=devices, weight_format=weight_format
1033 ) as prediction_pipeline:
1034 for n, batch_size, inputs, exptected_output_shape in generate_test_cases():
1035 error: Optional[str] = None
1036 try:
1037 result = prediction_pipeline.predict_sample_without_blocking(inputs)
1038 except Exception as e:
1039 error = str(e)
1040 else:
1041 if len(result.members) != len(exptected_output_shape):
1042 error = (
1043 f"Expected {len(exptected_output_shape)} outputs,"
1044 + f" but got {len(result.members)}"
1045 )
1047 else:
1048 for m, exp in exptected_output_shape.items():
1049 res = result.members.get(m)
1050 if res is None:
1051 error = "Output tensors may not be None for test case"
1052 break
1054 diff: Dict[AxisId, int] = {}
1055 for a, s in res.sizes.items():
1056 if isinstance((e_aid := exp[AxisId(a)]), int):
1057 if s != e_aid:
1058 diff[AxisId(a)] = s
1059 elif (
1060 s < e_aid.min
1061 or e_aid.max is not None
1062 and s > e_aid.max
1063 ):
1064 diff[AxisId(a)] = s
1065 if diff:
1066 error = (
1067 f"(n={n}) Expected output shape {exp},"
1068 + f" but got {res.sizes} (diff: {diff})"
1069 )
1070 break
1072 model.validation_summary.add_detail(
1073 ValidationDetail(
1074 name=f"Run {weight_format} inference for inputs with"
1075 + f" batch_size: {batch_size} and size parameter n: {n}",
1076 loc=("weights", weight_format),
1077 status="passed" if error is None else "failed",
1078 errors=(
1079 []
1080 if error is None
1081 else [
1082 ErrorEntry(
1083 loc=("weights", weight_format),
1084 msg=error,
1085 type="bioimageio.core",
1086 )
1087 ]
1088 ),
1089 )
1090 )
1091 if stop_early and error is not None:
1092 break
1093 except Exception as e:
1094 if get_validation_context().raise_errors:
1095 raise e
1097 model.validation_summary.add_detail(
1098 ValidationDetail(
1099 name=f"Run {weight_format} inference for parametrized inputs",
1100 status="failed",
1101 loc=("weights", weight_format),
1102 errors=[
1103 ErrorEntry(
1104 loc=("weights", weight_format),
1105 msg=str(e),
1106 type="bioimageio.core",
1107 with_traceback=True,
1108 )
1109 ],
1110 )
1111 )
1114def _test_expected_resource_type(
1115 rd: Union[InvalidDescr, ResourceDescr], expected_type: str
1116):
1117 has_expected_type = rd.type == expected_type
1118 rd.validation_summary.details.append(
1119 ValidationDetail(
1120 name="Has expected resource type",
1121 status="passed" if has_expected_type else "failed",
1122 loc=("type",),
1123 errors=(
1124 []
1125 if has_expected_type
1126 else [
1127 ErrorEntry(
1128 loc=("type",),
1129 type="type",
1130 msg=f"Expected type {expected_type}, found {rd.type}",
1131 )
1132 ]
1133 ),
1134 )
1135 )
1138# TODO: Implement `debug_model()`
1139# def debug_model(
1140# model_rdf: Union[RawResourceDescr, ResourceDescr, URI, Path, str],
1141# *,
1142# weight_format: Optional[WeightsFormat] = None,
1143# devices: Optional[List[str]] = None,
1144# ):
1145# """Run the model test and return dict with inputs, results, expected results and intermediates.
1147# Returns dict with tensors "inputs", "inputs_processed", "outputs_raw", "outputs", "expected" and "diff".
1148# """
1149# inputs_raw: Optional = None
1150# inputs_processed: Optional = None
1151# outputs_raw: Optional = None
1152# outputs: Optional = None
1153# expected: Optional = None
1154# diff: Optional = None
1156# model = load_description(
1157# model_rdf, weights_priority_order=None if weight_format is None else [weight_format]
1158# )
1159# if not isinstance(model, Model):
1160# raise ValueError(f"Not a bioimageio.model: {model_rdf}")
1162# prediction_pipeline = create_prediction_pipeline(
1163# bioimageio_model=model, devices=devices, weight_format=weight_format
1164# )
1165# inputs = [
1166# xr.DataArray(load_array(str(in_path)), dims=input_spec.axes)
1167# for in_path, input_spec in zip(model.test_inputs, model.inputs)
1168# ]
1169# input_dict = {input_spec.name: input for input_spec, input in zip(model.inputs, inputs)}
1171# # keep track of the non-processed inputs
1172# inputs_raw = [deepcopy(input) for input in inputs]
1174# computed_measures = {}
1176# prediction_pipeline.apply_preprocessing(input_dict, computed_measures)
1177# inputs_processed = list(input_dict.values())
1178# outputs_raw = prediction_pipeline.predict(*inputs_processed)
1179# output_dict = {output_spec.name: deepcopy(output) for output_spec, output in zip(model.outputs, outputs_raw)}
1180# prediction_pipeline.apply_postprocessing(output_dict, computed_measures)
1181# outputs = list(output_dict.values())
1183# if isinstance(outputs, (np.ndarray, xr.DataArray)):
1184# outputs = [outputs]
1186# expected = [
1187# xr.DataArray(load_array(str(out_path)), dims=output_spec.axes)
1188# for out_path, output_spec in zip(model.test_outputs, model.outputs)
1189# ]
1190# if len(outputs) != len(expected):
1191# error = f"Number of outputs and number of expected outputs disagree: {len(outputs)} != {len(expected)}"
1192# print(error)
1193# else:
1194# diff = []
1195# for res, exp in zip(outputs, expected):
1196# diff.append(res - exp)
1198# return {
1199# "inputs": inputs_raw,
1200# "inputs_processed": inputs_processed,
1201# "outputs_raw": outputs_raw,
1202# "outputs": outputs,
1203# "expected": expected,
1204# "diff": diff,
1205# }