Coverage for src / bioimageio / core / _resource_tests.py: 74%
426 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-18 12:35 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-18 12:35 +0000
1import hashlib
2import os
3import platform
4import subprocess
5import sys
6import warnings
7from contextlib import nullcontext
8from copy import deepcopy
9from io import StringIO
10from itertools import product
11from pathlib import Path
12from tempfile import TemporaryDirectory
13from typing import (
14 Any,
15 Callable,
16 Dict,
17 Hashable,
18 List,
19 Literal,
20 Optional,
21 Sequence,
22 Set,
23 Tuple,
24 Union,
25 overload,
26)
28import numpy as np
29from loguru import logger
30from numpy.typing import NDArray
31from typing_extensions import NotRequired, TypedDict, Unpack, assert_never, get_args
33from bioimageio.spec import (
34 AnyDatasetDescr,
35 AnyModelDescr,
36 BioimageioCondaEnv,
37 DatasetDescr,
38 InvalidDescr,
39 LatestResourceDescr,
40 ModelDescr,
41 ResourceDescr,
42 ValidationContext,
43 build_description,
44 dump_description,
45 get_conda_env,
46 load_description,
47 save_bioimageio_package,
48)
49from bioimageio.spec._description_impl import DISCOVER
50from bioimageio.spec._internal.common_nodes import ResourceDescrBase
51from bioimageio.spec._internal.io import is_yaml_value
52from bioimageio.spec._internal.io_utils import read_yaml, write_yaml
53from bioimageio.spec._internal.types import (
54 AbsoluteTolerance,
55 FormatVersionPlaceholder,
56 MismatchedElementsPerMillion,
57 RelativeTolerance,
58)
59from bioimageio.spec._internal.validation_context import get_validation_context
60from bioimageio.spec._internal.warning_levels import INFO, WARNING, WarningSeverity
61from bioimageio.spec.common import BioimageioYamlContent, PermissiveFileSource, Sha256
62from bioimageio.spec.model import v0_4, v0_5
63from bioimageio.spec.model.v0_5 import WeightsFormat
64from bioimageio.spec.summary import (
65 ErrorEntry,
66 InstalledPackage,
67 ValidationDetail,
68 ValidationSummary,
69 WarningEntry,
70)
72from . import __version__
73from ._prediction_pipeline import create_prediction_pipeline
74from ._settings import settings
75from .axis import AxisId, BatchSize
76from .common import MemberId, SupportedWeightsFormat
77from .digest_spec import get_test_input_sample, get_test_output_sample
78from .io import save_tensor
79from .sample import Sample
80from .tensor import Tensor
82CONDA_CMD = "conda.bat" if platform.system() == "Windows" else "conda"
85class DeprecatedKwargs(TypedDict):
86 absolute_tolerance: NotRequired[AbsoluteTolerance]
87 relative_tolerance: NotRequired[RelativeTolerance]
88 decimal: NotRequired[Optional[int]]
91def enable_determinism(
92 mode: Literal["seed_only", "full"] = "full",
93 weight_formats: Optional[Sequence[SupportedWeightsFormat]] = None,
94):
95 """Seed and configure ML frameworks for maximum reproducibility.
96 May degrade performance. Only recommended for testing reproducibility!
98 Seed any random generators and (if **mode**=="full") request ML frameworks to use
99 deterministic algorithms.
101 Args:
102 mode: determinism mode
103 - 'seed_only' -- only set seeds, or
104 - 'full' determinsm features (might degrade performance or throw exceptions)
105 weight_formats: Limit deep learning importing deep learning frameworks
106 based on weight_formats.
107 E.g. this allows to avoid importing tensorflow when testing with pytorch.
109 Notes:
110 - **mode** == "full" might degrade performance or throw exceptions.
111 - Subsequent inference calls might still differ. Call before each function
112 (sequence) that is expected to be reproducible.
113 - Degraded performance: Use for testing reproducibility only!
114 - Recipes:
115 - [PyTorch](https://pytorch.org/docs/stable/notes/randomness.html)
116 - [Keras](https://keras.io/examples/keras_recipes/reproducibility_recipes/)
117 - [NumPy](https://numpy.org/doc/2.0/reference/random/generated/numpy.random.seed.html)
118 """
119 try:
120 try:
121 import numpy.random
122 except ImportError:
123 pass
124 else:
125 numpy.random.seed(0)
126 except Exception as e:
127 logger.debug(str(e))
129 if (
130 weight_formats is None
131 or "pytorch_state_dict" in weight_formats
132 or "torchscript" in weight_formats
133 ):
134 try:
135 try:
136 import torch
137 except ImportError:
138 pass
139 else:
140 _ = torch.manual_seed(0)
141 torch.use_deterministic_algorithms(mode == "full")
142 except Exception as e:
143 logger.debug(str(e))
145 if (
146 weight_formats is None
147 or "tensorflow_saved_model_bundle" in weight_formats
148 or "keras_hdf5" in weight_formats
149 ):
150 try:
151 os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
152 try:
153 import tensorflow as tf
154 except ImportError:
155 pass
156 else:
157 tf.random.set_seed(0)
158 if mode == "full":
159 tf.config.experimental.enable_op_determinism()
160 # TODO: find possibility to switch it off again??
161 except Exception as e:
162 logger.debug(str(e))
164 if weight_formats is None or "keras_hdf5" in weight_formats:
165 try:
166 try:
167 import keras # pyright: ignore[reportMissingTypeStubs]
168 except ImportError:
169 pass
170 else:
171 keras.utils.set_random_seed(0)
172 except Exception as e:
173 logger.debug(str(e))
176def test_model(
177 source: Union[v0_4.ModelDescr, v0_5.ModelDescr, PermissiveFileSource],
178 weight_format: Optional[SupportedWeightsFormat] = None,
179 devices: Optional[List[str]] = None,
180 *,
181 determinism: Literal["seed_only", "full"] = "seed_only",
182 sha256: Optional[Sha256] = None,
183 stop_early: bool = False,
184 working_dir: Optional[Union[os.PathLike[str], str]] = None,
185 **deprecated: Unpack[DeprecatedKwargs],
186) -> ValidationSummary:
187 """Test model inference"""
188 return test_description(
189 source,
190 weight_format=weight_format,
191 devices=devices,
192 determinism=determinism,
193 expected_type="model",
194 sha256=sha256,
195 stop_early=stop_early,
196 working_dir=working_dir,
197 **deprecated,
198 )
201def default_run_command(args: Sequence[str]):
202 logger.info("running '{}'...", " ".join(args))
203 _ = subprocess.check_call(args)
206def test_description(
207 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
208 *,
209 format_version: Union[FormatVersionPlaceholder, str] = "discover",
210 weight_format: Optional[SupportedWeightsFormat] = None,
211 devices: Optional[Sequence[str]] = None,
212 determinism: Literal["seed_only", "full"] = "seed_only",
213 expected_type: Optional[str] = None,
214 sha256: Optional[Sha256] = None,
215 stop_early: bool = False,
216 runtime_env: Union[
217 Literal["currently-active", "as-described"], Path, BioimageioCondaEnv
218 ] = ("currently-active"),
219 run_command: Callable[[Sequence[str]], None] = default_run_command,
220 working_dir: Optional[Union[os.PathLike[str], str]] = None,
221 **deprecated: Unpack[DeprecatedKwargs],
222) -> ValidationSummary:
223 """Test a bioimage.io resource dynamically,
224 for example run prediction of test tensors for models.
226 Args:
227 source: model description source.
228 weight_format: Weight format to test.
229 Default: All weight formats present in **source**.
230 devices: Devices to test with, e.g. 'cpu', 'cuda'.
231 Default (may be weight format dependent): ['cuda'] if available, ['cpu'] otherwise.
232 determinism: Modes to improve reproducibility of test outputs.
233 expected_type: Assert an expected resource description `type`.
234 sha256: Expected SHA256 value of **source**.
235 (Ignored if **source** already is a loaded `ResourceDescr` object.)
236 stop_early: Do not run further subtests after a failed one.
237 runtime_env: (Experimental feature!) The Python environment to run the tests in
238 - `"currently-active"`: Use active Python interpreter.
239 - `"as-described"`: Use `bioimageio.spec.get_conda_env` to generate a conda
240 environment YAML file based on the model weights description.
241 - A `BioimageioCondaEnv` or a path to a conda environment YAML file.
242 Note: The `bioimageio.core` dependency will be added automatically if not present.
243 run_command: (Experimental feature!) Function to execute (conda) terminal commands in a subprocess.
244 The function should raise an exception if the command fails.
245 **run_command** is ignored if **runtime_env** is `"currently-active"`.
246 working_dir: (for debugging) directory to save any temporary files
247 (model packages, conda environments, test summaries).
248 Defaults to a temporary directory.
249 """
250 if runtime_env == "currently-active":
251 rd = load_description_and_test(
252 source,
253 format_version=format_version,
254 weight_format=weight_format,
255 devices=devices,
256 determinism=determinism,
257 expected_type=expected_type,
258 sha256=sha256,
259 stop_early=stop_early,
260 working_dir=working_dir,
261 **deprecated,
262 )
263 return rd.validation_summary
265 if runtime_env == "as-described":
266 conda_env = None
267 elif isinstance(runtime_env, (str, Path)):
268 conda_env = BioimageioCondaEnv.model_validate(read_yaml(Path(runtime_env)))
269 elif isinstance(runtime_env, BioimageioCondaEnv):
270 conda_env = runtime_env
271 else:
272 assert_never(runtime_env)
274 if run_command is not default_run_command:
275 try:
276 run_command(["thiscommandshouldalwaysfail", "please"])
277 except Exception:
278 pass
279 else:
280 raise RuntimeError(
281 "given run_command does not raise an exception for a failing command"
282 )
284 verbose = working_dir is not None
285 if working_dir is None:
286 td_kwargs: Dict[str, Any] = (
287 dict(ignore_cleanup_errors=True) if sys.version_info >= (3, 10) else {}
288 )
289 working_dir_ctxt = TemporaryDirectory(**td_kwargs)
290 else:
291 working_dir_ctxt = nullcontext(working_dir)
293 with working_dir_ctxt as _d:
294 working_dir = Path(_d)
296 if isinstance(source, ResourceDescrBase):
297 descr = source
298 elif isinstance(source, dict):
299 context = get_validation_context().replace(
300 perform_io_checks=True # make sure we perform io checks though
301 )
303 descr = build_description(source, context=context)
304 else:
305 descr = load_description(source, perform_io_checks=True)
307 if isinstance(descr, InvalidDescr):
308 return descr.validation_summary
309 elif isinstance(source, (dict, ResourceDescrBase)):
310 file_source = save_bioimageio_package(
311 descr, output_path=working_dir / "package.zip"
312 )
313 else:
314 file_source = source
316 # elevate status valid-format to passed and start testing
317 descr.validation_summary.status = "passed"
318 _test_in_env(
319 file_source,
320 descr=descr,
321 working_dir=working_dir,
322 weight_format=weight_format,
323 conda_env=conda_env,
324 devices=devices,
325 determinism=determinism,
326 expected_type=expected_type,
327 sha256=sha256,
328 stop_early=stop_early,
329 run_command=run_command,
330 verbose=verbose,
331 **deprecated,
332 )
334 return descr.validation_summary
337def _test_in_env(
338 source: PermissiveFileSource,
339 *,
340 descr: ResourceDescr,
341 working_dir: Path,
342 weight_format: Optional[SupportedWeightsFormat],
343 conda_env: Optional[BioimageioCondaEnv],
344 devices: Optional[Sequence[str]],
345 determinism: Literal["seed_only", "full"],
346 run_command: Callable[[Sequence[str]], None],
347 stop_early: bool,
348 expected_type: Optional[str],
349 sha256: Optional[Sha256],
350 verbose: bool,
351 **deprecated: Unpack[DeprecatedKwargs],
352):
353 """Test a bioimage.io resource in a given conda environment.
354 Adds details to the existing validation summary of **descr**.
355 """
356 if isinstance(descr, (v0_4.ModelDescr, v0_5.ModelDescr)):
357 if weight_format is None:
358 # run tests for all present weight formats
359 all_present_wfs = [
360 wf for wf in get_args(WeightsFormat) if getattr(descr.weights, wf)
361 ]
362 ignore_wfs = [wf for wf in all_present_wfs if wf in ["tensorflow_js"]]
363 logger.info(
364 "Found weight formats {}. Start testing all{}...",
365 all_present_wfs,
366 f" (except: {', '.join(ignore_wfs)}) " if ignore_wfs else "",
367 )
368 for wf in all_present_wfs:
369 _test_in_env(
370 source,
371 descr=descr,
372 working_dir=working_dir / wf,
373 weight_format=wf,
374 devices=devices,
375 determinism=determinism,
376 conda_env=conda_env,
377 run_command=run_command,
378 expected_type=expected_type,
379 sha256=sha256,
380 stop_early=stop_early,
381 verbose=verbose,
382 **deprecated,
383 )
385 return
387 if weight_format == "pytorch_state_dict":
388 wf = descr.weights.pytorch_state_dict
389 elif weight_format == "torchscript":
390 wf = descr.weights.torchscript
391 elif weight_format == "keras_hdf5":
392 wf = descr.weights.keras_hdf5
393 elif weight_format == "onnx":
394 wf = descr.weights.onnx
395 elif weight_format == "tensorflow_saved_model_bundle":
396 wf = descr.weights.tensorflow_saved_model_bundle
397 elif weight_format == "keras_v3":
398 if isinstance(descr, v0_4.ModelDescr):
399 raise ValueError(
400 "Weight format 'keras_v3' is not supported in v0.4 model descriptions. use format version >= 0.5"
401 )
403 wf = descr.weights.keras_v3
404 elif weight_format == "tensorflow_js":
405 raise RuntimeError(
406 "testing 'tensorflow_js' is not supported by bioimageio.core"
407 )
408 else:
409 assert_never(weight_format)
411 assert wf is not None
412 if conda_env is None:
413 conda_env = get_conda_env(entry=wf)
415 test_loc = ("weights", weight_format)
416 else:
417 if conda_env is None:
418 warnings.warn(
419 "No conda environment description given for testing (And no default conda envs available for non-model descriptions)."
420 )
421 return
423 test_loc = ()
425 # remove name as we create a name based on the env description hash value
426 conda_env.name = None
428 dumped_env = conda_env.model_dump(mode="json", exclude_none=True)
429 if not is_yaml_value(dumped_env):
430 raise ValueError(f"Failed to dump conda env to valid YAML {conda_env}")
432 env_io = StringIO()
433 write_yaml(dumped_env, file=env_io)
434 encoded_env = env_io.getvalue().encode()
435 env_name = hashlib.sha256(encoded_env).hexdigest()
437 try:
438 run_command(["where" if platform.system() == "Windows" else "which", CONDA_CMD])
439 except Exception as e:
440 raise RuntimeError("Conda not available") from e
442 try:
443 run_command([CONDA_CMD, "run", "-n", env_name, "python", "--version"])
444 except Exception:
445 working_dir.mkdir(parents=True, exist_ok=True)
446 path = working_dir / "env.yaml"
447 try:
448 _ = path.write_bytes(encoded_env)
449 logger.debug("written conda env to {}", path)
450 run_command(
451 [
452 CONDA_CMD,
453 "env",
454 "create",
455 "--yes",
456 f"--file={path}",
457 f"--name={env_name}",
458 ]
459 + (["--quiet"] if settings.CI else [])
460 )
461 # double check that environment was created successfully
462 run_command([CONDA_CMD, "run", "-n", env_name, "python", "--version"])
463 except Exception as e:
464 descr.validation_summary.add_detail(
465 ValidationDetail(
466 name="Conda environment creation",
467 status="failed",
468 loc=test_loc,
469 recommended_env=conda_env,
470 errors=[
471 ErrorEntry(
472 loc=test_loc,
473 msg=str(e),
474 type="conda",
475 with_traceback=True,
476 )
477 ],
478 )
479 )
480 return
481 else:
482 descr.validation_summary.add_detail(
483 ValidationDetail(
484 name=f"Created conda environment '{env_name}'",
485 status="passed",
486 loc=test_loc,
487 )
488 )
489 else:
490 descr.validation_summary.add_detail(
491 ValidationDetail(
492 name=f"Found existing conda environment '{env_name}'",
493 status="passed",
494 loc=test_loc,
495 )
496 )
498 working_dir.mkdir(parents=True, exist_ok=True)
499 summary_path = working_dir / "summary.json"
500 assert not summary_path.exists(), "Summary file already exists"
501 cmd = []
502 cmd_error = None
503 for summary_path_arg_name in ("summary", "summary-path"):
504 try:
505 run_command(
506 cmd := (
507 [
508 CONDA_CMD,
509 "run",
510 "-n",
511 env_name,
512 "bioimageio",
513 "test",
514 str(source),
515 f"--{summary_path_arg_name}={summary_path.as_posix()}",
516 f"--determinism={determinism}",
517 ]
518 + ([f"--weight-format={weight_format}"] if weight_format else [])
519 + ([f"--expected-type={expected_type}"] if expected_type else [])
520 + (["--stop-early"] if stop_early else [])
521 )
522 )
523 except Exception as e:
524 cmd_error = f"Command '{' '.join(cmd)}' returned with error: {e}."
526 if summary_path.exists():
527 break
528 else:
529 if cmd_error is not None:
530 logger.warning(cmd_error)
532 descr.validation_summary.add_detail(
533 ValidationDetail(
534 name="run 'bioimageio test' command",
535 recommended_env=conda_env,
536 errors=[
537 ErrorEntry(
538 loc=(),
539 type="bioimageio cli",
540 msg=f"test command '{' '.join(cmd)}' did not produce a summary file at {summary_path}",
541 )
542 ],
543 status="failed",
544 )
545 )
546 return
548 # add relevant details from command summary
549 command_summary = ValidationSummary.load_json(summary_path)
550 for detail in command_summary.details:
551 if detail.loc[: len(test_loc)] == test_loc or detail.status == "failed":
552 descr.validation_summary.add_detail(detail)
555@overload
556def load_description_and_test(
557 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
558 *,
559 format_version: Literal["latest"],
560 weight_format: Optional[SupportedWeightsFormat] = None,
561 devices: Optional[Sequence[str]] = None,
562 determinism: Literal["seed_only", "full"] = "seed_only",
563 expected_type: Literal["model"],
564 sha256: Optional[Sha256] = None,
565 stop_early: bool = False,
566 working_dir: Optional[Union[os.PathLike[str], str]] = None,
567 **deprecated: Unpack[DeprecatedKwargs],
568) -> Union[ModelDescr, InvalidDescr]: ...
571@overload
572def load_description_and_test(
573 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
574 *,
575 format_version: Literal["latest"],
576 weight_format: Optional[SupportedWeightsFormat] = None,
577 devices: Optional[Sequence[str]] = None,
578 determinism: Literal["seed_only", "full"] = "seed_only",
579 expected_type: Literal["dataset"],
580 sha256: Optional[Sha256] = None,
581 stop_early: bool = False,
582 working_dir: Optional[Union[os.PathLike[str], str]] = None,
583 **deprecated: Unpack[DeprecatedKwargs],
584) -> Union[DatasetDescr, InvalidDescr]: ...
587@overload
588def load_description_and_test(
589 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
590 *,
591 format_version: Literal["latest"],
592 weight_format: Optional[SupportedWeightsFormat] = None,
593 devices: Optional[Sequence[str]] = None,
594 determinism: Literal["seed_only", "full"] = "seed_only",
595 expected_type: Optional[str] = None,
596 sha256: Optional[Sha256] = None,
597 stop_early: bool = False,
598 working_dir: Optional[Union[os.PathLike[str], str]] = None,
599 **deprecated: Unpack[DeprecatedKwargs],
600) -> Union[LatestResourceDescr, InvalidDescr]: ...
603@overload
604def load_description_and_test(
605 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
606 *,
607 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER,
608 weight_format: Optional[SupportedWeightsFormat] = None,
609 devices: Optional[Sequence[str]] = None,
610 determinism: Literal["seed_only", "full"] = "seed_only",
611 expected_type: Literal["model"],
612 sha256: Optional[Sha256] = None,
613 stop_early: bool = False,
614 working_dir: Optional[Union[os.PathLike[str], str]] = None,
615 **deprecated: Unpack[DeprecatedKwargs],
616) -> Union[AnyModelDescr, InvalidDescr]: ...
619@overload
620def load_description_and_test(
621 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
622 *,
623 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER,
624 weight_format: Optional[SupportedWeightsFormat] = None,
625 devices: Optional[Sequence[str]] = None,
626 determinism: Literal["seed_only", "full"] = "seed_only",
627 expected_type: Literal["dataset"],
628 sha256: Optional[Sha256] = None,
629 stop_early: bool = False,
630 working_dir: Optional[Union[os.PathLike[str], str]] = None,
631 **deprecated: Unpack[DeprecatedKwargs],
632) -> Union[AnyDatasetDescr, InvalidDescr]: ...
635@overload
636def load_description_and_test(
637 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
638 *,
639 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER,
640 weight_format: Optional[SupportedWeightsFormat] = None,
641 devices: Optional[Sequence[str]] = None,
642 determinism: Literal["seed_only", "full"] = "seed_only",
643 expected_type: Optional[str] = None,
644 sha256: Optional[Sha256] = None,
645 stop_early: bool = False,
646 working_dir: Optional[Union[os.PathLike[str], str]] = None,
647 **deprecated: Unpack[DeprecatedKwargs],
648) -> Union[ResourceDescr, InvalidDescr]: ...
651def load_description_and_test(
652 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
653 *,
654 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER,
655 weight_format: Optional[SupportedWeightsFormat] = None,
656 devices: Optional[Sequence[str]] = None,
657 determinism: Literal["seed_only", "full"] = "seed_only",
658 expected_type: Optional[str] = None,
659 sha256: Optional[Sha256] = None,
660 stop_early: bool = False,
661 working_dir: Optional[Union[os.PathLike[str], str]] = None,
662 **deprecated: Unpack[DeprecatedKwargs],
663) -> Union[ResourceDescr, InvalidDescr]:
664 """Test a bioimage.io resource dynamically,
665 for example run prediction of test tensors for models.
667 See `test_description` for more details.
669 Returns:
670 A (possibly invalid) resource description object
671 with a populated `.validation_summary` attribute.
672 """
673 if isinstance(source, ResourceDescrBase):
674 root = source.root
675 file_name = source.file_name
676 if (
677 (
678 format_version
679 not in (
680 DISCOVER,
681 source.format_version,
682 ".".join(source.format_version.split(".")[:2]),
683 )
684 )
685 or (c := source.validation_summary.details[0].context) is None
686 or not c.perform_io_checks
687 ):
688 logger.debug(
689 "deserializing source to ensure we validate and test using format {} and perform io checks",
690 format_version,
691 )
692 source = dump_description(source)
693 else:
694 root = Path()
695 file_name = None
697 if isinstance(source, ResourceDescrBase):
698 rd = source
699 elif isinstance(source, dict):
700 # check context for a given root; default to root of source
701 context = get_validation_context(
702 ValidationContext(root=root, file_name=file_name)
703 ).replace(
704 perform_io_checks=True # make sure we perform io checks though
705 )
707 rd = build_description(
708 source,
709 format_version=format_version,
710 context=context,
711 )
712 else:
713 rd = load_description(
714 source, format_version=format_version, sha256=sha256, perform_io_checks=True
715 )
717 rd.validation_summary.env.add(
718 InstalledPackage(name="bioimageio.core", version=__version__)
719 )
721 if expected_type is not None:
722 has_expected_type = _test_expected_resource_type(rd, expected_type)
723 if not has_expected_type:
724 # unexpected type -> invalid format
725 rd.validation_summary.status = "failed"
726 return rd
728 # elevate status valid-format to passed and start testing
729 if rd.validation_summary.status == "valid-format":
730 rd.validation_summary.status = "passed"
732 if isinstance(rd, (v0_4.ModelDescr, v0_5.ModelDescr)):
733 if weight_format is None:
734 weight_formats: List[SupportedWeightsFormat] = [
735 w for w, we in rd.weights if we is not None
736 ] # pyright: ignore[reportAssignmentType]
737 else:
738 weight_formats = [weight_format]
740 enable_determinism(determinism, weight_formats=weight_formats)
741 for w in weight_formats:
742 passed_recreate_test_outputs = _test_recreate_test_outputs(
743 rd,
744 w,
745 devices,
746 stop_early=stop_early,
747 working_dir=working_dir,
748 verbose=working_dir is not None,
749 **deprecated,
750 )
752 if stop_early and not passed_recreate_test_outputs:
753 break
755 if not isinstance(rd, v0_4.ModelDescr):
756 passed_parametrized_inference = _test_parametrized_inference(
757 rd, w, devices, stop_early=stop_early
758 )
759 if stop_early and not passed_parametrized_inference:
760 break
762 # TODO: add execution of jupyter notebooks
763 # TODO: add more tests
765 return rd
768def _get_tolerance(
769 model: Union[v0_4.ModelDescr, v0_5.ModelDescr],
770 wf: SupportedWeightsFormat,
771 m: MemberId,
772 **deprecated: Unpack[DeprecatedKwargs],
773) -> Tuple[RelativeTolerance, AbsoluteTolerance, MismatchedElementsPerMillion]:
774 if isinstance(model, v0_5.ModelDescr):
775 applicable = v0_5.ReproducibilityTolerance()
777 # check legacy test kwargs for weight format specific tolerance
778 if model.config.bioimageio.model_extra is not None:
779 for weights_format, test_kwargs in model.config.bioimageio.model_extra.get(
780 "test_kwargs", {}
781 ).items():
782 if wf == weights_format:
783 applicable = v0_5.ReproducibilityTolerance(
784 relative_tolerance=test_kwargs.get("relative_tolerance", 1e-3),
785 absolute_tolerance=test_kwargs.get("absolute_tolerance", 1e-3),
786 )
787 break
789 # check for weights format and output tensor specific tolerance
790 for a in model.config.bioimageio.reproducibility_tolerance:
791 if (not a.weights_formats or wf in a.weights_formats) and (
792 not a.output_ids or m in a.output_ids
793 ):
794 applicable = a
795 break
797 rtol = applicable.relative_tolerance
798 atol = applicable.absolute_tolerance
799 mismatched_tol = applicable.mismatched_elements_per_million
800 elif (decimal := deprecated.get("decimal")) is not None:
801 warnings.warn(
802 "The argument `decimal` has been deprecated in favour of"
803 + " `relative_tolerance` and `absolute_tolerance`, with different"
804 + " validation logic, using `numpy.testing.assert_allclose, see"
805 + " 'https://numpy.org/doc/stable/reference/generated/"
806 + " numpy.testing.assert_allclose.html'. Passing a value for `decimal`"
807 + " will cause validation to revert to the old behaviour."
808 )
809 atol = 1.5 * 10 ** (-decimal)
810 rtol = 0
811 mismatched_tol = 0
812 else:
813 # use given (deprecated) test kwargs
814 atol = deprecated.get("absolute_tolerance", 1e-3)
815 rtol = deprecated.get("relative_tolerance", 1e-3)
816 mismatched_tol = 0
818 return rtol, atol, mismatched_tol
821def _test_recreate_test_outputs(
822 model: Union[v0_4.ModelDescr, v0_5.ModelDescr],
823 weight_format: SupportedWeightsFormat,
824 devices: Optional[Sequence[str]],
825 stop_early: bool,
826 *,
827 working_dir: Optional[Union[os.PathLike[str], str]],
828 verbose: bool,
829 **deprecated: Unpack[DeprecatedKwargs],
830) -> bool:
831 test_name = f"Reproduce test outputs from test inputs ({weight_format})"
832 logger.debug("starting '{}'", test_name)
833 error_entries: List[ErrorEntry] = []
834 warning_entries: List[WarningEntry] = []
836 def add_error_entry(msg: str, with_traceback: bool = False):
837 error_entries.append(
838 ErrorEntry(
839 loc=("weights", weight_format),
840 msg=msg,
841 type="bioimageio.core",
842 with_traceback=with_traceback,
843 )
844 )
846 def add_warning_entry(msg: str, severity: WarningSeverity):
847 warning_entries.append(
848 WarningEntry(
849 loc=("weights", weight_format),
850 msg=msg,
851 type="bioimageio.core",
852 severity=severity,
853 )
854 )
856 def save_to_working_dir(name: str, tensor: Tensor) -> List[Path]:
857 saved_paths: List[Path] = []
858 if working_dir is not None and verbose:
859 for p in [
860 Path(working_dir) / f"{name}_{weight_format}{suffix}"
861 for suffix in (".npy", ".tiff")
862 ]:
863 try:
864 save_tensor(p, tensor)
865 except Exception as e:
866 logger.error(
867 "Failed to save tensor {}: {}",
868 p,
869 e,
870 )
871 else:
872 saved_paths.append(p)
874 return saved_paths
876 try:
877 test_input = get_test_input_sample(model)
878 expected = get_test_output_sample(model)
880 with create_prediction_pipeline(
881 bioimageio_model=model, devices=devices, weight_format=weight_format
882 ) as prediction_pipeline:
883 prediction_pipeline.apply_preprocessing(test_input)
884 test_input_preprocessed = deepcopy(test_input)
885 results_not_postprocessed = (
886 prediction_pipeline.predict_sample_without_blocking(
887 test_input,
888 skip_postprocessing=True,
889 skip_preprocessing=True,
890 skip_input_padding=True,
891 skip_output_cropping=True,
892 )
893 )
894 results = deepcopy(results_not_postprocessed)
895 prediction_pipeline.apply_postprocessing(results)
897 if len(results.members) != len(expected.members):
898 add_error_entry(
899 f"Expected {len(expected.members)} outputs, but got {len(results.members)}"
900 )
902 else:
903 intermediate_paths: List[Path] = []
904 for m, t in test_input_preprocessed.members.items():
905 intermediate_paths.extend(
906 save_to_working_dir(f"test_input_preprocessed_{m}", t)
907 )
908 if intermediate_paths:
909 logger.debug("Saved preprocessed test inputs to {}", intermediate_paths)
911 for m, expected in expected.members.items():
912 actual = results.members.get(m)
913 if actual is None:
914 add_error_entry("Output tensors for test case may not be None")
915 if stop_early:
916 break
917 else:
918 continue
920 if actual.dims != (dims := expected.dims):
921 add_error_entry(
922 f"Output '{m}' has dims {actual.dims}, but expected {expected.dims}"
923 )
924 if stop_early:
925 break
926 else:
927 continue
929 if actual.tagged_shape != expected.tagged_shape:
930 add_error_entry(
931 f"Output '{m}' has shape {actual.tagged_shape}, but expected {expected.tagged_shape}"
932 )
933 if stop_early:
934 break
935 else:
936 continue
938 try:
939 output_paths = save_to_working_dir(f"actual_output_{m}", actual)
940 if m in results_not_postprocessed.members:
941 output_paths.extend(
942 save_to_working_dir(
943 f"actual_output_{m}_not_postprocessed",
944 results_not_postprocessed.members[m],
945 )
946 )
948 expected_np = expected.data.to_numpy().astype(np.float32)
949 del expected
950 actual_np: NDArray[Any] = actual.data.to_numpy().astype(np.float32)
952 rtol, atol, mismatched_tol = _get_tolerance(
953 model, wf=weight_format, m=m, **deprecated
954 )
955 rtol_value = rtol * abs(expected_np)
956 abs_diff = abs(actual_np - expected_np)
957 mismatched = abs_diff > atol + rtol_value
958 mismatched_elements = mismatched.sum().item()
960 mismatched_ppm = mismatched_elements / expected_np.size * 1e6
961 abs_diff[~mismatched] = 0 # ignore non-mismatched elements
963 r_max_idx_flat = (
964 r_diff := (abs_diff / (abs(expected_np) + 1e-6))
965 ).argmax()
966 r_max_idx = np.unravel_index(r_max_idx_flat, r_diff.shape)
967 r_max = r_diff[r_max_idx].item()
968 r_actual = actual_np[r_max_idx].item()
969 r_expected = expected_np[r_max_idx].item()
971 # Calculate the max absolute difference with the relative tolerance subtracted
972 abs_diff_wo_rtol: NDArray[np.float32] = abs_diff - rtol_value
973 a_max_idx = np.unravel_index(
974 abs_diff_wo_rtol.argmax(), abs_diff_wo_rtol.shape
975 )
977 a_max = abs_diff[a_max_idx].item()
978 a_actual = actual_np[a_max_idx].item()
979 a_expected = expected_np[a_max_idx].item()
980 except Exception as e:
981 msg = f"Error while checking if '{m}' disagrees with expected values: {e}"
982 add_error_entry(msg)
983 if stop_early:
984 break
985 else:
986 if mismatched_elements:
987 msg = (
988 f"Output '{m}': {mismatched_elements} of "
989 + f"{expected_np.size} elements disagree with expected values."
990 + f" ({mismatched_ppm:.1f} ppm). "
991 )
992 else:
993 msg = f"Output `{m}`: all elements agree with expected values. "
995 msg += (
996 f"\nMax relative difference not accounted for by absolute tolerance ({atol:.2e}):\n{r_max:.2e}"
997 + rf" (= \|{r_actual:.2e} - {r_expected:.2e}\|/\|{r_expected:.2e} + 1e-6\|)"
998 + f" at {dict(zip(dims, r_max_idx))} "
999 + f"\nMax absolute difference not accounted for by relative tolerance ({rtol:.2e}):\n{a_max:.2e}"
1000 + rf" (= \|{a_actual:.7e} - {a_expected:.7e}\|) at {dict(zip(dims, a_max_idx))}"
1001 )
1002 if output_paths:
1003 msg += f"\n Saved (intermediate) outputs to {output_paths}."
1005 if mismatched_ppm > mismatched_tol:
1006 add_error_entry(msg)
1007 if stop_early:
1008 break
1009 else:
1010 add_warning_entry(
1011 msg, severity=WARNING if mismatched_elements else INFO
1012 )
1014 except Exception as e:
1015 if get_validation_context().raise_errors:
1016 raise e
1018 add_error_entry(str(e), with_traceback=True)
1020 model.validation_summary.add_detail(
1021 ValidationDetail(
1022 name=test_name,
1023 loc=("weights", weight_format),
1024 status="failed" if error_entries else "passed",
1025 recommended_env=get_conda_env(entry=dict(model.weights)[weight_format]),
1026 errors=error_entries,
1027 warnings=warning_entries,
1028 )
1029 )
1030 return bool(error_entries)
1033def _test_parametrized_inference(
1034 model: v0_5.ModelDescr,
1035 weight_format: SupportedWeightsFormat,
1036 devices: Optional[Sequence[str]],
1037 *,
1038 stop_early: bool,
1039) -> None:
1040 if not any(
1041 isinstance(a.size, v0_5.ParameterizedSize)
1042 for ipt in model.inputs
1043 for a in ipt.axes
1044 ):
1045 # no parameterized sizes => set n=0
1046 ns: Set[v0_5.ParameterizedSize_N] = {0}
1047 else:
1048 ns = {0, 1, 2}
1050 given_batch_sizes = {
1051 a.size
1052 for ipt in model.inputs
1053 for a in ipt.axes
1054 if isinstance(a, v0_5.BatchAxis)
1055 }
1056 if given_batch_sizes:
1057 batch_sizes = {gbs for gbs in given_batch_sizes if gbs is not None}
1058 if not batch_sizes:
1059 # only arbitrary batch sizes
1060 batch_sizes = {1, 2}
1061 else:
1062 # no batch axis
1063 batch_sizes = {1}
1065 test_cases: Set[Tuple[BatchSize, v0_5.ParameterizedSize_N]] = {
1066 (b, n) for b, n in product(sorted(batch_sizes), sorted(ns))
1067 }
1068 logger.info(
1069 "Testing inference with '{}' for {} different inputs (B, N): {}",
1070 weight_format,
1071 len(test_cases),
1072 test_cases,
1073 )
1075 def generate_test_cases():
1076 tested: Set[Hashable] = set()
1078 def get_ns(n: int):
1079 return {
1080 (t.id, a.id): n
1081 for t in model.inputs
1082 for a in t.axes
1083 if isinstance(a.size, v0_5.ParameterizedSize)
1084 }
1086 for batch_size, n in sorted(test_cases):
1087 input_target_sizes, expected_output_sizes = model.get_axis_sizes(
1088 get_ns(n), batch_size=batch_size
1089 )
1090 hashable_target_size = tuple(
1091 (k, input_target_sizes[k]) for k in sorted(input_target_sizes)
1092 )
1093 if hashable_target_size in tested:
1094 continue
1095 else:
1096 tested.add(hashable_target_size)
1098 resized_test_inputs = Sample(
1099 members={
1100 t.id: (
1101 test_input.members[t.id].resize_to(
1102 {
1103 aid: s
1104 for (tid, aid), s in input_target_sizes.items()
1105 if tid == t.id
1106 },
1107 )
1108 )
1109 for t in model.inputs
1110 },
1111 stat=test_input.stat,
1112 id=test_input.id,
1113 )
1114 expected_output_shapes = {
1115 t.id: {
1116 aid: s
1117 for (tid, aid), s in expected_output_sizes.items()
1118 if tid == t.id
1119 }
1120 for t in model.outputs
1121 }
1122 yield n, batch_size, resized_test_inputs, expected_output_shapes
1124 try:
1125 test_input = get_test_input_sample(model)
1127 with create_prediction_pipeline(
1128 bioimageio_model=model, devices=devices, weight_format=weight_format
1129 ) as prediction_pipeline:
1130 for n, batch_size, inputs, exptected_output_shape in generate_test_cases():
1131 error: Optional[str] = None
1132 try:
1133 result = prediction_pipeline.predict_sample_without_blocking(
1134 inputs, skip_input_padding=True, skip_output_cropping=True
1135 )
1136 except Exception as e:
1137 error = str(e)
1138 else:
1139 if len(result.members) != len(exptected_output_shape):
1140 error = (
1141 f"Expected {len(exptected_output_shape)} outputs,"
1142 + f" but got {len(result.members)}"
1143 )
1145 else:
1146 for m, exp in exptected_output_shape.items():
1147 res = result.members.get(m)
1148 if res is None:
1149 error = "Output tensors may not be None for test case"
1150 break
1152 diff: Dict[AxisId, int] = {}
1153 for a, s in res.sizes.items():
1154 if isinstance((e_aid := exp[AxisId(a)]), int):
1155 if s != e_aid:
1156 diff[AxisId(a)] = s
1157 elif (
1158 s < e_aid.min
1159 or e_aid.max is not None
1160 and s > e_aid.max
1161 ):
1162 diff[AxisId(a)] = s
1163 if diff:
1164 error = (
1165 f"(n={n}) Expected output shape {exp},"
1166 + f" but got {res.sizes} (diff: {diff})"
1167 )
1168 break
1170 model.validation_summary.add_detail(
1171 ValidationDetail(
1172 name=f"Run {weight_format} inference for inputs with"
1173 + f" batch_size: {batch_size} and size parameter n: {n}",
1174 loc=("weights", weight_format),
1175 status="passed" if error is None else "failed",
1176 errors=(
1177 []
1178 if error is None
1179 else [
1180 ErrorEntry(
1181 loc=("weights", weight_format),
1182 msg=error,
1183 type="bioimageio.core",
1184 )
1185 ]
1186 ),
1187 )
1188 )
1189 if stop_early and error is not None:
1190 break
1191 except Exception as e:
1192 if get_validation_context().raise_errors:
1193 raise e
1195 model.validation_summary.add_detail(
1196 ValidationDetail(
1197 name=f"Run {weight_format} inference for parametrized inputs",
1198 status="failed",
1199 loc=("weights", weight_format),
1200 errors=[
1201 ErrorEntry(
1202 loc=("weights", weight_format),
1203 msg=str(e),
1204 type="bioimageio.core",
1205 with_traceback=True,
1206 )
1207 ],
1208 )
1209 )
1212def _test_expected_resource_type(
1213 rd: Union[InvalidDescr, ResourceDescr], expected_type: str
1214):
1215 has_expected_type = rd.type is expected_type
1216 rd.validation_summary.details.append(
1217 ValidationDetail(
1218 name="Has expected resource type",
1219 status="passed" if has_expected_type else "failed",
1220 loc=("type",),
1221 errors=(
1222 []
1223 if has_expected_type
1224 else [
1225 ErrorEntry(
1226 loc=("type",),
1227 type="type",
1228 msg=f"Expected type {expected_type}, found {rd.type}",
1229 )
1230 ]
1231 ),
1232 )
1233 )
1234 return has_expected_type
1237# TODO: Implement `debug_model()`
1238# def debug_model(
1239# model_rdf: Union[RawResourceDescr, ResourceDescr, URI, Path, str],
1240# *,
1241# weight_format: Optional[WeightsFormat] = None,
1242# devices: Optional[List[str]] = None,
1243# ):
1244# """Run the model test and return dict with inputs, results, expected results and intermediates.
1246# Returns dict with tensors "inputs", "inputs_processed", "outputs_raw", "outputs", "expected" and "diff".
1247# """
1248# inputs_raw: Optional = None
1249# inputs_processed: Optional = None
1250# outputs_raw: Optional = None
1251# outputs: Optional = None
1252# expected: Optional = None
1253# diff: Optional = None
1255# model = load_description(
1256# model_rdf, weights_priority_order=None if weight_format is None else [weight_format]
1257# )
1258# if not isinstance(model, Model):
1259# raise ValueError(f"Not a bioimageio.model: {model_rdf}")
1261# prediction_pipeline = create_prediction_pipeline(
1262# bioimageio_model=model, devices=devices, weight_format=weight_format
1263# )
1264# inputs = [
1265# xr.DataArray(load_array(str(in_path)), dims=input_spec.axes)
1266# for in_path, input_spec in zip(model.test_inputs, model.inputs)
1267# ]
1268# input_dict = {input_spec.name: input for input_spec, input in zip(model.inputs, inputs)}
1270# # keep track of the non-processed inputs
1271# inputs_raw = [deepcopy(input) for input in inputs]
1273# computed_measures = {}
1275# prediction_pipeline.apply_preprocessing(input_dict, computed_measures)
1276# inputs_processed = list(input_dict.values())
1277# outputs_raw = prediction_pipeline.predict(*inputs_processed)
1278# output_dict = {output_spec.name: deepcopy(output) for output_spec, output in zip(model.outputs, outputs_raw)}
1279# prediction_pipeline.apply_postprocessing(output_dict, computed_measures)
1280# outputs = list(output_dict.values())
1282# if isinstance(outputs, (np.ndarray, xr.DataArray)):
1283# outputs = [outputs]
1285# expected = [
1286# xr.DataArray(load_array(str(out_path)), dims=output_spec.axes)
1287# for out_path, output_spec in zip(model.test_outputs, model.outputs)
1288# ]
1289# if len(outputs) != len(expected):
1290# error = f"Number of outputs and number of expected outputs disagree: {len(outputs)} != {len(expected)}"
1291# print(error)
1292# else:
1293# diff = []
1294# for res, exp in zip(outputs, expected):
1295# diff.append(res - exp)
1297# return {
1298# "inputs": inputs_raw,
1299# "inputs_processed": inputs_processed,
1300# "outputs_raw": outputs_raw,
1301# "outputs": outputs,
1302# "expected": expected,
1303# "diff": diff,
1304# }