Coverage for src / bioimageio / core / backends / pytorch_backend.py: 77%
108 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-22 18:38 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-22 18:38 +0000
1import gc
2import warnings
3from contextlib import nullcontext
4from io import BytesIO, TextIOWrapper
5from pathlib import Path
6from typing import Any, List, Literal, Optional, Sequence, Union
8import torch
9from loguru import logger
10from numpy.typing import NDArray
11from torch import nn
12from typing_extensions import assert_never
14from bioimageio.spec._internal.version_type import Version
15from bioimageio.spec.common import BytesReader, ZipPath
16from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
17from bioimageio.spec.utils import download
19from ..digest_spec import import_callable
20from ..utils._type_guards import is_list, is_ndarray, is_tuple
21from ._model_adapter import ModelAdapter
24class PytorchModelAdapter(ModelAdapter):
25 def __init__(
26 self,
27 *,
28 model_description: AnyModelDescr,
29 devices: Optional[Sequence[Union[str, torch.device]]] = None,
30 mode: Literal["eval", "train"] = "eval",
31 ):
32 super().__init__(model_description=model_description)
33 weights = model_description.weights.pytorch_state_dict
34 if weights is None:
35 raise ValueError("No `pytorch_state_dict` weights found")
37 devices = get_devices(devices)
38 self._model = load_torch_model(weights, load_state=True, devices=devices)
39 if mode == "eval":
40 self._model = self._model.eval()
41 elif mode == "train":
42 self._model = self._model.train()
43 else:
44 assert_never(mode)
46 self._mode: Literal["eval", "train"] = mode
47 self._primary_device = devices[0]
49 def _forward_impl(
50 self, input_arrays: Sequence[Optional[NDArray[Any]]]
51 ) -> List[Optional[NDArray[Any]]]:
52 tensors = [
53 None if a is None else torch.from_numpy(a).to(self._primary_device)
54 for a in input_arrays
55 ]
57 if self._mode == "eval":
58 ctxt = torch.no_grad
59 elif self._mode == "train":
60 ctxt = nullcontext
61 else:
62 assert_never(self._mode)
64 with ctxt():
65 model_out = self._model(*tensors)
67 if is_tuple(model_out) or is_list(model_out):
68 model_out_seq = model_out
69 else:
70 model_out_seq = model_out = [model_out]
72 result: List[Optional[NDArray[Any]]] = []
73 for i, r in enumerate(model_out_seq):
74 if r is None:
75 result.append(None)
76 elif isinstance(r, torch.Tensor):
77 r_np: NDArray[Any] = ( # pyright: ignore[reportUnknownVariableType]
78 r.detach().cpu().numpy()
79 )
80 result.append(r_np)
81 elif is_ndarray(r):
82 result.append(r)
83 else:
84 raise TypeError(f"Model output[{i}] has unexpected type {type(r)}.")
86 return result
88 def unload(self) -> None:
89 del self._model
90 _ = gc.collect() # deallocate memory
91 assert torch is not None
92 torch.cuda.empty_cache() # release reserved memory
95def load_torch_model(
96 weight_spec: Union[
97 v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr
98 ],
99 *,
100 load_state: bool = True,
101 devices: Optional[Sequence[Union[str, torch.device]]] = None,
102) -> nn.Module:
103 custom_callable = import_callable(
104 weight_spec.architecture,
105 sha256=(
106 weight_spec.architecture_sha256
107 if isinstance(weight_spec, v0_4.PytorchStateDictWeightsDescr)
108 else weight_spec.sha256
109 ),
110 )
111 model_kwargs = (
112 weight_spec.kwargs
113 if isinstance(weight_spec, v0_4.PytorchStateDictWeightsDescr)
114 else weight_spec.architecture.kwargs
115 )
116 torch_model = custom_callable(**model_kwargs)
118 if not isinstance(torch_model, nn.Module):
119 if isinstance(
120 weight_spec.architecture,
121 (v0_4.CallableFromFile, v0_4.CallableFromDepencency),
122 ):
123 callable_name = weight_spec.architecture.callable_name
124 else:
125 callable_name = weight_spec.architecture.callable
127 raise ValueError(f"Calling {callable_name} did not return a torch.nn.Module.")
129 if load_state or devices:
130 use_devices = get_devices(devices)
131 torch_model = torch_model.to(use_devices[0])
132 if load_state:
133 torch_model = load_torch_state_dict(
134 torch_model,
135 path=download(weight_spec),
136 devices=use_devices,
137 )
138 return torch_model
141def load_torch_state_dict(
142 model: nn.Module,
143 path: Union[Path, ZipPath, BytesReader],
144 devices: Sequence[torch.device],
145) -> nn.Module:
146 model = model.to(devices[0])
147 if isinstance(path, (Path, ZipPath)):
148 ctxt = path.open("rb")
149 else:
150 ctxt = nullcontext(BytesIO(path.read()))
152 with ctxt as f:
153 assert not isinstance(f, TextIOWrapper)
154 if Version(str(torch.__version__)) < Version("1.13"):
155 state = torch.load(f, map_location=devices[0])
156 else:
157 try:
158 state = torch.load(f, map_location=devices[0], weights_only=True)
159 except Exception as e:
160 msg = (
161 f"Failed to load weights with `weights_only=True`: {e}\n\n"
162 + "This usually means the weights file contains non-tensor objects"
163 + " (e.g. numpy arrays, custom classes, or nested dicts with"
164 + " metadata). The BioImage.IO spec requires a pure state dict —"
165 + " an OrderedDict mapping parameter names to tensors only.\n\n"
166 + "To fix this, extract only the state dict from your checkpoint:\n\n"
167 + " import torch\n"
168 + " checkpoint = torch.load('original.pth', weights_only=False)\n"
169 + " # Inspect keys, e.g.: checkpoint.keys()"
170 + " -> dict_keys(['model', 'optimizer', ...])\n"
171 + " torch.save(checkpoint['model'], 'weights.pt')\n\n"
172 + "Then reference 'weights.pt' in your bioimageio.yaml."
173 )
174 raise ValueError(msg) from e
176 incompatible = model.load_state_dict(state)
177 if (
178 isinstance(incompatible, tuple)
179 and hasattr(incompatible, "missing_keys")
180 and hasattr(incompatible, "unexpected_keys")
181 ):
182 if incompatible.missing_keys:
183 logger.warning("Missing state dict keys: {}", incompatible.missing_keys)
185 if hasattr(incompatible, "unexpected_keys") and incompatible.unexpected_keys:
186 logger.warning(
187 "Unexpected state dict keys: {}", incompatible.unexpected_keys
188 )
189 else:
190 logger.warning(
191 "`model.load_state_dict()` unexpectedly returned: {} "
192 + "(expected named tuple with `missing_keys` and `unexpected_keys` attributes)",
193 (s[:20] + "..." if len(s := str(incompatible)) > 20 else s),
194 )
196 return model
199def get_devices(
200 devices: Optional[Sequence[Union[torch.device, str]]] = None,
201) -> List[torch.device]:
202 if not devices:
203 if torch.cuda.is_available():
204 torch_devices = [torch.device("cuda")]
205 elif torch.backends.mps.is_available():
206 torch_devices = [torch.device("mps")]
207 else:
208 torch_devices = [torch.device("cpu")]
209 else:
210 torch_devices = [torch.device(d) for d in devices]
212 if len(torch_devices) > 1:
213 warnings.warn(
214 f"Multiple devices for pytorch model not yet implemented; ignoring {torch_devices[1:]}"
215 )
216 torch_devices = torch_devices[:1]
218 return torch_devices