Coverage for src/bioimageio/core/_resource_tests.py: 50%
383 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-13 11:02 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-13 11:02 +0000
1import hashlib
2import os
3import platform
4import subprocess
5import sys
6import warnings
7from io import StringIO
8from itertools import product
9from pathlib import Path
10from tempfile import TemporaryDirectory
11from typing import (
12 Any,
13 Callable,
14 Dict,
15 Hashable,
16 List,
17 Literal,
18 Optional,
19 Sequence,
20 Set,
21 Tuple,
22 Union,
23 overload,
24)
26import numpy as np
27from bioimageio.spec import (
28 AnyDatasetDescr,
29 AnyModelDescr,
30 BioimageioCondaEnv,
31 DatasetDescr,
32 InvalidDescr,
33 LatestResourceDescr,
34 ModelDescr,
35 ResourceDescr,
36 ValidationContext,
37 build_description,
38 dump_description,
39 get_conda_env,
40 load_description,
41 save_bioimageio_package,
42)
43from bioimageio.spec._description_impl import DISCOVER
44from bioimageio.spec._internal.common_nodes import ResourceDescrBase
45from bioimageio.spec._internal.io import is_yaml_value
46from bioimageio.spec._internal.io_utils import read_yaml, write_yaml
47from bioimageio.spec._internal.types import (
48 AbsoluteTolerance,
49 FormatVersionPlaceholder,
50 MismatchedElementsPerMillion,
51 RelativeTolerance,
52)
53from bioimageio.spec._internal.validation_context import get_validation_context
54from bioimageio.spec.common import BioimageioYamlContent, PermissiveFileSource, Sha256
55from bioimageio.spec.model import v0_4, v0_5
56from bioimageio.spec.model.v0_5 import WeightsFormat
57from bioimageio.spec.summary import (
58 ErrorEntry,
59 InstalledPackage,
60 ValidationDetail,
61 ValidationSummary,
62 WarningEntry,
63)
64from loguru import logger
65from numpy.typing import NDArray
66from typing_extensions import NotRequired, TypedDict, Unpack, assert_never, get_args
68from bioimageio.core import __version__
69from bioimageio.core.io import save_tensor
71from ._prediction_pipeline import create_prediction_pipeline
72from ._settings import settings
73from .axis import AxisId, BatchSize
74from .common import MemberId, SupportedWeightsFormat
75from .digest_spec import get_test_input_sample, get_test_output_sample
76from .sample import Sample
78CONDA_CMD = "conda.bat" if platform.system() == "Windows" else "conda"
81class DeprecatedKwargs(TypedDict):
82 absolute_tolerance: NotRequired[AbsoluteTolerance]
83 relative_tolerance: NotRequired[RelativeTolerance]
84 decimal: NotRequired[Optional[int]]
87def enable_determinism(
88 mode: Literal["seed_only", "full"] = "full",
89 weight_formats: Optional[Sequence[SupportedWeightsFormat]] = None,
90):
91 """Seed and configure ML frameworks for maximum reproducibility.
92 May degrade performance. Only recommended for testing reproducibility!
94 Seed any random generators and (if **mode**=="full") request ML frameworks to use
95 deterministic algorithms.
97 Args:
98 mode: determinism mode
99 - 'seed_only' -- only set seeds, or
100 - 'full' determinsm features (might degrade performance or throw exceptions)
101 weight_formats: Limit deep learning importing deep learning frameworks
102 based on weight_formats.
103 E.g. this allows to avoid importing tensorflow when testing with pytorch.
105 Notes:
106 - **mode** == "full" might degrade performance or throw exceptions.
107 - Subsequent inference calls might still differ. Call before each function
108 (sequence) that is expected to be reproducible.
109 - Degraded performance: Use for testing reproducibility only!
110 - Recipes:
111 - [PyTorch](https://pytorch.org/docs/stable/notes/randomness.html)
112 - [Keras](https://keras.io/examples/keras_recipes/reproducibility_recipes/)
113 - [NumPy](https://numpy.org/doc/2.0/reference/random/generated/numpy.random.seed.html)
114 """
115 try:
116 try:
117 import numpy.random
118 except ImportError:
119 pass
120 else:
121 numpy.random.seed(0)
122 except Exception as e:
123 logger.debug(str(e))
125 if (
126 weight_formats is None
127 or "pytorch_state_dict" in weight_formats
128 or "torchscript" in weight_formats
129 ):
130 try:
131 try:
132 import torch
133 except ImportError:
134 pass
135 else:
136 _ = torch.manual_seed(0)
137 torch.use_deterministic_algorithms(mode == "full")
138 except Exception as e:
139 logger.debug(str(e))
141 if (
142 weight_formats is None
143 or "tensorflow_saved_model_bundle" in weight_formats
144 or "keras_hdf5" in weight_formats
145 ):
146 try:
147 os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
148 try:
149 import tensorflow as tf # pyright: ignore[reportMissingTypeStubs]
150 except ImportError:
151 pass
152 else:
153 tf.random.set_seed(0)
154 if mode == "full":
155 tf.config.experimental.enable_op_determinism()
156 # TODO: find possibility to switch it off again??
157 except Exception as e:
158 logger.debug(str(e))
160 if weight_formats is None or "keras_hdf5" in weight_formats:
161 try:
162 try:
163 import keras # pyright: ignore[reportMissingTypeStubs]
164 except ImportError:
165 pass
166 else:
167 keras.utils.set_random_seed(0)
168 except Exception as e:
169 logger.debug(str(e))
172def test_model(
173 source: Union[v0_4.ModelDescr, v0_5.ModelDescr, PermissiveFileSource],
174 weight_format: Optional[SupportedWeightsFormat] = None,
175 devices: Optional[List[str]] = None,
176 *,
177 determinism: Literal["seed_only", "full"] = "seed_only",
178 sha256: Optional[Sha256] = None,
179 stop_early: bool = True,
180 **deprecated: Unpack[DeprecatedKwargs],
181) -> ValidationSummary:
182 """Test model inference"""
183 return test_description(
184 source,
185 weight_format=weight_format,
186 devices=devices,
187 determinism=determinism,
188 expected_type="model",
189 sha256=sha256,
190 stop_early=stop_early,
191 **deprecated,
192 )
195def default_run_command(args: Sequence[str]):
196 logger.info("running '{}'...", " ".join(args))
197 _ = subprocess.check_call(args)
200def test_description(
201 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
202 *,
203 format_version: Union[FormatVersionPlaceholder, str] = "discover",
204 weight_format: Optional[SupportedWeightsFormat] = None,
205 devices: Optional[Sequence[str]] = None,
206 determinism: Literal["seed_only", "full"] = "seed_only",
207 expected_type: Optional[str] = None,
208 sha256: Optional[Sha256] = None,
209 stop_early: bool = True,
210 runtime_env: Union[
211 Literal["currently-active", "as-described"], Path, BioimageioCondaEnv
212 ] = ("currently-active"),
213 run_command: Callable[[Sequence[str]], None] = default_run_command,
214 **deprecated: Unpack[DeprecatedKwargs],
215) -> ValidationSummary:
216 """Test a bioimage.io resource dynamically,
217 for example run prediction of test tensors for models.
219 Args:
220 source: model description source.
221 weight_format: Weight format to test.
222 Default: All weight formats present in **source**.
223 devices: Devices to test with, e.g. 'cpu', 'cuda'.
224 Default (may be weight format dependent): ['cuda'] if available, ['cpu'] otherwise.
225 determinism: Modes to improve reproducibility of test outputs.
226 expected_type: Assert an expected resource description `type`.
227 sha256: Expected SHA256 value of **source**.
228 (Ignored if **source** already is a loaded `ResourceDescr` object.)
229 stop_early: Do not run further subtests after a failed one.
230 runtime_env: (Experimental feature!) The Python environment to run the tests in
231 - `"currently-active"`: Use active Python interpreter.
232 - `"as-described"`: Use `bioimageio.spec.get_conda_env` to generate a conda
233 environment YAML file based on the model weights description.
234 - A `BioimageioCondaEnv` or a path to a conda environment YAML file.
235 Note: The `bioimageio.core` dependency will be added automatically if not present.
236 run_command: (Experimental feature!) Function to execute (conda) terminal commands in a subprocess.
237 The function should raise an exception if the command fails.
238 **run_command** is ignored if **runtime_env** is `"currently-active"`.
239 """
240 if runtime_env == "currently-active":
241 rd = load_description_and_test(
242 source,
243 format_version=format_version,
244 weight_format=weight_format,
245 devices=devices,
246 determinism=determinism,
247 expected_type=expected_type,
248 sha256=sha256,
249 stop_early=stop_early,
250 **deprecated,
251 )
252 return rd.validation_summary
254 if runtime_env == "as-described":
255 conda_env = None
256 elif isinstance(runtime_env, (str, Path)):
257 conda_env = BioimageioCondaEnv.model_validate(read_yaml(Path(runtime_env)))
258 elif isinstance(runtime_env, BioimageioCondaEnv):
259 conda_env = runtime_env
260 else:
261 assert_never(runtime_env)
263 try:
264 run_command(["thiscommandshouldalwaysfail", "please"])
265 except Exception:
266 pass
267 else:
268 raise RuntimeError(
269 "given run_command does not raise an exception for a failing command"
270 )
272 td_kwargs: Dict[str, Any] = (
273 dict(ignore_cleanup_errors=True) if sys.version_info >= (3, 10) else {}
274 )
275 with TemporaryDirectory(**td_kwargs) as _d:
276 working_dir = Path(_d)
278 if isinstance(source, ResourceDescrBase):
279 descr = source
280 elif isinstance(source, dict):
281 context = get_validation_context().replace(
282 perform_io_checks=True # make sure we perform io checks though
283 )
285 descr = build_description(source, context=context)
286 else:
287 descr = load_description(source, perform_io_checks=True)
289 if isinstance(descr, InvalidDescr):
290 return descr.validation_summary
291 elif isinstance(source, (dict, ResourceDescrBase)):
292 file_source = save_bioimageio_package(
293 descr, output_path=working_dir / "package.zip"
294 )
295 else:
296 file_source = source
298 _test_in_env(
299 file_source,
300 descr=descr,
301 working_dir=working_dir,
302 weight_format=weight_format,
303 conda_env=conda_env,
304 devices=devices,
305 determinism=determinism,
306 expected_type=expected_type,
307 sha256=sha256,
308 stop_early=stop_early,
309 run_command=run_command,
310 **deprecated,
311 )
313 return descr.validation_summary
316def _test_in_env(
317 source: PermissiveFileSource,
318 *,
319 descr: ResourceDescr,
320 working_dir: Path,
321 weight_format: Optional[SupportedWeightsFormat],
322 conda_env: Optional[BioimageioCondaEnv],
323 devices: Optional[Sequence[str]],
324 determinism: Literal["seed_only", "full"],
325 run_command: Callable[[Sequence[str]], None],
326 stop_early: bool,
327 expected_type: Optional[str],
328 sha256: Optional[Sha256],
329 **deprecated: Unpack[DeprecatedKwargs],
330):
331 """Test a bioimage.io resource in a given conda environment.
332 Adds details to the existing validation summary of **descr**.
333 """
334 if isinstance(descr, (v0_4.ModelDescr, v0_5.ModelDescr)):
335 if weight_format is None:
336 # run tests for all present weight formats
337 all_present_wfs = [
338 wf for wf in get_args(WeightsFormat) if getattr(descr.weights, wf)
339 ]
340 ignore_wfs = [wf for wf in all_present_wfs if wf in ["tensorflow_js"]]
341 logger.info(
342 "Found weight formats {}. Start testing all{}...",
343 all_present_wfs,
344 f" (except: {', '.join(ignore_wfs)}) " if ignore_wfs else "",
345 )
346 for wf in all_present_wfs:
347 _test_in_env(
348 source,
349 descr=descr,
350 working_dir=working_dir / wf,
351 weight_format=wf,
352 devices=devices,
353 determinism=determinism,
354 conda_env=conda_env,
355 run_command=run_command,
356 expected_type=expected_type,
357 sha256=sha256,
358 stop_early=stop_early,
359 **deprecated,
360 )
362 return
364 if weight_format == "pytorch_state_dict":
365 wf = descr.weights.pytorch_state_dict
366 elif weight_format == "torchscript":
367 wf = descr.weights.torchscript
368 elif weight_format == "keras_hdf5":
369 wf = descr.weights.keras_hdf5
370 elif weight_format == "onnx":
371 wf = descr.weights.onnx
372 elif weight_format == "tensorflow_saved_model_bundle":
373 wf = descr.weights.tensorflow_saved_model_bundle
374 elif weight_format == "tensorflow_js":
375 raise RuntimeError(
376 "testing 'tensorflow_js' is not supported by bioimageio.core"
377 )
378 else:
379 assert_never(weight_format)
380 assert wf is not None
381 if conda_env is None:
382 conda_env = get_conda_env(entry=wf)
384 test_loc = ("weights", weight_format)
385 else:
386 if conda_env is None:
387 warnings.warn(
388 "No conda environment description given for testing (And no default conda envs available for non-model descriptions)."
389 )
390 return
392 test_loc = ()
394 # remove name as we crate a name based on the env description hash value
395 conda_env.name = None
397 dumped_env = conda_env.model_dump(mode="json", exclude_none=True)
398 if not is_yaml_value(dumped_env):
399 raise ValueError(f"Failed to dump conda env to valid YAML {conda_env}")
401 env_io = StringIO()
402 write_yaml(dumped_env, file=env_io)
403 encoded_env = env_io.getvalue().encode()
404 env_name = hashlib.sha256(encoded_env).hexdigest()
406 try:
407 run_command(["where" if platform.system() == "Windows" else "which", CONDA_CMD])
408 except Exception as e:
409 raise RuntimeError("Conda not available") from e
411 try:
412 run_command([CONDA_CMD, "run", "-n", env_name, "python", "--version"])
413 except Exception as e:
414 working_dir.mkdir(parents=True, exist_ok=True)
415 path = working_dir / "env.yaml"
416 try:
417 _ = path.write_bytes(encoded_env)
418 logger.debug("written conda env to {}", path)
419 run_command(
420 [
421 CONDA_CMD,
422 "env",
423 "create",
424 "--yes",
425 f"--file={path}",
426 f"--name={env_name}",
427 ]
428 + (["--quiet"] if settings.CI else [])
429 )
430 # double check that environment was created successfully
431 run_command([CONDA_CMD, "run", "-n", env_name, "python", "--version"])
432 except Exception as e:
433 descr.validation_summary.add_detail(
434 ValidationDetail(
435 name="Conda environment creation",
436 status="failed",
437 loc=test_loc,
438 recommended_env=conda_env,
439 errors=[
440 ErrorEntry(
441 loc=test_loc,
442 msg=str(e),
443 type="conda",
444 with_traceback=True,
445 )
446 ],
447 )
448 )
449 return
451 working_dir.mkdir(parents=True, exist_ok=True)
452 summary_path = working_dir / "summary.json"
453 assert not summary_path.exists(), "Summary file already exists"
454 cmd = []
455 cmd_error = None
456 for summary_path_arg_name in ("summary", "summary-path"):
457 try:
458 run_command(
459 cmd := (
460 [
461 CONDA_CMD,
462 "run",
463 "-n",
464 env_name,
465 "bioimageio",
466 "test",
467 str(source),
468 f"--{summary_path_arg_name}={summary_path.as_posix()}",
469 f"--determinism={determinism}",
470 ]
471 + ([f"--expected-type={expected_type}"] if expected_type else [])
472 + (["--stop-early"] if stop_early else [])
473 )
474 )
475 except Exception as e:
476 cmd_error = f"Command '{' '.join(cmd)}' returned with error: {e}."
478 if summary_path.exists():
479 break
480 else:
481 if cmd_error is not None:
482 logger.warning(cmd_error)
484 descr.validation_summary.add_detail(
485 ValidationDetail(
486 name="run 'bioimageio test' command",
487 recommended_env=conda_env,
488 errors=[
489 ErrorEntry(
490 loc=(),
491 type="bioimageio cli",
492 msg=f"test command '{' '.join(cmd)}' did not produce a summary file at {summary_path}",
493 )
494 ],
495 status="failed",
496 )
497 )
498 return
500 # add relevant details from command summary
501 command_summary = ValidationSummary.load_json(summary_path)
502 for detail in command_summary.details:
503 if detail.loc[: len(test_loc)] == test_loc:
504 descr.validation_summary.add_detail(detail)
507@overload
508def load_description_and_test(
509 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
510 *,
511 format_version: Literal["latest"],
512 weight_format: Optional[SupportedWeightsFormat] = None,
513 devices: Optional[Sequence[str]] = None,
514 determinism: Literal["seed_only", "full"] = "seed_only",
515 expected_type: Literal["model"],
516 sha256: Optional[Sha256] = None,
517 stop_early: bool = True,
518 **deprecated: Unpack[DeprecatedKwargs],
519) -> Union[ModelDescr, InvalidDescr]: ...
522@overload
523def load_description_and_test(
524 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
525 *,
526 format_version: Literal["latest"],
527 weight_format: Optional[SupportedWeightsFormat] = None,
528 devices: Optional[Sequence[str]] = None,
529 determinism: Literal["seed_only", "full"] = "seed_only",
530 expected_type: Literal["dataset"],
531 sha256: Optional[Sha256] = None,
532 stop_early: bool = True,
533 **deprecated: Unpack[DeprecatedKwargs],
534) -> Union[DatasetDescr, InvalidDescr]: ...
537@overload
538def load_description_and_test(
539 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
540 *,
541 format_version: Literal["latest"],
542 weight_format: Optional[SupportedWeightsFormat] = None,
543 devices: Optional[Sequence[str]] = None,
544 determinism: Literal["seed_only", "full"] = "seed_only",
545 expected_type: Optional[str] = None,
546 sha256: Optional[Sha256] = None,
547 stop_early: bool = True,
548 **deprecated: Unpack[DeprecatedKwargs],
549) -> Union[LatestResourceDescr, InvalidDescr]: ...
552@overload
553def load_description_and_test(
554 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
555 *,
556 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER,
557 weight_format: Optional[SupportedWeightsFormat] = None,
558 devices: Optional[Sequence[str]] = None,
559 determinism: Literal["seed_only", "full"] = "seed_only",
560 expected_type: Literal["model"],
561 sha256: Optional[Sha256] = None,
562 stop_early: bool = True,
563 **deprecated: Unpack[DeprecatedKwargs],
564) -> Union[AnyModelDescr, InvalidDescr]: ...
567@overload
568def load_description_and_test(
569 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
570 *,
571 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER,
572 weight_format: Optional[SupportedWeightsFormat] = None,
573 devices: Optional[Sequence[str]] = None,
574 determinism: Literal["seed_only", "full"] = "seed_only",
575 expected_type: Literal["dataset"],
576 sha256: Optional[Sha256] = None,
577 stop_early: bool = True,
578 **deprecated: Unpack[DeprecatedKwargs],
579) -> Union[AnyDatasetDescr, InvalidDescr]: ...
582@overload
583def load_description_and_test(
584 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
585 *,
586 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER,
587 weight_format: Optional[SupportedWeightsFormat] = None,
588 devices: Optional[Sequence[str]] = None,
589 determinism: Literal["seed_only", "full"] = "seed_only",
590 expected_type: Optional[str] = None,
591 sha256: Optional[Sha256] = None,
592 stop_early: bool = True,
593 **deprecated: Unpack[DeprecatedKwargs],
594) -> Union[ResourceDescr, InvalidDescr]: ...
597def load_description_and_test(
598 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
599 *,
600 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER,
601 weight_format: Optional[SupportedWeightsFormat] = None,
602 devices: Optional[Sequence[str]] = None,
603 determinism: Literal["seed_only", "full"] = "seed_only",
604 expected_type: Optional[str] = None,
605 sha256: Optional[Sha256] = None,
606 stop_early: bool = True,
607 **deprecated: Unpack[DeprecatedKwargs],
608) -> Union[ResourceDescr, InvalidDescr]:
609 """Test a bioimage.io resource dynamically,
610 for example run prediction of test tensors for models.
612 See `test_description` for more details.
614 Returns:
615 A (possibly invalid) resource description object
616 with a populated `.validation_summary` attribute.
617 """
618 if isinstance(source, ResourceDescrBase):
619 root = source.root
620 file_name = source.file_name
621 if (
622 (
623 format_version
624 not in (
625 DISCOVER,
626 source.format_version,
627 ".".join(source.format_version.split(".")[:2]),
628 )
629 )
630 or (c := source.validation_summary.details[0].context) is None
631 or not c.perform_io_checks
632 ):
633 logger.debug(
634 "deserializing source to ensure we validate and test using format {} and perform io checks",
635 format_version,
636 )
637 source = dump_description(source)
638 else:
639 root = Path()
640 file_name = None
642 if isinstance(source, ResourceDescrBase):
643 rd = source
644 elif isinstance(source, dict):
645 # check context for a given root; default to root of source
646 context = get_validation_context(
647 ValidationContext(root=root, file_name=file_name)
648 ).replace(
649 perform_io_checks=True # make sure we perform io checks though
650 )
652 rd = build_description(
653 source,
654 format_version=format_version,
655 context=context,
656 )
657 else:
658 rd = load_description(
659 source, format_version=format_version, sha256=sha256, perform_io_checks=True
660 )
662 rd.validation_summary.env.add(
663 InstalledPackage(name="bioimageio.core", version=__version__)
664 )
666 if expected_type is not None:
667 _test_expected_resource_type(rd, expected_type)
669 if isinstance(rd, (v0_4.ModelDescr, v0_5.ModelDescr)):
670 if weight_format is None:
671 weight_formats: List[SupportedWeightsFormat] = [
672 w for w, we in rd.weights if we is not None
673 ] # pyright: ignore[reportAssignmentType]
674 else:
675 weight_formats = [weight_format]
677 enable_determinism(determinism, weight_formats=weight_formats)
678 for w in weight_formats:
679 _test_model_inference(rd, w, devices, stop_early=stop_early, **deprecated)
680 if stop_early and rd.validation_summary.status != "passed":
681 break
683 if not isinstance(rd, v0_4.ModelDescr):
684 _test_model_inference_parametrized(
685 rd, w, devices, stop_early=stop_early
686 )
687 if stop_early and rd.validation_summary.status != "passed":
688 break
690 # TODO: add execution of jupyter notebooks
691 # TODO: add more tests
693 return rd
696def _get_tolerance(
697 model: Union[v0_4.ModelDescr, v0_5.ModelDescr],
698 wf: SupportedWeightsFormat,
699 m: MemberId,
700 **deprecated: Unpack[DeprecatedKwargs],
701) -> Tuple[RelativeTolerance, AbsoluteTolerance, MismatchedElementsPerMillion]:
702 if isinstance(model, v0_5.ModelDescr):
703 applicable = v0_5.ReproducibilityTolerance()
705 # check legacy test kwargs for weight format specific tolerance
706 if model.config.bioimageio.model_extra is not None:
707 for weights_format, test_kwargs in model.config.bioimageio.model_extra.get(
708 "test_kwargs", {}
709 ).items():
710 if wf == weights_format:
711 applicable = v0_5.ReproducibilityTolerance(
712 relative_tolerance=test_kwargs.get("relative_tolerance", 1e-3),
713 absolute_tolerance=test_kwargs.get("absolute_tolerance", 1e-4),
714 )
715 break
717 # check for weights format and output tensor specific tolerance
718 for a in model.config.bioimageio.reproducibility_tolerance:
719 if (not a.weights_formats or wf in a.weights_formats) and (
720 not a.output_ids or m in a.output_ids
721 ):
722 applicable = a
723 break
725 rtol = applicable.relative_tolerance
726 atol = applicable.absolute_tolerance
727 mismatched_tol = applicable.mismatched_elements_per_million
728 elif (decimal := deprecated.get("decimal")) is not None:
729 warnings.warn(
730 "The argument `decimal` has been deprecated in favour of"
731 + " `relative_tolerance` and `absolute_tolerance`, with different"
732 + " validation logic, using `numpy.testing.assert_allclose, see"
733 + " 'https://numpy.org/doc/stable/reference/generated/"
734 + " numpy.testing.assert_allclose.html'. Passing a value for `decimal`"
735 + " will cause validation to revert to the old behaviour."
736 )
737 atol = 1.5 * 10 ** (-decimal)
738 rtol = 0
739 mismatched_tol = 0
740 else:
741 # use given (deprecated) test kwargs
742 atol = deprecated.get("absolute_tolerance", 1e-5)
743 rtol = deprecated.get("relative_tolerance", 1e-3)
744 mismatched_tol = 0
746 return rtol, atol, mismatched_tol
749def _test_model_inference(
750 model: Union[v0_4.ModelDescr, v0_5.ModelDescr],
751 weight_format: SupportedWeightsFormat,
752 devices: Optional[Sequence[str]],
753 stop_early: bool,
754 **deprecated: Unpack[DeprecatedKwargs],
755) -> None:
756 test_name = f"Reproduce test outputs from test inputs ({weight_format})"
757 logger.debug("starting '{}'", test_name)
758 error_entries: List[ErrorEntry] = []
759 warning_entries: List[WarningEntry] = []
761 def add_error_entry(msg: str, with_traceback: bool = False):
762 error_entries.append(
763 ErrorEntry(
764 loc=("weights", weight_format),
765 msg=msg,
766 type="bioimageio.core",
767 with_traceback=with_traceback,
768 )
769 )
771 def add_warning_entry(msg: str):
772 warning_entries.append(
773 WarningEntry(
774 loc=("weights", weight_format),
775 msg=msg,
776 type="bioimageio.core",
777 )
778 )
780 try:
781 test_input = get_test_input_sample(model)
782 expected = get_test_output_sample(model)
784 with create_prediction_pipeline(
785 bioimageio_model=model, devices=devices, weight_format=weight_format
786 ) as prediction_pipeline:
787 results = prediction_pipeline.predict_sample_without_blocking(test_input)
789 if len(results.members) != len(expected.members):
790 add_error_entry(
791 f"Expected {len(expected.members)} outputs, but got {len(results.members)}"
792 )
794 else:
795 for m, expected in expected.members.items():
796 actual = results.members.get(m)
797 if actual is None:
798 add_error_entry("Output tensors for test case may not be None")
799 if stop_early:
800 break
801 else:
802 continue
804 if actual.dims != (dims := expected.dims):
805 add_error_entry(
806 f"Output '{m}' has dims {actual.dims}, but expected {expected.dims}"
807 )
808 if stop_early:
809 break
810 else:
811 continue
813 if actual.tagged_shape != expected.tagged_shape:
814 add_error_entry(
815 f"Output '{m}' has shape {actual.tagged_shape}, but expected {expected.tagged_shape}"
816 )
817 if stop_early:
818 break
819 else:
820 continue
822 try:
823 expected_np = expected.data.to_numpy().astype(np.float32)
824 del expected
825 actual_np: NDArray[Any] = actual.data.to_numpy().astype(np.float32)
827 rtol, atol, mismatched_tol = _get_tolerance(
828 model, wf=weight_format, m=m, **deprecated
829 )
830 rtol_value = rtol * abs(expected_np)
831 abs_diff = abs(actual_np - expected_np)
832 mismatched = abs_diff > atol + rtol_value
833 mismatched_elements = mismatched.sum().item()
834 if not mismatched_elements:
835 continue
837 actual_output_path = Path(f"actual_output_{m}_{weight_format}.npy")
838 try:
839 save_tensor(actual_output_path, actual)
840 except Exception as e:
841 logger.error(
842 "Failed to save actual output tensor to {}: {}",
843 actual_output_path,
844 e,
845 )
847 mismatched_ppm = mismatched_elements / expected_np.size * 1e6
848 abs_diff[~mismatched] = 0 # ignore non-mismatched elements
850 r_max_idx_flat = (
851 r_diff := (abs_diff / (abs(expected_np) + 1e-6))
852 ).argmax()
853 r_max_idx = np.unravel_index(r_max_idx_flat, r_diff.shape)
854 r_max = r_diff[r_max_idx].item()
855 r_actual = actual_np[r_max_idx].item()
856 r_expected = expected_np[r_max_idx].item()
858 # Calculate the max absolute difference with the relative tolerance subtracted
859 abs_diff_wo_rtol: NDArray[np.float32] = abs_diff - rtol_value
860 a_max_idx = np.unravel_index(
861 abs_diff_wo_rtol.argmax(), abs_diff_wo_rtol.shape
862 )
864 a_max = abs_diff[a_max_idx].item()
865 a_actual = actual_np[a_max_idx].item()
866 a_expected = expected_np[a_max_idx].item()
867 except Exception as e:
868 msg = f"Output '{m}' disagrees with expected values."
869 add_error_entry(msg)
870 if stop_early:
871 break
872 else:
873 msg = (
874 f"Output '{m}' disagrees with {mismatched_elements} of"
875 + f" {expected_np.size} expected values"
876 + f" ({mismatched_ppm:.1f} ppm)."
877 + f"\n Max relative difference: {r_max:.2e}"
878 + rf" (= \|{r_actual:.2e} - {r_expected:.2e}\|/\|{r_expected:.2e} + 1e-6\|)"
879 + f" at {dict(zip(dims, r_max_idx))}"
880 + f"\n Max absolute difference not accounted for by relative tolerance: {a_max:.2e}"
881 + rf" (= \|{a_actual:.7e} - {a_expected:.7e}\|) at {dict(zip(dims, a_max_idx))}"
882 + f"\n Saved actual output to {actual_output_path}."
883 )
884 if mismatched_ppm > mismatched_tol:
885 add_error_entry(msg)
886 if stop_early:
887 break
888 else:
889 add_warning_entry(msg)
891 except Exception as e:
892 if get_validation_context().raise_errors:
893 raise e
895 add_error_entry(str(e), with_traceback=True)
897 model.validation_summary.add_detail(
898 ValidationDetail(
899 name=test_name,
900 loc=("weights", weight_format),
901 status="failed" if error_entries else "passed",
902 recommended_env=get_conda_env(entry=dict(model.weights)[weight_format]),
903 errors=error_entries,
904 warnings=warning_entries,
905 )
906 )
909def _test_model_inference_parametrized(
910 model: v0_5.ModelDescr,
911 weight_format: SupportedWeightsFormat,
912 devices: Optional[Sequence[str]],
913 *,
914 stop_early: bool,
915) -> None:
916 if not any(
917 isinstance(a.size, v0_5.ParameterizedSize)
918 for ipt in model.inputs
919 for a in ipt.axes
920 ):
921 # no parameterized sizes => set n=0
922 ns: Set[v0_5.ParameterizedSize_N] = {0}
923 else:
924 ns = {0, 1, 2}
926 given_batch_sizes = {
927 a.size
928 for ipt in model.inputs
929 for a in ipt.axes
930 if isinstance(a, v0_5.BatchAxis)
931 }
932 if given_batch_sizes:
933 batch_sizes = {gbs for gbs in given_batch_sizes if gbs is not None}
934 if not batch_sizes:
935 # only arbitrary batch sizes
936 batch_sizes = {1, 2}
937 else:
938 # no batch axis
939 batch_sizes = {1}
941 test_cases: Set[Tuple[BatchSize, v0_5.ParameterizedSize_N]] = {
942 (b, n) for b, n in product(sorted(batch_sizes), sorted(ns))
943 }
944 logger.info(
945 "Testing inference with '{}' for {} different inputs (B, N): {}",
946 weight_format,
947 len(test_cases),
948 test_cases,
949 )
951 def generate_test_cases():
952 tested: Set[Hashable] = set()
954 def get_ns(n: int):
955 return {
956 (t.id, a.id): n
957 for t in model.inputs
958 for a in t.axes
959 if isinstance(a.size, v0_5.ParameterizedSize)
960 }
962 for batch_size, n in sorted(test_cases):
963 input_target_sizes, expected_output_sizes = model.get_axis_sizes(
964 get_ns(n), batch_size=batch_size
965 )
966 hashable_target_size = tuple(
967 (k, input_target_sizes[k]) for k in sorted(input_target_sizes)
968 )
969 if hashable_target_size in tested:
970 continue
971 else:
972 tested.add(hashable_target_size)
974 resized_test_inputs = Sample(
975 members={
976 t.id: (
977 test_input.members[t.id].resize_to(
978 {
979 aid: s
980 for (tid, aid), s in input_target_sizes.items()
981 if tid == t.id
982 },
983 )
984 )
985 for t in model.inputs
986 },
987 stat=test_input.stat,
988 id=test_input.id,
989 )
990 expected_output_shapes = {
991 t.id: {
992 aid: s
993 for (tid, aid), s in expected_output_sizes.items()
994 if tid == t.id
995 }
996 for t in model.outputs
997 }
998 yield n, batch_size, resized_test_inputs, expected_output_shapes
1000 try:
1001 test_input = get_test_input_sample(model)
1003 with create_prediction_pipeline(
1004 bioimageio_model=model, devices=devices, weight_format=weight_format
1005 ) as prediction_pipeline:
1006 for n, batch_size, inputs, exptected_output_shape in generate_test_cases():
1007 error: Optional[str] = None
1008 try:
1009 result = prediction_pipeline.predict_sample_without_blocking(inputs)
1010 except Exception as e:
1011 error = str(e)
1012 else:
1013 if len(result.members) != len(exptected_output_shape):
1014 error = (
1015 f"Expected {len(exptected_output_shape)} outputs,"
1016 + f" but got {len(result.members)}"
1017 )
1019 else:
1020 for m, exp in exptected_output_shape.items():
1021 res = result.members.get(m)
1022 if res is None:
1023 error = "Output tensors may not be None for test case"
1024 break
1026 diff: Dict[AxisId, int] = {}
1027 for a, s in res.sizes.items():
1028 if isinstance((e_aid := exp[AxisId(a)]), int):
1029 if s != e_aid:
1030 diff[AxisId(a)] = s
1031 elif (
1032 s < e_aid.min
1033 or e_aid.max is not None
1034 and s > e_aid.max
1035 ):
1036 diff[AxisId(a)] = s
1037 if diff:
1038 error = (
1039 f"(n={n}) Expected output shape {exp},"
1040 + f" but got {res.sizes} (diff: {diff})"
1041 )
1042 break
1044 model.validation_summary.add_detail(
1045 ValidationDetail(
1046 name=f"Run {weight_format} inference for inputs with"
1047 + f" batch_size: {batch_size} and size parameter n: {n}",
1048 loc=("weights", weight_format),
1049 status="passed" if error is None else "failed",
1050 errors=(
1051 []
1052 if error is None
1053 else [
1054 ErrorEntry(
1055 loc=("weights", weight_format),
1056 msg=error,
1057 type="bioimageio.core",
1058 )
1059 ]
1060 ),
1061 )
1062 )
1063 if stop_early and error is not None:
1064 break
1065 except Exception as e:
1066 if get_validation_context().raise_errors:
1067 raise e
1069 model.validation_summary.add_detail(
1070 ValidationDetail(
1071 name=f"Run {weight_format} inference for parametrized inputs",
1072 status="failed",
1073 loc=("weights", weight_format),
1074 errors=[
1075 ErrorEntry(
1076 loc=("weights", weight_format),
1077 msg=str(e),
1078 type="bioimageio.core",
1079 with_traceback=True,
1080 )
1081 ],
1082 )
1083 )
1086def _test_expected_resource_type(
1087 rd: Union[InvalidDescr, ResourceDescr], expected_type: str
1088):
1089 has_expected_type = rd.type == expected_type
1090 rd.validation_summary.details.append(
1091 ValidationDetail(
1092 name="Has expected resource type",
1093 status="passed" if has_expected_type else "failed",
1094 loc=("type",),
1095 errors=(
1096 []
1097 if has_expected_type
1098 else [
1099 ErrorEntry(
1100 loc=("type",),
1101 type="type",
1102 msg=f"Expected type {expected_type}, found {rd.type}",
1103 )
1104 ]
1105 ),
1106 )
1107 )
1110# TODO: Implement `debug_model()`
1111# def debug_model(
1112# model_rdf: Union[RawResourceDescr, ResourceDescr, URI, Path, str],
1113# *,
1114# weight_format: Optional[WeightsFormat] = None,
1115# devices: Optional[List[str]] = None,
1116# ):
1117# """Run the model test and return dict with inputs, results, expected results and intermediates.
1119# Returns dict with tensors "inputs", "inputs_processed", "outputs_raw", "outputs", "expected" and "diff".
1120# """
1121# inputs_raw: Optional = None
1122# inputs_processed: Optional = None
1123# outputs_raw: Optional = None
1124# outputs: Optional = None
1125# expected: Optional = None
1126# diff: Optional = None
1128# model = load_description(
1129# model_rdf, weights_priority_order=None if weight_format is None else [weight_format]
1130# )
1131# if not isinstance(model, Model):
1132# raise ValueError(f"Not a bioimageio.model: {model_rdf}")
1134# prediction_pipeline = create_prediction_pipeline(
1135# bioimageio_model=model, devices=devices, weight_format=weight_format
1136# )
1137# inputs = [
1138# xr.DataArray(load_array(str(in_path)), dims=input_spec.axes)
1139# for in_path, input_spec in zip(model.test_inputs, model.inputs)
1140# ]
1141# input_dict = {input_spec.name: input for input_spec, input in zip(model.inputs, inputs)}
1143# # keep track of the non-processed inputs
1144# inputs_raw = [deepcopy(input) for input in inputs]
1146# computed_measures = {}
1148# prediction_pipeline.apply_preprocessing(input_dict, computed_measures)
1149# inputs_processed = list(input_dict.values())
1150# outputs_raw = prediction_pipeline.predict(*inputs_processed)
1151# output_dict = {output_spec.name: deepcopy(output) for output_spec, output in zip(model.outputs, outputs_raw)}
1152# prediction_pipeline.apply_postprocessing(output_dict, computed_measures)
1153# outputs = list(output_dict.values())
1155# if isinstance(outputs, (np.ndarray, xr.DataArray)):
1156# outputs = [outputs]
1158# expected = [
1159# xr.DataArray(load_array(str(out_path)), dims=output_spec.axes)
1160# for out_path, output_spec in zip(model.test_outputs, model.outputs)
1161# ]
1162# if len(outputs) != len(expected):
1163# error = f"Number of outputs and number of expected outputs disagree: {len(outputs)} != {len(expected)}"
1164# print(error)
1165# else:
1166# diff = []
1167# for res, exp in zip(outputs, expected):
1168# diff.append(res - exp)
1170# return {
1171# "inputs": inputs_raw,
1172# "inputs_processed": inputs_processed,
1173# "outputs_raw": outputs_raw,
1174# "outputs": outputs,
1175# "expected": expected,
1176# "diff": diff,
1177# }