Coverage for src/bioimageio/core/backends/torchscript_backend.py: 87%

38 statements  

« prev     ^ index     » next       coverage.py v7.14.2, created at 2026-06-22 16:54 +0000

1# pyright: reportUnknownVariableType=false 

2import gc 

3from typing import Any, List, Optional, Sequence, Union 

4 

5import torch 

6from loguru import logger 

7from numpy.typing import NDArray 

8 

9from bioimageio.spec.model import v0_4, v0_5 

10 

11from .._model_adapter import LocalModelAdapter 

12from ..utils._type_guards import is_list, is_tuple 

13from .pytorch_backend import get_devices 

14 

15 

16class TorchscriptModelAdapter(LocalModelAdapter[torch.device, Any]): 

17 def __init__( 

18 self, 

19 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

20 devices: Optional[Sequence[str]] = None, 

21 ): 

22 if model_description.weights.torchscript is None: 

23 raise ValueError( 

24 f"No torchscript weights found for model {model_description.name}" 

25 ) 

26 

27 self._weight_descr = model_description.weights.torchscript 

28 super().__init__(model_description=model_description, devices=devices) 

29 

30 def _parse_devices( 

31 self, devices: Optional[Sequence[str]] 

32 ) -> Sequence[torch.device]: 

33 return get_devices(devices) 

34 

35 def _init_model_on_device(self, device: torch.device) -> Any: 

36 model = torch.jit.load(self._weight_descr.get_reader(), map_location=device) 

37 try: 

38 model.eval() 

39 except Exception as e: 

40 logger.warning( 

41 f"Failed to set model to evaluation mode for torchscript model on {device}: {e}" 

42 ) 

43 return model 

44 

45 def _forward_impl( 

46 self, 

47 device: torch.device, 

48 model: Any, 

49 input_arrays: Sequence[Optional[NDArray[Any]]], 

50 ) -> List[Optional[NDArray[Any]]]: 

51 with torch.no_grad(): 

52 torch_tensor = [ 

53 None if a is None else torch.from_numpy(a).to(device) 

54 for a in input_arrays 

55 ] 

56 output: Any = model.forward(*torch_tensor) 

57 if is_list(output) or is_tuple(output): 

58 output_seq: Sequence[Any] = output 

59 else: 

60 output_seq = [output] 

61 

62 return [ 

63 ( 

64 None 

65 if r is None 

66 else r.cpu().numpy() 

67 if isinstance(r, torch.Tensor) 

68 else r 

69 ) 

70 for r in output_seq 

71 ] 

72 

73 def _cleanup_pre_model_deletion(self, device: torch.device, model: Any) -> None: 

74 return 

75 

76 def _cleanup_post_model_deletion(self, device: torch.device) -> None: 

77 _ = gc.collect() # deallocate memory 

78 if device.type == "cuda": 

79 torch.cuda.empty_cache() # release reserved memory