Coverage for src/bioimageio/core/backends/keras_backend.py: 0%

66 statements  

« prev     ^ index     » next       coverage.py v7.14.2, created at 2026-06-22 16:54 +0000

1import os 

2import shutil 

3from pathlib import Path 

4from tempfile import TemporaryDirectory 

5from typing import Any, Optional, Sequence, Tuple 

6 

7from keras.src.legacy.saving import ( # pyright: ignore[reportMissingTypeStubs] 

8 legacy_h5_format, 

9) 

10from loguru import logger 

11from numpy.typing import NDArray 

12 

13from bioimageio.spec._internal.version_type import Version 

14from bioimageio.spec.model import v0_4 

15 

16from .._model_adapter import LocalModelAdapter 

17from .._settings import settings 

18from ..utils._compare import warn_about_version 

19from ..utils._type_guards import is_list, is_tuple 

20 

21os.environ["KERAS_BACKEND"] = settings.keras_backend 

22 

23 

24# by default, we use the keras integrated with tensorflow 

25# TODO: check if we should prefer keras 

26try: 

27 import tensorflow as tf 

28 from tensorflow import keras 

29 

30 tf_version = Version(tf.__version__) # pyright: ignore[reportUnknownArgumentType] 

31except Exception: 

32 import keras # pyright: ignore[reportMissingTypeStubs] 

33 

34 tf_version = None 

35 

36 

37class KerasModelAdapter(LocalModelAdapter[None, Any]): 

38 def _parse_devices(self, devices: Optional[Sequence[str]]) -> Tuple[None]: 

39 # TODO keras device management 

40 if devices is not None: 

41 logger.warning( 

42 "Device management is not implemented for keras yet, ignoring the devices {}", 

43 devices, 

44 ) 

45 return (None,) 

46 

47 def _init_model_on_device(self, device: None) -> Any: 

48 if ( 

49 not isinstance(self._model_descr, v0_4.ModelDescr) 

50 and self._model_descr.weights.keras_v3 is not None 

51 ): 

52 weight_reader = self._model_descr.weights.keras_v3.get_reader() 

53 backend, backend_version = self._model_descr.weights.keras_v3.backend 

54 elif self._model_descr.weights.keras_hdf5 is not None: 

55 backend = "legacy_tensorflow" 

56 backend_version = self._model_descr.weights.keras_hdf5.tensorflow_version 

57 weight_reader = self._model_descr.weights.keras_hdf5.get_reader() 

58 else: 

59 raise ValueError("model has no Keras weights") 

60 

61 if backend != "legacy_tensorflow" and backend != settings.keras_backend: 

62 logger.warning( 

63 "Model specifies Keras backend '{}', but environment variable KERAS_BACKEND is set to '{}'." 

64 + " Attempting to load model with KERAS_BACKEND='{}' (this may fail if the model is not compatible with this backend).", 

65 backend, 

66 settings.keras_backend, 

67 settings.keras_backend, 

68 ) 

69 

70 if (backend == "legacy_tensorflow") or ( 

71 backend == settings.keras_backend == "tensorflow" 

72 ): 

73 warn_about_version("tensorflow", backend_version, tf_version) 

74 elif backend == settings.keras_backend == "torch": 

75 import torch 

76 

77 torch_version = Version(torch.__version__) 

78 warn_about_version("torch", backend_version, torch_version) 

79 elif backend == settings.keras_backend == "jax": 

80 import jax 

81 

82 jax_version = Version(jax.__version__) 

83 warn_about_version("jax", backend_version, jax_version) 

84 

85 if weight_reader.suffix in (".h5", "hdf5"): 

86 import h5py # pyright: ignore[reportMissingTypeStubs] 

87 

88 h5_file = h5py.File(weight_reader, mode="r") 

89 return legacy_h5_format.load_model_from_hdf5(h5_file) # pyright: ignore[reportUnknownVariableType] 

90 else: 

91 with TemporaryDirectory() as temp_dir: 

92 temp_path = Path(temp_dir) / weight_reader.original_file_name 

93 with temp_path.open("wb") as f: 

94 shutil.copyfileobj(weight_reader, f) 

95 

96 return keras.models.load_model(temp_path) # pyright: ignore[reportUnknownVariableType] 

97 

98 def _forward_impl( 

99 self, 

100 device: None, 

101 model: Any, 

102 input_arrays: Sequence[Optional[NDArray[Any]]], 

103 ): 

104 network_output = model.predict(*input_arrays) 

105 if is_list(network_output) or is_tuple(network_output): 

106 return network_output 

107 else: 

108 return [network_output] 

109 

110 def _cleanup_pre_model_deletion(self, device: None, model: Any) -> None: 

111 return 

112 

113 def _cleanup_post_model_deletion(self, device: None) -> None: 

114 return