Coverage for bioimageio/core/backends/keras_backend.py: 81%

54 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-16 15:20 +0000

1import os 

2import shutil 

3from pathlib import Path 

4from tempfile import TemporaryDirectory 

5from typing import Any, Optional, Sequence, Union 

6 

7import h5py # pyright: ignore[reportMissingTypeStubs] 

8from keras.src.legacy.saving import ( # pyright: ignore[reportMissingTypeStubs] 

9 legacy_h5_format, 

10) 

11from loguru import logger 

12from numpy.typing import NDArray 

13 

14from bioimageio.spec.model import v0_4, v0_5 

15from bioimageio.spec.model.v0_5 import Version 

16 

17from .._settings import settings 

18from ..digest_spec import get_axes_infos 

19from ..utils._type_guards import is_list, is_tuple 

20from ._model_adapter import ModelAdapter 

21 

22os.environ["KERAS_BACKEND"] = settings.keras_backend 

23 

24 

25# by default, we use the keras integrated with tensorflow 

26# TODO: check if we should prefer keras 

27try: 

28 import tensorflow as tf # pyright: ignore[reportMissingTypeStubs] 

29 from tensorflow import ( # pyright: ignore[reportMissingTypeStubs] 

30 keras, # pyright: ignore[reportUnknownVariableType,reportAttributeAccessIssue] 

31 ) 

32 

33 tf_version = Version(tf.__version__) 

34except Exception: 

35 import keras # pyright: ignore[reportMissingTypeStubs] 

36 

37 tf_version = None 

38 

39 

40class KerasModelAdapter(ModelAdapter): 

41 def __init__( 

42 self, 

43 *, 

44 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

45 devices: Optional[Sequence[str]] = None, 

46 ) -> None: 

47 super().__init__(model_description=model_description) 

48 if model_description.weights.keras_hdf5 is None: 

49 raise ValueError("model has not keras_hdf5 weights specified") 

50 model_tf_version = model_description.weights.keras_hdf5.tensorflow_version 

51 

52 if tf_version is None or model_tf_version is None: 

53 logger.warning("Could not check tensorflow versions.") 

54 elif model_tf_version > tf_version: 

55 logger.warning( 

56 "The model specifies a newer tensorflow version than installed: {} > {}.", 

57 model_tf_version, 

58 tf_version, 

59 ) 

60 elif (model_tf_version.major, model_tf_version.minor) != ( 

61 tf_version.major, 

62 tf_version.minor, 

63 ): 

64 logger.warning( 

65 "Model tensorflow version {} does not match {}.", 

66 model_tf_version, 

67 tf_version, 

68 ) 

69 

70 # TODO keras device management 

71 if devices is not None: 

72 logger.warning( 

73 "Device management is not implemented for keras yet, ignoring the devices {}", 

74 devices, 

75 ) 

76 

77 weight_reader = model_description.weights.keras_hdf5.get_reader() 

78 if weight_reader.suffix in (".h5", "hdf5"): 

79 h5_file = h5py.File(weight_reader, mode="r") 

80 self._network = legacy_h5_format.load_model_from_hdf5(h5_file) 

81 else: 

82 with TemporaryDirectory() as temp_dir: 

83 temp_path = Path(temp_dir) / weight_reader.original_file_name 

84 with temp_path.open("wb") as f: 

85 shutil.copyfileobj(weight_reader, f) 

86 

87 self._network = keras.models.load_model(temp_path) 

88 

89 self._output_axes = [ 

90 tuple(a.id for a in get_axes_infos(out)) 

91 for out in model_description.outputs 

92 ] 

93 

94 def _forward_impl( # pyright: ignore[reportUnknownParameterType] 

95 self, input_arrays: Sequence[Optional[NDArray[Any]]] 

96 ): 

97 network_output = self._network.predict(*input_arrays) # type: ignore 

98 if is_list(network_output) or is_tuple(network_output): 

99 return network_output 

100 else: 

101 return [network_output] # pyright: ignore[reportUnknownVariableType] 

102 

103 def unload(self) -> None: 

104 logger.warning( 

105 "Device management is not implemented for keras yet, cannot unload model" 

106 )