Coverage for bioimageio/core/backends/pytorch_backend.py: 84%

92 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-01 09:51 +0000

1import gc 

2import warnings 

3from contextlib import nullcontext 

4from io import TextIOWrapper 

5from pathlib import Path 

6from typing import Any, List, Literal, Optional, Sequence, Union 

7 

8import torch 

9from loguru import logger 

10from numpy.typing import NDArray 

11from torch import nn 

12from typing_extensions import assert_never 

13 

14from bioimageio.spec._internal.type_guards import is_list, is_ndarray, is_tuple 

15from bioimageio.spec.common import ZipPath 

16from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 

17from bioimageio.spec.utils import download 

18 

19from ..digest_spec import import_callable 

20from ._model_adapter import ModelAdapter 

21 

22 

23class PytorchModelAdapter(ModelAdapter): 

24 def __init__( 

25 self, 

26 *, 

27 model_description: AnyModelDescr, 

28 devices: Optional[Sequence[Union[str, torch.device]]] = None, 

29 mode: Literal["eval", "train"] = "eval", 

30 ): 

31 super().__init__(model_description=model_description) 

32 weights = model_description.weights.pytorch_state_dict 

33 if weights is None: 

34 raise ValueError("No `pytorch_state_dict` weights found") 

35 

36 devices = get_devices(devices) 

37 self._model = load_torch_model(weights, load_state=True, devices=devices) 

38 if mode == "eval": 

39 self._model = self._model.eval() 

40 elif mode == "train": 

41 self._model = self._model.train() 

42 else: 

43 assert_never(mode) 

44 

45 self._mode: Literal["eval", "train"] = mode 

46 self._primary_device = devices[0] 

47 

48 def _forward_impl( 

49 self, input_arrays: Sequence[Optional[NDArray[Any]]] 

50 ) -> List[Optional[NDArray[Any]]]: 

51 tensors = [ 

52 None if a is None else torch.from_numpy(a).to(self._primary_device) 

53 for a in input_arrays 

54 ] 

55 

56 if self._mode == "eval": 

57 ctxt = torch.no_grad 

58 elif self._mode == "train": 

59 ctxt = nullcontext 

60 else: 

61 assert_never(self._mode) 

62 

63 with ctxt(): 

64 model_out = self._model(*tensors) 

65 

66 if is_tuple(model_out) or is_list(model_out): 

67 model_out_seq = model_out 

68 else: 

69 model_out_seq = model_out = [model_out] 

70 

71 result: List[Optional[NDArray[Any]]] = [] 

72 for i, r in enumerate(model_out_seq): 

73 if r is None: 

74 result.append(None) 

75 elif isinstance(r, torch.Tensor): 

76 r_np: NDArray[Any] = r.detach().cpu().numpy() 

77 result.append(r_np) 

78 elif is_ndarray(r): 

79 result.append(r) 

80 else: 

81 raise TypeError(f"Model output[{i}] has unexpected type {type(r)}.") 

82 

83 return result 

84 

85 def unload(self) -> None: 

86 del self._model 

87 _ = gc.collect() # deallocate memory 

88 assert torch is not None 

89 torch.cuda.empty_cache() # release reserved memory 

90 

91 

92def load_torch_model( 

93 weight_spec: Union[ 

94 v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr 

95 ], 

96 *, 

97 load_state: bool = True, 

98 devices: Optional[Sequence[Union[str, torch.device]]] = None, 

99) -> nn.Module: 

100 custom_callable = import_callable( 

101 weight_spec.architecture, 

102 sha256=( 

103 weight_spec.architecture_sha256 

104 if isinstance(weight_spec, v0_4.PytorchStateDictWeightsDescr) 

105 else weight_spec.sha256 

106 ), 

107 ) 

108 model_kwargs = ( 

109 weight_spec.kwargs 

110 if isinstance(weight_spec, v0_4.PytorchStateDictWeightsDescr) 

111 else weight_spec.architecture.kwargs 

112 ) 

113 torch_model = custom_callable(**model_kwargs) 

114 

115 if not isinstance(torch_model, nn.Module): 

116 if isinstance( 

117 weight_spec.architecture, 

118 (v0_4.CallableFromFile, v0_4.CallableFromDepencency), 

119 ): 

120 callable_name = weight_spec.architecture.callable_name 

121 else: 

122 callable_name = weight_spec.architecture.callable 

123 

124 raise ValueError(f"Calling {callable_name} did not return a torch.nn.Module.") 

125 

126 if load_state or devices: 

127 use_devices = get_devices(devices) 

128 torch_model = torch_model.to(use_devices[0]) 

129 if load_state: 

130 torch_model = load_torch_state_dict( 

131 torch_model, 

132 path=download(weight_spec).path, 

133 devices=use_devices, 

134 ) 

135 return torch_model 

136 

137 

138def load_torch_state_dict( 

139 model: nn.Module, 

140 path: Union[Path, ZipPath], 

141 devices: Sequence[torch.device], 

142) -> nn.Module: 

143 model = model.to(devices[0]) 

144 with path.open("rb") as f: 

145 assert not isinstance(f, TextIOWrapper) 

146 state = torch.load(f, map_location=devices[0], weights_only=True) 

147 

148 incompatible = model.load_state_dict(state) 

149 if ( 

150 incompatible is not None # pyright: ignore[reportUnnecessaryComparison] 

151 and incompatible.missing_keys 

152 ): 

153 logger.warning("Missing state dict keys: {}", incompatible.missing_keys) 

154 

155 if ( 

156 incompatible is not None # pyright: ignore[reportUnnecessaryComparison] 

157 and incompatible.unexpected_keys 

158 ): 

159 logger.warning("Unexpected state dict keys: {}", incompatible.unexpected_keys) 

160 

161 return model 

162 

163 

164def get_devices( 

165 devices: Optional[Sequence[Union[torch.device, str]]] = None, 

166) -> List[torch.device]: 

167 if not devices: 

168 torch_devices = [ 

169 torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 

170 ] 

171 else: 

172 torch_devices = [torch.device(d) for d in devices] 

173 

174 if len(torch_devices) > 1: 

175 warnings.warn( 

176 f"Multiple devices for single pytorch model not yet implemented; ignoring {torch_devices[1:]}" 

177 ) 

178 torch_devices = torch_devices[:1] 

179 

180 return torch_devices