Coverage for src / bioimageio / core / io.py: 75%

119 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-18 12:35 +0000

1import collections.abc 

2import json 

3import warnings 

4import zipfile 

5from contextlib import nullcontext 

6from io import BytesIO 

7from pathlib import Path 

8from shutil import copyfileobj 

9from typing import ( 

10 Any, 

11 Dict, 

12 List, 

13 Mapping, 

14 Optional, 

15 Sequence, 

16 TypeVar, 

17 Union, 

18) 

19 

20import xarray as xr 

21from imageio.v3 import imread, imwrite # type: ignore 

22from loguru import logger 

23from numpy.typing import NDArray 

24from pydantic import BaseModel, RootModel 

25 

26from bioimageio.spec._internal.io import get_reader, interprete_file_source 

27from bioimageio.spec._internal.type_guards import is_ndarray 

28from bioimageio.spec.common import ( 

29 BytesReader, 

30 FileDescr, 

31 FileSource, 

32 HttpUrl, 

33 PermissiveFileSource, 

34 RelativeFilePath, 

35 ZipPath, 

36) 

37from bioimageio.spec.utils import download, load_array, save_array 

38 

39from .axis import AxisId, AxisLike 

40from .common import PerMember 

41from .sample import Sample, Stat 

42from .stat_measures import DatasetMeasure, MeasureValue, SampleMeasure 

43from .tensor import Tensor 

44 

45JsonValue = Union[ 

46 bool, int, float, str, None, List["JsonValue"], Dict[str, "JsonValue"] 

47] 

48 

49 

50def load_image( 

51 source: Union[PermissiveFileSource, ZipPath], is_volume: Optional[bool] = None 

52) -> NDArray[Any]: 

53 """load a single image as numpy array 

54 

55 Args: 

56 source: image source 

57 is_volume: deprecated 

58 """ 

59 if is_volume is not None: 

60 warnings.warn("**is_volume** is deprecated and will be removed soon.") 

61 

62 if isinstance(source, (FileDescr, ZipPath)): 

63 parsed_source = source 

64 else: 

65 parsed_source = interprete_file_source(source) 

66 

67 if isinstance(parsed_source, RelativeFilePath): 

68 parsed_source = parsed_source.absolute() 

69 

70 if parsed_source.suffix == ".npy": 

71 image = load_array(parsed_source) 

72 else: 

73 reader = download(parsed_source) 

74 image = imread( # pyright: ignore[reportUnknownVariableType] 

75 reader.read(), extension=parsed_source.suffix 

76 ) 

77 

78 assert is_ndarray(image) 

79 return image 

80 

81 

82def load_tensor( 

83 source: Union[PermissiveFileSource, ZipPath], 

84 /, 

85 axes: Optional[Sequence[AxisLike]] = None, 

86) -> Tensor: 

87 # TODO: load axis meta data 

88 array = load_image(source) 

89 

90 return Tensor.from_numpy(array, dims=axes) 

91 

92 

93_SourceT = TypeVar("_SourceT", Path, HttpUrl, ZipPath) 

94 

95Suffix = str 

96 

97 

98def save_tensor(path: Union[Path, str], tensor: Tensor) -> None: 

99 # TODO: save axis meta data 

100 

101 path = Path(path) 

102 if not path.suffix: 

103 raise ValueError(f"No suffix (needed to decide file format) found in {path}") 

104 

105 extension = path.suffix.lower() 

106 path.parent.mkdir(exist_ok=True, parents=True) 

107 if extension == ".npy": 

108 save_array(path, tensor.to_numpy()) 

109 elif extension in (".h5", ".hdf", ".hdf5"): 

110 raise NotImplementedError("Saving to h5 with dataset path is not implemented.") 

111 else: 

112 removed_singleton_axes: List[AxisId] = [] 

113 remove_singletons = { 

114 AxisId("batch"): [ 

115 ".tif", 

116 ".tiff", 

117 ], # remove singleton batch dim for tiff files 

118 **{ 

119 a: [".png", ".jpg", ".jpeg"] for a in tensor.dims 

120 }, # remove any singleton axis for png and jpg files 

121 } 

122 for rm_a, rm_ext in remove_singletons.items(): 

123 if extension in rm_ext and tensor.tagged_shape.get(rm_a) == 1: 

124 tensor = tensor[{rm_a: 0}] 

125 removed_singleton_axes.append(rm_a) 

126 

127 if removed_singleton_axes: 

128 singleton_axes_msg = f"(with removed singleton axes {list(map(str, removed_singleton_axes))}) " 

129 else: 

130 singleton_axes_msg = "" 

131 

132 logger.info( 

133 "writing tensor {} {}to {}", 

134 dict(tensor.tagged_shape), 

135 singleton_axes_msg, 

136 path, 

137 ) 

138 if extension in (".png", ".jpg", ".jpeg") and tensor.dtype in ( 

139 "float32", 

140 "float64", 

141 ): 

142 logger.warning( 

143 "converting tensor of dtype {} to uint8 for saving as {}", 

144 tensor.dtype, 

145 extension, 

146 ) 

147 tensor = ( 

148 (tensor - (t_min := tensor.data.min())) 

149 / xr.ufuncs.maximum(tensor.data.max() - t_min, 1e-8) 

150 * 255 

151 ).astype("uint8") 

152 

153 imwrite(path, tensor, extension=extension) 

154 

155 

156def save_sample( 

157 path: Union[Path, str, PerMember[Union[Path, str]]], sample: Sample 

158) -> None: 

159 """Save a **sample** to a **path** pattern 

160 or all sample members in the **path** mapping. 

161 

162 If **path** is a pathlib.Path or a string and the **sample** has multiple members, 

163 **path** it must contain `{member_id}` (or `{input_id}` or `{output_id}`). 

164 

165 (Each) **path** may contain `{sample_id}` to be formatted with the **sample** object. 

166 """ 

167 if not isinstance(path, collections.abc.Mapping): 

168 if len(sample.members) < 2 or any( 

169 m in str(path) for m in ("{member_id}", "{input_id}", "{output_id}") 

170 ): 

171 path = {m: path for m in sample.members} 

172 else: 

173 raise ValueError( 

174 f"path {path} must contain '{{member_id}}' for sample with multiple members {list(sample.members)}." 

175 ) 

176 

177 for m, p in path.items(): 

178 t = sample.members[m] 

179 p_formatted = Path( 

180 str(p).format(sample_id=sample.id, member_id=m, input_id=m, output_id=m) 

181 ) 

182 save_tensor(p_formatted, t) 

183 

184 

185class _StatEntry(BaseModel, frozen=True, arbitrary_types_allowed=True): 

186 """Serializable stat entry""" 

187 

188 measure: Union[DatasetMeasure, SampleMeasure] 

189 value: MeasureValue 

190 

191 

192class _StatList(RootModel[List[_StatEntry]]): 

193 """Serializable stat mapping""" 

194 

195 pass 

196 

197 

198def serialize_stat( 

199 stat: Mapping[Union[DatasetMeasure, SampleMeasure], MeasureValue], 

200) -> List[JsonValue]: 

201 """Serialize a stat mapping to a JSON string""" 

202 stat_list = _StatList([_StatEntry(measure=k, value=v) for k, v in stat.items()]) 

203 return stat_list.model_dump(mode="json") 

204 

205 

206def save_stat( 

207 stat: Mapping[Union[DatasetMeasure, SampleMeasure], MeasureValue], 

208 output: Union[Path, BytesIO], 

209) -> None: 

210 """Save sample and dataset statistics as a JSON file""" 

211 

212 if isinstance(output, Path): 

213 ctxt = output.open("wb") 

214 else: 

215 ctxt = nullcontext(output) 

216 

217 with ctxt as out: 

218 _ = out.write(json.dumps(serialize_stat(stat), indent=2).encode("utf-8")) 

219 

220 

221def load_stat(source: Union[Path, str, Sequence[JsonValue]]) -> Stat: 

222 """Load sample and dataset statistics from JSON""" 

223 if isinstance(source, Path): 

224 source = source.read_text(encoding="utf-8") 

225 

226 if isinstance(source, str): 

227 seq = _StatList.model_validate_json(source) 

228 else: 

229 seq = _StatList.model_validate(source) 

230 

231 return {e.measure: e.value for e in seq.root} 

232 

233 

234def save_dataset_stat(stat: Mapping[DatasetMeasure, MeasureValue], path: Path) -> None: 

235 """DEPRECATED alias for save_stat(): use `save_stats()` instead.""" 

236 warnings.warn("`save_dataset_stat()` is deprecated, use `save_stats()` instead.") 

237 save_stat({k: v for k, v in stat.items()}, path) 

238 

239 

240def load_dataset_stat(path: Path) -> Stat: 

241 """DEPRECATED alias for `load_stat()`: use `load_stat()` instead.""" 

242 warnings.warn("`load_dataset_stat()` is deprecated, use `load_stats()` instead.") 

243 return load_stat(path) 

244 

245 

246def ensure_unzipped( 

247 source: Union[PermissiveFileSource, ZipPath, BytesReader], folder: Path 

248): 

249 """unzip a (downloaded) **source** to a file in **folder** if source is a zip archive 

250 otherwise copy **source** to a file in **folder**.""" 

251 if isinstance(source, BytesReader): 

252 weights_reader = source 

253 else: 

254 weights_reader = get_reader(source) 

255 

256 out_path = folder / ( 

257 weights_reader.original_file_name or f"file{weights_reader.suffix}" 

258 ) 

259 

260 if zipfile.is_zipfile(weights_reader): 

261 out_path = out_path.with_name(out_path.name + ".unzipped") 

262 out_path.parent.mkdir(exist_ok=True, parents=True) 

263 # source itself is a zipfile 

264 with zipfile.ZipFile(weights_reader, "r") as f: 

265 f.extractall(out_path) 

266 

267 else: 

268 out_path.parent.mkdir(exist_ok=True, parents=True) 

269 with out_path.open("wb") as f: 

270 copyfileobj(weights_reader, f) 

271 

272 return out_path 

273 

274 

275def get_suffix(source: Union[ZipPath, FileSource]) -> Suffix: 

276 """DEPRECATED: use source.suffix instead.""" 

277 return source.suffix