Coverage for bioimageio/core/weight_converter/torch/_torchscript.py: 16%

73 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-19 09:02 +0000

1# type: ignore # TODO: type 

2from pathlib import Path 

3from typing import List, Sequence, Union 

4 

5import numpy as np 

6from numpy.testing import assert_array_almost_equal 

7from typing_extensions import Any, assert_never 

8 

9from bioimageio.spec.model import v0_4, v0_5 

10from bioimageio.spec.model.v0_5 import Version 

11 

12from ._utils import load_torch_model 

13 

14try: 

15 import torch 

16except ImportError: 

17 torch = None 

18 

19 

20# FIXME: remove Any 

21def _check_predictions( 

22 model: Any, 

23 scripted_model: Any, 

24 model_spec: "v0_4.ModelDescr | v0_5.ModelDescr", 

25 input_data: Sequence["torch.Tensor"], 

26): 

27 assert torch is not None 

28 

29 def _check(input_: Sequence[torch.Tensor]) -> None: 

30 expected_tensors = model(*input_) 

31 if isinstance(expected_tensors, torch.Tensor): 

32 expected_tensors = [expected_tensors] 

33 expected_outputs: List[np.ndarray[Any, Any]] = [ 

34 out.numpy() for out in expected_tensors 

35 ] 

36 

37 output_tensors = scripted_model(*input_) 

38 if isinstance(output_tensors, torch.Tensor): 

39 output_tensors = [output_tensors] 

40 outputs: List[np.ndarray[Any, Any]] = [out.numpy() for out in output_tensors] 

41 

42 try: 

43 for exp, out in zip(expected_outputs, outputs): 

44 assert_array_almost_equal(exp, out, decimal=4) 

45 except AssertionError as e: 

46 raise ValueError( 

47 f"Results before and after weights conversion do not agree:\n {str(e)}" 

48 ) 

49 

50 _check(input_data) 

51 

52 if len(model_spec.inputs) > 1: 

53 return # FIXME: why don't we check multiple inputs? 

54 

55 input_descr = model_spec.inputs[0] 

56 if isinstance(input_descr, v0_4.InputTensorDescr): 

57 if not isinstance(input_descr.shape, v0_4.ParameterizedInputShape): 

58 return 

59 min_shape = input_descr.shape.min 

60 step = input_descr.shape.step 

61 else: 

62 min_shape: List[int] = [] 

63 step: List[int] = [] 

64 for axis in input_descr.axes: 

65 if isinstance(axis.size, v0_5.ParameterizedSize): 

66 min_shape.append(axis.size.min) 

67 step.append(axis.size.step) 

68 elif isinstance(axis.size, int): 

69 min_shape.append(axis.size) 

70 step.append(0) 

71 elif axis.size is None: 

72 raise NotImplementedError( 

73 f"Can't verify inputs that don't specify their shape fully: {axis}" 

74 ) 

75 elif isinstance(axis.size, v0_5.SizeReference): 

76 raise NotImplementedError(f"Can't handle axes like '{axis}' yet") 

77 else: 

78 assert_never(axis.size) 

79 

80 half_step = [st // 2 for st in step] 

81 max_steps = 4 

82 

83 # check that input and output agree for decreasing input sizes 

84 for step_factor in range(1, max_steps + 1): 

85 slice_ = tuple( 

86 slice(None) if st == 0 else slice(step_factor * st, -step_factor * st) 

87 for st in half_step 

88 ) 

89 this_input = [inp[slice_] for inp in input_data] 

90 this_shape = this_input[0].shape 

91 if any(tsh < msh for tsh, msh in zip(this_shape, min_shape)): 

92 raise ValueError( 

93 f"Mismatched shapes: {this_shape}. Expected at least {min_shape}" 

94 ) 

95 _check(this_input) 

96 

97 

98def convert_weights_to_torchscript( 

99 model_descr: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

100 output_path: Path, 

101 use_tracing: bool = True, 

102) -> v0_5.TorchscriptWeightsDescr: 

103 """Convert model weights from format 'pytorch_state_dict' to 'torchscript'. 

104 

105 Args: 

106 model_descr: location of the resource for the input bioimageio model 

107 output_path: where to save the torchscript weights 

108 use_tracing: whether to use tracing or scripting to export the torchscript format 

109 """ 

110 

111 state_dict_weights_descr = model_descr.weights.pytorch_state_dict 

112 if state_dict_weights_descr is None: 

113 raise ValueError( 

114 "The provided model does not have weights in the pytorch state dict format" 

115 ) 

116 

117 input_data = model_descr.get_input_test_arrays() 

118 

119 with torch.no_grad(): 

120 input_data = [torch.from_numpy(inp.astype("float32")) for inp in input_data] 

121 

122 model = load_torch_model(state_dict_weights_descr) 

123 

124 # FIXME: remove Any 

125 if use_tracing: 

126 scripted_model: Any = torch.jit.trace(model, input_data) 

127 else: 

128 scripted_model: Any = torch.jit.script(model) 

129 

130 _check_predictions( 

131 model=model, 

132 scripted_model=scripted_model, 

133 model_spec=model_descr, 

134 input_data=input_data, 

135 ) 

136 

137 # save the torchscript model 

138 scripted_model.save( 

139 str(output_path) 

140 ) # does not support Path, so need to cast to str 

141 

142 return v0_5.TorchscriptWeightsDescr( 

143 source=output_path, 

144 pytorch_version=Version(torch.__version__), 

145 parent="pytorch_state_dict", 

146 )