Coverage for bioimageio/spec/_get_conda_env.py: 12%
130 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-02-05 13:53 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-02-05 13:53 +0000
1from typing import List, Literal, Optional, Union
3from typing_extensions import assert_never
5from ._internal.gh_utils import set_github_warning
6from ._internal.io_utils import ZipPath, read_yaml
7from .common import RelativeFilePath
8from .conda_env import BioimageioCondaEnv, PipDeps
9from .model import v0_4, v0_5
10from .model.v0_5 import Version
11from .utils import download
13SupportedWeightsEntry = Union[
14 v0_4.KerasHdf5WeightsDescr,
15 v0_4.OnnxWeightsDescr,
16 v0_4.PytorchStateDictWeightsDescr,
17 v0_4.TensorflowSavedModelBundleWeightsDescr,
18 v0_4.TorchscriptWeightsDescr,
19 v0_5.KerasHdf5WeightsDescr,
20 v0_5.OnnxWeightsDescr,
21 v0_5.PytorchStateDictWeightsDescr,
22 v0_5.TensorflowSavedModelBundleWeightsDescr,
23 v0_5.TorchscriptWeightsDescr,
24]
27def get_conda_env(
28 *,
29 entry: SupportedWeightsEntry,
30 env_name: Optional[Union[Literal["DROP"], str]] = None,
31) -> BioimageioCondaEnv:
32 """get the recommended Conda environment for a given weights entry description"""
33 if isinstance(entry, (v0_4.OnnxWeightsDescr, v0_5.OnnxWeightsDescr)):
34 conda_env = _get_default_onnx_env(opset_version=entry.opset_version)
35 elif isinstance(
36 entry,
37 (
38 v0_4.PytorchStateDictWeightsDescr,
39 v0_5.PytorchStateDictWeightsDescr,
40 v0_4.TorchscriptWeightsDescr,
41 v0_5.TorchscriptWeightsDescr,
42 ),
43 ):
44 if (
45 isinstance(entry, v0_5.TorchscriptWeightsDescr)
46 or entry.dependencies is None
47 ):
48 conda_env = _get_default_pytorch_env(pytorch_version=entry.pytorch_version)
49 else:
50 conda_env = _get_env_from_deps(entry.dependencies)
52 elif isinstance(
53 entry,
54 (
55 v0_4.TensorflowSavedModelBundleWeightsDescr,
56 v0_5.TensorflowSavedModelBundleWeightsDescr,
57 ),
58 ):
59 if entry.dependencies is None:
60 conda_env = _get_default_tf_env(tensorflow_version=entry.tensorflow_version)
61 else:
62 conda_env = _get_env_from_deps(entry.dependencies)
63 elif isinstance(
64 entry,
65 (v0_4.KerasHdf5WeightsDescr, v0_5.KerasHdf5WeightsDescr),
66 ):
67 conda_env = _get_default_tf_env(tensorflow_version=entry.tensorflow_version)
68 else:
69 assert_never(entry)
71 if env_name == "DROP":
72 conda_env.name = None
73 elif env_name is not None:
74 conda_env.name = env_name
76 return conda_env
79def _get_default_pytorch_env(
80 *,
81 pytorch_version: Optional[Version] = None,
82) -> BioimageioCondaEnv:
83 if pytorch_version is None:
84 pytorch_version = Version("1.10.1")
86 # dependencies to install pytorch according to
87 # https://pytorch.org/get-started/previous-versions/
88 v = str(pytorch_version)
89 deps: List[Union[str, PipDeps]] = [f"pytorch=={v}"]
90 if v == "1.5.1":
91 deps += ["torchvision==0.6.1"]
92 elif v == "1.6.0":
93 deps += ["torchvision==0.7.0"]
94 elif v == "1.7.0":
95 deps += ["torchvision==0.8.0", "torchaudio==0.7.0"]
96 elif v == "1.7.1":
97 deps += ["torchvision==0.8.2", "torchaudio==0.7.1"]
98 elif v == "1.8.0":
99 deps += ["torchvision==0.9.0", "torchaudio==0.8.0"]
100 elif v == "1.8.1":
101 deps += ["torchvision==0.9.1", "torchaudio==0.8.1"]
102 elif v == "1.9.0":
103 deps += ["torchvision==0.10.0", "torchaudio==0.9.0"]
104 elif v == "1.9.1":
105 deps += ["torchvision==0.10.1", "torchaudio==0.9.1"]
106 elif v == "1.10.0":
107 deps += ["torchvision==0.11.0", "torchaudio==0.10.0"]
108 elif v == "1.10.1":
109 deps += ["torchvision==0.11.2", "torchaudio==0.10.1"]
110 elif v == "1.11.0":
111 deps += ["torchvision==0.12.0", "torchaudio==0.11.0"]
112 elif v == "1.12.0":
113 deps += ["torchvision==0.13.0", "torchaudio==0.12.0"]
114 elif v == "1.12.1":
115 deps += ["torchvision==0.13.1", "torchaudio==0.12.1"]
116 elif v == "1.13.0":
117 deps += ["torchvision==0.14.0", "torchaudio==0.13.0"]
118 elif v == "1.13.1":
119 deps += ["torchvision==0.14.1", "torchaudio==0.13.1"]
120 elif v == "2.0.0":
121 deps += ["torchvision==0.15.0", "torchaudio==2.0.0"]
122 elif v == "2.0.1":
123 deps += ["torchvision==0.15.2", "torchaudio==2.0.2"]
124 elif v == "2.1.0":
125 deps += ["torchvision==0.16.0", "torchaudio==2.1.0"]
126 elif v == "2.1.1":
127 deps += ["torchvision==0.16.1", "torchaudio==2.1.1"]
128 elif v == "2.1.2":
129 deps += ["torchvision==0.16.2", "torchaudio==2.1.2"]
130 elif v == "2.2.0":
131 deps += ["torchvision==0.17.0", "torchaudio==2.2.0"]
132 elif v == "2.2.1":
133 deps += ["torchvision==0.17.1", "torchaudio==2.2.1"]
134 elif v == "2.2.2":
135 deps += ["torchvision==0.17.2", "torchaudio==2.2.2"]
136 elif v == "2.3.0":
137 deps += ["torchvision==0.18.0", "torchaudio==2.3.0"]
138 elif v == "2.3.1":
139 deps += ["torchvision==0.18.1", "torchaudio==2.3.1"]
140 elif v == "2.4.0":
141 deps += ["torchvision==0.19.0", "torchaudio==2.4.0"]
142 elif v == "2.4.1":
143 deps += ["torchvision==0.19.1", "torchaudio==2.4.1"]
144 elif v == "2.5.0":
145 deps += ["torchvision==0.20.0", "torchaudio==2.5.0"]
146 else:
147 set_github_warning(
148 "UPDATE NEEDED",
149 f"Leaving torchvision and torchaudio unpinned for pytorch=={v}",
150 )
151 deps += ["torchvision", "torchaudio"]
153 # avoid `undefined symbol: iJIT_NotifyEvent` from `torch/lib/libtorch_cpu.so`
154 # see https://github.com/pytorch/pytorch/issues/123097
155 if pytorch_version < Version(
156 "2.1.0" # TODO: check if this is the correct cutoff where the fix is not longer needed
157 ):
158 deps.append("mkl ==2024.0.0")
160 if pytorch_version < Version("2.2"):
161 # avoid ImportError: cannot import name 'packaging' from 'pkg_resources'
162 # see https://github.com/pypa/setuptools/issues/4376#issuecomment-2126162839
163 deps.append("setuptools <70.0.0")
165 if pytorch_version < Version(
166 "2.3"
167 ): # TODO: verify that future pytorch 2.4 is numpy 2.0 compatible
168 # make sure we have a compatible numpy version
169 # see https://github.com/pytorch/vision/issues/8460
170 deps.append("numpy <2")
171 else:
172 deps.append("numpy >=2,<3")
174 return BioimageioCondaEnv(
175 channels=["pytorch", "conda-forge", "nodefaults"],
176 dependencies=deps,
177 )
180def _get_default_onnx_env(*, opset_version: Optional[int]) -> BioimageioCondaEnv:
181 if opset_version is None:
182 opset_version = 15
184 # note: we should not need to worry about the opset version,
185 # see https://github.com/microsoft/onnxruntime/blob/master/docs/Versioning.md
186 return BioimageioCondaEnv(dependencies=["onnxruntime"])
189def _get_default_tf_env(tensorflow_version: Optional[Version]) -> BioimageioCondaEnv:
190 if tensorflow_version is None:
191 tensorflow_version = Version("1.15")
193 # tensorflow 1 is not available on conda, so we need to inject this as a pip dependency
194 if tensorflow_version.major == 1:
195 tensorflow_version = max(
196 tensorflow_version, Version("1.13")
197 ) # tf <1.13 not available anymore
198 deps = (
199 "pip",
200 "python=3.7.*", # tf 1.15 not available for py>=3.8
201 PipDeps(
202 pip=[
203 "bioimageio.core", # get bioimageio.core (and its dependencies) via pip as well to avoid conda/pip mix
204 f"tensorflow =={tensorflow_version}",
205 "protobuf <4.0", # protobuf pin: tf 1 does not pin an upper limit for protobuf, but fails to load models saved with protobuf 3 when installing protobuf 4.
206 ]
207 ),
208 )
209 return BioimageioCondaEnv(
210 dependencies=list(deps),
211 )
212 else:
213 return BioimageioCondaEnv(
214 dependencies=["bioimageio.core", f"tensorflow =={tensorflow_version}"],
215 )
218def _get_env_from_deps(
219 deps: Union[v0_4.Dependencies, v0_5.EnvironmentFileDescr],
220) -> BioimageioCondaEnv:
221 if isinstance(deps, v0_4.Dependencies):
222 if deps.manager == "pip":
223 pip_deps_str = download(deps.file).path.read_text(encoding="utf-8")
224 assert isinstance(pip_deps_str, str)
225 pip_deps = [d.strip() for d in pip_deps_str.split("\n")]
226 if "bioimageio.core" not in pip_deps:
227 pip_deps.append("bioimageio.core")
229 return BioimageioCondaEnv(
230 dependencies=[PipDeps(pip=pip_deps)],
231 )
232 elif deps.manager not in ("conda", "mamba"):
233 raise ValueError(f"Dependency manager {deps.manager} not supported")
234 else:
235 deps_source = (
236 deps.file.absolute()
237 if isinstance(deps.file, RelativeFilePath)
238 else deps.file
239 )
240 if isinstance(deps_source, ZipPath):
241 local = deps_source
242 else:
243 local = download(deps_source).path
245 return BioimageioCondaEnv.model_validate(read_yaml(local))
246 elif isinstance(deps, v0_5.EnvironmentFileDescr):
247 local = download(deps.source).path
248 return BioimageioCondaEnv.model_validate(read_yaml(local))
249 else:
250 assert_never(deps)