Coverage for bioimageio/core/backends/pytorch_backend.py: 81%

100 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-16 15:20 +0000

1import gc 

2import warnings 

3from contextlib import nullcontext 

4from io import BytesIO, TextIOWrapper 

5from pathlib import Path 

6from typing import Any, List, Literal, Optional, Sequence, Union 

7 

8import torch 

9from loguru import logger 

10from numpy.typing import NDArray 

11from torch import nn 

12from typing_extensions import assert_never 

13 

14from bioimageio.spec._internal.version_type import Version 

15from bioimageio.spec.common import BytesReader, ZipPath 

16from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 

17from bioimageio.spec.utils import download 

18 

19from ..digest_spec import import_callable 

20from ..utils._type_guards import is_list, is_ndarray, is_tuple 

21from ._model_adapter import ModelAdapter 

22 

23 

24class PytorchModelAdapter(ModelAdapter): 

25 def __init__( 

26 self, 

27 *, 

28 model_description: AnyModelDescr, 

29 devices: Optional[Sequence[Union[str, torch.device]]] = None, 

30 mode: Literal["eval", "train"] = "eval", 

31 ): 

32 super().__init__(model_description=model_description) 

33 weights = model_description.weights.pytorch_state_dict 

34 if weights is None: 

35 raise ValueError("No `pytorch_state_dict` weights found") 

36 

37 devices = get_devices(devices) 

38 self._model = load_torch_model(weights, load_state=True, devices=devices) 

39 if mode == "eval": 

40 self._model = self._model.eval() 

41 elif mode == "train": 

42 self._model = self._model.train() 

43 else: 

44 assert_never(mode) 

45 

46 self._mode: Literal["eval", "train"] = mode 

47 self._primary_device = devices[0] 

48 

49 def _forward_impl( 

50 self, input_arrays: Sequence[Optional[NDArray[Any]]] 

51 ) -> List[Optional[NDArray[Any]]]: 

52 tensors = [ 

53 None if a is None else torch.from_numpy(a).to(self._primary_device) 

54 for a in input_arrays 

55 ] 

56 

57 if self._mode == "eval": 

58 ctxt = torch.no_grad 

59 elif self._mode == "train": 

60 ctxt = nullcontext 

61 else: 

62 assert_never(self._mode) 

63 

64 with ctxt(): 

65 model_out = self._model(*tensors) 

66 

67 if is_tuple(model_out) or is_list(model_out): 

68 model_out_seq = model_out 

69 else: 

70 model_out_seq = model_out = [model_out] 

71 

72 result: List[Optional[NDArray[Any]]] = [] 

73 for i, r in enumerate(model_out_seq): 

74 if r is None: 

75 result.append(None) 

76 elif isinstance(r, torch.Tensor): 

77 r_np: NDArray[Any] = ( # pyright: ignore[reportUnknownVariableType] 

78 r.detach().cpu().numpy() 

79 ) 

80 result.append(r_np) 

81 elif is_ndarray(r): 

82 result.append(r) 

83 else: 

84 raise TypeError(f"Model output[{i}] has unexpected type {type(r)}.") 

85 

86 return result 

87 

88 def unload(self) -> None: 

89 del self._model 

90 _ = gc.collect() # deallocate memory 

91 assert torch is not None 

92 torch.cuda.empty_cache() # release reserved memory 

93 

94 

95def load_torch_model( 

96 weight_spec: Union[ 

97 v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr 

98 ], 

99 *, 

100 load_state: bool = True, 

101 devices: Optional[Sequence[Union[str, torch.device]]] = None, 

102) -> nn.Module: 

103 custom_callable = import_callable( 

104 weight_spec.architecture, 

105 sha256=( 

106 weight_spec.architecture_sha256 

107 if isinstance(weight_spec, v0_4.PytorchStateDictWeightsDescr) 

108 else weight_spec.sha256 

109 ), 

110 ) 

111 model_kwargs = ( 

112 weight_spec.kwargs 

113 if isinstance(weight_spec, v0_4.PytorchStateDictWeightsDescr) 

114 else weight_spec.architecture.kwargs 

115 ) 

116 torch_model = custom_callable(**model_kwargs) 

117 

118 if not isinstance(torch_model, nn.Module): 

119 if isinstance( 

120 weight_spec.architecture, 

121 (v0_4.CallableFromFile, v0_4.CallableFromDepencency), 

122 ): 

123 callable_name = weight_spec.architecture.callable_name 

124 else: 

125 callable_name = weight_spec.architecture.callable 

126 

127 raise ValueError(f"Calling {callable_name} did not return a torch.nn.Module.") 

128 

129 if load_state or devices: 

130 use_devices = get_devices(devices) 

131 torch_model = torch_model.to(use_devices[0]) 

132 if load_state: 

133 torch_model = load_torch_state_dict( 

134 torch_model, 

135 path=download(weight_spec), 

136 devices=use_devices, 

137 ) 

138 return torch_model 

139 

140 

141def load_torch_state_dict( 

142 model: nn.Module, 

143 path: Union[Path, ZipPath, BytesReader], 

144 devices: Sequence[torch.device], 

145) -> nn.Module: 

146 model = model.to(devices[0]) 

147 if isinstance(path, (Path, ZipPath)): 

148 ctxt = path.open("rb") 

149 else: 

150 ctxt = nullcontext(BytesIO(path.read())) 

151 

152 with ctxt as f: 

153 assert not isinstance(f, TextIOWrapper) 

154 if Version(str(torch.__version__)) < Version("1.13"): 

155 state = torch.load(f, map_location=devices[0]) 

156 else: 

157 state = torch.load(f, map_location=devices[0], weights_only=True) 

158 

159 incompatible = model.load_state_dict(state) 

160 if ( 

161 isinstance(incompatible, tuple) 

162 and hasattr(incompatible, "missing_keys") 

163 and hasattr(incompatible, "unexpected_keys") 

164 ): 

165 if incompatible.missing_keys: 

166 logger.warning("Missing state dict keys: {}", incompatible.missing_keys) 

167 

168 if hasattr(incompatible, "unexpected_keys") and incompatible.unexpected_keys: 

169 logger.warning( 

170 "Unexpected state dict keys: {}", incompatible.unexpected_keys 

171 ) 

172 else: 

173 logger.warning( 

174 "`model.load_state_dict()` unexpectedly returned: {} " 

175 + "(expected named tuple with `missing_keys` and `unexpected_keys` attributes)", 

176 (s[:20] + "..." if len(s := str(incompatible)) > 20 else s), 

177 ) 

178 

179 return model 

180 

181 

182def get_devices( 

183 devices: Optional[Sequence[Union[torch.device, str]]] = None, 

184) -> List[torch.device]: 

185 if not devices: 

186 torch_devices = [ 

187 torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 

188 ] 

189 else: 

190 torch_devices = [torch.device(d) for d in devices] 

191 

192 if len(torch_devices) > 1: 

193 warnings.warn( 

194 f"Multiple devices for single pytorch model not yet implemented; ignoring {torch_devices[1:]}" 

195 ) 

196 torch_devices = torch_devices[:1] 

197 

198 return torch_devices