Coverage for src/bioimageio/core/backends/onnx_backend.py: 73%

52 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-13 11:02 +0000

1# pyright: reportUnknownVariableType=false 

2import shutil 

3import tempfile 

4import warnings 

5from pathlib import Path 

6from typing import Any, List, Optional, Sequence, Union 

7 

8import onnxruntime as rt # pyright: ignore[reportMissingTypeStubs] 

9from bioimageio.spec.model import v0_4, v0_5 

10from loguru import logger 

11from numpy.typing import NDArray 

12 

13from ..model_adapters import ModelAdapter 

14from ..utils._type_guards import is_list, is_tuple 

15 

16 

17class ONNXModelAdapter(ModelAdapter): 

18 def __init__( 

19 self, 

20 *, 

21 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

22 devices: Optional[Sequence[str]] = None, 

23 ): 

24 super().__init__(model_description=model_description) 

25 

26 onnx_descr = model_description.weights.onnx 

27 if onnx_descr is None: 

28 raise ValueError("No ONNX weights specified for {model_description.name}") 

29 

30 providers = None 

31 if hasattr(rt, "get_available_providers"): 

32 providers = rt.get_available_providers() 

33 

34 if ( 

35 isinstance(onnx_descr, v0_5.OnnxWeightsDescr) 

36 and onnx_descr.external_data is not None 

37 ): 

38 src = onnx_descr.source.absolute() 

39 src_data = onnx_descr.external_data.source.absolute() 

40 if ( 

41 isinstance(src, Path) 

42 and isinstance(src_data, Path) 

43 and src.parent == src_data.parent 

44 ): 

45 logger.debug( 

46 "Loading ONNX model with external data from {}", 

47 src.parent, 

48 ) 

49 self._session = rt.InferenceSession( 

50 src, 

51 providers=providers, # pyright: ignore[reportUnknownArgumentType] 

52 ) 

53 else: 

54 src_reader = onnx_descr.get_reader() 

55 src_data_reader = onnx_descr.external_data.get_reader() 

56 with tempfile.TemporaryDirectory() as tmpdir: 

57 logger.debug( 

58 "Loading ONNX model with external data from {}", 

59 tmpdir, 

60 ) 

61 src = Path(tmpdir) / src_reader.original_file_name 

62 src_data = Path(tmpdir) / src_data_reader.original_file_name 

63 with src.open("wb") as f: 

64 shutil.copyfileobj(src_reader, f) 

65 with src_data.open("wb") as f: 

66 shutil.copyfileobj(src_data_reader, f) 

67 

68 self._session = rt.InferenceSession( 

69 src, 

70 providers=providers, # pyright: ignore[reportUnknownArgumentType] 

71 ) 

72 else: 

73 # load single source file from bytes (without external data, so probably <2GB) 

74 logger.debug( 

75 "Loading ONNX model from bytes (read from {})", onnx_descr.source 

76 ) 

77 reader = onnx_descr.get_reader() 

78 self._session = rt.InferenceSession( 

79 reader.read(), 

80 providers=providers, # pyright: ignore[reportUnknownArgumentType] 

81 ) 

82 

83 onnx_inputs = self._session.get_inputs() 

84 self._input_names: List[str] = [ipt.name for ipt in onnx_inputs] 

85 

86 if devices is not None: 

87 warnings.warn( 

88 f"Device management is not implemented for onnx yet, ignoring the devices {devices}" 

89 ) 

90 

91 def _forward_impl( 

92 self, input_arrays: Sequence[Optional[NDArray[Any]]] 

93 ) -> List[Optional[NDArray[Any]]]: 

94 result: Any = self._session.run( 

95 None, dict(zip(self._input_names, input_arrays)) 

96 ) 

97 if is_list(result) or is_tuple(result): 

98 result_seq = list(result) 

99 else: 

100 result_seq = [result] 

101 

102 return result_seq 

103 

104 def unload(self) -> None: 

105 warnings.warn( 

106 "Device management is not implemented for onnx yet, cannot unload model" 

107 )