Coverage for src / bioimageio / core / backends / onnx_backend.py: 72%

54 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-30 13:23 +0000

1# pyright: reportUnknownVariableType=false 

2import shutil 

3import tempfile 

4import warnings 

5from pathlib import Path 

6from typing import Any, List, Optional, Sequence, Union 

7 

8import onnxruntime as rt # pyright: ignore[reportMissingTypeStubs] 

9from loguru import logger 

10from numpy.typing import NDArray 

11 

12from bioimageio.spec.model import v0_4, v0_5 

13 

14from ..model_adapters import ModelAdapter 

15from ..utils._type_guards import is_list, is_tuple 

16 

17 

18class ONNXModelAdapter(ModelAdapter): 

19 def __init__( 

20 self, 

21 *, 

22 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

23 devices: Optional[Sequence[str]] = None, 

24 ): 

25 super().__init__(model_description=model_description) 

26 

27 onnx_descr = model_description.weights.onnx 

28 if onnx_descr is None: 

29 raise ValueError("No ONNX weights specified for {model_description.name}") 

30 

31 providers = None 

32 if hasattr(rt, "get_available_providers"): 

33 providers = rt.get_available_providers() 

34 

35 if ( 

36 isinstance(onnx_descr, v0_5.OnnxWeightsDescr) 

37 and onnx_descr.external_data is not None 

38 ): 

39 src = onnx_descr.source.absolute() 

40 src_data = onnx_descr.external_data.source.absolute() 

41 if ( 

42 isinstance(src, Path) 

43 and isinstance(src_data, Path) 

44 and src.parent == src_data.parent 

45 ): 

46 logger.debug( 

47 "Loading ONNX model with external data from {}", 

48 src.parent, 

49 ) 

50 assert src.exists() 

51 self._session = rt.InferenceSession( 

52 src, 

53 providers=providers, # pyright: ignore[reportUnknownArgumentType] 

54 ) 

55 else: 

56 src_reader = onnx_descr.get_reader() 

57 src_data_reader = onnx_descr.external_data.get_reader() 

58 with tempfile.TemporaryDirectory() as tmpdir: 

59 logger.debug( 

60 "Loading ONNX model with external data from {}", 

61 tmpdir, 

62 ) 

63 src = Path(tmpdir) / src_reader.original_file_name 

64 src_data = Path(tmpdir) / src_data_reader.original_file_name 

65 with src.open("wb") as f: 

66 shutil.copyfileobj(src_reader, f) 

67 with src_data.open("wb") as f: 

68 shutil.copyfileobj(src_data_reader, f) 

69 

70 assert src.exists() 

71 self._session = rt.InferenceSession( 

72 src, 

73 providers=providers, # pyright: ignore[reportUnknownArgumentType] 

74 ) 

75 else: 

76 # load single source file from bytes (without external data, so probably <2GB) 

77 logger.debug( 

78 "Loading ONNX model from bytes (read from {})", onnx_descr.source 

79 ) 

80 reader = onnx_descr.get_reader() 

81 self._session = rt.InferenceSession( 

82 reader.read(), 

83 providers=providers, # pyright: ignore[reportUnknownArgumentType] 

84 ) 

85 

86 onnx_inputs = self._session.get_inputs() 

87 self._input_names: List[str] = [ipt.name for ipt in onnx_inputs] 

88 

89 if devices is not None: 

90 warnings.warn( 

91 f"Device management is not implemented for onnx yet, ignoring the devices {devices}" 

92 ) 

93 

94 def _forward_impl( 

95 self, input_arrays: Sequence[Optional[NDArray[Any]]] 

96 ) -> List[Optional[NDArray[Any]]]: 

97 result: Any = self._session.run( 

98 None, dict(zip(self._input_names, input_arrays)) 

99 ) 

100 if is_list(result) or is_tuple(result): 

101 result_seq = list(result) 

102 else: 

103 result_seq = [result] 

104 

105 return result_seq 

106 

107 def unload(self) -> None: 

108 warnings.warn( 

109 "Device management is not implemented for onnx yet, cannot unload model" 

110 )