Coverage for bioimageio/core/io.py: 76%

138 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-01 09:51 +0000

1import collections.abc 

2import warnings 

3import zipfile 

4from io import TextIOWrapper 

5from pathlib import Path, PurePosixPath 

6from shutil import copyfileobj 

7from typing import ( 

8 Any, 

9 Mapping, 

10 Optional, 

11 Sequence, 

12 Tuple, 

13 TypeVar, 

14 Union, 

15) 

16 

17import h5py # pyright: ignore[reportMissingTypeStubs] 

18import numpy as np 

19from imageio.v3 import imread, imwrite # type: ignore 

20from loguru import logger 

21from numpy.typing import NDArray 

22from pydantic import BaseModel, ConfigDict, TypeAdapter 

23from typing_extensions import assert_never 

24 

25from bioimageio.spec._internal.io import interprete_file_source 

26from bioimageio.spec.common import ( 

27 HttpUrl, 

28 PermissiveFileSource, 

29 RelativeFilePath, 

30 ZipPath, 

31) 

32from bioimageio.spec.utils import download, load_array, save_array 

33 

34from .axis import AxisLike 

35from .common import PerMember 

36from .sample import Sample 

37from .stat_measures import DatasetMeasure, MeasureValue 

38from .tensor import Tensor 

39 

40DEFAULT_H5_DATASET_PATH = "data" 

41 

42 

43SUFFIXES_WITH_DATAPATH = (".h5", ".hdf", ".hdf5") 

44 

45 

46def load_image( 

47 source: Union[ZipPath, PermissiveFileSource], is_volume: Optional[bool] = None 

48) -> NDArray[Any]: 

49 """load a single image as numpy array 

50 

51 Args: 

52 source: image source 

53 is_volume: deprecated 

54 """ 

55 if is_volume is not None: 

56 warnings.warn("**is_volume** is deprecated and will be removed soon.") 

57 

58 if isinstance(source, ZipPath): 

59 parsed_source = source 

60 else: 

61 parsed_source = interprete_file_source(source) 

62 

63 if isinstance(parsed_source, RelativeFilePath): 

64 src = parsed_source.absolute() 

65 else: 

66 src = parsed_source 

67 

68 # FIXME: why is pyright complaining about giving the union to _split_dataset_path? 

69 if isinstance(src, Path): 

70 file_source, subpath = _split_dataset_path(src) 

71 elif isinstance(src, HttpUrl): 

72 file_source, subpath = _split_dataset_path(src) 

73 elif isinstance(src, ZipPath): 

74 file_source, subpath = _split_dataset_path(src) 

75 else: 

76 assert_never(src) 

77 

78 path = download(file_source).path 

79 

80 if path.suffix == ".npy": 

81 if subpath is not None: 

82 raise ValueError(f"Unexpected subpath {subpath} for .npy path {path}") 

83 return load_array(path) 

84 elif path.suffix in SUFFIXES_WITH_DATAPATH: 

85 if subpath is None: 

86 dataset_path = DEFAULT_H5_DATASET_PATH 

87 else: 

88 dataset_path = str(subpath) 

89 

90 with h5py.File(path, "r") as f: 

91 h5_dataset = f.get( # pyright: ignore[reportUnknownVariableType] 

92 dataset_path 

93 ) 

94 if not isinstance(h5_dataset, h5py.Dataset): 

95 raise ValueError( 

96 f"{path} is not of type {h5py.Dataset}, but has type " 

97 + str( 

98 type(h5_dataset) # pyright: ignore[reportUnknownArgumentType] 

99 ) 

100 ) 

101 image: NDArray[Any] 

102 image = h5_dataset[:] # pyright: ignore[reportUnknownVariableType] 

103 assert isinstance(image, np.ndarray), type( 

104 image # pyright: ignore[reportUnknownArgumentType] 

105 ) 

106 return image # pyright: ignore[reportUnknownVariableType] 

107 elif isinstance(path, ZipPath): 

108 return imread( 

109 path.read_bytes(), extension=path.suffix 

110 ) # pyright: ignore[reportUnknownVariableType] 

111 else: 

112 return imread(path) # pyright: ignore[reportUnknownVariableType] 

113 

114 

115def load_tensor( 

116 path: Union[ZipPath, Path, str], axes: Optional[Sequence[AxisLike]] = None 

117) -> Tensor: 

118 # TODO: load axis meta data 

119 array = load_image(path) 

120 

121 return Tensor.from_numpy(array, dims=axes) 

122 

123 

124_SourceT = TypeVar("_SourceT", Path, HttpUrl, ZipPath) 

125 

126 

127def _split_dataset_path( 

128 source: _SourceT, 

129) -> Tuple[_SourceT, Optional[PurePosixPath]]: 

130 """Split off subpath (e.g. internal h5 dataset path) 

131 from a file path following a file extension. 

132 

133 Examples: 

134 >>> _split_dataset_path(Path("my_file.h5/dataset")) 

135 (...Path('my_file.h5'), PurePosixPath('dataset')) 

136 

137 >>> _split_dataset_path(Path("my_plain_file")) 

138 (...Path('my_plain_file'), None) 

139 

140 """ 

141 if isinstance(source, RelativeFilePath): 

142 src = source.absolute() 

143 else: 

144 src = source 

145 

146 del source 

147 

148 def separate_pure_path(path: PurePosixPath): 

149 for p in path.parents: 

150 if p.suffix in SUFFIXES_WITH_DATAPATH: 

151 return p, PurePosixPath(path.relative_to(p)) 

152 

153 return path, None 

154 

155 if isinstance(src, HttpUrl): 

156 file_path, data_path = separate_pure_path(PurePosixPath(src.path or "")) 

157 

158 if data_path is None: 

159 return src, None 

160 

161 return ( 

162 HttpUrl(str(file_path).replace(f"/{data_path}", "")), 

163 data_path, 

164 ) 

165 

166 if isinstance(src, ZipPath): 

167 file_path, data_path = separate_pure_path(PurePosixPath(str(src))) 

168 

169 if data_path is None: 

170 return src, None 

171 

172 return ( 

173 ZipPath(str(file_path).replace(f"/{data_path}", "")), 

174 data_path, 

175 ) 

176 

177 file_path, data_path = separate_pure_path(PurePosixPath(src)) 

178 return Path(file_path), data_path 

179 

180 

181def save_tensor(path: Union[Path, str], tensor: Tensor) -> None: 

182 # TODO: save axis meta data 

183 

184 data: NDArray[Any] = tensor.data.to_numpy() 

185 file_path, subpath = _split_dataset_path(Path(path)) 

186 if not file_path.suffix: 

187 raise ValueError(f"No suffix (needed to decide file format) found in {path}") 

188 

189 file_path.parent.mkdir(exist_ok=True, parents=True) 

190 if file_path.suffix == ".npy": 

191 if subpath is not None: 

192 raise ValueError(f"Unexpected subpath {subpath} found in .npy path {path}") 

193 save_array(file_path, data) 

194 elif file_path.suffix in (".h5", ".hdf", ".hdf5"): 

195 if subpath is None: 

196 dataset_path = DEFAULT_H5_DATASET_PATH 

197 else: 

198 dataset_path = str(subpath) 

199 

200 with h5py.File(file_path, "a") as f: 

201 if dataset_path in f: 

202 del f[dataset_path] 

203 

204 _ = f.create_dataset(dataset_path, data=data, chunks=True) 

205 else: 

206 # if singleton_axes := [a for a, s in tensor.tagged_shape.items() if s == 1]: 

207 # tensor = tensor[{a: 0 for a in singleton_axes}] 

208 # singleton_axes_msg = f"(without singleton axes {singleton_axes}) " 

209 # else: 

210 singleton_axes_msg = "" 

211 

212 logger.debug( 

213 "writing tensor {} {}to {}", 

214 dict(tensor.tagged_shape), 

215 singleton_axes_msg, 

216 path, 

217 ) 

218 imwrite(path, data) 

219 

220 

221def save_sample( 

222 path: Union[Path, str, PerMember[Union[Path, str]]], sample: Sample 

223) -> None: 

224 """Save a **sample** to a **path** pattern 

225 or all sample members in the **path** mapping. 

226 

227 If **path** is a pathlib.Path or a string and the **sample** has multiple members, 

228 **path** it must contain `{member_id}` (or `{input_id}` or `{output_id}`). 

229 

230 (Each) **path** may contain `{sample_id}` to be formatted with the **sample** object. 

231 """ 

232 if not isinstance(path, collections.abc.Mapping): 

233 if len(sample.members) < 2 or any( 

234 m in str(path) for m in ("{member_id}", "{input_id}", "{output_id}") 

235 ): 

236 path = {m: path for m in sample.members} 

237 else: 

238 raise ValueError( 

239 f"path {path} must contain '{{member_id}}' for sample with multiple members {list(sample.members)}." 

240 ) 

241 

242 for m, p in path.items(): 

243 t = sample.members[m] 

244 p_formatted = Path( 

245 str(p).format(sample_id=sample.id, member_id=m, input_id=m, output_id=m) 

246 ) 

247 save_tensor(p_formatted, t) 

248 

249 

250class _SerializedDatasetStatsEntry( 

251 BaseModel, frozen=True, arbitrary_types_allowed=True 

252): 

253 measure: DatasetMeasure 

254 value: MeasureValue 

255 

256 

257_stat_adapter = TypeAdapter( 

258 Sequence[_SerializedDatasetStatsEntry], 

259 config=ConfigDict(arbitrary_types_allowed=True), 

260) 

261 

262 

263def save_dataset_stat(stat: Mapping[DatasetMeasure, MeasureValue], path: Path): 

264 serializable = [ 

265 _SerializedDatasetStatsEntry(measure=k, value=v) for k, v in stat.items() 

266 ] 

267 _ = path.write_bytes(_stat_adapter.dump_json(serializable)) 

268 

269 

270def load_dataset_stat(path: Path): 

271 seq = _stat_adapter.validate_json(path.read_bytes()) 

272 return {e.measure: e.value for e in seq} 

273 

274 

275def ensure_unzipped(source: Union[PermissiveFileSource, ZipPath], folder: Path): 

276 """unzip a (downloaded) **source** to a file in **folder** if source is a zip archive. 

277 Always returns the path to the unzipped source (maybe source itself)""" 

278 local_weights_file = download(source).path 

279 if isinstance(local_weights_file, ZipPath): 

280 # source is inside a zip archive 

281 out_path = folder / local_weights_file.filename 

282 with local_weights_file.open("rb") as src, out_path.open("wb") as dst: 

283 assert not isinstance(src, TextIOWrapper) 

284 copyfileobj(src, dst) 

285 

286 local_weights_file = out_path 

287 

288 if zipfile.is_zipfile(local_weights_file): 

289 # source itself is a zipfile 

290 out_path = folder / local_weights_file.with_suffix(".unzipped").name 

291 with zipfile.ZipFile(local_weights_file, "r") as f: 

292 f.extractall(out_path) 

293 

294 return out_path 

295 else: 

296 return local_weights_file