Coverage for src / bioimageio / core / backends / pytorch_backend.py: 78%

113 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-18 12:35 +0000

1import gc 

2import warnings 

3from abc import abstractmethod 

4from contextlib import nullcontext 

5from io import BytesIO, TextIOWrapper 

6from pathlib import Path 

7from typing import Any, List, Literal, Mapping, Optional, Sequence, Tuple, Union 

8 

9import torch 

10from loguru import logger 

11from numpy.typing import NDArray 

12from torch import nn 

13from typing_extensions import Protocol, Self, assert_never, runtime_checkable 

14 

15from bioimageio.spec._internal.version_type import Version 

16from bioimageio.spec.common import BytesReader, ZipPath 

17from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 

18from bioimageio.spec.utils import download 

19 

20from ..digest_spec import import_callable 

21from ..utils._type_guards import is_list, is_ndarray, is_tuple 

22from ._model_adapter import ModelAdapter 

23 

24 

25@runtime_checkable 

26class TorchNNModuleLike(Protocol): 

27 @abstractmethod 

28 def load_state_dict( 

29 self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False 

30 ) -> Self: ... 

31 

32 @abstractmethod 

33 def to( 

34 self, 

35 *, 

36 device: Optional[torch.device] = None, 

37 dtype: Optional[torch.dtype] = None, 

38 non_blocking: bool = False, 

39 ) -> Self: ... 

40 

41 @abstractmethod 

42 def forward( 

43 self, *input: torch.Tensor 

44 ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], List[torch.Tensor]]: ... 

45 

46 def eval(self) -> Self: 

47 """Set model to eval mode""" 

48 return self 

49 

50 

51class PytorchModelAdapter(ModelAdapter): 

52 def __init__( 

53 self, 

54 *, 

55 model_description: AnyModelDescr, 

56 devices: Optional[Sequence[Union[str, torch.device]]] = None, 

57 mode: Literal["eval", "train"] = "eval", 

58 ): 

59 super().__init__(model_description=model_description) 

60 weights = model_description.weights.pytorch_state_dict 

61 if weights is None: 

62 raise ValueError("No `pytorch_state_dict` weights found") 

63 

64 devices = get_devices(devices) 

65 self._model = load_torch_model(weights, load_state=True, devices=devices) 

66 if mode == "eval": 

67 self._model = self._model.eval() 

68 elif mode == "train": 

69 self._model = self._model.train() 

70 else: 

71 assert_never(mode) 

72 

73 self._mode: Literal["eval", "train"] = mode 

74 self._primary_device = devices[0] 

75 

76 def _forward_impl( 

77 self, input_arrays: Sequence[Optional[NDArray[Any]]] 

78 ) -> List[Optional[NDArray[Any]]]: 

79 tensors = [ 

80 None if a is None else torch.from_numpy(a).to(self._primary_device) 

81 for a in input_arrays 

82 ] 

83 

84 if self._mode == "eval": 

85 ctxt = torch.no_grad 

86 elif self._mode == "train": 

87 ctxt = nullcontext 

88 else: 

89 assert_never(self._mode) 

90 

91 with ctxt(): 

92 model_out = self._model(*tensors) 

93 

94 if is_tuple(model_out) or is_list(model_out): 

95 model_out_seq = model_out 

96 else: 

97 model_out_seq = model_out = [model_out] 

98 

99 result: List[Optional[NDArray[Any]]] = [] 

100 for i, r in enumerate(model_out_seq): 

101 if r is None: 

102 result.append(None) 

103 elif isinstance(r, torch.Tensor): 

104 r_np: NDArray[Any] = ( # pyright: ignore[reportUnknownVariableType] 

105 r.detach().cpu().numpy() 

106 ) 

107 result.append(r_np) 

108 elif is_ndarray(r): 

109 result.append(r) 

110 else: 

111 raise TypeError(f"Model output[{i}] has unexpected type {type(r)}.") 

112 

113 return result 

114 

115 def unload(self) -> None: 

116 del self._model 

117 _ = gc.collect() # deallocate memory 

118 assert torch is not None 

119 torch.cuda.empty_cache() # release reserved memory 

120 

121 

122def load_torch_model( 

123 weight_spec: Union[ 

124 v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr 

125 ], 

126 *, 

127 load_state: bool = True, 

128 devices: Optional[Sequence[Union[str, torch.device]]] = None, 

129) -> nn.Module: 

130 custom_callable = import_callable( 

131 weight_spec.architecture, 

132 sha256=( 

133 weight_spec.architecture_sha256 

134 if isinstance(weight_spec, v0_4.PytorchStateDictWeightsDescr) 

135 else weight_spec.sha256 

136 ), 

137 ) 

138 model_kwargs = ( 

139 weight_spec.kwargs 

140 if isinstance(weight_spec, v0_4.PytorchStateDictWeightsDescr) 

141 else weight_spec.architecture.kwargs 

142 ) 

143 torch_model = custom_callable(**model_kwargs) 

144 

145 if not isinstance(torch_model, nn.Module): 

146 if isinstance( 

147 weight_spec.architecture, 

148 (v0_4.CallableFromFile, v0_4.CallableFromDepencency), 

149 ): 

150 callable_name = weight_spec.architecture.callable_name 

151 else: 

152 callable_name = weight_spec.architecture.callable 

153 

154 raise ValueError(f"Calling {callable_name} did not return a torch.nn.Module.") 

155 

156 if load_state or devices: 

157 use_devices = get_devices(devices) 

158 torch_model = torch_model.to(use_devices[0]) 

159 if load_state: 

160 torch_model = load_torch_state_dict( 

161 torch_model, 

162 path=download(weight_spec), 

163 devices=use_devices, 

164 strict=weight_spec.strict 

165 if isinstance(weight_spec, v0_5.PytorchStateDictWeightsDescr) 

166 else True, 

167 ) 

168 return torch_model 

169 

170 

171def load_torch_state_dict( 

172 model: nn.Module, 

173 path: Union[Path, ZipPath, BytesReader], 

174 devices: Sequence[torch.device], 

175 strict: bool = True, 

176) -> nn.Module: 

177 model = model.to(devices[0]) 

178 if isinstance(path, (Path, ZipPath)): 

179 ctxt = path.open("rb") 

180 else: 

181 ctxt = nullcontext(BytesIO(path.read())) 

182 

183 with ctxt as f: 

184 assert not isinstance(f, TextIOWrapper) 

185 if Version(str(torch.__version__)) < Version("1.13"): 

186 state = torch.load(f, map_location=devices[0]) 

187 else: 

188 try: 

189 state = torch.load(f, map_location=devices[0], weights_only=True) 

190 except Exception as e: 

191 msg = ( 

192 f"Failed to load weights with `weights_only=True`: {e}\n\n" 

193 + "This usually means the weights file contains non-tensor objects" 

194 + " (e.g. numpy arrays, custom classes, or nested dicts with" 

195 + " metadata). The BioImage.IO spec requires a pure state dict —" 

196 + " an OrderedDict mapping parameter names to tensors only.\n\n" 

197 + "To fix this, extract only the state dict from your checkpoint:\n\n" 

198 + " import torch\n" 

199 + " checkpoint = torch.load('original.pth', weights_only=False)\n" 

200 + " # Inspect keys, e.g.: checkpoint.keys()" 

201 + " -> dict_keys(['model', 'optimizer', ...])\n" 

202 + " torch.save(checkpoint['model'], 'weights.pt')\n\n" 

203 + "Then reference 'weights.pt' in your bioimageio.yaml." 

204 ) 

205 raise ValueError(msg) from e 

206 

207 incompatible = model.load_state_dict(state, strict=strict) 

208 if ( 

209 isinstance(incompatible, tuple) 

210 and hasattr(incompatible, "missing_keys") 

211 and hasattr(incompatible, "unexpected_keys") 

212 ): 

213 if incompatible.missing_keys: 

214 logger.warning("Missing state dict keys: {}", incompatible.missing_keys) 

215 

216 if hasattr(incompatible, "unexpected_keys") and incompatible.unexpected_keys: 

217 logger.warning( 

218 "Unexpected state dict keys: {}", incompatible.unexpected_keys 

219 ) 

220 else: 

221 logger.warning( 

222 "`model.load_state_dict()` unexpectedly returned: {} " 

223 + "(expected named tuple with `missing_keys` and `unexpected_keys` attributes)", 

224 (s[:20] + "..." if len(s := str(incompatible)) > 20 else s), 

225 ) 

226 

227 return model 

228 

229 

230def get_devices( 

231 devices: Optional[Sequence[Union[torch.device, str]]] = None, 

232) -> List[torch.device]: 

233 if not devices: 

234 if torch.cuda.is_available(): 

235 torch_devices = [torch.device("cuda")] 

236 elif torch.backends.mps.is_available(): 

237 torch_devices = [torch.device("mps")] 

238 else: 

239 torch_devices = [torch.device("cpu")] 

240 else: 

241 torch_devices = [torch.device(d) for d in devices] 

242 

243 if len(torch_devices) > 1: 

244 warnings.warn( 

245 f"Multiple devices for pytorch model not yet implemented; ignoring {torch_devices[1:]}" 

246 ) 

247 torch_devices = torch_devices[:1] 

248 

249 return torch_devices