Coverage for src / bioimageio / core / backends / keras_backend.py: 69%

65 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-30 13:23 +0000

1import os 

2import shutil 

3from pathlib import Path 

4from tempfile import TemporaryDirectory 

5from typing import Any, Optional, Sequence, Union 

6 

7from keras.src.legacy.saving import ( # pyright: ignore[reportMissingTypeStubs] 

8 legacy_h5_format, 

9) 

10from loguru import logger 

11from numpy.typing import NDArray 

12 

13from bioimageio.core.utils._compare import warn_about_version 

14from bioimageio.spec._internal.version_type import Version 

15from bioimageio.spec.model import v0_4, v0_5 

16 

17from .._settings import settings 

18from ..digest_spec import get_axes_infos 

19from ..utils._type_guards import is_list, is_tuple 

20from ._model_adapter import ModelAdapter 

21 

22os.environ["KERAS_BACKEND"] = settings.keras_backend 

23 

24 

25# by default, we use the keras integrated with tensorflow 

26# TODO: check if we should prefer keras 

27try: 

28 import tensorflow as tf 

29 from tensorflow import keras 

30 

31 tf_version = Version(tf.__version__) # pyright: ignore[reportUnknownArgumentType] 

32except Exception: 

33 import keras # pyright: ignore[reportMissingTypeStubs] 

34 

35 tf_version = None 

36 

37 

38class KerasModelAdapter(ModelAdapter): 

39 def __init__( 

40 self, 

41 *, 

42 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

43 devices: Optional[Sequence[str]] = None, 

44 ) -> None: 

45 super().__init__(model_description=model_description) 

46 

47 if ( 

48 not isinstance(model_description, v0_4.ModelDescr) 

49 and model_description.weights.keras_v3 is not None 

50 ): 

51 weight_reader = model_description.weights.keras_v3.get_reader() 

52 backend, backend_version = model_description.weights.keras_v3.backend 

53 elif model_description.weights.keras_hdf5 is not None: 

54 backend = "legacy_tensorflow" 

55 backend_version = model_description.weights.keras_hdf5.tensorflow_version 

56 weight_reader = model_description.weights.keras_hdf5.get_reader() 

57 else: 

58 raise ValueError("model has no Keras weights") 

59 

60 if backend != "legacy_tensorflow" and backend != settings.keras_backend: 

61 logger.warning( 

62 "Model specifies Keras backend '{}', but environment variable KERAS_BACKEND is set to '{}'." 

63 + " Attempting to load model with KERAS_BACKEND='{}' (this may fail if the model is not compatible with this backend).", 

64 backend, 

65 settings.keras_backend, 

66 settings.keras_backend, 

67 ) 

68 

69 if (backend == "legacy_tensorflow") or ( 

70 backend == settings.keras_backend == "tensorflow" 

71 ): 

72 warn_about_version("tensorflow", backend_version, tf_version) 

73 elif backend == settings.keras_backend == "torch": 

74 import torch 

75 

76 torch_version = Version(torch.__version__) 

77 warn_about_version("torch", backend_version, torch_version) 

78 elif backend == settings.keras_backend == "jax": 

79 import jax 

80 

81 jax_version = Version(jax.__version__) 

82 warn_about_version("jax", backend_version, jax_version) 

83 

84 # TODO keras device management 

85 if devices is not None: 

86 logger.warning( 

87 "Device management is not implemented for keras yet, ignoring the devices {}", 

88 devices, 

89 ) 

90 

91 if weight_reader.suffix in (".h5", "hdf5"): 

92 import h5py # pyright: ignore[reportMissingTypeStubs] 

93 

94 h5_file = h5py.File(weight_reader, mode="r") 

95 self._network = legacy_h5_format.load_model_from_hdf5(h5_file) 

96 else: 

97 with TemporaryDirectory() as temp_dir: 

98 temp_path = Path(temp_dir) / weight_reader.original_file_name 

99 with temp_path.open("wb") as f: 

100 shutil.copyfileobj(weight_reader, f) 

101 

102 self._network = keras.models.load_model(temp_path) 

103 

104 self._output_axes = [ 

105 tuple(a.id for a in get_axes_infos(out)) 

106 for out in model_description.outputs 

107 ] 

108 

109 def _forward_impl( # pyright: ignore[reportUnknownParameterType] 

110 self, input_arrays: Sequence[Optional[NDArray[Any]]] 

111 ): 

112 network_output = self._network.predict(*input_arrays) # type: ignore 

113 if is_list(network_output) or is_tuple(network_output): 

114 return network_output 

115 else: 

116 return [network_output] # pyright: ignore[reportUnknownVariableType] 

117 

118 def unload(self) -> None: 

119 logger.warning( 

120 "Device management is not implemented for keras yet, cannot unload model" 

121 )