Coverage for bioimageio/spec/_get_conda_env.py: 79%

128 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-27 09:20 +0000

1from typing import List, Literal, Optional, Union 

2 

3from typing_extensions import assert_never 

4 

5from ._internal.gh_utils import set_github_warning 

6from ._internal.io import FileDescr, get_reader 

7from ._internal.io_utils import read_yaml 

8from .conda_env import BioimageioCondaEnv, PipDeps 

9from .model import v0_4, v0_5 

10from .model.v0_5 import Version 

11 

12SupportedWeightsEntry = Union[ 

13 v0_4.KerasHdf5WeightsDescr, 

14 v0_4.OnnxWeightsDescr, 

15 v0_4.PytorchStateDictWeightsDescr, 

16 v0_4.TensorflowSavedModelBundleWeightsDescr, 

17 v0_4.TorchscriptWeightsDescr, 

18 v0_5.KerasHdf5WeightsDescr, 

19 v0_5.OnnxWeightsDescr, 

20 v0_5.PytorchStateDictWeightsDescr, 

21 v0_5.TensorflowSavedModelBundleWeightsDescr, 

22 v0_5.TorchscriptWeightsDescr, 

23] 

24 

25 

26def get_conda_env( 

27 *, 

28 entry: SupportedWeightsEntry, 

29 env_name: Optional[Union[Literal["DROP"], str]] = None, 

30) -> BioimageioCondaEnv: 

31 """get the recommended Conda environment for a given weights entry description""" 

32 if isinstance(entry, (v0_4.OnnxWeightsDescr, v0_5.OnnxWeightsDescr)): 

33 conda_env = _get_default_onnx_env(opset_version=entry.opset_version) 

34 elif isinstance( 

35 entry, 

36 ( 

37 v0_4.PytorchStateDictWeightsDescr, 

38 v0_5.PytorchStateDictWeightsDescr, 

39 v0_4.TorchscriptWeightsDescr, 

40 v0_5.TorchscriptWeightsDescr, 

41 ), 

42 ): 

43 if ( 

44 isinstance(entry, v0_5.TorchscriptWeightsDescr) 

45 or entry.dependencies is None 

46 ): 

47 conda_env = _get_default_pytorch_env(pytorch_version=entry.pytorch_version) 

48 else: 

49 conda_env = _get_env_from_deps(entry.dependencies) 

50 

51 elif isinstance( 

52 entry, 

53 ( 

54 v0_4.TensorflowSavedModelBundleWeightsDescr, 

55 v0_5.TensorflowSavedModelBundleWeightsDescr, 

56 ), 

57 ): 

58 if entry.dependencies is None: 

59 conda_env = _get_default_tf_env(tensorflow_version=entry.tensorflow_version) 

60 else: 

61 conda_env = _get_env_from_deps(entry.dependencies) 

62 elif isinstance( 

63 entry, 

64 (v0_4.KerasHdf5WeightsDescr, v0_5.KerasHdf5WeightsDescr), 

65 ): 

66 conda_env = _get_default_tf_env(tensorflow_version=entry.tensorflow_version) 

67 else: 

68 assert_never(entry) 

69 

70 if env_name == "DROP": 

71 conda_env.name = None 

72 elif env_name is not None: 

73 conda_env.name = env_name 

74 

75 return conda_env 

76 

77 

78def _get_default_pytorch_env( 

79 *, 

80 pytorch_version: Optional[Version] = None, 

81) -> BioimageioCondaEnv: 

82 if pytorch_version is None: 

83 pytorch_version = Version("1.10.1") 

84 

85 channels = ["conda-forge", "nodefaults"] 

86 if pytorch_version < Version("2.5.2"): 

87 channels.insert(0, "pytorch") 

88 

89 # dependencies to install pytorch according to 

90 # https://pytorch.org/get-started/previous-versions/ 

91 v = pytorch_version.base_version 

92 if v.count(".") == 0: 

93 v += ".0.0" 

94 elif v.count(".") == 1: 

95 v += ".0" 

96 

97 deps: List[Union[str, PipDeps]] = [f"pytorch=={v}"] 

98 if v == "1.5.1": 

99 deps += ["torchvision==0.6.1"] 

100 elif v == "1.6.0": 

101 deps += ["torchvision==0.7.0"] 

102 elif v == "1.7.0": 

103 deps += ["torchvision==0.8.0", "torchaudio==0.7.0"] 

104 elif v == "1.7.1": 

105 deps += ["torchvision==0.8.2", "torchaudio==0.7.1"] 

106 elif v == "1.8.0": 

107 deps += ["torchvision==0.9.0", "torchaudio==0.8.0"] 

108 elif v == "1.8.1": 

109 deps += ["torchvision==0.9.1", "torchaudio==0.8.1"] 

110 elif v == "1.9.0": 

111 deps += ["torchvision==0.10.0", "torchaudio==0.9.0"] 

112 elif v == "1.9.1": 

113 deps += ["torchvision==0.10.1", "torchaudio==0.9.1"] 

114 elif v == "1.10.0": 

115 deps += ["torchvision==0.11.0", "torchaudio==0.10.0"] 

116 elif v == "1.10.1": 

117 deps += ["torchvision==0.11.2", "torchaudio==0.10.1"] 

118 elif v == "1.11.0": 

119 deps += ["torchvision==0.12.0", "torchaudio==0.11.0"] 

120 elif v == "1.12.0": 

121 deps += ["torchvision==0.13.0", "torchaudio==0.12.0"] 

122 elif v == "1.12.1": 

123 deps += ["torchvision==0.13.1", "torchaudio==0.12.1"] 

124 elif v == "1.13.0": 

125 deps += ["torchvision==0.14.0", "torchaudio==0.13.0"] 

126 elif v == "1.13.1": 

127 deps += ["torchvision==0.14.1", "torchaudio==0.13.1"] 

128 elif v == "2.0.0": 

129 deps += ["torchvision==0.15.0", "torchaudio==2.0.0"] 

130 elif v == "2.0.1": 

131 deps += ["torchvision==0.15.2", "torchaudio==2.0.2"] 

132 elif v == "2.1.0": 

133 deps += ["torchvision==0.16.0", "torchaudio==2.1.0"] 

134 elif v == "2.1.1": 

135 deps += ["torchvision==0.16.1", "torchaudio==2.1.1"] 

136 elif v == "2.1.2": 

137 deps += ["torchvision==0.16.2", "torchaudio==2.1.2"] 

138 elif v == "2.2.0": 

139 deps += ["torchvision==0.17.0", "torchaudio==2.2.0"] 

140 elif v == "2.2.1": 

141 deps += ["torchvision==0.17.1", "torchaudio==2.2.1"] 

142 elif v == "2.2.2": 

143 deps += ["torchvision==0.17.2", "torchaudio==2.2.2"] 

144 elif v == "2.3.0": 

145 deps += ["torchvision==0.18.0", "torchaudio==2.3.0"] 

146 elif v == "2.3.1": 

147 deps += ["torchvision==0.18.1", "torchaudio==2.3.1"] 

148 elif v == "2.4.0": 

149 deps += ["torchvision==0.19.0", "torchaudio==2.4.0"] 

150 elif v == "2.4.1": 

151 deps += ["torchvision==0.19.1", "torchaudio==2.4.1"] 

152 elif v == "2.5.0": 

153 deps += ["torchvision==0.20.0", "torchaudio==2.5.0"] 

154 else: 

155 set_github_warning( 

156 "UPDATE NEEDED", 

157 f"Leaving torchvision and torchaudio unpinned for pytorch=={v}", 

158 ) 

159 deps += ["torchvision", "torchaudio"] 

160 

161 # avoid `undefined symbol: iJIT_NotifyEvent` from `torch/lib/libtorch_cpu.so` 

162 # see https://github.com/pytorch/pytorch/issues/123097 

163 if pytorch_version < Version( 

164 "2.1.0" # TODO: check if this is the correct cutoff where the fix is not longer needed 

165 ): 

166 deps.append("mkl ==2024.0.0") 

167 

168 if pytorch_version < Version("2.2"): 

169 # avoid ImportError: cannot import name 'packaging' from 'pkg_resources' 

170 # see https://github.com/pypa/setuptools/issues/4376#issuecomment-2126162839 

171 deps.append("setuptools <70.0.0") 

172 

173 if pytorch_version < Version( 

174 "2.3" 

175 ): # TODO: verify that future pytorch 2.4 is numpy 2.0 compatible 

176 # make sure we have a compatible numpy version 

177 # see https://github.com/pytorch/vision/issues/8460 

178 deps.append("numpy <2") 

179 else: 

180 deps.append("numpy >=2,<3") 

181 

182 return BioimageioCondaEnv(channels=channels, dependencies=deps) 

183 

184 

185def _get_default_onnx_env(*, opset_version: Optional[int]) -> BioimageioCondaEnv: 

186 if opset_version is None: 

187 opset_version = 15 

188 

189 # note: we should not need to worry about the opset version, 

190 # see https://github.com/microsoft/onnxruntime/blob/master/docs/Versioning.md 

191 return BioimageioCondaEnv(dependencies=["onnxruntime"]) 

192 

193 

194def _get_default_tf_env(tensorflow_version: Optional[Version]) -> BioimageioCondaEnv: 

195 if tensorflow_version is None or tensorflow_version.major < 2: 

196 tensorflow_version = Version("2.17") 

197 

198 return BioimageioCondaEnv( 

199 dependencies=["bioimageio.core", f"tensorflow =={tensorflow_version}"], 

200 ) 

201 

202 

203def _get_env_from_deps( 

204 deps: Union[v0_4.Dependencies, FileDescr], 

205) -> BioimageioCondaEnv: 

206 if isinstance(deps, v0_4.Dependencies): 

207 deps_reader = get_reader(deps.file) 

208 if deps.manager == "pip": 

209 pip_deps_str = deps_reader.read_text() 

210 pip_deps = [d.strip() for d in pip_deps_str.split("\n")] 

211 if "bioimageio.core" not in pip_deps: 

212 pip_deps.append("bioimageio.core") 

213 

214 return BioimageioCondaEnv( 

215 dependencies=[PipDeps(pip=pip_deps)], 

216 ) 

217 elif deps.manager in ("conda", "mamba"): 

218 return BioimageioCondaEnv.model_validate(read_yaml(deps_reader)) 

219 else: 

220 raise ValueError(f"Dependency manager {deps.manager} not supported") 

221 

222 elif isinstance(deps, FileDescr): 

223 deps_reader = deps.get_reader() 

224 return BioimageioCondaEnv.model_validate(read_yaml(deps_reader)) 

225 else: 

226 assert_never(deps)