Coverage for bioimageio/core/weight_converters/torchscript_to_onnx.py: 0%

27 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-01 09:51 +0000

1from pathlib import Path 

2 

3import torch.jit 

4 

5from bioimageio.spec.model.v0_5 import ModelDescr, OnnxWeightsDescr 

6from bioimageio.spec.utils import download 

7 

8from .. import __version__ 

9from ..digest_spec import get_member_id, get_test_inputs 

10from ..proc_setup import get_pre_and_postprocessing 

11from ._utils_onnx import get_dynamic_axes 

12 

13 

14def convert( 

15 model_descr: ModelDescr, 

16 output_path: Path, 

17 *, 

18 verbose: bool = False, 

19 opset_version: int = 15, 

20) -> OnnxWeightsDescr: 

21 """ 

22 Convert model weights from the PyTorch state_dict format to the ONNX format. 

23 

24 Args: 

25 model_descr (Union[v0_4.ModelDescr, v0_5.ModelDescr]): 

26 The model description object that contains the model and its weights. 

27 output_path (Path): 

28 The file path where the ONNX model will be saved. 

29 verbose (bool, optional): 

30 If True, will print out detailed information during the ONNX export process. Defaults to False. 

31 opset_version (int, optional): 

32 The ONNX opset version to use for the export. Defaults to 15. 

33 Raises: 

34 ValueError: 

35 If the provided model does not have weights in the torchscript format. 

36 

37 Returns: 

38 v0_5.OnnxWeightsDescr: 

39 A descriptor object that contains information about the exported ONNX weights. 

40 """ 

41 

42 torchscript_descr = model_descr.weights.torchscript 

43 if torchscript_descr is None: 

44 raise ValueError( 

45 "The provided model does not have weights in the torchscript format" 

46 ) 

47 

48 sample = get_test_inputs(model_descr) 

49 procs = get_pre_and_postprocessing( 

50 model_descr, dataset_for_initial_statistics=[sample] 

51 ) 

52 procs.pre(sample) 

53 inputs_numpy = [ 

54 sample.members[get_member_id(ipt)].data.data for ipt in model_descr.inputs 

55 ] 

56 inputs_torch = [torch.from_numpy(ipt) for ipt in inputs_numpy] 

57 

58 weight_path = download(torchscript_descr).path 

59 model = torch.jit.load(weight_path) # type: ignore 

60 model.to("cpu") 

61 model = model.eval() # type: ignore 

62 

63 with torch.no_grad(): 

64 outputs_original_torch = model(*inputs_torch) # type: ignore 

65 if isinstance(outputs_original_torch, torch.Tensor): 

66 outputs_original_torch = [outputs_original_torch] 

67 

68 _ = torch.onnx.export( 

69 model, # type: ignore 

70 tuple(inputs_torch), 

71 str(output_path), 

72 input_names=[str(d.id) for d in model_descr.inputs], 

73 output_names=[str(d.id) for d in model_descr.outputs], 

74 dynamic_axes=get_dynamic_axes(model_descr), 

75 verbose=verbose, 

76 opset_version=opset_version, 

77 ) 

78 

79 return OnnxWeightsDescr( 

80 source=output_path, 

81 parent="torchscript", 

82 opset_version=opset_version, 

83 comment=f"Converted with bioimageio.core {__version__}.", 

84 )