Coverage for bioimageio/core/io.py: 72%

151 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-16 15:20 +0000

1import collections.abc 

2import warnings 

3import zipfile 

4from pathlib import Path, PurePosixPath 

5from shutil import copyfileobj 

6from typing import ( 

7 Any, 

8 Mapping, 

9 Optional, 

10 Sequence, 

11 Tuple, 

12 TypeVar, 

13 Union, 

14) 

15 

16import h5py # pyright: ignore[reportMissingTypeStubs] 

17from imageio.v3 import imread, imwrite # type: ignore 

18from loguru import logger 

19from numpy.typing import NDArray 

20from pydantic import BaseModel, ConfigDict, TypeAdapter 

21from typing_extensions import assert_never 

22 

23from bioimageio.spec._internal.io import get_reader, interprete_file_source 

24from bioimageio.spec._internal.type_guards import is_ndarray 

25from bioimageio.spec.common import ( 

26 BytesReader, 

27 FileSource, 

28 HttpUrl, 

29 PermissiveFileSource, 

30 RelativeFilePath, 

31 ZipPath, 

32) 

33from bioimageio.spec.utils import download, load_array, save_array 

34 

35from .axis import AxisLike 

36from .common import PerMember 

37from .sample import Sample 

38from .stat_measures import DatasetMeasure, MeasureValue 

39from .tensor import Tensor 

40 

41DEFAULT_H5_DATASET_PATH = "data" 

42 

43 

44SUFFIXES_WITH_DATAPATH = (".h5", ".hdf", ".hdf5") 

45 

46 

47def load_image( 

48 source: Union[ZipPath, PermissiveFileSource], is_volume: Optional[bool] = None 

49) -> NDArray[Any]: 

50 """load a single image as numpy array 

51 

52 Args: 

53 source: image source 

54 is_volume: deprecated 

55 """ 

56 if is_volume is not None: 

57 warnings.warn("**is_volume** is deprecated and will be removed soon.") 

58 

59 if isinstance(source, ZipPath): 

60 parsed_source = source 

61 else: 

62 parsed_source = interprete_file_source(source) 

63 

64 if isinstance(parsed_source, RelativeFilePath): 

65 src = parsed_source.absolute() 

66 else: 

67 src = parsed_source 

68 

69 if isinstance(src, Path): 

70 file_source, suffix, subpath = _split_dataset_path(src) 

71 elif isinstance(src, HttpUrl): 

72 file_source, suffix, subpath = _split_dataset_path(src) 

73 elif isinstance(src, ZipPath): 

74 file_source, suffix, subpath = _split_dataset_path(src) 

75 else: 

76 assert_never(src) 

77 

78 if suffix == ".npy": 

79 if subpath is not None: 

80 logger.warning( 

81 "Unexpected subpath {} for .npy source {}", subpath, file_source 

82 ) 

83 

84 image = load_array(file_source) 

85 elif suffix in SUFFIXES_WITH_DATAPATH: 

86 if subpath is None: 

87 dataset_path = DEFAULT_H5_DATASET_PATH 

88 else: 

89 dataset_path = str(subpath) 

90 

91 reader = download(file_source) 

92 

93 with h5py.File(reader, "r") as f: 

94 h5_dataset = f.get( # pyright: ignore[reportUnknownVariableType] 

95 dataset_path 

96 ) 

97 if not isinstance(h5_dataset, h5py.Dataset): 

98 raise ValueError( 

99 f"{file_source} did not load as {h5py.Dataset}, but has type " 

100 + str( 

101 type(h5_dataset) # pyright: ignore[reportUnknownArgumentType] 

102 ) 

103 ) 

104 image: NDArray[Any] 

105 image = h5_dataset[:] # pyright: ignore[reportUnknownVariableType] 

106 else: 

107 reader = download(file_source) 

108 image = imread( # pyright: ignore[reportUnknownVariableType] 

109 reader.read(), extension=suffix 

110 ) 

111 

112 assert is_ndarray(image) 

113 return image 

114 

115 

116def load_tensor( 

117 path: Union[ZipPath, Path, str], axes: Optional[Sequence[AxisLike]] = None 

118) -> Tensor: 

119 # TODO: load axis meta data 

120 array = load_image(path) 

121 

122 return Tensor.from_numpy(array, dims=axes) 

123 

124 

125_SourceT = TypeVar("_SourceT", Path, HttpUrl, ZipPath) 

126 

127Suffix = str 

128 

129 

130def _split_dataset_path( 

131 source: _SourceT, 

132) -> Tuple[_SourceT, Suffix, Optional[PurePosixPath]]: 

133 """Split off subpath (e.g. internal h5 dataset path) 

134 from a file path following a file extension. 

135 

136 Examples: 

137 >>> _split_dataset_path(Path("my_file.h5/dataset")) 

138 (...Path('my_file.h5'), '.h5', PurePosixPath('dataset')) 

139 

140 >>> _split_dataset_path(Path("my_plain_file")) 

141 (...Path('my_plain_file'), '', None) 

142 

143 """ 

144 if isinstance(source, RelativeFilePath): 

145 src = source.absolute() 

146 else: 

147 src = source 

148 

149 del source 

150 

151 def separate_pure_path(path: PurePosixPath): 

152 for p in path.parents: 

153 if p.suffix in SUFFIXES_WITH_DATAPATH: 

154 return p, p.suffix, PurePosixPath(path.relative_to(p)) 

155 

156 return path, path.suffix, None 

157 

158 if isinstance(src, HttpUrl): 

159 file_path, suffix, data_path = separate_pure_path(PurePosixPath(src.path or "")) 

160 

161 if data_path is None: 

162 return src, suffix, None 

163 

164 return ( 

165 HttpUrl(str(file_path).replace(f"/{data_path}", "")), 

166 suffix, 

167 data_path, 

168 ) 

169 

170 if isinstance(src, ZipPath): 

171 file_path, suffix, data_path = separate_pure_path(PurePosixPath(str(src))) 

172 

173 if data_path is None: 

174 return src, suffix, None 

175 

176 return ( 

177 ZipPath(str(file_path).replace(f"/{data_path}", "")), 

178 suffix, 

179 data_path, 

180 ) 

181 

182 file_path, suffix, data_path = separate_pure_path(PurePosixPath(src)) 

183 return Path(file_path), suffix, data_path 

184 

185 

186def save_tensor(path: Union[Path, str], tensor: Tensor) -> None: 

187 # TODO: save axis meta data 

188 

189 data: NDArray[Any] = ( # pyright: ignore[reportUnknownVariableType] 

190 tensor.data.to_numpy() 

191 ) 

192 assert is_ndarray(data) 

193 file_path, suffix, subpath = _split_dataset_path(Path(path)) 

194 if not suffix: 

195 raise ValueError(f"No suffix (needed to decide file format) found in {path}") 

196 

197 file_path.parent.mkdir(exist_ok=True, parents=True) 

198 if file_path.suffix == ".npy": 

199 if subpath is not None: 

200 raise ValueError(f"Unexpected subpath {subpath} found in .npy path {path}") 

201 save_array(file_path, data) 

202 elif suffix in (".h5", ".hdf", ".hdf5"): 

203 if subpath is None: 

204 dataset_path = DEFAULT_H5_DATASET_PATH 

205 else: 

206 dataset_path = str(subpath) 

207 

208 with h5py.File(file_path, "a") as f: 

209 if dataset_path in f: 

210 del f[dataset_path] 

211 

212 _ = f.create_dataset(dataset_path, data=data, chunks=True) 

213 else: 

214 # if singleton_axes := [a for a, s in tensor.tagged_shape.items() if s == 1]: 

215 # tensor = tensor[{a: 0 for a in singleton_axes}] 

216 # singleton_axes_msg = f"(without singleton axes {singleton_axes}) " 

217 # else: 

218 singleton_axes_msg = "" 

219 

220 logger.debug( 

221 "writing tensor {} {}to {}", 

222 dict(tensor.tagged_shape), 

223 singleton_axes_msg, 

224 path, 

225 ) 

226 imwrite(path, data) 

227 

228 

229def save_sample( 

230 path: Union[Path, str, PerMember[Union[Path, str]]], sample: Sample 

231) -> None: 

232 """Save a **sample** to a **path** pattern 

233 or all sample members in the **path** mapping. 

234 

235 If **path** is a pathlib.Path or a string and the **sample** has multiple members, 

236 **path** it must contain `{member_id}` (or `{input_id}` or `{output_id}`). 

237 

238 (Each) **path** may contain `{sample_id}` to be formatted with the **sample** object. 

239 """ 

240 if not isinstance(path, collections.abc.Mapping): 

241 if len(sample.members) < 2 or any( 

242 m in str(path) for m in ("{member_id}", "{input_id}", "{output_id}") 

243 ): 

244 path = {m: path for m in sample.members} 

245 else: 

246 raise ValueError( 

247 f"path {path} must contain '{{member_id}}' for sample with multiple members {list(sample.members)}." 

248 ) 

249 

250 for m, p in path.items(): 

251 t = sample.members[m] 

252 p_formatted = Path( 

253 str(p).format(sample_id=sample.id, member_id=m, input_id=m, output_id=m) 

254 ) 

255 save_tensor(p_formatted, t) 

256 

257 

258class _SerializedDatasetStatsEntry( 

259 BaseModel, frozen=True, arbitrary_types_allowed=True 

260): 

261 measure: DatasetMeasure 

262 value: MeasureValue 

263 

264 

265_stat_adapter = TypeAdapter( 

266 Sequence[_SerializedDatasetStatsEntry], 

267 config=ConfigDict(arbitrary_types_allowed=True), 

268) 

269 

270 

271def save_dataset_stat(stat: Mapping[DatasetMeasure, MeasureValue], path: Path): 

272 serializable = [ 

273 _SerializedDatasetStatsEntry(measure=k, value=v) for k, v in stat.items() 

274 ] 

275 _ = path.write_bytes(_stat_adapter.dump_json(serializable)) 

276 

277 

278def load_dataset_stat(path: Path): 

279 seq = _stat_adapter.validate_json(path.read_bytes()) 

280 return {e.measure: e.value for e in seq} 

281 

282 

283def ensure_unzipped( 

284 source: Union[PermissiveFileSource, ZipPath, BytesReader], folder: Path 

285): 

286 """unzip a (downloaded) **source** to a file in **folder** if source is a zip archive 

287 otherwise copy **source** to a file in **folder**.""" 

288 if isinstance(source, BytesReader): 

289 weights_reader = source 

290 else: 

291 weights_reader = get_reader(source) 

292 

293 out_path = folder / ( 

294 weights_reader.original_file_name or f"file{weights_reader.suffix}" 

295 ) 

296 

297 if zipfile.is_zipfile(weights_reader): 

298 out_path = out_path.with_name(out_path.name + ".unzipped") 

299 out_path.parent.mkdir(exist_ok=True, parents=True) 

300 # source itself is a zipfile 

301 with zipfile.ZipFile(weights_reader, "r") as f: 

302 f.extractall(out_path) 

303 

304 else: 

305 out_path.parent.mkdir(exist_ok=True, parents=True) 

306 with out_path.open("wb") as f: 

307 copyfileobj(weights_reader, f) 

308 

309 return out_path 

310 

311 

312def get_suffix(source: Union[ZipPath, FileSource]) -> str: 

313 if isinstance(source, Path): 

314 return source.suffix 

315 elif isinstance(source, ZipPath): 

316 return source.suffix 

317 if isinstance(source, RelativeFilePath): 

318 return source.path.suffix 

319 elif isinstance(source, ZipPath): 

320 return source.suffix 

321 elif isinstance(source, HttpUrl): 

322 if source.path is None: 

323 return "" 

324 else: 

325 return PurePosixPath(source.path).suffix 

326 else: 

327 assert_never(source)