Coverage for bioimageio/core/weight_converters/torchscript_to_onnx.py: 0%

26 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-07-01 13:47 +0000

1from pathlib import Path 

2 

3import torch.jit 

4 

5from bioimageio.spec.model.v0_5 import ModelDescr, OnnxWeightsDescr 

6 

7from .. import __version__ 

8from ..digest_spec import get_member_id, get_test_inputs 

9from ..proc_setup import get_pre_and_postprocessing 

10from ._utils_onnx import get_dynamic_axes 

11 

12 

13def convert( 

14 model_descr: ModelDescr, 

15 output_path: Path, 

16 *, 

17 verbose: bool = False, 

18 opset_version: int = 15, 

19) -> OnnxWeightsDescr: 

20 """ 

21 Convert model weights from the PyTorch state_dict format to the ONNX format. 

22 

23 Args: 

24 model_descr (Union[v0_4.ModelDescr, v0_5.ModelDescr]): 

25 The model description object that contains the model and its weights. 

26 output_path (Path): 

27 The file path where the ONNX model will be saved. 

28 verbose (bool, optional): 

29 If True, will print out detailed information during the ONNX export process. Defaults to False. 

30 opset_version (int, optional): 

31 The ONNX opset version to use for the export. Defaults to 15. 

32 Raises: 

33 ValueError: 

34 If the provided model does not have weights in the torchscript format. 

35 

36 Returns: 

37 v0_5.OnnxWeightsDescr: 

38 A descriptor object that contains information about the exported ONNX weights. 

39 """ 

40 

41 torchscript_descr = model_descr.weights.torchscript 

42 if torchscript_descr is None: 

43 raise ValueError( 

44 "The provided model does not have weights in the torchscript format" 

45 ) 

46 

47 sample = get_test_inputs(model_descr) 

48 procs = get_pre_and_postprocessing( 

49 model_descr, dataset_for_initial_statistics=[sample] 

50 ) 

51 procs.pre(sample) 

52 inputs_numpy = [ 

53 sample.members[get_member_id(ipt)].data.data for ipt in model_descr.inputs 

54 ] 

55 inputs_torch = [torch.from_numpy(ipt) for ipt in inputs_numpy] 

56 

57 weight_reader = torchscript_descr.get_reader() 

58 model = torch.jit.load(weight_reader) # type: ignore 

59 model.to("cpu") 

60 model = model.eval() # type: ignore 

61 

62 with torch.no_grad(): 

63 outputs_original_torch = model(*inputs_torch) # type: ignore 

64 if isinstance(outputs_original_torch, torch.Tensor): 

65 outputs_original_torch = [outputs_original_torch] 

66 

67 _ = torch.onnx.export( 

68 model, # type: ignore 

69 tuple(inputs_torch), 

70 str(output_path), 

71 input_names=[str(d.id) for d in model_descr.inputs], 

72 output_names=[str(d.id) for d in model_descr.outputs], 

73 dynamic_axes=get_dynamic_axes(model_descr), 

74 verbose=verbose, 

75 opset_version=opset_version, 

76 ) 

77 

78 return OnnxWeightsDescr( 

79 source=output_path, 

80 parent="torchscript", 

81 opset_version=opset_version, 

82 comment=f"Converted with bioimageio.core {__version__}.", 

83 )