Coverage for bioimageio/core/weight_converter/torch/_onnx.py: 23%

56 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-19 09:02 +0000

1# type: ignore # TODO: type 

2import warnings 

3from pathlib import Path 

4from typing import Any, List, Sequence, cast 

5 

6import numpy as np 

7from numpy.testing import assert_array_almost_equal 

8 

9from bioimageio.spec import load_description 

10from bioimageio.spec.common import InvalidDescr 

11from bioimageio.spec.model import v0_4, v0_5 

12 

13from ...digest_spec import get_member_id, get_test_inputs 

14from ...weight_converter.torch._utils import load_torch_model 

15 

16try: 

17 import torch 

18except ImportError: 

19 torch = None 

20 

21 

22def add_onnx_weights( 

23 model_spec: "str | Path | v0_4.ModelDescr | v0_5.ModelDescr", 

24 *, 

25 output_path: Path, 

26 use_tracing: bool = True, 

27 test_decimal: int = 4, 

28 verbose: bool = False, 

29 opset_version: "int | None" = None, 

30): 

31 """Convert model weights from format 'pytorch_state_dict' to 'onnx'. 

32 

33 Args: 

34 source_model: model without onnx weights 

35 opset_version: onnx opset version 

36 use_tracing: whether to use tracing or scripting to export the onnx format 

37 test_decimal: precision for testing whether the results agree 

38 """ 

39 if isinstance(model_spec, (str, Path)): 

40 loaded_spec = load_description(Path(model_spec)) 

41 if isinstance(loaded_spec, InvalidDescr): 

42 raise ValueError(f"Bad resource description: {loaded_spec}") 

43 if not isinstance(loaded_spec, (v0_4.ModelDescr, v0_5.ModelDescr)): 

44 raise TypeError( 

45 f"Path {model_spec} is a {loaded_spec.__class__.__name__}, expected a v0_4.ModelDescr or v0_5.ModelDescr" 

46 ) 

47 model_spec = loaded_spec 

48 

49 state_dict_weights_descr = model_spec.weights.pytorch_state_dict 

50 if state_dict_weights_descr is None: 

51 raise ValueError( 

52 "The provided model does not have weights in the pytorch state dict format" 

53 ) 

54 

55 assert torch is not None 

56 with torch.no_grad(): 

57 

58 sample = get_test_inputs(model_spec) 

59 input_data = [sample[get_member_id(ipt)].data.data for ipt in model_spec.inputs] 

60 input_tensors = [torch.from_numpy(ipt) for ipt in input_data] 

61 model = load_torch_model(state_dict_weights_descr) 

62 

63 expected_tensors = model(*input_tensors) 

64 if isinstance(expected_tensors, torch.Tensor): 

65 expected_tensors = [expected_tensors] 

66 expected_outputs: List[np.ndarray[Any, Any]] = [ 

67 out.numpy() for out in expected_tensors 

68 ] 

69 

70 if use_tracing: 

71 torch.onnx.export( 

72 model, 

73 tuple(input_tensors) if len(input_tensors) > 1 else input_tensors[0], 

74 str(output_path), 

75 verbose=verbose, 

76 opset_version=opset_version, 

77 ) 

78 else: 

79 raise NotImplementedError 

80 

81 try: 

82 import onnxruntime as rt # pyright: ignore [reportMissingTypeStubs] 

83 except ImportError: 

84 msg = "The onnx weights were exported, but onnx rt is not available and weights cannot be checked." 

85 warnings.warn(msg) 

86 return 

87 

88 # check the onnx model 

89 sess = rt.InferenceSession(str(output_path)) 

90 onnx_input_node_args = cast( 

91 List[Any], sess.get_inputs() 

92 ) # fixme: remove cast, try using rt.NodeArg instead of Any 

93 onnx_inputs = { 

94 input_name.name: inp 

95 for input_name, inp in zip(onnx_input_node_args, input_data) 

96 } 

97 outputs = cast( 

98 Sequence[np.ndarray[Any, Any]], sess.run(None, onnx_inputs) 

99 ) # FIXME: remove cast 

100 

101 try: 

102 for exp, out in zip(expected_outputs, outputs): 

103 assert_array_almost_equal(exp, out, decimal=test_decimal) 

104 return 0 

105 except AssertionError as e: 

106 msg = f"The onnx weights were exported, but results before and after conversion do not agree:\n {str(e)}" 

107 warnings.warn(msg) 

108 return 1