Coverage for bioimageio/core/weight_converter/torch/_onnx.py: 23%
56 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-19 09:02 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-19 09:02 +0000
1# type: ignore # TODO: type
2import warnings
3from pathlib import Path
4from typing import Any, List, Sequence, cast
6import numpy as np
7from numpy.testing import assert_array_almost_equal
9from bioimageio.spec import load_description
10from bioimageio.spec.common import InvalidDescr
11from bioimageio.spec.model import v0_4, v0_5
13from ...digest_spec import get_member_id, get_test_inputs
14from ...weight_converter.torch._utils import load_torch_model
16try:
17 import torch
18except ImportError:
19 torch = None
22def add_onnx_weights(
23 model_spec: "str | Path | v0_4.ModelDescr | v0_5.ModelDescr",
24 *,
25 output_path: Path,
26 use_tracing: bool = True,
27 test_decimal: int = 4,
28 verbose: bool = False,
29 opset_version: "int | None" = None,
30):
31 """Convert model weights from format 'pytorch_state_dict' to 'onnx'.
33 Args:
34 source_model: model without onnx weights
35 opset_version: onnx opset version
36 use_tracing: whether to use tracing or scripting to export the onnx format
37 test_decimal: precision for testing whether the results agree
38 """
39 if isinstance(model_spec, (str, Path)):
40 loaded_spec = load_description(Path(model_spec))
41 if isinstance(loaded_spec, InvalidDescr):
42 raise ValueError(f"Bad resource description: {loaded_spec}")
43 if not isinstance(loaded_spec, (v0_4.ModelDescr, v0_5.ModelDescr)):
44 raise TypeError(
45 f"Path {model_spec} is a {loaded_spec.__class__.__name__}, expected a v0_4.ModelDescr or v0_5.ModelDescr"
46 )
47 model_spec = loaded_spec
49 state_dict_weights_descr = model_spec.weights.pytorch_state_dict
50 if state_dict_weights_descr is None:
51 raise ValueError(
52 "The provided model does not have weights in the pytorch state dict format"
53 )
55 assert torch is not None
56 with torch.no_grad():
58 sample = get_test_inputs(model_spec)
59 input_data = [sample[get_member_id(ipt)].data.data for ipt in model_spec.inputs]
60 input_tensors = [torch.from_numpy(ipt) for ipt in input_data]
61 model = load_torch_model(state_dict_weights_descr)
63 expected_tensors = model(*input_tensors)
64 if isinstance(expected_tensors, torch.Tensor):
65 expected_tensors = [expected_tensors]
66 expected_outputs: List[np.ndarray[Any, Any]] = [
67 out.numpy() for out in expected_tensors
68 ]
70 if use_tracing:
71 torch.onnx.export(
72 model,
73 tuple(input_tensors) if len(input_tensors) > 1 else input_tensors[0],
74 str(output_path),
75 verbose=verbose,
76 opset_version=opset_version,
77 )
78 else:
79 raise NotImplementedError
81 try:
82 import onnxruntime as rt # pyright: ignore [reportMissingTypeStubs]
83 except ImportError:
84 msg = "The onnx weights were exported, but onnx rt is not available and weights cannot be checked."
85 warnings.warn(msg)
86 return
88 # check the onnx model
89 sess = rt.InferenceSession(str(output_path))
90 onnx_input_node_args = cast(
91 List[Any], sess.get_inputs()
92 ) # fixme: remove cast, try using rt.NodeArg instead of Any
93 onnx_inputs = {
94 input_name.name: inp
95 for input_name, inp in zip(onnx_input_node_args, input_data)
96 }
97 outputs = cast(
98 Sequence[np.ndarray[Any, Any]], sess.run(None, onnx_inputs)
99 ) # FIXME: remove cast
101 try:
102 for exp, out in zip(expected_outputs, outputs):
103 assert_array_almost_equal(exp, out, decimal=test_decimal)
104 return 0
105 except AssertionError as e:
106 msg = f"The onnx weights were exported, but results before and after conversion do not agree:\n {str(e)}"
107 warnings.warn(msg)
108 return 1