Coverage for src/bioimageio/core/backends/pytorch_backend.py: 75%

118 statements  

« prev     ^ index     » next       coverage.py v7.14.2, created at 2026-06-22 16:54 +0000

1import gc 

2from abc import abstractmethod 

3from contextlib import nullcontext 

4from io import BytesIO, TextIOWrapper 

5from pathlib import Path 

6from typing import Any, List, Literal, Mapping, Optional, Sequence, Tuple, Union 

7 

8import torch 

9from loguru import logger 

10from numpy.typing import NDArray 

11from torch import nn 

12from typing_extensions import Protocol, Self, assert_never, runtime_checkable 

13 

14from bioimageio.spec._internal.version_type import Version 

15from bioimageio.spec.common import BytesReader, ZipPath 

16from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 

17from bioimageio.spec.utils import download 

18 

19from .._model_adapter import LocalModelAdapter 

20from ..digest_spec import import_callable 

21from ..utils._type_guards import is_list, is_ndarray, is_tuple 

22 

23 

24@runtime_checkable 

25class TorchNNModuleLike(Protocol): 

26 @abstractmethod 

27 def load_state_dict( 

28 self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False 

29 ) -> Self: ... 

30 

31 @abstractmethod 

32 def to( 

33 self, 

34 *, 

35 device: Optional[torch.device] = None, 

36 dtype: Optional[torch.dtype] = None, 

37 non_blocking: bool = False, 

38 ) -> Self: ... 

39 

40 @abstractmethod 

41 def forward( 

42 self, *input: torch.Tensor 

43 ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], List[torch.Tensor]]: ... 

44 

45 def eval(self) -> Self: 

46 """Set model to eval mode""" 

47 return self 

48 

49 

50class PytorchModelAdapter(LocalModelAdapter[torch.device, nn.Module]): 

51 def __init__( 

52 self, 

53 model_description: AnyModelDescr, 

54 mode: Literal["eval", "train"] = "eval", 

55 devices: Optional[Sequence[str]] = None, 

56 ): 

57 weights = model_description.weights.pytorch_state_dict 

58 if weights is None: 

59 raise ValueError("No `pytorch_state_dict` weights found") 

60 

61 self._weights = weights 

62 self._mode: Literal["eval", "train"] = mode 

63 super().__init__(model_description=model_description, devices=devices) 

64 

65 def _parse_devices( 

66 self, devices: Optional[Sequence[str]] 

67 ) -> Sequence[torch.device]: 

68 return get_devices(devices) 

69 

70 def _init_model_on_device(self, device: torch.device) -> nn.Module: 

71 model = load_torch_model(self._weights, load_state=True, devices=[device]) 

72 

73 if self._mode == "eval": 

74 model = model.eval() 

75 elif self._mode == "train": 

76 model = model.train() 

77 else: 

78 assert_never(self._mode) 

79 

80 return model 

81 

82 def _forward_impl( 

83 self, 

84 device: torch.device, 

85 model: nn.Module, 

86 input_arrays: Sequence[Optional[NDArray[Any]]], 

87 ) -> List[Optional[NDArray[Any]]]: 

88 tensors = [ 

89 None if a is None else torch.from_numpy(a).to(device) for a in input_arrays 

90 ] 

91 

92 if self._mode == "eval": 

93 ctxt = torch.no_grad 

94 elif self._mode == "train": 

95 ctxt = nullcontext 

96 else: 

97 assert_never(self._mode) 

98 

99 with ctxt(): 

100 model_out = model(*tensors) 

101 

102 if is_tuple(model_out) or is_list(model_out): 

103 model_out_seq = model_out 

104 else: 

105 model_out_seq = model_out = [model_out] 

106 

107 result: List[Optional[NDArray[Any]]] = [] 

108 for i, r in enumerate(model_out_seq): 

109 if r is None: 

110 result.append(None) 

111 elif isinstance(r, torch.Tensor): 

112 r_np: NDArray[Any] = ( # pyright: ignore[reportUnknownVariableType] 

113 r.detach().cpu().numpy() 

114 ) 

115 result.append(r_np) 

116 elif is_ndarray(r): 

117 result.append(r) 

118 else: 

119 raise TypeError(f"Model output[{i}] has unexpected type {type(r)}.") 

120 

121 return result 

122 

123 def _cleanup_pre_model_deletion( 

124 self, device: torch.device, model: nn.Module 

125 ) -> None: 

126 return 

127 

128 def _cleanup_post_model_deletion(self, device: torch.device) -> None: 

129 _ = gc.collect() # deallocate memory 

130 if device.type == "cuda": 

131 torch.cuda.empty_cache() # release reserved memory 

132 

133 

134def load_torch_model( 

135 weight_spec: Union[ 

136 v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr 

137 ], 

138 *, 

139 load_state: bool = True, 

140 devices: Optional[Sequence[Union[str, torch.device]]] = None, 

141) -> nn.Module: 

142 custom_callable = import_callable( 

143 weight_spec.architecture, 

144 sha256=( 

145 weight_spec.architecture_sha256 

146 if isinstance(weight_spec, v0_4.PytorchStateDictWeightsDescr) 

147 else weight_spec.sha256 

148 ), 

149 ) 

150 model_kwargs = ( 

151 weight_spec.kwargs 

152 if isinstance(weight_spec, v0_4.PytorchStateDictWeightsDescr) 

153 else weight_spec.architecture.kwargs 

154 ) 

155 torch_model = custom_callable(**model_kwargs) 

156 

157 if not isinstance(torch_model, nn.Module): 

158 if isinstance( 

159 weight_spec.architecture, 

160 (v0_4.CallableFromFile, v0_4.CallableFromDepencency), 

161 ): 

162 callable_name = weight_spec.architecture.callable_name 

163 else: 

164 callable_name = weight_spec.architecture.callable 

165 

166 raise ValueError(f"Calling {callable_name} did not return a torch.nn.Module.") 

167 

168 if load_state or devices: 

169 use_devices = get_devices(devices) 

170 torch_model = torch_model.to(use_devices[0]) 

171 if load_state: 

172 torch_model = load_torch_state_dict( 

173 torch_model, 

174 path=download(weight_spec), 

175 devices=use_devices, 

176 strict=weight_spec.strict 

177 if isinstance(weight_spec, v0_5.PytorchStateDictWeightsDescr) 

178 else True, 

179 ) 

180 return torch_model 

181 

182 

183def load_torch_state_dict( 

184 model: nn.Module, 

185 path: Union[Path, ZipPath, BytesReader], 

186 devices: Sequence[torch.device], 

187 strict: bool = True, 

188) -> nn.Module: 

189 model = model.to(devices[0]) 

190 if isinstance(path, (Path, ZipPath)): 

191 ctxt = path.open("rb") 

192 else: 

193 ctxt = nullcontext(BytesIO(path.read())) 

194 

195 with ctxt as f: 

196 assert not isinstance(f, TextIOWrapper) 

197 if Version(str(torch.__version__)) < Version("1.13"): 

198 state = torch.load(f, map_location=devices[0]) 

199 else: 

200 try: 

201 state = torch.load(f, map_location=devices[0], weights_only=True) 

202 except Exception as e: 

203 msg = ( 

204 f"Failed to load weights with `weights_only=True`: {e}\n\n" 

205 + "This usually means the weights file contains non-tensor objects" 

206 + " (e.g. numpy arrays, custom classes, or nested dicts with" 

207 + " metadata). The BioImage.IO spec requires a pure state dict —" 

208 + " an OrderedDict mapping parameter names to tensors only.\n\n" 

209 + "To fix this, extract only the state dict from your checkpoint:\n\n" 

210 + " import torch\n" 

211 + " checkpoint = torch.load('original.pth', weights_only=False)\n" 

212 + " # Inspect keys, e.g.: checkpoint.keys()" 

213 + " -> dict_keys(['model', 'optimizer', ...])\n" 

214 + " torch.save(checkpoint['model'], 'weights.pt')\n\n" 

215 + "Then reference 'weights.pt' in your bioimageio.yaml." 

216 ) 

217 raise ValueError(msg) from e 

218 

219 incompatible = model.load_state_dict(state, strict=strict) 

220 if ( 

221 isinstance(incompatible, tuple) 

222 and hasattr(incompatible, "missing_keys") 

223 and hasattr(incompatible, "unexpected_keys") 

224 ): 

225 if incompatible.missing_keys: 

226 logger.warning("Missing state dict keys: {}", incompatible.missing_keys) 

227 

228 if hasattr(incompatible, "unexpected_keys") and incompatible.unexpected_keys: 

229 logger.warning( 

230 "Unexpected state dict keys: {}", incompatible.unexpected_keys 

231 ) 

232 else: 

233 logger.warning( 

234 "`model.load_state_dict()` unexpectedly returned: {} " 

235 + "(expected named tuple with `missing_keys` and `unexpected_keys` attributes)", 

236 (s[:20] + "..." if len(s := str(incompatible)) > 20 else s), 

237 ) 

238 

239 return model 

240 

241 

242def get_devices( 

243 devices: Optional[Sequence[Union[torch.device, str]]] = None, 

244) -> List[torch.device]: 

245 if not devices: 

246 if torch.cuda.is_available(): 

247 torch_devices = [ 

248 torch.device(f"cuda:{i}") for i in range(torch.cuda.device_count()) 

249 ] 

250 elif torch.backends.mps.is_available(): 

251 torch_devices = [torch.device("mps")] 

252 else: 

253 try: 

254 if ( 

255 torch.accelerator.is_available() 

256 and (current_accelerator := torch.accelerator.current_accelerator()) 

257 is not None 

258 ): 

259 torch_devices = [current_accelerator] 

260 else: 

261 torch_devices = [torch.device("cpu")] 

262 except Exception: 

263 torch_devices = [torch.device("cpu")] 

264 else: 

265 torch_devices = [torch.device(d) for d in devices] 

266 

267 return torch_devices