Coverage for src / bioimageio / core / weight_converters / pytorch_to_torchscript.py: 95%

21 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-18 12:35 +0000

1from pathlib import Path 

2from typing import Any, Optional, Sequence, Tuple, Union 

3 

4import torch 

5from torch.jit import ScriptModule 

6 

7from bioimageio.spec._internal.version_type import Version 

8from bioimageio.spec.model.v0_5 import ModelDescr, TorchscriptWeightsDescr 

9 

10from .. import __version__ 

11from ..backends.pytorch_backend import get_devices, load_torch_model 

12 

13 

14def convert( 

15 model_descr: ModelDescr, 

16 output_path: Path, 

17 *, 

18 use_tracing: bool = True, 

19 devices: Optional[Sequence[Union[str, torch.device]]] = None, 

20) -> TorchscriptWeightsDescr: 

21 """ 

22 Convert model weights from the PyTorch `state_dict` format to TorchScript. 

23 

24 Args: 

25 model_descr: 

26 The model description object that contains the model and its weights in the PyTorch `state_dict` format. 

27 output_path: 

28 The file path where the TorchScript model will be saved. 

29 use_tracing: 

30 Whether to use tracing or scripting to export the TorchScript format. 

31 - `True`: Use tracing, which is recommended for models with straightforward control flow. 

32 - `False`: Use scripting, which is better for models with dynamic control flow (e.g., loops, conditionals). 

33 

34 Raises: 

35 ValueError: 

36 If the provided model does not have weights in the PyTorch `state_dict` format. 

37 

38 Returns: 

39 A descriptor object that contains information about the exported TorchScript weights. 

40 """ 

41 state_dict_weights_descr = model_descr.weights.pytorch_state_dict 

42 if state_dict_weights_descr is None: 

43 raise ValueError( 

44 "The provided model does not have weights in the pytorch state dict format" 

45 ) 

46 

47 input_data = model_descr.get_input_test_arrays() 

48 devices = get_devices(devices) 

49 

50 with torch.no_grad(): 

51 input_data = [torch.from_numpy(inp).to(device=devices[0]) for inp in input_data] 

52 model = load_torch_model( 

53 state_dict_weights_descr, load_state=True, devices=devices 

54 ) 

55 scripted_model: Union[ # pyright: ignore[reportUnknownVariableType] 

56 ScriptModule, Tuple[Any, ...] 

57 ] = ( 

58 torch.jit.trace(model, input_data) 

59 if use_tracing 

60 else torch.jit.script(model) 

61 ) 

62 assert not isinstance(scripted_model, tuple), scripted_model 

63 

64 scripted_model.save(output_path) 

65 

66 return TorchscriptWeightsDescr( 

67 source=output_path.absolute(), 

68 pytorch_version=Version(torch.__version__), 

69 parent="pytorch_state_dict", 

70 comment=( 

71 f"Converted with bioimageio.core {__version__}" 

72 + f" with use_tracing={use_tracing}." 

73 ), 

74 )