Coverage for bioimageio/core/backends/onnx_backend.py: 88%
26 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-16 15:20 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-16 15:20 +0000
1# pyright: reportUnknownVariableType=false
2import warnings
3from typing import Any, List, Optional, Sequence, Union
5import onnxruntime as rt # pyright: ignore[reportMissingTypeStubs]
6from numpy.typing import NDArray
8from bioimageio.spec.model import v0_4, v0_5
10from ..model_adapters import ModelAdapter
11from ..utils._type_guards import is_list, is_tuple
14class ONNXModelAdapter(ModelAdapter):
15 def __init__(
16 self,
17 *,
18 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
19 devices: Optional[Sequence[str]] = None,
20 ):
21 super().__init__(model_description=model_description)
23 if model_description.weights.onnx is None:
24 raise ValueError("No ONNX weights specified for {model_description.name}")
26 reader = model_description.weights.onnx.get_reader()
27 self._session = rt.InferenceSession(reader.read())
28 onnx_inputs = self._session.get_inputs()
29 self._input_names: List[str] = [ipt.name for ipt in onnx_inputs]
31 if devices is not None:
32 warnings.warn(
33 f"Device management is not implemented for onnx yet, ignoring the devices {devices}"
34 )
36 def _forward_impl(
37 self, input_arrays: Sequence[Optional[NDArray[Any]]]
38 ) -> List[Optional[NDArray[Any]]]:
39 result: Any = self._session.run(
40 None, dict(zip(self._input_names, input_arrays))
41 )
42 if is_list(result) or is_tuple(result):
43 result_seq = list(result)
44 else:
45 result_seq = [result]
47 return result_seq
49 def unload(self) -> None:
50 warnings.warn(
51 "Device management is not implemented for onnx yet, cannot unload model"
52 )