Coverage for bioimageio/core/model_adapters/_onnx_model_adapter.py: 81%
37 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-19 09:02 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-19 09:02 +0000
1import warnings
2from typing import Any, List, Optional, Sequence, Union
4from numpy.typing import NDArray
6from bioimageio.spec.model import v0_4, v0_5
7from bioimageio.spec.utils import download
9from ..digest_spec import get_axes_infos
10from ..tensor import Tensor
11from ._model_adapter import ModelAdapter
13try:
14 import onnxruntime as rt
15except Exception as e:
16 rt = None
17 rt_error = str(e)
18else:
19 rt_error = None
22class ONNXModelAdapter(ModelAdapter):
23 def __init__(
24 self,
25 *,
26 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
27 devices: Optional[Sequence[str]] = None,
28 ):
29 if rt is None:
30 raise ImportError(f"failed to import onnxruntime: {rt_error}")
32 super().__init__()
33 self._internal_output_axes = [
34 tuple(a.id for a in get_axes_infos(out))
35 for out in model_description.outputs
36 ]
37 if model_description.weights.onnx is None:
38 raise ValueError("No ONNX weights specified for {model_description.name}")
40 self._session = rt.InferenceSession(
41 str(download(model_description.weights.onnx.source).path)
42 )
43 onnx_inputs = self._session.get_inputs() # type: ignore
44 self._input_names: List[str] = [ipt.name for ipt in onnx_inputs] # type: ignore
46 if devices is not None:
47 warnings.warn(
48 f"Device management is not implemented for onnx yet, ignoring the devices {devices}"
49 )
51 def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]:
52 assert len(input_tensors) == len(self._input_names)
53 input_arrays = [None if ipt is None else ipt.data.data for ipt in input_tensors]
54 result: Union[Sequence[Optional[NDArray[Any]]], Optional[NDArray[Any]]]
55 result = self._session.run( # pyright: ignore[reportUnknownVariableType]
56 None, dict(zip(self._input_names, input_arrays))
57 )
58 if isinstance(result, (list, tuple)):
59 result_seq: Sequence[Optional[NDArray[Any]]] = result
60 else:
61 result_seq = [result] # type: ignore
63 return [
64 None if r is None else Tensor(r, dims=axes)
65 for r, axes in zip(result_seq, self._internal_output_axes)
66 ]
68 def unload(self) -> None:
69 warnings.warn(
70 "Device management is not implemented for onnx yet, cannot unload model"
71 )