Coverage for bioimageio/spec/_get_conda_env.py: 12%

130 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-02-05 13:53 +0000

1from typing import List, Literal, Optional, Union 

2 

3from typing_extensions import assert_never 

4 

5from ._internal.gh_utils import set_github_warning 

6from ._internal.io_utils import ZipPath, read_yaml 

7from .common import RelativeFilePath 

8from .conda_env import BioimageioCondaEnv, PipDeps 

9from .model import v0_4, v0_5 

10from .model.v0_5 import Version 

11from .utils import download 

12 

13SupportedWeightsEntry = Union[ 

14 v0_4.KerasHdf5WeightsDescr, 

15 v0_4.OnnxWeightsDescr, 

16 v0_4.PytorchStateDictWeightsDescr, 

17 v0_4.TensorflowSavedModelBundleWeightsDescr, 

18 v0_4.TorchscriptWeightsDescr, 

19 v0_5.KerasHdf5WeightsDescr, 

20 v0_5.OnnxWeightsDescr, 

21 v0_5.PytorchStateDictWeightsDescr, 

22 v0_5.TensorflowSavedModelBundleWeightsDescr, 

23 v0_5.TorchscriptWeightsDescr, 

24] 

25 

26 

27def get_conda_env( 

28 *, 

29 entry: SupportedWeightsEntry, 

30 env_name: Optional[Union[Literal["DROP"], str]] = None, 

31) -> BioimageioCondaEnv: 

32 """get the recommended Conda environment for a given weights entry description""" 

33 if isinstance(entry, (v0_4.OnnxWeightsDescr, v0_5.OnnxWeightsDescr)): 

34 conda_env = _get_default_onnx_env(opset_version=entry.opset_version) 

35 elif isinstance( 

36 entry, 

37 ( 

38 v0_4.PytorchStateDictWeightsDescr, 

39 v0_5.PytorchStateDictWeightsDescr, 

40 v0_4.TorchscriptWeightsDescr, 

41 v0_5.TorchscriptWeightsDescr, 

42 ), 

43 ): 

44 if ( 

45 isinstance(entry, v0_5.TorchscriptWeightsDescr) 

46 or entry.dependencies is None 

47 ): 

48 conda_env = _get_default_pytorch_env(pytorch_version=entry.pytorch_version) 

49 else: 

50 conda_env = _get_env_from_deps(entry.dependencies) 

51 

52 elif isinstance( 

53 entry, 

54 ( 

55 v0_4.TensorflowSavedModelBundleWeightsDescr, 

56 v0_5.TensorflowSavedModelBundleWeightsDescr, 

57 ), 

58 ): 

59 if entry.dependencies is None: 

60 conda_env = _get_default_tf_env(tensorflow_version=entry.tensorflow_version) 

61 else: 

62 conda_env = _get_env_from_deps(entry.dependencies) 

63 elif isinstance( 

64 entry, 

65 (v0_4.KerasHdf5WeightsDescr, v0_5.KerasHdf5WeightsDescr), 

66 ): 

67 conda_env = _get_default_tf_env(tensorflow_version=entry.tensorflow_version) 

68 else: 

69 assert_never(entry) 

70 

71 if env_name == "DROP": 

72 conda_env.name = None 

73 elif env_name is not None: 

74 conda_env.name = env_name 

75 

76 return conda_env 

77 

78 

79def _get_default_pytorch_env( 

80 *, 

81 pytorch_version: Optional[Version] = None, 

82) -> BioimageioCondaEnv: 

83 if pytorch_version is None: 

84 pytorch_version = Version("1.10.1") 

85 

86 # dependencies to install pytorch according to 

87 # https://pytorch.org/get-started/previous-versions/ 

88 v = str(pytorch_version) 

89 deps: List[Union[str, PipDeps]] = [f"pytorch=={v}"] 

90 if v == "1.5.1": 

91 deps += ["torchvision==0.6.1"] 

92 elif v == "1.6.0": 

93 deps += ["torchvision==0.7.0"] 

94 elif v == "1.7.0": 

95 deps += ["torchvision==0.8.0", "torchaudio==0.7.0"] 

96 elif v == "1.7.1": 

97 deps += ["torchvision==0.8.2", "torchaudio==0.7.1"] 

98 elif v == "1.8.0": 

99 deps += ["torchvision==0.9.0", "torchaudio==0.8.0"] 

100 elif v == "1.8.1": 

101 deps += ["torchvision==0.9.1", "torchaudio==0.8.1"] 

102 elif v == "1.9.0": 

103 deps += ["torchvision==0.10.0", "torchaudio==0.9.0"] 

104 elif v == "1.9.1": 

105 deps += ["torchvision==0.10.1", "torchaudio==0.9.1"] 

106 elif v == "1.10.0": 

107 deps += ["torchvision==0.11.0", "torchaudio==0.10.0"] 

108 elif v == "1.10.1": 

109 deps += ["torchvision==0.11.2", "torchaudio==0.10.1"] 

110 elif v == "1.11.0": 

111 deps += ["torchvision==0.12.0", "torchaudio==0.11.0"] 

112 elif v == "1.12.0": 

113 deps += ["torchvision==0.13.0", "torchaudio==0.12.0"] 

114 elif v == "1.12.1": 

115 deps += ["torchvision==0.13.1", "torchaudio==0.12.1"] 

116 elif v == "1.13.0": 

117 deps += ["torchvision==0.14.0", "torchaudio==0.13.0"] 

118 elif v == "1.13.1": 

119 deps += ["torchvision==0.14.1", "torchaudio==0.13.1"] 

120 elif v == "2.0.0": 

121 deps += ["torchvision==0.15.0", "torchaudio==2.0.0"] 

122 elif v == "2.0.1": 

123 deps += ["torchvision==0.15.2", "torchaudio==2.0.2"] 

124 elif v == "2.1.0": 

125 deps += ["torchvision==0.16.0", "torchaudio==2.1.0"] 

126 elif v == "2.1.1": 

127 deps += ["torchvision==0.16.1", "torchaudio==2.1.1"] 

128 elif v == "2.1.2": 

129 deps += ["torchvision==0.16.2", "torchaudio==2.1.2"] 

130 elif v == "2.2.0": 

131 deps += ["torchvision==0.17.0", "torchaudio==2.2.0"] 

132 elif v == "2.2.1": 

133 deps += ["torchvision==0.17.1", "torchaudio==2.2.1"] 

134 elif v == "2.2.2": 

135 deps += ["torchvision==0.17.2", "torchaudio==2.2.2"] 

136 elif v == "2.3.0": 

137 deps += ["torchvision==0.18.0", "torchaudio==2.3.0"] 

138 elif v == "2.3.1": 

139 deps += ["torchvision==0.18.1", "torchaudio==2.3.1"] 

140 elif v == "2.4.0": 

141 deps += ["torchvision==0.19.0", "torchaudio==2.4.0"] 

142 elif v == "2.4.1": 

143 deps += ["torchvision==0.19.1", "torchaudio==2.4.1"] 

144 elif v == "2.5.0": 

145 deps += ["torchvision==0.20.0", "torchaudio==2.5.0"] 

146 else: 

147 set_github_warning( 

148 "UPDATE NEEDED", 

149 f"Leaving torchvision and torchaudio unpinned for pytorch=={v}", 

150 ) 

151 deps += ["torchvision", "torchaudio"] 

152 

153 # avoid `undefined symbol: iJIT_NotifyEvent` from `torch/lib/libtorch_cpu.so` 

154 # see https://github.com/pytorch/pytorch/issues/123097 

155 if pytorch_version < Version( 

156 "2.1.0" # TODO: check if this is the correct cutoff where the fix is not longer needed 

157 ): 

158 deps.append("mkl ==2024.0.0") 

159 

160 if pytorch_version < Version("2.2"): 

161 # avoid ImportError: cannot import name 'packaging' from 'pkg_resources' 

162 # see https://github.com/pypa/setuptools/issues/4376#issuecomment-2126162839 

163 deps.append("setuptools <70.0.0") 

164 

165 if pytorch_version < Version( 

166 "2.3" 

167 ): # TODO: verify that future pytorch 2.4 is numpy 2.0 compatible 

168 # make sure we have a compatible numpy version 

169 # see https://github.com/pytorch/vision/issues/8460 

170 deps.append("numpy <2") 

171 else: 

172 deps.append("numpy >=2,<3") 

173 

174 return BioimageioCondaEnv( 

175 channels=["pytorch", "conda-forge", "nodefaults"], 

176 dependencies=deps, 

177 ) 

178 

179 

180def _get_default_onnx_env(*, opset_version: Optional[int]) -> BioimageioCondaEnv: 

181 if opset_version is None: 

182 opset_version = 15 

183 

184 # note: we should not need to worry about the opset version, 

185 # see https://github.com/microsoft/onnxruntime/blob/master/docs/Versioning.md 

186 return BioimageioCondaEnv(dependencies=["onnxruntime"]) 

187 

188 

189def _get_default_tf_env(tensorflow_version: Optional[Version]) -> BioimageioCondaEnv: 

190 if tensorflow_version is None: 

191 tensorflow_version = Version("1.15") 

192 

193 # tensorflow 1 is not available on conda, so we need to inject this as a pip dependency 

194 if tensorflow_version.major == 1: 

195 tensorflow_version = max( 

196 tensorflow_version, Version("1.13") 

197 ) # tf <1.13 not available anymore 

198 deps = ( 

199 "pip", 

200 "python=3.7.*", # tf 1.15 not available for py>=3.8 

201 PipDeps( 

202 pip=[ 

203 "bioimageio.core", # get bioimageio.core (and its dependencies) via pip as well to avoid conda/pip mix 

204 f"tensorflow =={tensorflow_version}", 

205 "protobuf <4.0", # protobuf pin: tf 1 does not pin an upper limit for protobuf, but fails to load models saved with protobuf 3 when installing protobuf 4. 

206 ] 

207 ), 

208 ) 

209 return BioimageioCondaEnv( 

210 dependencies=list(deps), 

211 ) 

212 else: 

213 return BioimageioCondaEnv( 

214 dependencies=["bioimageio.core", f"tensorflow =={tensorflow_version}"], 

215 ) 

216 

217 

218def _get_env_from_deps( 

219 deps: Union[v0_4.Dependencies, v0_5.EnvironmentFileDescr], 

220) -> BioimageioCondaEnv: 

221 if isinstance(deps, v0_4.Dependencies): 

222 if deps.manager == "pip": 

223 pip_deps_str = download(deps.file).path.read_text(encoding="utf-8") 

224 assert isinstance(pip_deps_str, str) 

225 pip_deps = [d.strip() for d in pip_deps_str.split("\n")] 

226 if "bioimageio.core" not in pip_deps: 

227 pip_deps.append("bioimageio.core") 

228 

229 return BioimageioCondaEnv( 

230 dependencies=[PipDeps(pip=pip_deps)], 

231 ) 

232 elif deps.manager not in ("conda", "mamba"): 

233 raise ValueError(f"Dependency manager {deps.manager} not supported") 

234 else: 

235 deps_source = ( 

236 deps.file.absolute() 

237 if isinstance(deps.file, RelativeFilePath) 

238 else deps.file 

239 ) 

240 if isinstance(deps_source, ZipPath): 

241 local = deps_source 

242 else: 

243 local = download(deps_source).path 

244 

245 return BioimageioCondaEnv.model_validate(read_yaml(local)) 

246 elif isinstance(deps, v0_5.EnvironmentFileDescr): 

247 local = download(deps.source).path 

248 return BioimageioCondaEnv.model_validate(read_yaml(local)) 

249 else: 

250 assert_never(deps)