Coverage for bioimageio/core/weight_converters/_add_weights.py: 34%

82 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-01 09:51 +0000

1import traceback 

2from typing import Optional 

3 

4from loguru import logger 

5from pydantic import DirectoryPath 

6 

7from bioimageio.spec import ( 

8 InvalidDescr, 

9 load_model_description, 

10 save_bioimageio_package_as_folder, 

11) 

12from bioimageio.spec.model.v0_5 import ModelDescr, WeightsFormat 

13 

14from .._resource_tests import load_description_and_test 

15 

16 

17def add_weights( 

18 model_descr: ModelDescr, 

19 *, 

20 output_path: DirectoryPath, 

21 source_format: Optional[WeightsFormat] = None, 

22 target_format: Optional[WeightsFormat] = None, 

23 verbose: bool = False, 

24) -> Optional[ModelDescr]: 

25 """Convert model weights to other formats and add them to the model description 

26 

27 Args: 

28 output_path: Path to save updated model package to. 

29 source_format: convert from a specific weights format. 

30 Default: choose automatically from any available. 

31 target_format: convert to a specific weights format. 

32 Default: attempt to convert to any missing format. 

33 devices: Devices that may be used during conversion. 

34 verbose: log more (error) output 

35 

36 Returns: 

37 - An updated model description if any converted weights were added. 

38 - `None` if no conversion was possible. 

39 """ 

40 if not isinstance(model_descr, ModelDescr): 

41 if model_descr.type == "model" and not isinstance(model_descr, InvalidDescr): 

42 raise TypeError( 

43 f"Model format {model_descr.format} is not supported, please update" 

44 + f" model to format {ModelDescr.implemented_format_version} first." 

45 ) 

46 

47 raise TypeError(type(model_descr)) 

48 

49 # save model to local folder 

50 output_path = save_bioimageio_package_as_folder( 

51 model_descr, output_path=output_path 

52 ) 

53 # reload from local folder to make sure we do not edit the given model 

54 _model_descr = load_model_description(output_path, perform_io_checks=False) 

55 assert isinstance(_model_descr, ModelDescr) 

56 model_descr = _model_descr 

57 del _model_descr 

58 

59 if source_format is None: 

60 available = set(model_descr.weights.available_formats) 

61 else: 

62 available = {source_format} 

63 

64 if target_format is None: 

65 missing = set(model_descr.weights.missing_formats) 

66 else: 

67 missing = {target_format} 

68 

69 originally_missing = set(missing) 

70 

71 if "pytorch_state_dict" in available and "torchscript" in missing: 

72 logger.info( 

73 "Attempting to convert 'pytorch_state_dict' weights to 'torchscript'." 

74 ) 

75 from .pytorch_to_torchscript import convert 

76 

77 try: 

78 torchscript_weights_path = output_path / "weights_torchscript.pt" 

79 model_descr.weights.torchscript = convert( 

80 model_descr, 

81 output_path=torchscript_weights_path, 

82 use_tracing=False, 

83 ) 

84 except Exception as e: 

85 if verbose: 

86 traceback.print_exception(e) 

87 

88 logger.error(e) 

89 else: 

90 available.add("torchscript") 

91 missing.discard("torchscript") 

92 

93 if "pytorch_state_dict" in available and "torchscript" in missing: 

94 logger.info( 

95 "Attempting to convert 'pytorch_state_dict' weights to 'torchscript' by tracing." 

96 ) 

97 from .pytorch_to_torchscript import convert 

98 

99 try: 

100 torchscript_weights_path = output_path / "weights_torchscript_traced.pt" 

101 

102 model_descr.weights.torchscript = convert( 

103 model_descr, 

104 output_path=torchscript_weights_path, 

105 use_tracing=True, 

106 ) 

107 except Exception as e: 

108 if verbose: 

109 traceback.print_exception(e) 

110 

111 logger.error(e) 

112 else: 

113 available.add("torchscript") 

114 missing.discard("torchscript") 

115 

116 if "torchscript" in available and "onnx" in missing: 

117 logger.info("Attempting to convert 'torchscript' weights to 'onnx'.") 

118 from .torchscript_to_onnx import convert 

119 

120 try: 

121 onnx_weights_path = output_path / "weights.onnx" 

122 model_descr.weights.onnx = convert( 

123 model_descr, 

124 output_path=onnx_weights_path, 

125 ) 

126 except Exception as e: 

127 if verbose: 

128 traceback.print_exception(e) 

129 

130 logger.error(e) 

131 else: 

132 available.add("onnx") 

133 missing.discard("onnx") 

134 

135 if "pytorch_state_dict" in available and "onnx" in missing: 

136 logger.info("Attempting to convert 'pytorch_state_dict' weights to 'onnx'.") 

137 from .pytorch_to_onnx import convert 

138 

139 try: 

140 onnx_weights_path = output_path / "weights.onnx" 

141 

142 model_descr.weights.onnx = convert( 

143 model_descr, 

144 output_path=onnx_weights_path, 

145 verbose=verbose, 

146 ) 

147 except Exception as e: 

148 if verbose: 

149 traceback.print_exception(e) 

150 

151 logger.error(e) 

152 else: 

153 available.add("onnx") 

154 missing.discard("onnx") 

155 

156 if missing: 

157 logger.warning( 

158 f"Converting from any of the available weights formats {available} to any" 

159 + f" of {missing} failed or is not yet implemented. Please create an issue" 

160 + " at https://github.com/bioimage-io/core-bioimage-io-python/issues/new/choose" 

161 + " if you would like bioimageio.core to support a particular conversion." 

162 ) 

163 

164 if originally_missing == missing: 

165 logger.warning("failed to add any converted weights") 

166 return None 

167 else: 

168 logger.info("added weights formats {}", originally_missing - missing) 

169 # resave model with updated rdf.yaml 

170 _ = save_bioimageio_package_as_folder(model_descr, output_path=output_path) 

171 tested_model_descr = load_description_and_test(model_descr) 

172 assert isinstance(tested_model_descr, ModelDescr) 

173 return tested_model_descr