Coverage for src / bioimageio / core / backends / onnx_backend.py: 68%
73 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-18 12:35 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-18 12:35 +0000
1# pyright: reportUnknownVariableType=false
2import shutil
3import tempfile
4import warnings
5from contextlib import contextmanager, nullcontext
6from pathlib import Path
7from typing import Any, List, Optional, Sequence, Union, cast
9import onnxruntime as rt # pyright: ignore[reportMissingTypeStubs]
10from exceptiongroup import ExceptionGroup
11from loguru import logger
12from numpy.typing import NDArray
14from bioimageio.spec.model import v0_4, v0_5
16from ..model_adapters import ModelAdapter
17from ..utils._type_guards import is_list, is_tuple
20class ONNXModelAdapter(ModelAdapter):
21 def __init__(
22 self,
23 *,
24 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
25 devices: Optional[Sequence[str]] = None,
26 ):
27 super().__init__(model_description=model_description)
29 onnx_descr = model_description.weights.onnx
30 if onnx_descr is None:
31 raise ValueError("No ONNX weights specified for {model_description.name}")
33 available_providers: Any = None
34 if hasattr(rt, "get_available_providers"):
35 available_providers = cast(Any, rt.get_available_providers())
37 if is_list(available_providers):
38 if len(available_providers) == 0:
39 providers = [None]
40 else:
41 providers = available_providers
42 else:
43 providers = [available_providers]
45 if (
46 isinstance(onnx_descr, v0_5.OnnxWeightsDescr)
47 and onnx_descr.external_data is not None
48 ):
49 src = onnx_descr.source.absolute()
50 src_data = onnx_descr.external_data.source.absolute()
51 if (
52 isinstance(src, Path)
53 and isinstance(src_data, Path)
54 and src.parent == src_data.parent
55 ):
56 logger.debug(
57 "Loading ONNX model with external data from {}",
58 src.parent,
59 )
60 source_context = nullcontext(src)
61 else:
62 src_reader = onnx_descr.get_reader()
63 src_data_reader = onnx_descr.external_data.get_reader()
65 @contextmanager
66 def source_context_func():
67 with tempfile.TemporaryDirectory() as tmpdir:
68 logger.debug(
69 "Loading ONNX model with external data from {}",
70 tmpdir,
71 )
72 src = Path(tmpdir) / src_reader.original_file_name
73 src_data = Path(tmpdir) / src_data_reader.original_file_name
74 with src.open("wb") as f:
75 shutil.copyfileobj(src_reader, f)
76 with src_data.open("wb") as f:
77 shutil.copyfileobj(src_data_reader, f)
78 yield src
80 source_context = source_context_func()
82 else:
83 # load single source file from bytes (without external data, so probably <2GB)
84 logger.debug(
85 "Loading ONNX model from bytes (read from {})", onnx_descr.source
86 )
87 source_context = nullcontext(onnx_descr.get_reader().read())
89 with source_context as s:
90 assert isinstance(s, bytes) or s.exists()
92 # try providers in order until one works
93 # TODO: check if issue with backup providers is fixed and evaluate handing over all available providers
94 # currently (onnxruntime 1.23.2) if a higher priority providers fails a RUNTIME_EXCEPTION may be raised
95 # stating 'model_path must not be empty' instead of trying the next provider, see # TODO: reference issue
96 provider_exceptions: List[Exception] = []
97 for p in providers:
98 try:
99 self._session = rt.InferenceSession(
100 s,
101 providers=None if p is None else [p],
102 )
103 except Exception as e:
104 provider_exceptions.append(e)
105 else:
106 for bad_p, e in zip(
107 providers[: len(provider_exceptions)], provider_exceptions
108 ):
109 logger.warning(
110 "Failed to load ONNX model with provider {}: {}",
111 bad_p,
112 e,
113 )
115 break
116 else:
117 raise ExceptionGroup(
118 "Failed to load ONNX model with any of the available providers.",
119 provider_exceptions,
120 )
122 onnx_inputs = self._session.get_inputs()
123 self._input_names: List[str] = [ipt.name for ipt in onnx_inputs]
125 if devices is not None:
126 warnings.warn(
127 f"Device management is not implemented for onnx yet, ignoring the devices {devices}"
128 )
130 def _forward_impl(
131 self, input_arrays: Sequence[Optional[NDArray[Any]]]
132 ) -> List[Optional[NDArray[Any]]]:
133 result: Any = self._session.run(
134 None, dict(zip(self._input_names, input_arrays))
135 )
136 if is_list(result) or is_tuple(result):
137 result_seq = list(result)
138 else:
139 result_seq = [result]
141 return result_seq
143 def unload(self) -> None:
144 warnings.warn(
145 "Device management is not implemented for onnx yet, cannot unload model"
146 )