Coverage for src/bioimageio/core/backends/torchscript_backend.py: 87%
38 statements
« prev ^ index » next coverage.py v7.14.2, created at 2026-06-22 16:54 +0000
« prev ^ index » next coverage.py v7.14.2, created at 2026-06-22 16:54 +0000
1# pyright: reportUnknownVariableType=false
2import gc
3from typing import Any, List, Optional, Sequence, Union
5import torch
6from loguru import logger
7from numpy.typing import NDArray
9from bioimageio.spec.model import v0_4, v0_5
11from .._model_adapter import LocalModelAdapter
12from ..utils._type_guards import is_list, is_tuple
13from .pytorch_backend import get_devices
16class TorchscriptModelAdapter(LocalModelAdapter[torch.device, Any]):
17 def __init__(
18 self,
19 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
20 devices: Optional[Sequence[str]] = None,
21 ):
22 if model_description.weights.torchscript is None:
23 raise ValueError(
24 f"No torchscript weights found for model {model_description.name}"
25 )
27 self._weight_descr = model_description.weights.torchscript
28 super().__init__(model_description=model_description, devices=devices)
30 def _parse_devices(
31 self, devices: Optional[Sequence[str]]
32 ) -> Sequence[torch.device]:
33 return get_devices(devices)
35 def _init_model_on_device(self, device: torch.device) -> Any:
36 model = torch.jit.load(self._weight_descr.get_reader(), map_location=device)
37 try:
38 model.eval()
39 except Exception as e:
40 logger.warning(
41 f"Failed to set model to evaluation mode for torchscript model on {device}: {e}"
42 )
43 return model
45 def _forward_impl(
46 self,
47 device: torch.device,
48 model: Any,
49 input_arrays: Sequence[Optional[NDArray[Any]]],
50 ) -> List[Optional[NDArray[Any]]]:
51 with torch.no_grad():
52 torch_tensor = [
53 None if a is None else torch.from_numpy(a).to(device)
54 for a in input_arrays
55 ]
56 output: Any = model.forward(*torch_tensor)
57 if is_list(output) or is_tuple(output):
58 output_seq: Sequence[Any] = output
59 else:
60 output_seq = [output]
62 return [
63 (
64 None
65 if r is None
66 else r.cpu().numpy()
67 if isinstance(r, torch.Tensor)
68 else r
69 )
70 for r in output_seq
71 ]
73 def _cleanup_pre_model_deletion(self, device: torch.device, model: Any) -> None:
74 return
76 def _cleanup_post_model_deletion(self, device: torch.device) -> None:
77 _ = gc.collect() # deallocate memory
78 if device.type == "cuda":
79 torch.cuda.empty_cache() # release reserved memory