Coverage for src/bioimageio/core/weight_converters/torchscript_to_onnx.py: 94%

16 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-13 11:02 +0000

1from pathlib import Path 

2 

3import torch.jit 

4from bioimageio.spec.model.v0_5 import ModelDescr, OnnxWeightsDescr 

5from torch._export.converter import TS2EPConverter 

6 

7from ._utils_torch_onnx import export_to_onnx, get_torch_sample_inputs 

8 

9 

10def convert( 

11 model_descr: ModelDescr, 

12 output_path: Path, 

13 *, 

14 verbose: bool = False, 

15 opset_version: int = 18, 

16) -> OnnxWeightsDescr: 

17 """ 

18 Convert model weights from the PyTorch state_dict format to the ONNX format. 

19 

20 Args: 

21 model_descr (Union[v0_4.ModelDescr, v0_5.ModelDescr]): 

22 The model description object that contains the model and its weights. 

23 output_path (Path): 

24 The file path where the ONNX model will be saved. 

25 verbose (bool, optional): 

26 If True, will print out detailed information during the ONNX export process. Defaults to False. 

27 opset_version (int, optional): 

28 The ONNX opset version to use for the export. Defaults to 18. 

29 Raises: 

30 ValueError: 

31 If the provided model does not have weights in the torchscript format. 

32 

33 Returns: 

34 A description of the exported ONNX weights. 

35 """ 

36 

37 torchscript_descr = model_descr.weights.torchscript 

38 if torchscript_descr is None: 

39 raise ValueError( 

40 "The provided model does not have weights in the torchscript format" 

41 ) 

42 

43 weight_reader = torchscript_descr.get_reader() 

44 model = torch.jit.load(weight_reader) # pyright: ignore[reportUnknownVariableType] 

45 model.to("cpu") 

46 model = model.eval() # pyright: ignore[reportUnknownVariableType] 

47 

48 torch_sample_inputs = get_torch_sample_inputs(model_descr) 

49 exported_program = TS2EPConverter( 

50 model, # pyright: ignore[reportUnknownArgumentType] 

51 torch_sample_inputs, 

52 ).convert() 

53 

54 return export_to_onnx( 

55 model_descr, 

56 exported_program.module(), 

57 output_path, 

58 verbose, 

59 opset_version, 

60 parent="torchscript", 

61 )