Coverage for bioimageio/core/weight_converters/_add_weights.py: 34%
82 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 09:51 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 09:51 +0000
1import traceback
2from typing import Optional
4from loguru import logger
5from pydantic import DirectoryPath
7from bioimageio.spec import (
8 InvalidDescr,
9 load_model_description,
10 save_bioimageio_package_as_folder,
11)
12from bioimageio.spec.model.v0_5 import ModelDescr, WeightsFormat
14from .._resource_tests import load_description_and_test
17def add_weights(
18 model_descr: ModelDescr,
19 *,
20 output_path: DirectoryPath,
21 source_format: Optional[WeightsFormat] = None,
22 target_format: Optional[WeightsFormat] = None,
23 verbose: bool = False,
24) -> Optional[ModelDescr]:
25 """Convert model weights to other formats and add them to the model description
27 Args:
28 output_path: Path to save updated model package to.
29 source_format: convert from a specific weights format.
30 Default: choose automatically from any available.
31 target_format: convert to a specific weights format.
32 Default: attempt to convert to any missing format.
33 devices: Devices that may be used during conversion.
34 verbose: log more (error) output
36 Returns:
37 - An updated model description if any converted weights were added.
38 - `None` if no conversion was possible.
39 """
40 if not isinstance(model_descr, ModelDescr):
41 if model_descr.type == "model" and not isinstance(model_descr, InvalidDescr):
42 raise TypeError(
43 f"Model format {model_descr.format} is not supported, please update"
44 + f" model to format {ModelDescr.implemented_format_version} first."
45 )
47 raise TypeError(type(model_descr))
49 # save model to local folder
50 output_path = save_bioimageio_package_as_folder(
51 model_descr, output_path=output_path
52 )
53 # reload from local folder to make sure we do not edit the given model
54 _model_descr = load_model_description(output_path, perform_io_checks=False)
55 assert isinstance(_model_descr, ModelDescr)
56 model_descr = _model_descr
57 del _model_descr
59 if source_format is None:
60 available = set(model_descr.weights.available_formats)
61 else:
62 available = {source_format}
64 if target_format is None:
65 missing = set(model_descr.weights.missing_formats)
66 else:
67 missing = {target_format}
69 originally_missing = set(missing)
71 if "pytorch_state_dict" in available and "torchscript" in missing:
72 logger.info(
73 "Attempting to convert 'pytorch_state_dict' weights to 'torchscript'."
74 )
75 from .pytorch_to_torchscript import convert
77 try:
78 torchscript_weights_path = output_path / "weights_torchscript.pt"
79 model_descr.weights.torchscript = convert(
80 model_descr,
81 output_path=torchscript_weights_path,
82 use_tracing=False,
83 )
84 except Exception as e:
85 if verbose:
86 traceback.print_exception(e)
88 logger.error(e)
89 else:
90 available.add("torchscript")
91 missing.discard("torchscript")
93 if "pytorch_state_dict" in available and "torchscript" in missing:
94 logger.info(
95 "Attempting to convert 'pytorch_state_dict' weights to 'torchscript' by tracing."
96 )
97 from .pytorch_to_torchscript import convert
99 try:
100 torchscript_weights_path = output_path / "weights_torchscript_traced.pt"
102 model_descr.weights.torchscript = convert(
103 model_descr,
104 output_path=torchscript_weights_path,
105 use_tracing=True,
106 )
107 except Exception as e:
108 if verbose:
109 traceback.print_exception(e)
111 logger.error(e)
112 else:
113 available.add("torchscript")
114 missing.discard("torchscript")
116 if "torchscript" in available and "onnx" in missing:
117 logger.info("Attempting to convert 'torchscript' weights to 'onnx'.")
118 from .torchscript_to_onnx import convert
120 try:
121 onnx_weights_path = output_path / "weights.onnx"
122 model_descr.weights.onnx = convert(
123 model_descr,
124 output_path=onnx_weights_path,
125 )
126 except Exception as e:
127 if verbose:
128 traceback.print_exception(e)
130 logger.error(e)
131 else:
132 available.add("onnx")
133 missing.discard("onnx")
135 if "pytorch_state_dict" in available and "onnx" in missing:
136 logger.info("Attempting to convert 'pytorch_state_dict' weights to 'onnx'.")
137 from .pytorch_to_onnx import convert
139 try:
140 onnx_weights_path = output_path / "weights.onnx"
142 model_descr.weights.onnx = convert(
143 model_descr,
144 output_path=onnx_weights_path,
145 verbose=verbose,
146 )
147 except Exception as e:
148 if verbose:
149 traceback.print_exception(e)
151 logger.error(e)
152 else:
153 available.add("onnx")
154 missing.discard("onnx")
156 if missing:
157 logger.warning(
158 f"Converting from any of the available weights formats {available} to any"
159 + f" of {missing} failed or is not yet implemented. Please create an issue"
160 + " at https://github.com/bioimage-io/core-bioimage-io-python/issues/new/choose"
161 + " if you would like bioimageio.core to support a particular conversion."
162 )
164 if originally_missing == missing:
165 logger.warning("failed to add any converted weights")
166 return None
167 else:
168 logger.info("added weights formats {}", originally_missing - missing)
169 # resave model with updated rdf.yaml
170 _ = save_bioimageio_package_as_folder(model_descr, output_path=output_path)
171 tested_model_descr = load_description_and_test(model_descr)
172 assert isinstance(tested_model_descr, ModelDescr)
173 return tested_model_descr