Coverage for src/bioimageio/core/backends/onnx_backend.py: 60%
85 statements
« prev ^ index » next coverage.py v7.14.2, created at 2026-06-22 16:54 +0000
« prev ^ index » next coverage.py v7.14.2, created at 2026-06-22 16:54 +0000
1# pyright: reportUnknownVariableType=false
2import shutil
3import tempfile
4from contextlib import contextmanager, nullcontext
5from pathlib import Path
6from typing import Any, List, Optional, Sequence, Union, cast
8import onnxruntime as rt # pyright: ignore[reportMissingTypeStubs]
9from loguru import logger
10from numpy.typing import NDArray
12from bioimageio.spec.model import v0_4, v0_5
14from .._model_adapter import LocalModelAdapter
15from ..utils._type_guards import is_list, is_tuple
18class ONNXModelAdapter(LocalModelAdapter[Optional[str], rt.InferenceSession]):
19 def __init__(
20 self,
21 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
22 devices: Optional[Sequence[str]] = None,
23 ):
24 onnx_descr = model_description.weights.onnx
25 if onnx_descr is None:
26 raise ValueError("No ONNX weights specified for {model_description.name}")
28 self._onnx_descr = onnx_descr
29 self._input_names: Optional[List[str]] = None
30 super().__init__(model_description=model_description, devices=devices)
32 def _parse_devices(
33 self, devices: Optional[Sequence[str]]
34 ) -> Sequence[Optional[str]]:
35 available_providers: Any = None
36 if hasattr(rt, "get_available_providers"):
37 available_providers = cast(Any, rt.get_available_providers())
39 if is_list(available_providers):
40 if len(available_providers) == 0:
41 providers = [None]
42 else:
43 providers = available_providers
44 else:
45 available_providers = [available_providers]
46 providers = [available_providers]
48 if devices is not None:
49 available_devices = [d for d in devices if d in providers]
50 unavailable_devices = [d for d in devices if d not in providers]
51 if available_devices:
52 if unavailable_devices:
53 logger.warning(
54 "The following requested devices are not available for ONNX Runtime and will be ignored: {}.\nSelected available providers/devices are: {}\nOther available providers are: {}",
55 unavailable_devices,
56 available_devices,
57 [p for p in providers if p not in devices],
58 )
60 providers = available_devices
61 elif not available_providers:
62 logger.error(
63 "ONNX Runtime does not report any available providers. Attempting to load model with default providers, but this will likely fail."
64 )
65 else:
66 logger.warning(
67 "None of the requested devices are available for ONNX Runtime, falling back to default, available providers: {}",
68 available_providers,
69 )
70 return providers
72 def _init_model_on_device(self, device: Optional[str]) -> rt.InferenceSession:
73 onnx_descr = self._onnx_descr
74 if (
75 isinstance(onnx_descr, v0_5.OnnxWeightsDescr)
76 and onnx_descr.external_data is not None
77 ):
78 src = onnx_descr.source.absolute()
79 src_data = onnx_descr.external_data.source.absolute()
80 if (
81 isinstance(src, Path)
82 and isinstance(src_data, Path)
83 and src.parent == src_data.parent
84 ):
85 logger.debug(
86 "Loading ONNX model with external data from {}",
87 src.parent,
88 )
89 source_context = nullcontext(src)
90 else:
91 src_reader = onnx_descr.get_reader()
92 src_data_reader = onnx_descr.external_data.get_reader()
94 @contextmanager
95 def source_context_func():
96 with tempfile.TemporaryDirectory() as tmpdir:
97 logger.debug(
98 "Loading ONNX model with external data from {}",
99 tmpdir,
100 )
101 src = Path(tmpdir) / src_reader.original_file_name
102 src_data = Path(tmpdir) / src_data_reader.original_file_name
103 with src.open("wb") as f:
104 shutil.copyfileobj(src_reader, f)
105 with src_data.open("wb") as f:
106 shutil.copyfileobj(src_data_reader, f)
107 yield src
109 source_context = source_context_func()
111 else:
112 # load single source file from bytes (without external data, so probably <2GB)
113 logger.debug(
114 "Loading ONNX model from bytes (read from {})", onnx_descr.source
115 )
116 source_context = nullcontext(onnx_descr.get_reader().read())
118 with source_context as s:
119 assert isinstance(s, bytes) or s.exists()
120 session = rt.InferenceSession(
121 s,
122 providers=None if device is None else [device],
123 )
125 onnx_inputs = session.get_inputs()
126 onnx_input_names = [str(ipt.name) for ipt in onnx_inputs] # pyright: ignore[reportUnknownArgumentType]
127 if self._input_names is None:
128 self._input_names = onnx_input_names
129 elif self._input_names != onnx_input_names:
130 raise RuntimeError(
131 f"Input names of the ONNX model {onnx_input_names} do not match expected input names {self._input_names} from previous model initialization."
132 )
134 return session
136 def _forward_impl(
137 self,
138 device: Optional[str],
139 model: rt.InferenceSession,
140 input_arrays: Sequence[Optional[NDArray[Any]]],
141 ) -> List[Optional[NDArray[Any]]]:
142 assert self._input_names is not None, "set during model initialization"
143 result: Any = model.run(None, dict(zip(self._input_names, input_arrays)))
144 if is_list(result) or is_tuple(result):
145 result_seq = list(result)
146 else:
147 result_seq = [result]
149 return result_seq
151 def _cleanup_pre_model_deletion(
152 self, device: Optional[str], model: rt.InferenceSession
153 ) -> None:
154 return
156 def _cleanup_post_model_deletion(self, device: Optional[str]) -> None:
157 return