Coverage for src / bioimageio / core / weight_converters / pytorch_to_onnx.py: 91%

11 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-18 12:35 +0000

1from pathlib import Path 

2from typing import Optional, Sequence 

3 

4from bioimageio.spec.model.v0_5 import ModelDescr, OnnxWeightsDescr 

5 

6from ..backends.pytorch_backend import load_torch_model 

7from ._utils_torch_onnx import export_to_onnx 

8 

9 

10def convert( 

11 model_descr: ModelDescr, 

12 output_path: Path, 

13 *, 

14 verbose: bool = False, 

15 opset_version: int = 18, 

16 devices: Optional[Sequence[str]] = None, 

17) -> OnnxWeightsDescr: 

18 """ 

19 Convert model weights from the Torchscript state_dict format to the ONNX format. 

20 

21 Args: 

22 model_descr: 

23 The model description object that contains the model and its weights. 

24 output_path: 

25 The file path where the ONNX model will be saved. 

26 verbose: 

27 If True, will print out detailed information during the ONNX export process. Defaults to False. 

28 opset_version: 

29 The ONNX opset version to use for the export. Defaults to 18. 

30 

31 Raises: 

32 ValueError: 

33 If the provided model does not have weights in the PyTorch state_dict format. 

34 

35 Returns: 

36 A description of the exported ONNX weights. 

37 """ 

38 

39 state_dict_weights_descr = model_descr.weights.pytorch_state_dict 

40 if state_dict_weights_descr is None: 

41 raise ValueError( 

42 "The provided model does not have weights in the pytorch state dict format" 

43 ) 

44 

45 model = load_torch_model( 

46 state_dict_weights_descr, load_state=True, devices=devices 

47 ).eval() 

48 

49 return export_to_onnx( 

50 model_descr, 

51 model, 

52 output_path, 

53 verbose, 

54 opset_version, 

55 parent="pytorch_state_dict", 

56 )