Coverage for bioimageio/core/backends/torchscript_backend.py: 92%

37 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-01 09:51 +0000

1# pyright: reportUnknownVariableType=false 

2import gc 

3import warnings 

4from typing import Any, List, Optional, Sequence, Union 

5 

6import torch 

7from numpy.typing import NDArray 

8 

9from bioimageio.spec._internal.type_guards import is_list, is_tuple 

10from bioimageio.spec.model import v0_4, v0_5 

11from bioimageio.spec.utils import download 

12 

13from ..model_adapters import ModelAdapter 

14 

15 

16class TorchscriptModelAdapter(ModelAdapter): 

17 def __init__( 

18 self, 

19 *, 

20 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

21 devices: Optional[Sequence[str]] = None, 

22 ): 

23 super().__init__(model_description=model_description) 

24 if model_description.weights.torchscript is None: 

25 raise ValueError( 

26 f"No torchscript weights found for model {model_description.name}" 

27 ) 

28 

29 weight_path = download(model_description.weights.torchscript.source).path 

30 if devices is None: 

31 self.devices = ["cuda" if torch.cuda.is_available() else "cpu"] 

32 else: 

33 self.devices = [torch.device(d) for d in devices] 

34 

35 if len(self.devices) > 1: 

36 warnings.warn( 

37 "Multiple devices for single torchscript model not yet implemented" 

38 ) 

39 

40 with weight_path.open("rb") as f: 

41 self._model = torch.jit.load(f) 

42 

43 self._model.to(self.devices[0]) 

44 self._model = self._model.eval() 

45 

46 def _forward_impl( 

47 self, input_arrays: Sequence[Optional[NDArray[Any]]] 

48 ) -> List[Optional[NDArray[Any]]]: 

49 

50 with torch.no_grad(): 

51 torch_tensor = [ 

52 None if a is None else torch.from_numpy(a).to(self.devices[0]) 

53 for a in input_arrays 

54 ] 

55 output: Any = self._model.forward(*torch_tensor) 

56 if is_list(output) or is_tuple(output): 

57 output_seq: Sequence[Any] = output 

58 else: 

59 output_seq = [output] 

60 

61 return [ 

62 ( 

63 None 

64 if r is None 

65 else r.cpu().numpy() if isinstance(r, torch.Tensor) else r 

66 ) 

67 for r in output_seq 

68 ] 

69 

70 def unload(self) -> None: 

71 self._devices = None 

72 del self._model 

73 _ = gc.collect() # deallocate memory 

74 torch.cuda.empty_cache() # release reserved memory