Coverage for bioimageio/core/weight_converter/torch/_torchscript.py: 16%
73 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-19 09:02 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-19 09:02 +0000
1# type: ignore # TODO: type
2from pathlib import Path
3from typing import List, Sequence, Union
5import numpy as np
6from numpy.testing import assert_array_almost_equal
7from typing_extensions import Any, assert_never
9from bioimageio.spec.model import v0_4, v0_5
10from bioimageio.spec.model.v0_5 import Version
12from ._utils import load_torch_model
14try:
15 import torch
16except ImportError:
17 torch = None
20# FIXME: remove Any
21def _check_predictions(
22 model: Any,
23 scripted_model: Any,
24 model_spec: "v0_4.ModelDescr | v0_5.ModelDescr",
25 input_data: Sequence["torch.Tensor"],
26):
27 assert torch is not None
29 def _check(input_: Sequence[torch.Tensor]) -> None:
30 expected_tensors = model(*input_)
31 if isinstance(expected_tensors, torch.Tensor):
32 expected_tensors = [expected_tensors]
33 expected_outputs: List[np.ndarray[Any, Any]] = [
34 out.numpy() for out in expected_tensors
35 ]
37 output_tensors = scripted_model(*input_)
38 if isinstance(output_tensors, torch.Tensor):
39 output_tensors = [output_tensors]
40 outputs: List[np.ndarray[Any, Any]] = [out.numpy() for out in output_tensors]
42 try:
43 for exp, out in zip(expected_outputs, outputs):
44 assert_array_almost_equal(exp, out, decimal=4)
45 except AssertionError as e:
46 raise ValueError(
47 f"Results before and after weights conversion do not agree:\n {str(e)}"
48 )
50 _check(input_data)
52 if len(model_spec.inputs) > 1:
53 return # FIXME: why don't we check multiple inputs?
55 input_descr = model_spec.inputs[0]
56 if isinstance(input_descr, v0_4.InputTensorDescr):
57 if not isinstance(input_descr.shape, v0_4.ParameterizedInputShape):
58 return
59 min_shape = input_descr.shape.min
60 step = input_descr.shape.step
61 else:
62 min_shape: List[int] = []
63 step: List[int] = []
64 for axis in input_descr.axes:
65 if isinstance(axis.size, v0_5.ParameterizedSize):
66 min_shape.append(axis.size.min)
67 step.append(axis.size.step)
68 elif isinstance(axis.size, int):
69 min_shape.append(axis.size)
70 step.append(0)
71 elif axis.size is None:
72 raise NotImplementedError(
73 f"Can't verify inputs that don't specify their shape fully: {axis}"
74 )
75 elif isinstance(axis.size, v0_5.SizeReference):
76 raise NotImplementedError(f"Can't handle axes like '{axis}' yet")
77 else:
78 assert_never(axis.size)
80 half_step = [st // 2 for st in step]
81 max_steps = 4
83 # check that input and output agree for decreasing input sizes
84 for step_factor in range(1, max_steps + 1):
85 slice_ = tuple(
86 slice(None) if st == 0 else slice(step_factor * st, -step_factor * st)
87 for st in half_step
88 )
89 this_input = [inp[slice_] for inp in input_data]
90 this_shape = this_input[0].shape
91 if any(tsh < msh for tsh, msh in zip(this_shape, min_shape)):
92 raise ValueError(
93 f"Mismatched shapes: {this_shape}. Expected at least {min_shape}"
94 )
95 _check(this_input)
98def convert_weights_to_torchscript(
99 model_descr: Union[v0_4.ModelDescr, v0_5.ModelDescr],
100 output_path: Path,
101 use_tracing: bool = True,
102) -> v0_5.TorchscriptWeightsDescr:
103 """Convert model weights from format 'pytorch_state_dict' to 'torchscript'.
105 Args:
106 model_descr: location of the resource for the input bioimageio model
107 output_path: where to save the torchscript weights
108 use_tracing: whether to use tracing or scripting to export the torchscript format
109 """
111 state_dict_weights_descr = model_descr.weights.pytorch_state_dict
112 if state_dict_weights_descr is None:
113 raise ValueError(
114 "The provided model does not have weights in the pytorch state dict format"
115 )
117 input_data = model_descr.get_input_test_arrays()
119 with torch.no_grad():
120 input_data = [torch.from_numpy(inp.astype("float32")) for inp in input_data]
122 model = load_torch_model(state_dict_weights_descr)
124 # FIXME: remove Any
125 if use_tracing:
126 scripted_model: Any = torch.jit.trace(model, input_data)
127 else:
128 scripted_model: Any = torch.jit.script(model)
130 _check_predictions(
131 model=model,
132 scripted_model=scripted_model,
133 model_spec=model_descr,
134 input_data=input_data,
135 )
137 # save the torchscript model
138 scripted_model.save(
139 str(output_path)
140 ) # does not support Path, so need to cast to str
142 return v0_5.TorchscriptWeightsDescr(
143 source=output_path,
144 pytorch_version=Version(torch.__version__),
145 parent="pytorch_state_dict",
146 )