Coverage for src / bioimageio / spec / _get_conda_env.py: 71%
139 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-17 16:08 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-17 16:08 +0000
1from typing import List, Literal, Optional, Union
3from typing_extensions import assert_never
5from ._internal.gh_utils import set_github_warning
6from ._internal.io import FileDescr, get_reader
7from ._internal.io_utils import read_yaml
8from .conda_env import BioimageioCondaEnv, PipDeps
9from .model import v0_4, v0_5
10from .model.v0_5 import Version
12SupportedWeightsEntry = Union[
13 v0_4.KerasHdf5WeightsDescr,
14 v0_4.OnnxWeightsDescr,
15 v0_4.PytorchStateDictWeightsDescr,
16 v0_4.TensorflowSavedModelBundleWeightsDescr,
17 v0_4.TorchscriptWeightsDescr,
18 v0_5.KerasHdf5WeightsDescr,
19 v0_5.OnnxWeightsDescr,
20 v0_5.PytorchStateDictWeightsDescr,
21 v0_5.TensorflowSavedModelBundleWeightsDescr,
22 v0_5.TorchscriptWeightsDescr,
23]
26def get_conda_env(
27 *,
28 entry: SupportedWeightsEntry,
29 env_name: Optional[Union[Literal["DROP"], str]] = None,
30) -> BioimageioCondaEnv:
31 """get the recommended Conda environment for a given weights entry description"""
32 if isinstance(entry, (v0_4.OnnxWeightsDescr, v0_5.OnnxWeightsDescr)):
33 conda_env = _get_default_onnx_env(opset_version=entry.opset_version)
34 elif isinstance(
35 entry,
36 (
37 v0_4.PytorchStateDictWeightsDescr,
38 v0_5.PytorchStateDictWeightsDescr,
39 v0_4.TorchscriptWeightsDescr,
40 v0_5.TorchscriptWeightsDescr,
41 ),
42 ):
43 if (
44 isinstance(entry, v0_5.TorchscriptWeightsDescr)
45 or entry.dependencies is None
46 ):
47 conda_env = _get_default_pytorch_env(pytorch_version=entry.pytorch_version)
48 else:
49 conda_env = _get_env_from_deps(entry.dependencies)
51 elif isinstance(
52 entry,
53 (
54 v0_4.TensorflowSavedModelBundleWeightsDescr,
55 v0_5.TensorflowSavedModelBundleWeightsDescr,
56 ),
57 ):
58 if entry.dependencies is None:
59 conda_env = _get_default_tf_env(tensorflow_version=entry.tensorflow_version)
60 else:
61 conda_env = _get_env_from_deps(entry.dependencies)
62 elif isinstance(
63 entry,
64 (v0_4.KerasHdf5WeightsDescr, v0_5.KerasHdf5WeightsDescr),
65 ):
66 conda_env = _get_default_tf_env(tensorflow_version=entry.tensorflow_version)
67 else:
68 assert_never(entry)
70 if env_name == "DROP":
71 conda_env.name = None
72 elif env_name is not None:
73 conda_env.name = env_name
75 return conda_env
78def _get_default_pytorch_env(
79 *,
80 pytorch_version: Optional[Version] = None,
81) -> BioimageioCondaEnv:
82 if pytorch_version is None:
83 pytorch_version = Version("1.10.1")
85 channels = ["conda-forge", "nodefaults"]
87 # dependencies to install pytorch according to
88 # https://pytorch.org/get-started/previous-versions/
89 v = pytorch_version.base_version
90 if v.count(".") == 0:
91 v += ".0.0"
92 elif v.count(".") == 1:
93 v += ".0"
95 deps: List[Union[str, PipDeps]] = [f"pytorch=={v}"]
96 if v == "1.5.1":
97 deps += ["torchvision==0.6.1"]
98 elif v == "1.6.0":
99 deps += ["torchvision==0.7.0"]
100 elif v == "1.7.0":
101 deps += ["torchvision==0.8.0", "torchaudio==0.7.0"]
102 elif v == "1.7.1":
103 deps += ["torchvision==0.8.2", "torchaudio==0.7.1"]
104 elif v == "1.8.0":
105 deps += ["torchvision==0.9.0", "torchaudio==0.8.0"]
106 elif v == "1.8.1":
107 deps += ["torchvision==0.9.1", "torchaudio==0.8.1"]
108 elif v == "1.9.0":
109 deps += ["torchvision==0.10.0", "torchaudio==0.9.0"]
110 elif v == "1.9.1":
111 deps += ["torchvision==0.10.1", "torchaudio==0.9.1"]
112 elif v == "1.10.0":
113 deps += ["torchvision==0.11.0", "torchaudio==0.10.0"]
114 elif v == "1.10.1":
115 deps += ["torchvision==0.11.2", "torchaudio==0.10.1"]
116 elif v == "1.11.0":
117 deps += ["torchvision==0.12.0", "torchaudio==0.11.0"]
118 elif v == "1.12.0":
119 deps += ["torchvision==0.13.0", "torchaudio==0.12.0"]
120 elif v == "1.12.1":
121 deps += ["torchvision==0.13.1", "torchaudio==0.12.1"]
122 elif v == "1.13.0":
123 deps += ["torchvision==0.14.0", "torchaudio==0.13.0"]
124 elif v == "1.13.1":
125 deps += ["torchvision==0.14.1", "torchaudio==0.13.1"]
126 elif v == "2.0.0":
127 deps += ["torchvision==0.15.0", "torchaudio==2.0.0"]
128 elif v == "2.0.1":
129 deps += ["torchvision==0.15.2", "torchaudio==2.0.2"]
130 elif v == "2.1.0":
131 deps += ["torchvision==0.16.0", "torchaudio==2.1.0"]
132 elif v == "2.1.1":
133 deps += ["torchvision==0.16.1", "torchaudio==2.1.1"]
134 elif v == "2.1.2":
135 deps += ["torchvision==0.16.2", "torchaudio==2.1.2"]
136 elif v == "2.2.0":
137 deps += ["torchvision==0.17.0", "torchaudio==2.2.0"]
138 elif v == "2.2.1":
139 deps += ["torchvision==0.17.1", "torchaudio==2.2.1"]
140 elif v == "2.2.2":
141 deps += ["torchvision==0.17.2", "torchaudio==2.2.2"]
142 elif v == "2.3.0":
143 deps += ["torchvision==0.18.0", "torchaudio==2.3.0"]
144 elif v == "2.3.1":
145 deps += ["torchvision==0.18.1", "torchaudio==2.3.1"]
146 elif v == "2.4.0":
147 deps += ["torchvision==0.19.0", "torchaudio==2.4.0"]
148 elif v == "2.4.1":
149 deps += ["torchvision==0.19.1", "torchaudio==2.4.1"]
150 elif v == "2.5.0":
151 deps += ["torchvision==0.20.0", "torchaudio==2.5.0"]
152 elif v == "2.5.1":
153 deps += ["torchvision==0.20.1", "torchaudio==2.5.1"]
154 elif v == "2.6.0":
155 deps += ["torchvision==0.21.0", "torchaudio==2.6.0"]
156 elif v == "2.7.0":
157 deps += ["torchvision==0.22.0", "torchaudio==2.7.0"]
158 elif v == "2.7.1":
159 deps += ["torchvision==0.22.1", "torchaudio==2.7.1"]
160 elif v == "2.8.0":
161 deps += ["torchvision==0.23.0", "torchaudio==2.8.0"]
162 elif v == "2.9.0":
163 deps += ["torchvision==0.24.0", "torchaudio==2.9.0"]
164 elif v == "2.9.1":
165 deps += ["torchvision==0.24.1", "torchaudio==2.9.1"]
166 else:
167 set_github_warning(
168 "UPDATE NEEDED",
169 f"Leaving torchvision and torchaudio unpinned for pytorch=={v}",
170 )
171 deps += ["torchvision", "torchaudio"]
173 # avoid `undefined symbol: iJIT_NotifyEvent` from `torch/lib/libtorch_cpu.so`
174 # see https://github.com/pytorch/pytorch/issues/123097
175 if (
176 pytorch_version
177 < Version(
178 "2.1.0" # TODO: check if this is the correct cutoff where the fix is not longer needed
179 )
180 ):
181 deps.append("mkl ==2024.0.0")
183 if pytorch_version < Version("2.2"):
184 # avoid ImportError: cannot import name 'packaging' from 'pkg_resources'
185 # see https://github.com/pypa/setuptools/issues/4376#issuecomment-2126162839
186 deps.append("setuptools <70.0.0")
188 if pytorch_version < Version("2.3"):
189 # see https://github.com/pytorch/pytorch/issues/107302
190 deps.append("numpy <2")
192 return BioimageioCondaEnv(channels=channels, dependencies=deps)
195def _get_default_onnx_env(*, opset_version: Optional[int]) -> BioimageioCondaEnv:
196 if opset_version is None:
197 opset_version = 15
199 # note: we should not need to worry about the opset version,
200 # see https://github.com/microsoft/onnxruntime/blob/master/docs/Versioning.md
201 return BioimageioCondaEnv(dependencies=["onnxruntime"])
204def _get_default_tf_env(tensorflow_version: Optional[Version]) -> BioimageioCondaEnv:
205 if tensorflow_version is None or tensorflow_version.major < 2:
206 tensorflow_version = Version("2.17")
208 return BioimageioCondaEnv(
209 dependencies=[f"tensorflow =={tensorflow_version}"],
210 )
213def _get_env_from_deps(
214 deps: Union[v0_4.Dependencies, FileDescr],
215) -> BioimageioCondaEnv:
216 if isinstance(deps, v0_4.Dependencies):
217 deps_reader = get_reader(deps.file)
218 if deps.manager == "pip":
219 pip_deps_str = deps_reader.read_text()
220 pip_deps = [d.strip() for d in pip_deps_str.split("\n")]
221 if "bioimageio.core" not in pip_deps:
222 pip_deps.append("bioimageio.core>=0.9.4")
224 return BioimageioCondaEnv(
225 dependencies=[PipDeps(pip=pip_deps)],
226 )
227 elif deps.manager in ("conda", "mamba"):
228 return BioimageioCondaEnv.model_validate(read_yaml(deps_reader))
229 else:
230 raise ValueError(f"Dependency manager {deps.manager} not supported")
232 elif isinstance(deps, FileDescr):
233 deps_reader = deps.get_reader()
234 return BioimageioCondaEnv.model_validate(read_yaml(deps_reader))
235 else:
236 assert_never(deps)