Skip to content

weight_converters ¤

Modules:

Name Description
keras_to_tensorflow
pytorch_to_onnx
pytorch_to_torchscript
torchscript_to_onnx

Functions:

Name Description
add_weights

Convert model weights to other formats and add them to the model description

add_weights ¤

add_weights(model_descr: ModelDescr, *, output_path: DirectoryPath, source_format: Optional[WeightsFormat] = None, target_format: Optional[WeightsFormat] = None, verbose: bool = False, allow_tracing: bool = True) -> Union[ModelDescr, InvalidDescr]

Convert model weights to other formats and add them to the model description

Parameters:

Name Type Description Default

output_path ¤

DirectoryPath

Path to save updated model package to.

required

source_format ¤

Optional[WeightsFormat]

convert from a specific weights format. Default: choose automatically from any available.

None

target_format ¤

Optional[WeightsFormat]

convert to a specific weights format. Default: attempt to convert to any missing format.

None

verbose ¤

bool

log more (error) output

False

allow_tracing ¤

bool

allow conversion to torchscript by tracing if scripting fails.

True

Returns:

Type Description
Union[ModelDescr, InvalidDescr]

A (potentially invalid) model copy stored at output_path with added weights if any conversion was possible.

Source code in src/bioimageio/core/weight_converters/_add_weights.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
def add_weights(
    model_descr: ModelDescr,
    *,
    output_path: DirectoryPath,
    source_format: Optional[WeightsFormat] = None,
    target_format: Optional[WeightsFormat] = None,
    verbose: bool = False,
    allow_tracing: bool = True,
) -> Union[ModelDescr, InvalidDescr]:
    """Convert model weights to other formats and add them to the model description

    Args:
        output_path: Path to save updated model package to.
        source_format: convert from a specific weights format.
                       Default: choose automatically from any available.
        target_format: convert to a specific weights format.
                       Default: attempt to convert to any missing format.
        verbose: log more (error) output
        allow_tracing: allow conversion to torchscript by tracing if scripting fails.

    Returns:
        A (potentially invalid) model copy stored at `output_path` with added weights if any conversion was possible.

    """
    if not isinstance(model_descr, ModelDescr):
        if model_descr.type == "model" and not isinstance(model_descr, InvalidDescr):
            raise TypeError(
                f"Model format {model_descr.format} is not supported, please update"
                + f" model to format {ModelDescr.implemented_format_version} first."
            )

        raise TypeError(type(model_descr))

    # save model to local folder
    output_path = save_bioimageio_package_as_folder(
        model_descr, output_path=output_path
    )
    # reload from local folder to make sure we do not edit the given model
    model_descr = load_model_description(
        output_path, perform_io_checks=False, format_version="latest"
    )

    if source_format is None:
        available = set(model_descr.weights.available_formats)
    else:
        available = {source_format}

    if target_format is None:
        missing = set(model_descr.weights.missing_formats)
    else:
        missing = {target_format}

    originally_missing = set(missing)

    if "pytorch_state_dict" in available and "torchscript" in missing:
        logger.info(
            "Attempting to convert 'pytorch_state_dict' weights to 'torchscript'."
        )
        from .pytorch_to_torchscript import convert

        try:
            torchscript_weights_path = output_path / "weights_torchscript.pt"
            model_descr.weights.torchscript = convert(
                model_descr,
                output_path=torchscript_weights_path,
                use_tracing=False,
            )
        except Exception as e:
            if verbose:
                traceback.print_exception(type(e), e, e.__traceback__)

            logger.error(e)
        else:
            available.add("torchscript")
            missing.discard("torchscript")

    if allow_tracing and "pytorch_state_dict" in available and "torchscript" in missing:
        logger.info(
            "Attempting to convert 'pytorch_state_dict' weights to 'torchscript' by tracing."
        )
        from .pytorch_to_torchscript import convert

        try:
            torchscript_weights_path = output_path / "weights_torchscript_traced.pt"

            model_descr.weights.torchscript = convert(
                model_descr,
                output_path=torchscript_weights_path,
                use_tracing=True,
            )
        except Exception as e:
            if verbose:
                traceback.print_exception(type(e), e, e.__traceback__)

            logger.error(e)
        else:
            available.add("torchscript")
            missing.discard("torchscript")

    if "pytorch_state_dict" in available and "onnx" in missing:
        logger.info("Attempting to convert 'pytorch_state_dict' weights to 'onnx'.")
        from .pytorch_to_onnx import convert

        try:
            onnx_weights_path = output_path / "weights.onnx"

            model_descr.weights.onnx = convert(
                model_descr,
                output_path=onnx_weights_path,
                verbose=verbose,
            )
        except Exception as e:
            if verbose:
                traceback.print_exception(type(e), e, e.__traceback__)

            logger.error(e)
        else:
            available.add("onnx")
            missing.discard("onnx")

    if "torchscript" in available and "onnx" in missing:
        logger.info("Attempting to convert 'torchscript' weights to 'onnx'.")
        from .torchscript_to_onnx import convert

        try:
            onnx_weights_path = output_path / "weights.onnx"
            model_descr.weights.onnx = convert(
                model_descr,
                output_path=onnx_weights_path,
                verbose=verbose,
            )
        except Exception as e:
            if verbose:
                traceback.print_exception(type(e), e, e.__traceback__)

            logger.error(e)
        else:
            available.add("onnx")
            missing.discard("onnx")

    if missing:
        logger.warning(
            f"Converting from any of the available weights formats {available} to any"
            + f" of {missing} failed or is not yet implemented. Please create an issue"
            + " at https://github.com/bioimage-io/core-bioimage-io-python/issues/new/choose"
            + " if you would like bioimageio.core to support a particular conversion."
        )

    if originally_missing == missing:
        logger.warning("failed to add any converted weights")
        return model_descr
    else:
        logger.info("added weights formats {}", originally_missing - missing)
        # resave model with updated rdf.yaml
        _ = save_bioimageio_package_as_folder(model_descr, output_path=output_path)
        tested_model_descr = load_description_and_test(
            model_descr, format_version="latest", expected_type="model"
        )
        if not isinstance(tested_model_descr, ModelDescr):
            logger.error(
                f"The updated model description at {output_path} did not pass testing."
            )

        return tested_model_descr