Coverage for bioimageio/core/backends/torchscript_backend.py: 91%
35 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-16 15:20 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-16 15:20 +0000
1# pyright: reportUnknownVariableType=false
2import gc
3import warnings
4from typing import Any, List, Optional, Sequence, Union
6import torch
7from numpy.typing import NDArray
9from bioimageio.spec.model import v0_4, v0_5
11from ..model_adapters import ModelAdapter
12from ..utils._type_guards import is_list, is_tuple
15class TorchscriptModelAdapter(ModelAdapter):
16 def __init__(
17 self,
18 *,
19 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
20 devices: Optional[Sequence[str]] = None,
21 ):
22 super().__init__(model_description=model_description)
23 if model_description.weights.torchscript is None:
24 raise ValueError(
25 f"No torchscript weights found for model {model_description.name}"
26 )
28 if devices is None:
29 self.devices = ["cuda" if torch.cuda.is_available() else "cpu"]
30 else:
31 self.devices = [torch.device(d) for d in devices]
33 if len(self.devices) > 1:
34 warnings.warn(
35 "Multiple devices for single torchscript model not yet implemented"
36 )
38 weight_reader = model_description.weights.torchscript.get_reader()
39 self._model = torch.jit.load(weight_reader)
41 self._model.to(self.devices[0])
42 self._model = self._model.eval()
44 def _forward_impl(
45 self, input_arrays: Sequence[Optional[NDArray[Any]]]
46 ) -> List[Optional[NDArray[Any]]]:
48 with torch.no_grad():
49 torch_tensor = [
50 None if a is None else torch.from_numpy(a).to(self.devices[0])
51 for a in input_arrays
52 ]
53 output: Any = self._model.forward(*torch_tensor)
54 if is_list(output) or is_tuple(output):
55 output_seq: Sequence[Any] = output
56 else:
57 output_seq = [output]
59 return [
60 (
61 None
62 if r is None
63 else r.cpu().numpy() if isinstance(r, torch.Tensor) else r
64 )
65 for r in output_seq
66 ]
68 def unload(self) -> None:
69 self._devices = None
70 del self._model
71 _ = gc.collect() # deallocate memory
72 torch.cuda.empty_cache() # release reserved memory