Coverage for src/bioimageio/core/backends/torchscript_backend.py: 91%
35 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-14 08:35 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-14 08:35 +0000
1# pyright: reportUnknownVariableType=false
2import gc
3import warnings
4from typing import Any, List, Optional, Sequence, Union
6import torch
7from numpy.typing import NDArray
9from bioimageio.spec.model import v0_4, v0_5
11from ..model_adapters import ModelAdapter
12from ..utils._type_guards import is_list, is_tuple
15class TorchscriptModelAdapter(ModelAdapter):
16 def __init__(
17 self,
18 *,
19 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
20 devices: Optional[Sequence[str]] = None,
21 ):
22 super().__init__(model_description=model_description)
23 if model_description.weights.torchscript is None:
24 raise ValueError(
25 f"No torchscript weights found for model {model_description.name}"
26 )
28 if devices is None:
29 self.devices = ["cuda" if torch.cuda.is_available() else "cpu"]
30 else:
31 self.devices = [torch.device(d) for d in devices]
33 if len(self.devices) > 1:
34 warnings.warn(
35 "Multiple devices for single torchscript model not yet implemented"
36 )
38 weight_reader = model_description.weights.torchscript.get_reader()
39 self._model = torch.jit.load(weight_reader)
41 self._model.to(self.devices[0])
42 self._model = self._model.eval()
44 def _forward_impl(
45 self, input_arrays: Sequence[Optional[NDArray[Any]]]
46 ) -> List[Optional[NDArray[Any]]]:
47 with torch.no_grad():
48 torch_tensor = [
49 None if a is None else torch.from_numpy(a).to(self.devices[0])
50 for a in input_arrays
51 ]
52 output: Any = self._model.forward(*torch_tensor)
53 if is_list(output) or is_tuple(output):
54 output_seq: Sequence[Any] = output
55 else:
56 output_seq = [output]
58 return [
59 (
60 None
61 if r is None
62 else r.cpu().numpy()
63 if isinstance(r, torch.Tensor)
64 else r
65 )
66 for r in output_seq
67 ]
69 def unload(self) -> None:
70 self._devices = None
71 del self._model
72 _ = gc.collect() # deallocate memory
73 torch.cuda.empty_cache() # release reserved memory