Coverage for src/bioimageio/core/backends/onnx_backend.py: 60%

85 statements  

« prev     ^ index     » next       coverage.py v7.14.2, created at 2026-06-22 16:54 +0000

1# pyright: reportUnknownVariableType=false 

2import shutil 

3import tempfile 

4from contextlib import contextmanager, nullcontext 

5from pathlib import Path 

6from typing import Any, List, Optional, Sequence, Union, cast 

7 

8import onnxruntime as rt # pyright: ignore[reportMissingTypeStubs] 

9from loguru import logger 

10from numpy.typing import NDArray 

11 

12from bioimageio.spec.model import v0_4, v0_5 

13 

14from .._model_adapter import LocalModelAdapter 

15from ..utils._type_guards import is_list, is_tuple 

16 

17 

18class ONNXModelAdapter(LocalModelAdapter[Optional[str], rt.InferenceSession]): 

19 def __init__( 

20 self, 

21 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

22 devices: Optional[Sequence[str]] = None, 

23 ): 

24 onnx_descr = model_description.weights.onnx 

25 if onnx_descr is None: 

26 raise ValueError("No ONNX weights specified for {model_description.name}") 

27 

28 self._onnx_descr = onnx_descr 

29 self._input_names: Optional[List[str]] = None 

30 super().__init__(model_description=model_description, devices=devices) 

31 

32 def _parse_devices( 

33 self, devices: Optional[Sequence[str]] 

34 ) -> Sequence[Optional[str]]: 

35 available_providers: Any = None 

36 if hasattr(rt, "get_available_providers"): 

37 available_providers = cast(Any, rt.get_available_providers()) 

38 

39 if is_list(available_providers): 

40 if len(available_providers) == 0: 

41 providers = [None] 

42 else: 

43 providers = available_providers 

44 else: 

45 available_providers = [available_providers] 

46 providers = [available_providers] 

47 

48 if devices is not None: 

49 available_devices = [d for d in devices if d in providers] 

50 unavailable_devices = [d for d in devices if d not in providers] 

51 if available_devices: 

52 if unavailable_devices: 

53 logger.warning( 

54 "The following requested devices are not available for ONNX Runtime and will be ignored: {}.\nSelected available providers/devices are: {}\nOther available providers are: {}", 

55 unavailable_devices, 

56 available_devices, 

57 [p for p in providers if p not in devices], 

58 ) 

59 

60 providers = available_devices 

61 elif not available_providers: 

62 logger.error( 

63 "ONNX Runtime does not report any available providers. Attempting to load model with default providers, but this will likely fail." 

64 ) 

65 else: 

66 logger.warning( 

67 "None of the requested devices are available for ONNX Runtime, falling back to default, available providers: {}", 

68 available_providers, 

69 ) 

70 return providers 

71 

72 def _init_model_on_device(self, device: Optional[str]) -> rt.InferenceSession: 

73 onnx_descr = self._onnx_descr 

74 if ( 

75 isinstance(onnx_descr, v0_5.OnnxWeightsDescr) 

76 and onnx_descr.external_data is not None 

77 ): 

78 src = onnx_descr.source.absolute() 

79 src_data = onnx_descr.external_data.source.absolute() 

80 if ( 

81 isinstance(src, Path) 

82 and isinstance(src_data, Path) 

83 and src.parent == src_data.parent 

84 ): 

85 logger.debug( 

86 "Loading ONNX model with external data from {}", 

87 src.parent, 

88 ) 

89 source_context = nullcontext(src) 

90 else: 

91 src_reader = onnx_descr.get_reader() 

92 src_data_reader = onnx_descr.external_data.get_reader() 

93 

94 @contextmanager 

95 def source_context_func(): 

96 with tempfile.TemporaryDirectory() as tmpdir: 

97 logger.debug( 

98 "Loading ONNX model with external data from {}", 

99 tmpdir, 

100 ) 

101 src = Path(tmpdir) / src_reader.original_file_name 

102 src_data = Path(tmpdir) / src_data_reader.original_file_name 

103 with src.open("wb") as f: 

104 shutil.copyfileobj(src_reader, f) 

105 with src_data.open("wb") as f: 

106 shutil.copyfileobj(src_data_reader, f) 

107 yield src 

108 

109 source_context = source_context_func() 

110 

111 else: 

112 # load single source file from bytes (without external data, so probably <2GB) 

113 logger.debug( 

114 "Loading ONNX model from bytes (read from {})", onnx_descr.source 

115 ) 

116 source_context = nullcontext(onnx_descr.get_reader().read()) 

117 

118 with source_context as s: 

119 assert isinstance(s, bytes) or s.exists() 

120 session = rt.InferenceSession( 

121 s, 

122 providers=None if device is None else [device], 

123 ) 

124 

125 onnx_inputs = session.get_inputs() 

126 onnx_input_names = [str(ipt.name) for ipt in onnx_inputs] # pyright: ignore[reportUnknownArgumentType] 

127 if self._input_names is None: 

128 self._input_names = onnx_input_names 

129 elif self._input_names != onnx_input_names: 

130 raise RuntimeError( 

131 f"Input names of the ONNX model {onnx_input_names} do not match expected input names {self._input_names} from previous model initialization." 

132 ) 

133 

134 return session 

135 

136 def _forward_impl( 

137 self, 

138 device: Optional[str], 

139 model: rt.InferenceSession, 

140 input_arrays: Sequence[Optional[NDArray[Any]]], 

141 ) -> List[Optional[NDArray[Any]]]: 

142 assert self._input_names is not None, "set during model initialization" 

143 result: Any = model.run(None, dict(zip(self._input_names, input_arrays))) 

144 if is_list(result) or is_tuple(result): 

145 result_seq = list(result) 

146 else: 

147 result_seq = [result] 

148 

149 return result_seq 

150 

151 def _cleanup_pre_model_deletion( 

152 self, device: Optional[str], model: rt.InferenceSession 

153 ) -> None: 

154 return 

155 

156 def _cleanup_post_model_deletion(self, device: Optional[str]) -> None: 

157 return