Coverage for bioimageio/core/model_adapters/_pytorch_model_adapter.py: 84%
67 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-19 09:02 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-19 09:02 +0000
1import gc
2import warnings
3from typing import Any, List, Optional, Sequence, Tuple, Union
5from bioimageio.spec.model import v0_4, v0_5
6from bioimageio.spec.utils import download
8from ..axis import AxisId
9from ..digest_spec import get_axes_infos, import_callable
10from ..tensor import Tensor
11from ._model_adapter import ModelAdapter
13try:
14 import torch
15except Exception as e:
16 torch = None
17 torch_error = str(e)
18else:
19 torch_error = None
22class PytorchModelAdapter(ModelAdapter):
23 def __init__(
24 self,
25 *,
26 outputs: Union[
27 Sequence[v0_4.OutputTensorDescr], Sequence[v0_5.OutputTensorDescr]
28 ],
29 weights: Union[
30 v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr
31 ],
32 devices: Optional[Sequence[str]] = None,
33 ):
34 if torch is None:
35 raise ImportError(f"failed to import torch: {torch_error}")
37 super().__init__()
38 self.output_dims = [tuple(a.id for a in get_axes_infos(out)) for out in outputs]
39 self._network = self.get_network(weights)
40 self._devices = self.get_devices(devices)
41 self._network = self._network.to(self._devices[0])
43 self._primary_device = self._devices[0]
44 state: Any = torch.load(
45 download(weights).path,
46 map_location=self._primary_device, # pyright: ignore[reportUnknownArgumentType]
47 )
48 self._network.load_state_dict(state)
50 self._network = self._network.eval()
52 def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]:
53 if torch is None:
54 raise ImportError("torch")
55 with torch.no_grad():
56 tensors = [
57 None if ipt is None else torch.from_numpy(ipt.data.data)
58 for ipt in input_tensors
59 ]
60 tensors = [
61 (
62 None
63 if t is None
64 else t.to(
65 self._primary_device # pyright: ignore[reportUnknownArgumentType]
66 )
67 )
68 for t in tensors
69 ]
70 result: Union[Tuple[Any, ...], List[Any], Any]
71 result = self._network( # pyright: ignore[reportUnknownVariableType]
72 *tensors
73 )
74 if not isinstance(result, (tuple, list)):
75 result = [result]
77 result = [
78 (
79 None
80 if r is None
81 else r.detach().cpu().numpy() if isinstance(r, torch.Tensor) else r
82 )
83 for r in result # pyright: ignore[reportUnknownVariableType]
84 ]
85 if len(result) > len(self.output_dims):
86 raise ValueError(
87 f"Expected at most {len(self.output_dims)} outputs, but got {len(result)}"
88 )
90 return [
91 None if r is None else Tensor(r, dims=out)
92 for r, out in zip(result, self.output_dims)
93 ]
95 def unload(self) -> None:
96 del self._network
97 _ = gc.collect() # deallocate memory
98 assert torch is not None
99 torch.cuda.empty_cache() # release reserved memory
101 @staticmethod
102 def get_network( # pyright: ignore[reportUnknownParameterType]
103 weight_spec: Union[
104 v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr
105 ],
106 ) -> "torch.nn.Module": # pyright: ignore[reportInvalidTypeForm]
107 if torch is None:
108 raise ImportError("torch")
109 arch = import_callable(
110 weight_spec.architecture,
111 sha256=(
112 weight_spec.architecture_sha256
113 if isinstance(weight_spec, v0_4.PytorchStateDictWeightsDescr)
114 else weight_spec.sha256
115 ),
116 )
117 model_kwargs = (
118 weight_spec.kwargs
119 if isinstance(weight_spec, v0_4.PytorchStateDictWeightsDescr)
120 else weight_spec.architecture.kwargs
121 )
122 network = arch(**model_kwargs)
123 if not isinstance(network, torch.nn.Module):
124 raise ValueError(
125 f"calling {weight_spec.architecture.callable} did not return a torch.nn.Module"
126 )
128 return network
130 @staticmethod
131 def get_devices( # pyright: ignore[reportUnknownParameterType]
132 devices: Optional[Sequence[str]] = None,
133 ) -> List["torch.device"]: # pyright: ignore[reportInvalidTypeForm]
134 if torch is None:
135 raise ImportError("torch")
136 if not devices:
137 torch_devices = [
138 (
139 torch.device("cuda")
140 if torch.cuda.is_available()
141 else torch.device("cpu")
142 )
143 ]
144 else:
145 torch_devices = [torch.device(d) for d in devices]
147 if len(torch_devices) > 1:
148 warnings.warn(
149 f"Multiple devices for single pytorch model not yet implemented; ignoring {torch_devices[1:]}"
150 )
151 torch_devices = torch_devices[:1]
153 return torch_devices