Coverage for src/bioimageio/core/_resource_tests.py: 52%
370 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-14 08:35 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-14 08:35 +0000
1import hashlib
2import os
3import platform
4import subprocess
5import sys
6import warnings
7from io import StringIO
8from itertools import product
9from pathlib import Path
10from tempfile import TemporaryDirectory
11from typing import (
12 Any,
13 Callable,
14 Dict,
15 Hashable,
16 List,
17 Literal,
18 Optional,
19 Sequence,
20 Set,
21 Tuple,
22 Union,
23 overload,
24)
26import numpy as np
27from bioimageio.spec import (
28 AnyDatasetDescr,
29 AnyModelDescr,
30 BioimageioCondaEnv,
31 DatasetDescr,
32 InvalidDescr,
33 LatestResourceDescr,
34 ModelDescr,
35 ResourceDescr,
36 ValidationContext,
37 build_description,
38 dump_description,
39 get_conda_env,
40 load_description,
41 save_bioimageio_package,
42)
43from bioimageio.spec._description_impl import DISCOVER
44from bioimageio.spec._internal.common_nodes import ResourceDescrBase
45from bioimageio.spec._internal.io import is_yaml_value
46from bioimageio.spec._internal.io_utils import read_yaml, write_yaml
47from bioimageio.spec._internal.types import (
48 AbsoluteTolerance,
49 FormatVersionPlaceholder,
50 MismatchedElementsPerMillion,
51 RelativeTolerance,
52)
53from bioimageio.spec._internal.validation_context import get_validation_context
54from bioimageio.spec.common import BioimageioYamlContent, PermissiveFileSource, Sha256
55from bioimageio.spec.model import v0_4, v0_5
56from bioimageio.spec.model.v0_5 import WeightsFormat
57from bioimageio.spec.summary import (
58 ErrorEntry,
59 InstalledPackage,
60 ValidationDetail,
61 ValidationSummary,
62 WarningEntry,
63)
64from loguru import logger
65from numpy.typing import NDArray
66from typing_extensions import NotRequired, TypedDict, Unpack, assert_never, get_args
68from bioimageio.core import __version__
69from bioimageio.core.io import save_tensor
71from ._prediction_pipeline import create_prediction_pipeline
72from ._settings import settings
73from .axis import AxisId, BatchSize
74from .common import MemberId, SupportedWeightsFormat
75from .digest_spec import get_test_input_sample, get_test_output_sample
76from .sample import Sample
78CONDA_CMD = "conda.bat" if platform.system() == "Windows" else "conda"
81class DeprecatedKwargs(TypedDict):
82 absolute_tolerance: NotRequired[AbsoluteTolerance]
83 relative_tolerance: NotRequired[RelativeTolerance]
84 decimal: NotRequired[Optional[int]]
87def enable_determinism(
88 mode: Literal["seed_only", "full"] = "full",
89 weight_formats: Optional[Sequence[SupportedWeightsFormat]] = None,
90):
91 """Seed and configure ML frameworks for maximum reproducibility.
92 May degrade performance. Only recommended for testing reproducibility!
94 Seed any random generators and (if **mode**=="full") request ML frameworks to use
95 deterministic algorithms.
97 Args:
98 mode: determinism mode
99 - 'seed_only' -- only set seeds, or
100 - 'full' determinsm features (might degrade performance or throw exceptions)
101 weight_formats: Limit deep learning importing deep learning frameworks
102 based on weight_formats.
103 E.g. this allows to avoid importing tensorflow when testing with pytorch.
105 Notes:
106 - **mode** == "full" might degrade performance or throw exceptions.
107 - Subsequent inference calls might still differ. Call before each function
108 (sequence) that is expected to be reproducible.
109 - Degraded performance: Use for testing reproducibility only!
110 - Recipes:
111 - [PyTorch](https://pytorch.org/docs/stable/notes/randomness.html)
112 - [Keras](https://keras.io/examples/keras_recipes/reproducibility_recipes/)
113 - [NumPy](https://numpy.org/doc/2.0/reference/random/generated/numpy.random.seed.html)
114 """
115 try:
116 try:
117 import numpy.random
118 except ImportError:
119 pass
120 else:
121 numpy.random.seed(0)
122 except Exception as e:
123 logger.debug(str(e))
125 if (
126 weight_formats is None
127 or "pytorch_state_dict" in weight_formats
128 or "torchscript" in weight_formats
129 ):
130 try:
131 try:
132 import torch
133 except ImportError:
134 pass
135 else:
136 _ = torch.manual_seed(0)
137 torch.use_deterministic_algorithms(mode == "full")
138 except Exception as e:
139 logger.debug(str(e))
141 if (
142 weight_formats is None
143 or "tensorflow_saved_model_bundle" in weight_formats
144 or "keras_hdf5" in weight_formats
145 ):
146 try:
147 os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
148 try:
149 import tensorflow as tf # pyright: ignore[reportMissingTypeStubs]
150 except ImportError:
151 pass
152 else:
153 tf.random.set_seed(0)
154 if mode == "full":
155 tf.config.experimental.enable_op_determinism()
156 # TODO: find possibility to switch it off again??
157 except Exception as e:
158 logger.debug(str(e))
160 if weight_formats is None or "keras_hdf5" in weight_formats:
161 try:
162 try:
163 import keras # pyright: ignore[reportMissingTypeStubs]
164 except ImportError:
165 pass
166 else:
167 keras.utils.set_random_seed(0)
168 except Exception as e:
169 logger.debug(str(e))
172def test_model(
173 source: Union[v0_4.ModelDescr, v0_5.ModelDescr, PermissiveFileSource],
174 weight_format: Optional[SupportedWeightsFormat] = None,
175 devices: Optional[List[str]] = None,
176 *,
177 determinism: Literal["seed_only", "full"] = "seed_only",
178 sha256: Optional[Sha256] = None,
179 stop_early: bool = True,
180 **deprecated: Unpack[DeprecatedKwargs],
181) -> ValidationSummary:
182 """Test model inference"""
183 return test_description(
184 source,
185 weight_format=weight_format,
186 devices=devices,
187 determinism=determinism,
188 expected_type="model",
189 sha256=sha256,
190 stop_early=stop_early,
191 **deprecated,
192 )
195def default_run_command(args: Sequence[str]):
196 logger.info("running '{}'...", " ".join(args))
197 _ = subprocess.check_call(args)
200def test_description(
201 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
202 *,
203 format_version: Union[FormatVersionPlaceholder, str] = "discover",
204 weight_format: Optional[SupportedWeightsFormat] = None,
205 devices: Optional[Sequence[str]] = None,
206 determinism: Literal["seed_only", "full"] = "seed_only",
207 expected_type: Optional[str] = None,
208 sha256: Optional[Sha256] = None,
209 stop_early: bool = True,
210 runtime_env: Union[
211 Literal["currently-active", "as-described"], Path, BioimageioCondaEnv
212 ] = ("currently-active"),
213 run_command: Callable[[Sequence[str]], None] = default_run_command,
214 **deprecated: Unpack[DeprecatedKwargs],
215) -> ValidationSummary:
216 """Test a bioimage.io resource dynamically,
217 for example run prediction of test tensors for models.
219 Args:
220 source: model description source.
221 weight_format: Weight format to test.
222 Default: All weight formats present in **source**.
223 devices: Devices to test with, e.g. 'cpu', 'cuda'.
224 Default (may be weight format dependent): ['cuda'] if available, ['cpu'] otherwise.
225 determinism: Modes to improve reproducibility of test outputs.
226 expected_type: Assert an expected resource description `type`.
227 sha256: Expected SHA256 value of **source**.
228 (Ignored if **source** already is a loaded `ResourceDescr` object.)
229 stop_early: Do not run further subtests after a failed one.
230 runtime_env: (Experimental feature!) The Python environment to run the tests in
231 - `"currently-active"`: Use active Python interpreter.
232 - `"as-described"`: Use `bioimageio.spec.get_conda_env` to generate a conda
233 environment YAML file based on the model weights description.
234 - A `BioimageioCondaEnv` or a path to a conda environment YAML file.
235 Note: The `bioimageio.core` dependency will be added automatically if not present.
236 run_command: (Experimental feature!) Function to execute (conda) terminal commands in a subprocess.
237 The function should raise an exception if the command fails.
238 **run_command** is ignored if **runtime_env** is `"currently-active"`.
239 """
240 if runtime_env == "currently-active":
241 rd = load_description_and_test(
242 source,
243 format_version=format_version,
244 weight_format=weight_format,
245 devices=devices,
246 determinism=determinism,
247 expected_type=expected_type,
248 sha256=sha256,
249 stop_early=stop_early,
250 **deprecated,
251 )
252 return rd.validation_summary
254 if runtime_env == "as-described":
255 conda_env = None
256 elif isinstance(runtime_env, (str, Path)):
257 conda_env = BioimageioCondaEnv.model_validate(read_yaml(Path(runtime_env)))
258 elif isinstance(runtime_env, BioimageioCondaEnv):
259 conda_env = runtime_env
260 else:
261 assert_never(runtime_env)
263 try:
264 run_command(["thiscommandshouldalwaysfail", "please"])
265 except Exception:
266 pass
267 else:
268 raise RuntimeError(
269 "given run_command does not raise an exception for a failing command"
270 )
272 td_kwargs: Dict[str, Any] = (
273 dict(ignore_cleanup_errors=True) if sys.version_info >= (3, 10) else {}
274 )
275 with TemporaryDirectory(**td_kwargs) as _d:
276 working_dir = Path(_d)
277 if isinstance(source, (dict, ResourceDescrBase)):
278 file_source = save_bioimageio_package(
279 source, output_path=working_dir / "package.zip"
280 )
281 else:
282 file_source = source
284 return _test_in_env(
285 file_source,
286 working_dir=working_dir,
287 weight_format=weight_format,
288 conda_env=conda_env,
289 devices=devices,
290 determinism=determinism,
291 expected_type=expected_type,
292 sha256=sha256,
293 stop_early=stop_early,
294 run_command=run_command,
295 **deprecated,
296 )
299def _test_in_env(
300 source: PermissiveFileSource,
301 *,
302 working_dir: Path,
303 weight_format: Optional[SupportedWeightsFormat],
304 conda_env: Optional[BioimageioCondaEnv],
305 devices: Optional[Sequence[str]],
306 determinism: Literal["seed_only", "full"],
307 run_command: Callable[[Sequence[str]], None],
308 stop_early: bool,
309 expected_type: Optional[str],
310 sha256: Optional[Sha256],
311 **deprecated: Unpack[DeprecatedKwargs],
312) -> ValidationSummary:
313 descr = load_description(source)
315 if not isinstance(descr, (v0_4.ModelDescr, v0_5.ModelDescr)):
316 raise NotImplementedError("Not yet implemented for non-model resources")
318 if weight_format is None:
319 all_present_wfs = [
320 wf for wf in get_args(WeightsFormat) if getattr(descr.weights, wf)
321 ]
322 ignore_wfs = [wf for wf in all_present_wfs if wf in ["tensorflow_js"]]
323 logger.info(
324 "Found weight formats {}. Start testing all{}...",
325 all_present_wfs,
326 f" (except: {', '.join(ignore_wfs)}) " if ignore_wfs else "",
327 )
328 summary = _test_in_env(
329 source,
330 working_dir=working_dir / all_present_wfs[0],
331 weight_format=all_present_wfs[0],
332 devices=devices,
333 determinism=determinism,
334 conda_env=conda_env,
335 run_command=run_command,
336 expected_type=expected_type,
337 sha256=sha256,
338 stop_early=stop_early,
339 **deprecated,
340 )
341 for wf in all_present_wfs[1:]:
342 additional_summary = _test_in_env(
343 source,
344 working_dir=working_dir / wf,
345 weight_format=wf,
346 devices=devices,
347 determinism=determinism,
348 conda_env=conda_env,
349 run_command=run_command,
350 expected_type=expected_type,
351 sha256=sha256,
352 stop_early=stop_early,
353 **deprecated,
354 )
355 for d in additional_summary.details:
356 # TODO: filter reduntant details; group details
357 summary.add_detail(d)
358 return summary
360 if weight_format == "pytorch_state_dict":
361 wf = descr.weights.pytorch_state_dict
362 elif weight_format == "torchscript":
363 wf = descr.weights.torchscript
364 elif weight_format == "keras_hdf5":
365 wf = descr.weights.keras_hdf5
366 elif weight_format == "onnx":
367 wf = descr.weights.onnx
368 elif weight_format == "tensorflow_saved_model_bundle":
369 wf = descr.weights.tensorflow_saved_model_bundle
370 elif weight_format == "tensorflow_js":
371 raise RuntimeError(
372 "testing 'tensorflow_js' is not supported by bioimageio.core"
373 )
374 else:
375 assert_never(weight_format)
377 assert wf is not None
378 if conda_env is None:
379 conda_env = get_conda_env(entry=wf)
381 # remove name as we crate a name based on the env description hash value
382 conda_env.name = None
384 dumped_env = conda_env.model_dump(mode="json", exclude_none=True)
385 if not is_yaml_value(dumped_env):
386 raise ValueError(f"Failed to dump conda env to valid YAML {conda_env}")
388 env_io = StringIO()
389 write_yaml(dumped_env, file=env_io)
390 encoded_env = env_io.getvalue().encode()
391 env_name = hashlib.sha256(encoded_env).hexdigest()
393 try:
394 run_command(["where" if platform.system() == "Windows" else "which", CONDA_CMD])
395 except Exception as e:
396 raise RuntimeError("Conda not available") from e
398 try:
399 run_command([CONDA_CMD, "run", "-n", env_name, "python", "--version"])
400 except Exception as e:
401 working_dir.mkdir(parents=True, exist_ok=True)
402 path = working_dir / "env.yaml"
403 try:
404 _ = path.write_bytes(encoded_env)
405 logger.debug("written conda env to {}", path)
406 run_command(
407 [
408 CONDA_CMD,
409 "env",
410 "create",
411 "--yes",
412 f"--file={path}",
413 f"--name={env_name}",
414 ]
415 + (["--quiet"] if settings.CI else [])
416 )
417 # double check that environment was created successfully
418 run_command([CONDA_CMD, "run", "-n", env_name, "python", "--version"])
419 except Exception as e:
420 summary = descr.validation_summary
421 summary.add_detail(
422 ValidationDetail(
423 name="Conda environment creation",
424 status="failed",
425 loc=("weights", weight_format),
426 recommended_env=conda_env,
427 errors=[
428 ErrorEntry(
429 loc=("weights", weight_format),
430 msg=str(e),
431 type="conda",
432 with_traceback=True,
433 )
434 ],
435 )
436 )
437 return summary
439 working_dir.mkdir(parents=True, exist_ok=True)
440 summary_path = working_dir / "summary.json"
441 assert not summary_path.exists(), "Summary file already exists"
442 cmd = []
443 cmd_error = None
444 for summary_path_arg_name in ("summary", "summary-path"):
445 try:
446 run_command(
447 cmd := (
448 [
449 CONDA_CMD,
450 "run",
451 "-n",
452 env_name,
453 "bioimageio",
454 "test",
455 str(source),
456 f"--{summary_path_arg_name}={summary_path.as_posix()}",
457 f"--determinism={determinism}",
458 ]
459 + ([f"--expected-type={expected_type}"] if expected_type else [])
460 + (["--stop-early"] if stop_early else [])
461 )
462 )
463 except Exception as e:
464 cmd_error = f"Failed to run command '{' '.join(cmd)}': {e}."
466 if summary_path.exists():
467 break
468 else:
469 if cmd_error is not None:
470 logger.warning(cmd_error)
472 return ValidationSummary(
473 name="calling bioimageio test command",
474 source_name=str(source),
475 status="failed",
476 type="unknown",
477 format_version="unknown",
478 details=[
479 ValidationDetail(
480 name="run 'bioimageio test'",
481 errors=[
482 ErrorEntry(
483 loc=(),
484 type="bioimageio cli",
485 msg=f"test command '{' '.join(cmd)}' did not produce a summary file at {summary_path}",
486 )
487 ],
488 status="failed",
489 )
490 ],
491 env=set(),
492 )
494 return ValidationSummary.load_json(summary_path)
497@overload
498def load_description_and_test(
499 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
500 *,
501 format_version: Literal["latest"],
502 weight_format: Optional[SupportedWeightsFormat] = None,
503 devices: Optional[Sequence[str]] = None,
504 determinism: Literal["seed_only", "full"] = "seed_only",
505 expected_type: Literal["model"],
506 sha256: Optional[Sha256] = None,
507 stop_early: bool = True,
508 **deprecated: Unpack[DeprecatedKwargs],
509) -> Union[ModelDescr, InvalidDescr]: ...
512@overload
513def load_description_and_test(
514 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
515 *,
516 format_version: Literal["latest"],
517 weight_format: Optional[SupportedWeightsFormat] = None,
518 devices: Optional[Sequence[str]] = None,
519 determinism: Literal["seed_only", "full"] = "seed_only",
520 expected_type: Literal["dataset"],
521 sha256: Optional[Sha256] = None,
522 stop_early: bool = True,
523 **deprecated: Unpack[DeprecatedKwargs],
524) -> Union[DatasetDescr, InvalidDescr]: ...
527@overload
528def load_description_and_test(
529 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
530 *,
531 format_version: Literal["latest"],
532 weight_format: Optional[SupportedWeightsFormat] = None,
533 devices: Optional[Sequence[str]] = None,
534 determinism: Literal["seed_only", "full"] = "seed_only",
535 expected_type: Optional[str] = None,
536 sha256: Optional[Sha256] = None,
537 stop_early: bool = True,
538 **deprecated: Unpack[DeprecatedKwargs],
539) -> Union[LatestResourceDescr, InvalidDescr]: ...
542@overload
543def load_description_and_test(
544 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
545 *,
546 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER,
547 weight_format: Optional[SupportedWeightsFormat] = None,
548 devices: Optional[Sequence[str]] = None,
549 determinism: Literal["seed_only", "full"] = "seed_only",
550 expected_type: Literal["model"],
551 sha256: Optional[Sha256] = None,
552 stop_early: bool = True,
553 **deprecated: Unpack[DeprecatedKwargs],
554) -> Union[AnyModelDescr, InvalidDescr]: ...
557@overload
558def load_description_and_test(
559 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
560 *,
561 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER,
562 weight_format: Optional[SupportedWeightsFormat] = None,
563 devices: Optional[Sequence[str]] = None,
564 determinism: Literal["seed_only", "full"] = "seed_only",
565 expected_type: Literal["dataset"],
566 sha256: Optional[Sha256] = None,
567 stop_early: bool = True,
568 **deprecated: Unpack[DeprecatedKwargs],
569) -> Union[AnyDatasetDescr, InvalidDescr]: ...
572@overload
573def load_description_and_test(
574 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
575 *,
576 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER,
577 weight_format: Optional[SupportedWeightsFormat] = None,
578 devices: Optional[Sequence[str]] = None,
579 determinism: Literal["seed_only", "full"] = "seed_only",
580 expected_type: Optional[str] = None,
581 sha256: Optional[Sha256] = None,
582 stop_early: bool = True,
583 **deprecated: Unpack[DeprecatedKwargs],
584) -> Union[ResourceDescr, InvalidDescr]: ...
587def load_description_and_test(
588 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
589 *,
590 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER,
591 weight_format: Optional[SupportedWeightsFormat] = None,
592 devices: Optional[Sequence[str]] = None,
593 determinism: Literal["seed_only", "full"] = "seed_only",
594 expected_type: Optional[str] = None,
595 sha256: Optional[Sha256] = None,
596 stop_early: bool = True,
597 **deprecated: Unpack[DeprecatedKwargs],
598) -> Union[ResourceDescr, InvalidDescr]:
599 """Test a bioimage.io resource dynamically,
600 for example run prediction of test tensors for models.
602 See `test_description` for more details.
604 Returns:
605 A (possibly invalid) resource description object
606 with a populated `.validation_summary` attribute.
607 """
608 if isinstance(source, ResourceDescrBase):
609 root = source.root
610 file_name = source.file_name
611 if (
612 (
613 format_version
614 not in (
615 DISCOVER,
616 source.format_version,
617 ".".join(source.format_version.split(".")[:2]),
618 )
619 )
620 or (c := source.validation_summary.details[0].context) is None
621 or not c.perform_io_checks
622 ):
623 logger.debug(
624 "deserializing source to ensure we validate and test using format {} and perform io checks",
625 format_version,
626 )
627 source = dump_description(source)
628 else:
629 root = Path()
630 file_name = None
632 if isinstance(source, ResourceDescrBase):
633 rd = source
634 elif isinstance(source, dict):
635 # check context for a given root; default to root of source
636 context = get_validation_context(
637 ValidationContext(root=root, file_name=file_name)
638 ).replace(
639 perform_io_checks=True # make sure we perform io checks though
640 )
642 rd = build_description(
643 source,
644 format_version=format_version,
645 context=context,
646 )
647 else:
648 rd = load_description(
649 source, format_version=format_version, sha256=sha256, perform_io_checks=True
650 )
652 rd.validation_summary.env.add(
653 InstalledPackage(name="bioimageio.core", version=__version__)
654 )
656 if expected_type is not None:
657 _test_expected_resource_type(rd, expected_type)
659 if isinstance(rd, (v0_4.ModelDescr, v0_5.ModelDescr)):
660 if weight_format is None:
661 weight_formats: List[SupportedWeightsFormat] = [
662 w for w, we in rd.weights if we is not None
663 ] # pyright: ignore[reportAssignmentType]
664 else:
665 weight_formats = [weight_format]
667 enable_determinism(determinism, weight_formats=weight_formats)
668 for w in weight_formats:
669 _test_model_inference(rd, w, devices, stop_early=stop_early, **deprecated)
670 if stop_early and rd.validation_summary.status == "failed":
671 break
673 if not isinstance(rd, v0_4.ModelDescr):
674 _test_model_inference_parametrized(
675 rd, w, devices, stop_early=stop_early
676 )
677 if stop_early and rd.validation_summary.status == "failed":
678 break
680 # TODO: add execution of jupyter notebooks
681 # TODO: add more tests
683 if rd.validation_summary.status == "valid-format":
684 rd.validation_summary.status = "passed"
686 return rd
689def _get_tolerance(
690 model: Union[v0_4.ModelDescr, v0_5.ModelDescr],
691 wf: SupportedWeightsFormat,
692 m: MemberId,
693 **deprecated: Unpack[DeprecatedKwargs],
694) -> Tuple[RelativeTolerance, AbsoluteTolerance, MismatchedElementsPerMillion]:
695 if isinstance(model, v0_5.ModelDescr):
696 applicable = v0_5.ReproducibilityTolerance()
698 # check legacy test kwargs for weight format specific tolerance
699 if model.config.bioimageio.model_extra is not None:
700 for weights_format, test_kwargs in model.config.bioimageio.model_extra.get(
701 "test_kwargs", {}
702 ).items():
703 if wf == weights_format:
704 applicable = v0_5.ReproducibilityTolerance(
705 relative_tolerance=test_kwargs.get("relative_tolerance", 1e-3),
706 absolute_tolerance=test_kwargs.get("absolute_tolerance", 1e-4),
707 )
708 break
710 # check for weights format and output tensor specific tolerance
711 for a in model.config.bioimageio.reproducibility_tolerance:
712 if (not a.weights_formats or wf in a.weights_formats) and (
713 not a.output_ids or m in a.output_ids
714 ):
715 applicable = a
716 break
718 rtol = applicable.relative_tolerance
719 atol = applicable.absolute_tolerance
720 mismatched_tol = applicable.mismatched_elements_per_million
721 elif (decimal := deprecated.get("decimal")) is not None:
722 warnings.warn(
723 "The argument `decimal` has been deprecated in favour of"
724 + " `relative_tolerance` and `absolute_tolerance`, with different"
725 + " validation logic, using `numpy.testing.assert_allclose, see"
726 + " 'https://numpy.org/doc/stable/reference/generated/"
727 + " numpy.testing.assert_allclose.html'. Passing a value for `decimal`"
728 + " will cause validation to revert to the old behaviour."
729 )
730 atol = 1.5 * 10 ** (-decimal)
731 rtol = 0
732 mismatched_tol = 0
733 else:
734 # use given (deprecated) test kwargs
735 atol = deprecated.get("absolute_tolerance", 1e-5)
736 rtol = deprecated.get("relative_tolerance", 1e-3)
737 mismatched_tol = 0
739 return rtol, atol, mismatched_tol
742def _test_model_inference(
743 model: Union[v0_4.ModelDescr, v0_5.ModelDescr],
744 weight_format: SupportedWeightsFormat,
745 devices: Optional[Sequence[str]],
746 stop_early: bool,
747 **deprecated: Unpack[DeprecatedKwargs],
748) -> None:
749 test_name = f"Reproduce test outputs from test inputs ({weight_format})"
750 logger.debug("starting '{}'", test_name)
751 error_entries: List[ErrorEntry] = []
752 warning_entries: List[WarningEntry] = []
754 def add_error_entry(msg: str, with_traceback: bool = False):
755 error_entries.append(
756 ErrorEntry(
757 loc=("weights", weight_format),
758 msg=msg,
759 type="bioimageio.core",
760 with_traceback=with_traceback,
761 )
762 )
764 def add_warning_entry(msg: str):
765 warning_entries.append(
766 WarningEntry(
767 loc=("weights", weight_format),
768 msg=msg,
769 type="bioimageio.core",
770 )
771 )
773 try:
774 test_input = get_test_input_sample(model)
775 expected = get_test_output_sample(model)
777 with create_prediction_pipeline(
778 bioimageio_model=model, devices=devices, weight_format=weight_format
779 ) as prediction_pipeline:
780 results = prediction_pipeline.predict_sample_without_blocking(test_input)
782 if len(results.members) != len(expected.members):
783 add_error_entry(
784 f"Expected {len(expected.members)} outputs, but got {len(results.members)}"
785 )
787 else:
788 for m, expected in expected.members.items():
789 actual = results.members.get(m)
790 if actual is None:
791 add_error_entry("Output tensors for test case may not be None")
792 if stop_early:
793 break
794 else:
795 continue
797 if actual.dims != (dims := expected.dims):
798 add_error_entry(
799 f"Output '{m}' has dims {actual.dims}, but expected {expected.dims}"
800 )
801 if stop_early:
802 break
803 else:
804 continue
806 if actual.tagged_shape != expected.tagged_shape:
807 add_error_entry(
808 f"Output '{m}' has shape {actual.tagged_shape}, but expected {expected.tagged_shape}"
809 )
810 if stop_early:
811 break
812 else:
813 continue
815 try:
816 expected_np = expected.data.to_numpy().astype(np.float32)
817 del expected
818 actual_np: NDArray[Any] = actual.data.to_numpy().astype(np.float32)
820 rtol, atol, mismatched_tol = _get_tolerance(
821 model, wf=weight_format, m=m, **deprecated
822 )
823 rtol_value = rtol * abs(expected_np)
824 abs_diff = abs(actual_np - expected_np)
825 mismatched = abs_diff > atol + rtol_value
826 mismatched_elements = mismatched.sum().item()
827 if not mismatched_elements:
828 continue
830 actual_output_path = Path(f"actual_output_{m}_{weight_format}.npy")
831 try:
832 save_tensor(actual_output_path, actual)
833 except Exception as e:
834 logger.error(
835 "Failed to save actual output tensor to {}: {}",
836 actual_output_path,
837 e,
838 )
840 mismatched_ppm = mismatched_elements / expected_np.size * 1e6
841 abs_diff[~mismatched] = 0 # ignore non-mismatched elements
843 r_max_idx_flat = (
844 r_diff := (abs_diff / (abs(expected_np) + 1e-6))
845 ).argmax()
846 r_max_idx = np.unravel_index(r_max_idx_flat, r_diff.shape)
847 r_max = r_diff[r_max_idx].item()
848 r_actual = actual_np[r_max_idx].item()
849 r_expected = expected_np[r_max_idx].item()
851 # Calculate the max absolute difference with the relative tolerance subtracted
852 abs_diff_wo_rtol: NDArray[np.float32] = abs_diff - rtol_value
853 a_max_idx = np.unravel_index(
854 abs_diff_wo_rtol.argmax(), abs_diff_wo_rtol.shape
855 )
857 a_max = abs_diff[a_max_idx].item()
858 a_actual = actual_np[a_max_idx].item()
859 a_expected = expected_np[a_max_idx].item()
860 except Exception as e:
861 msg = f"Output '{m}' disagrees with expected values."
862 add_error_entry(msg)
863 if stop_early:
864 break
865 else:
866 msg = (
867 f"Output '{m}' disagrees with {mismatched_elements} of"
868 + f" {expected_np.size} expected values"
869 + f" ({mismatched_ppm:.1f} ppm)."
870 + f"\n Max relative difference: {r_max:.2e}"
871 + rf" (= \|{r_actual:.2e} - {r_expected:.2e}\|/\|{r_expected:.2e} + 1e-6\|)"
872 + f" at {dict(zip(dims, r_max_idx))}"
873 + f"\n Max absolute difference not accounted for by relative tolerance: {a_max:.2e}"
874 + rf" (= \|{a_actual:.7e} - {a_expected:.7e}\|) at {dict(zip(dims, a_max_idx))}"
875 + f"\n Saved actual output to {actual_output_path}."
876 )
877 if mismatched_ppm > mismatched_tol:
878 add_error_entry(msg)
879 if stop_early:
880 break
881 else:
882 add_warning_entry(msg)
884 except Exception as e:
885 if get_validation_context().raise_errors:
886 raise e
888 add_error_entry(str(e), with_traceback=True)
890 model.validation_summary.add_detail(
891 ValidationDetail(
892 name=test_name,
893 loc=("weights", weight_format),
894 status="failed" if error_entries else "passed",
895 recommended_env=get_conda_env(entry=dict(model.weights)[weight_format]),
896 errors=error_entries,
897 warnings=warning_entries,
898 )
899 )
902def _test_model_inference_parametrized(
903 model: v0_5.ModelDescr,
904 weight_format: SupportedWeightsFormat,
905 devices: Optional[Sequence[str]],
906 *,
907 stop_early: bool,
908) -> None:
909 if not any(
910 isinstance(a.size, v0_5.ParameterizedSize)
911 for ipt in model.inputs
912 for a in ipt.axes
913 ):
914 # no parameterized sizes => set n=0
915 ns: Set[v0_5.ParameterizedSize_N] = {0}
916 else:
917 ns = {0, 1, 2}
919 given_batch_sizes = {
920 a.size
921 for ipt in model.inputs
922 for a in ipt.axes
923 if isinstance(a, v0_5.BatchAxis)
924 }
925 if given_batch_sizes:
926 batch_sizes = {gbs for gbs in given_batch_sizes if gbs is not None}
927 if not batch_sizes:
928 # only arbitrary batch sizes
929 batch_sizes = {1, 2}
930 else:
931 # no batch axis
932 batch_sizes = {1}
934 test_cases: Set[Tuple[BatchSize, v0_5.ParameterizedSize_N]] = {
935 (b, n) for b, n in product(sorted(batch_sizes), sorted(ns))
936 }
937 logger.info(
938 "Testing inference with '{}' for {} different inputs (B, N): {}",
939 weight_format,
940 len(test_cases),
941 test_cases,
942 )
944 def generate_test_cases():
945 tested: Set[Hashable] = set()
947 def get_ns(n: int):
948 return {
949 (t.id, a.id): n
950 for t in model.inputs
951 for a in t.axes
952 if isinstance(a.size, v0_5.ParameterizedSize)
953 }
955 for batch_size, n in sorted(test_cases):
956 input_target_sizes, expected_output_sizes = model.get_axis_sizes(
957 get_ns(n), batch_size=batch_size
958 )
959 hashable_target_size = tuple(
960 (k, input_target_sizes[k]) for k in sorted(input_target_sizes)
961 )
962 if hashable_target_size in tested:
963 continue
964 else:
965 tested.add(hashable_target_size)
967 resized_test_inputs = Sample(
968 members={
969 t.id: (
970 test_input.members[t.id].resize_to(
971 {
972 aid: s
973 for (tid, aid), s in input_target_sizes.items()
974 if tid == t.id
975 },
976 )
977 )
978 for t in model.inputs
979 },
980 stat=test_input.stat,
981 id=test_input.id,
982 )
983 expected_output_shapes = {
984 t.id: {
985 aid: s
986 for (tid, aid), s in expected_output_sizes.items()
987 if tid == t.id
988 }
989 for t in model.outputs
990 }
991 yield n, batch_size, resized_test_inputs, expected_output_shapes
993 try:
994 test_input = get_test_input_sample(model)
996 with create_prediction_pipeline(
997 bioimageio_model=model, devices=devices, weight_format=weight_format
998 ) as prediction_pipeline:
999 for n, batch_size, inputs, exptected_output_shape in generate_test_cases():
1000 error: Optional[str] = None
1001 result = prediction_pipeline.predict_sample_without_blocking(inputs)
1002 if len(result.members) != len(exptected_output_shape):
1003 error = (
1004 f"Expected {len(exptected_output_shape)} outputs,"
1005 + f" but got {len(result.members)}"
1006 )
1008 else:
1009 for m, exp in exptected_output_shape.items():
1010 res = result.members.get(m)
1011 if res is None:
1012 error = "Output tensors may not be None for test case"
1013 break
1015 diff: Dict[AxisId, int] = {}
1016 for a, s in res.sizes.items():
1017 if isinstance((e_aid := exp[AxisId(a)]), int):
1018 if s != e_aid:
1019 diff[AxisId(a)] = s
1020 elif (
1021 s < e_aid.min or e_aid.max is not None and s > e_aid.max
1022 ):
1023 diff[AxisId(a)] = s
1024 if diff:
1025 error = (
1026 f"(n={n}) Expected output shape {exp},"
1027 + f" but got {res.sizes} (diff: {diff})"
1028 )
1029 break
1031 model.validation_summary.add_detail(
1032 ValidationDetail(
1033 name=f"Run {weight_format} inference for inputs with"
1034 + f" batch_size: {batch_size} and size parameter n: {n}",
1035 loc=("weights", weight_format),
1036 status="passed" if error is None else "failed",
1037 errors=(
1038 []
1039 if error is None
1040 else [
1041 ErrorEntry(
1042 loc=("weights", weight_format),
1043 msg=error,
1044 type="bioimageio.core",
1045 )
1046 ]
1047 ),
1048 )
1049 )
1050 if stop_early and error is not None:
1051 break
1052 except Exception as e:
1053 if get_validation_context().raise_errors:
1054 raise e
1056 model.validation_summary.add_detail(
1057 ValidationDetail(
1058 name=f"Run {weight_format} inference for parametrized inputs",
1059 status="failed",
1060 loc=("weights", weight_format),
1061 errors=[
1062 ErrorEntry(
1063 loc=("weights", weight_format),
1064 msg=str(e),
1065 type="bioimageio.core",
1066 with_traceback=True,
1067 )
1068 ],
1069 )
1070 )
1073def _test_expected_resource_type(
1074 rd: Union[InvalidDescr, ResourceDescr], expected_type: str
1075):
1076 has_expected_type = rd.type == expected_type
1077 rd.validation_summary.details.append(
1078 ValidationDetail(
1079 name="Has expected resource type",
1080 status="passed" if has_expected_type else "failed",
1081 loc=("type",),
1082 errors=(
1083 []
1084 if has_expected_type
1085 else [
1086 ErrorEntry(
1087 loc=("type",),
1088 type="type",
1089 msg=f"Expected type {expected_type}, found {rd.type}",
1090 )
1091 ]
1092 ),
1093 )
1094 )
1097# TODO: Implement `debug_model()`
1098# def debug_model(
1099# model_rdf: Union[RawResourceDescr, ResourceDescr, URI, Path, str],
1100# *,
1101# weight_format: Optional[WeightsFormat] = None,
1102# devices: Optional[List[str]] = None,
1103# ):
1104# """Run the model test and return dict with inputs, results, expected results and intermediates.
1106# Returns dict with tensors "inputs", "inputs_processed", "outputs_raw", "outputs", "expected" and "diff".
1107# """
1108# inputs_raw: Optional = None
1109# inputs_processed: Optional = None
1110# outputs_raw: Optional = None
1111# outputs: Optional = None
1112# expected: Optional = None
1113# diff: Optional = None
1115# model = load_description(
1116# model_rdf, weights_priority_order=None if weight_format is None else [weight_format]
1117# )
1118# if not isinstance(model, Model):
1119# raise ValueError(f"Not a bioimageio.model: {model_rdf}")
1121# prediction_pipeline = create_prediction_pipeline(
1122# bioimageio_model=model, devices=devices, weight_format=weight_format
1123# )
1124# inputs = [
1125# xr.DataArray(load_array(str(in_path)), dims=input_spec.axes)
1126# for in_path, input_spec in zip(model.test_inputs, model.inputs)
1127# ]
1128# input_dict = {input_spec.name: input for input_spec, input in zip(model.inputs, inputs)}
1130# # keep track of the non-processed inputs
1131# inputs_raw = [deepcopy(input) for input in inputs]
1133# computed_measures = {}
1135# prediction_pipeline.apply_preprocessing(input_dict, computed_measures)
1136# inputs_processed = list(input_dict.values())
1137# outputs_raw = prediction_pipeline.predict(*inputs_processed)
1138# output_dict = {output_spec.name: deepcopy(output) for output_spec, output in zip(model.outputs, outputs_raw)}
1139# prediction_pipeline.apply_postprocessing(output_dict, computed_measures)
1140# outputs = list(output_dict.values())
1142# if isinstance(outputs, (np.ndarray, xr.DataArray)):
1143# outputs = [outputs]
1145# expected = [
1146# xr.DataArray(load_array(str(out_path)), dims=output_spec.axes)
1147# for out_path, output_spec in zip(model.test_outputs, model.outputs)
1148# ]
1149# if len(outputs) != len(expected):
1150# error = f"Number of outputs and number of expected outputs disagree: {len(outputs)} != {len(expected)}"
1151# print(error)
1152# else:
1153# diff = []
1154# for res, exp in zip(outputs, expected):
1155# diff.append(res - exp)
1157# return {
1158# "inputs": inputs_raw,
1159# "inputs_processed": inputs_processed,
1160# "outputs_raw": outputs_raw,
1161# "outputs": outputs,
1162# "expected": expected,
1163# "diff": diff,
1164# }