Coverage for src / bioimageio / core / backends / pytorch_backend.py: 78%
113 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-18 12:35 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-18 12:35 +0000
1import gc
2import warnings
3from abc import abstractmethod
4from contextlib import nullcontext
5from io import BytesIO, TextIOWrapper
6from pathlib import Path
7from typing import Any, List, Literal, Mapping, Optional, Sequence, Tuple, Union
9import torch
10from loguru import logger
11from numpy.typing import NDArray
12from torch import nn
13from typing_extensions import Protocol, Self, assert_never, runtime_checkable
15from bioimageio.spec._internal.version_type import Version
16from bioimageio.spec.common import BytesReader, ZipPath
17from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
18from bioimageio.spec.utils import download
20from ..digest_spec import import_callable
21from ..utils._type_guards import is_list, is_ndarray, is_tuple
22from ._model_adapter import ModelAdapter
25@runtime_checkable
26class TorchNNModuleLike(Protocol):
27 @abstractmethod
28 def load_state_dict(
29 self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False
30 ) -> Self: ...
32 @abstractmethod
33 def to(
34 self,
35 *,
36 device: Optional[torch.device] = None,
37 dtype: Optional[torch.dtype] = None,
38 non_blocking: bool = False,
39 ) -> Self: ...
41 @abstractmethod
42 def forward(
43 self, *input: torch.Tensor
44 ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], List[torch.Tensor]]: ...
46 def eval(self) -> Self:
47 """Set model to eval mode"""
48 return self
51class PytorchModelAdapter(ModelAdapter):
52 def __init__(
53 self,
54 *,
55 model_description: AnyModelDescr,
56 devices: Optional[Sequence[Union[str, torch.device]]] = None,
57 mode: Literal["eval", "train"] = "eval",
58 ):
59 super().__init__(model_description=model_description)
60 weights = model_description.weights.pytorch_state_dict
61 if weights is None:
62 raise ValueError("No `pytorch_state_dict` weights found")
64 devices = get_devices(devices)
65 self._model = load_torch_model(weights, load_state=True, devices=devices)
66 if mode == "eval":
67 self._model = self._model.eval()
68 elif mode == "train":
69 self._model = self._model.train()
70 else:
71 assert_never(mode)
73 self._mode: Literal["eval", "train"] = mode
74 self._primary_device = devices[0]
76 def _forward_impl(
77 self, input_arrays: Sequence[Optional[NDArray[Any]]]
78 ) -> List[Optional[NDArray[Any]]]:
79 tensors = [
80 None if a is None else torch.from_numpy(a).to(self._primary_device)
81 for a in input_arrays
82 ]
84 if self._mode == "eval":
85 ctxt = torch.no_grad
86 elif self._mode == "train":
87 ctxt = nullcontext
88 else:
89 assert_never(self._mode)
91 with ctxt():
92 model_out = self._model(*tensors)
94 if is_tuple(model_out) or is_list(model_out):
95 model_out_seq = model_out
96 else:
97 model_out_seq = model_out = [model_out]
99 result: List[Optional[NDArray[Any]]] = []
100 for i, r in enumerate(model_out_seq):
101 if r is None:
102 result.append(None)
103 elif isinstance(r, torch.Tensor):
104 r_np: NDArray[Any] = ( # pyright: ignore[reportUnknownVariableType]
105 r.detach().cpu().numpy()
106 )
107 result.append(r_np)
108 elif is_ndarray(r):
109 result.append(r)
110 else:
111 raise TypeError(f"Model output[{i}] has unexpected type {type(r)}.")
113 return result
115 def unload(self) -> None:
116 del self._model
117 _ = gc.collect() # deallocate memory
118 assert torch is not None
119 torch.cuda.empty_cache() # release reserved memory
122def load_torch_model(
123 weight_spec: Union[
124 v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr
125 ],
126 *,
127 load_state: bool = True,
128 devices: Optional[Sequence[Union[str, torch.device]]] = None,
129) -> nn.Module:
130 custom_callable = import_callable(
131 weight_spec.architecture,
132 sha256=(
133 weight_spec.architecture_sha256
134 if isinstance(weight_spec, v0_4.PytorchStateDictWeightsDescr)
135 else weight_spec.sha256
136 ),
137 )
138 model_kwargs = (
139 weight_spec.kwargs
140 if isinstance(weight_spec, v0_4.PytorchStateDictWeightsDescr)
141 else weight_spec.architecture.kwargs
142 )
143 torch_model = custom_callable(**model_kwargs)
145 if not isinstance(torch_model, nn.Module):
146 if isinstance(
147 weight_spec.architecture,
148 (v0_4.CallableFromFile, v0_4.CallableFromDepencency),
149 ):
150 callable_name = weight_spec.architecture.callable_name
151 else:
152 callable_name = weight_spec.architecture.callable
154 raise ValueError(f"Calling {callable_name} did not return a torch.nn.Module.")
156 if load_state or devices:
157 use_devices = get_devices(devices)
158 torch_model = torch_model.to(use_devices[0])
159 if load_state:
160 torch_model = load_torch_state_dict(
161 torch_model,
162 path=download(weight_spec),
163 devices=use_devices,
164 strict=weight_spec.strict
165 if isinstance(weight_spec, v0_5.PytorchStateDictWeightsDescr)
166 else True,
167 )
168 return torch_model
171def load_torch_state_dict(
172 model: nn.Module,
173 path: Union[Path, ZipPath, BytesReader],
174 devices: Sequence[torch.device],
175 strict: bool = True,
176) -> nn.Module:
177 model = model.to(devices[0])
178 if isinstance(path, (Path, ZipPath)):
179 ctxt = path.open("rb")
180 else:
181 ctxt = nullcontext(BytesIO(path.read()))
183 with ctxt as f:
184 assert not isinstance(f, TextIOWrapper)
185 if Version(str(torch.__version__)) < Version("1.13"):
186 state = torch.load(f, map_location=devices[0])
187 else:
188 try:
189 state = torch.load(f, map_location=devices[0], weights_only=True)
190 except Exception as e:
191 msg = (
192 f"Failed to load weights with `weights_only=True`: {e}\n\n"
193 + "This usually means the weights file contains non-tensor objects"
194 + " (e.g. numpy arrays, custom classes, or nested dicts with"
195 + " metadata). The BioImage.IO spec requires a pure state dict —"
196 + " an OrderedDict mapping parameter names to tensors only.\n\n"
197 + "To fix this, extract only the state dict from your checkpoint:\n\n"
198 + " import torch\n"
199 + " checkpoint = torch.load('original.pth', weights_only=False)\n"
200 + " # Inspect keys, e.g.: checkpoint.keys()"
201 + " -> dict_keys(['model', 'optimizer', ...])\n"
202 + " torch.save(checkpoint['model'], 'weights.pt')\n\n"
203 + "Then reference 'weights.pt' in your bioimageio.yaml."
204 )
205 raise ValueError(msg) from e
207 incompatible = model.load_state_dict(state, strict=strict)
208 if (
209 isinstance(incompatible, tuple)
210 and hasattr(incompatible, "missing_keys")
211 and hasattr(incompatible, "unexpected_keys")
212 ):
213 if incompatible.missing_keys:
214 logger.warning("Missing state dict keys: {}", incompatible.missing_keys)
216 if hasattr(incompatible, "unexpected_keys") and incompatible.unexpected_keys:
217 logger.warning(
218 "Unexpected state dict keys: {}", incompatible.unexpected_keys
219 )
220 else:
221 logger.warning(
222 "`model.load_state_dict()` unexpectedly returned: {} "
223 + "(expected named tuple with `missing_keys` and `unexpected_keys` attributes)",
224 (s[:20] + "..." if len(s := str(incompatible)) > 20 else s),
225 )
227 return model
230def get_devices(
231 devices: Optional[Sequence[Union[torch.device, str]]] = None,
232) -> List[torch.device]:
233 if not devices:
234 if torch.cuda.is_available():
235 torch_devices = [torch.device("cuda")]
236 elif torch.backends.mps.is_available():
237 torch_devices = [torch.device("mps")]
238 else:
239 torch_devices = [torch.device("cpu")]
240 else:
241 torch_devices = [torch.device(d) for d in devices]
243 if len(torch_devices) > 1:
244 warnings.warn(
245 f"Multiple devices for pytorch model not yet implemented; ignoring {torch_devices[1:]}"
246 )
247 torch_devices = torch_devices[:1]
249 return torch_devices