Coverage for src / bioimageio / core / backends / onnx_backend.py: 68%

73 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-18 12:35 +0000

1# pyright: reportUnknownVariableType=false 

2import shutil 

3import tempfile 

4import warnings 

5from contextlib import contextmanager, nullcontext 

6from pathlib import Path 

7from typing import Any, List, Optional, Sequence, Union, cast 

8 

9import onnxruntime as rt # pyright: ignore[reportMissingTypeStubs] 

10from exceptiongroup import ExceptionGroup 

11from loguru import logger 

12from numpy.typing import NDArray 

13 

14from bioimageio.spec.model import v0_4, v0_5 

15 

16from ..model_adapters import ModelAdapter 

17from ..utils._type_guards import is_list, is_tuple 

18 

19 

20class ONNXModelAdapter(ModelAdapter): 

21 def __init__( 

22 self, 

23 *, 

24 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

25 devices: Optional[Sequence[str]] = None, 

26 ): 

27 super().__init__(model_description=model_description) 

28 

29 onnx_descr = model_description.weights.onnx 

30 if onnx_descr is None: 

31 raise ValueError("No ONNX weights specified for {model_description.name}") 

32 

33 available_providers: Any = None 

34 if hasattr(rt, "get_available_providers"): 

35 available_providers = cast(Any, rt.get_available_providers()) 

36 

37 if is_list(available_providers): 

38 if len(available_providers) == 0: 

39 providers = [None] 

40 else: 

41 providers = available_providers 

42 else: 

43 providers = [available_providers] 

44 

45 if ( 

46 isinstance(onnx_descr, v0_5.OnnxWeightsDescr) 

47 and onnx_descr.external_data is not None 

48 ): 

49 src = onnx_descr.source.absolute() 

50 src_data = onnx_descr.external_data.source.absolute() 

51 if ( 

52 isinstance(src, Path) 

53 and isinstance(src_data, Path) 

54 and src.parent == src_data.parent 

55 ): 

56 logger.debug( 

57 "Loading ONNX model with external data from {}", 

58 src.parent, 

59 ) 

60 source_context = nullcontext(src) 

61 else: 

62 src_reader = onnx_descr.get_reader() 

63 src_data_reader = onnx_descr.external_data.get_reader() 

64 

65 @contextmanager 

66 def source_context_func(): 

67 with tempfile.TemporaryDirectory() as tmpdir: 

68 logger.debug( 

69 "Loading ONNX model with external data from {}", 

70 tmpdir, 

71 ) 

72 src = Path(tmpdir) / src_reader.original_file_name 

73 src_data = Path(tmpdir) / src_data_reader.original_file_name 

74 with src.open("wb") as f: 

75 shutil.copyfileobj(src_reader, f) 

76 with src_data.open("wb") as f: 

77 shutil.copyfileobj(src_data_reader, f) 

78 yield src 

79 

80 source_context = source_context_func() 

81 

82 else: 

83 # load single source file from bytes (without external data, so probably <2GB) 

84 logger.debug( 

85 "Loading ONNX model from bytes (read from {})", onnx_descr.source 

86 ) 

87 source_context = nullcontext(onnx_descr.get_reader().read()) 

88 

89 with source_context as s: 

90 assert isinstance(s, bytes) or s.exists() 

91 

92 # try providers in order until one works 

93 # TODO: check if issue with backup providers is fixed and evaluate handing over all available providers 

94 # currently (onnxruntime 1.23.2) if a higher priority providers fails a RUNTIME_EXCEPTION may be raised 

95 # stating 'model_path must not be empty' instead of trying the next provider, see # TODO: reference issue 

96 provider_exceptions: List[Exception] = [] 

97 for p in providers: 

98 try: 

99 self._session = rt.InferenceSession( 

100 s, 

101 providers=None if p is None else [p], 

102 ) 

103 except Exception as e: 

104 provider_exceptions.append(e) 

105 else: 

106 for bad_p, e in zip( 

107 providers[: len(provider_exceptions)], provider_exceptions 

108 ): 

109 logger.warning( 

110 "Failed to load ONNX model with provider {}: {}", 

111 bad_p, 

112 e, 

113 ) 

114 

115 break 

116 else: 

117 raise ExceptionGroup( 

118 "Failed to load ONNX model with any of the available providers.", 

119 provider_exceptions, 

120 ) 

121 

122 onnx_inputs = self._session.get_inputs() 

123 self._input_names: List[str] = [ipt.name for ipt in onnx_inputs] 

124 

125 if devices is not None: 

126 warnings.warn( 

127 f"Device management is not implemented for onnx yet, ignoring the devices {devices}" 

128 ) 

129 

130 def _forward_impl( 

131 self, input_arrays: Sequence[Optional[NDArray[Any]]] 

132 ) -> List[Optional[NDArray[Any]]]: 

133 result: Any = self._session.run( 

134 None, dict(zip(self._input_names, input_arrays)) 

135 ) 

136 if is_list(result) or is_tuple(result): 

137 result_seq = list(result) 

138 else: 

139 result_seq = [result] 

140 

141 return result_seq 

142 

143 def unload(self) -> None: 

144 warnings.warn( 

145 "Device management is not implemented for onnx yet, cannot unload model" 

146 )