Coverage for bioimageio/spec/model/_v0_3_converter.py: 5%

76 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-02 14:21 +0000

1# type: ignore 

2from typing import Any, Dict 

3 

4 

5def convert_model_from_v0_3_to_0_4_0(data: Dict[str, Any]) -> None: 

6 """auto converts model 'data' to newest format""" 

7 

8 if "format_version" not in data: 

9 return 

10 

11 if data["format_version"] == "0.3.0": 

12 # no breaking change, bump to 0.3.1 

13 data["format_version"] = "0.3.1" 

14 

15 if data["format_version"] == "0.3.1": 

16 data = _convert_model_v0_3_1_to_v0_3_2(data) 

17 

18 if data["format_version"] == "0.3.2": 

19 data = _convert_model_v0_3_2_to_v0_3_3(data) 

20 

21 if data["format_version"] in ("0.3.3", "0.3.4", "0.3.5"): 

22 data["format_version"] = "0.3.6" 

23 

24 if data["format_version"] != "0.3.6": 

25 return 

26 

27 # remove 'future' from config if no other than the used future entries exist 

28 config = data.get("config", {}) 

29 if config.get("future") == {}: 

30 del config["future"] 

31 

32 # remove 'config' if now empty 

33 if data.get("config") == {}: 

34 del data["config"] 

35 

36 data.pop("language", None) 

37 data.pop("framework", None) 

38 

39 architecture = data.pop("source", None) 

40 architecture_sha256 = data.pop("sha256", None) 

41 kwargs = data.pop("kwargs", None) 

42 pytorch_state_dict_weights_entry = data.get("weights", {}).get("pytorch_state_dict") 

43 if pytorch_state_dict_weights_entry is not None: 

44 if architecture is not None: 

45 pytorch_state_dict_weights_entry["architecture"] = architecture 

46 

47 if architecture_sha256 is not None: 

48 pytorch_state_dict_weights_entry["architecture_sha256"] = ( 

49 architecture_sha256 

50 ) 

51 

52 if kwargs is not None: 

53 pytorch_state_dict_weights_entry["kwargs"] = kwargs 

54 

55 torchscript_weights_entry = data.get("weights", {}).pop("pytorch_script", None) 

56 if torchscript_weights_entry is not None: 

57 data.setdefault("weights", {})["torchscript"] = torchscript_weights_entry 

58 

59 data["format_version"] = "0.4.0" 

60 

61 

62def _convert_model_v0_3_1_to_v0_3_2(data: Dict[str, Any]) -> Dict[str, Any]: 

63 data["type"] = "model" 

64 data["format_version"] = "0.3.2" 

65 future = data.get("config", {}).get("future", {}).pop("0.3.2", {}) 

66 

67 authors = data.get("authors") 

68 if isinstance(authors, list): 

69 data["authors"] = [{"name": name} for name in authors] 

70 authors_update = future.get("authors") 

71 if authors_update is not None: 

72 for a, u in zip(data["authors"], authors_update): 

73 a.update(u) 

74 

75 # packaged_by 

76 packaged_by = data.get("packaged_by") 

77 if packaged_by is not None: 

78 data["packaged_by"] = [{"name": name} for name in data["packaged_by"]] 

79 packaged_by_update = future.get("packaged_by") 

80 if packaged_by_update is not None: 

81 for a, u in zip(data["packaged_by"], packaged_by_update): 

82 a.update(u) 

83 

84 # authors of weights 

85 weights = data.get("weights") 

86 if isinstance(weights, dict): 

87 for weights_format, weights_entry in weights.items(): 

88 if "authors" not in weights_entry: 

89 continue 

90 

91 weights_entry["authors"] = [ 

92 {"name": name} for name in weights_entry["authors"] 

93 ] 

94 authors_update = ( 

95 future.get("weights", {}).get(weights_format, {}).get("authors") 

96 ) 

97 if authors_update is not None: 

98 for a, u in zip(weights_entry["authors"], authors_update): 

99 a.update(u) 

100 

101 # model version 

102 if "version" in future: 

103 data["version"] = future.pop("version") 

104 

105 return data 

106 

107 

108def _convert_model_v0_3_2_to_v0_3_3(data: Dict[str, Any]) -> Dict[str, Any]: 

109 data["format_version"] = "0.3.3" 

110 if "outputs" in data: 

111 for out in data["outputs"]: 

112 if "shape" in out: 

113 shape = out["shape"] 

114 if isinstance(shape, dict) and "reference_input" in shape: 

115 shape["reference_tensor"] = shape.pop("reference_input") 

116 

117 return data