Coverage for bioimageio/core/backends/torchscript_backend.py: 92%
37 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 09:51 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 09:51 +0000
1# pyright: reportUnknownVariableType=false
2import gc
3import warnings
4from typing import Any, List, Optional, Sequence, Union
6import torch
7from numpy.typing import NDArray
9from bioimageio.spec._internal.type_guards import is_list, is_tuple
10from bioimageio.spec.model import v0_4, v0_5
11from bioimageio.spec.utils import download
13from ..model_adapters import ModelAdapter
16class TorchscriptModelAdapter(ModelAdapter):
17 def __init__(
18 self,
19 *,
20 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
21 devices: Optional[Sequence[str]] = None,
22 ):
23 super().__init__(model_description=model_description)
24 if model_description.weights.torchscript is None:
25 raise ValueError(
26 f"No torchscript weights found for model {model_description.name}"
27 )
29 weight_path = download(model_description.weights.torchscript.source).path
30 if devices is None:
31 self.devices = ["cuda" if torch.cuda.is_available() else "cpu"]
32 else:
33 self.devices = [torch.device(d) for d in devices]
35 if len(self.devices) > 1:
36 warnings.warn(
37 "Multiple devices for single torchscript model not yet implemented"
38 )
40 with weight_path.open("rb") as f:
41 self._model = torch.jit.load(f)
43 self._model.to(self.devices[0])
44 self._model = self._model.eval()
46 def _forward_impl(
47 self, input_arrays: Sequence[Optional[NDArray[Any]]]
48 ) -> List[Optional[NDArray[Any]]]:
50 with torch.no_grad():
51 torch_tensor = [
52 None if a is None else torch.from_numpy(a).to(self.devices[0])
53 for a in input_arrays
54 ]
55 output: Any = self._model.forward(*torch_tensor)
56 if is_list(output) or is_tuple(output):
57 output_seq: Sequence[Any] = output
58 else:
59 output_seq = [output]
61 return [
62 (
63 None
64 if r is None
65 else r.cpu().numpy() if isinstance(r, torch.Tensor) else r
66 )
67 for r in output_seq
68 ]
70 def unload(self) -> None:
71 self._devices = None
72 del self._model
73 _ = gc.collect() # deallocate memory
74 torch.cuda.empty_cache() # release reserved memory