Coverage for bioimageio/core/backends/onnx_backend.py: 89%
27 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 09:51 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 09:51 +0000
1# pyright: reportUnknownVariableType=false
2import warnings
3from typing import Any, List, Optional, Sequence, Union
5import onnxruntime as rt # pyright: ignore[reportMissingTypeStubs]
6from numpy.typing import NDArray
8from bioimageio.spec._internal.type_guards import is_list, is_tuple
9from bioimageio.spec.model import v0_4, v0_5
10from bioimageio.spec.utils import download
12from ..model_adapters import ModelAdapter
15class ONNXModelAdapter(ModelAdapter):
16 def __init__(
17 self,
18 *,
19 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
20 devices: Optional[Sequence[str]] = None,
21 ):
22 super().__init__(model_description=model_description)
24 if model_description.weights.onnx is None:
25 raise ValueError("No ONNX weights specified for {model_description.name}")
27 local_path = download(model_description.weights.onnx.source).path
28 self._session = rt.InferenceSession(local_path.read_bytes())
29 onnx_inputs = self._session.get_inputs()
30 self._input_names: List[str] = [ipt.name for ipt in onnx_inputs]
32 if devices is not None:
33 warnings.warn(
34 f"Device management is not implemented for onnx yet, ignoring the devices {devices}"
35 )
37 def _forward_impl(
38 self, input_arrays: Sequence[Optional[NDArray[Any]]]
39 ) -> List[Optional[NDArray[Any]]]:
40 result: Any = self._session.run(
41 None, dict(zip(self._input_names, input_arrays))
42 )
43 if is_list(result) or is_tuple(result):
44 result_seq = list(result)
45 else:
46 result_seq = [result]
48 return result_seq
50 def unload(self) -> None:
51 warnings.warn(
52 "Device management is not implemented for onnx yet, cannot unload model"
53 )