Coverage for bioimageio/core/backends/onnx_backend.py: 88%

26 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-16 15:20 +0000

1# pyright: reportUnknownVariableType=false 

2import warnings 

3from typing import Any, List, Optional, Sequence, Union 

4 

5import onnxruntime as rt # pyright: ignore[reportMissingTypeStubs] 

6from numpy.typing import NDArray 

7 

8from bioimageio.spec.model import v0_4, v0_5 

9 

10from ..model_adapters import ModelAdapter 

11from ..utils._type_guards import is_list, is_tuple 

12 

13 

14class ONNXModelAdapter(ModelAdapter): 

15 def __init__( 

16 self, 

17 *, 

18 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

19 devices: Optional[Sequence[str]] = None, 

20 ): 

21 super().__init__(model_description=model_description) 

22 

23 if model_description.weights.onnx is None: 

24 raise ValueError("No ONNX weights specified for {model_description.name}") 

25 

26 reader = model_description.weights.onnx.get_reader() 

27 self._session = rt.InferenceSession(reader.read()) 

28 onnx_inputs = self._session.get_inputs() 

29 self._input_names: List[str] = [ipt.name for ipt in onnx_inputs] 

30 

31 if devices is not None: 

32 warnings.warn( 

33 f"Device management is not implemented for onnx yet, ignoring the devices {devices}" 

34 ) 

35 

36 def _forward_impl( 

37 self, input_arrays: Sequence[Optional[NDArray[Any]]] 

38 ) -> List[Optional[NDArray[Any]]]: 

39 result: Any = self._session.run( 

40 None, dict(zip(self._input_names, input_arrays)) 

41 ) 

42 if is_list(result) or is_tuple(result): 

43 result_seq = list(result) 

44 else: 

45 result_seq = [result] 

46 

47 return result_seq 

48 

49 def unload(self) -> None: 

50 warnings.warn( 

51 "Device management is not implemented for onnx yet, cannot unload model" 

52 )