Coverage for src / bioimageio / core / backends / keras_backend.py: 80%
54 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-13 09:46 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-13 09:46 +0000
1import os
2import shutil
3from pathlib import Path
4from tempfile import TemporaryDirectory
5from typing import Any, Optional, Sequence, Union
7from keras.src.legacy.saving import ( # pyright: ignore[reportMissingTypeStubs]
8 legacy_h5_format,
9)
10from loguru import logger
11from numpy.typing import NDArray
13from bioimageio.spec.model import v0_4, v0_5
14from bioimageio.spec.model.v0_5 import Version
16from .._settings import settings
17from ..digest_spec import get_axes_infos
18from ..utils._type_guards import is_list, is_tuple
19from ._model_adapter import ModelAdapter
21os.environ["KERAS_BACKEND"] = settings.keras_backend
24# by default, we use the keras integrated with tensorflow
25# TODO: check if we should prefer keras
26try:
27 import tensorflow as tf # pyright: ignore[reportMissingTypeStubs]
28 from tensorflow import ( # pyright: ignore[reportMissingTypeStubs]
29 keras, # pyright: ignore[reportUnknownVariableType,reportAttributeAccessIssue]
30 )
32 tf_version = Version(tf.__version__)
33except Exception:
34 import keras # pyright: ignore[reportMissingTypeStubs]
36 tf_version = None
39class KerasModelAdapter(ModelAdapter):
40 def __init__(
41 self,
42 *,
43 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
44 devices: Optional[Sequence[str]] = None,
45 ) -> None:
46 super().__init__(model_description=model_description)
47 if model_description.weights.keras_hdf5 is None:
48 raise ValueError("model has not keras_hdf5 weights specified")
49 model_tf_version = model_description.weights.keras_hdf5.tensorflow_version
51 if tf_version is None or model_tf_version is None:
52 logger.warning("Could not check tensorflow versions.")
53 elif model_tf_version > tf_version:
54 logger.warning(
55 "The model specifies a newer tensorflow version than installed: {} > {}.",
56 model_tf_version,
57 tf_version,
58 )
59 elif (model_tf_version.major, model_tf_version.minor) != (
60 tf_version.major,
61 tf_version.minor,
62 ):
63 logger.warning(
64 "Model tensorflow version {} does not match {}.",
65 model_tf_version,
66 tf_version,
67 )
69 # TODO keras device management
70 if devices is not None:
71 logger.warning(
72 "Device management is not implemented for keras yet, ignoring the devices {}",
73 devices,
74 )
76 weight_reader = model_description.weights.keras_hdf5.get_reader()
77 if weight_reader.suffix in (".h5", "hdf5"):
78 import h5py # pyright: ignore[reportMissingTypeStubs]
80 h5_file = h5py.File(weight_reader, mode="r")
81 self._network = legacy_h5_format.load_model_from_hdf5(h5_file)
82 else:
83 with TemporaryDirectory() as temp_dir:
84 temp_path = Path(temp_dir) / weight_reader.original_file_name
85 with temp_path.open("wb") as f:
86 shutil.copyfileobj(weight_reader, f)
88 self._network = keras.models.load_model(temp_path)
90 self._output_axes = [
91 tuple(a.id for a in get_axes_infos(out))
92 for out in model_description.outputs
93 ]
95 def _forward_impl( # pyright: ignore[reportUnknownParameterType]
96 self, input_arrays: Sequence[Optional[NDArray[Any]]]
97 ):
98 network_output = self._network.predict(*input_arrays) # type: ignore
99 if is_list(network_output) or is_tuple(network_output):
100 return network_output
101 else:
102 return [network_output] # pyright: ignore[reportUnknownVariableType]
104 def unload(self) -> None:
105 logger.warning(
106 "Device management is not implemented for keras yet, cannot unload model"
107 )