Coverage for src/bioimageio/core/weight_converters/_add_weights.py: 10%

80 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-13 11:02 +0000

1import traceback 

2from typing import Optional, Union 

3 

4from bioimageio.spec import ( 

5 InvalidDescr, 

6 load_model_description, 

7 save_bioimageio_package_as_folder, 

8) 

9from bioimageio.spec.model.v0_5 import ModelDescr, WeightsFormat 

10from loguru import logger 

11from pydantic import DirectoryPath 

12 

13from .._resource_tests import load_description_and_test 

14 

15 

16def add_weights( 

17 model_descr: ModelDescr, 

18 *, 

19 output_path: DirectoryPath, 

20 source_format: Optional[WeightsFormat] = None, 

21 target_format: Optional[WeightsFormat] = None, 

22 verbose: bool = False, 

23 allow_tracing: bool = True, 

24) -> Union[ModelDescr, InvalidDescr]: 

25 """Convert model weights to other formats and add them to the model description 

26 

27 Args: 

28 output_path: Path to save updated model package to. 

29 source_format: convert from a specific weights format. 

30 Default: choose automatically from any available. 

31 target_format: convert to a specific weights format. 

32 Default: attempt to convert to any missing format. 

33 devices: Devices that may be used during conversion. 

34 verbose: log more (error) output 

35 

36 Returns: 

37 A (potentially invalid) model copy stored at `output_path` with added weights if any conversion was possible. 

38 

39 """ 

40 if not isinstance(model_descr, ModelDescr): 

41 if model_descr.type == "model" and not isinstance(model_descr, InvalidDescr): 

42 raise TypeError( 

43 f"Model format {model_descr.format} is not supported, please update" 

44 + f" model to format {ModelDescr.implemented_format_version} first." 

45 ) 

46 

47 raise TypeError(type(model_descr)) 

48 

49 # save model to local folder 

50 output_path = save_bioimageio_package_as_folder( 

51 model_descr, output_path=output_path 

52 ) 

53 # reload from local folder to make sure we do not edit the given model 

54 model_descr = load_model_description( 

55 output_path, perform_io_checks=False, format_version="latest" 

56 ) 

57 

58 if source_format is None: 

59 available = set(model_descr.weights.available_formats) 

60 else: 

61 available = {source_format} 

62 

63 if target_format is None: 

64 missing = set(model_descr.weights.missing_formats) 

65 else: 

66 missing = {target_format} 

67 

68 originally_missing = set(missing) 

69 

70 if "pytorch_state_dict" in available and "torchscript" in missing: 

71 logger.info( 

72 "Attempting to convert 'pytorch_state_dict' weights to 'torchscript'." 

73 ) 

74 from .pytorch_to_torchscript import convert 

75 

76 try: 

77 torchscript_weights_path = output_path / "weights_torchscript.pt" 

78 model_descr.weights.torchscript = convert( 

79 model_descr, 

80 output_path=torchscript_weights_path, 

81 use_tracing=False, 

82 ) 

83 except Exception as e: 

84 if verbose: 

85 traceback.print_exception(type(e), e, e.__traceback__) 

86 

87 logger.error(e) 

88 else: 

89 available.add("torchscript") 

90 missing.discard("torchscript") 

91 

92 if allow_tracing and "pytorch_state_dict" in available and "torchscript" in missing: 

93 logger.info( 

94 "Attempting to convert 'pytorch_state_dict' weights to 'torchscript' by tracing." 

95 ) 

96 from .pytorch_to_torchscript import convert 

97 

98 try: 

99 torchscript_weights_path = output_path / "weights_torchscript_traced.pt" 

100 

101 model_descr.weights.torchscript = convert( 

102 model_descr, 

103 output_path=torchscript_weights_path, 

104 use_tracing=True, 

105 ) 

106 except Exception as e: 

107 if verbose: 

108 traceback.print_exception(type(e), e, e.__traceback__) 

109 

110 logger.error(e) 

111 else: 

112 available.add("torchscript") 

113 missing.discard("torchscript") 

114 

115 if "pytorch_state_dict" in available and "onnx" in missing: 

116 logger.info("Attempting to convert 'pytorch_state_dict' weights to 'onnx'.") 

117 from .pytorch_to_onnx import convert 

118 

119 try: 

120 onnx_weights_path = output_path / "weights.onnx" 

121 

122 model_descr.weights.onnx = convert( 

123 model_descr, 

124 output_path=onnx_weights_path, 

125 verbose=verbose, 

126 ) 

127 except Exception as e: 

128 if verbose: 

129 traceback.print_exception(type(e), e, e.__traceback__) 

130 

131 logger.error(e) 

132 else: 

133 available.add("onnx") 

134 missing.discard("onnx") 

135 

136 if "torchscript" in available and "onnx" in missing: 

137 logger.info("Attempting to convert 'torchscript' weights to 'onnx'.") 

138 from .torchscript_to_onnx import convert 

139 

140 try: 

141 onnx_weights_path = output_path / "weights.onnx" 

142 model_descr.weights.onnx = convert( 

143 model_descr, 

144 output_path=onnx_weights_path, 

145 verbose=verbose, 

146 ) 

147 except Exception as e: 

148 if verbose: 

149 traceback.print_exception(type(e), e, e.__traceback__) 

150 

151 logger.error(e) 

152 else: 

153 available.add("onnx") 

154 missing.discard("onnx") 

155 

156 if missing: 

157 logger.warning( 

158 f"Converting from any of the available weights formats {available} to any" 

159 + f" of {missing} failed or is not yet implemented. Please create an issue" 

160 + " at https://github.com/bioimage-io/core-bioimage-io-python/issues/new/choose" 

161 + " if you would like bioimageio.core to support a particular conversion." 

162 ) 

163 

164 if originally_missing == missing: 

165 logger.warning("failed to add any converted weights") 

166 return model_descr 

167 else: 

168 logger.info("added weights formats {}", originally_missing - missing) 

169 # resave model with updated rdf.yaml 

170 _ = save_bioimageio_package_as_folder(model_descr, output_path=output_path) 

171 tested_model_descr = load_description_and_test( 

172 model_descr, format_version="latest", expected_type="model" 

173 ) 

174 if not isinstance(tested_model_descr, ModelDescr): 

175 logger.error( 

176 f"The updated model description at {output_path} did not pass testing." 

177 ) 

178 

179 return tested_model_descr