Coverage for src/bioimageio/core/backends/__init__.py: 59%

68 statements  

« prev     ^ index     » next       coverage.py v7.14.2, created at 2026-06-22 16:54 +0000

1from typing import ( 

2 List, 

3 Optional, 

4 Sequence, 

5 Tuple, 

6 Union, 

7) 

8 

9from exceptiongroup import ExceptionGroup 

10from typing_extensions import assert_never 

11 

12from bioimageio.spec.model import v0_4, v0_5 

13 

14from ..common import SupportedWeightsFormat 

15 

16# Known weight formats in order of priority 

17# First match wins 

18DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER: Tuple[SupportedWeightsFormat, ...] = ( 

19 "pytorch_state_dict", 

20 "tensorflow_saved_model_bundle", 

21 "torchscript", 

22 "onnx", 

23 "keras_v3", 

24 "keras_hdf5", 

25) 

26 

27 

28def create_model_adapter( 

29 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

30 *, 

31 devices: Optional[Sequence[str]] = None, 

32 weight_format_priority_order: Optional[Sequence[SupportedWeightsFormat]] = None, 

33): 

34 """Creates model adapter for `model_descritption`""" 

35 if not isinstance(model_description, (v0_4.ModelDescr, v0_5.ModelDescr)): 

36 raise TypeError( 

37 f"expected v0_4.ModelDescr or v0_5.ModelDescr, but got {type(model_description)}" 

38 ) 

39 

40 weights = model_description.weights 

41 errors: List[Exception] = [] 

42 weight_format_priority_order = ( 

43 DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER 

44 if weight_format_priority_order is None 

45 else weight_format_priority_order 

46 ) 

47 # limit weight formats to the ones present 

48 weight_format_priority_order_present: Sequence[SupportedWeightsFormat] = [ 

49 w for w in weight_format_priority_order if getattr(weights, w, None) is not None 

50 ] 

51 if not weight_format_priority_order_present: 

52 raise ValueError( 

53 f"None of the specified weight formats ({weight_format_priority_order}) is present ({weight_format_priority_order_present})" 

54 ) 

55 

56 for wf in weight_format_priority_order_present: 

57 if wf == "pytorch_state_dict": 

58 assert weights.pytorch_state_dict is not None 

59 try: 

60 from .pytorch_backend import PytorchModelAdapter 

61 

62 return PytorchModelAdapter(model_description, devices=devices) 

63 except Exception as e: 

64 errors.append(e) 

65 elif wf == "tensorflow_saved_model_bundle": 

66 assert weights.tensorflow_saved_model_bundle is not None 

67 try: 

68 from .tensorflow_backend import create_tf_model_adapter 

69 

70 return create_tf_model_adapter(model_description, devices=devices) 

71 except Exception as e: 

72 errors.append(e) 

73 elif wf == "onnx": 

74 assert weights.onnx is not None 

75 try: 

76 from .onnx_backend import ONNXModelAdapter 

77 

78 return ONNXModelAdapter(model_description, devices=devices) 

79 except Exception as e: 

80 errors.append(e) 

81 elif wf == "torchscript": 

82 assert weights.torchscript is not None 

83 try: 

84 from .torchscript_backend import TorchscriptModelAdapter 

85 

86 return TorchscriptModelAdapter(model_description, devices=devices) 

87 except Exception as e: 

88 errors.append(e) 

89 elif wf == "keras_hdf5": 

90 assert weights.keras_hdf5 is not None 

91 # keras can either be installed as a separate package or used as part of tensorflow 

92 # we try to first import the keras model adapter using the separate package and, 

93 # if it is not available, try to load the one using tf 

94 try: 

95 try: 

96 from .keras_backend import KerasModelAdapter 

97 except Exception: 

98 from .tensorflow_backend import KerasModelAdapter 

99 

100 return KerasModelAdapter(model_description, devices=devices) 

101 except Exception as e: 

102 errors.append(e) 

103 elif wf == "keras_v3": 

104 assert not isinstance(weights, v0_4.WeightsDescr), ( 

105 "keras_v3 weights not supported for v0.4 specs" 

106 ) 

107 assert weights.keras_v3 is not None 

108 try: 

109 from .keras_backend import KerasModelAdapter 

110 

111 return KerasModelAdapter(model_description, devices=devices) 

112 except Exception as e: 

113 errors.append(e) 

114 else: 

115 assert_never(wf) 

116 

117 assert errors 

118 if len(weight_format_priority_order) == 1: 

119 assert len(errors) == 1 

120 raise errors[0] 

121 

122 else: 

123 msg = ( 

124 "None of the weight format specific model adapters could be created" 

125 + " in this environment." 

126 ) 

127 raise ExceptionGroup(msg, errors)