Coverage for bioimageio/spec/model/_v0_3_converter.py: 5%
76 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-02 14:21 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-02 14:21 +0000
1# type: ignore
2from typing import Any, Dict
5def convert_model_from_v0_3_to_0_4_0(data: Dict[str, Any]) -> None:
6 """auto converts model 'data' to newest format"""
8 if "format_version" not in data:
9 return
11 if data["format_version"] == "0.3.0":
12 # no breaking change, bump to 0.3.1
13 data["format_version"] = "0.3.1"
15 if data["format_version"] == "0.3.1":
16 data = _convert_model_v0_3_1_to_v0_3_2(data)
18 if data["format_version"] == "0.3.2":
19 data = _convert_model_v0_3_2_to_v0_3_3(data)
21 if data["format_version"] in ("0.3.3", "0.3.4", "0.3.5"):
22 data["format_version"] = "0.3.6"
24 if data["format_version"] != "0.3.6":
25 return
27 # remove 'future' from config if no other than the used future entries exist
28 config = data.get("config", {})
29 if config.get("future") == {}:
30 del config["future"]
32 # remove 'config' if now empty
33 if data.get("config") == {}:
34 del data["config"]
36 data.pop("language", None)
37 data.pop("framework", None)
39 architecture = data.pop("source", None)
40 architecture_sha256 = data.pop("sha256", None)
41 kwargs = data.pop("kwargs", None)
42 pytorch_state_dict_weights_entry = data.get("weights", {}).get("pytorch_state_dict")
43 if pytorch_state_dict_weights_entry is not None:
44 if architecture is not None:
45 pytorch_state_dict_weights_entry["architecture"] = architecture
47 if architecture_sha256 is not None:
48 pytorch_state_dict_weights_entry["architecture_sha256"] = (
49 architecture_sha256
50 )
52 if kwargs is not None:
53 pytorch_state_dict_weights_entry["kwargs"] = kwargs
55 torchscript_weights_entry = data.get("weights", {}).pop("pytorch_script", None)
56 if torchscript_weights_entry is not None:
57 data.setdefault("weights", {})["torchscript"] = torchscript_weights_entry
59 data["format_version"] = "0.4.0"
62def _convert_model_v0_3_1_to_v0_3_2(data: Dict[str, Any]) -> Dict[str, Any]:
63 data["type"] = "model"
64 data["format_version"] = "0.3.2"
65 future = data.get("config", {}).get("future", {}).pop("0.3.2", {})
67 authors = data.get("authors")
68 if isinstance(authors, list):
69 data["authors"] = [{"name": name} for name in authors]
70 authors_update = future.get("authors")
71 if authors_update is not None:
72 for a, u in zip(data["authors"], authors_update):
73 a.update(u)
75 # packaged_by
76 packaged_by = data.get("packaged_by")
77 if packaged_by is not None:
78 data["packaged_by"] = [{"name": name} for name in data["packaged_by"]]
79 packaged_by_update = future.get("packaged_by")
80 if packaged_by_update is not None:
81 for a, u in zip(data["packaged_by"], packaged_by_update):
82 a.update(u)
84 # authors of weights
85 weights = data.get("weights")
86 if isinstance(weights, dict):
87 for weights_format, weights_entry in weights.items():
88 if "authors" not in weights_entry:
89 continue
91 weights_entry["authors"] = [
92 {"name": name} for name in weights_entry["authors"]
93 ]
94 authors_update = (
95 future.get("weights", {}).get(weights_format, {}).get("authors")
96 )
97 if authors_update is not None:
98 for a, u in zip(weights_entry["authors"], authors_update):
99 a.update(u)
101 # model version
102 if "version" in future:
103 data["version"] = future.pop("version")
105 return data
108def _convert_model_v0_3_2_to_v0_3_3(data: Dict[str, Any]) -> Dict[str, Any]:
109 data["format_version"] = "0.3.3"
110 if "outputs" in data:
111 for out in data["outputs"]:
112 if "shape" in out:
113 shape = out["shape"]
114 if isinstance(shape, dict) and "reference_input" in shape:
115 shape["reference_tensor"] = shape.pop("reference_input")
117 return data