Coverage for bioimageio/core/model_adapters/_model_adapter.py: 58%

66 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-19 09:02 +0000

1import warnings 

2from abc import ABC, abstractmethod 

3from typing import List, Optional, Sequence, Tuple, Union, final 

4 

5from bioimageio.spec.model import v0_4, v0_5 

6 

7from ..tensor import Tensor 

8 

9WeightsFormat = Union[v0_4.WeightsFormat, v0_5.WeightsFormat] 

10 

11# Known weight formats in order of priority 

12# First match wins 

13DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER: Tuple[WeightsFormat, ...] = ( 

14 "pytorch_state_dict", 

15 "tensorflow_saved_model_bundle", 

16 "torchscript", 

17 "onnx", 

18 "keras_hdf5", 

19) 

20 

21 

22class ModelAdapter(ABC): 

23 """ 

24 Represents model *without* any preprocessing or postprocessing. 

25 

26 ``` 

27 from bioimageio.core import load_description 

28 

29 model = load_description(...) 

30 

31 # option 1: 

32 adapter = ModelAdapter.create(model) 

33 adapter.forward(...) 

34 adapter.unload() 

35 

36 # option 2: 

37 with ModelAdapter.create(model) as adapter: 

38 adapter.forward(...) 

39 ``` 

40 """ 

41 

42 @final 

43 @classmethod 

44 def create( 

45 cls, 

46 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

47 *, 

48 devices: Optional[Sequence[str]] = None, 

49 weight_format_priority_order: Optional[Sequence[WeightsFormat]] = None, 

50 ): 

51 """ 

52 Creates model adapter based on the passed spec 

53 Note: All specific adapters should happen inside this function to prevent different framework 

54 initializations interfering with each other 

55 """ 

56 if not isinstance(model_description, (v0_4.ModelDescr, v0_5.ModelDescr)): 

57 raise TypeError( 

58 f"expected v0_4.ModelDescr or v0_5.ModelDescr, but got {type(model_description)}" 

59 ) 

60 

61 weights = model_description.weights 

62 errors: List[Tuple[WeightsFormat, Exception]] = [] 

63 weight_format_priority_order = ( 

64 DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER 

65 if weight_format_priority_order is None 

66 else weight_format_priority_order 

67 ) 

68 # limit weight formats to the ones present 

69 weight_format_priority_order = [ 

70 w for w in weight_format_priority_order if getattr(weights, w) is not None 

71 ] 

72 

73 for wf in weight_format_priority_order: 

74 if wf == "pytorch_state_dict" and weights.pytorch_state_dict is not None: 

75 try: 

76 from ._pytorch_model_adapter import PytorchModelAdapter 

77 

78 return PytorchModelAdapter( 

79 outputs=model_description.outputs, 

80 weights=weights.pytorch_state_dict, 

81 devices=devices, 

82 ) 

83 except Exception as e: 

84 errors.append((wf, e)) 

85 elif ( 

86 wf == "tensorflow_saved_model_bundle" 

87 and weights.tensorflow_saved_model_bundle is not None 

88 ): 

89 try: 

90 from ._tensorflow_model_adapter import TensorflowModelAdapter 

91 

92 return TensorflowModelAdapter( 

93 model_description=model_description, devices=devices 

94 ) 

95 except Exception as e: 

96 errors.append((wf, e)) 

97 elif wf == "onnx" and weights.onnx is not None: 

98 try: 

99 from ._onnx_model_adapter import ONNXModelAdapter 

100 

101 return ONNXModelAdapter( 

102 model_description=model_description, devices=devices 

103 ) 

104 except Exception as e: 

105 errors.append((wf, e)) 

106 elif wf == "torchscript" and weights.torchscript is not None: 

107 try: 

108 from ._torchscript_model_adapter import TorchscriptModelAdapter 

109 

110 return TorchscriptModelAdapter( 

111 model_description=model_description, devices=devices 

112 ) 

113 except Exception as e: 

114 errors.append((wf, e)) 

115 elif wf == "keras_hdf5" and weights.keras_hdf5 is not None: 

116 # keras can either be installed as a separate package or used as part of tensorflow 

117 # we try to first import the keras model adapter using the separate package and, 

118 # if it is not available, try to load the one using tf 

119 try: 

120 from ._keras_model_adapter import ( 

121 KerasModelAdapter, 

122 keras, # type: ignore 

123 ) 

124 

125 if keras is None: 

126 from ._tensorflow_model_adapter import KerasModelAdapter 

127 

128 return KerasModelAdapter( 

129 model_description=model_description, devices=devices 

130 ) 

131 except Exception as e: 

132 errors.append((wf, e)) 

133 

134 assert errors 

135 if len(weight_format_priority_order) == 1: 

136 assert len(errors) == 1 

137 raise ValueError( 

138 f"The '{weight_format_priority_order[0]}' model adapter could not be created" 

139 + f" in this environment:\n{errors[0][1].__class__.__name__}({errors[0][1]}).\n\n" 

140 ) 

141 

142 else: 

143 error_list = "\n - ".join( 

144 f"{wf}: {e.__class__.__name__}({e})" for wf, e in errors 

145 ) 

146 raise ValueError( 

147 "None of the weight format specific model adapters could be created" 

148 + f" in this environment. Errors are:\n\n{error_list}.\n\n" 

149 ) 

150 

151 @final 

152 def load(self, *, devices: Optional[Sequence[str]] = None) -> None: 

153 warnings.warn("Deprecated. ModelAdapter is loaded on initialization") 

154 

155 @abstractmethod 

156 def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]: 

157 """ 

158 Run forward pass of model to get model predictions 

159 """ 

160 # TODO: handle tensor.transpose in here and make _forward_impl the abstract impl 

161 

162 @abstractmethod 

163 def unload(self): 

164 """ 

165 Unload model from any devices, freeing their memory. 

166 The moder adapter should be considered unusable afterwards. 

167 """ 

168 

169 

170def get_weight_formats() -> List[str]: 

171 """ 

172 Return list of supported weight types 

173 """ 

174 return list(DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER) 

175 

176 

177create_model_adapter = ModelAdapter.create