Coverage for bioimageio/core/model_adapters/_keras_model_adapter.py: 45%
55 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-19 09:02 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-19 09:02 +0000
1import os
2from typing import Any, List, Optional, Sequence, Union
4from loguru import logger
5from numpy.typing import NDArray
7from bioimageio.spec._internal.io_utils import download
8from bioimageio.spec.model import v0_4, v0_5
9from bioimageio.spec.model.v0_5 import Version
11from .._settings import settings
12from ..digest_spec import get_axes_infos
13from ..tensor import Tensor
14from ._model_adapter import ModelAdapter
16os.environ["KERAS_BACKEND"] = settings.keras_backend
18# by default, we use the keras integrated with tensorflow
19try:
20 import tensorflow as tf # pyright: ignore[reportMissingImports]
21 from tensorflow import ( # pyright: ignore[reportMissingImports]
22 keras, # pyright: ignore[reportUnknownVariableType]
23 )
25 tf_version = Version(tf.__version__) # pyright: ignore[reportUnknownArgumentType]
26except Exception:
27 try:
28 import keras # pyright: ignore[reportMissingImports]
29 except Exception as e:
30 keras = None
31 keras_error = str(e)
32 else:
33 keras_error = None
34 tf_version = None
35else:
36 keras_error = None
39class KerasModelAdapter(ModelAdapter):
40 def __init__(
41 self,
42 *,
43 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
44 devices: Optional[Sequence[str]] = None,
45 ) -> None:
46 if keras is None:
47 raise ImportError(f"failed to import keras: {keras_error}")
49 super().__init__()
50 if model_description.weights.keras_hdf5 is None:
51 raise ValueError("model has not keras_hdf5 weights specified")
52 model_tf_version = model_description.weights.keras_hdf5.tensorflow_version
54 if tf_version is None or model_tf_version is None:
55 logger.warning("Could not check tensorflow versions.")
56 elif model_tf_version > tf_version:
57 logger.warning(
58 "The model specifies a newer tensorflow version than installed: {} > {}.",
59 model_tf_version,
60 tf_version,
61 )
62 elif (model_tf_version.major, model_tf_version.minor) != (
63 tf_version.major,
64 tf_version.minor,
65 ):
66 logger.warning(
67 "Model tensorflow version {} does not match {}.",
68 model_tf_version,
69 tf_version,
70 )
72 # TODO keras device management
73 if devices is not None:
74 logger.warning(
75 "Device management is not implemented for keras yet, ignoring the devices {}",
76 devices,
77 )
79 weight_path = download(model_description.weights.keras_hdf5.source).path
81 self._network = keras.models.load_model(weight_path)
82 self._output_axes = [
83 tuple(a.id for a in get_axes_infos(out))
84 for out in model_description.outputs
85 ]
87 def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]:
88 _result: Union[Sequence[NDArray[Any]], NDArray[Any]]
89 _result = self._network.predict( # pyright: ignore[reportUnknownVariableType]
90 *[None if t is None else t.data.data for t in input_tensors]
91 )
92 if isinstance(_result, (tuple, list)):
93 result: Sequence[NDArray[Any]] = _result
94 else:
95 result = [_result] # type: ignore
97 assert len(result) == len(self._output_axes)
98 ret: List[Optional[Tensor]] = []
99 ret.extend(
100 [Tensor(r, dims=axes) for r, axes, in zip(result, self._output_axes)]
101 )
102 return ret
104 def unload(self) -> None:
105 logger.warning(
106 "Device management is not implemented for keras yet, cannot unload model"
107 )