Coverage for bioimageio/core/model_adapters/_torchscript_model_adapter.py: 86%
50 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-19 09:02 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-19 09:02 +0000
1import gc
2import warnings
3from typing import Any, List, Optional, Sequence, Tuple, Union
5import numpy as np
6from numpy.typing import NDArray
8from bioimageio.spec.model import v0_4, v0_5
9from bioimageio.spec.utils import download
11from ..digest_spec import get_axes_infos
12from ..tensor import Tensor
13from ._model_adapter import ModelAdapter
15try:
16 import torch
17except Exception as e:
18 torch = None
19 torch_error = str(e)
20else:
21 torch_error = None
24class TorchscriptModelAdapter(ModelAdapter):
25 def __init__(
26 self,
27 *,
28 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
29 devices: Optional[Sequence[str]] = None,
30 ):
31 if torch is None:
32 raise ImportError(f"failed to import torch: {torch_error}")
34 super().__init__()
35 if model_description.weights.torchscript is None:
36 raise ValueError(
37 f"No torchscript weights found for model {model_description.name}"
38 )
40 weight_path = download(model_description.weights.torchscript.source).path
41 if devices is None:
42 self.devices = ["cuda" if torch.cuda.is_available() else "cpu"]
43 else:
44 self.devices = [torch.device(d) for d in devices]
46 if len(self.devices) > 1:
47 warnings.warn(
48 "Multiple devices for single torchscript model not yet implemented"
49 )
51 self._model = torch.jit.load(weight_path)
52 self._model.to(self.devices[0])
53 self._model = self._model.eval()
54 self._internal_output_axes = [
55 tuple(a.id for a in get_axes_infos(out))
56 for out in model_description.outputs
57 ]
59 def forward(self, *batch: Optional[Tensor]) -> List[Optional[Tensor]]:
60 assert torch is not None
61 with torch.no_grad():
62 torch_tensor = [
63 None if b is None else torch.from_numpy(b.data.data).to(self.devices[0])
64 for b in batch
65 ]
66 _result: Union[ # pyright: ignore[reportUnknownVariableType]
67 Tuple[Optional[NDArray[Any]], ...],
68 List[Optional[NDArray[Any]]],
69 Optional[NDArray[Any]],
70 ] = self._model.forward(*torch_tensor)
71 if isinstance(_result, (tuple, list)):
72 result: Sequence[Optional[NDArray[Any]]] = _result
73 else:
74 result = [_result]
76 result = [
77 (
78 None
79 if r is None
80 else r.cpu().numpy() if not isinstance(r, np.ndarray) else r
81 )
82 for r in result
83 ]
85 assert len(result) == len(self._internal_output_axes)
86 return [
87 None if r is None else Tensor(r, dims=axes)
88 for r, axes in zip(result, self._internal_output_axes)
89 ]
91 def unload(self) -> None:
92 assert torch is not None
93 self._devices = None
94 del self._model
95 _ = gc.collect() # deallocate memory
96 torch.cuda.empty_cache() # release reserved memory