Coverage for src / bioimageio / core / weight_converters / pytorch_to_torchscript.py: 95%
21 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-18 12:35 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-18 12:35 +0000
1from pathlib import Path
2from typing import Any, Optional, Sequence, Tuple, Union
4import torch
5from torch.jit import ScriptModule
7from bioimageio.spec._internal.version_type import Version
8from bioimageio.spec.model.v0_5 import ModelDescr, TorchscriptWeightsDescr
10from .. import __version__
11from ..backends.pytorch_backend import get_devices, load_torch_model
14def convert(
15 model_descr: ModelDescr,
16 output_path: Path,
17 *,
18 use_tracing: bool = True,
19 devices: Optional[Sequence[Union[str, torch.device]]] = None,
20) -> TorchscriptWeightsDescr:
21 """
22 Convert model weights from the PyTorch `state_dict` format to TorchScript.
24 Args:
25 model_descr:
26 The model description object that contains the model and its weights in the PyTorch `state_dict` format.
27 output_path:
28 The file path where the TorchScript model will be saved.
29 use_tracing:
30 Whether to use tracing or scripting to export the TorchScript format.
31 - `True`: Use tracing, which is recommended for models with straightforward control flow.
32 - `False`: Use scripting, which is better for models with dynamic control flow (e.g., loops, conditionals).
34 Raises:
35 ValueError:
36 If the provided model does not have weights in the PyTorch `state_dict` format.
38 Returns:
39 A descriptor object that contains information about the exported TorchScript weights.
40 """
41 state_dict_weights_descr = model_descr.weights.pytorch_state_dict
42 if state_dict_weights_descr is None:
43 raise ValueError(
44 "The provided model does not have weights in the pytorch state dict format"
45 )
47 input_data = model_descr.get_input_test_arrays()
48 devices = get_devices(devices)
50 with torch.no_grad():
51 input_data = [torch.from_numpy(inp).to(device=devices[0]) for inp in input_data]
52 model = load_torch_model(
53 state_dict_weights_descr, load_state=True, devices=devices
54 )
55 scripted_model: Union[ # pyright: ignore[reportUnknownVariableType]
56 ScriptModule, Tuple[Any, ...]
57 ] = (
58 torch.jit.trace(model, input_data)
59 if use_tracing
60 else torch.jit.script(model)
61 )
62 assert not isinstance(scripted_model, tuple), scripted_model
64 scripted_model.save(output_path)
66 return TorchscriptWeightsDescr(
67 source=output_path.absolute(),
68 pytorch_version=Version(torch.__version__),
69 parent="pytorch_state_dict",
70 comment=(
71 f"Converted with bioimageio.core {__version__}"
72 + f" with use_tracing={use_tracing}."
73 ),
74 )