Coverage for src / bioimageio / core / backends / pytorch_backend.py: 77%

108 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-22 18:38 +0000

1import gc 

2import warnings 

3from contextlib import nullcontext 

4from io import BytesIO, TextIOWrapper 

5from pathlib import Path 

6from typing import Any, List, Literal, Optional, Sequence, Union 

7 

8import torch 

9from loguru import logger 

10from numpy.typing import NDArray 

11from torch import nn 

12from typing_extensions import assert_never 

13 

14from bioimageio.spec._internal.version_type import Version 

15from bioimageio.spec.common import BytesReader, ZipPath 

16from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 

17from bioimageio.spec.utils import download 

18 

19from ..digest_spec import import_callable 

20from ..utils._type_guards import is_list, is_ndarray, is_tuple 

21from ._model_adapter import ModelAdapter 

22 

23 

24class PytorchModelAdapter(ModelAdapter): 

25 def __init__( 

26 self, 

27 *, 

28 model_description: AnyModelDescr, 

29 devices: Optional[Sequence[Union[str, torch.device]]] = None, 

30 mode: Literal["eval", "train"] = "eval", 

31 ): 

32 super().__init__(model_description=model_description) 

33 weights = model_description.weights.pytorch_state_dict 

34 if weights is None: 

35 raise ValueError("No `pytorch_state_dict` weights found") 

36 

37 devices = get_devices(devices) 

38 self._model = load_torch_model(weights, load_state=True, devices=devices) 

39 if mode == "eval": 

40 self._model = self._model.eval() 

41 elif mode == "train": 

42 self._model = self._model.train() 

43 else: 

44 assert_never(mode) 

45 

46 self._mode: Literal["eval", "train"] = mode 

47 self._primary_device = devices[0] 

48 

49 def _forward_impl( 

50 self, input_arrays: Sequence[Optional[NDArray[Any]]] 

51 ) -> List[Optional[NDArray[Any]]]: 

52 tensors = [ 

53 None if a is None else torch.from_numpy(a).to(self._primary_device) 

54 for a in input_arrays 

55 ] 

56 

57 if self._mode == "eval": 

58 ctxt = torch.no_grad 

59 elif self._mode == "train": 

60 ctxt = nullcontext 

61 else: 

62 assert_never(self._mode) 

63 

64 with ctxt(): 

65 model_out = self._model(*tensors) 

66 

67 if is_tuple(model_out) or is_list(model_out): 

68 model_out_seq = model_out 

69 else: 

70 model_out_seq = model_out = [model_out] 

71 

72 result: List[Optional[NDArray[Any]]] = [] 

73 for i, r in enumerate(model_out_seq): 

74 if r is None: 

75 result.append(None) 

76 elif isinstance(r, torch.Tensor): 

77 r_np: NDArray[Any] = ( # pyright: ignore[reportUnknownVariableType] 

78 r.detach().cpu().numpy() 

79 ) 

80 result.append(r_np) 

81 elif is_ndarray(r): 

82 result.append(r) 

83 else: 

84 raise TypeError(f"Model output[{i}] has unexpected type {type(r)}.") 

85 

86 return result 

87 

88 def unload(self) -> None: 

89 del self._model 

90 _ = gc.collect() # deallocate memory 

91 assert torch is not None 

92 torch.cuda.empty_cache() # release reserved memory 

93 

94 

95def load_torch_model( 

96 weight_spec: Union[ 

97 v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr 

98 ], 

99 *, 

100 load_state: bool = True, 

101 devices: Optional[Sequence[Union[str, torch.device]]] = None, 

102) -> nn.Module: 

103 custom_callable = import_callable( 

104 weight_spec.architecture, 

105 sha256=( 

106 weight_spec.architecture_sha256 

107 if isinstance(weight_spec, v0_4.PytorchStateDictWeightsDescr) 

108 else weight_spec.sha256 

109 ), 

110 ) 

111 model_kwargs = ( 

112 weight_spec.kwargs 

113 if isinstance(weight_spec, v0_4.PytorchStateDictWeightsDescr) 

114 else weight_spec.architecture.kwargs 

115 ) 

116 torch_model = custom_callable(**model_kwargs) 

117 

118 if not isinstance(torch_model, nn.Module): 

119 if isinstance( 

120 weight_spec.architecture, 

121 (v0_4.CallableFromFile, v0_4.CallableFromDepencency), 

122 ): 

123 callable_name = weight_spec.architecture.callable_name 

124 else: 

125 callable_name = weight_spec.architecture.callable 

126 

127 raise ValueError(f"Calling {callable_name} did not return a torch.nn.Module.") 

128 

129 if load_state or devices: 

130 use_devices = get_devices(devices) 

131 torch_model = torch_model.to(use_devices[0]) 

132 if load_state: 

133 torch_model = load_torch_state_dict( 

134 torch_model, 

135 path=download(weight_spec), 

136 devices=use_devices, 

137 ) 

138 return torch_model 

139 

140 

141def load_torch_state_dict( 

142 model: nn.Module, 

143 path: Union[Path, ZipPath, BytesReader], 

144 devices: Sequence[torch.device], 

145) -> nn.Module: 

146 model = model.to(devices[0]) 

147 if isinstance(path, (Path, ZipPath)): 

148 ctxt = path.open("rb") 

149 else: 

150 ctxt = nullcontext(BytesIO(path.read())) 

151 

152 with ctxt as f: 

153 assert not isinstance(f, TextIOWrapper) 

154 if Version(str(torch.__version__)) < Version("1.13"): 

155 state = torch.load(f, map_location=devices[0]) 

156 else: 

157 try: 

158 state = torch.load(f, map_location=devices[0], weights_only=True) 

159 except Exception as e: 

160 msg = ( 

161 f"Failed to load weights with `weights_only=True`: {e}\n\n" 

162 + "This usually means the weights file contains non-tensor objects" 

163 + " (e.g. numpy arrays, custom classes, or nested dicts with" 

164 + " metadata). The BioImage.IO spec requires a pure state dict —" 

165 + " an OrderedDict mapping parameter names to tensors only.\n\n" 

166 + "To fix this, extract only the state dict from your checkpoint:\n\n" 

167 + " import torch\n" 

168 + " checkpoint = torch.load('original.pth', weights_only=False)\n" 

169 + " # Inspect keys, e.g.: checkpoint.keys()" 

170 + " -> dict_keys(['model', 'optimizer', ...])\n" 

171 + " torch.save(checkpoint['model'], 'weights.pt')\n\n" 

172 + "Then reference 'weights.pt' in your bioimageio.yaml." 

173 ) 

174 raise ValueError(msg) from e 

175 

176 incompatible = model.load_state_dict(state) 

177 if ( 

178 isinstance(incompatible, tuple) 

179 and hasattr(incompatible, "missing_keys") 

180 and hasattr(incompatible, "unexpected_keys") 

181 ): 

182 if incompatible.missing_keys: 

183 logger.warning("Missing state dict keys: {}", incompatible.missing_keys) 

184 

185 if hasattr(incompatible, "unexpected_keys") and incompatible.unexpected_keys: 

186 logger.warning( 

187 "Unexpected state dict keys: {}", incompatible.unexpected_keys 

188 ) 

189 else: 

190 logger.warning( 

191 "`model.load_state_dict()` unexpectedly returned: {} " 

192 + "(expected named tuple with `missing_keys` and `unexpected_keys` attributes)", 

193 (s[:20] + "..." if len(s := str(incompatible)) > 20 else s), 

194 ) 

195 

196 return model 

197 

198 

199def get_devices( 

200 devices: Optional[Sequence[Union[torch.device, str]]] = None, 

201) -> List[torch.device]: 

202 if not devices: 

203 if torch.cuda.is_available(): 

204 torch_devices = [torch.device("cuda")] 

205 elif torch.backends.mps.is_available(): 

206 torch_devices = [torch.device("mps")] 

207 else: 

208 torch_devices = [torch.device("cpu")] 

209 else: 

210 torch_devices = [torch.device(d) for d in devices] 

211 

212 if len(torch_devices) > 1: 

213 warnings.warn( 

214 f"Multiple devices for pytorch model not yet implemented; ignoring {torch_devices[1:]}" 

215 ) 

216 torch_devices = torch_devices[:1] 

217 

218 return torch_devices