Coverage for bioimageio/core/weight_converters/pytorch_to_onnx.py: 96%

24 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-01 09:51 +0000

1from pathlib import Path 

2 

3import torch 

4 

5from bioimageio.spec.model.v0_5 import ModelDescr, OnnxWeightsDescr 

6 

7from .. import __version__ 

8from ..backends.pytorch_backend import load_torch_model 

9from ..digest_spec import get_member_id, get_test_inputs 

10from ..proc_setup import get_pre_and_postprocessing 

11from ._utils_onnx import get_dynamic_axes 

12 

13 

14def convert( 

15 model_descr: ModelDescr, 

16 output_path: Path, 

17 *, 

18 verbose: bool = False, 

19 opset_version: int = 15, 

20) -> OnnxWeightsDescr: 

21 """ 

22 Convert model weights from the Torchscript state_dict format to the ONNX format. 

23 

24 Args: 

25 model_descr: 

26 The model description object that contains the model and its weights. 

27 output_path: 

28 The file path where the ONNX model will be saved. 

29 verbose: 

30 If True, will print out detailed information during the ONNX export process. Defaults to False. 

31 opset_version: 

32 The ONNX opset version to use for the export. Defaults to 15. 

33 

34 Raises: 

35 ValueError: 

36 If the provided model does not have weights in the PyTorch state_dict format. 

37 

38 Returns: 

39 A descriptor object that contains information about the exported ONNX weights. 

40 """ 

41 

42 state_dict_weights_descr = model_descr.weights.pytorch_state_dict 

43 if state_dict_weights_descr is None: 

44 raise ValueError( 

45 "The provided model does not have weights in the pytorch state dict format" 

46 ) 

47 

48 sample = get_test_inputs(model_descr) 

49 procs = get_pre_and_postprocessing( 

50 model_descr, dataset_for_initial_statistics=[sample] 

51 ) 

52 procs.pre(sample) 

53 inputs_numpy = [ 

54 sample.members[get_member_id(ipt)].data.data for ipt in model_descr.inputs 

55 ] 

56 inputs_torch = [torch.from_numpy(ipt) for ipt in inputs_numpy] 

57 model = load_torch_model(state_dict_weights_descr, load_state=True) 

58 with torch.no_grad(): 

59 outputs_original_torch = model(*inputs_torch) 

60 if isinstance(outputs_original_torch, torch.Tensor): 

61 outputs_original_torch = [outputs_original_torch] 

62 

63 _ = torch.onnx.export( 

64 model, 

65 tuple(inputs_torch), 

66 str(output_path), 

67 input_names=[str(d.id) for d in model_descr.inputs], 

68 output_names=[str(d.id) for d in model_descr.outputs], 

69 dynamic_axes=get_dynamic_axes(model_descr), 

70 verbose=verbose, 

71 opset_version=opset_version, 

72 ) 

73 

74 return OnnxWeightsDescr( 

75 source=output_path, 

76 parent="pytorch_state_dict", 

77 opset_version=opset_version, 

78 comment=f"Converted with bioimageio.core {__version__}.", 

79 )