Coverage for bioimageio/spec/_get_conda_env.py: 77%
133 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-02 14:21 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-02 14:21 +0000
1from typing import List, Literal, Optional, Union
3from typing_extensions import assert_never
5from ._internal.gh_utils import set_github_warning
6from ._internal.io_utils import ZipPath, read_yaml
7from .common import RelativeFilePath
8from .conda_env import BioimageioCondaEnv, PipDeps
9from .model import v0_4, v0_5
10from .model.v0_5 import Version
11from .utils import download
13SupportedWeightsEntry = Union[
14 v0_4.KerasHdf5WeightsDescr,
15 v0_4.OnnxWeightsDescr,
16 v0_4.PytorchStateDictWeightsDescr,
17 v0_4.TensorflowSavedModelBundleWeightsDescr,
18 v0_4.TorchscriptWeightsDescr,
19 v0_5.KerasHdf5WeightsDescr,
20 v0_5.OnnxWeightsDescr,
21 v0_5.PytorchStateDictWeightsDescr,
22 v0_5.TensorflowSavedModelBundleWeightsDescr,
23 v0_5.TorchscriptWeightsDescr,
24]
27def get_conda_env(
28 *,
29 entry: SupportedWeightsEntry,
30 env_name: Optional[Union[Literal["DROP"], str]] = None,
31) -> BioimageioCondaEnv:
32 """get the recommended Conda environment for a given weights entry description"""
33 if isinstance(entry, (v0_4.OnnxWeightsDescr, v0_5.OnnxWeightsDescr)):
34 conda_env = _get_default_onnx_env(opset_version=entry.opset_version)
35 elif isinstance(
36 entry,
37 (
38 v0_4.PytorchStateDictWeightsDescr,
39 v0_5.PytorchStateDictWeightsDescr,
40 v0_4.TorchscriptWeightsDescr,
41 v0_5.TorchscriptWeightsDescr,
42 ),
43 ):
44 if (
45 isinstance(entry, v0_5.TorchscriptWeightsDescr)
46 or entry.dependencies is None
47 ):
48 conda_env = _get_default_pytorch_env(pytorch_version=entry.pytorch_version)
49 else:
50 conda_env = _get_env_from_deps(entry.dependencies)
52 elif isinstance(
53 entry,
54 (
55 v0_4.TensorflowSavedModelBundleWeightsDescr,
56 v0_5.TensorflowSavedModelBundleWeightsDescr,
57 ),
58 ):
59 if entry.dependencies is None:
60 conda_env = _get_default_tf_env(tensorflow_version=entry.tensorflow_version)
61 else:
62 conda_env = _get_env_from_deps(entry.dependencies)
63 elif isinstance(
64 entry,
65 (v0_4.KerasHdf5WeightsDescr, v0_5.KerasHdf5WeightsDescr),
66 ):
67 conda_env = _get_default_tf_env(tensorflow_version=entry.tensorflow_version)
68 else:
69 assert_never(entry)
71 if env_name == "DROP":
72 conda_env.name = None
73 elif env_name is not None:
74 conda_env.name = env_name
76 return conda_env
79def _get_default_pytorch_env(
80 *,
81 pytorch_version: Optional[Version] = None,
82) -> BioimageioCondaEnv:
83 if pytorch_version is None:
84 pytorch_version = Version("1.10.1")
86 channels = ["conda-forge", "nodefaults"]
87 if pytorch_version < Version("2.5.2"):
88 channels.insert(0, "pytorch")
90 # dependencies to install pytorch according to
91 # https://pytorch.org/get-started/previous-versions/
92 v = pytorch_version.base_version
93 if v.count(".") == 0:
94 v += ".0.0"
95 elif v.count(".") == 1:
96 v += ".0"
98 deps: List[Union[str, PipDeps]] = [f"pytorch=={v}"]
99 if v == "1.5.1":
100 deps += ["torchvision==0.6.1"]
101 elif v == "1.6.0":
102 deps += ["torchvision==0.7.0"]
103 elif v == "1.7.0":
104 deps += ["torchvision==0.8.0", "torchaudio==0.7.0"]
105 elif v == "1.7.1":
106 deps += ["torchvision==0.8.2", "torchaudio==0.7.1"]
107 elif v == "1.8.0":
108 deps += ["torchvision==0.9.0", "torchaudio==0.8.0"]
109 elif v == "1.8.1":
110 deps += ["torchvision==0.9.1", "torchaudio==0.8.1"]
111 elif v == "1.9.0":
112 deps += ["torchvision==0.10.0", "torchaudio==0.9.0"]
113 elif v == "1.9.1":
114 deps += ["torchvision==0.10.1", "torchaudio==0.9.1"]
115 elif v == "1.10.0":
116 deps += ["torchvision==0.11.0", "torchaudio==0.10.0"]
117 elif v == "1.10.1":
118 deps += ["torchvision==0.11.2", "torchaudio==0.10.1"]
119 elif v == "1.11.0":
120 deps += ["torchvision==0.12.0", "torchaudio==0.11.0"]
121 elif v == "1.12.0":
122 deps += ["torchvision==0.13.0", "torchaudio==0.12.0"]
123 elif v == "1.12.1":
124 deps += ["torchvision==0.13.1", "torchaudio==0.12.1"]
125 elif v == "1.13.0":
126 deps += ["torchvision==0.14.0", "torchaudio==0.13.0"]
127 elif v == "1.13.1":
128 deps += ["torchvision==0.14.1", "torchaudio==0.13.1"]
129 elif v == "2.0.0":
130 deps += ["torchvision==0.15.0", "torchaudio==2.0.0"]
131 elif v == "2.0.1":
132 deps += ["torchvision==0.15.2", "torchaudio==2.0.2"]
133 elif v == "2.1.0":
134 deps += ["torchvision==0.16.0", "torchaudio==2.1.0"]
135 elif v == "2.1.1":
136 deps += ["torchvision==0.16.1", "torchaudio==2.1.1"]
137 elif v == "2.1.2":
138 deps += ["torchvision==0.16.2", "torchaudio==2.1.2"]
139 elif v == "2.2.0":
140 deps += ["torchvision==0.17.0", "torchaudio==2.2.0"]
141 elif v == "2.2.1":
142 deps += ["torchvision==0.17.1", "torchaudio==2.2.1"]
143 elif v == "2.2.2":
144 deps += ["torchvision==0.17.2", "torchaudio==2.2.2"]
145 elif v == "2.3.0":
146 deps += ["torchvision==0.18.0", "torchaudio==2.3.0"]
147 elif v == "2.3.1":
148 deps += ["torchvision==0.18.1", "torchaudio==2.3.1"]
149 elif v == "2.4.0":
150 deps += ["torchvision==0.19.0", "torchaudio==2.4.0"]
151 elif v == "2.4.1":
152 deps += ["torchvision==0.19.1", "torchaudio==2.4.1"]
153 elif v == "2.5.0":
154 deps += ["torchvision==0.20.0", "torchaudio==2.5.0"]
155 else:
156 set_github_warning(
157 "UPDATE NEEDED",
158 f"Leaving torchvision and torchaudio unpinned for pytorch=={v}",
159 )
160 deps += ["torchvision", "torchaudio"]
162 # avoid `undefined symbol: iJIT_NotifyEvent` from `torch/lib/libtorch_cpu.so`
163 # see https://github.com/pytorch/pytorch/issues/123097
164 if pytorch_version < Version(
165 "2.1.0" # TODO: check if this is the correct cutoff where the fix is not longer needed
166 ):
167 deps.append("mkl ==2024.0.0")
169 if pytorch_version < Version("2.2"):
170 # avoid ImportError: cannot import name 'packaging' from 'pkg_resources'
171 # see https://github.com/pypa/setuptools/issues/4376#issuecomment-2126162839
172 deps.append("setuptools <70.0.0")
174 if pytorch_version < Version(
175 "2.3"
176 ): # TODO: verify that future pytorch 2.4 is numpy 2.0 compatible
177 # make sure we have a compatible numpy version
178 # see https://github.com/pytorch/vision/issues/8460
179 deps.append("numpy <2")
180 else:
181 deps.append("numpy >=2,<3")
183 return BioimageioCondaEnv(channels=channels, dependencies=deps)
186def _get_default_onnx_env(*, opset_version: Optional[int]) -> BioimageioCondaEnv:
187 if opset_version is None:
188 opset_version = 15
190 # note: we should not need to worry about the opset version,
191 # see https://github.com/microsoft/onnxruntime/blob/master/docs/Versioning.md
192 return BioimageioCondaEnv(dependencies=["onnxruntime"])
195def _get_default_tf_env(tensorflow_version: Optional[Version]) -> BioimageioCondaEnv:
196 if tensorflow_version is None or tensorflow_version.major < 2:
197 tensorflow_version = Version("2.17")
199 return BioimageioCondaEnv(
200 dependencies=["bioimageio.core", f"tensorflow =={tensorflow_version}"],
201 )
204def _get_env_from_deps(
205 deps: Union[v0_4.Dependencies, v0_5.EnvironmentFileDescr],
206) -> BioimageioCondaEnv:
207 if isinstance(deps, v0_4.Dependencies):
208 if deps.manager == "pip":
209 pip_deps_str = download(deps.file).path.read_text(encoding="utf-8")
210 assert isinstance(pip_deps_str, str)
211 pip_deps = [d.strip() for d in pip_deps_str.split("\n")]
212 if "bioimageio.core" not in pip_deps:
213 pip_deps.append("bioimageio.core")
215 return BioimageioCondaEnv(
216 dependencies=[PipDeps(pip=pip_deps)],
217 )
218 elif deps.manager not in ("conda", "mamba"):
219 raise ValueError(f"Dependency manager {deps.manager} not supported")
220 else:
221 deps_source = (
222 deps.file.absolute()
223 if isinstance(deps.file, RelativeFilePath)
224 else deps.file
225 )
226 if isinstance(deps_source, ZipPath):
227 local = deps_source
228 else:
229 local = download(deps_source).path
231 return BioimageioCondaEnv.model_validate(read_yaml(local))
232 elif isinstance(deps, v0_5.EnvironmentFileDescr):
233 local = download(deps.source).path
234 return BioimageioCondaEnv.model_validate(read_yaml(local))
235 else:
236 assert_never(deps)