Coverage for bioimageio/spec/_get_conda_env.py: 77%

133 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-02 14:21 +0000

1from typing import List, Literal, Optional, Union 

2 

3from typing_extensions import assert_never 

4 

5from ._internal.gh_utils import set_github_warning 

6from ._internal.io_utils import ZipPath, read_yaml 

7from .common import RelativeFilePath 

8from .conda_env import BioimageioCondaEnv, PipDeps 

9from .model import v0_4, v0_5 

10from .model.v0_5 import Version 

11from .utils import download 

12 

13SupportedWeightsEntry = Union[ 

14 v0_4.KerasHdf5WeightsDescr, 

15 v0_4.OnnxWeightsDescr, 

16 v0_4.PytorchStateDictWeightsDescr, 

17 v0_4.TensorflowSavedModelBundleWeightsDescr, 

18 v0_4.TorchscriptWeightsDescr, 

19 v0_5.KerasHdf5WeightsDescr, 

20 v0_5.OnnxWeightsDescr, 

21 v0_5.PytorchStateDictWeightsDescr, 

22 v0_5.TensorflowSavedModelBundleWeightsDescr, 

23 v0_5.TorchscriptWeightsDescr, 

24] 

25 

26 

27def get_conda_env( 

28 *, 

29 entry: SupportedWeightsEntry, 

30 env_name: Optional[Union[Literal["DROP"], str]] = None, 

31) -> BioimageioCondaEnv: 

32 """get the recommended Conda environment for a given weights entry description""" 

33 if isinstance(entry, (v0_4.OnnxWeightsDescr, v0_5.OnnxWeightsDescr)): 

34 conda_env = _get_default_onnx_env(opset_version=entry.opset_version) 

35 elif isinstance( 

36 entry, 

37 ( 

38 v0_4.PytorchStateDictWeightsDescr, 

39 v0_5.PytorchStateDictWeightsDescr, 

40 v0_4.TorchscriptWeightsDescr, 

41 v0_5.TorchscriptWeightsDescr, 

42 ), 

43 ): 

44 if ( 

45 isinstance(entry, v0_5.TorchscriptWeightsDescr) 

46 or entry.dependencies is None 

47 ): 

48 conda_env = _get_default_pytorch_env(pytorch_version=entry.pytorch_version) 

49 else: 

50 conda_env = _get_env_from_deps(entry.dependencies) 

51 

52 elif isinstance( 

53 entry, 

54 ( 

55 v0_4.TensorflowSavedModelBundleWeightsDescr, 

56 v0_5.TensorflowSavedModelBundleWeightsDescr, 

57 ), 

58 ): 

59 if entry.dependencies is None: 

60 conda_env = _get_default_tf_env(tensorflow_version=entry.tensorflow_version) 

61 else: 

62 conda_env = _get_env_from_deps(entry.dependencies) 

63 elif isinstance( 

64 entry, 

65 (v0_4.KerasHdf5WeightsDescr, v0_5.KerasHdf5WeightsDescr), 

66 ): 

67 conda_env = _get_default_tf_env(tensorflow_version=entry.tensorflow_version) 

68 else: 

69 assert_never(entry) 

70 

71 if env_name == "DROP": 

72 conda_env.name = None 

73 elif env_name is not None: 

74 conda_env.name = env_name 

75 

76 return conda_env 

77 

78 

79def _get_default_pytorch_env( 

80 *, 

81 pytorch_version: Optional[Version] = None, 

82) -> BioimageioCondaEnv: 

83 if pytorch_version is None: 

84 pytorch_version = Version("1.10.1") 

85 

86 channels = ["conda-forge", "nodefaults"] 

87 if pytorch_version < Version("2.5.2"): 

88 channels.insert(0, "pytorch") 

89 

90 # dependencies to install pytorch according to 

91 # https://pytorch.org/get-started/previous-versions/ 

92 v = pytorch_version.base_version 

93 if v.count(".") == 0: 

94 v += ".0.0" 

95 elif v.count(".") == 1: 

96 v += ".0" 

97 

98 deps: List[Union[str, PipDeps]] = [f"pytorch=={v}"] 

99 if v == "1.5.1": 

100 deps += ["torchvision==0.6.1"] 

101 elif v == "1.6.0": 

102 deps += ["torchvision==0.7.0"] 

103 elif v == "1.7.0": 

104 deps += ["torchvision==0.8.0", "torchaudio==0.7.0"] 

105 elif v == "1.7.1": 

106 deps += ["torchvision==0.8.2", "torchaudio==0.7.1"] 

107 elif v == "1.8.0": 

108 deps += ["torchvision==0.9.0", "torchaudio==0.8.0"] 

109 elif v == "1.8.1": 

110 deps += ["torchvision==0.9.1", "torchaudio==0.8.1"] 

111 elif v == "1.9.0": 

112 deps += ["torchvision==0.10.0", "torchaudio==0.9.0"] 

113 elif v == "1.9.1": 

114 deps += ["torchvision==0.10.1", "torchaudio==0.9.1"] 

115 elif v == "1.10.0": 

116 deps += ["torchvision==0.11.0", "torchaudio==0.10.0"] 

117 elif v == "1.10.1": 

118 deps += ["torchvision==0.11.2", "torchaudio==0.10.1"] 

119 elif v == "1.11.0": 

120 deps += ["torchvision==0.12.0", "torchaudio==0.11.0"] 

121 elif v == "1.12.0": 

122 deps += ["torchvision==0.13.0", "torchaudio==0.12.0"] 

123 elif v == "1.12.1": 

124 deps += ["torchvision==0.13.1", "torchaudio==0.12.1"] 

125 elif v == "1.13.0": 

126 deps += ["torchvision==0.14.0", "torchaudio==0.13.0"] 

127 elif v == "1.13.1": 

128 deps += ["torchvision==0.14.1", "torchaudio==0.13.1"] 

129 elif v == "2.0.0": 

130 deps += ["torchvision==0.15.0", "torchaudio==2.0.0"] 

131 elif v == "2.0.1": 

132 deps += ["torchvision==0.15.2", "torchaudio==2.0.2"] 

133 elif v == "2.1.0": 

134 deps += ["torchvision==0.16.0", "torchaudio==2.1.0"] 

135 elif v == "2.1.1": 

136 deps += ["torchvision==0.16.1", "torchaudio==2.1.1"] 

137 elif v == "2.1.2": 

138 deps += ["torchvision==0.16.2", "torchaudio==2.1.2"] 

139 elif v == "2.2.0": 

140 deps += ["torchvision==0.17.0", "torchaudio==2.2.0"] 

141 elif v == "2.2.1": 

142 deps += ["torchvision==0.17.1", "torchaudio==2.2.1"] 

143 elif v == "2.2.2": 

144 deps += ["torchvision==0.17.2", "torchaudio==2.2.2"] 

145 elif v == "2.3.0": 

146 deps += ["torchvision==0.18.0", "torchaudio==2.3.0"] 

147 elif v == "2.3.1": 

148 deps += ["torchvision==0.18.1", "torchaudio==2.3.1"] 

149 elif v == "2.4.0": 

150 deps += ["torchvision==0.19.0", "torchaudio==2.4.0"] 

151 elif v == "2.4.1": 

152 deps += ["torchvision==0.19.1", "torchaudio==2.4.1"] 

153 elif v == "2.5.0": 

154 deps += ["torchvision==0.20.0", "torchaudio==2.5.0"] 

155 else: 

156 set_github_warning( 

157 "UPDATE NEEDED", 

158 f"Leaving torchvision and torchaudio unpinned for pytorch=={v}", 

159 ) 

160 deps += ["torchvision", "torchaudio"] 

161 

162 # avoid `undefined symbol: iJIT_NotifyEvent` from `torch/lib/libtorch_cpu.so` 

163 # see https://github.com/pytorch/pytorch/issues/123097 

164 if pytorch_version < Version( 

165 "2.1.0" # TODO: check if this is the correct cutoff where the fix is not longer needed 

166 ): 

167 deps.append("mkl ==2024.0.0") 

168 

169 if pytorch_version < Version("2.2"): 

170 # avoid ImportError: cannot import name 'packaging' from 'pkg_resources' 

171 # see https://github.com/pypa/setuptools/issues/4376#issuecomment-2126162839 

172 deps.append("setuptools <70.0.0") 

173 

174 if pytorch_version < Version( 

175 "2.3" 

176 ): # TODO: verify that future pytorch 2.4 is numpy 2.0 compatible 

177 # make sure we have a compatible numpy version 

178 # see https://github.com/pytorch/vision/issues/8460 

179 deps.append("numpy <2") 

180 else: 

181 deps.append("numpy >=2,<3") 

182 

183 return BioimageioCondaEnv(channels=channels, dependencies=deps) 

184 

185 

186def _get_default_onnx_env(*, opset_version: Optional[int]) -> BioimageioCondaEnv: 

187 if opset_version is None: 

188 opset_version = 15 

189 

190 # note: we should not need to worry about the opset version, 

191 # see https://github.com/microsoft/onnxruntime/blob/master/docs/Versioning.md 

192 return BioimageioCondaEnv(dependencies=["onnxruntime"]) 

193 

194 

195def _get_default_tf_env(tensorflow_version: Optional[Version]) -> BioimageioCondaEnv: 

196 if tensorflow_version is None or tensorflow_version.major < 2: 

197 tensorflow_version = Version("2.17") 

198 

199 return BioimageioCondaEnv( 

200 dependencies=["bioimageio.core", f"tensorflow =={tensorflow_version}"], 

201 ) 

202 

203 

204def _get_env_from_deps( 

205 deps: Union[v0_4.Dependencies, v0_5.EnvironmentFileDescr], 

206) -> BioimageioCondaEnv: 

207 if isinstance(deps, v0_4.Dependencies): 

208 if deps.manager == "pip": 

209 pip_deps_str = download(deps.file).path.read_text(encoding="utf-8") 

210 assert isinstance(pip_deps_str, str) 

211 pip_deps = [d.strip() for d in pip_deps_str.split("\n")] 

212 if "bioimageio.core" not in pip_deps: 

213 pip_deps.append("bioimageio.core") 

214 

215 return BioimageioCondaEnv( 

216 dependencies=[PipDeps(pip=pip_deps)], 

217 ) 

218 elif deps.manager not in ("conda", "mamba"): 

219 raise ValueError(f"Dependency manager {deps.manager} not supported") 

220 else: 

221 deps_source = ( 

222 deps.file.absolute() 

223 if isinstance(deps.file, RelativeFilePath) 

224 else deps.file 

225 ) 

226 if isinstance(deps_source, ZipPath): 

227 local = deps_source 

228 else: 

229 local = download(deps_source).path 

230 

231 return BioimageioCondaEnv.model_validate(read_yaml(local)) 

232 elif isinstance(deps, v0_5.EnvironmentFileDescr): 

233 local = download(deps.source).path 

234 return BioimageioCondaEnv.model_validate(read_yaml(local)) 

235 else: 

236 assert_never(deps)