Coverage for src/bioimageio/core/weight_converters/torchscript_to_onnx.py: 0%

25 statements  

« prev     ^ index     » next       coverage.py v7.14.2, created at 2026-06-22 16:54 +0000

1from pathlib import Path 

2from typing import Optional, Sequence, Union 

3 

4import torch.jit 

5from loguru import logger 

6from torch._export.converter import TS2EPConverter 

7 

8from bioimageio.spec.model.v0_5 import ModelDescr, OnnxWeightsDescr 

9 

10from ..backends.pytorch_backend import get_devices 

11from ._utils_torch_onnx import export_to_onnx, get_torch_sample_inputs 

12 

13 

14def convert( 

15 model_descr: ModelDescr, 

16 output_path: Path, 

17 *, 

18 verbose: bool = False, 

19 opset_version: int = 18, 

20 devices: Optional[Sequence[Union[str, torch.device]]] = None, 

21) -> OnnxWeightsDescr: 

22 """ 

23 Convert model weights from the PyTorch state_dict format to the ONNX format. 

24 

25 Args: 

26 model_descr (Union[v0_4.ModelDescr, v0_5.ModelDescr]): 

27 The model description object that contains the model and its weights. 

28 output_path (Path): 

29 The file path where the ONNX model will be saved. 

30 verbose (bool, optional): 

31 If True, will print out detailed information during the ONNX export process. Defaults to False. 

32 opset_version (int, optional): 

33 The ONNX opset version to use for the export. Defaults to 18. 

34 Raises: 

35 ValueError: 

36 If the provided model does not have weights in the torchscript format. 

37 

38 Returns: 

39 A description of the exported ONNX weights. 

40 """ 

41 

42 torchscript_descr = model_descr.weights.torchscript 

43 if torchscript_descr is None: 

44 raise ValueError( 

45 "The provided model does not have weights in the torchscript format" 

46 ) 

47 

48 weight_reader = torchscript_descr.get_reader() 

49 model = torch.jit.load(weight_reader) # pyright: ignore[reportUnknownVariableType] 

50 devices = get_devices(devices) 

51 model.to(devices[0]) 

52 model = model.eval() # pyright: ignore[reportUnknownVariableType] 

53 

54 torch_sample_inputs = tuple( 

55 t.to(device=devices[0]) for t in get_torch_sample_inputs(model_descr) 

56 ) 

57 exported_program = TS2EPConverter( 

58 model, # pyright: ignore[reportUnknownArgumentType] 

59 torch_sample_inputs, 

60 ).convert() 

61 exported_module = exported_program.module() 

62 

63 try: 

64 exported_module = exported_module.eval() 

65 except Exception as e: 

66 logger.warning("Failed to set TS2EPConverter program to evaluation mode: {}", e) 

67 

68 return export_to_onnx( 

69 model_descr, 

70 exported_module, 

71 output_path, 

72 verbose, 

73 opset_version, 

74 parent="torchscript", 

75 )