Coverage for bioimageio/core/backends/torchscript_backend.py: 91%

35 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-16 15:20 +0000

1# pyright: reportUnknownVariableType=false 

2import gc 

3import warnings 

4from typing import Any, List, Optional, Sequence, Union 

5 

6import torch 

7from numpy.typing import NDArray 

8 

9from bioimageio.spec.model import v0_4, v0_5 

10 

11from ..model_adapters import ModelAdapter 

12from ..utils._type_guards import is_list, is_tuple 

13 

14 

15class TorchscriptModelAdapter(ModelAdapter): 

16 def __init__( 

17 self, 

18 *, 

19 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

20 devices: Optional[Sequence[str]] = None, 

21 ): 

22 super().__init__(model_description=model_description) 

23 if model_description.weights.torchscript is None: 

24 raise ValueError( 

25 f"No torchscript weights found for model {model_description.name}" 

26 ) 

27 

28 if devices is None: 

29 self.devices = ["cuda" if torch.cuda.is_available() else "cpu"] 

30 else: 

31 self.devices = [torch.device(d) for d in devices] 

32 

33 if len(self.devices) > 1: 

34 warnings.warn( 

35 "Multiple devices for single torchscript model not yet implemented" 

36 ) 

37 

38 weight_reader = model_description.weights.torchscript.get_reader() 

39 self._model = torch.jit.load(weight_reader) 

40 

41 self._model.to(self.devices[0]) 

42 self._model = self._model.eval() 

43 

44 def _forward_impl( 

45 self, input_arrays: Sequence[Optional[NDArray[Any]]] 

46 ) -> List[Optional[NDArray[Any]]]: 

47 

48 with torch.no_grad(): 

49 torch_tensor = [ 

50 None if a is None else torch.from_numpy(a).to(self.devices[0]) 

51 for a in input_arrays 

52 ] 

53 output: Any = self._model.forward(*torch_tensor) 

54 if is_list(output) or is_tuple(output): 

55 output_seq: Sequence[Any] = output 

56 else: 

57 output_seq = [output] 

58 

59 return [ 

60 ( 

61 None 

62 if r is None 

63 else r.cpu().numpy() if isinstance(r, torch.Tensor) else r 

64 ) 

65 for r in output_seq 

66 ] 

67 

68 def unload(self) -> None: 

69 self._devices = None 

70 del self._model 

71 _ = gc.collect() # deallocate memory 

72 torch.cuda.empty_cache() # release reserved memory