Coverage for bioimageio/core/weight_converters/pytorch_to_torchscript.py: 95%

20 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-01 09:51 +0000

1from pathlib import Path 

2from typing import Any, Tuple, Union 

3 

4import torch 

5from torch.jit import ScriptModule 

6 

7from bioimageio.spec._internal.version_type import Version 

8from bioimageio.spec.model.v0_5 import ModelDescr, TorchscriptWeightsDescr 

9 

10from .. import __version__ 

11from ..backends.pytorch_backend import load_torch_model 

12 

13 

14def convert( 

15 model_descr: ModelDescr, 

16 output_path: Path, 

17 *, 

18 use_tracing: bool = True, 

19) -> TorchscriptWeightsDescr: 

20 """ 

21 Convert model weights from the PyTorch `state_dict` format to TorchScript. 

22 

23 Args: 

24 model_descr: 

25 The model description object that contains the model and its weights in the PyTorch `state_dict` format. 

26 output_path: 

27 The file path where the TorchScript model will be saved. 

28 use_tracing: 

29 Whether to use tracing or scripting to export the TorchScript format. 

30 - `True`: Use tracing, which is recommended for models with straightforward control flow. 

31 - `False`: Use scripting, which is better for models with dynamic control flow (e.g., loops, conditionals). 

32 

33 Raises: 

34 ValueError: 

35 If the provided model does not have weights in the PyTorch `state_dict` format. 

36 

37 Returns: 

38 A descriptor object that contains information about the exported TorchScript weights. 

39 """ 

40 state_dict_weights_descr = model_descr.weights.pytorch_state_dict 

41 if state_dict_weights_descr is None: 

42 raise ValueError( 

43 "The provided model does not have weights in the pytorch state dict format" 

44 ) 

45 

46 input_data = model_descr.get_input_test_arrays() 

47 

48 with torch.no_grad(): 

49 input_data = [torch.from_numpy(inp) for inp in input_data] 

50 model = load_torch_model(state_dict_weights_descr, load_state=True) 

51 scripted_model: Union[ # pyright: ignore[reportUnknownVariableType] 

52 ScriptModule, Tuple[Any, ...] 

53 ] = ( 

54 torch.jit.trace(model, input_data) 

55 if use_tracing 

56 else torch.jit.script(model) 

57 ) 

58 assert not isinstance(scripted_model, tuple), scripted_model 

59 

60 scripted_model.save(output_path) 

61 

62 return TorchscriptWeightsDescr( 

63 source=output_path, 

64 pytorch_version=Version(torch.__version__), 

65 parent="pytorch_state_dict", 

66 comment=( 

67 f"Converted with bioimageio.core {__version__}" 

68 + f" with use_tracing={use_tracing}." 

69 ), 

70 )