Coverage for bioimageio/core/model_adapters/_torchscript_model_adapter.py: 86%

50 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-19 09:02 +0000

1import gc 

2import warnings 

3from typing import Any, List, Optional, Sequence, Tuple, Union 

4 

5import numpy as np 

6from numpy.typing import NDArray 

7 

8from bioimageio.spec.model import v0_4, v0_5 

9from bioimageio.spec.utils import download 

10 

11from ..digest_spec import get_axes_infos 

12from ..tensor import Tensor 

13from ._model_adapter import ModelAdapter 

14 

15try: 

16 import torch 

17except Exception as e: 

18 torch = None 

19 torch_error = str(e) 

20else: 

21 torch_error = None 

22 

23 

24class TorchscriptModelAdapter(ModelAdapter): 

25 def __init__( 

26 self, 

27 *, 

28 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

29 devices: Optional[Sequence[str]] = None, 

30 ): 

31 if torch is None: 

32 raise ImportError(f"failed to import torch: {torch_error}") 

33 

34 super().__init__() 

35 if model_description.weights.torchscript is None: 

36 raise ValueError( 

37 f"No torchscript weights found for model {model_description.name}" 

38 ) 

39 

40 weight_path = download(model_description.weights.torchscript.source).path 

41 if devices is None: 

42 self.devices = ["cuda" if torch.cuda.is_available() else "cpu"] 

43 else: 

44 self.devices = [torch.device(d) for d in devices] 

45 

46 if len(self.devices) > 1: 

47 warnings.warn( 

48 "Multiple devices for single torchscript model not yet implemented" 

49 ) 

50 

51 self._model = torch.jit.load(weight_path) 

52 self._model.to(self.devices[0]) 

53 self._model = self._model.eval() 

54 self._internal_output_axes = [ 

55 tuple(a.id for a in get_axes_infos(out)) 

56 for out in model_description.outputs 

57 ] 

58 

59 def forward(self, *batch: Optional[Tensor]) -> List[Optional[Tensor]]: 

60 assert torch is not None 

61 with torch.no_grad(): 

62 torch_tensor = [ 

63 None if b is None else torch.from_numpy(b.data.data).to(self.devices[0]) 

64 for b in batch 

65 ] 

66 _result: Union[ # pyright: ignore[reportUnknownVariableType] 

67 Tuple[Optional[NDArray[Any]], ...], 

68 List[Optional[NDArray[Any]]], 

69 Optional[NDArray[Any]], 

70 ] = self._model.forward(*torch_tensor) 

71 if isinstance(_result, (tuple, list)): 

72 result: Sequence[Optional[NDArray[Any]]] = _result 

73 else: 

74 result = [_result] 

75 

76 result = [ 

77 ( 

78 None 

79 if r is None 

80 else r.cpu().numpy() if not isinstance(r, np.ndarray) else r 

81 ) 

82 for r in result 

83 ] 

84 

85 assert len(result) == len(self._internal_output_axes) 

86 return [ 

87 None if r is None else Tensor(r, dims=axes) 

88 for r, axes in zip(result, self._internal_output_axes) 

89 ] 

90 

91 def unload(self) -> None: 

92 assert torch is not None 

93 self._devices = None 

94 del self._model 

95 _ = gc.collect() # deallocate memory 

96 torch.cuda.empty_cache() # release reserved memory