Coverage for src/bioimageio/core/backends/pytorch_backend.py: 75%
118 statements
« prev ^ index » next coverage.py v7.14.2, created at 2026-06-22 16:54 +0000
« prev ^ index » next coverage.py v7.14.2, created at 2026-06-22 16:54 +0000
1import gc
2from abc import abstractmethod
3from contextlib import nullcontext
4from io import BytesIO, TextIOWrapper
5from pathlib import Path
6from typing import Any, List, Literal, Mapping, Optional, Sequence, Tuple, Union
8import torch
9from loguru import logger
10from numpy.typing import NDArray
11from torch import nn
12from typing_extensions import Protocol, Self, assert_never, runtime_checkable
14from bioimageio.spec._internal.version_type import Version
15from bioimageio.spec.common import BytesReader, ZipPath
16from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
17from bioimageio.spec.utils import download
19from .._model_adapter import LocalModelAdapter
20from ..digest_spec import import_callable
21from ..utils._type_guards import is_list, is_ndarray, is_tuple
24@runtime_checkable
25class TorchNNModuleLike(Protocol):
26 @abstractmethod
27 def load_state_dict(
28 self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False
29 ) -> Self: ...
31 @abstractmethod
32 def to(
33 self,
34 *,
35 device: Optional[torch.device] = None,
36 dtype: Optional[torch.dtype] = None,
37 non_blocking: bool = False,
38 ) -> Self: ...
40 @abstractmethod
41 def forward(
42 self, *input: torch.Tensor
43 ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], List[torch.Tensor]]: ...
45 def eval(self) -> Self:
46 """Set model to eval mode"""
47 return self
50class PytorchModelAdapter(LocalModelAdapter[torch.device, nn.Module]):
51 def __init__(
52 self,
53 model_description: AnyModelDescr,
54 mode: Literal["eval", "train"] = "eval",
55 devices: Optional[Sequence[str]] = None,
56 ):
57 weights = model_description.weights.pytorch_state_dict
58 if weights is None:
59 raise ValueError("No `pytorch_state_dict` weights found")
61 self._weights = weights
62 self._mode: Literal["eval", "train"] = mode
63 super().__init__(model_description=model_description, devices=devices)
65 def _parse_devices(
66 self, devices: Optional[Sequence[str]]
67 ) -> Sequence[torch.device]:
68 return get_devices(devices)
70 def _init_model_on_device(self, device: torch.device) -> nn.Module:
71 model = load_torch_model(self._weights, load_state=True, devices=[device])
73 if self._mode == "eval":
74 model = model.eval()
75 elif self._mode == "train":
76 model = model.train()
77 else:
78 assert_never(self._mode)
80 return model
82 def _forward_impl(
83 self,
84 device: torch.device,
85 model: nn.Module,
86 input_arrays: Sequence[Optional[NDArray[Any]]],
87 ) -> List[Optional[NDArray[Any]]]:
88 tensors = [
89 None if a is None else torch.from_numpy(a).to(device) for a in input_arrays
90 ]
92 if self._mode == "eval":
93 ctxt = torch.no_grad
94 elif self._mode == "train":
95 ctxt = nullcontext
96 else:
97 assert_never(self._mode)
99 with ctxt():
100 model_out = model(*tensors)
102 if is_tuple(model_out) or is_list(model_out):
103 model_out_seq = model_out
104 else:
105 model_out_seq = model_out = [model_out]
107 result: List[Optional[NDArray[Any]]] = []
108 for i, r in enumerate(model_out_seq):
109 if r is None:
110 result.append(None)
111 elif isinstance(r, torch.Tensor):
112 r_np: NDArray[Any] = ( # pyright: ignore[reportUnknownVariableType]
113 r.detach().cpu().numpy()
114 )
115 result.append(r_np)
116 elif is_ndarray(r):
117 result.append(r)
118 else:
119 raise TypeError(f"Model output[{i}] has unexpected type {type(r)}.")
121 return result
123 def _cleanup_pre_model_deletion(
124 self, device: torch.device, model: nn.Module
125 ) -> None:
126 return
128 def _cleanup_post_model_deletion(self, device: torch.device) -> None:
129 _ = gc.collect() # deallocate memory
130 if device.type == "cuda":
131 torch.cuda.empty_cache() # release reserved memory
134def load_torch_model(
135 weight_spec: Union[
136 v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr
137 ],
138 *,
139 load_state: bool = True,
140 devices: Optional[Sequence[Union[str, torch.device]]] = None,
141) -> nn.Module:
142 custom_callable = import_callable(
143 weight_spec.architecture,
144 sha256=(
145 weight_spec.architecture_sha256
146 if isinstance(weight_spec, v0_4.PytorchStateDictWeightsDescr)
147 else weight_spec.sha256
148 ),
149 )
150 model_kwargs = (
151 weight_spec.kwargs
152 if isinstance(weight_spec, v0_4.PytorchStateDictWeightsDescr)
153 else weight_spec.architecture.kwargs
154 )
155 torch_model = custom_callable(**model_kwargs)
157 if not isinstance(torch_model, nn.Module):
158 if isinstance(
159 weight_spec.architecture,
160 (v0_4.CallableFromFile, v0_4.CallableFromDepencency),
161 ):
162 callable_name = weight_spec.architecture.callable_name
163 else:
164 callable_name = weight_spec.architecture.callable
166 raise ValueError(f"Calling {callable_name} did not return a torch.nn.Module.")
168 if load_state or devices:
169 use_devices = get_devices(devices)
170 torch_model = torch_model.to(use_devices[0])
171 if load_state:
172 torch_model = load_torch_state_dict(
173 torch_model,
174 path=download(weight_spec),
175 devices=use_devices,
176 strict=weight_spec.strict
177 if isinstance(weight_spec, v0_5.PytorchStateDictWeightsDescr)
178 else True,
179 )
180 return torch_model
183def load_torch_state_dict(
184 model: nn.Module,
185 path: Union[Path, ZipPath, BytesReader],
186 devices: Sequence[torch.device],
187 strict: bool = True,
188) -> nn.Module:
189 model = model.to(devices[0])
190 if isinstance(path, (Path, ZipPath)):
191 ctxt = path.open("rb")
192 else:
193 ctxt = nullcontext(BytesIO(path.read()))
195 with ctxt as f:
196 assert not isinstance(f, TextIOWrapper)
197 if Version(str(torch.__version__)) < Version("1.13"):
198 state = torch.load(f, map_location=devices[0])
199 else:
200 try:
201 state = torch.load(f, map_location=devices[0], weights_only=True)
202 except Exception as e:
203 msg = (
204 f"Failed to load weights with `weights_only=True`: {e}\n\n"
205 + "This usually means the weights file contains non-tensor objects"
206 + " (e.g. numpy arrays, custom classes, or nested dicts with"
207 + " metadata). The BioImage.IO spec requires a pure state dict —"
208 + " an OrderedDict mapping parameter names to tensors only.\n\n"
209 + "To fix this, extract only the state dict from your checkpoint:\n\n"
210 + " import torch\n"
211 + " checkpoint = torch.load('original.pth', weights_only=False)\n"
212 + " # Inspect keys, e.g.: checkpoint.keys()"
213 + " -> dict_keys(['model', 'optimizer', ...])\n"
214 + " torch.save(checkpoint['model'], 'weights.pt')\n\n"
215 + "Then reference 'weights.pt' in your bioimageio.yaml."
216 )
217 raise ValueError(msg) from e
219 incompatible = model.load_state_dict(state, strict=strict)
220 if (
221 isinstance(incompatible, tuple)
222 and hasattr(incompatible, "missing_keys")
223 and hasattr(incompatible, "unexpected_keys")
224 ):
225 if incompatible.missing_keys:
226 logger.warning("Missing state dict keys: {}", incompatible.missing_keys)
228 if hasattr(incompatible, "unexpected_keys") and incompatible.unexpected_keys:
229 logger.warning(
230 "Unexpected state dict keys: {}", incompatible.unexpected_keys
231 )
232 else:
233 logger.warning(
234 "`model.load_state_dict()` unexpectedly returned: {} "
235 + "(expected named tuple with `missing_keys` and `unexpected_keys` attributes)",
236 (s[:20] + "..." if len(s := str(incompatible)) > 20 else s),
237 )
239 return model
242def get_devices(
243 devices: Optional[Sequence[Union[torch.device, str]]] = None,
244) -> List[torch.device]:
245 if not devices:
246 if torch.cuda.is_available():
247 torch_devices = [
248 torch.device(f"cuda:{i}") for i in range(torch.cuda.device_count())
249 ]
250 elif torch.backends.mps.is_available():
251 torch_devices = [torch.device("mps")]
252 else:
253 try:
254 if (
255 torch.accelerator.is_available()
256 and (current_accelerator := torch.accelerator.current_accelerator())
257 is not None
258 ):
259 torch_devices = [current_accelerator]
260 else:
261 torch_devices = [torch.device("cpu")]
262 except Exception:
263 torch_devices = [torch.device("cpu")]
264 else:
265 torch_devices = [torch.device(d) for d in devices]
267 return torch_devices