Coverage for src / bioimageio / spec / _get_conda_env.py: 55%
83 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-31 13:09 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-31 13:09 +0000
1from typing import List, Literal, Optional, Tuple, Union
3from typing_extensions import assert_never
5from ._internal.gh_utils import set_github_warning
6from ._internal.io import FileDescr, get_reader
7from ._internal.io_utils import read_yaml
8from .conda_env import BioimageioCondaEnv, PipDeps
9from .model import v0_4, v0_5
10from .model.v0_5 import Version
12SupportedWeightsEntry = Union[
13 v0_4.KerasHdf5WeightsDescr,
14 v0_4.OnnxWeightsDescr,
15 v0_4.PytorchStateDictWeightsDescr,
16 v0_4.TensorflowSavedModelBundleWeightsDescr,
17 v0_4.TorchscriptWeightsDescr,
18 v0_5.KerasHdf5WeightsDescr,
19 v0_5.KerasV3WeightsDescr,
20 v0_5.OnnxWeightsDescr,
21 v0_5.PytorchStateDictWeightsDescr,
22 v0_5.TensorflowSavedModelBundleWeightsDescr,
23 v0_5.TorchscriptWeightsDescr,
24]
27def get_conda_env(
28 *,
29 entry: SupportedWeightsEntry,
30 env_name: Optional[Union[Literal["DROP"], str]] = None,
31) -> BioimageioCondaEnv:
32 """get the recommended Conda environment for a given weights entry description"""
33 if isinstance(entry, (v0_4.OnnxWeightsDescr, v0_5.OnnxWeightsDescr)):
34 conda_env = _get_default_onnx_env(opset_version=entry.opset_version)
35 elif isinstance(
36 entry,
37 (
38 v0_4.PytorchStateDictWeightsDescr,
39 v0_5.PytorchStateDictWeightsDescr,
40 v0_4.TorchscriptWeightsDescr,
41 v0_5.TorchscriptWeightsDescr,
42 ),
43 ):
44 if (
45 isinstance(entry, v0_5.TorchscriptWeightsDescr)
46 or entry.dependencies is None
47 ):
48 conda_env = _get_default_pytorch_env(pytorch_version=entry.pytorch_version)
49 else:
50 conda_env = _get_env_from_deps(entry.dependencies)
52 elif isinstance(
53 entry,
54 (
55 v0_4.TensorflowSavedModelBundleWeightsDescr,
56 v0_5.TensorflowSavedModelBundleWeightsDescr,
57 ),
58 ):
59 if entry.dependencies is None:
60 conda_env = _get_default_tf_env(tensorflow_version=entry.tensorflow_version)
61 else:
62 conda_env = _get_env_from_deps(entry.dependencies)
63 elif isinstance(
64 entry,
65 (v0_4.KerasHdf5WeightsDescr, v0_5.KerasHdf5WeightsDescr),
66 ):
67 conda_env = _get_default_tf_env(tensorflow_version=entry.tensorflow_version)
68 elif isinstance(entry, v0_5.KerasV3WeightsDescr):
69 conda_env = _get_default_keras3_env(entry.backend)
70 else:
71 assert_never(entry)
73 if env_name == "DROP":
74 conda_env.name = None
75 elif env_name is not None:
76 conda_env.name = env_name
78 return conda_env
81def _get_default_keras3_env(
82 backend: Tuple[Literal["tensorflow", "jax", "torch"], Version],
83) -> BioimageioCondaEnv:
84 if backend[0] == "tensorflow":
85 env = _get_default_tf_env(tensorflow_version=backend[1])
86 elif backend[0] == "torch":
87 env = _get_default_pytorch_env(pytorch_version=backend[1])
88 elif backend[0] == "jax":
89 env = BioimageioCondaEnv(
90 dependencies=[
91 f"jax=={backend[1]}",
92 ]
93 )
94 else:
95 assert_never(backend[0])
97 env.dependencies.append("keras >=3.0, <4")
98 return env
101def _get_default_pytorch_env(
102 *,
103 pytorch_version: Optional[Version] = None,
104) -> BioimageioCondaEnv:
105 if pytorch_version is None:
106 pytorch_version = Version("1.10.1")
108 channels = ["conda-forge", "nodefaults"]
110 # dependencies to install pytorch according to
111 # https://pytorch.org/get-started/previous-versions/
112 v = pytorch_version.base_version
113 if v.count(".") == 0:
114 v += ".0.0"
115 elif v.count(".") == 1:
116 v += ".0"
118 deps: List[Union[str, PipDeps]] = [f"pytorch=={v}"]
119 additional_deps = {
120 "1.5.1": "torchvision==0.6.1",
121 "1.6.0": "torchvision==0.7.0",
122 "1.7.0": "torchvision==0.8.0",
123 "1.7.1": "torchvision==0.8.2",
124 "1.8.0": "torchvision==0.9.0",
125 "1.8.1": "torchvision==0.9.1",
126 "1.9.0": "torchvision==0.10.0",
127 "1.9.1": "torchvision==0.10.1",
128 "1.10.0": "torchvision==0.11.0",
129 "1.10.1": "torchvision==0.11.2",
130 "1.11.0": "torchvision==0.12.0",
131 "1.12.0": "torchvision==0.13.0",
132 "1.12.1": "torchvision==0.13.1",
133 "1.13.0": "torchvision==0.14.0",
134 "1.13.1": "torchvision==0.14.1",
135 "2.0.0": "torchvision==0.15.0",
136 "2.0.1": "torchvision==0.15.2",
137 "2.1.0": "torchvision==0.16.0",
138 "2.1.1": "torchvision==0.16.1",
139 "2.1.2": "torchvision==0.16.2",
140 "2.2.0": "torchvision==0.17.0",
141 "2.2.1": "torchvision==0.17.1",
142 "2.2.2": "torchvision==0.17.2",
143 "2.3.0": "torchvision==0.18.0",
144 "2.3.1": "torchvision==0.18.1",
145 "2.4.0": "torchvision==0.19.0",
146 "2.4.1": "torchvision==0.19.1",
147 "2.5.0": "torchvision==0.20.0",
148 "2.5.1": "torchvision==0.20.1",
149 "2.6.0": "torchvision==0.21.0",
150 "2.7.0": "torchvision==0.22.0",
151 "2.7.1": "torchvision==0.22.1",
152 "2.8.0": "torchvision==0.23.0",
153 "2.9.0": "torchvision==0.24.0",
154 "2.9.1": "torchvision==0.24.1",
155 }.get(v)
156 if additional_deps is None:
157 set_github_warning(
158 "UPDATE NEEDED",
159 f"Leaving torchvision unpinned for pytorch=={v}",
160 )
161 additional_deps = "torchvision"
163 deps.append(additional_deps)
165 # avoid `undefined symbol: iJIT_NotifyEvent` from `torch/lib/libtorch_cpu.so`
166 # see https://github.com/pytorch/pytorch/issues/123097
167 if (
168 pytorch_version
169 < Version(
170 "2.1.0" # TODO: check if this is the correct cutoff where the fix is not longer needed
171 )
172 ):
173 deps.append("mkl ==2024.0.0")
175 if pytorch_version < Version("2.2"):
176 # avoid ImportError: cannot import name 'packaging' from 'pkg_resources'
177 # see https://github.com/pypa/setuptools/issues/4376#issuecomment-2126162839
178 deps.append("setuptools <70.0.0")
180 if pytorch_version < Version("2.3"):
181 # see https://github.com/pytorch/pytorch/issues/107302
182 deps.append("numpy <2")
184 return BioimageioCondaEnv(channels=channels, dependencies=deps)
187def _get_default_onnx_env(*, opset_version: Optional[int]) -> BioimageioCondaEnv:
188 if opset_version is None:
189 opset_version = 15
191 # note: we should not need to worry about the opset version,
192 # see https://github.com/microsoft/onnxruntime/blob/master/docs/Versioning.md
193 return BioimageioCondaEnv(dependencies=["onnxruntime"])
196def _get_default_tf_env(tensorflow_version: Optional[Version]) -> BioimageioCondaEnv:
197 if tensorflow_version is None or tensorflow_version.major < 2:
198 tensorflow_version = Version("2.17")
200 return BioimageioCondaEnv(
201 dependencies=[f"tensorflow =={tensorflow_version}"],
202 )
205def _get_env_from_deps(
206 deps: Union[v0_4.Dependencies, FileDescr],
207) -> BioimageioCondaEnv:
208 if isinstance(deps, v0_4.Dependencies):
209 deps_reader = get_reader(deps.file)
210 if deps.manager == "pip":
211 pip_deps_str = deps_reader.read_text()
212 pip_deps = [d.strip() for d in pip_deps_str.split("\n")]
213 if "bioimageio.core" not in pip_deps:
214 pip_deps.append("bioimageio.core>=0.9.4")
216 return BioimageioCondaEnv(
217 dependencies=[PipDeps(pip=pip_deps)],
218 )
219 elif deps.manager in ("conda", "mamba"):
220 return BioimageioCondaEnv.model_validate(read_yaml(deps_reader))
221 else:
222 raise ValueError(f"Dependency manager {deps.manager} not supported")
224 elif isinstance(deps, FileDescr):
225 deps_reader = deps.get_reader()
226 return BioimageioCondaEnv.model_validate(read_yaml(deps_reader))
227 else:
228 assert_never(deps)