Coverage for bioimageio/core/weight_converters/pytorch_to_torchscript.py: 95%
20 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 09:51 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 09:51 +0000
1from pathlib import Path
2from typing import Any, Tuple, Union
4import torch
5from torch.jit import ScriptModule
7from bioimageio.spec._internal.version_type import Version
8from bioimageio.spec.model.v0_5 import ModelDescr, TorchscriptWeightsDescr
10from .. import __version__
11from ..backends.pytorch_backend import load_torch_model
14def convert(
15 model_descr: ModelDescr,
16 output_path: Path,
17 *,
18 use_tracing: bool = True,
19) -> TorchscriptWeightsDescr:
20 """
21 Convert model weights from the PyTorch `state_dict` format to TorchScript.
23 Args:
24 model_descr:
25 The model description object that contains the model and its weights in the PyTorch `state_dict` format.
26 output_path:
27 The file path where the TorchScript model will be saved.
28 use_tracing:
29 Whether to use tracing or scripting to export the TorchScript format.
30 - `True`: Use tracing, which is recommended for models with straightforward control flow.
31 - `False`: Use scripting, which is better for models with dynamic control flow (e.g., loops, conditionals).
33 Raises:
34 ValueError:
35 If the provided model does not have weights in the PyTorch `state_dict` format.
37 Returns:
38 A descriptor object that contains information about the exported TorchScript weights.
39 """
40 state_dict_weights_descr = model_descr.weights.pytorch_state_dict
41 if state_dict_weights_descr is None:
42 raise ValueError(
43 "The provided model does not have weights in the pytorch state dict format"
44 )
46 input_data = model_descr.get_input_test_arrays()
48 with torch.no_grad():
49 input_data = [torch.from_numpy(inp) for inp in input_data]
50 model = load_torch_model(state_dict_weights_descr, load_state=True)
51 scripted_model: Union[ # pyright: ignore[reportUnknownVariableType]
52 ScriptModule, Tuple[Any, ...]
53 ] = (
54 torch.jit.trace(model, input_data)
55 if use_tracing
56 else torch.jit.script(model)
57 )
58 assert not isinstance(scripted_model, tuple), scripted_model
60 scripted_model.save(output_path)
62 return TorchscriptWeightsDescr(
63 source=output_path,
64 pytorch_version=Version(torch.__version__),
65 parent="pytorch_state_dict",
66 comment=(
67 f"Converted with bioimageio.core {__version__}"
68 + f" with use_tracing={use_tracing}."
69 ),
70 )