Coverage for bioimageio/core/backends/onnx_backend.py: 89%

27 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-01 09:51 +0000

1# pyright: reportUnknownVariableType=false 

2import warnings 

3from typing import Any, List, Optional, Sequence, Union 

4 

5import onnxruntime as rt # pyright: ignore[reportMissingTypeStubs] 

6from numpy.typing import NDArray 

7 

8from bioimageio.spec._internal.type_guards import is_list, is_tuple 

9from bioimageio.spec.model import v0_4, v0_5 

10from bioimageio.spec.utils import download 

11 

12from ..model_adapters import ModelAdapter 

13 

14 

15class ONNXModelAdapter(ModelAdapter): 

16 def __init__( 

17 self, 

18 *, 

19 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

20 devices: Optional[Sequence[str]] = None, 

21 ): 

22 super().__init__(model_description=model_description) 

23 

24 if model_description.weights.onnx is None: 

25 raise ValueError("No ONNX weights specified for {model_description.name}") 

26 

27 local_path = download(model_description.weights.onnx.source).path 

28 self._session = rt.InferenceSession(local_path.read_bytes()) 

29 onnx_inputs = self._session.get_inputs() 

30 self._input_names: List[str] = [ipt.name for ipt in onnx_inputs] 

31 

32 if devices is not None: 

33 warnings.warn( 

34 f"Device management is not implemented for onnx yet, ignoring the devices {devices}" 

35 ) 

36 

37 def _forward_impl( 

38 self, input_arrays: Sequence[Optional[NDArray[Any]]] 

39 ) -> List[Optional[NDArray[Any]]]: 

40 result: Any = self._session.run( 

41 None, dict(zip(self._input_names, input_arrays)) 

42 ) 

43 if is_list(result) or is_tuple(result): 

44 result_seq = list(result) 

45 else: 

46 result_seq = [result] 

47 

48 return result_seq 

49 

50 def unload(self) -> None: 

51 warnings.warn( 

52 "Device management is not implemented for onnx yet, cannot unload model" 

53 )