Coverage for bioimageio/core/_resource_tests.py: 61%
304 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 09:51 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 09:51 +0000
1import hashlib
2import os
3import platform
4import subprocess
5import warnings
6from io import StringIO
7from itertools import product
8from pathlib import Path
9from tempfile import TemporaryDirectory
10from typing import (
11 Callable,
12 Dict,
13 Hashable,
14 List,
15 Literal,
16 Optional,
17 Sequence,
18 Set,
19 Tuple,
20 Union,
21 overload,
22)
24from loguru import logger
25from typing_extensions import NotRequired, TypedDict, Unpack, assert_never, get_args
27from bioimageio.spec import (
28 BioimageioCondaEnv,
29 InvalidDescr,
30 LatestResourceDescr,
31 ResourceDescr,
32 ValidationContext,
33 build_description,
34 dump_description,
35 get_conda_env,
36 load_description,
37 save_bioimageio_package,
38)
39from bioimageio.spec._description_impl import DISCOVER
40from bioimageio.spec._internal.common_nodes import ResourceDescrBase
41from bioimageio.spec._internal.io import is_yaml_value
42from bioimageio.spec._internal.io_utils import read_yaml, write_yaml
43from bioimageio.spec._internal.types import (
44 AbsoluteTolerance,
45 FormatVersionPlaceholder,
46 MismatchedElementsPerMillion,
47 RelativeTolerance,
48)
49from bioimageio.spec._internal.validation_context import get_validation_context
50from bioimageio.spec.common import BioimageioYamlContent, PermissiveFileSource, Sha256
51from bioimageio.spec.model import v0_4, v0_5
52from bioimageio.spec.model.v0_5 import WeightsFormat
53from bioimageio.spec.summary import (
54 ErrorEntry,
55 InstalledPackage,
56 ValidationDetail,
57 ValidationSummary,
58)
60from ._prediction_pipeline import create_prediction_pipeline
61from .axis import AxisId, BatchSize
62from .common import MemberId, SupportedWeightsFormat
63from .digest_spec import get_test_inputs, get_test_outputs
64from .sample import Sample
65from .utils import VERSION
68class DeprecatedKwargs(TypedDict):
69 absolute_tolerance: NotRequired[AbsoluteTolerance]
70 relative_tolerance: NotRequired[RelativeTolerance]
71 decimal: NotRequired[Optional[int]]
74def enable_determinism(
75 mode: Literal["seed_only", "full"] = "full",
76 weight_formats: Optional[Sequence[SupportedWeightsFormat]] = None,
77):
78 """Seed and configure ML frameworks for maximum reproducibility.
79 May degrade performance. Only recommended for testing reproducibility!
81 Seed any random generators and (if **mode**=="full") request ML frameworks to use
82 deterministic algorithms.
84 Args:
85 mode: determinism mode
86 - 'seed_only' -- only set seeds, or
87 - 'full' determinsm features (might degrade performance or throw exceptions)
88 weight_formats: Limit deep learning importing deep learning frameworks
89 based on weight_formats.
90 E.g. this allows to avoid importing tensorflow when testing with pytorch.
92 Notes:
93 - **mode** == "full" might degrade performance or throw exceptions.
94 - Subsequent inference calls might still differ. Call before each function
95 (sequence) that is expected to be reproducible.
96 - Degraded performance: Use for testing reproducibility only!
97 - Recipes:
98 - [PyTorch](https://pytorch.org/docs/stable/notes/randomness.html)
99 - [Keras](https://keras.io/examples/keras_recipes/reproducibility_recipes/)
100 - [NumPy](https://numpy.org/doc/2.0/reference/random/generated/numpy.random.seed.html)
101 """
102 try:
103 try:
104 import numpy.random
105 except ImportError:
106 pass
107 else:
108 numpy.random.seed(0)
109 except Exception as e:
110 logger.debug(str(e))
112 if (
113 weight_formats is None
114 or "pytorch_state_dict" in weight_formats
115 or "torchscript" in weight_formats
116 ):
117 try:
118 try:
119 import torch
120 except ImportError:
121 pass
122 else:
123 _ = torch.manual_seed(0)
124 torch.use_deterministic_algorithms(mode == "full")
125 except Exception as e:
126 logger.debug(str(e))
128 if (
129 weight_formats is None
130 or "tensorflow_saved_model_bundle" in weight_formats
131 or "keras_hdf5" in weight_formats
132 ):
133 try:
134 os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
135 try:
136 import tensorflow as tf # pyright: ignore[reportMissingTypeStubs]
137 except ImportError:
138 pass
139 else:
140 tf.random.set_seed(0)
141 if mode == "full":
142 tf.config.experimental.enable_op_determinism()
143 # TODO: find possibility to switch it off again??
144 except Exception as e:
145 logger.debug(str(e))
147 if weight_formats is None or "keras_hdf5" in weight_formats:
148 try:
149 try:
150 import keras # pyright: ignore[reportMissingTypeStubs]
151 except ImportError:
152 pass
153 else:
154 keras.utils.set_random_seed(0)
155 except Exception as e:
156 logger.debug(str(e))
159def test_model(
160 source: Union[v0_4.ModelDescr, v0_5.ModelDescr, PermissiveFileSource],
161 weight_format: Optional[SupportedWeightsFormat] = None,
162 devices: Optional[List[str]] = None,
163 *,
164 determinism: Literal["seed_only", "full"] = "seed_only",
165 sha256: Optional[Sha256] = None,
166 stop_early: bool = False,
167 **deprecated: Unpack[DeprecatedKwargs],
168) -> ValidationSummary:
169 """Test model inference"""
170 return test_description(
171 source,
172 weight_format=weight_format,
173 devices=devices,
174 determinism=determinism,
175 expected_type="model",
176 sha256=sha256,
177 stop_early=stop_early,
178 **deprecated,
179 )
182def default_run_command(args: Sequence[str]):
183 logger.info("running '{}'...", " ".join(args))
184 _ = subprocess.run(args, shell=True, text=True, check=True)
187def test_description(
188 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
189 *,
190 format_version: Union[FormatVersionPlaceholder, str] = "discover",
191 weight_format: Optional[SupportedWeightsFormat] = None,
192 devices: Optional[Sequence[str]] = None,
193 determinism: Literal["seed_only", "full"] = "seed_only",
194 expected_type: Optional[str] = None,
195 sha256: Optional[Sha256] = None,
196 stop_early: bool = False,
197 runtime_env: Union[
198 Literal["currently-active", "as-described"], Path, BioimageioCondaEnv
199 ] = ("currently-active"),
200 run_command: Callable[[Sequence[str]], None] = default_run_command,
201 **deprecated: Unpack[DeprecatedKwargs],
202) -> ValidationSummary:
203 """Test a bioimage.io resource dynamically,
204 for example run prediction of test tensors for models.
206 Args:
207 source: model description source.
208 weight_format: Weight format to test.
209 Default: All weight formats present in **source**.
210 devices: Devices to test with, e.g. 'cpu', 'cuda'.
211 Default (may be weight format dependent): ['cuda'] if available, ['cpu'] otherwise.
212 determinism: Modes to improve reproducibility of test outputs.
213 expected_type: Assert an expected resource description `type`.
214 sha256: Expected SHA256 value of **source**.
215 (Ignored if **source** already is a loaded `ResourceDescr` object.)
216 stop_early: Do not run further subtests after a failed one.
217 runtime_env: (Experimental feature!) The Python environment to run the tests in
218 - `"currently-active"`: Use active Python interpreter.
219 - `"as-described"`: Use `bioimageio.spec.get_conda_env` to generate a conda
220 environment YAML file based on the model weights description.
221 - A `BioimageioCondaEnv` or a path to a conda environment YAML file.
222 Note: The `bioimageio.core` dependency will be added automatically if not present.
223 run_command: (Experimental feature!) Function to execute (conda) terminal commands in a subprocess
224 (ignored if **runtime_env** is `"currently-active"`).
225 """
226 if runtime_env == "currently-active":
227 rd = load_description_and_test(
228 source,
229 format_version=format_version,
230 weight_format=weight_format,
231 devices=devices,
232 determinism=determinism,
233 expected_type=expected_type,
234 sha256=sha256,
235 stop_early=stop_early,
236 **deprecated,
237 )
238 return rd.validation_summary
240 if runtime_env == "as-described":
241 conda_env = None
242 elif isinstance(runtime_env, (str, Path)):
243 conda_env = BioimageioCondaEnv.model_validate(read_yaml(Path(runtime_env)))
244 elif isinstance(runtime_env, BioimageioCondaEnv):
245 conda_env = runtime_env
246 else:
247 assert_never(runtime_env)
249 with TemporaryDirectory(ignore_cleanup_errors=True) as _d:
250 working_dir = Path(_d)
251 if isinstance(source, (dict, ResourceDescrBase)):
252 file_source = save_bioimageio_package(
253 source, output_path=working_dir / "package.zip"
254 )
255 else:
256 file_source = source
258 return _test_in_env(
259 file_source,
260 working_dir=working_dir,
261 weight_format=weight_format,
262 conda_env=conda_env,
263 devices=devices,
264 determinism=determinism,
265 expected_type=expected_type,
266 sha256=sha256,
267 stop_early=stop_early,
268 run_command=run_command,
269 **deprecated,
270 )
273def _test_in_env(
274 source: PermissiveFileSource,
275 *,
276 working_dir: Path,
277 weight_format: Optional[SupportedWeightsFormat],
278 conda_env: Optional[BioimageioCondaEnv],
279 devices: Optional[Sequence[str]],
280 determinism: Literal["seed_only", "full"],
281 run_command: Callable[[Sequence[str]], None],
282 stop_early: bool,
283 expected_type: Optional[str],
284 sha256: Optional[Sha256],
285 **deprecated: Unpack[DeprecatedKwargs],
286) -> ValidationSummary:
287 descr = load_description(source)
289 if not isinstance(descr, (v0_4.ModelDescr, v0_5.ModelDescr)):
290 raise NotImplementedError("Not yet implemented for non-model resources")
292 if weight_format is None:
293 all_present_wfs = [
294 wf for wf in get_args(WeightsFormat) if getattr(descr.weights, wf)
295 ]
296 ignore_wfs = [wf for wf in all_present_wfs if wf in ["tensorflow_js"]]
297 logger.info(
298 "Found weight formats {}. Start testing all{}...",
299 all_present_wfs,
300 f" (except: {', '.join(ignore_wfs)}) " if ignore_wfs else "",
301 )
302 summary = _test_in_env(
303 source,
304 working_dir=working_dir / all_present_wfs[0],
305 weight_format=all_present_wfs[0],
306 devices=devices,
307 determinism=determinism,
308 conda_env=conda_env,
309 run_command=run_command,
310 expected_type=expected_type,
311 sha256=sha256,
312 stop_early=stop_early,
313 **deprecated,
314 )
315 for wf in all_present_wfs[1:]:
316 additional_summary = _test_in_env(
317 source,
318 working_dir=working_dir / wf,
319 weight_format=wf,
320 devices=devices,
321 determinism=determinism,
322 conda_env=conda_env,
323 run_command=run_command,
324 expected_type=expected_type,
325 sha256=sha256,
326 stop_early=stop_early,
327 **deprecated,
328 )
329 for d in additional_summary.details:
330 # TODO: filter reduntant details; group details
331 summary.add_detail(d)
332 return summary
334 if weight_format == "pytorch_state_dict":
335 wf = descr.weights.pytorch_state_dict
336 elif weight_format == "torchscript":
337 wf = descr.weights.torchscript
338 elif weight_format == "keras_hdf5":
339 wf = descr.weights.keras_hdf5
340 elif weight_format == "onnx":
341 wf = descr.weights.onnx
342 elif weight_format == "tensorflow_saved_model_bundle":
343 wf = descr.weights.tensorflow_saved_model_bundle
344 elif weight_format == "tensorflow_js":
345 raise RuntimeError(
346 "testing 'tensorflow_js' is not supported by bioimageio.core"
347 )
348 else:
349 assert_never(weight_format)
351 assert wf is not None
352 if conda_env is None:
353 conda_env = get_conda_env(entry=wf)
355 # remove name as we crate a name based on the env description hash value
356 conda_env.name = None
358 dumped_env = conda_env.model_dump(mode="json", exclude_none=True)
359 if not is_yaml_value(dumped_env):
360 raise ValueError(f"Failed to dump conda env to valid YAML {conda_env}")
362 env_io = StringIO()
363 write_yaml(dumped_env, file=env_io)
364 encoded_env = env_io.getvalue().encode()
365 env_name = hashlib.sha256(encoded_env).hexdigest()
367 try:
368 run_command(["where" if platform.system() == "Windows" else "which", "conda"])
369 except Exception as e:
370 raise RuntimeError("Conda not available") from e
372 working_dir.mkdir(parents=True, exist_ok=True)
373 try:
374 run_command(["conda", "activate", env_name])
375 except Exception:
376 path = working_dir / "env.yaml"
377 _ = path.write_bytes(encoded_env)
378 logger.debug("written conda env to {}", path)
379 run_command(["conda", "env", "create", f"--file={path}", f"--name={env_name}"])
380 run_command(["conda", "activate", env_name])
382 summary_path = working_dir / "summary.json"
383 run_command(
384 [
385 "conda",
386 "run",
387 "-n",
388 env_name,
389 "bioimageio",
390 "test",
391 str(source),
392 f"--summary-path={summary_path}",
393 f"--determinism={determinism}",
394 ]
395 + ([f"--expected-type={expected_type}"] if expected_type else [])
396 + (["--stop-early"] if stop_early else [])
397 )
398 return ValidationSummary.model_validate_json(summary_path.read_bytes())
401@overload
402def load_description_and_test(
403 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
404 *,
405 format_version: Literal["latest"],
406 weight_format: Optional[SupportedWeightsFormat] = None,
407 devices: Optional[Sequence[str]] = None,
408 determinism: Literal["seed_only", "full"] = "seed_only",
409 expected_type: Optional[str] = None,
410 sha256: Optional[Sha256] = None,
411 stop_early: bool = False,
412 **deprecated: Unpack[DeprecatedKwargs],
413) -> Union[LatestResourceDescr, InvalidDescr]: ...
416@overload
417def load_description_and_test(
418 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
419 *,
420 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER,
421 weight_format: Optional[SupportedWeightsFormat] = None,
422 devices: Optional[Sequence[str]] = None,
423 determinism: Literal["seed_only", "full"] = "seed_only",
424 expected_type: Optional[str] = None,
425 sha256: Optional[Sha256] = None,
426 stop_early: bool = False,
427 **deprecated: Unpack[DeprecatedKwargs],
428) -> Union[ResourceDescr, InvalidDescr]: ...
431def load_description_and_test(
432 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
433 *,
434 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER,
435 weight_format: Optional[SupportedWeightsFormat] = None,
436 devices: Optional[Sequence[str]] = None,
437 determinism: Literal["seed_only", "full"] = "seed_only",
438 expected_type: Optional[str] = None,
439 sha256: Optional[Sha256] = None,
440 stop_early: bool = False,
441 **deprecated: Unpack[DeprecatedKwargs],
442) -> Union[ResourceDescr, InvalidDescr]:
443 """Test a bioimage.io resource dynamically,
444 for example run prediction of test tensors for models.
446 See `test_description` for more details.
448 Returns:
449 A (possibly invalid) resource description object
450 with a populated `.validation_summary` attribute.
451 """
452 if isinstance(source, ResourceDescrBase):
453 root = source.root
454 file_name = source.file_name
455 if (
456 (
457 format_version
458 not in (
459 DISCOVER,
460 source.format_version,
461 ".".join(source.format_version.split(".")[:2]),
462 )
463 )
464 or (c := source.validation_summary.details[0].context) is None
465 or not c.perform_io_checks
466 ):
467 logger.debug(
468 "deserializing source to ensure we validate and test using format {} and perform io checks",
469 format_version,
470 )
471 source = dump_description(source)
472 else:
473 root = Path()
474 file_name = None
476 if isinstance(source, ResourceDescrBase):
477 rd = source
478 elif isinstance(source, dict):
479 # check context for a given root; default to root of source
480 context = get_validation_context(
481 ValidationContext(root=root, file_name=file_name)
482 ).replace(
483 perform_io_checks=True # make sure we perform io checks though
484 )
486 rd = build_description(
487 source,
488 format_version=format_version,
489 context=context,
490 )
491 else:
492 rd = load_description(
493 source, format_version=format_version, sha256=sha256, perform_io_checks=True
494 )
496 rd.validation_summary.env.add(
497 InstalledPackage(name="bioimageio.core", version=VERSION)
498 )
500 if expected_type is not None:
501 _test_expected_resource_type(rd, expected_type)
503 if isinstance(rd, (v0_4.ModelDescr, v0_5.ModelDescr)):
504 if weight_format is None:
505 weight_formats: List[SupportedWeightsFormat] = [
506 w for w, we in rd.weights if we is not None
507 ] # pyright: ignore[reportAssignmentType]
508 else:
509 weight_formats = [weight_format]
511 enable_determinism(determinism, weight_formats=weight_formats)
512 for w in weight_formats:
513 _test_model_inference(rd, w, devices, **deprecated)
514 if stop_early and rd.validation_summary.status == "failed":
515 break
517 if not isinstance(rd, v0_4.ModelDescr):
518 _test_model_inference_parametrized(
519 rd, w, devices, stop_early=stop_early
520 )
521 if stop_early and rd.validation_summary.status == "failed":
522 break
524 # TODO: add execution of jupyter notebooks
525 # TODO: add more tests
527 if rd.validation_summary.status == "valid-format":
528 rd.validation_summary.status = "passed"
530 return rd
533def _get_tolerance(
534 model: Union[v0_4.ModelDescr, v0_5.ModelDescr],
535 wf: SupportedWeightsFormat,
536 m: MemberId,
537 **deprecated: Unpack[DeprecatedKwargs],
538) -> Tuple[RelativeTolerance, AbsoluteTolerance, MismatchedElementsPerMillion]:
539 if isinstance(model, v0_5.ModelDescr):
540 applicable = v0_5.ReproducibilityTolerance()
542 # check legacy test kwargs for weight format specific tolerance
543 if model.config.bioimageio.model_extra is not None:
544 for weights_format, test_kwargs in model.config.bioimageio.model_extra.get(
545 "test_kwargs", {}
546 ).items():
547 if wf == weights_format:
548 applicable = v0_5.ReproducibilityTolerance(
549 relative_tolerance=test_kwargs.get("relative_tolerance", 1e-3),
550 absolute_tolerance=test_kwargs.get("absolute_tolerance", 1e-4),
551 )
552 break
554 # check for weights format and output tensor specific tolerance
555 for a in model.config.bioimageio.reproducibility_tolerance:
556 if (not a.weights_formats or wf in a.weights_formats) and (
557 not a.output_ids or m in a.output_ids
558 ):
559 applicable = a
560 break
562 rtol = applicable.relative_tolerance
563 atol = applicable.absolute_tolerance
564 mismatched_tol = applicable.mismatched_elements_per_million
565 elif (decimal := deprecated.get("decimal")) is not None:
566 warnings.warn(
567 "The argument `decimal` has been deprecated in favour of"
568 + " `relative_tolerance` and `absolute_tolerance`, with different"
569 + " validation logic, using `numpy.testing.assert_allclose, see"
570 + " 'https://numpy.org/doc/stable/reference/generated/"
571 + " numpy.testing.assert_allclose.html'. Passing a value for `decimal`"
572 + " will cause validation to revert to the old behaviour."
573 )
574 atol = 1.5 * 10 ** (-decimal)
575 rtol = 0
576 mismatched_tol = 0
577 else:
578 # use given (deprecated) test kwargs
579 atol = deprecated.get("absolute_tolerance", 1e-5)
580 rtol = deprecated.get("relative_tolerance", 1e-3)
581 mismatched_tol = 0
583 return rtol, atol, mismatched_tol
586def _test_model_inference(
587 model: Union[v0_4.ModelDescr, v0_5.ModelDescr],
588 weight_format: SupportedWeightsFormat,
589 devices: Optional[Sequence[str]],
590 **deprecated: Unpack[DeprecatedKwargs],
591) -> None:
592 test_name = f"Reproduce test outputs from test inputs ({weight_format})"
593 logger.debug("starting '{}'", test_name)
594 errors: List[ErrorEntry] = []
596 def add_error_entry(msg: str, with_traceback: bool = False):
597 errors.append(
598 ErrorEntry(
599 loc=("weights", weight_format),
600 msg=msg,
601 type="bioimageio.core",
602 with_traceback=with_traceback,
603 )
604 )
606 try:
607 inputs = get_test_inputs(model)
608 expected = get_test_outputs(model)
610 with create_prediction_pipeline(
611 bioimageio_model=model, devices=devices, weight_format=weight_format
612 ) as prediction_pipeline:
613 results = prediction_pipeline.predict_sample_without_blocking(inputs)
615 if len(results.members) != len(expected.members):
616 add_error_entry(
617 f"Expected {len(expected.members)} outputs, but got {len(results.members)}"
618 )
620 else:
621 for m, expected in expected.members.items():
622 actual = results.members.get(m)
623 if actual is None:
624 add_error_entry("Output tensors for test case may not be None")
625 break
627 rtol, atol, mismatched_tol = _get_tolerance(
628 model, wf=weight_format, m=m, **deprecated
629 )
630 mismatched = (abs_diff := abs(actual - expected)) > atol + rtol * abs(
631 expected
632 )
633 mismatched_elements = mismatched.sum().item()
634 if mismatched_elements / expected.size > mismatched_tol / 1e6:
635 r_max_idx = (r_diff := (abs_diff / (abs(expected) + 1e-6))).argmax()
636 r_max = r_diff[r_max_idx].item()
637 r_actual = actual[r_max_idx].item()
638 r_expected = expected[r_max_idx].item()
639 a_max_idx = abs_diff.argmax()
640 a_max = abs_diff[a_max_idx].item()
641 a_actual = actual[a_max_idx].item()
642 a_expected = expected[a_max_idx].item()
643 add_error_entry(
644 f"Output '{m}' disagrees with {mismatched_elements} of"
645 + f" {expected.size} expected values."
646 + f"\n Max relative difference: {r_max:.2e}"
647 + rf" (= \|{r_actual:.2e} - {r_expected:.2e}\|/\|{r_expected:.2e} + 1e-6\|)"
648 + f" at {r_max_idx}"
649 + f"\n Max absolute difference: {a_max:.2e}"
650 + rf" (= \|{a_actual:.7e} - {a_expected:.7e}\|) at {a_max_idx}"
651 )
652 break
653 except Exception as e:
654 if get_validation_context().raise_errors:
655 raise e
657 add_error_entry(str(e), with_traceback=True)
659 model.validation_summary.add_detail(
660 ValidationDetail(
661 name=test_name,
662 loc=("weights", weight_format),
663 status="failed" if errors else "passed",
664 recommended_env=get_conda_env(entry=dict(model.weights)[weight_format]),
665 errors=errors,
666 )
667 )
670def _test_model_inference_parametrized(
671 model: v0_5.ModelDescr,
672 weight_format: SupportedWeightsFormat,
673 devices: Optional[Sequence[str]],
674 *,
675 stop_early: bool,
676) -> None:
677 if not any(
678 isinstance(a.size, v0_5.ParameterizedSize)
679 for ipt in model.inputs
680 for a in ipt.axes
681 ):
682 # no parameterized sizes => set n=0
683 ns: Set[v0_5.ParameterizedSize_N] = {0}
684 else:
685 ns = {0, 1, 2}
687 given_batch_sizes = {
688 a.size
689 for ipt in model.inputs
690 for a in ipt.axes
691 if isinstance(a, v0_5.BatchAxis)
692 }
693 if given_batch_sizes:
694 batch_sizes = {gbs for gbs in given_batch_sizes if gbs is not None}
695 if not batch_sizes:
696 # only arbitrary batch sizes
697 batch_sizes = {1, 2}
698 else:
699 # no batch axis
700 batch_sizes = {1}
702 test_cases: Set[Tuple[BatchSize, v0_5.ParameterizedSize_N]] = {
703 (b, n) for b, n in product(sorted(batch_sizes), sorted(ns))
704 }
705 logger.info(
706 "Testing inference with {} different inputs (B, N): {}",
707 len(test_cases),
708 test_cases,
709 )
711 def generate_test_cases():
712 tested: Set[Hashable] = set()
714 def get_ns(n: int):
715 return {
716 (t.id, a.id): n
717 for t in model.inputs
718 for a in t.axes
719 if isinstance(a.size, v0_5.ParameterizedSize)
720 }
722 for batch_size, n in sorted(test_cases):
723 input_target_sizes, expected_output_sizes = model.get_axis_sizes(
724 get_ns(n), batch_size=batch_size
725 )
726 hashable_target_size = tuple(
727 (k, input_target_sizes[k]) for k in sorted(input_target_sizes)
728 )
729 if hashable_target_size in tested:
730 continue
731 else:
732 tested.add(hashable_target_size)
734 resized_test_inputs = Sample(
735 members={
736 t.id: (
737 test_inputs.members[t.id].resize_to(
738 {
739 aid: s
740 for (tid, aid), s in input_target_sizes.items()
741 if tid == t.id
742 },
743 )
744 )
745 for t in model.inputs
746 },
747 stat=test_inputs.stat,
748 id=test_inputs.id,
749 )
750 expected_output_shapes = {
751 t.id: {
752 aid: s
753 for (tid, aid), s in expected_output_sizes.items()
754 if tid == t.id
755 }
756 for t in model.outputs
757 }
758 yield n, batch_size, resized_test_inputs, expected_output_shapes
760 try:
761 test_inputs = get_test_inputs(model)
763 with create_prediction_pipeline(
764 bioimageio_model=model, devices=devices, weight_format=weight_format
765 ) as prediction_pipeline:
766 for n, batch_size, inputs, exptected_output_shape in generate_test_cases():
767 error: Optional[str] = None
768 result = prediction_pipeline.predict_sample_without_blocking(inputs)
769 if len(result.members) != len(exptected_output_shape):
770 error = (
771 f"Expected {len(exptected_output_shape)} outputs,"
772 + f" but got {len(result.members)}"
773 )
775 else:
776 for m, exp in exptected_output_shape.items():
777 res = result.members.get(m)
778 if res is None:
779 error = "Output tensors may not be None for test case"
780 break
782 diff: Dict[AxisId, int] = {}
783 for a, s in res.sizes.items():
784 if isinstance((e_aid := exp[AxisId(a)]), int):
785 if s != e_aid:
786 diff[AxisId(a)] = s
787 elif (
788 s < e_aid.min or e_aid.max is not None and s > e_aid.max
789 ):
790 diff[AxisId(a)] = s
791 if diff:
792 error = (
793 f"(n={n}) Expected output shape {exp},"
794 + f" but got {res.sizes} (diff: {diff})"
795 )
796 break
798 model.validation_summary.add_detail(
799 ValidationDetail(
800 name=f"Run {weight_format} inference for inputs with"
801 + f" batch_size: {batch_size} and size parameter n: {n}",
802 loc=("weights", weight_format),
803 status="passed" if error is None else "failed",
804 errors=(
805 []
806 if error is None
807 else [
808 ErrorEntry(
809 loc=("weights", weight_format),
810 msg=error,
811 type="bioimageio.core",
812 )
813 ]
814 ),
815 )
816 )
817 if stop_early and error is not None:
818 break
819 except Exception as e:
820 if get_validation_context().raise_errors:
821 raise e
823 model.validation_summary.add_detail(
824 ValidationDetail(
825 name=f"Run {weight_format} inference for parametrized inputs",
826 status="failed",
827 loc=("weights", weight_format),
828 errors=[
829 ErrorEntry(
830 loc=("weights", weight_format),
831 msg=str(e),
832 type="bioimageio.core",
833 with_traceback=True,
834 )
835 ],
836 )
837 )
840def _test_expected_resource_type(
841 rd: Union[InvalidDescr, ResourceDescr], expected_type: str
842):
843 has_expected_type = rd.type == expected_type
844 rd.validation_summary.details.append(
845 ValidationDetail(
846 name="Has expected resource type",
847 status="passed" if has_expected_type else "failed",
848 loc=("type",),
849 errors=(
850 []
851 if has_expected_type
852 else [
853 ErrorEntry(
854 loc=("type",),
855 type="type",
856 msg=f"Expected type {expected_type}, found {rd.type}",
857 )
858 ]
859 ),
860 )
861 )
864# TODO: Implement `debug_model()`
865# def debug_model(
866# model_rdf: Union[RawResourceDescr, ResourceDescr, URI, Path, str],
867# *,
868# weight_format: Optional[WeightsFormat] = None,
869# devices: Optional[List[str]] = None,
870# ):
871# """Run the model test and return dict with inputs, results, expected results and intermediates.
873# Returns dict with tensors "inputs", "inputs_processed", "outputs_raw", "outputs", "expected" and "diff".
874# """
875# inputs_raw: Optional = None
876# inputs_processed: Optional = None
877# outputs_raw: Optional = None
878# outputs: Optional = None
879# expected: Optional = None
880# diff: Optional = None
882# model = load_description(
883# model_rdf, weights_priority_order=None if weight_format is None else [weight_format]
884# )
885# if not isinstance(model, Model):
886# raise ValueError(f"Not a bioimageio.model: {model_rdf}")
888# prediction_pipeline = create_prediction_pipeline(
889# bioimageio_model=model, devices=devices, weight_format=weight_format
890# )
891# inputs = [
892# xr.DataArray(load_array(str(in_path)), dims=input_spec.axes)
893# for in_path, input_spec in zip(model.test_inputs, model.inputs)
894# ]
895# input_dict = {input_spec.name: input for input_spec, input in zip(model.inputs, inputs)}
897# # keep track of the non-processed inputs
898# inputs_raw = [deepcopy(input) for input in inputs]
900# computed_measures = {}
902# prediction_pipeline.apply_preprocessing(input_dict, computed_measures)
903# inputs_processed = list(input_dict.values())
904# outputs_raw = prediction_pipeline.predict(*inputs_processed)
905# output_dict = {output_spec.name: deepcopy(output) for output_spec, output in zip(model.outputs, outputs_raw)}
906# prediction_pipeline.apply_postprocessing(output_dict, computed_measures)
907# outputs = list(output_dict.values())
909# if isinstance(outputs, (np.ndarray, xr.DataArray)):
910# outputs = [outputs]
912# expected = [
913# xr.DataArray(load_array(str(out_path)), dims=output_spec.axes)
914# for out_path, output_spec in zip(model.test_outputs, model.outputs)
915# ]
916# if len(outputs) != len(expected):
917# error = f"Number of outputs and number of expected outputs disagree: {len(outputs)} != {len(expected)}"
918# print(error)
919# else:
920# diff = []
921# for res, exp in zip(outputs, expected):
922# diff.append(res - exp)
924# return {
925# "inputs": inputs_raw,
926# "inputs_processed": inputs_processed,
927# "outputs_raw": outputs_raw,
928# "outputs": outputs,
929# "expected": expected,
930# "diff": diff,
931# }