Coverage for bioimageio/core/model_adapters/_pytorch_model_adapter.py: 84%

67 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-19 09:02 +0000

1import gc 

2import warnings 

3from typing import Any, List, Optional, Sequence, Tuple, Union 

4 

5from bioimageio.spec.model import v0_4, v0_5 

6from bioimageio.spec.utils import download 

7 

8from ..axis import AxisId 

9from ..digest_spec import get_axes_infos, import_callable 

10from ..tensor import Tensor 

11from ._model_adapter import ModelAdapter 

12 

13try: 

14 import torch 

15except Exception as e: 

16 torch = None 

17 torch_error = str(e) 

18else: 

19 torch_error = None 

20 

21 

22class PytorchModelAdapter(ModelAdapter): 

23 def __init__( 

24 self, 

25 *, 

26 outputs: Union[ 

27 Sequence[v0_4.OutputTensorDescr], Sequence[v0_5.OutputTensorDescr] 

28 ], 

29 weights: Union[ 

30 v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr 

31 ], 

32 devices: Optional[Sequence[str]] = None, 

33 ): 

34 if torch is None: 

35 raise ImportError(f"failed to import torch: {torch_error}") 

36 

37 super().__init__() 

38 self.output_dims = [tuple(a.id for a in get_axes_infos(out)) for out in outputs] 

39 self._network = self.get_network(weights) 

40 self._devices = self.get_devices(devices) 

41 self._network = self._network.to(self._devices[0]) 

42 

43 self._primary_device = self._devices[0] 

44 state: Any = torch.load( 

45 download(weights).path, 

46 map_location=self._primary_device, # pyright: ignore[reportUnknownArgumentType] 

47 ) 

48 self._network.load_state_dict(state) 

49 

50 self._network = self._network.eval() 

51 

52 def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]: 

53 if torch is None: 

54 raise ImportError("torch") 

55 with torch.no_grad(): 

56 tensors = [ 

57 None if ipt is None else torch.from_numpy(ipt.data.data) 

58 for ipt in input_tensors 

59 ] 

60 tensors = [ 

61 ( 

62 None 

63 if t is None 

64 else t.to( 

65 self._primary_device # pyright: ignore[reportUnknownArgumentType] 

66 ) 

67 ) 

68 for t in tensors 

69 ] 

70 result: Union[Tuple[Any, ...], List[Any], Any] 

71 result = self._network( # pyright: ignore[reportUnknownVariableType] 

72 *tensors 

73 ) 

74 if not isinstance(result, (tuple, list)): 

75 result = [result] 

76 

77 result = [ 

78 ( 

79 None 

80 if r is None 

81 else r.detach().cpu().numpy() if isinstance(r, torch.Tensor) else r 

82 ) 

83 for r in result # pyright: ignore[reportUnknownVariableType] 

84 ] 

85 if len(result) > len(self.output_dims): 

86 raise ValueError( 

87 f"Expected at most {len(self.output_dims)} outputs, but got {len(result)}" 

88 ) 

89 

90 return [ 

91 None if r is None else Tensor(r, dims=out) 

92 for r, out in zip(result, self.output_dims) 

93 ] 

94 

95 def unload(self) -> None: 

96 del self._network 

97 _ = gc.collect() # deallocate memory 

98 assert torch is not None 

99 torch.cuda.empty_cache() # release reserved memory 

100 

101 @staticmethod 

102 def get_network( # pyright: ignore[reportUnknownParameterType] 

103 weight_spec: Union[ 

104 v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr 

105 ], 

106 ) -> "torch.nn.Module": # pyright: ignore[reportInvalidTypeForm] 

107 if torch is None: 

108 raise ImportError("torch") 

109 arch = import_callable( 

110 weight_spec.architecture, 

111 sha256=( 

112 weight_spec.architecture_sha256 

113 if isinstance(weight_spec, v0_4.PytorchStateDictWeightsDescr) 

114 else weight_spec.sha256 

115 ), 

116 ) 

117 model_kwargs = ( 

118 weight_spec.kwargs 

119 if isinstance(weight_spec, v0_4.PytorchStateDictWeightsDescr) 

120 else weight_spec.architecture.kwargs 

121 ) 

122 network = arch(**model_kwargs) 

123 if not isinstance(network, torch.nn.Module): 

124 raise ValueError( 

125 f"calling {weight_spec.architecture.callable} did not return a torch.nn.Module" 

126 ) 

127 

128 return network 

129 

130 @staticmethod 

131 def get_devices( # pyright: ignore[reportUnknownParameterType] 

132 devices: Optional[Sequence[str]] = None, 

133 ) -> List["torch.device"]: # pyright: ignore[reportInvalidTypeForm] 

134 if torch is None: 

135 raise ImportError("torch") 

136 if not devices: 

137 torch_devices = [ 

138 ( 

139 torch.device("cuda") 

140 if torch.cuda.is_available() 

141 else torch.device("cpu") 

142 ) 

143 ] 

144 else: 

145 torch_devices = [torch.device(d) for d in devices] 

146 

147 if len(torch_devices) > 1: 

148 warnings.warn( 

149 f"Multiple devices for single pytorch model not yet implemented; ignoring {torch_devices[1:]}" 

150 ) 

151 torch_devices = torch_devices[:1] 

152 

153 return torch_devices