Coverage for bioimageio/core/backends/keras_backend.py: 81%

43 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-01 09:51 +0000

1import os 

2from typing import Any, Optional, Sequence, Union 

3 

4from loguru import logger 

5from numpy.typing import NDArray 

6 

7from bioimageio.spec._internal.io import download 

8from bioimageio.spec._internal.type_guards import is_list, is_tuple 

9from bioimageio.spec.model import v0_4, v0_5 

10from bioimageio.spec.model.v0_5 import Version 

11 

12from .._settings import settings 

13from ..digest_spec import get_axes_infos 

14from ._model_adapter import ModelAdapter 

15 

16os.environ["KERAS_BACKEND"] = settings.keras_backend 

17 

18# by default, we use the keras integrated with tensorflow 

19# TODO: check if we should prefer keras 

20try: 

21 import tensorflow as tf # pyright: ignore[reportMissingTypeStubs] 

22 from tensorflow import ( # pyright: ignore[reportMissingTypeStubs] 

23 keras, # pyright: ignore[reportUnknownVariableType,reportAttributeAccessIssue] 

24 ) 

25 

26 tf_version = Version(tf.__version__) 

27except Exception: 

28 import keras # pyright: ignore[reportMissingTypeStubs] 

29 

30 tf_version = None 

31 

32 

33class KerasModelAdapter(ModelAdapter): 

34 def __init__( 

35 self, 

36 *, 

37 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

38 devices: Optional[Sequence[str]] = None, 

39 ) -> None: 

40 super().__init__(model_description=model_description) 

41 if model_description.weights.keras_hdf5 is None: 

42 raise ValueError("model has not keras_hdf5 weights specified") 

43 model_tf_version = model_description.weights.keras_hdf5.tensorflow_version 

44 

45 if tf_version is None or model_tf_version is None: 

46 logger.warning("Could not check tensorflow versions.") 

47 elif model_tf_version > tf_version: 

48 logger.warning( 

49 "The model specifies a newer tensorflow version than installed: {} > {}.", 

50 model_tf_version, 

51 tf_version, 

52 ) 

53 elif (model_tf_version.major, model_tf_version.minor) != ( 

54 tf_version.major, 

55 tf_version.minor, 

56 ): 

57 logger.warning( 

58 "Model tensorflow version {} does not match {}.", 

59 model_tf_version, 

60 tf_version, 

61 ) 

62 

63 # TODO keras device management 

64 if devices is not None: 

65 logger.warning( 

66 "Device management is not implemented for keras yet, ignoring the devices {}", 

67 devices, 

68 ) 

69 

70 weight_path = download(model_description.weights.keras_hdf5.source).path 

71 

72 self._network = keras.models.load_model(weight_path) 

73 self._output_axes = [ 

74 tuple(a.id for a in get_axes_infos(out)) 

75 for out in model_description.outputs 

76 ] 

77 

78 def _forward_impl( # pyright: ignore[reportUnknownParameterType] 

79 self, input_arrays: Sequence[Optional[NDArray[Any]]] 

80 ): 

81 network_output = self._network.predict(*input_arrays) # type: ignore 

82 if is_list(network_output) or is_tuple(network_output): 

83 return network_output 

84 else: 

85 return [network_output] # pyright: ignore[reportUnknownVariableType] 

86 

87 def unload(self) -> None: 

88 logger.warning( 

89 "Device management is not implemented for keras yet, cannot unload model" 

90 )