Coverage for bioimageio/core/backends/keras_backend.py: 81%
43 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 09:51 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 09:51 +0000
1import os
2from typing import Any, Optional, Sequence, Union
4from loguru import logger
5from numpy.typing import NDArray
7from bioimageio.spec._internal.io import download
8from bioimageio.spec._internal.type_guards import is_list, is_tuple
9from bioimageio.spec.model import v0_4, v0_5
10from bioimageio.spec.model.v0_5 import Version
12from .._settings import settings
13from ..digest_spec import get_axes_infos
14from ._model_adapter import ModelAdapter
16os.environ["KERAS_BACKEND"] = settings.keras_backend
18# by default, we use the keras integrated with tensorflow
19# TODO: check if we should prefer keras
20try:
21 import tensorflow as tf # pyright: ignore[reportMissingTypeStubs]
22 from tensorflow import ( # pyright: ignore[reportMissingTypeStubs]
23 keras, # pyright: ignore[reportUnknownVariableType,reportAttributeAccessIssue]
24 )
26 tf_version = Version(tf.__version__)
27except Exception:
28 import keras # pyright: ignore[reportMissingTypeStubs]
30 tf_version = None
33class KerasModelAdapter(ModelAdapter):
34 def __init__(
35 self,
36 *,
37 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
38 devices: Optional[Sequence[str]] = None,
39 ) -> None:
40 super().__init__(model_description=model_description)
41 if model_description.weights.keras_hdf5 is None:
42 raise ValueError("model has not keras_hdf5 weights specified")
43 model_tf_version = model_description.weights.keras_hdf5.tensorflow_version
45 if tf_version is None or model_tf_version is None:
46 logger.warning("Could not check tensorflow versions.")
47 elif model_tf_version > tf_version:
48 logger.warning(
49 "The model specifies a newer tensorflow version than installed: {} > {}.",
50 model_tf_version,
51 tf_version,
52 )
53 elif (model_tf_version.major, model_tf_version.minor) != (
54 tf_version.major,
55 tf_version.minor,
56 ):
57 logger.warning(
58 "Model tensorflow version {} does not match {}.",
59 model_tf_version,
60 tf_version,
61 )
63 # TODO keras device management
64 if devices is not None:
65 logger.warning(
66 "Device management is not implemented for keras yet, ignoring the devices {}",
67 devices,
68 )
70 weight_path = download(model_description.weights.keras_hdf5.source).path
72 self._network = keras.models.load_model(weight_path)
73 self._output_axes = [
74 tuple(a.id for a in get_axes_infos(out))
75 for out in model_description.outputs
76 ]
78 def _forward_impl( # pyright: ignore[reportUnknownParameterType]
79 self, input_arrays: Sequence[Optional[NDArray[Any]]]
80 ):
81 network_output = self._network.predict(*input_arrays) # type: ignore
82 if is_list(network_output) or is_tuple(network_output):
83 return network_output
84 else:
85 return [network_output] # pyright: ignore[reportUnknownVariableType]
87 def unload(self) -> None:
88 logger.warning(
89 "Device management is not implemented for keras yet, cannot unload model"
90 )