Coverage for src/bioimageio/core/_resource_tests.py: 54%
351 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-09-22 09:21 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-09-22 09:21 +0000
1import hashlib
2import os
3import platform
4import subprocess
5import sys
6import warnings
7from io import StringIO
8from itertools import product
9from pathlib import Path
10from tempfile import TemporaryDirectory
11from typing import (
12 Any,
13 Callable,
14 Dict,
15 Hashable,
16 List,
17 Literal,
18 Optional,
19 Sequence,
20 Set,
21 Tuple,
22 Union,
23 overload,
24)
26import numpy as np
27from bioimageio.spec import (
28 AnyDatasetDescr,
29 AnyModelDescr,
30 BioimageioCondaEnv,
31 DatasetDescr,
32 InvalidDescr,
33 LatestResourceDescr,
34 ModelDescr,
35 ResourceDescr,
36 ValidationContext,
37 build_description,
38 dump_description,
39 get_conda_env,
40 load_description,
41 save_bioimageio_package,
42)
43from bioimageio.spec._description_impl import DISCOVER
44from bioimageio.spec._internal.common_nodes import ResourceDescrBase
45from bioimageio.spec._internal.io import is_yaml_value
46from bioimageio.spec._internal.io_utils import read_yaml, write_yaml
47from bioimageio.spec._internal.types import (
48 AbsoluteTolerance,
49 FormatVersionPlaceholder,
50 MismatchedElementsPerMillion,
51 RelativeTolerance,
52)
53from bioimageio.spec._internal.validation_context import get_validation_context
54from bioimageio.spec.common import BioimageioYamlContent, PermissiveFileSource, Sha256
55from bioimageio.spec.model import v0_4, v0_5
56from bioimageio.spec.model.v0_5 import WeightsFormat
57from bioimageio.spec.summary import (
58 ErrorEntry,
59 InstalledPackage,
60 ValidationDetail,
61 ValidationSummary,
62 WarningEntry,
63)
64from loguru import logger
65from numpy.typing import NDArray
66from typing_extensions import NotRequired, TypedDict, Unpack, assert_never, get_args
68from bioimageio.core import __version__
70from ._prediction_pipeline import create_prediction_pipeline
71from .axis import AxisId, BatchSize
72from .common import MemberId, SupportedWeightsFormat
73from .digest_spec import get_test_input_sample, get_test_output_sample
74from .sample import Sample
77class DeprecatedKwargs(TypedDict):
78 absolute_tolerance: NotRequired[AbsoluteTolerance]
79 relative_tolerance: NotRequired[RelativeTolerance]
80 decimal: NotRequired[Optional[int]]
83def enable_determinism(
84 mode: Literal["seed_only", "full"] = "full",
85 weight_formats: Optional[Sequence[SupportedWeightsFormat]] = None,
86):
87 """Seed and configure ML frameworks for maximum reproducibility.
88 May degrade performance. Only recommended for testing reproducibility!
90 Seed any random generators and (if **mode**=="full") request ML frameworks to use
91 deterministic algorithms.
93 Args:
94 mode: determinism mode
95 - 'seed_only' -- only set seeds, or
96 - 'full' determinsm features (might degrade performance or throw exceptions)
97 weight_formats: Limit deep learning importing deep learning frameworks
98 based on weight_formats.
99 E.g. this allows to avoid importing tensorflow when testing with pytorch.
101 Notes:
102 - **mode** == "full" might degrade performance or throw exceptions.
103 - Subsequent inference calls might still differ. Call before each function
104 (sequence) that is expected to be reproducible.
105 - Degraded performance: Use for testing reproducibility only!
106 - Recipes:
107 - [PyTorch](https://pytorch.org/docs/stable/notes/randomness.html)
108 - [Keras](https://keras.io/examples/keras_recipes/reproducibility_recipes/)
109 - [NumPy](https://numpy.org/doc/2.0/reference/random/generated/numpy.random.seed.html)
110 """
111 try:
112 try:
113 import numpy.random
114 except ImportError:
115 pass
116 else:
117 numpy.random.seed(0)
118 except Exception as e:
119 logger.debug(str(e))
121 if (
122 weight_formats is None
123 or "pytorch_state_dict" in weight_formats
124 or "torchscript" in weight_formats
125 ):
126 try:
127 try:
128 import torch
129 except ImportError:
130 pass
131 else:
132 _ = torch.manual_seed(0)
133 torch.use_deterministic_algorithms(mode == "full")
134 except Exception as e:
135 logger.debug(str(e))
137 if (
138 weight_formats is None
139 or "tensorflow_saved_model_bundle" in weight_formats
140 or "keras_hdf5" in weight_formats
141 ):
142 try:
143 os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
144 try:
145 import tensorflow as tf # pyright: ignore[reportMissingTypeStubs]
146 except ImportError:
147 pass
148 else:
149 tf.random.set_seed(0)
150 if mode == "full":
151 tf.config.experimental.enable_op_determinism()
152 # TODO: find possibility to switch it off again??
153 except Exception as e:
154 logger.debug(str(e))
156 if weight_formats is None or "keras_hdf5" in weight_formats:
157 try:
158 try:
159 import keras # pyright: ignore[reportMissingTypeStubs]
160 except ImportError:
161 pass
162 else:
163 keras.utils.set_random_seed(0)
164 except Exception as e:
165 logger.debug(str(e))
168def test_model(
169 source: Union[v0_4.ModelDescr, v0_5.ModelDescr, PermissiveFileSource],
170 weight_format: Optional[SupportedWeightsFormat] = None,
171 devices: Optional[List[str]] = None,
172 *,
173 determinism: Literal["seed_only", "full"] = "seed_only",
174 sha256: Optional[Sha256] = None,
175 stop_early: bool = True,
176 **deprecated: Unpack[DeprecatedKwargs],
177) -> ValidationSummary:
178 """Test model inference"""
179 return test_description(
180 source,
181 weight_format=weight_format,
182 devices=devices,
183 determinism=determinism,
184 expected_type="model",
185 sha256=sha256,
186 stop_early=stop_early,
187 **deprecated,
188 )
191def default_run_command(args: Sequence[str]):
192 logger.info("running '{}'...", " ".join(args))
193 _ = subprocess.run(args, shell=True, text=True, check=True)
196def test_description(
197 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
198 *,
199 format_version: Union[FormatVersionPlaceholder, str] = "discover",
200 weight_format: Optional[SupportedWeightsFormat] = None,
201 devices: Optional[Sequence[str]] = None,
202 determinism: Literal["seed_only", "full"] = "seed_only",
203 expected_type: Optional[str] = None,
204 sha256: Optional[Sha256] = None,
205 stop_early: bool = True,
206 runtime_env: Union[
207 Literal["currently-active", "as-described"], Path, BioimageioCondaEnv
208 ] = ("currently-active"),
209 run_command: Callable[[Sequence[str]], None] = default_run_command,
210 **deprecated: Unpack[DeprecatedKwargs],
211) -> ValidationSummary:
212 """Test a bioimage.io resource dynamically,
213 for example run prediction of test tensors for models.
215 Args:
216 source: model description source.
217 weight_format: Weight format to test.
218 Default: All weight formats present in **source**.
219 devices: Devices to test with, e.g. 'cpu', 'cuda'.
220 Default (may be weight format dependent): ['cuda'] if available, ['cpu'] otherwise.
221 determinism: Modes to improve reproducibility of test outputs.
222 expected_type: Assert an expected resource description `type`.
223 sha256: Expected SHA256 value of **source**.
224 (Ignored if **source** already is a loaded `ResourceDescr` object.)
225 stop_early: Do not run further subtests after a failed one.
226 runtime_env: (Experimental feature!) The Python environment to run the tests in
227 - `"currently-active"`: Use active Python interpreter.
228 - `"as-described"`: Use `bioimageio.spec.get_conda_env` to generate a conda
229 environment YAML file based on the model weights description.
230 - A `BioimageioCondaEnv` or a path to a conda environment YAML file.
231 Note: The `bioimageio.core` dependency will be added automatically if not present.
232 run_command: (Experimental feature!) Function to execute (conda) terminal commands in a subprocess.
233 The function should raise an exception if the command fails.
234 **run_command** is ignored if **runtime_env** is `"currently-active"`.
235 """
236 if runtime_env == "currently-active":
237 rd = load_description_and_test(
238 source,
239 format_version=format_version,
240 weight_format=weight_format,
241 devices=devices,
242 determinism=determinism,
243 expected_type=expected_type,
244 sha256=sha256,
245 stop_early=stop_early,
246 **deprecated,
247 )
248 return rd.validation_summary
250 if runtime_env == "as-described":
251 conda_env = None
252 elif isinstance(runtime_env, (str, Path)):
253 conda_env = BioimageioCondaEnv.model_validate(read_yaml(Path(runtime_env)))
254 elif isinstance(runtime_env, BioimageioCondaEnv):
255 conda_env = runtime_env
256 else:
257 assert_never(runtime_env)
259 td_kwargs: Dict[str, Any] = (
260 dict(ignore_cleanup_errors=True) if sys.version_info >= (3, 10) else {}
261 )
262 with TemporaryDirectory(**td_kwargs) as _d:
263 working_dir = Path(_d)
264 if isinstance(source, (dict, ResourceDescrBase)):
265 file_source = save_bioimageio_package(
266 source, output_path=working_dir / "package.zip"
267 )
268 else:
269 file_source = source
271 return _test_in_env(
272 file_source,
273 working_dir=working_dir,
274 weight_format=weight_format,
275 conda_env=conda_env,
276 devices=devices,
277 determinism=determinism,
278 expected_type=expected_type,
279 sha256=sha256,
280 stop_early=stop_early,
281 run_command=run_command,
282 **deprecated,
283 )
286def _test_in_env(
287 source: PermissiveFileSource,
288 *,
289 working_dir: Path,
290 weight_format: Optional[SupportedWeightsFormat],
291 conda_env: Optional[BioimageioCondaEnv],
292 devices: Optional[Sequence[str]],
293 determinism: Literal["seed_only", "full"],
294 run_command: Callable[[Sequence[str]], None],
295 stop_early: bool,
296 expected_type: Optional[str],
297 sha256: Optional[Sha256],
298 **deprecated: Unpack[DeprecatedKwargs],
299) -> ValidationSummary:
300 descr = load_description(source)
302 if not isinstance(descr, (v0_4.ModelDescr, v0_5.ModelDescr)):
303 raise NotImplementedError("Not yet implemented for non-model resources")
305 if weight_format is None:
306 all_present_wfs = [
307 wf for wf in get_args(WeightsFormat) if getattr(descr.weights, wf)
308 ]
309 ignore_wfs = [wf for wf in all_present_wfs if wf in ["tensorflow_js"]]
310 logger.info(
311 "Found weight formats {}. Start testing all{}...",
312 all_present_wfs,
313 f" (except: {', '.join(ignore_wfs)}) " if ignore_wfs else "",
314 )
315 summary = _test_in_env(
316 source,
317 working_dir=working_dir / all_present_wfs[0],
318 weight_format=all_present_wfs[0],
319 devices=devices,
320 determinism=determinism,
321 conda_env=conda_env,
322 run_command=run_command,
323 expected_type=expected_type,
324 sha256=sha256,
325 stop_early=stop_early,
326 **deprecated,
327 )
328 for wf in all_present_wfs[1:]:
329 additional_summary = _test_in_env(
330 source,
331 working_dir=working_dir / wf,
332 weight_format=wf,
333 devices=devices,
334 determinism=determinism,
335 conda_env=conda_env,
336 run_command=run_command,
337 expected_type=expected_type,
338 sha256=sha256,
339 stop_early=stop_early,
340 **deprecated,
341 )
342 for d in additional_summary.details:
343 # TODO: filter reduntant details; group details
344 summary.add_detail(d)
345 return summary
347 if weight_format == "pytorch_state_dict":
348 wf = descr.weights.pytorch_state_dict
349 elif weight_format == "torchscript":
350 wf = descr.weights.torchscript
351 elif weight_format == "keras_hdf5":
352 wf = descr.weights.keras_hdf5
353 elif weight_format == "onnx":
354 wf = descr.weights.onnx
355 elif weight_format == "tensorflow_saved_model_bundle":
356 wf = descr.weights.tensorflow_saved_model_bundle
357 elif weight_format == "tensorflow_js":
358 raise RuntimeError(
359 "testing 'tensorflow_js' is not supported by bioimageio.core"
360 )
361 else:
362 assert_never(weight_format)
364 assert wf is not None
365 if conda_env is None:
366 conda_env = get_conda_env(entry=wf)
368 # remove name as we crate a name based on the env description hash value
369 conda_env.name = None
371 dumped_env = conda_env.model_dump(mode="json", exclude_none=True)
372 if not is_yaml_value(dumped_env):
373 raise ValueError(f"Failed to dump conda env to valid YAML {conda_env}")
375 env_io = StringIO()
376 write_yaml(dumped_env, file=env_io)
377 encoded_env = env_io.getvalue().encode()
378 env_name = hashlib.sha256(encoded_env).hexdigest()
380 try:
381 run_command(["where" if platform.system() == "Windows" else "which", "conda"])
382 except Exception as e:
383 raise RuntimeError("Conda not available") from e
385 try:
386 run_command(["conda", "activate", env_name])
387 except Exception:
388 path = working_dir / "env.yaml"
389 try:
390 _ = path.write_bytes(encoded_env)
391 logger.debug("written conda env to {}", path)
392 run_command(
393 ["conda", "env", "create", f"--file={path}", f"--name={env_name}"]
394 )
395 run_command(["conda", "activate", env_name])
396 except Exception as e:
397 summary = descr.validation_summary
398 summary.add_detail(
399 ValidationDetail(
400 name="Conda environment creation",
401 status="failed",
402 loc=("weights", weight_format),
403 recommended_env=conda_env,
404 errors=[
405 ErrorEntry(
406 loc=("weights", weight_format),
407 msg=str(e),
408 type="conda",
409 with_traceback=True,
410 )
411 ],
412 )
413 )
414 return summary
416 working_dir.mkdir(parents=True, exist_ok=True)
417 summary_path = working_dir / "summary.json"
418 assert not summary_path.exists(), "Summary file already exists"
419 cmd = []
420 cmd_error = None
421 for summary_path_arg_name in ("summary", "summary-path"):
422 try:
423 run_command(
424 cmd := (
425 [
426 "conda",
427 "run",
428 "-n",
429 env_name,
430 "bioimageio",
431 "test",
432 str(source),
433 f"--{summary_path_arg_name}={summary_path.as_posix()}",
434 f"--determinism={determinism}",
435 ]
436 + ([f"--expected-type={expected_type}"] if expected_type else [])
437 + (["--stop-early"] if stop_early else [])
438 )
439 )
440 except Exception as e:
441 cmd_error = f"Failed to run command '{' '.join(cmd)}': {e}."
443 if summary_path.exists():
444 break
445 else:
446 if cmd_error is not None:
447 logger.warning(cmd_error)
449 return ValidationSummary(
450 name="calling bioimageio test command",
451 source_name=str(source),
452 status="failed",
453 type="unknown",
454 format_version="unknown",
455 details=[
456 ValidationDetail(
457 name="run 'bioimageio test'",
458 errors=[
459 ErrorEntry(
460 loc=(),
461 type="bioimageio cli",
462 msg=f"test command '{' '.join(cmd)}' did not produce a summary file at {summary_path}",
463 )
464 ],
465 status="failed",
466 )
467 ],
468 env=set(),
469 )
471 return ValidationSummary.load_json(summary_path)
474@overload
475def load_description_and_test(
476 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
477 *,
478 format_version: Literal["latest"],
479 weight_format: Optional[SupportedWeightsFormat] = None,
480 devices: Optional[Sequence[str]] = None,
481 determinism: Literal["seed_only", "full"] = "seed_only",
482 expected_type: Literal["model"],
483 sha256: Optional[Sha256] = None,
484 stop_early: bool = True,
485 **deprecated: Unpack[DeprecatedKwargs],
486) -> Union[ModelDescr, InvalidDescr]: ...
489@overload
490def load_description_and_test(
491 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
492 *,
493 format_version: Literal["latest"],
494 weight_format: Optional[SupportedWeightsFormat] = None,
495 devices: Optional[Sequence[str]] = None,
496 determinism: Literal["seed_only", "full"] = "seed_only",
497 expected_type: Literal["dataset"],
498 sha256: Optional[Sha256] = None,
499 stop_early: bool = True,
500 **deprecated: Unpack[DeprecatedKwargs],
501) -> Union[DatasetDescr, InvalidDescr]: ...
504@overload
505def load_description_and_test(
506 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
507 *,
508 format_version: Literal["latest"],
509 weight_format: Optional[SupportedWeightsFormat] = None,
510 devices: Optional[Sequence[str]] = None,
511 determinism: Literal["seed_only", "full"] = "seed_only",
512 expected_type: Optional[str] = None,
513 sha256: Optional[Sha256] = None,
514 stop_early: bool = True,
515 **deprecated: Unpack[DeprecatedKwargs],
516) -> Union[LatestResourceDescr, InvalidDescr]: ...
519@overload
520def load_description_and_test(
521 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
522 *,
523 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER,
524 weight_format: Optional[SupportedWeightsFormat] = None,
525 devices: Optional[Sequence[str]] = None,
526 determinism: Literal["seed_only", "full"] = "seed_only",
527 expected_type: Literal["model"],
528 sha256: Optional[Sha256] = None,
529 stop_early: bool = True,
530 **deprecated: Unpack[DeprecatedKwargs],
531) -> Union[AnyModelDescr, InvalidDescr]: ...
534@overload
535def load_description_and_test(
536 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
537 *,
538 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER,
539 weight_format: Optional[SupportedWeightsFormat] = None,
540 devices: Optional[Sequence[str]] = None,
541 determinism: Literal["seed_only", "full"] = "seed_only",
542 expected_type: Literal["dataset"],
543 sha256: Optional[Sha256] = None,
544 stop_early: bool = True,
545 **deprecated: Unpack[DeprecatedKwargs],
546) -> Union[AnyDatasetDescr, InvalidDescr]: ...
549@overload
550def load_description_and_test(
551 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
552 *,
553 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER,
554 weight_format: Optional[SupportedWeightsFormat] = None,
555 devices: Optional[Sequence[str]] = None,
556 determinism: Literal["seed_only", "full"] = "seed_only",
557 expected_type: Optional[str] = None,
558 sha256: Optional[Sha256] = None,
559 stop_early: bool = True,
560 **deprecated: Unpack[DeprecatedKwargs],
561) -> Union[ResourceDescr, InvalidDescr]: ...
564def load_description_and_test(
565 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
566 *,
567 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER,
568 weight_format: Optional[SupportedWeightsFormat] = None,
569 devices: Optional[Sequence[str]] = None,
570 determinism: Literal["seed_only", "full"] = "seed_only",
571 expected_type: Optional[str] = None,
572 sha256: Optional[Sha256] = None,
573 stop_early: bool = True,
574 **deprecated: Unpack[DeprecatedKwargs],
575) -> Union[ResourceDescr, InvalidDescr]:
576 """Test a bioimage.io resource dynamically,
577 for example run prediction of test tensors for models.
579 See `test_description` for more details.
581 Returns:
582 A (possibly invalid) resource description object
583 with a populated `.validation_summary` attribute.
584 """
585 if isinstance(source, ResourceDescrBase):
586 root = source.root
587 file_name = source.file_name
588 if (
589 (
590 format_version
591 not in (
592 DISCOVER,
593 source.format_version,
594 ".".join(source.format_version.split(".")[:2]),
595 )
596 )
597 or (c := source.validation_summary.details[0].context) is None
598 or not c.perform_io_checks
599 ):
600 logger.debug(
601 "deserializing source to ensure we validate and test using format {} and perform io checks",
602 format_version,
603 )
604 source = dump_description(source)
605 else:
606 root = Path()
607 file_name = None
609 if isinstance(source, ResourceDescrBase):
610 rd = source
611 elif isinstance(source, dict):
612 # check context for a given root; default to root of source
613 context = get_validation_context(
614 ValidationContext(root=root, file_name=file_name)
615 ).replace(
616 perform_io_checks=True # make sure we perform io checks though
617 )
619 rd = build_description(
620 source,
621 format_version=format_version,
622 context=context,
623 )
624 else:
625 rd = load_description(
626 source, format_version=format_version, sha256=sha256, perform_io_checks=True
627 )
629 rd.validation_summary.env.add(
630 InstalledPackage(name="bioimageio.core", version=__version__)
631 )
633 if expected_type is not None:
634 _test_expected_resource_type(rd, expected_type)
636 if isinstance(rd, (v0_4.ModelDescr, v0_5.ModelDescr)):
637 if weight_format is None:
638 weight_formats: List[SupportedWeightsFormat] = [
639 w for w, we in rd.weights if we is not None
640 ] # pyright: ignore[reportAssignmentType]
641 else:
642 weight_formats = [weight_format]
644 enable_determinism(determinism, weight_formats=weight_formats)
645 for w in weight_formats:
646 _test_model_inference(rd, w, devices, stop_early=stop_early, **deprecated)
647 if stop_early and rd.validation_summary.status == "failed":
648 break
650 if not isinstance(rd, v0_4.ModelDescr):
651 _test_model_inference_parametrized(
652 rd, w, devices, stop_early=stop_early
653 )
654 if stop_early and rd.validation_summary.status == "failed":
655 break
657 # TODO: add execution of jupyter notebooks
658 # TODO: add more tests
660 if rd.validation_summary.status == "valid-format":
661 rd.validation_summary.status = "passed"
663 return rd
666def _get_tolerance(
667 model: Union[v0_4.ModelDescr, v0_5.ModelDescr],
668 wf: SupportedWeightsFormat,
669 m: MemberId,
670 **deprecated: Unpack[DeprecatedKwargs],
671) -> Tuple[RelativeTolerance, AbsoluteTolerance, MismatchedElementsPerMillion]:
672 if isinstance(model, v0_5.ModelDescr):
673 applicable = v0_5.ReproducibilityTolerance()
675 # check legacy test kwargs for weight format specific tolerance
676 if model.config.bioimageio.model_extra is not None:
677 for weights_format, test_kwargs in model.config.bioimageio.model_extra.get(
678 "test_kwargs", {}
679 ).items():
680 if wf == weights_format:
681 applicable = v0_5.ReproducibilityTolerance(
682 relative_tolerance=test_kwargs.get("relative_tolerance", 1e-3),
683 absolute_tolerance=test_kwargs.get("absolute_tolerance", 1e-4),
684 )
685 break
687 # check for weights format and output tensor specific tolerance
688 for a in model.config.bioimageio.reproducibility_tolerance:
689 if (not a.weights_formats or wf in a.weights_formats) and (
690 not a.output_ids or m in a.output_ids
691 ):
692 applicable = a
693 break
695 rtol = applicable.relative_tolerance
696 atol = applicable.absolute_tolerance
697 mismatched_tol = applicable.mismatched_elements_per_million
698 elif (decimal := deprecated.get("decimal")) is not None:
699 warnings.warn(
700 "The argument `decimal` has been deprecated in favour of"
701 + " `relative_tolerance` and `absolute_tolerance`, with different"
702 + " validation logic, using `numpy.testing.assert_allclose, see"
703 + " 'https://numpy.org/doc/stable/reference/generated/"
704 + " numpy.testing.assert_allclose.html'. Passing a value for `decimal`"
705 + " will cause validation to revert to the old behaviour."
706 )
707 atol = 1.5 * 10 ** (-decimal)
708 rtol = 0
709 mismatched_tol = 0
710 else:
711 # use given (deprecated) test kwargs
712 atol = deprecated.get("absolute_tolerance", 1e-5)
713 rtol = deprecated.get("relative_tolerance", 1e-3)
714 mismatched_tol = 0
716 return rtol, atol, mismatched_tol
719def _test_model_inference(
720 model: Union[v0_4.ModelDescr, v0_5.ModelDescr],
721 weight_format: SupportedWeightsFormat,
722 devices: Optional[Sequence[str]],
723 stop_early: bool,
724 **deprecated: Unpack[DeprecatedKwargs],
725) -> None:
726 test_name = f"Reproduce test outputs from test inputs ({weight_format})"
727 logger.debug("starting '{}'", test_name)
728 error_entries: List[ErrorEntry] = []
729 warning_entries: List[WarningEntry] = []
731 def add_error_entry(msg: str, with_traceback: bool = False):
732 error_entries.append(
733 ErrorEntry(
734 loc=("weights", weight_format),
735 msg=msg,
736 type="bioimageio.core",
737 with_traceback=with_traceback,
738 )
739 )
741 def add_warning_entry(msg: str):
742 warning_entries.append(
743 WarningEntry(
744 loc=("weights", weight_format),
745 msg=msg,
746 type="bioimageio.core",
747 )
748 )
750 try:
751 test_input = get_test_input_sample(model)
752 expected = get_test_output_sample(model)
754 with create_prediction_pipeline(
755 bioimageio_model=model, devices=devices, weight_format=weight_format
756 ) as prediction_pipeline:
757 results = prediction_pipeline.predict_sample_without_blocking(test_input)
759 if len(results.members) != len(expected.members):
760 add_error_entry(
761 f"Expected {len(expected.members)} outputs, but got {len(results.members)}"
762 )
764 else:
765 for m, expected in expected.members.items():
766 actual = results.members.get(m)
767 if actual is None:
768 add_error_entry("Output tensors for test case may not be None")
769 if stop_early:
770 break
771 else:
772 continue
774 if actual.dims != (dims := expected.dims):
775 add_error_entry(
776 f"Output '{m}' has dims {actual.dims}, but expected {expected.dims}"
777 )
778 if stop_early:
779 break
780 else:
781 continue
783 if actual.tagged_shape != expected.tagged_shape:
784 add_error_entry(
785 f"Output '{m}' has shape {actual.tagged_shape}, but expected {expected.tagged_shape}"
786 )
787 if stop_early:
788 break
789 else:
790 continue
792 expected_np = expected.data.to_numpy().astype(np.float32)
793 del expected
794 actual_np: NDArray[Any] = actual.data.to_numpy().astype(np.float32)
795 del actual
797 rtol, atol, mismatched_tol = _get_tolerance(
798 model, wf=weight_format, m=m, **deprecated
799 )
800 rtol_value = rtol * abs(expected_np)
801 abs_diff = abs(actual_np - expected_np)
802 mismatched = abs_diff > atol + rtol_value
803 mismatched_elements = mismatched.sum().item()
804 if not mismatched_elements:
805 continue
807 mismatched_ppm = mismatched_elements / expected_np.size * 1e6
808 abs_diff[~mismatched] = 0 # ignore non-mismatched elements
810 r_max_idx_flat = (
811 r_diff := (abs_diff / (abs(expected_np) + 1e-6))
812 ).argmax()
813 r_max_idx = np.unravel_index(r_max_idx_flat, r_diff.shape)
814 r_max = r_diff[r_max_idx].item()
815 r_actual = actual_np[r_max_idx].item()
816 r_expected = expected_np[r_max_idx].item()
818 # Calculate the max absolute difference with the relative tolerance subtracted
819 abs_diff_wo_rtol: NDArray[np.float32] = abs_diff - rtol_value
820 a_max_idx = np.unravel_index(
821 abs_diff_wo_rtol.argmax(), abs_diff_wo_rtol.shape
822 )
824 a_max = abs_diff[a_max_idx].item()
825 a_actual = actual_np[a_max_idx].item()
826 a_expected = expected_np[a_max_idx].item()
828 msg = (
829 f"Output '{m}' disagrees with {mismatched_elements} of"
830 + f" {expected_np.size} expected values"
831 + f" ({mismatched_ppm:.1f} ppm)."
832 + f"\n Max relative difference: {r_max:.2e}"
833 + rf" (= \|{r_actual:.2e} - {r_expected:.2e}\|/\|{r_expected:.2e} + 1e-6\|)"
834 + f" at {dict(zip(dims, r_max_idx))}"
835 + f"\n Max absolute difference not accounted for by relative tolerance: {a_max:.2e}"
836 + rf" (= \|{a_actual:.7e} - {a_expected:.7e}\|) at {dict(zip(dims, a_max_idx))}"
837 )
838 if mismatched_ppm > mismatched_tol:
839 add_error_entry(msg)
840 if stop_early:
841 break
842 else:
843 add_warning_entry(msg)
845 except Exception as e:
846 if get_validation_context().raise_errors:
847 raise e
849 add_error_entry(str(e), with_traceback=True)
851 model.validation_summary.add_detail(
852 ValidationDetail(
853 name=test_name,
854 loc=("weights", weight_format),
855 status="failed" if error_entries else "passed",
856 recommended_env=get_conda_env(entry=dict(model.weights)[weight_format]),
857 errors=error_entries,
858 warnings=warning_entries,
859 )
860 )
863def _test_model_inference_parametrized(
864 model: v0_5.ModelDescr,
865 weight_format: SupportedWeightsFormat,
866 devices: Optional[Sequence[str]],
867 *,
868 stop_early: bool,
869) -> None:
870 if not any(
871 isinstance(a.size, v0_5.ParameterizedSize)
872 for ipt in model.inputs
873 for a in ipt.axes
874 ):
875 # no parameterized sizes => set n=0
876 ns: Set[v0_5.ParameterizedSize_N] = {0}
877 else:
878 ns = {0, 1, 2}
880 given_batch_sizes = {
881 a.size
882 for ipt in model.inputs
883 for a in ipt.axes
884 if isinstance(a, v0_5.BatchAxis)
885 }
886 if given_batch_sizes:
887 batch_sizes = {gbs for gbs in given_batch_sizes if gbs is not None}
888 if not batch_sizes:
889 # only arbitrary batch sizes
890 batch_sizes = {1, 2}
891 else:
892 # no batch axis
893 batch_sizes = {1}
895 test_cases: Set[Tuple[BatchSize, v0_5.ParameterizedSize_N]] = {
896 (b, n) for b, n in product(sorted(batch_sizes), sorted(ns))
897 }
898 logger.info(
899 "Testing inference with '{}' for {} different inputs (B, N): {}",
900 weight_format,
901 len(test_cases),
902 test_cases,
903 )
905 def generate_test_cases():
906 tested: Set[Hashable] = set()
908 def get_ns(n: int):
909 return {
910 (t.id, a.id): n
911 for t in model.inputs
912 for a in t.axes
913 if isinstance(a.size, v0_5.ParameterizedSize)
914 }
916 for batch_size, n in sorted(test_cases):
917 input_target_sizes, expected_output_sizes = model.get_axis_sizes(
918 get_ns(n), batch_size=batch_size
919 )
920 hashable_target_size = tuple(
921 (k, input_target_sizes[k]) for k in sorted(input_target_sizes)
922 )
923 if hashable_target_size in tested:
924 continue
925 else:
926 tested.add(hashable_target_size)
928 resized_test_inputs = Sample(
929 members={
930 t.id: (
931 test_input.members[t.id].resize_to(
932 {
933 aid: s
934 for (tid, aid), s in input_target_sizes.items()
935 if tid == t.id
936 },
937 )
938 )
939 for t in model.inputs
940 },
941 stat=test_input.stat,
942 id=test_input.id,
943 )
944 expected_output_shapes = {
945 t.id: {
946 aid: s
947 for (tid, aid), s in expected_output_sizes.items()
948 if tid == t.id
949 }
950 for t in model.outputs
951 }
952 yield n, batch_size, resized_test_inputs, expected_output_shapes
954 try:
955 test_input = get_test_input_sample(model)
957 with create_prediction_pipeline(
958 bioimageio_model=model, devices=devices, weight_format=weight_format
959 ) as prediction_pipeline:
960 for n, batch_size, inputs, exptected_output_shape in generate_test_cases():
961 error: Optional[str] = None
962 result = prediction_pipeline.predict_sample_without_blocking(inputs)
963 if len(result.members) != len(exptected_output_shape):
964 error = (
965 f"Expected {len(exptected_output_shape)} outputs,"
966 + f" but got {len(result.members)}"
967 )
969 else:
970 for m, exp in exptected_output_shape.items():
971 res = result.members.get(m)
972 if res is None:
973 error = "Output tensors may not be None for test case"
974 break
976 diff: Dict[AxisId, int] = {}
977 for a, s in res.sizes.items():
978 if isinstance((e_aid := exp[AxisId(a)]), int):
979 if s != e_aid:
980 diff[AxisId(a)] = s
981 elif (
982 s < e_aid.min or e_aid.max is not None and s > e_aid.max
983 ):
984 diff[AxisId(a)] = s
985 if diff:
986 error = (
987 f"(n={n}) Expected output shape {exp},"
988 + f" but got {res.sizes} (diff: {diff})"
989 )
990 break
992 model.validation_summary.add_detail(
993 ValidationDetail(
994 name=f"Run {weight_format} inference for inputs with"
995 + f" batch_size: {batch_size} and size parameter n: {n}",
996 loc=("weights", weight_format),
997 status="passed" if error is None else "failed",
998 errors=(
999 []
1000 if error is None
1001 else [
1002 ErrorEntry(
1003 loc=("weights", weight_format),
1004 msg=error,
1005 type="bioimageio.core",
1006 )
1007 ]
1008 ),
1009 )
1010 )
1011 if stop_early and error is not None:
1012 break
1013 except Exception as e:
1014 if get_validation_context().raise_errors:
1015 raise e
1017 model.validation_summary.add_detail(
1018 ValidationDetail(
1019 name=f"Run {weight_format} inference for parametrized inputs",
1020 status="failed",
1021 loc=("weights", weight_format),
1022 errors=[
1023 ErrorEntry(
1024 loc=("weights", weight_format),
1025 msg=str(e),
1026 type="bioimageio.core",
1027 with_traceback=True,
1028 )
1029 ],
1030 )
1031 )
1034def _test_expected_resource_type(
1035 rd: Union[InvalidDescr, ResourceDescr], expected_type: str
1036):
1037 has_expected_type = rd.type == expected_type
1038 rd.validation_summary.details.append(
1039 ValidationDetail(
1040 name="Has expected resource type",
1041 status="passed" if has_expected_type else "failed",
1042 loc=("type",),
1043 errors=(
1044 []
1045 if has_expected_type
1046 else [
1047 ErrorEntry(
1048 loc=("type",),
1049 type="type",
1050 msg=f"Expected type {expected_type}, found {rd.type}",
1051 )
1052 ]
1053 ),
1054 )
1055 )
1058# TODO: Implement `debug_model()`
1059# def debug_model(
1060# model_rdf: Union[RawResourceDescr, ResourceDescr, URI, Path, str],
1061# *,
1062# weight_format: Optional[WeightsFormat] = None,
1063# devices: Optional[List[str]] = None,
1064# ):
1065# """Run the model test and return dict with inputs, results, expected results and intermediates.
1067# Returns dict with tensors "inputs", "inputs_processed", "outputs_raw", "outputs", "expected" and "diff".
1068# """
1069# inputs_raw: Optional = None
1070# inputs_processed: Optional = None
1071# outputs_raw: Optional = None
1072# outputs: Optional = None
1073# expected: Optional = None
1074# diff: Optional = None
1076# model = load_description(
1077# model_rdf, weights_priority_order=None if weight_format is None else [weight_format]
1078# )
1079# if not isinstance(model, Model):
1080# raise ValueError(f"Not a bioimageio.model: {model_rdf}")
1082# prediction_pipeline = create_prediction_pipeline(
1083# bioimageio_model=model, devices=devices, weight_format=weight_format
1084# )
1085# inputs = [
1086# xr.DataArray(load_array(str(in_path)), dims=input_spec.axes)
1087# for in_path, input_spec in zip(model.test_inputs, model.inputs)
1088# ]
1089# input_dict = {input_spec.name: input for input_spec, input in zip(model.inputs, inputs)}
1091# # keep track of the non-processed inputs
1092# inputs_raw = [deepcopy(input) for input in inputs]
1094# computed_measures = {}
1096# prediction_pipeline.apply_preprocessing(input_dict, computed_measures)
1097# inputs_processed = list(input_dict.values())
1098# outputs_raw = prediction_pipeline.predict(*inputs_processed)
1099# output_dict = {output_spec.name: deepcopy(output) for output_spec, output in zip(model.outputs, outputs_raw)}
1100# prediction_pipeline.apply_postprocessing(output_dict, computed_measures)
1101# outputs = list(output_dict.values())
1103# if isinstance(outputs, (np.ndarray, xr.DataArray)):
1104# outputs = [outputs]
1106# expected = [
1107# xr.DataArray(load_array(str(out_path)), dims=output_spec.axes)
1108# for out_path, output_spec in zip(model.test_outputs, model.outputs)
1109# ]
1110# if len(outputs) != len(expected):
1111# error = f"Number of outputs and number of expected outputs disagree: {len(outputs)} != {len(expected)}"
1112# print(error)
1113# else:
1114# diff = []
1115# for res, exp in zip(outputs, expected):
1116# diff.append(res - exp)
1118# return {
1119# "inputs": inputs_raw,
1120# "inputs_processed": inputs_processed,
1121# "outputs_raw": outputs_raw,
1122# "outputs": outputs,
1123# "expected": expected,
1124# "diff": diff,
1125# }