Coverage for src / bioimageio / core / weight_converters / _add_weights.py: 66%
80 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 13:23 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 13:23 +0000
1import traceback
2from typing import Optional, Union
4from loguru import logger
5from pydantic import DirectoryPath
7from bioimageio.spec import (
8 InvalidDescr,
9 load_model_description,
10 save_bioimageio_package_as_folder,
11)
12from bioimageio.spec.model.v0_5 import ModelDescr, WeightsFormat
14from .._resource_tests import load_description_and_test
17def add_weights(
18 model_descr: ModelDescr,
19 *,
20 output_path: DirectoryPath,
21 source_format: Optional[WeightsFormat] = None,
22 target_format: Optional[WeightsFormat] = None,
23 verbose: bool = False,
24 allow_tracing: bool = True,
25) -> Union[ModelDescr, InvalidDescr]:
26 """Convert model weights to other formats and add them to the model description
28 Args:
29 output_path: Path to save updated model package to.
30 source_format: convert from a specific weights format.
31 Default: choose automatically from any available.
32 target_format: convert to a specific weights format.
33 Default: attempt to convert to any missing format.
34 verbose: log more (error) output
35 allow_tracing: allow conversion to torchscript by tracing if scripting fails.
37 Returns:
38 A (potentially invalid) model copy stored at `output_path` with added weights if any conversion was possible.
40 """
41 if not isinstance(model_descr, ModelDescr):
42 if model_descr.type == "model" and not isinstance(model_descr, InvalidDescr):
43 raise TypeError(
44 f"Model format {model_descr.format} is not supported, please update"
45 + f" model to format {ModelDescr.implemented_format_version} first."
46 )
48 raise TypeError(type(model_descr))
50 # save model to local folder
51 output_path = save_bioimageio_package_as_folder(
52 model_descr, output_path=output_path
53 )
54 # reload from local folder to make sure we do not edit the given model
55 model_descr = load_model_description(
56 output_path, perform_io_checks=False, format_version="latest"
57 )
59 if source_format is None:
60 available = set(model_descr.weights.available_formats)
61 else:
62 available = {source_format}
64 if target_format is None:
65 missing = set(model_descr.weights.missing_formats)
66 else:
67 missing = {target_format}
69 originally_missing = set(missing)
71 if "pytorch_state_dict" in available and "torchscript" in missing:
72 logger.info(
73 "Attempting to convert 'pytorch_state_dict' weights to 'torchscript'."
74 )
75 from .pytorch_to_torchscript import convert
77 try:
78 torchscript_weights_path = output_path / "weights_torchscript.pt"
79 model_descr.weights.torchscript = convert(
80 model_descr,
81 output_path=torchscript_weights_path,
82 use_tracing=False,
83 )
84 except Exception as e:
85 if verbose:
86 traceback.print_exception(type(e), e, e.__traceback__)
88 logger.error(e)
89 else:
90 available.add("torchscript")
91 missing.discard("torchscript")
93 if allow_tracing and "pytorch_state_dict" in available and "torchscript" in missing:
94 logger.info(
95 "Attempting to convert 'pytorch_state_dict' weights to 'torchscript' by tracing."
96 )
97 from .pytorch_to_torchscript import convert
99 try:
100 torchscript_weights_path = output_path / "weights_torchscript_traced.pt"
102 model_descr.weights.torchscript = convert(
103 model_descr,
104 output_path=torchscript_weights_path,
105 use_tracing=True,
106 )
107 except Exception as e:
108 if verbose:
109 traceback.print_exception(type(e), e, e.__traceback__)
111 logger.error(e)
112 else:
113 available.add("torchscript")
114 missing.discard("torchscript")
116 if "pytorch_state_dict" in available and "onnx" in missing:
117 logger.info("Attempting to convert 'pytorch_state_dict' weights to 'onnx'.")
118 from .pytorch_to_onnx import convert
120 try:
121 onnx_weights_path = output_path / "weights.onnx"
123 model_descr.weights.onnx = convert(
124 model_descr,
125 output_path=onnx_weights_path,
126 verbose=verbose,
127 )
128 except Exception as e:
129 if verbose:
130 traceback.print_exception(type(e), e, e.__traceback__)
132 logger.error(e)
133 else:
134 available.add("onnx")
135 missing.discard("onnx")
137 if "torchscript" in available and "onnx" in missing:
138 logger.info("Attempting to convert 'torchscript' weights to 'onnx'.")
139 from .torchscript_to_onnx import convert
141 try:
142 onnx_weights_path = output_path / "weights.onnx"
143 model_descr.weights.onnx = convert(
144 model_descr,
145 output_path=onnx_weights_path,
146 verbose=verbose,
147 )
148 except Exception as e:
149 if verbose:
150 traceback.print_exception(type(e), e, e.__traceback__)
152 logger.error(e)
153 else:
154 available.add("onnx")
155 missing.discard("onnx")
157 if missing:
158 logger.warning(
159 f"Converting from any of the available weights formats {available} to any"
160 + f" of {missing} failed or is not yet implemented. Please create an issue"
161 + " at https://github.com/bioimage-io/core-bioimage-io-python/issues/new/choose"
162 + " if you would like bioimageio.core to support a particular conversion."
163 )
165 if originally_missing == missing:
166 logger.warning("failed to add any converted weights")
167 return model_descr
168 else:
169 logger.info("added weights formats {}", originally_missing - missing)
170 # resave model with updated rdf.yaml
171 _ = save_bioimageio_package_as_folder(model_descr, output_path=output_path)
172 tested_model_descr = load_description_and_test(
173 model_descr, format_version="latest", expected_type="model"
174 )
175 if not isinstance(tested_model_descr, ModelDescr):
176 logger.error(
177 f"The updated model description at {output_path} did not pass testing."
178 )
180 return tested_model_descr