Coverage for bioimageio/core/backends/pytorch_backend.py: 84%
92 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 09:51 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 09:51 +0000
1import gc
2import warnings
3from contextlib import nullcontext
4from io import TextIOWrapper
5from pathlib import Path
6from typing import Any, List, Literal, Optional, Sequence, Union
8import torch
9from loguru import logger
10from numpy.typing import NDArray
11from torch import nn
12from typing_extensions import assert_never
14from bioimageio.spec._internal.type_guards import is_list, is_ndarray, is_tuple
15from bioimageio.spec.common import ZipPath
16from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
17from bioimageio.spec.utils import download
19from ..digest_spec import import_callable
20from ._model_adapter import ModelAdapter
23class PytorchModelAdapter(ModelAdapter):
24 def __init__(
25 self,
26 *,
27 model_description: AnyModelDescr,
28 devices: Optional[Sequence[Union[str, torch.device]]] = None,
29 mode: Literal["eval", "train"] = "eval",
30 ):
31 super().__init__(model_description=model_description)
32 weights = model_description.weights.pytorch_state_dict
33 if weights is None:
34 raise ValueError("No `pytorch_state_dict` weights found")
36 devices = get_devices(devices)
37 self._model = load_torch_model(weights, load_state=True, devices=devices)
38 if mode == "eval":
39 self._model = self._model.eval()
40 elif mode == "train":
41 self._model = self._model.train()
42 else:
43 assert_never(mode)
45 self._mode: Literal["eval", "train"] = mode
46 self._primary_device = devices[0]
48 def _forward_impl(
49 self, input_arrays: Sequence[Optional[NDArray[Any]]]
50 ) -> List[Optional[NDArray[Any]]]:
51 tensors = [
52 None if a is None else torch.from_numpy(a).to(self._primary_device)
53 for a in input_arrays
54 ]
56 if self._mode == "eval":
57 ctxt = torch.no_grad
58 elif self._mode == "train":
59 ctxt = nullcontext
60 else:
61 assert_never(self._mode)
63 with ctxt():
64 model_out = self._model(*tensors)
66 if is_tuple(model_out) or is_list(model_out):
67 model_out_seq = model_out
68 else:
69 model_out_seq = model_out = [model_out]
71 result: List[Optional[NDArray[Any]]] = []
72 for i, r in enumerate(model_out_seq):
73 if r is None:
74 result.append(None)
75 elif isinstance(r, torch.Tensor):
76 r_np: NDArray[Any] = r.detach().cpu().numpy()
77 result.append(r_np)
78 elif is_ndarray(r):
79 result.append(r)
80 else:
81 raise TypeError(f"Model output[{i}] has unexpected type {type(r)}.")
83 return result
85 def unload(self) -> None:
86 del self._model
87 _ = gc.collect() # deallocate memory
88 assert torch is not None
89 torch.cuda.empty_cache() # release reserved memory
92def load_torch_model(
93 weight_spec: Union[
94 v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr
95 ],
96 *,
97 load_state: bool = True,
98 devices: Optional[Sequence[Union[str, torch.device]]] = None,
99) -> nn.Module:
100 custom_callable = import_callable(
101 weight_spec.architecture,
102 sha256=(
103 weight_spec.architecture_sha256
104 if isinstance(weight_spec, v0_4.PytorchStateDictWeightsDescr)
105 else weight_spec.sha256
106 ),
107 )
108 model_kwargs = (
109 weight_spec.kwargs
110 if isinstance(weight_spec, v0_4.PytorchStateDictWeightsDescr)
111 else weight_spec.architecture.kwargs
112 )
113 torch_model = custom_callable(**model_kwargs)
115 if not isinstance(torch_model, nn.Module):
116 if isinstance(
117 weight_spec.architecture,
118 (v0_4.CallableFromFile, v0_4.CallableFromDepencency),
119 ):
120 callable_name = weight_spec.architecture.callable_name
121 else:
122 callable_name = weight_spec.architecture.callable
124 raise ValueError(f"Calling {callable_name} did not return a torch.nn.Module.")
126 if load_state or devices:
127 use_devices = get_devices(devices)
128 torch_model = torch_model.to(use_devices[0])
129 if load_state:
130 torch_model = load_torch_state_dict(
131 torch_model,
132 path=download(weight_spec).path,
133 devices=use_devices,
134 )
135 return torch_model
138def load_torch_state_dict(
139 model: nn.Module,
140 path: Union[Path, ZipPath],
141 devices: Sequence[torch.device],
142) -> nn.Module:
143 model = model.to(devices[0])
144 with path.open("rb") as f:
145 assert not isinstance(f, TextIOWrapper)
146 state = torch.load(f, map_location=devices[0], weights_only=True)
148 incompatible = model.load_state_dict(state)
149 if (
150 incompatible is not None # pyright: ignore[reportUnnecessaryComparison]
151 and incompatible.missing_keys
152 ):
153 logger.warning("Missing state dict keys: {}", incompatible.missing_keys)
155 if (
156 incompatible is not None # pyright: ignore[reportUnnecessaryComparison]
157 and incompatible.unexpected_keys
158 ):
159 logger.warning("Unexpected state dict keys: {}", incompatible.unexpected_keys)
161 return model
164def get_devices(
165 devices: Optional[Sequence[Union[torch.device, str]]] = None,
166) -> List[torch.device]:
167 if not devices:
168 torch_devices = [
169 torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
170 ]
171 else:
172 torch_devices = [torch.device(d) for d in devices]
174 if len(torch_devices) > 1:
175 warnings.warn(
176 f"Multiple devices for single pytorch model not yet implemented; ignoring {torch_devices[1:]}"
177 )
178 torch_devices = torch_devices[:1]
180 return torch_devices