Coverage for src / bioimageio / core / backends / onnx_backend.py: 72%
54 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 13:23 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 13:23 +0000
1# pyright: reportUnknownVariableType=false
2import shutil
3import tempfile
4import warnings
5from pathlib import Path
6from typing import Any, List, Optional, Sequence, Union
8import onnxruntime as rt # pyright: ignore[reportMissingTypeStubs]
9from loguru import logger
10from numpy.typing import NDArray
12from bioimageio.spec.model import v0_4, v0_5
14from ..model_adapters import ModelAdapter
15from ..utils._type_guards import is_list, is_tuple
18class ONNXModelAdapter(ModelAdapter):
19 def __init__(
20 self,
21 *,
22 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
23 devices: Optional[Sequence[str]] = None,
24 ):
25 super().__init__(model_description=model_description)
27 onnx_descr = model_description.weights.onnx
28 if onnx_descr is None:
29 raise ValueError("No ONNX weights specified for {model_description.name}")
31 providers = None
32 if hasattr(rt, "get_available_providers"):
33 providers = rt.get_available_providers()
35 if (
36 isinstance(onnx_descr, v0_5.OnnxWeightsDescr)
37 and onnx_descr.external_data is not None
38 ):
39 src = onnx_descr.source.absolute()
40 src_data = onnx_descr.external_data.source.absolute()
41 if (
42 isinstance(src, Path)
43 and isinstance(src_data, Path)
44 and src.parent == src_data.parent
45 ):
46 logger.debug(
47 "Loading ONNX model with external data from {}",
48 src.parent,
49 )
50 assert src.exists()
51 self._session = rt.InferenceSession(
52 src,
53 providers=providers, # pyright: ignore[reportUnknownArgumentType]
54 )
55 else:
56 src_reader = onnx_descr.get_reader()
57 src_data_reader = onnx_descr.external_data.get_reader()
58 with tempfile.TemporaryDirectory() as tmpdir:
59 logger.debug(
60 "Loading ONNX model with external data from {}",
61 tmpdir,
62 )
63 src = Path(tmpdir) / src_reader.original_file_name
64 src_data = Path(tmpdir) / src_data_reader.original_file_name
65 with src.open("wb") as f:
66 shutil.copyfileobj(src_reader, f)
67 with src_data.open("wb") as f:
68 shutil.copyfileobj(src_data_reader, f)
70 assert src.exists()
71 self._session = rt.InferenceSession(
72 src,
73 providers=providers, # pyright: ignore[reportUnknownArgumentType]
74 )
75 else:
76 # load single source file from bytes (without external data, so probably <2GB)
77 logger.debug(
78 "Loading ONNX model from bytes (read from {})", onnx_descr.source
79 )
80 reader = onnx_descr.get_reader()
81 self._session = rt.InferenceSession(
82 reader.read(),
83 providers=providers, # pyright: ignore[reportUnknownArgumentType]
84 )
86 onnx_inputs = self._session.get_inputs()
87 self._input_names: List[str] = [ipt.name for ipt in onnx_inputs]
89 if devices is not None:
90 warnings.warn(
91 f"Device management is not implemented for onnx yet, ignoring the devices {devices}"
92 )
94 def _forward_impl(
95 self, input_arrays: Sequence[Optional[NDArray[Any]]]
96 ) -> List[Optional[NDArray[Any]]]:
97 result: Any = self._session.run(
98 None, dict(zip(self._input_names, input_arrays))
99 )
100 if is_list(result) or is_tuple(result):
101 result_seq = list(result)
102 else:
103 result_seq = [result]
105 return result_seq
107 def unload(self) -> None:
108 warnings.warn(
109 "Device management is not implemented for onnx yet, cannot unload model"
110 )