Coverage for src / bioimageio / core / backends / keras_backend.py: 80%

54 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-13 09:46 +0000

1import os 

2import shutil 

3from pathlib import Path 

4from tempfile import TemporaryDirectory 

5from typing import Any, Optional, Sequence, Union 

6 

7from keras.src.legacy.saving import ( # pyright: ignore[reportMissingTypeStubs] 

8 legacy_h5_format, 

9) 

10from loguru import logger 

11from numpy.typing import NDArray 

12 

13from bioimageio.spec.model import v0_4, v0_5 

14from bioimageio.spec.model.v0_5 import Version 

15 

16from .._settings import settings 

17from ..digest_spec import get_axes_infos 

18from ..utils._type_guards import is_list, is_tuple 

19from ._model_adapter import ModelAdapter 

20 

21os.environ["KERAS_BACKEND"] = settings.keras_backend 

22 

23 

24# by default, we use the keras integrated with tensorflow 

25# TODO: check if we should prefer keras 

26try: 

27 import tensorflow as tf # pyright: ignore[reportMissingTypeStubs] 

28 from tensorflow import ( # pyright: ignore[reportMissingTypeStubs] 

29 keras, # pyright: ignore[reportUnknownVariableType,reportAttributeAccessIssue] 

30 ) 

31 

32 tf_version = Version(tf.__version__) 

33except Exception: 

34 import keras # pyright: ignore[reportMissingTypeStubs] 

35 

36 tf_version = None 

37 

38 

39class KerasModelAdapter(ModelAdapter): 

40 def __init__( 

41 self, 

42 *, 

43 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

44 devices: Optional[Sequence[str]] = None, 

45 ) -> None: 

46 super().__init__(model_description=model_description) 

47 if model_description.weights.keras_hdf5 is None: 

48 raise ValueError("model has not keras_hdf5 weights specified") 

49 model_tf_version = model_description.weights.keras_hdf5.tensorflow_version 

50 

51 if tf_version is None or model_tf_version is None: 

52 logger.warning("Could not check tensorflow versions.") 

53 elif model_tf_version > tf_version: 

54 logger.warning( 

55 "The model specifies a newer tensorflow version than installed: {} > {}.", 

56 model_tf_version, 

57 tf_version, 

58 ) 

59 elif (model_tf_version.major, model_tf_version.minor) != ( 

60 tf_version.major, 

61 tf_version.minor, 

62 ): 

63 logger.warning( 

64 "Model tensorflow version {} does not match {}.", 

65 model_tf_version, 

66 tf_version, 

67 ) 

68 

69 # TODO keras device management 

70 if devices is not None: 

71 logger.warning( 

72 "Device management is not implemented for keras yet, ignoring the devices {}", 

73 devices, 

74 ) 

75 

76 weight_reader = model_description.weights.keras_hdf5.get_reader() 

77 if weight_reader.suffix in (".h5", "hdf5"): 

78 import h5py # pyright: ignore[reportMissingTypeStubs] 

79 

80 h5_file = h5py.File(weight_reader, mode="r") 

81 self._network = legacy_h5_format.load_model_from_hdf5(h5_file) 

82 else: 

83 with TemporaryDirectory() as temp_dir: 

84 temp_path = Path(temp_dir) / weight_reader.original_file_name 

85 with temp_path.open("wb") as f: 

86 shutil.copyfileobj(weight_reader, f) 

87 

88 self._network = keras.models.load_model(temp_path) 

89 

90 self._output_axes = [ 

91 tuple(a.id for a in get_axes_infos(out)) 

92 for out in model_description.outputs 

93 ] 

94 

95 def _forward_impl( # pyright: ignore[reportUnknownParameterType] 

96 self, input_arrays: Sequence[Optional[NDArray[Any]]] 

97 ): 

98 network_output = self._network.predict(*input_arrays) # type: ignore 

99 if is_list(network_output) or is_tuple(network_output): 

100 return network_output 

101 else: 

102 return [network_output] # pyright: ignore[reportUnknownVariableType] 

103 

104 def unload(self) -> None: 

105 logger.warning( 

106 "Device management is not implemented for keras yet, cannot unload model" 

107 )