Coverage for src/bioimageio/core/weight_converters/_add_weights.py: 10%

80 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-09-22 09:21 +0000

1import traceback 

2from typing import Optional, Union 

3 

4from loguru import logger 

5from pydantic import DirectoryPath 

6 

7from bioimageio.spec import ( 

8 InvalidDescr, 

9 load_model_description, 

10 save_bioimageio_package_as_folder, 

11) 

12from bioimageio.spec.model.v0_5 import ModelDescr, WeightsFormat 

13 

14from .._resource_tests import load_description_and_test 

15 

16 

17def add_weights( 

18 model_descr: ModelDescr, 

19 *, 

20 output_path: DirectoryPath, 

21 source_format: Optional[WeightsFormat] = None, 

22 target_format: Optional[WeightsFormat] = None, 

23 verbose: bool = False, 

24 allow_tracing: bool = True, 

25) -> Union[ModelDescr, InvalidDescr]: 

26 """Convert model weights to other formats and add them to the model description 

27 

28 Args: 

29 output_path: Path to save updated model package to. 

30 source_format: convert from a specific weights format. 

31 Default: choose automatically from any available. 

32 target_format: convert to a specific weights format. 

33 Default: attempt to convert to any missing format. 

34 devices: Devices that may be used during conversion. 

35 verbose: log more (error) output 

36 

37 Returns: 

38 A (potentially invalid) model copy stored at `output_path` with added weights if any conversion was possible. 

39 

40 """ 

41 if not isinstance(model_descr, ModelDescr): 

42 if model_descr.type == "model" and not isinstance(model_descr, InvalidDescr): 

43 raise TypeError( 

44 f"Model format {model_descr.format} is not supported, please update" 

45 + f" model to format {ModelDescr.implemented_format_version} first." 

46 ) 

47 

48 raise TypeError(type(model_descr)) 

49 

50 # save model to local folder 

51 output_path = save_bioimageio_package_as_folder( 

52 model_descr, output_path=output_path 

53 ) 

54 # reload from local folder to make sure we do not edit the given model 

55 model_descr = load_model_description( 

56 output_path, perform_io_checks=False, format_version="latest" 

57 ) 

58 

59 if source_format is None: 

60 available = set(model_descr.weights.available_formats) 

61 else: 

62 available = {source_format} 

63 

64 if target_format is None: 

65 missing = set(model_descr.weights.missing_formats) 

66 else: 

67 missing = {target_format} 

68 

69 originally_missing = set(missing) 

70 

71 if "pytorch_state_dict" in available and "torchscript" in missing: 

72 logger.info( 

73 "Attempting to convert 'pytorch_state_dict' weights to 'torchscript'." 

74 ) 

75 from .pytorch_to_torchscript import convert 

76 

77 try: 

78 torchscript_weights_path = output_path / "weights_torchscript.pt" 

79 model_descr.weights.torchscript = convert( 

80 model_descr, 

81 output_path=torchscript_weights_path, 

82 use_tracing=False, 

83 ) 

84 except Exception as e: 

85 if verbose: 

86 traceback.print_exception(type(e), e, e.__traceback__) 

87 

88 logger.error(e) 

89 else: 

90 available.add("torchscript") 

91 missing.discard("torchscript") 

92 

93 if allow_tracing and "pytorch_state_dict" in available and "torchscript" in missing: 

94 logger.info( 

95 "Attempting to convert 'pytorch_state_dict' weights to 'torchscript' by tracing." 

96 ) 

97 from .pytorch_to_torchscript import convert 

98 

99 try: 

100 torchscript_weights_path = output_path / "weights_torchscript_traced.pt" 

101 

102 model_descr.weights.torchscript = convert( 

103 model_descr, 

104 output_path=torchscript_weights_path, 

105 use_tracing=True, 

106 ) 

107 except Exception as e: 

108 if verbose: 

109 traceback.print_exception(type(e), e, e.__traceback__) 

110 

111 logger.error(e) 

112 else: 

113 available.add("torchscript") 

114 missing.discard("torchscript") 

115 

116 if "torchscript" in available and "onnx" in missing: 

117 logger.info("Attempting to convert 'torchscript' weights to 'onnx'.") 

118 from .torchscript_to_onnx import convert 

119 

120 try: 

121 onnx_weights_path = output_path / "weights.onnx" 

122 model_descr.weights.onnx = convert( 

123 model_descr, 

124 output_path=onnx_weights_path, 

125 ) 

126 except Exception as e: 

127 if verbose: 

128 traceback.print_exception(type(e), e, e.__traceback__) 

129 

130 logger.error(e) 

131 else: 

132 available.add("onnx") 

133 missing.discard("onnx") 

134 

135 if "pytorch_state_dict" in available and "onnx" in missing: 

136 logger.info("Attempting to convert 'pytorch_state_dict' weights to 'onnx'.") 

137 from .pytorch_to_onnx import convert 

138 

139 try: 

140 onnx_weights_path = output_path / "weights.onnx" 

141 

142 model_descr.weights.onnx = convert( 

143 model_descr, 

144 output_path=onnx_weights_path, 

145 verbose=verbose, 

146 ) 

147 except Exception as e: 

148 if verbose: 

149 traceback.print_exception(type(e), e, e.__traceback__) 

150 

151 logger.error(e) 

152 else: 

153 available.add("onnx") 

154 missing.discard("onnx") 

155 

156 if missing: 

157 logger.warning( 

158 f"Converting from any of the available weights formats {available} to any" 

159 + f" of {missing} failed or is not yet implemented. Please create an issue" 

160 + " at https://github.com/bioimage-io/core-bioimage-io-python/issues/new/choose" 

161 + " if you would like bioimageio.core to support a particular conversion." 

162 ) 

163 

164 if originally_missing == missing: 

165 logger.warning("failed to add any converted weights") 

166 return model_descr 

167 else: 

168 logger.info("added weights formats {}", originally_missing - missing) 

169 # resave model with updated rdf.yaml 

170 _ = save_bioimageio_package_as_folder(model_descr, output_path=output_path) 

171 tested_model_descr = load_description_and_test( 

172 model_descr, format_version="latest", expected_type="model" 

173 ) 

174 if not isinstance(tested_model_descr, ModelDescr): 

175 logger.error( 

176 f"The updated model description at {output_path} did not pass testing." 

177 ) 

178 

179 return tested_model_descr