Coverage for bioimageio/core/backends/keras_backend.py: 81%
54 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-16 15:20 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-16 15:20 +0000
1import os
2import shutil
3from pathlib import Path
4from tempfile import TemporaryDirectory
5from typing import Any, Optional, Sequence, Union
7import h5py # pyright: ignore[reportMissingTypeStubs]
8from keras.src.legacy.saving import ( # pyright: ignore[reportMissingTypeStubs]
9 legacy_h5_format,
10)
11from loguru import logger
12from numpy.typing import NDArray
14from bioimageio.spec.model import v0_4, v0_5
15from bioimageio.spec.model.v0_5 import Version
17from .._settings import settings
18from ..digest_spec import get_axes_infos
19from ..utils._type_guards import is_list, is_tuple
20from ._model_adapter import ModelAdapter
22os.environ["KERAS_BACKEND"] = settings.keras_backend
25# by default, we use the keras integrated with tensorflow
26# TODO: check if we should prefer keras
27try:
28 import tensorflow as tf # pyright: ignore[reportMissingTypeStubs]
29 from tensorflow import ( # pyright: ignore[reportMissingTypeStubs]
30 keras, # pyright: ignore[reportUnknownVariableType,reportAttributeAccessIssue]
31 )
33 tf_version = Version(tf.__version__)
34except Exception:
35 import keras # pyright: ignore[reportMissingTypeStubs]
37 tf_version = None
40class KerasModelAdapter(ModelAdapter):
41 def __init__(
42 self,
43 *,
44 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
45 devices: Optional[Sequence[str]] = None,
46 ) -> None:
47 super().__init__(model_description=model_description)
48 if model_description.weights.keras_hdf5 is None:
49 raise ValueError("model has not keras_hdf5 weights specified")
50 model_tf_version = model_description.weights.keras_hdf5.tensorflow_version
52 if tf_version is None or model_tf_version is None:
53 logger.warning("Could not check tensorflow versions.")
54 elif model_tf_version > tf_version:
55 logger.warning(
56 "The model specifies a newer tensorflow version than installed: {} > {}.",
57 model_tf_version,
58 tf_version,
59 )
60 elif (model_tf_version.major, model_tf_version.minor) != (
61 tf_version.major,
62 tf_version.minor,
63 ):
64 logger.warning(
65 "Model tensorflow version {} does not match {}.",
66 model_tf_version,
67 tf_version,
68 )
70 # TODO keras device management
71 if devices is not None:
72 logger.warning(
73 "Device management is not implemented for keras yet, ignoring the devices {}",
74 devices,
75 )
77 weight_reader = model_description.weights.keras_hdf5.get_reader()
78 if weight_reader.suffix in (".h5", "hdf5"):
79 h5_file = h5py.File(weight_reader, mode="r")
80 self._network = legacy_h5_format.load_model_from_hdf5(h5_file)
81 else:
82 with TemporaryDirectory() as temp_dir:
83 temp_path = Path(temp_dir) / weight_reader.original_file_name
84 with temp_path.open("wb") as f:
85 shutil.copyfileobj(weight_reader, f)
87 self._network = keras.models.load_model(temp_path)
89 self._output_axes = [
90 tuple(a.id for a in get_axes_infos(out))
91 for out in model_description.outputs
92 ]
94 def _forward_impl( # pyright: ignore[reportUnknownParameterType]
95 self, input_arrays: Sequence[Optional[NDArray[Any]]]
96 ):
97 network_output = self._network.predict(*input_arrays) # type: ignore
98 if is_list(network_output) or is_tuple(network_output):
99 return network_output
100 else:
101 return [network_output] # pyright: ignore[reportUnknownVariableType]
103 def unload(self) -> None:
104 logger.warning(
105 "Device management is not implemented for keras yet, cannot unload model"
106 )