Coverage for src / bioimageio / core / backends / keras_backend.py: 69%
65 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-27 22:06 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-27 22:06 +0000
1import os
2import shutil
3from pathlib import Path
4from tempfile import TemporaryDirectory
5from typing import Any, Optional, Sequence, Union
7from keras.src.legacy.saving import ( # pyright: ignore[reportMissingTypeStubs]
8 legacy_h5_format,
9)
10from loguru import logger
11from numpy.typing import NDArray
13from bioimageio.core.utils._compare import warn_about_version
14from bioimageio.spec._internal.version_type import Version
15from bioimageio.spec.model import v0_4, v0_5
17from .._settings import settings
18from ..digest_spec import get_axes_infos
19from ..utils._type_guards import is_list, is_tuple
20from ._model_adapter import ModelAdapter
22os.environ["KERAS_BACKEND"] = settings.keras_backend
25# by default, we use the keras integrated with tensorflow
26# TODO: check if we should prefer keras
27try:
28 import tensorflow as tf
29 from tensorflow import keras
31 tf_version = Version(tf.__version__) # pyright: ignore[reportUnknownArgumentType]
32except Exception:
33 import keras # pyright: ignore[reportMissingTypeStubs]
35 tf_version = None
38class KerasModelAdapter(ModelAdapter):
39 def __init__(
40 self,
41 *,
42 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
43 devices: Optional[Sequence[str]] = None,
44 ) -> None:
45 super().__init__(model_description=model_description)
47 if (
48 not isinstance(model_description, v0_4.ModelDescr)
49 and model_description.weights.keras_v3 is not None
50 ):
51 weight_reader = model_description.weights.keras_v3.get_reader()
52 backend, backend_version = model_description.weights.keras_v3.backend
53 elif model_description.weights.keras_hdf5 is not None:
54 backend = "legacy_tensorflow"
55 backend_version = model_description.weights.keras_hdf5.tensorflow_version
56 weight_reader = model_description.weights.keras_hdf5.get_reader()
57 else:
58 raise ValueError("model has no Keras weights")
60 if backend != "legacy_tensorflow" and backend != settings.keras_backend:
61 logger.warning(
62 "Model specifies Keras backend '{}', but environment variable KERAS_BACKEND is set to '{}'."
63 + " Attempting to load model with KERAS_BACKEND='{}' (this may fail if the model is not compatible with this backend).",
64 backend,
65 settings.keras_backend,
66 settings.keras_backend,
67 )
69 if (backend == "legacy_tensorflow") or (
70 backend == settings.keras_backend == "tensorflow"
71 ):
72 warn_about_version("tensorflow", backend_version, tf_version)
73 elif backend == settings.keras_backend == "torch":
74 import torch
76 torch_version = Version(torch.__version__)
77 warn_about_version("torch", backend_version, torch_version)
78 elif backend == settings.keras_backend == "jax":
79 import jax
81 jax_version = Version(jax.__version__)
82 warn_about_version("jax", backend_version, jax_version)
84 # TODO keras device management
85 if devices is not None:
86 logger.warning(
87 "Device management is not implemented for keras yet, ignoring the devices {}",
88 devices,
89 )
91 if weight_reader.suffix in (".h5", "hdf5"):
92 import h5py # pyright: ignore[reportMissingTypeStubs]
94 h5_file = h5py.File(weight_reader, mode="r")
95 self._network = legacy_h5_format.load_model_from_hdf5(h5_file)
96 else:
97 with TemporaryDirectory() as temp_dir:
98 temp_path = Path(temp_dir) / weight_reader.original_file_name
99 with temp_path.open("wb") as f:
100 shutil.copyfileobj(weight_reader, f)
102 self._network = keras.models.load_model(temp_path)
104 self._output_axes = [
105 tuple(a.id for a in get_axes_infos(out))
106 for out in model_description.outputs
107 ]
109 def _forward_impl( # pyright: ignore[reportUnknownParameterType]
110 self, input_arrays: Sequence[Optional[NDArray[Any]]]
111 ):
112 network_output = self._network.predict(*input_arrays) # type: ignore
113 if is_list(network_output) or is_tuple(network_output):
114 return network_output
115 else:
116 return [network_output] # pyright: ignore[reportUnknownVariableType]
118 def unload(self) -> None:
119 logger.warning(
120 "Device management is not implemented for keras yet, cannot unload model"
121 )