Coverage for src / bioimageio / core / backends / torchscript_backend.py: 97%

31 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-22 18:38 +0000

1# pyright: reportUnknownVariableType=false 

2import gc 

3from typing import Any, List, Optional, Sequence, Union 

4 

5import torch 

6from numpy.typing import NDArray 

7 

8from bioimageio.spec.model import v0_4, v0_5 

9 

10from ..model_adapters import ModelAdapter 

11from ..utils._type_guards import is_list, is_tuple 

12from .pytorch_backend import get_devices 

13 

14 

15class TorchscriptModelAdapter(ModelAdapter): 

16 def __init__( 

17 self, 

18 *, 

19 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

20 devices: Optional[Sequence[str]] = None, 

21 ): 

22 super().__init__(model_description=model_description) 

23 if model_description.weights.torchscript is None: 

24 raise ValueError( 

25 f"No torchscript weights found for model {model_description.name}" 

26 ) 

27 

28 self.devices = get_devices(devices) 

29 

30 weight_reader = model_description.weights.torchscript.get_reader() 

31 self._model = torch.jit.load(weight_reader) 

32 

33 self._model.to(self.devices[0]) 

34 self._model = self._model.eval() 

35 

36 def _forward_impl( 

37 self, input_arrays: Sequence[Optional[NDArray[Any]]] 

38 ) -> List[Optional[NDArray[Any]]]: 

39 with torch.no_grad(): 

40 torch_tensor = [ 

41 None if a is None else torch.from_numpy(a).to(self.devices[0]) 

42 for a in input_arrays 

43 ] 

44 output: Any = self._model.forward(*torch_tensor) 

45 if is_list(output) or is_tuple(output): 

46 output_seq: Sequence[Any] = output 

47 else: 

48 output_seq = [output] 

49 

50 return [ 

51 ( 

52 None 

53 if r is None 

54 else r.cpu().numpy() 

55 if isinstance(r, torch.Tensor) 

56 else r 

57 ) 

58 for r in output_seq 

59 ] 

60 

61 def unload(self) -> None: 

62 self._devices = None 

63 del self._model 

64 _ = gc.collect() # deallocate memory 

65 torch.cuda.empty_cache() # release reserved memory