Coverage for src/bioimageio/core/weight_converters/_add_weights.py: 10%
80 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-13 11:02 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-13 11:02 +0000
1import traceback
2from typing import Optional, Union
4from bioimageio.spec import (
5 InvalidDescr,
6 load_model_description,
7 save_bioimageio_package_as_folder,
8)
9from bioimageio.spec.model.v0_5 import ModelDescr, WeightsFormat
10from loguru import logger
11from pydantic import DirectoryPath
13from .._resource_tests import load_description_and_test
16def add_weights(
17 model_descr: ModelDescr,
18 *,
19 output_path: DirectoryPath,
20 source_format: Optional[WeightsFormat] = None,
21 target_format: Optional[WeightsFormat] = None,
22 verbose: bool = False,
23 allow_tracing: bool = True,
24) -> Union[ModelDescr, InvalidDescr]:
25 """Convert model weights to other formats and add them to the model description
27 Args:
28 output_path: Path to save updated model package to.
29 source_format: convert from a specific weights format.
30 Default: choose automatically from any available.
31 target_format: convert to a specific weights format.
32 Default: attempt to convert to any missing format.
33 devices: Devices that may be used during conversion.
34 verbose: log more (error) output
36 Returns:
37 A (potentially invalid) model copy stored at `output_path` with added weights if any conversion was possible.
39 """
40 if not isinstance(model_descr, ModelDescr):
41 if model_descr.type == "model" and not isinstance(model_descr, InvalidDescr):
42 raise TypeError(
43 f"Model format {model_descr.format} is not supported, please update"
44 + f" model to format {ModelDescr.implemented_format_version} first."
45 )
47 raise TypeError(type(model_descr))
49 # save model to local folder
50 output_path = save_bioimageio_package_as_folder(
51 model_descr, output_path=output_path
52 )
53 # reload from local folder to make sure we do not edit the given model
54 model_descr = load_model_description(
55 output_path, perform_io_checks=False, format_version="latest"
56 )
58 if source_format is None:
59 available = set(model_descr.weights.available_formats)
60 else:
61 available = {source_format}
63 if target_format is None:
64 missing = set(model_descr.weights.missing_formats)
65 else:
66 missing = {target_format}
68 originally_missing = set(missing)
70 if "pytorch_state_dict" in available and "torchscript" in missing:
71 logger.info(
72 "Attempting to convert 'pytorch_state_dict' weights to 'torchscript'."
73 )
74 from .pytorch_to_torchscript import convert
76 try:
77 torchscript_weights_path = output_path / "weights_torchscript.pt"
78 model_descr.weights.torchscript = convert(
79 model_descr,
80 output_path=torchscript_weights_path,
81 use_tracing=False,
82 )
83 except Exception as e:
84 if verbose:
85 traceback.print_exception(type(e), e, e.__traceback__)
87 logger.error(e)
88 else:
89 available.add("torchscript")
90 missing.discard("torchscript")
92 if allow_tracing and "pytorch_state_dict" in available and "torchscript" in missing:
93 logger.info(
94 "Attempting to convert 'pytorch_state_dict' weights to 'torchscript' by tracing."
95 )
96 from .pytorch_to_torchscript import convert
98 try:
99 torchscript_weights_path = output_path / "weights_torchscript_traced.pt"
101 model_descr.weights.torchscript = convert(
102 model_descr,
103 output_path=torchscript_weights_path,
104 use_tracing=True,
105 )
106 except Exception as e:
107 if verbose:
108 traceback.print_exception(type(e), e, e.__traceback__)
110 logger.error(e)
111 else:
112 available.add("torchscript")
113 missing.discard("torchscript")
115 if "pytorch_state_dict" in available and "onnx" in missing:
116 logger.info("Attempting to convert 'pytorch_state_dict' weights to 'onnx'.")
117 from .pytorch_to_onnx import convert
119 try:
120 onnx_weights_path = output_path / "weights.onnx"
122 model_descr.weights.onnx = convert(
123 model_descr,
124 output_path=onnx_weights_path,
125 verbose=verbose,
126 )
127 except Exception as e:
128 if verbose:
129 traceback.print_exception(type(e), e, e.__traceback__)
131 logger.error(e)
132 else:
133 available.add("onnx")
134 missing.discard("onnx")
136 if "torchscript" in available and "onnx" in missing:
137 logger.info("Attempting to convert 'torchscript' weights to 'onnx'.")
138 from .torchscript_to_onnx import convert
140 try:
141 onnx_weights_path = output_path / "weights.onnx"
142 model_descr.weights.onnx = convert(
143 model_descr,
144 output_path=onnx_weights_path,
145 verbose=verbose,
146 )
147 except Exception as e:
148 if verbose:
149 traceback.print_exception(type(e), e, e.__traceback__)
151 logger.error(e)
152 else:
153 available.add("onnx")
154 missing.discard("onnx")
156 if missing:
157 logger.warning(
158 f"Converting from any of the available weights formats {available} to any"
159 + f" of {missing} failed or is not yet implemented. Please create an issue"
160 + " at https://github.com/bioimage-io/core-bioimage-io-python/issues/new/choose"
161 + " if you would like bioimageio.core to support a particular conversion."
162 )
164 if originally_missing == missing:
165 logger.warning("failed to add any converted weights")
166 return model_descr
167 else:
168 logger.info("added weights formats {}", originally_missing - missing)
169 # resave model with updated rdf.yaml
170 _ = save_bioimageio_package_as_folder(model_descr, output_path=output_path)
171 tested_model_descr = load_description_and_test(
172 model_descr, format_version="latest", expected_type="model"
173 )
174 if not isinstance(tested_model_descr, ModelDescr):
175 logger.error(
176 f"The updated model description at {output_path} did not pass testing."
177 )
179 return tested_model_descr