Coverage for src / bioimageio / core / weight_converters / torchscript_to_onnx.py: 95%

19 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-18 12:35 +0000

1from pathlib import Path 

2from typing import Optional, Sequence, Union 

3 

4import torch.jit 

5from torch._export.converter import TS2EPConverter 

6 

7from bioimageio.spec.model.v0_5 import ModelDescr, OnnxWeightsDescr 

8 

9from ..backends.pytorch_backend import get_devices 

10from ._utils_torch_onnx import export_to_onnx, get_torch_sample_inputs 

11 

12 

13def convert( 

14 model_descr: ModelDescr, 

15 output_path: Path, 

16 *, 

17 verbose: bool = False, 

18 opset_version: int = 18, 

19 devices: Optional[Sequence[Union[str, torch.device]]] = None, 

20) -> OnnxWeightsDescr: 

21 """ 

22 Convert model weights from the PyTorch state_dict format to the ONNX format. 

23 

24 Args: 

25 model_descr (Union[v0_4.ModelDescr, v0_5.ModelDescr]): 

26 The model description object that contains the model and its weights. 

27 output_path (Path): 

28 The file path where the ONNX model will be saved. 

29 verbose (bool, optional): 

30 If True, will print out detailed information during the ONNX export process. Defaults to False. 

31 opset_version (int, optional): 

32 The ONNX opset version to use for the export. Defaults to 18. 

33 Raises: 

34 ValueError: 

35 If the provided model does not have weights in the torchscript format. 

36 

37 Returns: 

38 A description of the exported ONNX weights. 

39 """ 

40 

41 torchscript_descr = model_descr.weights.torchscript 

42 if torchscript_descr is None: 

43 raise ValueError( 

44 "The provided model does not have weights in the torchscript format" 

45 ) 

46 

47 weight_reader = torchscript_descr.get_reader() 

48 model = torch.jit.load(weight_reader) # pyright: ignore[reportUnknownVariableType] 

49 devices = get_devices(devices) 

50 model.to(devices[0]) 

51 model = model.eval() # pyright: ignore[reportUnknownVariableType] 

52 

53 torch_sample_inputs = tuple( 

54 t.to(device=devices[0]) for t in get_torch_sample_inputs(model_descr) 

55 ) 

56 exported_program = TS2EPConverter( 

57 model, # pyright: ignore[reportUnknownArgumentType] 

58 torch_sample_inputs, 

59 ).convert() 

60 

61 return export_to_onnx( 

62 model_descr, 

63 exported_program.module(), 

64 output_path, 

65 verbose, 

66 opset_version, 

67 parent="torchscript", 

68 )