Coverage for src / bioimageio / core / _resource_tests.py: 73%
418 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 13:23 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 13:23 +0000
1import hashlib
2import os
3import platform
4import subprocess
5import sys
6import warnings
7from contextlib import nullcontext
8from copy import deepcopy
9from io import StringIO
10from itertools import product
11from pathlib import Path
12from tempfile import TemporaryDirectory
13from typing import (
14 Any,
15 Callable,
16 Dict,
17 Hashable,
18 List,
19 Literal,
20 Optional,
21 Sequence,
22 Set,
23 Tuple,
24 Union,
25 overload,
26)
28import numpy as np
29from loguru import logger
30from numpy.typing import NDArray
31from typing_extensions import NotRequired, TypedDict, Unpack, assert_never, get_args
33from bioimageio.core.tensor import Tensor
34from bioimageio.spec import (
35 AnyDatasetDescr,
36 AnyModelDescr,
37 BioimageioCondaEnv,
38 DatasetDescr,
39 InvalidDescr,
40 LatestResourceDescr,
41 ModelDescr,
42 ResourceDescr,
43 ValidationContext,
44 build_description,
45 dump_description,
46 get_conda_env,
47 load_description,
48 save_bioimageio_package,
49)
50from bioimageio.spec._description_impl import DISCOVER
51from bioimageio.spec._internal.common_nodes import ResourceDescrBase
52from bioimageio.spec._internal.io import is_yaml_value
53from bioimageio.spec._internal.io_utils import read_yaml, write_yaml
54from bioimageio.spec._internal.types import (
55 AbsoluteTolerance,
56 FormatVersionPlaceholder,
57 MismatchedElementsPerMillion,
58 RelativeTolerance,
59)
60from bioimageio.spec._internal.validation_context import get_validation_context
61from bioimageio.spec._internal.warning_levels import INFO, WARNING, WarningSeverity
62from bioimageio.spec.common import BioimageioYamlContent, PermissiveFileSource, Sha256
63from bioimageio.spec.model import v0_4, v0_5
64from bioimageio.spec.model.v0_5 import WeightsFormat
65from bioimageio.spec.summary import (
66 ErrorEntry,
67 InstalledPackage,
68 ValidationDetail,
69 ValidationSummary,
70 WarningEntry,
71)
73from . import __version__
74from ._prediction_pipeline import create_prediction_pipeline
75from ._settings import settings
76from .axis import AxisId, BatchSize
77from .common import MemberId, SupportedWeightsFormat
78from .digest_spec import get_test_input_sample, get_test_output_sample
79from .io import save_tensor
80from .sample import Sample
82CONDA_CMD = "conda.bat" if platform.system() == "Windows" else "conda"
85class DeprecatedKwargs(TypedDict):
86 absolute_tolerance: NotRequired[AbsoluteTolerance]
87 relative_tolerance: NotRequired[RelativeTolerance]
88 decimal: NotRequired[Optional[int]]
91def enable_determinism(
92 mode: Literal["seed_only", "full"] = "full",
93 weight_formats: Optional[Sequence[SupportedWeightsFormat]] = None,
94):
95 """Seed and configure ML frameworks for maximum reproducibility.
96 May degrade performance. Only recommended for testing reproducibility!
98 Seed any random generators and (if **mode**=="full") request ML frameworks to use
99 deterministic algorithms.
101 Args:
102 mode: determinism mode
103 - 'seed_only' -- only set seeds, or
104 - 'full' determinsm features (might degrade performance or throw exceptions)
105 weight_formats: Limit deep learning importing deep learning frameworks
106 based on weight_formats.
107 E.g. this allows to avoid importing tensorflow when testing with pytorch.
109 Notes:
110 - **mode** == "full" might degrade performance or throw exceptions.
111 - Subsequent inference calls might still differ. Call before each function
112 (sequence) that is expected to be reproducible.
113 - Degraded performance: Use for testing reproducibility only!
114 - Recipes:
115 - [PyTorch](https://pytorch.org/docs/stable/notes/randomness.html)
116 - [Keras](https://keras.io/examples/keras_recipes/reproducibility_recipes/)
117 - [NumPy](https://numpy.org/doc/2.0/reference/random/generated/numpy.random.seed.html)
118 """
119 try:
120 try:
121 import numpy.random
122 except ImportError:
123 pass
124 else:
125 numpy.random.seed(0)
126 except Exception as e:
127 logger.debug(str(e))
129 if (
130 weight_formats is None
131 or "pytorch_state_dict" in weight_formats
132 or "torchscript" in weight_formats
133 ):
134 try:
135 try:
136 import torch
137 except ImportError:
138 pass
139 else:
140 _ = torch.manual_seed(0)
141 torch.use_deterministic_algorithms(mode == "full")
142 except Exception as e:
143 logger.debug(str(e))
145 if (
146 weight_formats is None
147 or "tensorflow_saved_model_bundle" in weight_formats
148 or "keras_hdf5" in weight_formats
149 ):
150 try:
151 os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
152 try:
153 import tensorflow as tf
154 except ImportError:
155 pass
156 else:
157 tf.random.set_seed(0)
158 if mode == "full":
159 tf.config.experimental.enable_op_determinism()
160 # TODO: find possibility to switch it off again??
161 except Exception as e:
162 logger.debug(str(e))
164 if weight_formats is None or "keras_hdf5" in weight_formats:
165 try:
166 try:
167 import keras # pyright: ignore[reportMissingTypeStubs]
168 except ImportError:
169 pass
170 else:
171 keras.utils.set_random_seed(0)
172 except Exception as e:
173 logger.debug(str(e))
176def test_model(
177 source: Union[v0_4.ModelDescr, v0_5.ModelDescr, PermissiveFileSource],
178 weight_format: Optional[SupportedWeightsFormat] = None,
179 devices: Optional[List[str]] = None,
180 *,
181 determinism: Literal["seed_only", "full"] = "seed_only",
182 sha256: Optional[Sha256] = None,
183 stop_early: bool = True,
184 working_dir: Optional[Union[os.PathLike[str], str]] = None,
185 **deprecated: Unpack[DeprecatedKwargs],
186) -> ValidationSummary:
187 """Test model inference"""
188 return test_description(
189 source,
190 weight_format=weight_format,
191 devices=devices,
192 determinism=determinism,
193 expected_type="model",
194 sha256=sha256,
195 stop_early=stop_early,
196 working_dir=working_dir,
197 **deprecated,
198 )
201def default_run_command(args: Sequence[str]):
202 logger.info("running '{}'...", " ".join(args))
203 _ = subprocess.check_call(args)
206def test_description(
207 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
208 *,
209 format_version: Union[FormatVersionPlaceholder, str] = "discover",
210 weight_format: Optional[SupportedWeightsFormat] = None,
211 devices: Optional[Sequence[str]] = None,
212 determinism: Literal["seed_only", "full"] = "seed_only",
213 expected_type: Optional[str] = None,
214 sha256: Optional[Sha256] = None,
215 stop_early: bool = True,
216 runtime_env: Union[
217 Literal["currently-active", "as-described"], Path, BioimageioCondaEnv
218 ] = ("currently-active"),
219 run_command: Callable[[Sequence[str]], None] = default_run_command,
220 working_dir: Optional[Union[os.PathLike[str], str]] = None,
221 **deprecated: Unpack[DeprecatedKwargs],
222) -> ValidationSummary:
223 """Test a bioimage.io resource dynamically,
224 for example run prediction of test tensors for models.
226 Args:
227 source: model description source.
228 weight_format: Weight format to test.
229 Default: All weight formats present in **source**.
230 devices: Devices to test with, e.g. 'cpu', 'cuda'.
231 Default (may be weight format dependent): ['cuda'] if available, ['cpu'] otherwise.
232 determinism: Modes to improve reproducibility of test outputs.
233 expected_type: Assert an expected resource description `type`.
234 sha256: Expected SHA256 value of **source**.
235 (Ignored if **source** already is a loaded `ResourceDescr` object.)
236 stop_early: Do not run further subtests after a failed one.
237 runtime_env: (Experimental feature!) The Python environment to run the tests in
238 - `"currently-active"`: Use active Python interpreter.
239 - `"as-described"`: Use `bioimageio.spec.get_conda_env` to generate a conda
240 environment YAML file based on the model weights description.
241 - A `BioimageioCondaEnv` or a path to a conda environment YAML file.
242 Note: The `bioimageio.core` dependency will be added automatically if not present.
243 run_command: (Experimental feature!) Function to execute (conda) terminal commands in a subprocess.
244 The function should raise an exception if the command fails.
245 **run_command** is ignored if **runtime_env** is `"currently-active"`.
246 working_dir: (for debugging) directory to save any temporary files
247 (model packages, conda environments, test summaries).
248 Defaults to a temporary directory.
249 """
250 if runtime_env == "currently-active":
251 rd = load_description_and_test(
252 source,
253 format_version=format_version,
254 weight_format=weight_format,
255 devices=devices,
256 determinism=determinism,
257 expected_type=expected_type,
258 sha256=sha256,
259 stop_early=stop_early,
260 working_dir=working_dir,
261 **deprecated,
262 )
263 return rd.validation_summary
265 if runtime_env == "as-described":
266 conda_env = None
267 elif isinstance(runtime_env, (str, Path)):
268 conda_env = BioimageioCondaEnv.model_validate(read_yaml(Path(runtime_env)))
269 elif isinstance(runtime_env, BioimageioCondaEnv):
270 conda_env = runtime_env
271 else:
272 assert_never(runtime_env)
274 if run_command is not default_run_command:
275 try:
276 run_command(["thiscommandshouldalwaysfail", "please"])
277 except Exception:
278 pass
279 else:
280 raise RuntimeError(
281 "given run_command does not raise an exception for a failing command"
282 )
284 verbose = working_dir is not None
285 if working_dir is None:
286 td_kwargs: Dict[str, Any] = (
287 dict(ignore_cleanup_errors=True) if sys.version_info >= (3, 10) else {}
288 )
289 working_dir_ctxt = TemporaryDirectory(**td_kwargs)
290 else:
291 working_dir_ctxt = nullcontext(working_dir)
293 with working_dir_ctxt as _d:
294 working_dir = Path(_d)
296 if isinstance(source, ResourceDescrBase):
297 descr = source
298 elif isinstance(source, dict):
299 context = get_validation_context().replace(
300 perform_io_checks=True # make sure we perform io checks though
301 )
303 descr = build_description(source, context=context)
304 else:
305 descr = load_description(source, perform_io_checks=True)
307 if isinstance(descr, InvalidDescr):
308 return descr.validation_summary
309 elif isinstance(source, (dict, ResourceDescrBase)):
310 file_source = save_bioimageio_package(
311 descr, output_path=working_dir / "package.zip"
312 )
313 else:
314 file_source = source
316 _test_in_env(
317 file_source,
318 descr=descr,
319 working_dir=working_dir,
320 weight_format=weight_format,
321 conda_env=conda_env,
322 devices=devices,
323 determinism=determinism,
324 expected_type=expected_type,
325 sha256=sha256,
326 stop_early=stop_early,
327 run_command=run_command,
328 verbose=verbose,
329 **deprecated,
330 )
332 return descr.validation_summary
335def _test_in_env(
336 source: PermissiveFileSource,
337 *,
338 descr: ResourceDescr,
339 working_dir: Path,
340 weight_format: Optional[SupportedWeightsFormat],
341 conda_env: Optional[BioimageioCondaEnv],
342 devices: Optional[Sequence[str]],
343 determinism: Literal["seed_only", "full"],
344 run_command: Callable[[Sequence[str]], None],
345 stop_early: bool,
346 expected_type: Optional[str],
347 sha256: Optional[Sha256],
348 verbose: bool,
349 **deprecated: Unpack[DeprecatedKwargs],
350):
351 """Test a bioimage.io resource in a given conda environment.
352 Adds details to the existing validation summary of **descr**.
353 """
354 if isinstance(descr, (v0_4.ModelDescr, v0_5.ModelDescr)):
355 if weight_format is None:
356 # run tests for all present weight formats
357 all_present_wfs = [
358 wf for wf in get_args(WeightsFormat) if getattr(descr.weights, wf)
359 ]
360 ignore_wfs = [wf for wf in all_present_wfs if wf in ["tensorflow_js"]]
361 logger.info(
362 "Found weight formats {}. Start testing all{}...",
363 all_present_wfs,
364 f" (except: {', '.join(ignore_wfs)}) " if ignore_wfs else "",
365 )
366 for wf in all_present_wfs:
367 _test_in_env(
368 source,
369 descr=descr,
370 working_dir=working_dir / wf,
371 weight_format=wf,
372 devices=devices,
373 determinism=determinism,
374 conda_env=conda_env,
375 run_command=run_command,
376 expected_type=expected_type,
377 sha256=sha256,
378 stop_early=stop_early,
379 verbose=verbose,
380 **deprecated,
381 )
383 return
385 if weight_format == "pytorch_state_dict":
386 wf = descr.weights.pytorch_state_dict
387 elif weight_format == "torchscript":
388 wf = descr.weights.torchscript
389 elif weight_format == "keras_hdf5":
390 wf = descr.weights.keras_hdf5
391 elif weight_format == "onnx":
392 wf = descr.weights.onnx
393 elif weight_format == "tensorflow_saved_model_bundle":
394 wf = descr.weights.tensorflow_saved_model_bundle
395 elif weight_format == "keras_v3":
396 if isinstance(descr, v0_4.ModelDescr):
397 raise ValueError(
398 "Weight format 'keras_v3' is not supported in v0.4 model descriptions. use format version >= 0.5"
399 )
401 wf = descr.weights.keras_v3
402 elif weight_format == "tensorflow_js":
403 raise RuntimeError(
404 "testing 'tensorflow_js' is not supported by bioimageio.core"
405 )
406 else:
407 assert_never(weight_format)
409 assert wf is not None
410 if conda_env is None:
411 conda_env = get_conda_env(entry=wf)
413 test_loc = ("weights", weight_format)
414 else:
415 if conda_env is None:
416 warnings.warn(
417 "No conda environment description given for testing (And no default conda envs available for non-model descriptions)."
418 )
419 return
421 test_loc = ()
423 # remove name as we create a name based on the env description hash value
424 conda_env.name = None
426 dumped_env = conda_env.model_dump(mode="json", exclude_none=True)
427 if not is_yaml_value(dumped_env):
428 raise ValueError(f"Failed to dump conda env to valid YAML {conda_env}")
430 env_io = StringIO()
431 write_yaml(dumped_env, file=env_io)
432 encoded_env = env_io.getvalue().encode()
433 env_name = hashlib.sha256(encoded_env).hexdigest()
435 try:
436 run_command(["where" if platform.system() == "Windows" else "which", CONDA_CMD])
437 except Exception as e:
438 raise RuntimeError("Conda not available") from e
440 try:
441 run_command([CONDA_CMD, "run", "-n", env_name, "python", "--version"])
442 except Exception:
443 working_dir.mkdir(parents=True, exist_ok=True)
444 path = working_dir / "env.yaml"
445 try:
446 _ = path.write_bytes(encoded_env)
447 logger.debug("written conda env to {}", path)
448 run_command(
449 [
450 CONDA_CMD,
451 "env",
452 "create",
453 "--yes",
454 f"--file={path}",
455 f"--name={env_name}",
456 ]
457 + (["--quiet"] if settings.CI else [])
458 )
459 # double check that environment was created successfully
460 run_command([CONDA_CMD, "run", "-n", env_name, "python", "--version"])
461 except Exception as e:
462 descr.validation_summary.add_detail(
463 ValidationDetail(
464 name="Conda environment creation",
465 status="failed",
466 loc=test_loc,
467 recommended_env=conda_env,
468 errors=[
469 ErrorEntry(
470 loc=test_loc,
471 msg=str(e),
472 type="conda",
473 with_traceback=True,
474 )
475 ],
476 )
477 )
478 return
479 else:
480 descr.validation_summary.add_detail(
481 ValidationDetail(
482 name=f"Created conda environment '{env_name}'",
483 status="passed",
484 loc=test_loc,
485 )
486 )
487 else:
488 descr.validation_summary.add_detail(
489 ValidationDetail(
490 name=f"Found existing conda environment '{env_name}'",
491 status="passed",
492 loc=test_loc,
493 )
494 )
496 working_dir.mkdir(parents=True, exist_ok=True)
497 summary_path = working_dir / "summary.json"
498 assert not summary_path.exists(), "Summary file already exists"
499 cmd = []
500 cmd_error = None
501 for summary_path_arg_name in ("summary", "summary-path"):
502 try:
503 run_command(
504 cmd := (
505 [
506 CONDA_CMD,
507 "run",
508 "-n",
509 env_name,
510 "bioimageio",
511 "test",
512 str(source),
513 f"--{summary_path_arg_name}={summary_path.as_posix()}",
514 f"--determinism={determinism}",
515 ]
516 + ([f"--weight-format={weight_format}"] if weight_format else [])
517 + ([f"--expected-type={expected_type}"] if expected_type else [])
518 + (["--stop-early"] if stop_early else [])
519 )
520 )
521 except Exception as e:
522 cmd_error = f"Command '{' '.join(cmd)}' returned with error: {e}."
524 if summary_path.exists():
525 break
526 else:
527 if cmd_error is not None:
528 logger.warning(cmd_error)
530 descr.validation_summary.add_detail(
531 ValidationDetail(
532 name="run 'bioimageio test' command",
533 recommended_env=conda_env,
534 errors=[
535 ErrorEntry(
536 loc=(),
537 type="bioimageio cli",
538 msg=f"test command '{' '.join(cmd)}' did not produce a summary file at {summary_path}",
539 )
540 ],
541 status="failed",
542 )
543 )
544 return
546 # add relevant details from command summary
547 command_summary = ValidationSummary.load_json(summary_path)
548 for detail in command_summary.details:
549 if detail.loc[: len(test_loc)] == test_loc or detail.status == "failed":
550 descr.validation_summary.add_detail(detail)
553@overload
554def load_description_and_test(
555 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
556 *,
557 format_version: Literal["latest"],
558 weight_format: Optional[SupportedWeightsFormat] = None,
559 devices: Optional[Sequence[str]] = None,
560 determinism: Literal["seed_only", "full"] = "seed_only",
561 expected_type: Literal["model"],
562 sha256: Optional[Sha256] = None,
563 stop_early: bool = True,
564 working_dir: Optional[Union[os.PathLike[str], str]] = None,
565 **deprecated: Unpack[DeprecatedKwargs],
566) -> Union[ModelDescr, InvalidDescr]: ...
569@overload
570def load_description_and_test(
571 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
572 *,
573 format_version: Literal["latest"],
574 weight_format: Optional[SupportedWeightsFormat] = None,
575 devices: Optional[Sequence[str]] = None,
576 determinism: Literal["seed_only", "full"] = "seed_only",
577 expected_type: Literal["dataset"],
578 sha256: Optional[Sha256] = None,
579 stop_early: bool = True,
580 working_dir: Optional[Union[os.PathLike[str], str]] = None,
581 **deprecated: Unpack[DeprecatedKwargs],
582) -> Union[DatasetDescr, InvalidDescr]: ...
585@overload
586def load_description_and_test(
587 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
588 *,
589 format_version: Literal["latest"],
590 weight_format: Optional[SupportedWeightsFormat] = None,
591 devices: Optional[Sequence[str]] = None,
592 determinism: Literal["seed_only", "full"] = "seed_only",
593 expected_type: Optional[str] = None,
594 sha256: Optional[Sha256] = None,
595 stop_early: bool = True,
596 working_dir: Optional[Union[os.PathLike[str], str]] = None,
597 **deprecated: Unpack[DeprecatedKwargs],
598) -> Union[LatestResourceDescr, InvalidDescr]: ...
601@overload
602def load_description_and_test(
603 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
604 *,
605 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER,
606 weight_format: Optional[SupportedWeightsFormat] = None,
607 devices: Optional[Sequence[str]] = None,
608 determinism: Literal["seed_only", "full"] = "seed_only",
609 expected_type: Literal["model"],
610 sha256: Optional[Sha256] = None,
611 stop_early: bool = True,
612 working_dir: Optional[Union[os.PathLike[str], str]] = None,
613 **deprecated: Unpack[DeprecatedKwargs],
614) -> Union[AnyModelDescr, InvalidDescr]: ...
617@overload
618def load_description_and_test(
619 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
620 *,
621 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER,
622 weight_format: Optional[SupportedWeightsFormat] = None,
623 devices: Optional[Sequence[str]] = None,
624 determinism: Literal["seed_only", "full"] = "seed_only",
625 expected_type: Literal["dataset"],
626 sha256: Optional[Sha256] = None,
627 stop_early: bool = True,
628 working_dir: Optional[Union[os.PathLike[str], str]] = None,
629 **deprecated: Unpack[DeprecatedKwargs],
630) -> Union[AnyDatasetDescr, InvalidDescr]: ...
633@overload
634def load_description_and_test(
635 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
636 *,
637 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER,
638 weight_format: Optional[SupportedWeightsFormat] = None,
639 devices: Optional[Sequence[str]] = None,
640 determinism: Literal["seed_only", "full"] = "seed_only",
641 expected_type: Optional[str] = None,
642 sha256: Optional[Sha256] = None,
643 stop_early: bool = True,
644 working_dir: Optional[Union[os.PathLike[str], str]] = None,
645 **deprecated: Unpack[DeprecatedKwargs],
646) -> Union[ResourceDescr, InvalidDescr]: ...
649def load_description_and_test(
650 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
651 *,
652 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER,
653 weight_format: Optional[SupportedWeightsFormat] = None,
654 devices: Optional[Sequence[str]] = None,
655 determinism: Literal["seed_only", "full"] = "seed_only",
656 expected_type: Optional[str] = None,
657 sha256: Optional[Sha256] = None,
658 stop_early: bool = True,
659 working_dir: Optional[Union[os.PathLike[str], str]] = None,
660 **deprecated: Unpack[DeprecatedKwargs],
661) -> Union[ResourceDescr, InvalidDescr]:
662 """Test a bioimage.io resource dynamically,
663 for example run prediction of test tensors for models.
665 See `test_description` for more details.
667 Returns:
668 A (possibly invalid) resource description object
669 with a populated `.validation_summary` attribute.
670 """
671 if isinstance(source, ResourceDescrBase):
672 root = source.root
673 file_name = source.file_name
674 if (
675 (
676 format_version
677 not in (
678 DISCOVER,
679 source.format_version,
680 ".".join(source.format_version.split(".")[:2]),
681 )
682 )
683 or (c := source.validation_summary.details[0].context) is None
684 or not c.perform_io_checks
685 ):
686 logger.debug(
687 "deserializing source to ensure we validate and test using format {} and perform io checks",
688 format_version,
689 )
690 source = dump_description(source)
691 else:
692 root = Path()
693 file_name = None
695 if isinstance(source, ResourceDescrBase):
696 rd = source
697 elif isinstance(source, dict):
698 # check context for a given root; default to root of source
699 context = get_validation_context(
700 ValidationContext(root=root, file_name=file_name)
701 ).replace(
702 perform_io_checks=True # make sure we perform io checks though
703 )
705 rd = build_description(
706 source,
707 format_version=format_version,
708 context=context,
709 )
710 else:
711 rd = load_description(
712 source, format_version=format_version, sha256=sha256, perform_io_checks=True
713 )
715 rd.validation_summary.env.add(
716 InstalledPackage(name="bioimageio.core", version=__version__)
717 )
719 if expected_type is not None:
720 _test_expected_resource_type(rd, expected_type)
722 if isinstance(rd, (v0_4.ModelDescr, v0_5.ModelDescr)):
723 if weight_format is None:
724 weight_formats: List[SupportedWeightsFormat] = [
725 w for w, we in rd.weights if we is not None
726 ] # pyright: ignore[reportAssignmentType]
727 else:
728 weight_formats = [weight_format]
730 enable_determinism(determinism, weight_formats=weight_formats)
731 for w in weight_formats:
732 _test_model_inference(
733 rd,
734 w,
735 devices,
736 stop_early=stop_early,
737 working_dir=working_dir,
738 verbose=working_dir is not None,
739 **deprecated,
740 )
741 if stop_early and rd.validation_summary.status != "passed":
742 break
744 if not isinstance(rd, v0_4.ModelDescr):
745 _test_model_inference_parametrized(
746 rd, w, devices, stop_early=stop_early
747 )
748 if stop_early and rd.validation_summary.status != "passed":
749 break
751 # TODO: add execution of jupyter notebooks
752 # TODO: add more tests
754 return rd
757def _get_tolerance(
758 model: Union[v0_4.ModelDescr, v0_5.ModelDescr],
759 wf: SupportedWeightsFormat,
760 m: MemberId,
761 **deprecated: Unpack[DeprecatedKwargs],
762) -> Tuple[RelativeTolerance, AbsoluteTolerance, MismatchedElementsPerMillion]:
763 if isinstance(model, v0_5.ModelDescr):
764 applicable = v0_5.ReproducibilityTolerance()
766 # check legacy test kwargs for weight format specific tolerance
767 if model.config.bioimageio.model_extra is not None:
768 for weights_format, test_kwargs in model.config.bioimageio.model_extra.get(
769 "test_kwargs", {}
770 ).items():
771 if wf == weights_format:
772 applicable = v0_5.ReproducibilityTolerance(
773 relative_tolerance=test_kwargs.get("relative_tolerance", 1e-3),
774 absolute_tolerance=test_kwargs.get("absolute_tolerance", 1e-3),
775 )
776 break
778 # check for weights format and output tensor specific tolerance
779 for a in model.config.bioimageio.reproducibility_tolerance:
780 if (not a.weights_formats or wf in a.weights_formats) and (
781 not a.output_ids or m in a.output_ids
782 ):
783 applicable = a
784 break
786 rtol = applicable.relative_tolerance
787 atol = applicable.absolute_tolerance
788 mismatched_tol = applicable.mismatched_elements_per_million
789 elif (decimal := deprecated.get("decimal")) is not None:
790 warnings.warn(
791 "The argument `decimal` has been deprecated in favour of"
792 + " `relative_tolerance` and `absolute_tolerance`, with different"
793 + " validation logic, using `numpy.testing.assert_allclose, see"
794 + " 'https://numpy.org/doc/stable/reference/generated/"
795 + " numpy.testing.assert_allclose.html'. Passing a value for `decimal`"
796 + " will cause validation to revert to the old behaviour."
797 )
798 atol = 1.5 * 10 ** (-decimal)
799 rtol = 0
800 mismatched_tol = 0
801 else:
802 # use given (deprecated) test kwargs
803 atol = deprecated.get("absolute_tolerance", 1e-3)
804 rtol = deprecated.get("relative_tolerance", 1e-3)
805 mismatched_tol = 0
807 return rtol, atol, mismatched_tol
810def _test_model_inference(
811 model: Union[v0_4.ModelDescr, v0_5.ModelDescr],
812 weight_format: SupportedWeightsFormat,
813 devices: Optional[Sequence[str]],
814 stop_early: bool,
815 *,
816 working_dir: Optional[Union[os.PathLike[str], str]],
817 verbose: bool,
818 **deprecated: Unpack[DeprecatedKwargs],
819) -> None:
820 test_name = f"Reproduce test outputs from test inputs ({weight_format})"
821 logger.debug("starting '{}'", test_name)
822 error_entries: List[ErrorEntry] = []
823 warning_entries: List[WarningEntry] = []
825 def add_error_entry(msg: str, with_traceback: bool = False):
826 error_entries.append(
827 ErrorEntry(
828 loc=("weights", weight_format),
829 msg=msg,
830 type="bioimageio.core",
831 with_traceback=with_traceback,
832 )
833 )
835 def add_warning_entry(msg: str, severity: WarningSeverity):
836 warning_entries.append(
837 WarningEntry(
838 loc=("weights", weight_format),
839 msg=msg,
840 type="bioimageio.core",
841 severity=severity,
842 )
843 )
845 def save_to_working_dir(name: str, tensor: Tensor) -> List[Path]:
846 saved_paths: List[Path] = []
847 if working_dir is not None and verbose:
848 for p in [
849 Path(working_dir) / f"{name}_{weight_format}{suffix}"
850 for suffix in (".npy", ".tiff")
851 ]:
852 try:
853 save_tensor(p, tensor)
854 except Exception as e:
855 logger.error(
856 "Failed to save tensor {}: {}",
857 p,
858 e,
859 )
860 else:
861 saved_paths.append(p)
863 return saved_paths
865 try:
866 test_input = get_test_input_sample(model)
867 expected = get_test_output_sample(model)
869 with create_prediction_pipeline(
870 bioimageio_model=model, devices=devices, weight_format=weight_format
871 ) as prediction_pipeline:
872 prediction_pipeline.apply_preprocessing(test_input)
873 test_input_preprocessed = deepcopy(test_input)
874 results_not_postprocessed = (
875 prediction_pipeline.predict_sample_without_blocking(
876 test_input, skip_postprocessing=True, skip_preprocessing=True
877 )
878 )
879 results = deepcopy(results_not_postprocessed)
880 prediction_pipeline.apply_postprocessing(results)
882 if len(results.members) != len(expected.members):
883 add_error_entry(
884 f"Expected {len(expected.members)} outputs, but got {len(results.members)}"
885 )
887 else:
888 intermediate_paths: List[Path] = []
889 for m, t in test_input_preprocessed.members.items():
890 intermediate_paths.extend(
891 save_to_working_dir(f"test_input_preprocessed_{m}", t)
892 )
893 if intermediate_paths:
894 logger.debug("Saved preprocessed test inputs to {}", intermediate_paths)
896 for m, expected in expected.members.items():
897 actual = results.members.get(m)
898 if actual is None:
899 add_error_entry("Output tensors for test case may not be None")
900 if stop_early:
901 break
902 else:
903 continue
905 if actual.dims != (dims := expected.dims):
906 add_error_entry(
907 f"Output '{m}' has dims {actual.dims}, but expected {expected.dims}"
908 )
909 if stop_early:
910 break
911 else:
912 continue
914 if actual.tagged_shape != expected.tagged_shape:
915 add_error_entry(
916 f"Output '{m}' has shape {actual.tagged_shape}, but expected {expected.tagged_shape}"
917 )
918 if stop_early:
919 break
920 else:
921 continue
923 try:
924 output_paths = save_to_working_dir(f"actual_output_{m}", actual)
925 if m in results_not_postprocessed.members:
926 output_paths.extend(
927 save_to_working_dir(
928 f"actual_output_{m}_not_postprocessed",
929 results_not_postprocessed.members[m],
930 )
931 )
933 expected_np = expected.data.to_numpy().astype(np.float32)
934 del expected
935 actual_np: NDArray[Any] = actual.data.to_numpy().astype(np.float32)
937 rtol, atol, mismatched_tol = _get_tolerance(
938 model, wf=weight_format, m=m, **deprecated
939 )
940 rtol_value = rtol * abs(expected_np)
941 abs_diff = abs(actual_np - expected_np)
942 mismatched = abs_diff > atol + rtol_value
943 mismatched_elements = mismatched.sum().item()
945 mismatched_ppm = mismatched_elements / expected_np.size * 1e6
946 abs_diff[~mismatched] = 0 # ignore non-mismatched elements
948 r_max_idx_flat = (
949 r_diff := (abs_diff / (abs(expected_np) + 1e-6))
950 ).argmax()
951 r_max_idx = np.unravel_index(r_max_idx_flat, r_diff.shape)
952 r_max = r_diff[r_max_idx].item()
953 r_actual = actual_np[r_max_idx].item()
954 r_expected = expected_np[r_max_idx].item()
956 # Calculate the max absolute difference with the relative tolerance subtracted
957 abs_diff_wo_rtol: NDArray[np.float32] = abs_diff - rtol_value
958 a_max_idx = np.unravel_index(
959 abs_diff_wo_rtol.argmax(), abs_diff_wo_rtol.shape
960 )
962 a_max = abs_diff[a_max_idx].item()
963 a_actual = actual_np[a_max_idx].item()
964 a_expected = expected_np[a_max_idx].item()
965 except Exception as e:
966 msg = f"Error while checking if '{m}' disagrees with expected values: {e}"
967 add_error_entry(msg)
968 if stop_early:
969 break
970 else:
971 if mismatched_elements:
972 msg = (
973 f"Output '{m}': {mismatched_elements} of "
974 + f"{expected_np.size} elements disagree with expected values."
975 + f" ({mismatched_ppm:.1f} ppm)."
976 )
977 else:
978 msg = f"Output `{m}`: all elements agree with expected values."
980 msg += (
981 f"\n Max relative difference not accounted for by absolute tolerance ({atol:.2e}): {r_max:.2e}"
982 + rf" (= \|{r_actual:.2e} - {r_expected:.2e}\|/\|{r_expected:.2e} + 1e-6\|)"
983 + f" at {dict(zip(dims, r_max_idx))}"
984 + f"\n Max absolute difference not accounted for by relative tolerance ({rtol:.2e}): {a_max:.2e}"
985 + rf" (= \|{a_actual:.7e} - {a_expected:.7e}\|) at {dict(zip(dims, a_max_idx))}"
986 )
987 if output_paths:
988 msg += f"\n Saved (intermediate) outputs to {output_paths}."
990 if mismatched_ppm > mismatched_tol:
991 add_error_entry(msg)
992 if stop_early:
993 break
994 else:
995 add_warning_entry(
996 msg, severity=WARNING if mismatched_elements else INFO
997 )
999 except Exception as e:
1000 if get_validation_context().raise_errors:
1001 raise e
1003 add_error_entry(str(e), with_traceback=True)
1005 model.validation_summary.add_detail(
1006 ValidationDetail(
1007 name=test_name,
1008 loc=("weights", weight_format),
1009 status="failed" if error_entries else "passed",
1010 recommended_env=get_conda_env(entry=dict(model.weights)[weight_format]),
1011 errors=error_entries,
1012 warnings=warning_entries,
1013 )
1014 )
1017def _test_model_inference_parametrized(
1018 model: v0_5.ModelDescr,
1019 weight_format: SupportedWeightsFormat,
1020 devices: Optional[Sequence[str]],
1021 *,
1022 stop_early: bool,
1023) -> None:
1024 if not any(
1025 isinstance(a.size, v0_5.ParameterizedSize)
1026 for ipt in model.inputs
1027 for a in ipt.axes
1028 ):
1029 # no parameterized sizes => set n=0
1030 ns: Set[v0_5.ParameterizedSize_N] = {0}
1031 else:
1032 ns = {0, 1, 2}
1034 given_batch_sizes = {
1035 a.size
1036 for ipt in model.inputs
1037 for a in ipt.axes
1038 if isinstance(a, v0_5.BatchAxis)
1039 }
1040 if given_batch_sizes:
1041 batch_sizes = {gbs for gbs in given_batch_sizes if gbs is not None}
1042 if not batch_sizes:
1043 # only arbitrary batch sizes
1044 batch_sizes = {1, 2}
1045 else:
1046 # no batch axis
1047 batch_sizes = {1}
1049 test_cases: Set[Tuple[BatchSize, v0_5.ParameterizedSize_N]] = {
1050 (b, n) for b, n in product(sorted(batch_sizes), sorted(ns))
1051 }
1052 logger.info(
1053 "Testing inference with '{}' for {} different inputs (B, N): {}",
1054 weight_format,
1055 len(test_cases),
1056 test_cases,
1057 )
1059 def generate_test_cases():
1060 tested: Set[Hashable] = set()
1062 def get_ns(n: int):
1063 return {
1064 (t.id, a.id): n
1065 for t in model.inputs
1066 for a in t.axes
1067 if isinstance(a.size, v0_5.ParameterizedSize)
1068 }
1070 for batch_size, n in sorted(test_cases):
1071 input_target_sizes, expected_output_sizes = model.get_axis_sizes(
1072 get_ns(n), batch_size=batch_size
1073 )
1074 hashable_target_size = tuple(
1075 (k, input_target_sizes[k]) for k in sorted(input_target_sizes)
1076 )
1077 if hashable_target_size in tested:
1078 continue
1079 else:
1080 tested.add(hashable_target_size)
1082 resized_test_inputs = Sample(
1083 members={
1084 t.id: (
1085 test_input.members[t.id].resize_to(
1086 {
1087 aid: s
1088 for (tid, aid), s in input_target_sizes.items()
1089 if tid == t.id
1090 },
1091 )
1092 )
1093 for t in model.inputs
1094 },
1095 stat=test_input.stat,
1096 id=test_input.id,
1097 )
1098 expected_output_shapes = {
1099 t.id: {
1100 aid: s
1101 for (tid, aid), s in expected_output_sizes.items()
1102 if tid == t.id
1103 }
1104 for t in model.outputs
1105 }
1106 yield n, batch_size, resized_test_inputs, expected_output_shapes
1108 try:
1109 test_input = get_test_input_sample(model)
1111 with create_prediction_pipeline(
1112 bioimageio_model=model, devices=devices, weight_format=weight_format
1113 ) as prediction_pipeline:
1114 for n, batch_size, inputs, exptected_output_shape in generate_test_cases():
1115 error: Optional[str] = None
1116 try:
1117 result = prediction_pipeline.predict_sample_without_blocking(inputs)
1118 except Exception as e:
1119 error = str(e)
1120 else:
1121 if len(result.members) != len(exptected_output_shape):
1122 error = (
1123 f"Expected {len(exptected_output_shape)} outputs,"
1124 + f" but got {len(result.members)}"
1125 )
1127 else:
1128 for m, exp in exptected_output_shape.items():
1129 res = result.members.get(m)
1130 if res is None:
1131 error = "Output tensors may not be None for test case"
1132 break
1134 diff: Dict[AxisId, int] = {}
1135 for a, s in res.sizes.items():
1136 if isinstance((e_aid := exp[AxisId(a)]), int):
1137 if s != e_aid:
1138 diff[AxisId(a)] = s
1139 elif (
1140 s < e_aid.min
1141 or e_aid.max is not None
1142 and s > e_aid.max
1143 ):
1144 diff[AxisId(a)] = s
1145 if diff:
1146 error = (
1147 f"(n={n}) Expected output shape {exp},"
1148 + f" but got {res.sizes} (diff: {diff})"
1149 )
1150 break
1152 model.validation_summary.add_detail(
1153 ValidationDetail(
1154 name=f"Run {weight_format} inference for inputs with"
1155 + f" batch_size: {batch_size} and size parameter n: {n}",
1156 loc=("weights", weight_format),
1157 status="passed" if error is None else "failed",
1158 errors=(
1159 []
1160 if error is None
1161 else [
1162 ErrorEntry(
1163 loc=("weights", weight_format),
1164 msg=error,
1165 type="bioimageio.core",
1166 )
1167 ]
1168 ),
1169 )
1170 )
1171 if stop_early and error is not None:
1172 break
1173 except Exception as e:
1174 if get_validation_context().raise_errors:
1175 raise e
1177 model.validation_summary.add_detail(
1178 ValidationDetail(
1179 name=f"Run {weight_format} inference for parametrized inputs",
1180 status="failed",
1181 loc=("weights", weight_format),
1182 errors=[
1183 ErrorEntry(
1184 loc=("weights", weight_format),
1185 msg=str(e),
1186 type="bioimageio.core",
1187 with_traceback=True,
1188 )
1189 ],
1190 )
1191 )
1194def _test_expected_resource_type(
1195 rd: Union[InvalidDescr, ResourceDescr], expected_type: str
1196):
1197 has_expected_type = rd.type == expected_type
1198 rd.validation_summary.details.append(
1199 ValidationDetail(
1200 name="Has expected resource type",
1201 status="passed" if has_expected_type else "failed",
1202 loc=("type",),
1203 errors=(
1204 []
1205 if has_expected_type
1206 else [
1207 ErrorEntry(
1208 loc=("type",),
1209 type="type",
1210 msg=f"Expected type {expected_type}, found {rd.type}",
1211 )
1212 ]
1213 ),
1214 )
1215 )
1218# TODO: Implement `debug_model()`
1219# def debug_model(
1220# model_rdf: Union[RawResourceDescr, ResourceDescr, URI, Path, str],
1221# *,
1222# weight_format: Optional[WeightsFormat] = None,
1223# devices: Optional[List[str]] = None,
1224# ):
1225# """Run the model test and return dict with inputs, results, expected results and intermediates.
1227# Returns dict with tensors "inputs", "inputs_processed", "outputs_raw", "outputs", "expected" and "diff".
1228# """
1229# inputs_raw: Optional = None
1230# inputs_processed: Optional = None
1231# outputs_raw: Optional = None
1232# outputs: Optional = None
1233# expected: Optional = None
1234# diff: Optional = None
1236# model = load_description(
1237# model_rdf, weights_priority_order=None if weight_format is None else [weight_format]
1238# )
1239# if not isinstance(model, Model):
1240# raise ValueError(f"Not a bioimageio.model: {model_rdf}")
1242# prediction_pipeline = create_prediction_pipeline(
1243# bioimageio_model=model, devices=devices, weight_format=weight_format
1244# )
1245# inputs = [
1246# xr.DataArray(load_array(str(in_path)), dims=input_spec.axes)
1247# for in_path, input_spec in zip(model.test_inputs, model.inputs)
1248# ]
1249# input_dict = {input_spec.name: input for input_spec, input in zip(model.inputs, inputs)}
1251# # keep track of the non-processed inputs
1252# inputs_raw = [deepcopy(input) for input in inputs]
1254# computed_measures = {}
1256# prediction_pipeline.apply_preprocessing(input_dict, computed_measures)
1257# inputs_processed = list(input_dict.values())
1258# outputs_raw = prediction_pipeline.predict(*inputs_processed)
1259# output_dict = {output_spec.name: deepcopy(output) for output_spec, output in zip(model.outputs, outputs_raw)}
1260# prediction_pipeline.apply_postprocessing(output_dict, computed_measures)
1261# outputs = list(output_dict.values())
1263# if isinstance(outputs, (np.ndarray, xr.DataArray)):
1264# outputs = [outputs]
1266# expected = [
1267# xr.DataArray(load_array(str(out_path)), dims=output_spec.axes)
1268# for out_path, output_spec in zip(model.test_outputs, model.outputs)
1269# ]
1270# if len(outputs) != len(expected):
1271# error = f"Number of outputs and number of expected outputs disagree: {len(outputs)} != {len(expected)}"
1272# print(error)
1273# else:
1274# diff = []
1275# for res, exp in zip(outputs, expected):
1276# diff.append(res - exp)
1278# return {
1279# "inputs": inputs_raw,
1280# "inputs_processed": inputs_processed,
1281# "outputs_raw": outputs_raw,
1282# "outputs": outputs,
1283# "expected": expected,
1284# "diff": diff,
1285# }