Coverage for bioimageio/core/model_adapters/_keras_model_adapter.py: 45%

55 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-19 09:02 +0000

1import os 

2from typing import Any, List, Optional, Sequence, Union 

3 

4from loguru import logger 

5from numpy.typing import NDArray 

6 

7from bioimageio.spec._internal.io_utils import download 

8from bioimageio.spec.model import v0_4, v0_5 

9from bioimageio.spec.model.v0_5 import Version 

10 

11from .._settings import settings 

12from ..digest_spec import get_axes_infos 

13from ..tensor import Tensor 

14from ._model_adapter import ModelAdapter 

15 

16os.environ["KERAS_BACKEND"] = settings.keras_backend 

17 

18# by default, we use the keras integrated with tensorflow 

19try: 

20 import tensorflow as tf # pyright: ignore[reportMissingImports] 

21 from tensorflow import ( # pyright: ignore[reportMissingImports] 

22 keras, # pyright: ignore[reportUnknownVariableType] 

23 ) 

24 

25 tf_version = Version(tf.__version__) # pyright: ignore[reportUnknownArgumentType] 

26except Exception: 

27 try: 

28 import keras # pyright: ignore[reportMissingImports] 

29 except Exception as e: 

30 keras = None 

31 keras_error = str(e) 

32 else: 

33 keras_error = None 

34 tf_version = None 

35else: 

36 keras_error = None 

37 

38 

39class KerasModelAdapter(ModelAdapter): 

40 def __init__( 

41 self, 

42 *, 

43 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

44 devices: Optional[Sequence[str]] = None, 

45 ) -> None: 

46 if keras is None: 

47 raise ImportError(f"failed to import keras: {keras_error}") 

48 

49 super().__init__() 

50 if model_description.weights.keras_hdf5 is None: 

51 raise ValueError("model has not keras_hdf5 weights specified") 

52 model_tf_version = model_description.weights.keras_hdf5.tensorflow_version 

53 

54 if tf_version is None or model_tf_version is None: 

55 logger.warning("Could not check tensorflow versions.") 

56 elif model_tf_version > tf_version: 

57 logger.warning( 

58 "The model specifies a newer tensorflow version than installed: {} > {}.", 

59 model_tf_version, 

60 tf_version, 

61 ) 

62 elif (model_tf_version.major, model_tf_version.minor) != ( 

63 tf_version.major, 

64 tf_version.minor, 

65 ): 

66 logger.warning( 

67 "Model tensorflow version {} does not match {}.", 

68 model_tf_version, 

69 tf_version, 

70 ) 

71 

72 # TODO keras device management 

73 if devices is not None: 

74 logger.warning( 

75 "Device management is not implemented for keras yet, ignoring the devices {}", 

76 devices, 

77 ) 

78 

79 weight_path = download(model_description.weights.keras_hdf5.source).path 

80 

81 self._network = keras.models.load_model(weight_path) 

82 self._output_axes = [ 

83 tuple(a.id for a in get_axes_infos(out)) 

84 for out in model_description.outputs 

85 ] 

86 

87 def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]: 

88 _result: Union[Sequence[NDArray[Any]], NDArray[Any]] 

89 _result = self._network.predict( # pyright: ignore[reportUnknownVariableType] 

90 *[None if t is None else t.data.data for t in input_tensors] 

91 ) 

92 if isinstance(_result, (tuple, list)): 

93 result: Sequence[NDArray[Any]] = _result 

94 else: 

95 result = [_result] # type: ignore 

96 

97 assert len(result) == len(self._output_axes) 

98 ret: List[Optional[Tensor]] = [] 

99 ret.extend( 

100 [Tensor(r, dims=axes) for r, axes, in zip(result, self._output_axes)] 

101 ) 

102 return ret 

103 

104 def unload(self) -> None: 

105 logger.warning( 

106 "Device management is not implemented for keras yet, cannot unload model" 

107 )