Coverage for src/bioimageio/core/weight_converters/pytorch_to_onnx.py: 90%

10 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-13 11:02 +0000

1from pathlib import Path 

2 

3from bioimageio.spec.model.v0_5 import ModelDescr, OnnxWeightsDescr 

4 

5from ..backends.pytorch_backend import load_torch_model 

6from ._utils_torch_onnx import export_to_onnx 

7 

8 

9def convert( 

10 model_descr: ModelDescr, 

11 output_path: Path, 

12 *, 

13 verbose: bool = False, 

14 opset_version: int = 18, 

15) -> OnnxWeightsDescr: 

16 """ 

17 Convert model weights from the Torchscript state_dict format to the ONNX format. 

18 

19 Args: 

20 model_descr: 

21 The model description object that contains the model and its weights. 

22 output_path: 

23 The file path where the ONNX model will be saved. 

24 verbose: 

25 If True, will print out detailed information during the ONNX export process. Defaults to False. 

26 opset_version: 

27 The ONNX opset version to use for the export. Defaults to 18. 

28 

29 Raises: 

30 ValueError: 

31 If the provided model does not have weights in the PyTorch state_dict format. 

32 

33 Returns: 

34 A description of the exported ONNX weights. 

35 """ 

36 

37 state_dict_weights_descr = model_descr.weights.pytorch_state_dict 

38 if state_dict_weights_descr is None: 

39 raise ValueError( 

40 "The provided model does not have weights in the pytorch state dict format" 

41 ) 

42 

43 model = load_torch_model(state_dict_weights_descr, load_state=True) 

44 

45 return export_to_onnx( 

46 model_descr, 

47 model, 

48 output_path, 

49 verbose, 

50 opset_version, 

51 parent="pytorch_state_dict", 

52 )