Coverage for src/bioimageio/core/io.py: 80%

121 statements  

« prev     ^ index     » next       coverage.py v7.14.2, created at 2026-06-22 16:54 +0000

1import collections.abc 

2import json 

3import warnings 

4import zipfile 

5from contextlib import nullcontext 

6from io import BytesIO 

7from pathlib import Path 

8from shutil import copyfileobj 

9from typing import ( 

10 TYPE_CHECKING, 

11 Any, 

12 Dict, 

13 List, 

14 Mapping, 

15 Optional, 

16 Sequence, 

17 TypeVar, 

18 Union, 

19) 

20 

21import xarray as xr 

22from imageio.v3 import imread, imwrite # type: ignore 

23from loguru import logger 

24from numpy.typing import NDArray 

25from pydantic import BaseModel, RootModel 

26from typing_extensions import TypeAlias 

27from typing_extensions import TypeAliasType as _TypeAliasType 

28 

29from bioimageio.spec._internal.io import get_reader, interprete_file_source 

30from bioimageio.spec._internal.type_guards import is_ndarray 

31from bioimageio.spec.common import ( 

32 BytesReader, 

33 FileDescr, 

34 FileSource, 

35 HttpUrl, 

36 PermissiveFileSource, 

37 RelativeFilePath, 

38 ZipPath, 

39) 

40from bioimageio.spec.utils import download, load_array, save_array 

41 

42from .axis import AxisId, AxisLike 

43from .common import PerMember 

44from .sample import Sample 

45from .stat_measures import DatasetMeasure, MeasureValue, SampleMeasure, Stat 

46from .tensor import Tensor 

47 

48if TYPE_CHECKING: 

49 JsonValue: TypeAlias = Union[ 

50 bool, int, float, str, None, List["JsonValue"], Dict[str, "JsonValue"] 

51 ] # note: order relevant for deserializing 

52 

53else: 

54 # for pydantic validation we need to use `TypeAliasType`, 

55 # see https://docs.pydantic.dev/latest/concepts/types/#named-recursive-types 

56 # however this results in a partially unknown type with the current pyright 1.1.388 

57 JsonValue: TypeAlias = _TypeAliasType( 

58 "JsonValue", 

59 Union[bool, int, float, str, None, List["JsonValue"], Dict[str, "JsonValue"]], 

60 ) 

61 

62 

63def load_image( 

64 source: Union[PermissiveFileSource, ZipPath], is_volume: Optional[bool] = None 

65) -> NDArray[Any]: 

66 """load a single image as numpy array 

67 

68 Args: 

69 source: image source 

70 is_volume: deprecated 

71 """ 

72 if is_volume is not None: 

73 warnings.warn("**is_volume** is deprecated and will be removed soon.") 

74 

75 if isinstance(source, (FileDescr, ZipPath)): 

76 parsed_source = source 

77 else: 

78 parsed_source = interprete_file_source(source) 

79 

80 if isinstance(parsed_source, RelativeFilePath): 

81 parsed_source = parsed_source.absolute() 

82 

83 if parsed_source.suffix == ".npy": 

84 image = load_array(parsed_source) 

85 else: 

86 reader = download(parsed_source) 

87 image = imread( # pyright: ignore[reportUnknownVariableType] 

88 reader.read(), extension=parsed_source.suffix 

89 ) 

90 

91 assert is_ndarray(image) 

92 return image 

93 

94 

95def load_tensor( 

96 source: Union[PermissiveFileSource, ZipPath], 

97 /, 

98 axes: Optional[Sequence[AxisLike]] = None, 

99) -> Tensor: 

100 # TODO: load axis meta data 

101 array = load_image(source) 

102 

103 return Tensor.from_numpy(array, dims=axes) 

104 

105 

106_SourceT = TypeVar("_SourceT", Path, HttpUrl, ZipPath) 

107 

108Suffix = str 

109 

110 

111def save_tensor(path: Union[Path, str], tensor: Tensor) -> None: 

112 # TODO: save axis meta data 

113 

114 path = Path(path) 

115 if not path.suffix: 

116 raise ValueError(f"No suffix (needed to decide file format) found in {path}") 

117 

118 extension = path.suffix.lower() 

119 path.parent.mkdir(exist_ok=True, parents=True) 

120 if extension == ".npy": 

121 save_array(path, tensor.to_numpy()) 

122 elif extension in (".h5", ".hdf", ".hdf5"): 

123 raise NotImplementedError("Saving to h5 with dataset path is not implemented.") 

124 else: 

125 removed_singleton_axes: List[AxisId] = [] 

126 remove_singletons = { 

127 AxisId("batch"): [ 

128 ".tif", 

129 ".tiff", 

130 ], # remove singleton batch dim for tiff files 

131 **{ 

132 a: [".png", ".jpg", ".jpeg"] for a in tensor.dims 

133 }, # remove any singleton axis for png and jpg files 

134 } 

135 for rm_a, rm_ext in remove_singletons.items(): 

136 if extension in rm_ext and tensor.tagged_shape.get(rm_a) == 1: 

137 tensor = tensor[{rm_a: 0}] 

138 removed_singleton_axes.append(rm_a) 

139 

140 if removed_singleton_axes: 

141 singleton_axes_msg = f"(with removed singleton axes {list(map(str, removed_singleton_axes))}) " 

142 else: 

143 singleton_axes_msg = "" 

144 

145 logger.info( 

146 "writing tensor {} {}to {}", 

147 dict(tensor.tagged_shape), 

148 singleton_axes_msg, 

149 path, 

150 ) 

151 if extension in (".png", ".jpg", ".jpeg") and tensor.dtype in ( 

152 "float32", 

153 "float64", 

154 ): 

155 logger.warning( 

156 "converting tensor of dtype {} to uint8 for saving as {}", 

157 tensor.dtype, 

158 extension, 

159 ) 

160 tensor = ( 

161 (tensor - (t_min := tensor.data.min())) 

162 / xr.ufuncs.maximum(tensor.data.max() - t_min, 1e-8) 

163 * 255 

164 ).astype("uint8") 

165 

166 imwrite(path, tensor, extension=extension) 

167 

168 

169def save_sample( 

170 path: Union[Path, str, PerMember[Union[Path, str]]], sample: Sample 

171) -> None: 

172 """Save a **sample** to a **path** pattern 

173 or all sample members in the **path** mapping. 

174 

175 If **path** is a pathlib.Path or a string and the **sample** has multiple members, 

176 **path** it must contain `{member_id}` (or `{input_id}` or `{output_id}`). 

177 

178 (Each) **path** may contain `{sample_id}` to be formatted with the **sample** object. 

179 """ 

180 if not isinstance(path, collections.abc.Mapping): 

181 if len(sample.members) < 2 or any( 

182 m in str(path) for m in ("{member_id}", "{input_id}", "{output_id}") 

183 ): 

184 path = {m: path for m in sample.members} 

185 else: 

186 raise ValueError( 

187 f"path {path} must contain '{{member_id}}' for sample with multiple members {list(sample.members)}." 

188 ) 

189 

190 for m, p in path.items(): 

191 t = sample.members[m] 

192 p_formatted = Path( 

193 str(p).format(sample_id=sample.id, member_id=m, input_id=m, output_id=m) 

194 ) 

195 save_tensor(p_formatted, t) 

196 

197 

198class _StatEntry(BaseModel, frozen=True, arbitrary_types_allowed=True): 

199 """Serializable stat entry""" 

200 

201 measure: Union[DatasetMeasure, SampleMeasure] 

202 value: MeasureValue 

203 

204 

205class _StatList(RootModel[List[_StatEntry]]): 

206 """Serializable stat mapping""" 

207 

208 pass 

209 

210 

211def serialize_stat( 

212 stat: Mapping[Union[DatasetMeasure, SampleMeasure], MeasureValue], 

213) -> List[JsonValue]: 

214 """Serialize a stat mapping to a JSON string""" 

215 stat_list = _StatList([_StatEntry(measure=k, value=v) for k, v in stat.items()]) 

216 return stat_list.model_dump(mode="json") 

217 

218 

219def save_stat( 

220 stat: Mapping[Union[DatasetMeasure, SampleMeasure], MeasureValue], 

221 output: Union[Path, BytesIO], 

222) -> None: 

223 """Save sample and dataset statistics as a JSON file""" 

224 

225 if isinstance(output, Path): 

226 ctxt = output.open("wb") 

227 else: 

228 ctxt = nullcontext(output) 

229 

230 with ctxt as out: 

231 _ = out.write(json.dumps(serialize_stat(stat), indent=2).encode("utf-8")) 

232 

233 

234def load_stat(source: Union[Path, str, Sequence[JsonValue]]) -> Stat: 

235 """Load sample and dataset statistics from JSON""" 

236 if isinstance(source, Path): 

237 source = source.read_text(encoding="utf-8") 

238 

239 if isinstance(source, str): 

240 seq = _StatList.model_validate_json(source) 

241 else: 

242 seq = _StatList.model_validate(source) 

243 

244 return {e.measure: e.value for e in seq.root} 

245 

246 

247def save_dataset_stat(stat: Mapping[DatasetMeasure, MeasureValue], path: Path) -> None: 

248 """DEPRECATED alias for save_stat(): use `save_stats()` instead.""" 

249 warnings.warn("`save_dataset_stat()` is deprecated, use `save_stats()` instead.") 

250 save_stat({k: v for k, v in stat.items()}, path) 

251 

252 

253def load_dataset_stat(path: Path) -> Stat: 

254 """DEPRECATED alias for `load_stat()`: use `load_stat()` instead.""" 

255 warnings.warn("`load_dataset_stat()` is deprecated, use `load_stats()` instead.") 

256 return load_stat(path) 

257 

258 

259def ensure_unzipped( 

260 source: Union[PermissiveFileSource, ZipPath, BytesReader], folder: Path 

261): 

262 """unzip a (downloaded) **source** to a file in **folder** if source is a zip archive 

263 otherwise copy **source** to a file in **folder**.""" 

264 if isinstance(source, BytesReader): 

265 weights_reader = source 

266 else: 

267 weights_reader = get_reader(source) 

268 

269 out_path = folder / ( 

270 weights_reader.original_file_name or f"file{weights_reader.suffix}" 

271 ) 

272 

273 if zipfile.is_zipfile(weights_reader): 

274 out_path = out_path.with_name(out_path.name + ".unzipped") 

275 out_path.parent.mkdir(exist_ok=True, parents=True) 

276 # source itself is a zipfile 

277 with zipfile.ZipFile(weights_reader, "r") as f: 

278 f.extractall(out_path) 

279 

280 else: 

281 out_path.parent.mkdir(exist_ok=True, parents=True) 

282 with out_path.open("wb") as f: 

283 copyfileobj(weights_reader, f) 

284 

285 return out_path 

286 

287 

288def get_suffix(source: Union[ZipPath, FileSource]) -> Suffix: 

289 """DEPRECATED: use source.suffix instead.""" 

290 return source.suffix