Coverage for src / bioimageio / core / _resource_tests.py: 73%
426 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-15 23:26 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-15 23:26 +0000
1import hashlib
2import os
3import platform
4import subprocess
5import sys
6import warnings
7from contextlib import nullcontext
8from copy import deepcopy
9from io import StringIO
10from itertools import product
11from pathlib import Path
12from tempfile import TemporaryDirectory
13from typing import (
14 Any,
15 Callable,
16 Dict,
17 Hashable,
18 List,
19 Literal,
20 Optional,
21 Sequence,
22 Set,
23 Tuple,
24 Union,
25 overload,
26)
28import numpy as np
29from loguru import logger
30from numpy.typing import NDArray
31from typing_extensions import NotRequired, TypedDict, Unpack, assert_never, get_args
33from bioimageio.core.tensor import Tensor
34from bioimageio.spec import (
35 AnyDatasetDescr,
36 AnyModelDescr,
37 BioimageioCondaEnv,
38 DatasetDescr,
39 InvalidDescr,
40 LatestResourceDescr,
41 ModelDescr,
42 ResourceDescr,
43 ValidationContext,
44 build_description,
45 dump_description,
46 get_conda_env,
47 load_description,
48 save_bioimageio_package,
49)
50from bioimageio.spec._description_impl import DISCOVER
51from bioimageio.spec._internal.common_nodes import ResourceDescrBase
52from bioimageio.spec._internal.io import is_yaml_value
53from bioimageio.spec._internal.io_utils import read_yaml, write_yaml
54from bioimageio.spec._internal.types import (
55 AbsoluteTolerance,
56 FormatVersionPlaceholder,
57 MismatchedElementsPerMillion,
58 RelativeTolerance,
59)
60from bioimageio.spec._internal.validation_context import get_validation_context
61from bioimageio.spec._internal.warning_levels import INFO, WARNING, WarningSeverity
62from bioimageio.spec.common import BioimageioYamlContent, PermissiveFileSource, Sha256
63from bioimageio.spec.model import v0_4, v0_5
64from bioimageio.spec.model.v0_5 import WeightsFormat
65from bioimageio.spec.summary import (
66 ErrorEntry,
67 InstalledPackage,
68 ValidationDetail,
69 ValidationSummary,
70 WarningEntry,
71)
73from . import __version__
74from ._prediction_pipeline import create_prediction_pipeline
75from ._settings import settings
76from .axis import AxisId, BatchSize
77from .common import MemberId, SupportedWeightsFormat
78from .digest_spec import get_test_input_sample, get_test_output_sample
79from .io import save_tensor
80from .sample import Sample
82CONDA_CMD = "conda.bat" if platform.system() == "Windows" else "conda"
85class DeprecatedKwargs(TypedDict):
86 absolute_tolerance: NotRequired[AbsoluteTolerance]
87 relative_tolerance: NotRequired[RelativeTolerance]
88 decimal: NotRequired[Optional[int]]
91def enable_determinism(
92 mode: Literal["seed_only", "full"] = "full",
93 weight_formats: Optional[Sequence[SupportedWeightsFormat]] = None,
94):
95 """Seed and configure ML frameworks for maximum reproducibility.
96 May degrade performance. Only recommended for testing reproducibility!
98 Seed any random generators and (if **mode**=="full") request ML frameworks to use
99 deterministic algorithms.
101 Args:
102 mode: determinism mode
103 - 'seed_only' -- only set seeds, or
104 - 'full' determinsm features (might degrade performance or throw exceptions)
105 weight_formats: Limit deep learning importing deep learning frameworks
106 based on weight_formats.
107 E.g. this allows to avoid importing tensorflow when testing with pytorch.
109 Notes:
110 - **mode** == "full" might degrade performance or throw exceptions.
111 - Subsequent inference calls might still differ. Call before each function
112 (sequence) that is expected to be reproducible.
113 - Degraded performance: Use for testing reproducibility only!
114 - Recipes:
115 - [PyTorch](https://pytorch.org/docs/stable/notes/randomness.html)
116 - [Keras](https://keras.io/examples/keras_recipes/reproducibility_recipes/)
117 - [NumPy](https://numpy.org/doc/2.0/reference/random/generated/numpy.random.seed.html)
118 """
119 try:
120 try:
121 import numpy.random
122 except ImportError:
123 pass
124 else:
125 numpy.random.seed(0)
126 except Exception as e:
127 logger.debug(str(e))
129 if (
130 weight_formats is None
131 or "pytorch_state_dict" in weight_formats
132 or "torchscript" in weight_formats
133 ):
134 try:
135 try:
136 import torch
137 except ImportError:
138 pass
139 else:
140 _ = torch.manual_seed(0)
141 torch.use_deterministic_algorithms(mode == "full")
142 except Exception as e:
143 logger.debug(str(e))
145 if (
146 weight_formats is None
147 or "tensorflow_saved_model_bundle" in weight_formats
148 or "keras_hdf5" in weight_formats
149 ):
150 try:
151 os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
152 try:
153 import tensorflow as tf
154 except ImportError:
155 pass
156 else:
157 tf.random.set_seed(0)
158 if mode == "full":
159 tf.config.experimental.enable_op_determinism()
160 # TODO: find possibility to switch it off again??
161 except Exception as e:
162 logger.debug(str(e))
164 if weight_formats is None or "keras_hdf5" in weight_formats:
165 try:
166 try:
167 import keras # pyright: ignore[reportMissingTypeStubs]
168 except ImportError:
169 pass
170 else:
171 keras.utils.set_random_seed(0)
172 except Exception as e:
173 logger.debug(str(e))
176def test_model(
177 source: Union[v0_4.ModelDescr, v0_5.ModelDescr, PermissiveFileSource],
178 weight_format: Optional[SupportedWeightsFormat] = None,
179 devices: Optional[List[str]] = None,
180 *,
181 determinism: Literal["seed_only", "full"] = "seed_only",
182 sha256: Optional[Sha256] = None,
183 stop_early: bool = False,
184 working_dir: Optional[Union[os.PathLike[str], str]] = None,
185 **deprecated: Unpack[DeprecatedKwargs],
186) -> ValidationSummary:
187 """Test model inference"""
188 return test_description(
189 source,
190 weight_format=weight_format,
191 devices=devices,
192 determinism=determinism,
193 expected_type="model",
194 sha256=sha256,
195 stop_early=stop_early,
196 working_dir=working_dir,
197 **deprecated,
198 )
201def default_run_command(args: Sequence[str]):
202 logger.info("running '{}'...", " ".join(args))
203 _ = subprocess.check_call(args)
206def test_description(
207 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
208 *,
209 format_version: Union[FormatVersionPlaceholder, str] = "discover",
210 weight_format: Optional[SupportedWeightsFormat] = None,
211 devices: Optional[Sequence[str]] = None,
212 determinism: Literal["seed_only", "full"] = "seed_only",
213 expected_type: Optional[str] = None,
214 sha256: Optional[Sha256] = None,
215 stop_early: bool = False,
216 runtime_env: Union[
217 Literal["currently-active", "as-described"], Path, BioimageioCondaEnv
218 ] = ("currently-active"),
219 run_command: Callable[[Sequence[str]], None] = default_run_command,
220 working_dir: Optional[Union[os.PathLike[str], str]] = None,
221 **deprecated: Unpack[DeprecatedKwargs],
222) -> ValidationSummary:
223 """Test a bioimage.io resource dynamically,
224 for example run prediction of test tensors for models.
226 Args:
227 source: model description source.
228 weight_format: Weight format to test.
229 Default: All weight formats present in **source**.
230 devices: Devices to test with, e.g. 'cpu', 'cuda'.
231 Default (may be weight format dependent): ['cuda'] if available, ['cpu'] otherwise.
232 determinism: Modes to improve reproducibility of test outputs.
233 expected_type: Assert an expected resource description `type`.
234 sha256: Expected SHA256 value of **source**.
235 (Ignored if **source** already is a loaded `ResourceDescr` object.)
236 stop_early: Do not run further subtests after a failed one.
237 runtime_env: (Experimental feature!) The Python environment to run the tests in
238 - `"currently-active"`: Use active Python interpreter.
239 - `"as-described"`: Use `bioimageio.spec.get_conda_env` to generate a conda
240 environment YAML file based on the model weights description.
241 - A `BioimageioCondaEnv` or a path to a conda environment YAML file.
242 Note: The `bioimageio.core` dependency will be added automatically if not present.
243 run_command: (Experimental feature!) Function to execute (conda) terminal commands in a subprocess.
244 The function should raise an exception if the command fails.
245 **run_command** is ignored if **runtime_env** is `"currently-active"`.
246 working_dir: (for debugging) directory to save any temporary files
247 (model packages, conda environments, test summaries).
248 Defaults to a temporary directory.
249 """
250 if runtime_env == "currently-active":
251 rd = load_description_and_test(
252 source,
253 format_version=format_version,
254 weight_format=weight_format,
255 devices=devices,
256 determinism=determinism,
257 expected_type=expected_type,
258 sha256=sha256,
259 stop_early=stop_early,
260 working_dir=working_dir,
261 **deprecated,
262 )
263 return rd.validation_summary
265 if runtime_env == "as-described":
266 conda_env = None
267 elif isinstance(runtime_env, (str, Path)):
268 conda_env = BioimageioCondaEnv.model_validate(read_yaml(Path(runtime_env)))
269 elif isinstance(runtime_env, BioimageioCondaEnv):
270 conda_env = runtime_env
271 else:
272 assert_never(runtime_env)
274 if run_command is not default_run_command:
275 try:
276 run_command(["thiscommandshouldalwaysfail", "please"])
277 except Exception:
278 pass
279 else:
280 raise RuntimeError(
281 "given run_command does not raise an exception for a failing command"
282 )
284 verbose = working_dir is not None
285 if working_dir is None:
286 td_kwargs: Dict[str, Any] = (
287 dict(ignore_cleanup_errors=True) if sys.version_info >= (3, 10) else {}
288 )
289 working_dir_ctxt = TemporaryDirectory(**td_kwargs)
290 else:
291 working_dir_ctxt = nullcontext(working_dir)
293 with working_dir_ctxt as _d:
294 working_dir = Path(_d)
296 if isinstance(source, ResourceDescrBase):
297 descr = source
298 elif isinstance(source, dict):
299 context = get_validation_context().replace(
300 perform_io_checks=True # make sure we perform io checks though
301 )
303 descr = build_description(source, context=context)
304 else:
305 descr = load_description(source, perform_io_checks=True)
307 if isinstance(descr, InvalidDescr):
308 return descr.validation_summary
309 elif isinstance(source, (dict, ResourceDescrBase)):
310 file_source = save_bioimageio_package(
311 descr, output_path=working_dir / "package.zip"
312 )
313 else:
314 file_source = source
316 # elevate status valid-format to passed and start testing
317 descr.validation_summary.status = "passed"
318 _test_in_env(
319 file_source,
320 descr=descr,
321 working_dir=working_dir,
322 weight_format=weight_format,
323 conda_env=conda_env,
324 devices=devices,
325 determinism=determinism,
326 expected_type=expected_type,
327 sha256=sha256,
328 stop_early=stop_early,
329 run_command=run_command,
330 verbose=verbose,
331 **deprecated,
332 )
334 return descr.validation_summary
337def _test_in_env(
338 source: PermissiveFileSource,
339 *,
340 descr: ResourceDescr,
341 working_dir: Path,
342 weight_format: Optional[SupportedWeightsFormat],
343 conda_env: Optional[BioimageioCondaEnv],
344 devices: Optional[Sequence[str]],
345 determinism: Literal["seed_only", "full"],
346 run_command: Callable[[Sequence[str]], None],
347 stop_early: bool,
348 expected_type: Optional[str],
349 sha256: Optional[Sha256],
350 verbose: bool,
351 **deprecated: Unpack[DeprecatedKwargs],
352):
353 """Test a bioimage.io resource in a given conda environment.
354 Adds details to the existing validation summary of **descr**.
355 """
356 if isinstance(descr, (v0_4.ModelDescr, v0_5.ModelDescr)):
357 if weight_format is None:
358 # run tests for all present weight formats
359 all_present_wfs = [
360 wf for wf in get_args(WeightsFormat) if getattr(descr.weights, wf)
361 ]
362 ignore_wfs = [wf for wf in all_present_wfs if wf in ["tensorflow_js"]]
363 logger.info(
364 "Found weight formats {}. Start testing all{}...",
365 all_present_wfs,
366 f" (except: {', '.join(ignore_wfs)}) " if ignore_wfs else "",
367 )
368 for wf in all_present_wfs:
369 _test_in_env(
370 source,
371 descr=descr,
372 working_dir=working_dir / wf,
373 weight_format=wf,
374 devices=devices,
375 determinism=determinism,
376 conda_env=conda_env,
377 run_command=run_command,
378 expected_type=expected_type,
379 sha256=sha256,
380 stop_early=stop_early,
381 verbose=verbose,
382 **deprecated,
383 )
385 return
387 if weight_format == "pytorch_state_dict":
388 wf = descr.weights.pytorch_state_dict
389 elif weight_format == "torchscript":
390 wf = descr.weights.torchscript
391 elif weight_format == "keras_hdf5":
392 wf = descr.weights.keras_hdf5
393 elif weight_format == "onnx":
394 wf = descr.weights.onnx
395 elif weight_format == "tensorflow_saved_model_bundle":
396 wf = descr.weights.tensorflow_saved_model_bundle
397 elif weight_format == "keras_v3":
398 if isinstance(descr, v0_4.ModelDescr):
399 raise ValueError(
400 "Weight format 'keras_v3' is not supported in v0.4 model descriptions. use format version >= 0.5"
401 )
403 wf = descr.weights.keras_v3
404 elif weight_format == "tensorflow_js":
405 raise RuntimeError(
406 "testing 'tensorflow_js' is not supported by bioimageio.core"
407 )
408 else:
409 assert_never(weight_format)
411 assert wf is not None
412 if conda_env is None:
413 conda_env = get_conda_env(entry=wf)
415 test_loc = ("weights", weight_format)
416 else:
417 if conda_env is None:
418 warnings.warn(
419 "No conda environment description given for testing (And no default conda envs available for non-model descriptions)."
420 )
421 return
423 test_loc = ()
425 # remove name as we create a name based on the env description hash value
426 conda_env.name = None
428 dumped_env = conda_env.model_dump(mode="json", exclude_none=True)
429 if not is_yaml_value(dumped_env):
430 raise ValueError(f"Failed to dump conda env to valid YAML {conda_env}")
432 env_io = StringIO()
433 write_yaml(dumped_env, file=env_io)
434 encoded_env = env_io.getvalue().encode()
435 env_name = hashlib.sha256(encoded_env).hexdigest()
437 try:
438 run_command(["where" if platform.system() == "Windows" else "which", CONDA_CMD])
439 except Exception as e:
440 raise RuntimeError("Conda not available") from e
442 try:
443 run_command([CONDA_CMD, "run", "-n", env_name, "python", "--version"])
444 except Exception:
445 working_dir.mkdir(parents=True, exist_ok=True)
446 path = working_dir / "env.yaml"
447 try:
448 _ = path.write_bytes(encoded_env)
449 logger.debug("written conda env to {}", path)
450 run_command(
451 [
452 CONDA_CMD,
453 "env",
454 "create",
455 "--yes",
456 f"--file={path}",
457 f"--name={env_name}",
458 ]
459 + (["--quiet"] if settings.CI else [])
460 )
461 # double check that environment was created successfully
462 run_command([CONDA_CMD, "run", "-n", env_name, "python", "--version"])
463 except Exception as e:
464 descr.validation_summary.add_detail(
465 ValidationDetail(
466 name="Conda environment creation",
467 status="failed",
468 loc=test_loc,
469 recommended_env=conda_env,
470 errors=[
471 ErrorEntry(
472 loc=test_loc,
473 msg=str(e),
474 type="conda",
475 with_traceback=True,
476 )
477 ],
478 )
479 )
480 return
481 else:
482 descr.validation_summary.add_detail(
483 ValidationDetail(
484 name=f"Created conda environment '{env_name}'",
485 status="passed",
486 loc=test_loc,
487 )
488 )
489 else:
490 descr.validation_summary.add_detail(
491 ValidationDetail(
492 name=f"Found existing conda environment '{env_name}'",
493 status="passed",
494 loc=test_loc,
495 )
496 )
498 working_dir.mkdir(parents=True, exist_ok=True)
499 summary_path = working_dir / "summary.json"
500 assert not summary_path.exists(), "Summary file already exists"
501 cmd = []
502 cmd_error = None
503 for summary_path_arg_name in ("summary", "summary-path"):
504 try:
505 run_command(
506 cmd := (
507 [
508 CONDA_CMD,
509 "run",
510 "-n",
511 env_name,
512 "bioimageio",
513 "test",
514 str(source),
515 f"--{summary_path_arg_name}={summary_path.as_posix()}",
516 f"--determinism={determinism}",
517 ]
518 + ([f"--weight-format={weight_format}"] if weight_format else [])
519 + ([f"--expected-type={expected_type}"] if expected_type else [])
520 + (["--stop-early"] if stop_early else [])
521 )
522 )
523 except Exception as e:
524 cmd_error = f"Command '{' '.join(cmd)}' returned with error: {e}."
526 if summary_path.exists():
527 break
528 else:
529 if cmd_error is not None:
530 logger.warning(cmd_error)
532 descr.validation_summary.add_detail(
533 ValidationDetail(
534 name="run 'bioimageio test' command",
535 recommended_env=conda_env,
536 errors=[
537 ErrorEntry(
538 loc=(),
539 type="bioimageio cli",
540 msg=f"test command '{' '.join(cmd)}' did not produce a summary file at {summary_path}",
541 )
542 ],
543 status="failed",
544 )
545 )
546 return
548 # add relevant details from command summary
549 command_summary = ValidationSummary.load_json(summary_path)
550 for detail in command_summary.details:
551 if detail.loc[: len(test_loc)] == test_loc or detail.status == "failed":
552 descr.validation_summary.add_detail(detail)
555@overload
556def load_description_and_test(
557 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
558 *,
559 format_version: Literal["latest"],
560 weight_format: Optional[SupportedWeightsFormat] = None,
561 devices: Optional[Sequence[str]] = None,
562 determinism: Literal["seed_only", "full"] = "seed_only",
563 expected_type: Literal["model"],
564 sha256: Optional[Sha256] = None,
565 stop_early: bool = False,
566 working_dir: Optional[Union[os.PathLike[str], str]] = None,
567 **deprecated: Unpack[DeprecatedKwargs],
568) -> Union[ModelDescr, InvalidDescr]: ...
571@overload
572def load_description_and_test(
573 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
574 *,
575 format_version: Literal["latest"],
576 weight_format: Optional[SupportedWeightsFormat] = None,
577 devices: Optional[Sequence[str]] = None,
578 determinism: Literal["seed_only", "full"] = "seed_only",
579 expected_type: Literal["dataset"],
580 sha256: Optional[Sha256] = None,
581 stop_early: bool = False,
582 working_dir: Optional[Union[os.PathLike[str], str]] = None,
583 **deprecated: Unpack[DeprecatedKwargs],
584) -> Union[DatasetDescr, InvalidDescr]: ...
587@overload
588def load_description_and_test(
589 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
590 *,
591 format_version: Literal["latest"],
592 weight_format: Optional[SupportedWeightsFormat] = None,
593 devices: Optional[Sequence[str]] = None,
594 determinism: Literal["seed_only", "full"] = "seed_only",
595 expected_type: Optional[str] = None,
596 sha256: Optional[Sha256] = None,
597 stop_early: bool = False,
598 working_dir: Optional[Union[os.PathLike[str], str]] = None,
599 **deprecated: Unpack[DeprecatedKwargs],
600) -> Union[LatestResourceDescr, InvalidDescr]: ...
603@overload
604def load_description_and_test(
605 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
606 *,
607 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER,
608 weight_format: Optional[SupportedWeightsFormat] = None,
609 devices: Optional[Sequence[str]] = None,
610 determinism: Literal["seed_only", "full"] = "seed_only",
611 expected_type: Literal["model"],
612 sha256: Optional[Sha256] = None,
613 stop_early: bool = False,
614 working_dir: Optional[Union[os.PathLike[str], str]] = None,
615 **deprecated: Unpack[DeprecatedKwargs],
616) -> Union[AnyModelDescr, InvalidDescr]: ...
619@overload
620def load_description_and_test(
621 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
622 *,
623 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER,
624 weight_format: Optional[SupportedWeightsFormat] = None,
625 devices: Optional[Sequence[str]] = None,
626 determinism: Literal["seed_only", "full"] = "seed_only",
627 expected_type: Literal["dataset"],
628 sha256: Optional[Sha256] = None,
629 stop_early: bool = False,
630 working_dir: Optional[Union[os.PathLike[str], str]] = None,
631 **deprecated: Unpack[DeprecatedKwargs],
632) -> Union[AnyDatasetDescr, InvalidDescr]: ...
635@overload
636def load_description_and_test(
637 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
638 *,
639 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER,
640 weight_format: Optional[SupportedWeightsFormat] = None,
641 devices: Optional[Sequence[str]] = None,
642 determinism: Literal["seed_only", "full"] = "seed_only",
643 expected_type: Optional[str] = None,
644 sha256: Optional[Sha256] = None,
645 stop_early: bool = False,
646 working_dir: Optional[Union[os.PathLike[str], str]] = None,
647 **deprecated: Unpack[DeprecatedKwargs],
648) -> Union[ResourceDescr, InvalidDescr]: ...
651def load_description_and_test(
652 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
653 *,
654 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER,
655 weight_format: Optional[SupportedWeightsFormat] = None,
656 devices: Optional[Sequence[str]] = None,
657 determinism: Literal["seed_only", "full"] = "seed_only",
658 expected_type: Optional[str] = None,
659 sha256: Optional[Sha256] = None,
660 stop_early: bool = False,
661 working_dir: Optional[Union[os.PathLike[str], str]] = None,
662 **deprecated: Unpack[DeprecatedKwargs],
663) -> Union[ResourceDescr, InvalidDescr]:
664 """Test a bioimage.io resource dynamically,
665 for example run prediction of test tensors for models.
667 See `test_description` for more details.
669 Returns:
670 A (possibly invalid) resource description object
671 with a populated `.validation_summary` attribute.
672 """
673 if isinstance(source, ResourceDescrBase):
674 root = source.root
675 file_name = source.file_name
676 if (
677 (
678 format_version
679 not in (
680 DISCOVER,
681 source.format_version,
682 ".".join(source.format_version.split(".")[:2]),
683 )
684 )
685 or (c := source.validation_summary.details[0].context) is None
686 or not c.perform_io_checks
687 ):
688 logger.debug(
689 "deserializing source to ensure we validate and test using format {} and perform io checks",
690 format_version,
691 )
692 source = dump_description(source)
693 else:
694 root = Path()
695 file_name = None
697 if isinstance(source, ResourceDescrBase):
698 rd = source
699 elif isinstance(source, dict):
700 # check context for a given root; default to root of source
701 context = get_validation_context(
702 ValidationContext(root=root, file_name=file_name)
703 ).replace(
704 perform_io_checks=True # make sure we perform io checks though
705 )
707 rd = build_description(
708 source,
709 format_version=format_version,
710 context=context,
711 )
712 else:
713 rd = load_description(
714 source, format_version=format_version, sha256=sha256, perform_io_checks=True
715 )
717 rd.validation_summary.env.add(
718 InstalledPackage(name="bioimageio.core", version=__version__)
719 )
721 if expected_type is not None:
722 has_expected_type = _test_expected_resource_type(rd, expected_type)
723 if not has_expected_type:
724 # unexpected type -> invalid format
725 rd.validation_summary.status = "failed"
726 return rd
728 # elevate status valid-format to passed and start testing
729 if rd.validation_summary.status == "valid-format":
730 rd.validation_summary.status = "passed"
732 if isinstance(rd, (v0_4.ModelDescr, v0_5.ModelDescr)):
733 if weight_format is None:
734 weight_formats: List[SupportedWeightsFormat] = [
735 w for w, we in rd.weights if we is not None
736 ] # pyright: ignore[reportAssignmentType]
737 else:
738 weight_formats = [weight_format]
740 enable_determinism(determinism, weight_formats=weight_formats)
741 for w in weight_formats:
742 passed_recreate_test_outputs = _test_recreate_test_outputs(
743 rd,
744 w,
745 devices,
746 stop_early=stop_early,
747 working_dir=working_dir,
748 verbose=working_dir is not None,
749 **deprecated,
750 )
752 if stop_early and not passed_recreate_test_outputs:
753 break
755 if not isinstance(rd, v0_4.ModelDescr):
756 passed_parametrized_inference = _test_parametrized_inference(
757 rd, w, devices, stop_early=stop_early
758 )
759 if stop_early and not passed_parametrized_inference:
760 break
762 # TODO: add execution of jupyter notebooks
763 # TODO: add more tests
765 return rd
768def _get_tolerance(
769 model: Union[v0_4.ModelDescr, v0_5.ModelDescr],
770 wf: SupportedWeightsFormat,
771 m: MemberId,
772 **deprecated: Unpack[DeprecatedKwargs],
773) -> Tuple[RelativeTolerance, AbsoluteTolerance, MismatchedElementsPerMillion]:
774 if isinstance(model, v0_5.ModelDescr):
775 applicable = v0_5.ReproducibilityTolerance()
777 # check legacy test kwargs for weight format specific tolerance
778 if model.config.bioimageio.model_extra is not None:
779 for weights_format, test_kwargs in model.config.bioimageio.model_extra.get(
780 "test_kwargs", {}
781 ).items():
782 if wf == weights_format:
783 applicable = v0_5.ReproducibilityTolerance(
784 relative_tolerance=test_kwargs.get("relative_tolerance", 1e-3),
785 absolute_tolerance=test_kwargs.get("absolute_tolerance", 1e-3),
786 )
787 break
789 # check for weights format and output tensor specific tolerance
790 for a in model.config.bioimageio.reproducibility_tolerance:
791 if (not a.weights_formats or wf in a.weights_formats) and (
792 not a.output_ids or m in a.output_ids
793 ):
794 applicable = a
795 break
797 rtol = applicable.relative_tolerance
798 atol = applicable.absolute_tolerance
799 mismatched_tol = applicable.mismatched_elements_per_million
800 elif (decimal := deprecated.get("decimal")) is not None:
801 warnings.warn(
802 "The argument `decimal` has been deprecated in favour of"
803 + " `relative_tolerance` and `absolute_tolerance`, with different"
804 + " validation logic, using `numpy.testing.assert_allclose, see"
805 + " 'https://numpy.org/doc/stable/reference/generated/"
806 + " numpy.testing.assert_allclose.html'. Passing a value for `decimal`"
807 + " will cause validation to revert to the old behaviour."
808 )
809 atol = 1.5 * 10 ** (-decimal)
810 rtol = 0
811 mismatched_tol = 0
812 else:
813 # use given (deprecated) test kwargs
814 atol = deprecated.get("absolute_tolerance", 1e-3)
815 rtol = deprecated.get("relative_tolerance", 1e-3)
816 mismatched_tol = 0
818 return rtol, atol, mismatched_tol
821def _test_recreate_test_outputs(
822 model: Union[v0_4.ModelDescr, v0_5.ModelDescr],
823 weight_format: SupportedWeightsFormat,
824 devices: Optional[Sequence[str]],
825 stop_early: bool,
826 *,
827 working_dir: Optional[Union[os.PathLike[str], str]],
828 verbose: bool,
829 **deprecated: Unpack[DeprecatedKwargs],
830) -> bool:
831 test_name = f"Reproduce test outputs from test inputs ({weight_format})"
832 logger.debug("starting '{}'", test_name)
833 error_entries: List[ErrorEntry] = []
834 warning_entries: List[WarningEntry] = []
836 def add_error_entry(msg: str, with_traceback: bool = False):
837 error_entries.append(
838 ErrorEntry(
839 loc=("weights", weight_format),
840 msg=msg,
841 type="bioimageio.core",
842 with_traceback=with_traceback,
843 )
844 )
846 def add_warning_entry(msg: str, severity: WarningSeverity):
847 warning_entries.append(
848 WarningEntry(
849 loc=("weights", weight_format),
850 msg=msg,
851 type="bioimageio.core",
852 severity=severity,
853 )
854 )
856 def save_to_working_dir(name: str, tensor: Tensor) -> List[Path]:
857 saved_paths: List[Path] = []
858 if working_dir is not None and verbose:
859 for p in [
860 Path(working_dir) / f"{name}_{weight_format}{suffix}"
861 for suffix in (".npy", ".tiff")
862 ]:
863 try:
864 save_tensor(p, tensor)
865 except Exception as e:
866 logger.error(
867 "Failed to save tensor {}: {}",
868 p,
869 e,
870 )
871 else:
872 saved_paths.append(p)
874 return saved_paths
876 try:
877 test_input = get_test_input_sample(model)
878 expected = get_test_output_sample(model)
880 with create_prediction_pipeline(
881 bioimageio_model=model, devices=devices, weight_format=weight_format
882 ) as prediction_pipeline:
883 prediction_pipeline.apply_preprocessing(test_input)
884 test_input_preprocessed = deepcopy(test_input)
885 results_not_postprocessed = (
886 prediction_pipeline.predict_sample_without_blocking(
887 test_input, skip_postprocessing=True, skip_preprocessing=True
888 )
889 )
890 results = deepcopy(results_not_postprocessed)
891 prediction_pipeline.apply_postprocessing(results)
893 if len(results.members) != len(expected.members):
894 add_error_entry(
895 f"Expected {len(expected.members)} outputs, but got {len(results.members)}"
896 )
898 else:
899 intermediate_paths: List[Path] = []
900 for m, t in test_input_preprocessed.members.items():
901 intermediate_paths.extend(
902 save_to_working_dir(f"test_input_preprocessed_{m}", t)
903 )
904 if intermediate_paths:
905 logger.debug("Saved preprocessed test inputs to {}", intermediate_paths)
907 for m, expected in expected.members.items():
908 actual = results.members.get(m)
909 if actual is None:
910 add_error_entry("Output tensors for test case may not be None")
911 if stop_early:
912 break
913 else:
914 continue
916 if actual.dims != (dims := expected.dims):
917 add_error_entry(
918 f"Output '{m}' has dims {actual.dims}, but expected {expected.dims}"
919 )
920 if stop_early:
921 break
922 else:
923 continue
925 if actual.tagged_shape != expected.tagged_shape:
926 add_error_entry(
927 f"Output '{m}' has shape {actual.tagged_shape}, but expected {expected.tagged_shape}"
928 )
929 if stop_early:
930 break
931 else:
932 continue
934 try:
935 output_paths = save_to_working_dir(f"actual_output_{m}", actual)
936 if m in results_not_postprocessed.members:
937 output_paths.extend(
938 save_to_working_dir(
939 f"actual_output_{m}_not_postprocessed",
940 results_not_postprocessed.members[m],
941 )
942 )
944 expected_np = expected.data.to_numpy().astype(np.float32)
945 del expected
946 actual_np: NDArray[Any] = actual.data.to_numpy().astype(np.float32)
948 rtol, atol, mismatched_tol = _get_tolerance(
949 model, wf=weight_format, m=m, **deprecated
950 )
951 rtol_value = rtol * abs(expected_np)
952 abs_diff = abs(actual_np - expected_np)
953 mismatched = abs_diff > atol + rtol_value
954 mismatched_elements = mismatched.sum().item()
956 mismatched_ppm = mismatched_elements / expected_np.size * 1e6
957 abs_diff[~mismatched] = 0 # ignore non-mismatched elements
959 r_max_idx_flat = (
960 r_diff := (abs_diff / (abs(expected_np) + 1e-6))
961 ).argmax()
962 r_max_idx = np.unravel_index(r_max_idx_flat, r_diff.shape)
963 r_max = r_diff[r_max_idx].item()
964 r_actual = actual_np[r_max_idx].item()
965 r_expected = expected_np[r_max_idx].item()
967 # Calculate the max absolute difference with the relative tolerance subtracted
968 abs_diff_wo_rtol: NDArray[np.float32] = abs_diff - rtol_value
969 a_max_idx = np.unravel_index(
970 abs_diff_wo_rtol.argmax(), abs_diff_wo_rtol.shape
971 )
973 a_max = abs_diff[a_max_idx].item()
974 a_actual = actual_np[a_max_idx].item()
975 a_expected = expected_np[a_max_idx].item()
976 except Exception as e:
977 msg = f"Error while checking if '{m}' disagrees with expected values: {e}"
978 add_error_entry(msg)
979 if stop_early:
980 break
981 else:
982 if mismatched_elements:
983 msg = (
984 f"Output '{m}': {mismatched_elements} of "
985 + f"{expected_np.size} elements disagree with expected values."
986 + f" ({mismatched_ppm:.1f} ppm). "
987 )
988 else:
989 msg = f"Output `{m}`: all elements agree with expected values. "
991 msg += (
992 f"\nMax relative difference not accounted for by absolute tolerance ({atol:.2e}):\n{r_max:.2e}"
993 + rf" (= \|{r_actual:.2e} - {r_expected:.2e}\|/\|{r_expected:.2e} + 1e-6\|)"
994 + f" at {dict(zip(dims, r_max_idx))} "
995 + f"\nMax absolute difference not accounted for by relative tolerance ({rtol:.2e}):\n{a_max:.2e}"
996 + rf" (= \|{a_actual:.7e} - {a_expected:.7e}\|) at {dict(zip(dims, a_max_idx))}"
997 )
998 if output_paths:
999 msg += f"\n Saved (intermediate) outputs to {output_paths}."
1001 if mismatched_ppm > mismatched_tol:
1002 add_error_entry(msg)
1003 if stop_early:
1004 break
1005 else:
1006 add_warning_entry(
1007 msg, severity=WARNING if mismatched_elements else INFO
1008 )
1010 except Exception as e:
1011 if get_validation_context().raise_errors:
1012 raise e
1014 add_error_entry(str(e), with_traceback=True)
1016 model.validation_summary.add_detail(
1017 ValidationDetail(
1018 name=test_name,
1019 loc=("weights", weight_format),
1020 status="failed" if error_entries else "passed",
1021 recommended_env=get_conda_env(entry=dict(model.weights)[weight_format]),
1022 errors=error_entries,
1023 warnings=warning_entries,
1024 )
1025 )
1026 return bool(error_entries)
1029def _test_parametrized_inference(
1030 model: v0_5.ModelDescr,
1031 weight_format: SupportedWeightsFormat,
1032 devices: Optional[Sequence[str]],
1033 *,
1034 stop_early: bool,
1035) -> None:
1036 if not any(
1037 isinstance(a.size, v0_5.ParameterizedSize)
1038 for ipt in model.inputs
1039 for a in ipt.axes
1040 ):
1041 # no parameterized sizes => set n=0
1042 ns: Set[v0_5.ParameterizedSize_N] = {0}
1043 else:
1044 ns = {0, 1, 2}
1046 given_batch_sizes = {
1047 a.size
1048 for ipt in model.inputs
1049 for a in ipt.axes
1050 if isinstance(a, v0_5.BatchAxis)
1051 }
1052 if given_batch_sizes:
1053 batch_sizes = {gbs for gbs in given_batch_sizes if gbs is not None}
1054 if not batch_sizes:
1055 # only arbitrary batch sizes
1056 batch_sizes = {1, 2}
1057 else:
1058 # no batch axis
1059 batch_sizes = {1}
1061 test_cases: Set[Tuple[BatchSize, v0_5.ParameterizedSize_N]] = {
1062 (b, n) for b, n in product(sorted(batch_sizes), sorted(ns))
1063 }
1064 logger.info(
1065 "Testing inference with '{}' for {} different inputs (B, N): {}",
1066 weight_format,
1067 len(test_cases),
1068 test_cases,
1069 )
1071 def generate_test_cases():
1072 tested: Set[Hashable] = set()
1074 def get_ns(n: int):
1075 return {
1076 (t.id, a.id): n
1077 for t in model.inputs
1078 for a in t.axes
1079 if isinstance(a.size, v0_5.ParameterizedSize)
1080 }
1082 for batch_size, n in sorted(test_cases):
1083 input_target_sizes, expected_output_sizes = model.get_axis_sizes(
1084 get_ns(n), batch_size=batch_size
1085 )
1086 hashable_target_size = tuple(
1087 (k, input_target_sizes[k]) for k in sorted(input_target_sizes)
1088 )
1089 if hashable_target_size in tested:
1090 continue
1091 else:
1092 tested.add(hashable_target_size)
1094 resized_test_inputs = Sample(
1095 members={
1096 t.id: (
1097 test_input.members[t.id].resize_to(
1098 {
1099 aid: s
1100 for (tid, aid), s in input_target_sizes.items()
1101 if tid == t.id
1102 },
1103 )
1104 )
1105 for t in model.inputs
1106 },
1107 stat=test_input.stat,
1108 id=test_input.id,
1109 )
1110 expected_output_shapes = {
1111 t.id: {
1112 aid: s
1113 for (tid, aid), s in expected_output_sizes.items()
1114 if tid == t.id
1115 }
1116 for t in model.outputs
1117 }
1118 yield n, batch_size, resized_test_inputs, expected_output_shapes
1120 try:
1121 test_input = get_test_input_sample(model)
1123 with create_prediction_pipeline(
1124 bioimageio_model=model, devices=devices, weight_format=weight_format
1125 ) as prediction_pipeline:
1126 for n, batch_size, inputs, exptected_output_shape in generate_test_cases():
1127 error: Optional[str] = None
1128 try:
1129 result = prediction_pipeline.predict_sample_without_blocking(inputs)
1130 except Exception as e:
1131 error = str(e)
1132 else:
1133 if len(result.members) != len(exptected_output_shape):
1134 error = (
1135 f"Expected {len(exptected_output_shape)} outputs,"
1136 + f" but got {len(result.members)}"
1137 )
1139 else:
1140 for m, exp in exptected_output_shape.items():
1141 res = result.members.get(m)
1142 if res is None:
1143 error = "Output tensors may not be None for test case"
1144 break
1146 diff: Dict[AxisId, int] = {}
1147 for a, s in res.sizes.items():
1148 if isinstance((e_aid := exp[AxisId(a)]), int):
1149 if s != e_aid:
1150 diff[AxisId(a)] = s
1151 elif (
1152 s < e_aid.min
1153 or e_aid.max is not None
1154 and s > e_aid.max
1155 ):
1156 diff[AxisId(a)] = s
1157 if diff:
1158 error = (
1159 f"(n={n}) Expected output shape {exp},"
1160 + f" but got {res.sizes} (diff: {diff})"
1161 )
1162 break
1164 model.validation_summary.add_detail(
1165 ValidationDetail(
1166 name=f"Run {weight_format} inference for inputs with"
1167 + f" batch_size: {batch_size} and size parameter n: {n}",
1168 loc=("weights", weight_format),
1169 status="passed" if error is None else "failed",
1170 errors=(
1171 []
1172 if error is None
1173 else [
1174 ErrorEntry(
1175 loc=("weights", weight_format),
1176 msg=error,
1177 type="bioimageio.core",
1178 )
1179 ]
1180 ),
1181 )
1182 )
1183 if stop_early and error is not None:
1184 break
1185 except Exception as e:
1186 if get_validation_context().raise_errors:
1187 raise e
1189 model.validation_summary.add_detail(
1190 ValidationDetail(
1191 name=f"Run {weight_format} inference for parametrized inputs",
1192 status="failed",
1193 loc=("weights", weight_format),
1194 errors=[
1195 ErrorEntry(
1196 loc=("weights", weight_format),
1197 msg=str(e),
1198 type="bioimageio.core",
1199 with_traceback=True,
1200 )
1201 ],
1202 )
1203 )
1206def _test_expected_resource_type(
1207 rd: Union[InvalidDescr, ResourceDescr], expected_type: str
1208):
1209 has_expected_type = rd.type is expected_type
1210 rd.validation_summary.details.append(
1211 ValidationDetail(
1212 name="Has expected resource type",
1213 status="passed" if has_expected_type else "failed",
1214 loc=("type",),
1215 errors=(
1216 []
1217 if has_expected_type
1218 else [
1219 ErrorEntry(
1220 loc=("type",),
1221 type="type",
1222 msg=f"Expected type {expected_type}, found {rd.type}",
1223 )
1224 ]
1225 ),
1226 )
1227 )
1228 return has_expected_type
1231# TODO: Implement `debug_model()`
1232# def debug_model(
1233# model_rdf: Union[RawResourceDescr, ResourceDescr, URI, Path, str],
1234# *,
1235# weight_format: Optional[WeightsFormat] = None,
1236# devices: Optional[List[str]] = None,
1237# ):
1238# """Run the model test and return dict with inputs, results, expected results and intermediates.
1240# Returns dict with tensors "inputs", "inputs_processed", "outputs_raw", "outputs", "expected" and "diff".
1241# """
1242# inputs_raw: Optional = None
1243# inputs_processed: Optional = None
1244# outputs_raw: Optional = None
1245# outputs: Optional = None
1246# expected: Optional = None
1247# diff: Optional = None
1249# model = load_description(
1250# model_rdf, weights_priority_order=None if weight_format is None else [weight_format]
1251# )
1252# if not isinstance(model, Model):
1253# raise ValueError(f"Not a bioimageio.model: {model_rdf}")
1255# prediction_pipeline = create_prediction_pipeline(
1256# bioimageio_model=model, devices=devices, weight_format=weight_format
1257# )
1258# inputs = [
1259# xr.DataArray(load_array(str(in_path)), dims=input_spec.axes)
1260# for in_path, input_spec in zip(model.test_inputs, model.inputs)
1261# ]
1262# input_dict = {input_spec.name: input for input_spec, input in zip(model.inputs, inputs)}
1264# # keep track of the non-processed inputs
1265# inputs_raw = [deepcopy(input) for input in inputs]
1267# computed_measures = {}
1269# prediction_pipeline.apply_preprocessing(input_dict, computed_measures)
1270# inputs_processed = list(input_dict.values())
1271# outputs_raw = prediction_pipeline.predict(*inputs_processed)
1272# output_dict = {output_spec.name: deepcopy(output) for output_spec, output in zip(model.outputs, outputs_raw)}
1273# prediction_pipeline.apply_postprocessing(output_dict, computed_measures)
1274# outputs = list(output_dict.values())
1276# if isinstance(outputs, (np.ndarray, xr.DataArray)):
1277# outputs = [outputs]
1279# expected = [
1280# xr.DataArray(load_array(str(out_path)), dims=output_spec.axes)
1281# for out_path, output_spec in zip(model.test_outputs, model.outputs)
1282# ]
1283# if len(outputs) != len(expected):
1284# error = f"Number of outputs and number of expected outputs disagree: {len(outputs)} != {len(expected)}"
1285# print(error)
1286# else:
1287# diff = []
1288# for res, exp in zip(outputs, expected):
1289# diff.append(res - exp)
1291# return {
1292# "inputs": inputs_raw,
1293# "inputs_processed": inputs_processed,
1294# "outputs_raw": outputs_raw,
1295# "outputs": outputs,
1296# "expected": expected,
1297# "diff": diff,
1298# }