Coverage for src/bioimageio/core/backends/keras_backend.py: 0%
66 statements
« prev ^ index » next coverage.py v7.14.2, created at 2026-06-22 16:54 +0000
« prev ^ index » next coverage.py v7.14.2, created at 2026-06-22 16:54 +0000
1import os
2import shutil
3from pathlib import Path
4from tempfile import TemporaryDirectory
5from typing import Any, Optional, Sequence, Tuple
7from keras.src.legacy.saving import ( # pyright: ignore[reportMissingTypeStubs]
8 legacy_h5_format,
9)
10from loguru import logger
11from numpy.typing import NDArray
13from bioimageio.spec._internal.version_type import Version
14from bioimageio.spec.model import v0_4
16from .._model_adapter import LocalModelAdapter
17from .._settings import settings
18from ..utils._compare import warn_about_version
19from ..utils._type_guards import is_list, is_tuple
21os.environ["KERAS_BACKEND"] = settings.keras_backend
24# by default, we use the keras integrated with tensorflow
25# TODO: check if we should prefer keras
26try:
27 import tensorflow as tf
28 from tensorflow import keras
30 tf_version = Version(tf.__version__) # pyright: ignore[reportUnknownArgumentType]
31except Exception:
32 import keras # pyright: ignore[reportMissingTypeStubs]
34 tf_version = None
37class KerasModelAdapter(LocalModelAdapter[None, Any]):
38 def _parse_devices(self, devices: Optional[Sequence[str]]) -> Tuple[None]:
39 # TODO keras device management
40 if devices is not None:
41 logger.warning(
42 "Device management is not implemented for keras yet, ignoring the devices {}",
43 devices,
44 )
45 return (None,)
47 def _init_model_on_device(self, device: None) -> Any:
48 if (
49 not isinstance(self._model_descr, v0_4.ModelDescr)
50 and self._model_descr.weights.keras_v3 is not None
51 ):
52 weight_reader = self._model_descr.weights.keras_v3.get_reader()
53 backend, backend_version = self._model_descr.weights.keras_v3.backend
54 elif self._model_descr.weights.keras_hdf5 is not None:
55 backend = "legacy_tensorflow"
56 backend_version = self._model_descr.weights.keras_hdf5.tensorflow_version
57 weight_reader = self._model_descr.weights.keras_hdf5.get_reader()
58 else:
59 raise ValueError("model has no Keras weights")
61 if backend != "legacy_tensorflow" and backend != settings.keras_backend:
62 logger.warning(
63 "Model specifies Keras backend '{}', but environment variable KERAS_BACKEND is set to '{}'."
64 + " Attempting to load model with KERAS_BACKEND='{}' (this may fail if the model is not compatible with this backend).",
65 backend,
66 settings.keras_backend,
67 settings.keras_backend,
68 )
70 if (backend == "legacy_tensorflow") or (
71 backend == settings.keras_backend == "tensorflow"
72 ):
73 warn_about_version("tensorflow", backend_version, tf_version)
74 elif backend == settings.keras_backend == "torch":
75 import torch
77 torch_version = Version(torch.__version__)
78 warn_about_version("torch", backend_version, torch_version)
79 elif backend == settings.keras_backend == "jax":
80 import jax
82 jax_version = Version(jax.__version__)
83 warn_about_version("jax", backend_version, jax_version)
85 if weight_reader.suffix in (".h5", "hdf5"):
86 import h5py # pyright: ignore[reportMissingTypeStubs]
88 h5_file = h5py.File(weight_reader, mode="r")
89 return legacy_h5_format.load_model_from_hdf5(h5_file) # pyright: ignore[reportUnknownVariableType]
90 else:
91 with TemporaryDirectory() as temp_dir:
92 temp_path = Path(temp_dir) / weight_reader.original_file_name
93 with temp_path.open("wb") as f:
94 shutil.copyfileobj(weight_reader, f)
96 return keras.models.load_model(temp_path) # pyright: ignore[reportUnknownVariableType]
98 def _forward_impl(
99 self,
100 device: None,
101 model: Any,
102 input_arrays: Sequence[Optional[NDArray[Any]]],
103 ):
104 network_output = model.predict(*input_arrays)
105 if is_list(network_output) or is_tuple(network_output):
106 return network_output
107 else:
108 return [network_output]
110 def _cleanup_pre_model_deletion(self, device: None, model: Any) -> None:
111 return
113 def _cleanup_post_model_deletion(self, device: None) -> None:
114 return