Coverage for bioimageio/core/model_adapters/_onnx_model_adapter.py: 81%

37 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-19 09:02 +0000

1import warnings 

2from typing import Any, List, Optional, Sequence, Union 

3 

4from numpy.typing import NDArray 

5 

6from bioimageio.spec.model import v0_4, v0_5 

7from bioimageio.spec.utils import download 

8 

9from ..digest_spec import get_axes_infos 

10from ..tensor import Tensor 

11from ._model_adapter import ModelAdapter 

12 

13try: 

14 import onnxruntime as rt 

15except Exception as e: 

16 rt = None 

17 rt_error = str(e) 

18else: 

19 rt_error = None 

20 

21 

22class ONNXModelAdapter(ModelAdapter): 

23 def __init__( 

24 self, 

25 *, 

26 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

27 devices: Optional[Sequence[str]] = None, 

28 ): 

29 if rt is None: 

30 raise ImportError(f"failed to import onnxruntime: {rt_error}") 

31 

32 super().__init__() 

33 self._internal_output_axes = [ 

34 tuple(a.id for a in get_axes_infos(out)) 

35 for out in model_description.outputs 

36 ] 

37 if model_description.weights.onnx is None: 

38 raise ValueError("No ONNX weights specified for {model_description.name}") 

39 

40 self._session = rt.InferenceSession( 

41 str(download(model_description.weights.onnx.source).path) 

42 ) 

43 onnx_inputs = self._session.get_inputs() # type: ignore 

44 self._input_names: List[str] = [ipt.name for ipt in onnx_inputs] # type: ignore 

45 

46 if devices is not None: 

47 warnings.warn( 

48 f"Device management is not implemented for onnx yet, ignoring the devices {devices}" 

49 ) 

50 

51 def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]: 

52 assert len(input_tensors) == len(self._input_names) 

53 input_arrays = [None if ipt is None else ipt.data.data for ipt in input_tensors] 

54 result: Union[Sequence[Optional[NDArray[Any]]], Optional[NDArray[Any]]] 

55 result = self._session.run( # pyright: ignore[reportUnknownVariableType] 

56 None, dict(zip(self._input_names, input_arrays)) 

57 ) 

58 if isinstance(result, (list, tuple)): 

59 result_seq: Sequence[Optional[NDArray[Any]]] = result 

60 else: 

61 result_seq = [result] # type: ignore 

62 

63 return [ 

64 None if r is None else Tensor(r, dims=axes) 

65 for r, axes in zip(result_seq, self._internal_output_axes) 

66 ] 

67 

68 def unload(self) -> None: 

69 warnings.warn( 

70 "Device management is not implemented for onnx yet, cannot unload model" 

71 )