Coverage for bioimageio/core/_resource_tests.py: 75%
169 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-19 09:02 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-19 09:02 +0000
1import traceback
2import warnings
3from itertools import product
4from typing import Dict, Hashable, List, Literal, Optional, Sequence, Set, Tuple, Union
6import numpy as np
7from loguru import logger
9from bioimageio.spec import (
10 InvalidDescr,
11 ResourceDescr,
12 build_description,
13 dump_description,
14 load_description,
15)
16from bioimageio.spec._internal.common_nodes import ResourceDescrBase
17from bioimageio.spec.common import BioimageioYamlContent, PermissiveFileSource
18from bioimageio.spec.get_conda_env import get_conda_env
19from bioimageio.spec.model import v0_4, v0_5
20from bioimageio.spec.model.v0_5 import WeightsFormat
21from bioimageio.spec.summary import (
22 ErrorEntry,
23 InstalledPackage,
24 ValidationDetail,
25 ValidationSummary,
26)
28from ._prediction_pipeline import create_prediction_pipeline
29from .axis import AxisId, BatchSize
30from .digest_spec import get_test_inputs, get_test_outputs
31from .sample import Sample
32from .utils import VERSION
35def enable_determinism(mode: Literal["seed_only", "full"]):
36 """Seed and configure ML frameworks for maximum reproducibility.
37 May degrade performance. Only recommended for testing reproducibility!
39 Seed any random generators and (if **mode**=="full") request ML frameworks to use
40 deterministic algorithms.
41 Notes:
42 - **mode** == "full" might degrade performance and throw exceptions.
43 - Subsequent inference calls might still differ. Call before each function
44 (sequence) that is expected to be reproducible.
45 - Degraded performance: Use for testing reproducibility only!
46 - Recipes:
47 - [PyTorch](https://pytorch.org/docs/stable/notes/randomness.html)
48 - [Keras](https://keras.io/examples/keras_recipes/reproducibility_recipes/)
49 - [NumPy](https://numpy.org/doc/2.0/reference/random/generated/numpy.random.seed.html)
50 """
51 try:
52 try:
53 import numpy.random
54 except ImportError:
55 pass
56 else:
57 numpy.random.seed(0)
58 except Exception as e:
59 logger.debug(str(e))
61 try:
62 try:
63 import torch
64 except ImportError:
65 pass
66 else:
67 _ = torch.manual_seed(0)
68 torch.use_deterministic_algorithms(mode == "full")
69 except Exception as e:
70 logger.debug(str(e))
72 try:
73 try:
74 import keras
75 except ImportError:
76 pass
77 else:
78 keras.utils.set_random_seed(0)
79 except Exception as e:
80 logger.debug(str(e))
82 try:
83 try:
84 import tensorflow as tf # pyright: ignore[reportMissingImports]
85 except ImportError:
86 pass
87 else:
88 tf.random.seed(0)
89 if mode == "full":
90 tf.config.experimental.enable_op_determinism()
91 # TODO: find possibility to switch it off again??
92 except Exception as e:
93 logger.debug(str(e))
96def test_model(
97 source: Union[v0_5.ModelDescr, PermissiveFileSource],
98 weight_format: Optional[WeightsFormat] = None,
99 devices: Optional[List[str]] = None,
100 absolute_tolerance: float = 1.5e-4,
101 relative_tolerance: float = 1e-4,
102 decimal: Optional[int] = None,
103 *,
104 determinism: Literal["seed_only", "full"] = "seed_only",
105) -> ValidationSummary:
106 """Test model inference"""
107 return test_description(
108 source,
109 weight_format=weight_format,
110 devices=devices,
111 absolute_tolerance=absolute_tolerance,
112 relative_tolerance=relative_tolerance,
113 decimal=decimal,
114 determinism=determinism,
115 expected_type="model",
116 )
119def test_description(
120 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
121 *,
122 format_version: Union[Literal["discover", "latest"], str] = "discover",
123 weight_format: Optional[WeightsFormat] = None,
124 devices: Optional[Sequence[str]] = None,
125 absolute_tolerance: float = 1.5e-4,
126 relative_tolerance: float = 1e-4,
127 decimal: Optional[int] = None,
128 determinism: Literal["seed_only", "full"] = "seed_only",
129 expected_type: Optional[str] = None,
130) -> ValidationSummary:
131 """Test a bioimage.io resource dynamically, e.g. prediction of test tensors for models"""
132 rd = load_description_and_test(
133 source,
134 format_version=format_version,
135 weight_format=weight_format,
136 devices=devices,
137 absolute_tolerance=absolute_tolerance,
138 relative_tolerance=relative_tolerance,
139 decimal=decimal,
140 determinism=determinism,
141 expected_type=expected_type,
142 )
143 return rd.validation_summary
146def load_description_and_test(
147 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent],
148 *,
149 format_version: Union[Literal["discover", "latest"], str] = "discover",
150 weight_format: Optional[WeightsFormat] = None,
151 devices: Optional[Sequence[str]] = None,
152 absolute_tolerance: float = 1.5e-4,
153 relative_tolerance: float = 1e-4,
154 decimal: Optional[int] = None,
155 determinism: Literal["seed_only", "full"] = "seed_only",
156 expected_type: Optional[str] = None,
157) -> Union[ResourceDescr, InvalidDescr]:
158 """Test RDF dynamically, e.g. model inference of test inputs"""
159 if (
160 isinstance(source, ResourceDescrBase)
161 and format_version != "discover"
162 and source.format_version != format_version
163 ):
164 warnings.warn(
165 f"deserializing source to ensure we validate and test using format {format_version}"
166 )
167 source = dump_description(source)
169 if isinstance(source, ResourceDescrBase):
170 rd = source
171 elif isinstance(source, dict):
172 rd = build_description(source, format_version=format_version)
173 else:
174 rd = load_description(source, format_version=format_version)
176 rd.validation_summary.env.add(
177 InstalledPackage(name="bioimageio.core", version=VERSION)
178 )
180 if expected_type is not None:
181 _test_expected_resource_type(rd, expected_type)
183 if isinstance(rd, (v0_4.ModelDescr, v0_5.ModelDescr)):
184 if weight_format is None:
185 weight_formats: List[WeightsFormat] = [
186 w for w, we in rd.weights if we is not None
187 ] # pyright: ignore[reportAssignmentType]
188 else:
189 weight_formats = [weight_format]
191 if decimal is None:
192 atol = absolute_tolerance
193 rtol = relative_tolerance
194 else:
195 warnings.warn(
196 "The argument `decimal` has been deprecated in favour of"
197 + " `relative_tolerance` and `absolute_tolerance`, with different"
198 + " validation logic, using `numpy.testing.assert_allclose, see"
199 + " 'https://numpy.org/doc/stable/reference/generated/"
200 + " numpy.testing.assert_allclose.html'. Passing a value for `decimal`"
201 + " will cause validation to revert to the old behaviour."
202 )
203 atol = 1.5 * 10 ** (-decimal)
204 rtol = 0
206 enable_determinism(determinism)
207 for w in weight_formats:
208 _test_model_inference(rd, w, devices, atol, rtol)
209 if not isinstance(rd, v0_4.ModelDescr):
210 _test_model_inference_parametrized(rd, w, devices)
212 # TODO: add execution of jupyter notebooks
213 # TODO: add more tests
215 return rd
218def _test_model_inference(
219 model: Union[v0_4.ModelDescr, v0_5.ModelDescr],
220 weight_format: WeightsFormat,
221 devices: Optional[Sequence[str]],
222 atol: float,
223 rtol: float,
224) -> None:
225 test_name = f"Reproduce test outputs from test inputs ({weight_format})"
226 logger.info("starting '{}'", test_name)
227 error: Optional[str] = None
228 tb: List[str] = []
230 try:
231 inputs = get_test_inputs(model)
232 expected = get_test_outputs(model)
234 with create_prediction_pipeline(
235 bioimageio_model=model, devices=devices, weight_format=weight_format
236 ) as prediction_pipeline:
237 results = prediction_pipeline.predict_sample_without_blocking(inputs)
239 if len(results.members) != len(expected.members):
240 error = f"Expected {len(expected.members)} outputs, but got {len(results.members)}"
242 else:
243 for m, exp in expected.members.items():
244 res = results.members.get(m)
245 if res is None:
246 error = "Output tensors for test case may not be None"
247 break
248 try:
249 np.testing.assert_allclose(
250 res.data,
251 exp.data,
252 rtol=rtol,
253 atol=atol,
254 )
255 except AssertionError as e:
256 error = f"Output and expected output disagree:\n {e}"
257 break
258 except Exception as e:
259 error = str(e)
260 tb = traceback.format_tb(e.__traceback__)
262 model.validation_summary.add_detail(
263 ValidationDetail(
264 name=test_name,
265 loc=("weights", weight_format),
266 status="passed" if error is None else "failed",
267 recommended_env=get_conda_env(entry=dict(model.weights)[weight_format]),
268 errors=(
269 []
270 if error is None
271 else [
272 ErrorEntry(
273 loc=("weights", weight_format),
274 msg=error,
275 type="bioimageio.core",
276 traceback=tb,
277 )
278 ]
279 ),
280 )
281 )
284def _test_model_inference_parametrized(
285 model: v0_5.ModelDescr,
286 weight_format: WeightsFormat,
287 devices: Optional[Sequence[str]],
288) -> None:
289 if not any(
290 isinstance(a.size, v0_5.ParameterizedSize)
291 for ipt in model.inputs
292 for a in ipt.axes
293 ):
294 # no parameterized sizes => set n=0
295 ns: Set[v0_5.ParameterizedSize_N] = {0}
296 else:
297 ns = {0, 1, 2}
299 given_batch_sizes = {
300 a.size
301 for ipt in model.inputs
302 for a in ipt.axes
303 if isinstance(a, v0_5.BatchAxis)
304 }
305 if given_batch_sizes:
306 batch_sizes = {gbs for gbs in given_batch_sizes if gbs is not None}
307 if not batch_sizes:
308 # only arbitrary batch sizes
309 batch_sizes = {1, 2}
310 else:
311 # no batch axis
312 batch_sizes = {1}
314 test_cases: Set[Tuple[v0_5.ParameterizedSize_N, BatchSize]] = {
315 (n, b) for n, b in product(sorted(ns), sorted(batch_sizes))
316 }
317 logger.info(
318 "Testing inference with {} different input tensor sizes", len(test_cases)
319 )
321 def generate_test_cases():
322 tested: Set[Hashable] = set()
324 def get_ns(n: int):
325 return {
326 (t.id, a.id): n
327 for t in model.inputs
328 for a in t.axes
329 if isinstance(a.size, v0_5.ParameterizedSize)
330 }
332 for n, batch_size in sorted(test_cases):
333 input_target_sizes, expected_output_sizes = model.get_axis_sizes(
334 get_ns(n), batch_size=batch_size
335 )
336 hashable_target_size = tuple(
337 (k, input_target_sizes[k]) for k in sorted(input_target_sizes)
338 )
339 if hashable_target_size in tested:
340 continue
341 else:
342 tested.add(hashable_target_size)
344 resized_test_inputs = Sample(
345 members={
346 t.id: test_inputs.members[t.id].resize_to(
347 {
348 aid: s
349 for (tid, aid), s in input_target_sizes.items()
350 if tid == t.id
351 },
352 )
353 for t in model.inputs
354 },
355 stat=test_inputs.stat,
356 id=test_inputs.id,
357 )
358 expected_output_shapes = {
359 t.id: {
360 aid: s
361 for (tid, aid), s in expected_output_sizes.items()
362 if tid == t.id
363 }
364 for t in model.outputs
365 }
366 yield n, batch_size, resized_test_inputs, expected_output_shapes
368 try:
369 test_inputs = get_test_inputs(model)
371 with create_prediction_pipeline(
372 bioimageio_model=model, devices=devices, weight_format=weight_format
373 ) as prediction_pipeline:
374 for n, batch_size, inputs, exptected_output_shape in generate_test_cases():
375 error: Optional[str] = None
376 result = prediction_pipeline.predict_sample_without_blocking(inputs)
377 if len(result.members) != len(exptected_output_shape):
378 error = (
379 f"Expected {len(exptected_output_shape)} outputs,"
380 + f" but got {len(result.members)}"
381 )
383 else:
384 for m, exp in exptected_output_shape.items():
385 res = result.members.get(m)
386 if res is None:
387 error = "Output tensors may not be None for test case"
388 break
390 diff: Dict[AxisId, int] = {}
391 for a, s in res.sizes.items():
392 if isinstance((e_aid := exp[AxisId(a)]), int):
393 if s != e_aid:
394 diff[AxisId(a)] = s
395 elif (
396 s < e_aid.min or e_aid.max is not None and s > e_aid.max
397 ):
398 diff[AxisId(a)] = s
399 if diff:
400 error = (
401 f"(n={n}) Expected output shape {exp},"
402 + f" but got {res.sizes} (diff: {diff})"
403 )
404 break
406 model.validation_summary.add_detail(
407 ValidationDetail(
408 name=f"Run {weight_format} inference for inputs with"
409 + f" batch_size: {batch_size} and size parameter n: {n}",
410 loc=("weights", weight_format),
411 status="passed" if error is None else "failed",
412 errors=(
413 []
414 if error is None
415 else [
416 ErrorEntry(
417 loc=("weights", weight_format),
418 msg=error,
419 type="bioimageio.core",
420 )
421 ]
422 ),
423 )
424 )
425 except Exception as e:
426 error = str(e)
427 tb = traceback.format_tb(e.__traceback__)
428 model.validation_summary.add_detail(
429 ValidationDetail(
430 name=f"Run {weight_format} inference for parametrized inputs",
431 status="failed",
432 loc=("weights", weight_format),
433 errors=[
434 ErrorEntry(
435 loc=("weights", weight_format),
436 msg=error,
437 type="bioimageio.core",
438 traceback=tb,
439 )
440 ],
441 )
442 )
445def _test_expected_resource_type(
446 rd: Union[InvalidDescr, ResourceDescr], expected_type: str
447):
448 has_expected_type = rd.type == expected_type
449 rd.validation_summary.details.append(
450 ValidationDetail(
451 name="Has expected resource type",
452 status="passed" if has_expected_type else "failed",
453 loc=("type",),
454 errors=(
455 []
456 if has_expected_type
457 else [
458 ErrorEntry(
459 loc=("type",),
460 type="type",
461 msg=f"expected type {expected_type}, found {rd.type}",
462 )
463 ]
464 ),
465 )
466 )
469# TODO: Implement `debug_model()`
470# def debug_model(
471# model_rdf: Union[RawResourceDescr, ResourceDescr, URI, Path, str],
472# *,
473# weight_format: Optional[WeightsFormat] = None,
474# devices: Optional[List[str]] = None,
475# ):
476# """Run the model test and return dict with inputs, results, expected results and intermediates.
478# Returns dict with tensors "inputs", "inputs_processed", "outputs_raw", "outputs", "expected" and "diff".
479# """
480# inputs_raw: Optional = None
481# inputs_processed: Optional = None
482# outputs_raw: Optional = None
483# outputs: Optional = None
484# expected: Optional = None
485# diff: Optional = None
487# model = load_description(
488# model_rdf, weights_priority_order=None if weight_format is None else [weight_format]
489# )
490# if not isinstance(model, Model):
491# raise ValueError(f"Not a bioimageio.model: {model_rdf}")
493# prediction_pipeline = create_prediction_pipeline(
494# bioimageio_model=model, devices=devices, weight_format=weight_format
495# )
496# inputs = [
497# xr.DataArray(load_array(str(in_path)), dims=input_spec.axes)
498# for in_path, input_spec in zip(model.test_inputs, model.inputs)
499# ]
500# input_dict = {input_spec.name: input for input_spec, input in zip(model.inputs, inputs)}
502# # keep track of the non-processed inputs
503# inputs_raw = [deepcopy(input) for input in inputs]
505# computed_measures = {}
507# prediction_pipeline.apply_preprocessing(input_dict, computed_measures)
508# inputs_processed = list(input_dict.values())
509# outputs_raw = prediction_pipeline.predict(*inputs_processed)
510# output_dict = {output_spec.name: deepcopy(output) for output_spec, output in zip(model.outputs, outputs_raw)}
511# prediction_pipeline.apply_postprocessing(output_dict, computed_measures)
512# outputs = list(output_dict.values())
514# if isinstance(outputs, (np.ndarray, xr.DataArray)):
515# outputs = [outputs]
517# expected = [
518# xr.DataArray(load_array(str(out_path)), dims=output_spec.axes)
519# for out_path, output_spec in zip(model.test_outputs, model.outputs)
520# ]
521# if len(outputs) != len(expected):
522# error = f"Number of outputs and number of expected outputs disagree: {len(outputs)} != {len(expected)}"
523# print(error)
524# else:
525# diff = []
526# for res, exp in zip(outputs, expected):
527# diff.append(res - exp)
529# return {
530# "inputs": inputs_raw,
531# "inputs_processed": inputs_processed,
532# "outputs_raw": outputs_raw,
533# "outputs": outputs,
534# "expected": expected,
535# "diff": diff,
536# }