Coverage for src/bioimageio/core/weight_converters/_add_weights.py: 10%
80 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-09-22 09:21 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-09-22 09:21 +0000
1import traceback
2from typing import Optional, Union
4from loguru import logger
5from pydantic import DirectoryPath
7from bioimageio.spec import (
8 InvalidDescr,
9 load_model_description,
10 save_bioimageio_package_as_folder,
11)
12from bioimageio.spec.model.v0_5 import ModelDescr, WeightsFormat
14from .._resource_tests import load_description_and_test
17def add_weights(
18 model_descr: ModelDescr,
19 *,
20 output_path: DirectoryPath,
21 source_format: Optional[WeightsFormat] = None,
22 target_format: Optional[WeightsFormat] = None,
23 verbose: bool = False,
24 allow_tracing: bool = True,
25) -> Union[ModelDescr, InvalidDescr]:
26 """Convert model weights to other formats and add them to the model description
28 Args:
29 output_path: Path to save updated model package to.
30 source_format: convert from a specific weights format.
31 Default: choose automatically from any available.
32 target_format: convert to a specific weights format.
33 Default: attempt to convert to any missing format.
34 devices: Devices that may be used during conversion.
35 verbose: log more (error) output
37 Returns:
38 A (potentially invalid) model copy stored at `output_path` with added weights if any conversion was possible.
40 """
41 if not isinstance(model_descr, ModelDescr):
42 if model_descr.type == "model" and not isinstance(model_descr, InvalidDescr):
43 raise TypeError(
44 f"Model format {model_descr.format} is not supported, please update"
45 + f" model to format {ModelDescr.implemented_format_version} first."
46 )
48 raise TypeError(type(model_descr))
50 # save model to local folder
51 output_path = save_bioimageio_package_as_folder(
52 model_descr, output_path=output_path
53 )
54 # reload from local folder to make sure we do not edit the given model
55 model_descr = load_model_description(
56 output_path, perform_io_checks=False, format_version="latest"
57 )
59 if source_format is None:
60 available = set(model_descr.weights.available_formats)
61 else:
62 available = {source_format}
64 if target_format is None:
65 missing = set(model_descr.weights.missing_formats)
66 else:
67 missing = {target_format}
69 originally_missing = set(missing)
71 if "pytorch_state_dict" in available and "torchscript" in missing:
72 logger.info(
73 "Attempting to convert 'pytorch_state_dict' weights to 'torchscript'."
74 )
75 from .pytorch_to_torchscript import convert
77 try:
78 torchscript_weights_path = output_path / "weights_torchscript.pt"
79 model_descr.weights.torchscript = convert(
80 model_descr,
81 output_path=torchscript_weights_path,
82 use_tracing=False,
83 )
84 except Exception as e:
85 if verbose:
86 traceback.print_exception(type(e), e, e.__traceback__)
88 logger.error(e)
89 else:
90 available.add("torchscript")
91 missing.discard("torchscript")
93 if allow_tracing and "pytorch_state_dict" in available and "torchscript" in missing:
94 logger.info(
95 "Attempting to convert 'pytorch_state_dict' weights to 'torchscript' by tracing."
96 )
97 from .pytorch_to_torchscript import convert
99 try:
100 torchscript_weights_path = output_path / "weights_torchscript_traced.pt"
102 model_descr.weights.torchscript = convert(
103 model_descr,
104 output_path=torchscript_weights_path,
105 use_tracing=True,
106 )
107 except Exception as e:
108 if verbose:
109 traceback.print_exception(type(e), e, e.__traceback__)
111 logger.error(e)
112 else:
113 available.add("torchscript")
114 missing.discard("torchscript")
116 if "torchscript" in available and "onnx" in missing:
117 logger.info("Attempting to convert 'torchscript' weights to 'onnx'.")
118 from .torchscript_to_onnx import convert
120 try:
121 onnx_weights_path = output_path / "weights.onnx"
122 model_descr.weights.onnx = convert(
123 model_descr,
124 output_path=onnx_weights_path,
125 )
126 except Exception as e:
127 if verbose:
128 traceback.print_exception(type(e), e, e.__traceback__)
130 logger.error(e)
131 else:
132 available.add("onnx")
133 missing.discard("onnx")
135 if "pytorch_state_dict" in available and "onnx" in missing:
136 logger.info("Attempting to convert 'pytorch_state_dict' weights to 'onnx'.")
137 from .pytorch_to_onnx import convert
139 try:
140 onnx_weights_path = output_path / "weights.onnx"
142 model_descr.weights.onnx = convert(
143 model_descr,
144 output_path=onnx_weights_path,
145 verbose=verbose,
146 )
147 except Exception as e:
148 if verbose:
149 traceback.print_exception(type(e), e, e.__traceback__)
151 logger.error(e)
152 else:
153 available.add("onnx")
154 missing.discard("onnx")
156 if missing:
157 logger.warning(
158 f"Converting from any of the available weights formats {available} to any"
159 + f" of {missing} failed or is not yet implemented. Please create an issue"
160 + " at https://github.com/bioimage-io/core-bioimage-io-python/issues/new/choose"
161 + " if you would like bioimageio.core to support a particular conversion."
162 )
164 if originally_missing == missing:
165 logger.warning("failed to add any converted weights")
166 return model_descr
167 else:
168 logger.info("added weights formats {}", originally_missing - missing)
169 # resave model with updated rdf.yaml
170 _ = save_bioimageio_package_as_folder(model_descr, output_path=output_path)
171 tested_model_descr = load_description_and_test(
172 model_descr, format_version="latest", expected_type="model"
173 )
174 if not isinstance(tested_model_descr, ModelDescr):
175 logger.error(
176 f"The updated model description at {output_path} did not pass testing."
177 )
179 return tested_model_descr