Coverage for src / bioimageio / core / backends / torchscript_backend.py: 97%
31 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-22 18:38 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-22 18:38 +0000
1# pyright: reportUnknownVariableType=false
2import gc
3from typing import Any, List, Optional, Sequence, Union
5import torch
6from numpy.typing import NDArray
8from bioimageio.spec.model import v0_4, v0_5
10from ..model_adapters import ModelAdapter
11from ..utils._type_guards import is_list, is_tuple
12from .pytorch_backend import get_devices
15class TorchscriptModelAdapter(ModelAdapter):
16 def __init__(
17 self,
18 *,
19 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
20 devices: Optional[Sequence[str]] = None,
21 ):
22 super().__init__(model_description=model_description)
23 if model_description.weights.torchscript is None:
24 raise ValueError(
25 f"No torchscript weights found for model {model_description.name}"
26 )
28 self.devices = get_devices(devices)
30 weight_reader = model_description.weights.torchscript.get_reader()
31 self._model = torch.jit.load(weight_reader)
33 self._model.to(self.devices[0])
34 self._model = self._model.eval()
36 def _forward_impl(
37 self, input_arrays: Sequence[Optional[NDArray[Any]]]
38 ) -> List[Optional[NDArray[Any]]]:
39 with torch.no_grad():
40 torch_tensor = [
41 None if a is None else torch.from_numpy(a).to(self.devices[0])
42 for a in input_arrays
43 ]
44 output: Any = self._model.forward(*torch_tensor)
45 if is_list(output) or is_tuple(output):
46 output_seq: Sequence[Any] = output
47 else:
48 output_seq = [output]
50 return [
51 (
52 None
53 if r is None
54 else r.cpu().numpy()
55 if isinstance(r, torch.Tensor)
56 else r
57 )
58 for r in output_seq
59 ]
61 def unload(self) -> None:
62 self._devices = None
63 del self._model
64 _ = gc.collect() # deallocate memory
65 torch.cuda.empty_cache() # release reserved memory