Coverage for src/bioimageio/core/backends/onnx_backend.py: 73%
52 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-13 11:02 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-13 11:02 +0000
1# pyright: reportUnknownVariableType=false
2import shutil
3import tempfile
4import warnings
5from pathlib import Path
6from typing import Any, List, Optional, Sequence, Union
8import onnxruntime as rt # pyright: ignore[reportMissingTypeStubs]
9from bioimageio.spec.model import v0_4, v0_5
10from loguru import logger
11from numpy.typing import NDArray
13from ..model_adapters import ModelAdapter
14from ..utils._type_guards import is_list, is_tuple
17class ONNXModelAdapter(ModelAdapter):
18 def __init__(
19 self,
20 *,
21 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
22 devices: Optional[Sequence[str]] = None,
23 ):
24 super().__init__(model_description=model_description)
26 onnx_descr = model_description.weights.onnx
27 if onnx_descr is None:
28 raise ValueError("No ONNX weights specified for {model_description.name}")
30 providers = None
31 if hasattr(rt, "get_available_providers"):
32 providers = rt.get_available_providers()
34 if (
35 isinstance(onnx_descr, v0_5.OnnxWeightsDescr)
36 and onnx_descr.external_data is not None
37 ):
38 src = onnx_descr.source.absolute()
39 src_data = onnx_descr.external_data.source.absolute()
40 if (
41 isinstance(src, Path)
42 and isinstance(src_data, Path)
43 and src.parent == src_data.parent
44 ):
45 logger.debug(
46 "Loading ONNX model with external data from {}",
47 src.parent,
48 )
49 self._session = rt.InferenceSession(
50 src,
51 providers=providers, # pyright: ignore[reportUnknownArgumentType]
52 )
53 else:
54 src_reader = onnx_descr.get_reader()
55 src_data_reader = onnx_descr.external_data.get_reader()
56 with tempfile.TemporaryDirectory() as tmpdir:
57 logger.debug(
58 "Loading ONNX model with external data from {}",
59 tmpdir,
60 )
61 src = Path(tmpdir) / src_reader.original_file_name
62 src_data = Path(tmpdir) / src_data_reader.original_file_name
63 with src.open("wb") as f:
64 shutil.copyfileobj(src_reader, f)
65 with src_data.open("wb") as f:
66 shutil.copyfileobj(src_data_reader, f)
68 self._session = rt.InferenceSession(
69 src,
70 providers=providers, # pyright: ignore[reportUnknownArgumentType]
71 )
72 else:
73 # load single source file from bytes (without external data, so probably <2GB)
74 logger.debug(
75 "Loading ONNX model from bytes (read from {})", onnx_descr.source
76 )
77 reader = onnx_descr.get_reader()
78 self._session = rt.InferenceSession(
79 reader.read(),
80 providers=providers, # pyright: ignore[reportUnknownArgumentType]
81 )
83 onnx_inputs = self._session.get_inputs()
84 self._input_names: List[str] = [ipt.name for ipt in onnx_inputs]
86 if devices is not None:
87 warnings.warn(
88 f"Device management is not implemented for onnx yet, ignoring the devices {devices}"
89 )
91 def _forward_impl(
92 self, input_arrays: Sequence[Optional[NDArray[Any]]]
93 ) -> List[Optional[NDArray[Any]]]:
94 result: Any = self._session.run(
95 None, dict(zip(self._input_names, input_arrays))
96 )
97 if is_list(result) or is_tuple(result):
98 result_seq = list(result)
99 else:
100 result_seq = [result]
102 return result_seq
104 def unload(self) -> None:
105 warnings.warn(
106 "Device management is not implemented for onnx yet, cannot unload model"
107 )