Coverage for bioimageio/core/backends/pytorch_backend.py: 81%
100 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-16 15:20 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-16 15:20 +0000
1import gc
2import warnings
3from contextlib import nullcontext
4from io import BytesIO, TextIOWrapper
5from pathlib import Path
6from typing import Any, List, Literal, Optional, Sequence, Union
8import torch
9from loguru import logger
10from numpy.typing import NDArray
11from torch import nn
12from typing_extensions import assert_never
14from bioimageio.spec._internal.version_type import Version
15from bioimageio.spec.common import BytesReader, ZipPath
16from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
17from bioimageio.spec.utils import download
19from ..digest_spec import import_callable
20from ..utils._type_guards import is_list, is_ndarray, is_tuple
21from ._model_adapter import ModelAdapter
24class PytorchModelAdapter(ModelAdapter):
25 def __init__(
26 self,
27 *,
28 model_description: AnyModelDescr,
29 devices: Optional[Sequence[Union[str, torch.device]]] = None,
30 mode: Literal["eval", "train"] = "eval",
31 ):
32 super().__init__(model_description=model_description)
33 weights = model_description.weights.pytorch_state_dict
34 if weights is None:
35 raise ValueError("No `pytorch_state_dict` weights found")
37 devices = get_devices(devices)
38 self._model = load_torch_model(weights, load_state=True, devices=devices)
39 if mode == "eval":
40 self._model = self._model.eval()
41 elif mode == "train":
42 self._model = self._model.train()
43 else:
44 assert_never(mode)
46 self._mode: Literal["eval", "train"] = mode
47 self._primary_device = devices[0]
49 def _forward_impl(
50 self, input_arrays: Sequence[Optional[NDArray[Any]]]
51 ) -> List[Optional[NDArray[Any]]]:
52 tensors = [
53 None if a is None else torch.from_numpy(a).to(self._primary_device)
54 for a in input_arrays
55 ]
57 if self._mode == "eval":
58 ctxt = torch.no_grad
59 elif self._mode == "train":
60 ctxt = nullcontext
61 else:
62 assert_never(self._mode)
64 with ctxt():
65 model_out = self._model(*tensors)
67 if is_tuple(model_out) or is_list(model_out):
68 model_out_seq = model_out
69 else:
70 model_out_seq = model_out = [model_out]
72 result: List[Optional[NDArray[Any]]] = []
73 for i, r in enumerate(model_out_seq):
74 if r is None:
75 result.append(None)
76 elif isinstance(r, torch.Tensor):
77 r_np: NDArray[Any] = ( # pyright: ignore[reportUnknownVariableType]
78 r.detach().cpu().numpy()
79 )
80 result.append(r_np)
81 elif is_ndarray(r):
82 result.append(r)
83 else:
84 raise TypeError(f"Model output[{i}] has unexpected type {type(r)}.")
86 return result
88 def unload(self) -> None:
89 del self._model
90 _ = gc.collect() # deallocate memory
91 assert torch is not None
92 torch.cuda.empty_cache() # release reserved memory
95def load_torch_model(
96 weight_spec: Union[
97 v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr
98 ],
99 *,
100 load_state: bool = True,
101 devices: Optional[Sequence[Union[str, torch.device]]] = None,
102) -> nn.Module:
103 custom_callable = import_callable(
104 weight_spec.architecture,
105 sha256=(
106 weight_spec.architecture_sha256
107 if isinstance(weight_spec, v0_4.PytorchStateDictWeightsDescr)
108 else weight_spec.sha256
109 ),
110 )
111 model_kwargs = (
112 weight_spec.kwargs
113 if isinstance(weight_spec, v0_4.PytorchStateDictWeightsDescr)
114 else weight_spec.architecture.kwargs
115 )
116 torch_model = custom_callable(**model_kwargs)
118 if not isinstance(torch_model, nn.Module):
119 if isinstance(
120 weight_spec.architecture,
121 (v0_4.CallableFromFile, v0_4.CallableFromDepencency),
122 ):
123 callable_name = weight_spec.architecture.callable_name
124 else:
125 callable_name = weight_spec.architecture.callable
127 raise ValueError(f"Calling {callable_name} did not return a torch.nn.Module.")
129 if load_state or devices:
130 use_devices = get_devices(devices)
131 torch_model = torch_model.to(use_devices[0])
132 if load_state:
133 torch_model = load_torch_state_dict(
134 torch_model,
135 path=download(weight_spec),
136 devices=use_devices,
137 )
138 return torch_model
141def load_torch_state_dict(
142 model: nn.Module,
143 path: Union[Path, ZipPath, BytesReader],
144 devices: Sequence[torch.device],
145) -> nn.Module:
146 model = model.to(devices[0])
147 if isinstance(path, (Path, ZipPath)):
148 ctxt = path.open("rb")
149 else:
150 ctxt = nullcontext(BytesIO(path.read()))
152 with ctxt as f:
153 assert not isinstance(f, TextIOWrapper)
154 if Version(str(torch.__version__)) < Version("1.13"):
155 state = torch.load(f, map_location=devices[0])
156 else:
157 state = torch.load(f, map_location=devices[0], weights_only=True)
159 incompatible = model.load_state_dict(state)
160 if (
161 isinstance(incompatible, tuple)
162 and hasattr(incompatible, "missing_keys")
163 and hasattr(incompatible, "unexpected_keys")
164 ):
165 if incompatible.missing_keys:
166 logger.warning("Missing state dict keys: {}", incompatible.missing_keys)
168 if hasattr(incompatible, "unexpected_keys") and incompatible.unexpected_keys:
169 logger.warning(
170 "Unexpected state dict keys: {}", incompatible.unexpected_keys
171 )
172 else:
173 logger.warning(
174 "`model.load_state_dict()` unexpectedly returned: {} "
175 + "(expected named tuple with `missing_keys` and `unexpected_keys` attributes)",
176 (s[:20] + "..." if len(s := str(incompatible)) > 20 else s),
177 )
179 return model
182def get_devices(
183 devices: Optional[Sequence[Union[torch.device, str]]] = None,
184) -> List[torch.device]:
185 if not devices:
186 torch_devices = [
187 torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
188 ]
189 else:
190 torch_devices = [torch.device(d) for d in devices]
192 if len(torch_devices) > 1:
193 warnings.warn(
194 f"Multiple devices for single pytorch model not yet implemented; ignoring {torch_devices[1:]}"
195 )
196 torch_devices = torch_devices[:1]
198 return torch_devices