Coverage for bioimageio/core/io.py: 87%

86 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-19 09:02 +0000

1import collections.abc 

2import warnings 

3from pathlib import Path, PurePosixPath 

4from typing import Any, Mapping, Optional, Sequence, Tuple, Union 

5 

6import h5py 

7import numpy as np 

8from imageio.v3 import imread, imwrite 

9from loguru import logger 

10from numpy.typing import NDArray 

11from pydantic import BaseModel, ConfigDict, TypeAdapter 

12 

13from bioimageio.spec.utils import load_array, save_array 

14 

15from .axis import AxisLike 

16from .common import PerMember 

17from .sample import Sample 

18from .stat_measures import DatasetMeasure, MeasureValue 

19from .tensor import Tensor 

20 

21DEFAULT_H5_DATASET_PATH = "data" 

22 

23 

24def load_image(path: Path, is_volume: Optional[bool] = None) -> NDArray[Any]: 

25 """load a single image as numpy array 

26 

27 Args: 

28 path: image path 

29 is_volume: deprecated 

30 """ 

31 if is_volume is not None: 

32 warnings.warn("**is_volume** is deprecated and will be removed soon.") 

33 

34 file_path, subpath = _split_dataset_path(Path(path)) 

35 

36 if file_path.suffix == ".npy": 

37 if subpath is not None: 

38 raise ValueError(f"Unexpected subpath {subpath} for .npy path {path}") 

39 return load_array(path) 

40 elif file_path.suffix in (".h5", ".hdf", ".hdf5"): 

41 if subpath is None: 

42 dataset_path = DEFAULT_H5_DATASET_PATH 

43 else: 

44 dataset_path = str(subpath) 

45 

46 with h5py.File(file_path, "r") as f: 

47 h5_dataset = f.get( # pyright: ignore[reportUnknownVariableType] 

48 dataset_path 

49 ) 

50 if not isinstance(h5_dataset, h5py.Dataset): 

51 raise ValueError( 

52 f"{path} is not of type {h5py.Dataset}, but has type " 

53 + str( 

54 type(h5_dataset) # pyright: ignore[reportUnknownArgumentType] 

55 ) 

56 ) 

57 image: NDArray[Any] 

58 image = h5_dataset[:] # pyright: ignore[reportUnknownVariableType] 

59 assert isinstance(image, np.ndarray), type( 

60 image # pyright: ignore[reportUnknownArgumentType] 

61 ) 

62 return image # pyright: ignore[reportUnknownVariableType] 

63 else: 

64 return imread(path) # pyright: ignore[reportUnknownVariableType] 

65 

66 

67def load_tensor(path: Path, axes: Optional[Sequence[AxisLike]] = None) -> Tensor: 

68 # TODO: load axis meta data 

69 array = load_image(path) 

70 

71 return Tensor.from_numpy(array, dims=axes) 

72 

73 

74def _split_dataset_path(path: Path) -> Tuple[Path, Optional[PurePosixPath]]: 

75 """Split off subpath (e.g. internal h5 dataset path) 

76 from a file path following a file extension. 

77 

78 Examples: 

79 >>> _split_dataset_path(Path("my_file.h5/dataset")) 

80 (PosixPath('my_file.h5'), PurePosixPath('dataset')) 

81 

82 If no suffix is detected the path is returned with 

83 >>> _split_dataset_path(Path("my_plain_file")) 

84 (PosixPath('my_plain_file'), None) 

85 

86 """ 

87 if path.suffix: 

88 return path, None 

89 

90 for p in path.parents: 

91 if p.suffix: 

92 return p, PurePosixPath(path.relative_to(p)) 

93 

94 return path, None 

95 

96 

97def save_tensor(path: Path, tensor: Tensor) -> None: 

98 # TODO: save axis meta data 

99 

100 data: NDArray[Any] = tensor.data.to_numpy() 

101 file_path, subpath = _split_dataset_path(Path(path)) 

102 if not file_path.suffix: 

103 raise ValueError(f"No suffix (needed to decide file format) found in {path}") 

104 

105 file_path.parent.mkdir(exist_ok=True, parents=True) 

106 if file_path.suffix == ".npy": 

107 if subpath is not None: 

108 raise ValueError(f"Unexpected subpath {subpath} found in .npy path {path}") 

109 save_array(file_path, data) 

110 elif file_path.suffix in (".h5", ".hdf", ".hdf5"): 

111 if subpath is None: 

112 dataset_path = DEFAULT_H5_DATASET_PATH 

113 else: 

114 dataset_path = str(subpath) 

115 

116 with h5py.File(file_path, "a") as f: 

117 if dataset_path in f: 

118 del f[dataset_path] 

119 

120 _ = f.create_dataset(dataset_path, data=data, chunks=True) 

121 else: 

122 # if singleton_axes := [a for a, s in tensor.tagged_shape.items() if s == 1]: 

123 # tensor = tensor[{a: 0 for a in singleton_axes}] 

124 # singleton_axes_msg = f"(without singleton axes {singleton_axes}) " 

125 # else: 

126 singleton_axes_msg = "" 

127 

128 logger.debug( 

129 "writing tensor {} {}to {}", 

130 dict(tensor.tagged_shape), 

131 singleton_axes_msg, 

132 path, 

133 ) 

134 imwrite(path, data) 

135 

136 

137def save_sample(path: Union[Path, str, PerMember[Path]], sample: Sample) -> None: 

138 """save a sample to path 

139 

140 If `path` is a pathlib.Path or a string it must contain `{member_id}` and may contain `{sample_id}`, 

141 which are resolved with the `sample` object. 

142 """ 

143 

144 if not isinstance(path, collections.abc.Mapping) and "{member_id}" not in str(path): 

145 raise ValueError(f"missing `{ member_id} ` in path {path}") 

146 

147 for m, t in sample.members.items(): 

148 if isinstance(path, collections.abc.Mapping): 

149 p = path[m] 

150 else: 

151 p = Path(str(path).format(sample_id=sample.id, member_id=m)) 

152 

153 save_tensor(p, t) 

154 

155 

156class _SerializedDatasetStatsEntry( 

157 BaseModel, frozen=True, arbitrary_types_allowed=True 

158): 

159 measure: DatasetMeasure 

160 value: MeasureValue 

161 

162 

163_stat_adapter = TypeAdapter( 

164 Sequence[_SerializedDatasetStatsEntry], 

165 config=ConfigDict(arbitrary_types_allowed=True), 

166) 

167 

168 

169def save_dataset_stat(stat: Mapping[DatasetMeasure, MeasureValue], path: Path): 

170 serializable = [ 

171 _SerializedDatasetStatsEntry(measure=k, value=v) for k, v in stat.items() 

172 ] 

173 _ = path.write_bytes(_stat_adapter.dump_json(serializable)) 

174 

175 

176def load_dataset_stat(path: Path): 

177 seq = _stat_adapter.validate_json(path.read_bytes()) 

178 return {e.measure: e.value for e in seq}