Coverage for bioimageio/spec/_get_conda_env.py: 79%
128 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-27 09:20 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-27 09:20 +0000
1from typing import List, Literal, Optional, Union
3from typing_extensions import assert_never
5from ._internal.gh_utils import set_github_warning
6from ._internal.io import FileDescr, get_reader
7from ._internal.io_utils import read_yaml
8from .conda_env import BioimageioCondaEnv, PipDeps
9from .model import v0_4, v0_5
10from .model.v0_5 import Version
12SupportedWeightsEntry = Union[
13 v0_4.KerasHdf5WeightsDescr,
14 v0_4.OnnxWeightsDescr,
15 v0_4.PytorchStateDictWeightsDescr,
16 v0_4.TensorflowSavedModelBundleWeightsDescr,
17 v0_4.TorchscriptWeightsDescr,
18 v0_5.KerasHdf5WeightsDescr,
19 v0_5.OnnxWeightsDescr,
20 v0_5.PytorchStateDictWeightsDescr,
21 v0_5.TensorflowSavedModelBundleWeightsDescr,
22 v0_5.TorchscriptWeightsDescr,
23]
26def get_conda_env(
27 *,
28 entry: SupportedWeightsEntry,
29 env_name: Optional[Union[Literal["DROP"], str]] = None,
30) -> BioimageioCondaEnv:
31 """get the recommended Conda environment for a given weights entry description"""
32 if isinstance(entry, (v0_4.OnnxWeightsDescr, v0_5.OnnxWeightsDescr)):
33 conda_env = _get_default_onnx_env(opset_version=entry.opset_version)
34 elif isinstance(
35 entry,
36 (
37 v0_4.PytorchStateDictWeightsDescr,
38 v0_5.PytorchStateDictWeightsDescr,
39 v0_4.TorchscriptWeightsDescr,
40 v0_5.TorchscriptWeightsDescr,
41 ),
42 ):
43 if (
44 isinstance(entry, v0_5.TorchscriptWeightsDescr)
45 or entry.dependencies is None
46 ):
47 conda_env = _get_default_pytorch_env(pytorch_version=entry.pytorch_version)
48 else:
49 conda_env = _get_env_from_deps(entry.dependencies)
51 elif isinstance(
52 entry,
53 (
54 v0_4.TensorflowSavedModelBundleWeightsDescr,
55 v0_5.TensorflowSavedModelBundleWeightsDescr,
56 ),
57 ):
58 if entry.dependencies is None:
59 conda_env = _get_default_tf_env(tensorflow_version=entry.tensorflow_version)
60 else:
61 conda_env = _get_env_from_deps(entry.dependencies)
62 elif isinstance(
63 entry,
64 (v0_4.KerasHdf5WeightsDescr, v0_5.KerasHdf5WeightsDescr),
65 ):
66 conda_env = _get_default_tf_env(tensorflow_version=entry.tensorflow_version)
67 else:
68 assert_never(entry)
70 if env_name == "DROP":
71 conda_env.name = None
72 elif env_name is not None:
73 conda_env.name = env_name
75 return conda_env
78def _get_default_pytorch_env(
79 *,
80 pytorch_version: Optional[Version] = None,
81) -> BioimageioCondaEnv:
82 if pytorch_version is None:
83 pytorch_version = Version("1.10.1")
85 channels = ["conda-forge", "nodefaults"]
86 if pytorch_version < Version("2.5.2"):
87 channels.insert(0, "pytorch")
89 # dependencies to install pytorch according to
90 # https://pytorch.org/get-started/previous-versions/
91 v = pytorch_version.base_version
92 if v.count(".") == 0:
93 v += ".0.0"
94 elif v.count(".") == 1:
95 v += ".0"
97 deps: List[Union[str, PipDeps]] = [f"pytorch=={v}"]
98 if v == "1.5.1":
99 deps += ["torchvision==0.6.1"]
100 elif v == "1.6.0":
101 deps += ["torchvision==0.7.0"]
102 elif v == "1.7.0":
103 deps += ["torchvision==0.8.0", "torchaudio==0.7.0"]
104 elif v == "1.7.1":
105 deps += ["torchvision==0.8.2", "torchaudio==0.7.1"]
106 elif v == "1.8.0":
107 deps += ["torchvision==0.9.0", "torchaudio==0.8.0"]
108 elif v == "1.8.1":
109 deps += ["torchvision==0.9.1", "torchaudio==0.8.1"]
110 elif v == "1.9.0":
111 deps += ["torchvision==0.10.0", "torchaudio==0.9.0"]
112 elif v == "1.9.1":
113 deps += ["torchvision==0.10.1", "torchaudio==0.9.1"]
114 elif v == "1.10.0":
115 deps += ["torchvision==0.11.0", "torchaudio==0.10.0"]
116 elif v == "1.10.1":
117 deps += ["torchvision==0.11.2", "torchaudio==0.10.1"]
118 elif v == "1.11.0":
119 deps += ["torchvision==0.12.0", "torchaudio==0.11.0"]
120 elif v == "1.12.0":
121 deps += ["torchvision==0.13.0", "torchaudio==0.12.0"]
122 elif v == "1.12.1":
123 deps += ["torchvision==0.13.1", "torchaudio==0.12.1"]
124 elif v == "1.13.0":
125 deps += ["torchvision==0.14.0", "torchaudio==0.13.0"]
126 elif v == "1.13.1":
127 deps += ["torchvision==0.14.1", "torchaudio==0.13.1"]
128 elif v == "2.0.0":
129 deps += ["torchvision==0.15.0", "torchaudio==2.0.0"]
130 elif v == "2.0.1":
131 deps += ["torchvision==0.15.2", "torchaudio==2.0.2"]
132 elif v == "2.1.0":
133 deps += ["torchvision==0.16.0", "torchaudio==2.1.0"]
134 elif v == "2.1.1":
135 deps += ["torchvision==0.16.1", "torchaudio==2.1.1"]
136 elif v == "2.1.2":
137 deps += ["torchvision==0.16.2", "torchaudio==2.1.2"]
138 elif v == "2.2.0":
139 deps += ["torchvision==0.17.0", "torchaudio==2.2.0"]
140 elif v == "2.2.1":
141 deps += ["torchvision==0.17.1", "torchaudio==2.2.1"]
142 elif v == "2.2.2":
143 deps += ["torchvision==0.17.2", "torchaudio==2.2.2"]
144 elif v == "2.3.0":
145 deps += ["torchvision==0.18.0", "torchaudio==2.3.0"]
146 elif v == "2.3.1":
147 deps += ["torchvision==0.18.1", "torchaudio==2.3.1"]
148 elif v == "2.4.0":
149 deps += ["torchvision==0.19.0", "torchaudio==2.4.0"]
150 elif v == "2.4.1":
151 deps += ["torchvision==0.19.1", "torchaudio==2.4.1"]
152 elif v == "2.5.0":
153 deps += ["torchvision==0.20.0", "torchaudio==2.5.0"]
154 else:
155 set_github_warning(
156 "UPDATE NEEDED",
157 f"Leaving torchvision and torchaudio unpinned for pytorch=={v}",
158 )
159 deps += ["torchvision", "torchaudio"]
161 # avoid `undefined symbol: iJIT_NotifyEvent` from `torch/lib/libtorch_cpu.so`
162 # see https://github.com/pytorch/pytorch/issues/123097
163 if pytorch_version < Version(
164 "2.1.0" # TODO: check if this is the correct cutoff where the fix is not longer needed
165 ):
166 deps.append("mkl ==2024.0.0")
168 if pytorch_version < Version("2.2"):
169 # avoid ImportError: cannot import name 'packaging' from 'pkg_resources'
170 # see https://github.com/pypa/setuptools/issues/4376#issuecomment-2126162839
171 deps.append("setuptools <70.0.0")
173 if pytorch_version < Version(
174 "2.3"
175 ): # TODO: verify that future pytorch 2.4 is numpy 2.0 compatible
176 # make sure we have a compatible numpy version
177 # see https://github.com/pytorch/vision/issues/8460
178 deps.append("numpy <2")
179 else:
180 deps.append("numpy >=2,<3")
182 return BioimageioCondaEnv(channels=channels, dependencies=deps)
185def _get_default_onnx_env(*, opset_version: Optional[int]) -> BioimageioCondaEnv:
186 if opset_version is None:
187 opset_version = 15
189 # note: we should not need to worry about the opset version,
190 # see https://github.com/microsoft/onnxruntime/blob/master/docs/Versioning.md
191 return BioimageioCondaEnv(dependencies=["onnxruntime"])
194def _get_default_tf_env(tensorflow_version: Optional[Version]) -> BioimageioCondaEnv:
195 if tensorflow_version is None or tensorflow_version.major < 2:
196 tensorflow_version = Version("2.17")
198 return BioimageioCondaEnv(
199 dependencies=["bioimageio.core", f"tensorflow =={tensorflow_version}"],
200 )
203def _get_env_from_deps(
204 deps: Union[v0_4.Dependencies, FileDescr],
205) -> BioimageioCondaEnv:
206 if isinstance(deps, v0_4.Dependencies):
207 deps_reader = get_reader(deps.file)
208 if deps.manager == "pip":
209 pip_deps_str = deps_reader.read_text()
210 pip_deps = [d.strip() for d in pip_deps_str.split("\n")]
211 if "bioimageio.core" not in pip_deps:
212 pip_deps.append("bioimageio.core")
214 return BioimageioCondaEnv(
215 dependencies=[PipDeps(pip=pip_deps)],
216 )
217 elif deps.manager in ("conda", "mamba"):
218 return BioimageioCondaEnv.model_validate(read_yaml(deps_reader))
219 else:
220 raise ValueError(f"Dependency manager {deps.manager} not supported")
222 elif isinstance(deps, FileDescr):
223 deps_reader = deps.get_reader()
224 return BioimageioCondaEnv.model_validate(read_yaml(deps_reader))
225 else:
226 assert_never(deps)