Coverage for src / bioimageio / spec / _get_conda_env.py: 55%

83 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-31 13:09 +0000

1from typing import List, Literal, Optional, Tuple, Union 

2 

3from typing_extensions import assert_never 

4 

5from ._internal.gh_utils import set_github_warning 

6from ._internal.io import FileDescr, get_reader 

7from ._internal.io_utils import read_yaml 

8from .conda_env import BioimageioCondaEnv, PipDeps 

9from .model import v0_4, v0_5 

10from .model.v0_5 import Version 

11 

12SupportedWeightsEntry = Union[ 

13 v0_4.KerasHdf5WeightsDescr, 

14 v0_4.OnnxWeightsDescr, 

15 v0_4.PytorchStateDictWeightsDescr, 

16 v0_4.TensorflowSavedModelBundleWeightsDescr, 

17 v0_4.TorchscriptWeightsDescr, 

18 v0_5.KerasHdf5WeightsDescr, 

19 v0_5.KerasV3WeightsDescr, 

20 v0_5.OnnxWeightsDescr, 

21 v0_5.PytorchStateDictWeightsDescr, 

22 v0_5.TensorflowSavedModelBundleWeightsDescr, 

23 v0_5.TorchscriptWeightsDescr, 

24] 

25 

26 

27def get_conda_env( 

28 *, 

29 entry: SupportedWeightsEntry, 

30 env_name: Optional[Union[Literal["DROP"], str]] = None, 

31) -> BioimageioCondaEnv: 

32 """get the recommended Conda environment for a given weights entry description""" 

33 if isinstance(entry, (v0_4.OnnxWeightsDescr, v0_5.OnnxWeightsDescr)): 

34 conda_env = _get_default_onnx_env(opset_version=entry.opset_version) 

35 elif isinstance( 

36 entry, 

37 ( 

38 v0_4.PytorchStateDictWeightsDescr, 

39 v0_5.PytorchStateDictWeightsDescr, 

40 v0_4.TorchscriptWeightsDescr, 

41 v0_5.TorchscriptWeightsDescr, 

42 ), 

43 ): 

44 if ( 

45 isinstance(entry, v0_5.TorchscriptWeightsDescr) 

46 or entry.dependencies is None 

47 ): 

48 conda_env = _get_default_pytorch_env(pytorch_version=entry.pytorch_version) 

49 else: 

50 conda_env = _get_env_from_deps(entry.dependencies) 

51 

52 elif isinstance( 

53 entry, 

54 ( 

55 v0_4.TensorflowSavedModelBundleWeightsDescr, 

56 v0_5.TensorflowSavedModelBundleWeightsDescr, 

57 ), 

58 ): 

59 if entry.dependencies is None: 

60 conda_env = _get_default_tf_env(tensorflow_version=entry.tensorflow_version) 

61 else: 

62 conda_env = _get_env_from_deps(entry.dependencies) 

63 elif isinstance( 

64 entry, 

65 (v0_4.KerasHdf5WeightsDescr, v0_5.KerasHdf5WeightsDescr), 

66 ): 

67 conda_env = _get_default_tf_env(tensorflow_version=entry.tensorflow_version) 

68 elif isinstance(entry, v0_5.KerasV3WeightsDescr): 

69 conda_env = _get_default_keras3_env(entry.backend) 

70 else: 

71 assert_never(entry) 

72 

73 if env_name == "DROP": 

74 conda_env.name = None 

75 elif env_name is not None: 

76 conda_env.name = env_name 

77 

78 return conda_env 

79 

80 

81def _get_default_keras3_env( 

82 backend: Tuple[Literal["tensorflow", "jax", "torch"], Version], 

83) -> BioimageioCondaEnv: 

84 if backend[0] == "tensorflow": 

85 env = _get_default_tf_env(tensorflow_version=backend[1]) 

86 elif backend[0] == "torch": 

87 env = _get_default_pytorch_env(pytorch_version=backend[1]) 

88 elif backend[0] == "jax": 

89 env = BioimageioCondaEnv( 

90 dependencies=[ 

91 f"jax=={backend[1]}", 

92 ] 

93 ) 

94 else: 

95 assert_never(backend[0]) 

96 

97 env.dependencies.append("keras >=3.0, <4") 

98 return env 

99 

100 

101def _get_default_pytorch_env( 

102 *, 

103 pytorch_version: Optional[Version] = None, 

104) -> BioimageioCondaEnv: 

105 if pytorch_version is None: 

106 pytorch_version = Version("1.10.1") 

107 

108 channels = ["conda-forge", "nodefaults"] 

109 

110 # dependencies to install pytorch according to 

111 # https://pytorch.org/get-started/previous-versions/ 

112 v = pytorch_version.base_version 

113 if v.count(".") == 0: 

114 v += ".0.0" 

115 elif v.count(".") == 1: 

116 v += ".0" 

117 

118 deps: List[Union[str, PipDeps]] = [f"pytorch=={v}"] 

119 additional_deps = { 

120 "1.5.1": "torchvision==0.6.1", 

121 "1.6.0": "torchvision==0.7.0", 

122 "1.7.0": "torchvision==0.8.0", 

123 "1.7.1": "torchvision==0.8.2", 

124 "1.8.0": "torchvision==0.9.0", 

125 "1.8.1": "torchvision==0.9.1", 

126 "1.9.0": "torchvision==0.10.0", 

127 "1.9.1": "torchvision==0.10.1", 

128 "1.10.0": "torchvision==0.11.0", 

129 "1.10.1": "torchvision==0.11.2", 

130 "1.11.0": "torchvision==0.12.0", 

131 "1.12.0": "torchvision==0.13.0", 

132 "1.12.1": "torchvision==0.13.1", 

133 "1.13.0": "torchvision==0.14.0", 

134 "1.13.1": "torchvision==0.14.1", 

135 "2.0.0": "torchvision==0.15.0", 

136 "2.0.1": "torchvision==0.15.2", 

137 "2.1.0": "torchvision==0.16.0", 

138 "2.1.1": "torchvision==0.16.1", 

139 "2.1.2": "torchvision==0.16.2", 

140 "2.2.0": "torchvision==0.17.0", 

141 "2.2.1": "torchvision==0.17.1", 

142 "2.2.2": "torchvision==0.17.2", 

143 "2.3.0": "torchvision==0.18.0", 

144 "2.3.1": "torchvision==0.18.1", 

145 "2.4.0": "torchvision==0.19.0", 

146 "2.4.1": "torchvision==0.19.1", 

147 "2.5.0": "torchvision==0.20.0", 

148 "2.5.1": "torchvision==0.20.1", 

149 "2.6.0": "torchvision==0.21.0", 

150 "2.7.0": "torchvision==0.22.0", 

151 "2.7.1": "torchvision==0.22.1", 

152 "2.8.0": "torchvision==0.23.0", 

153 "2.9.0": "torchvision==0.24.0", 

154 "2.9.1": "torchvision==0.24.1", 

155 }.get(v) 

156 if additional_deps is None: 

157 set_github_warning( 

158 "UPDATE NEEDED", 

159 f"Leaving torchvision unpinned for pytorch=={v}", 

160 ) 

161 additional_deps = "torchvision" 

162 

163 deps.append(additional_deps) 

164 

165 # avoid `undefined symbol: iJIT_NotifyEvent` from `torch/lib/libtorch_cpu.so` 

166 # see https://github.com/pytorch/pytorch/issues/123097 

167 if ( 

168 pytorch_version 

169 < Version( 

170 "2.1.0" # TODO: check if this is the correct cutoff where the fix is not longer needed 

171 ) 

172 ): 

173 deps.append("mkl ==2024.0.0") 

174 

175 if pytorch_version < Version("2.2"): 

176 # avoid ImportError: cannot import name 'packaging' from 'pkg_resources' 

177 # see https://github.com/pypa/setuptools/issues/4376#issuecomment-2126162839 

178 deps.append("setuptools <70.0.0") 

179 

180 if pytorch_version < Version("2.3"): 

181 # see https://github.com/pytorch/pytorch/issues/107302 

182 deps.append("numpy <2") 

183 

184 return BioimageioCondaEnv(channels=channels, dependencies=deps) 

185 

186 

187def _get_default_onnx_env(*, opset_version: Optional[int]) -> BioimageioCondaEnv: 

188 if opset_version is None: 

189 opset_version = 15 

190 

191 # note: we should not need to worry about the opset version, 

192 # see https://github.com/microsoft/onnxruntime/blob/master/docs/Versioning.md 

193 return BioimageioCondaEnv(dependencies=["onnxruntime"]) 

194 

195 

196def _get_default_tf_env(tensorflow_version: Optional[Version]) -> BioimageioCondaEnv: 

197 if tensorflow_version is None or tensorflow_version.major < 2: 

198 tensorflow_version = Version("2.17") 

199 

200 return BioimageioCondaEnv( 

201 dependencies=[f"tensorflow =={tensorflow_version}"], 

202 ) 

203 

204 

205def _get_env_from_deps( 

206 deps: Union[v0_4.Dependencies, FileDescr], 

207) -> BioimageioCondaEnv: 

208 if isinstance(deps, v0_4.Dependencies): 

209 deps_reader = get_reader(deps.file) 

210 if deps.manager == "pip": 

211 pip_deps_str = deps_reader.read_text() 

212 pip_deps = [d.strip() for d in pip_deps_str.split("\n")] 

213 if "bioimageio.core" not in pip_deps: 

214 pip_deps.append("bioimageio.core>=0.9.4") 

215 

216 return BioimageioCondaEnv( 

217 dependencies=[PipDeps(pip=pip_deps)], 

218 ) 

219 elif deps.manager in ("conda", "mamba"): 

220 return BioimageioCondaEnv.model_validate(read_yaml(deps_reader)) 

221 else: 

222 raise ValueError(f"Dependency manager {deps.manager} not supported") 

223 

224 elif isinstance(deps, FileDescr): 

225 deps_reader = deps.get_reader() 

226 return BioimageioCondaEnv.model_validate(read_yaml(deps_reader)) 

227 else: 

228 assert_never(deps)