Coverage for src / bioimageio / core / io.py: 84%

91 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-22 18:38 +0000

1import collections.abc 

2import warnings 

3import zipfile 

4from pathlib import Path 

5from shutil import copyfileobj 

6from typing import ( 

7 Any, 

8 Mapping, 

9 Optional, 

10 Sequence, 

11 TypeVar, 

12 Union, 

13) 

14 

15from imageio.v3 import imread, imwrite # type: ignore 

16from loguru import logger 

17from numpy.typing import NDArray 

18from pydantic import BaseModel, ConfigDict, TypeAdapter 

19 

20from bioimageio.spec._internal.io import get_reader, interprete_file_source 

21from bioimageio.spec._internal.type_guards import is_ndarray 

22from bioimageio.spec.common import ( 

23 BytesReader, 

24 FileDescr, 

25 FileSource, 

26 HttpUrl, 

27 PermissiveFileSource, 

28 RelativeFilePath, 

29 ZipPath, 

30) 

31from bioimageio.spec.utils import download, load_array, save_array 

32 

33from .axis import AxisId, AxisLike 

34from .common import PerMember 

35from .sample import Sample 

36from .stat_measures import DatasetMeasure, MeasureValue 

37from .tensor import Tensor 

38 

39 

40def load_image( 

41 source: Union[PermissiveFileSource, ZipPath], is_volume: Optional[bool] = None 

42) -> NDArray[Any]: 

43 """load a single image as numpy array 

44 

45 Args: 

46 source: image source 

47 is_volume: deprecated 

48 """ 

49 if is_volume is not None: 

50 warnings.warn("**is_volume** is deprecated and will be removed soon.") 

51 

52 if isinstance(source, (FileDescr, ZipPath)): 

53 parsed_source = source 

54 else: 

55 parsed_source = interprete_file_source(source) 

56 

57 if isinstance(parsed_source, RelativeFilePath): 

58 parsed_source = parsed_source.absolute() 

59 

60 if parsed_source.suffix == ".npy": 

61 image = load_array(parsed_source) 

62 else: 

63 reader = download(parsed_source) 

64 image = imread( # pyright: ignore[reportUnknownVariableType] 

65 reader.read(), extension=parsed_source.suffix 

66 ) 

67 

68 assert is_ndarray(image) 

69 return image 

70 

71 

72def load_tensor( 

73 source: Union[PermissiveFileSource, ZipPath], 

74 /, 

75 axes: Optional[Sequence[AxisLike]] = None, 

76) -> Tensor: 

77 # TODO: load axis meta data 

78 array = load_image(source) 

79 

80 return Tensor.from_numpy(array, dims=axes) 

81 

82 

83_SourceT = TypeVar("_SourceT", Path, HttpUrl, ZipPath) 

84 

85Suffix = str 

86 

87 

88def save_tensor(path: Union[Path, str], tensor: Tensor) -> None: 

89 # TODO: save axis meta data 

90 

91 data: NDArray[Any] = ( # pyright: ignore[reportUnknownVariableType] 

92 tensor.data.to_numpy() 

93 ) 

94 assert is_ndarray(data) 

95 path = Path(path) 

96 if not path.suffix: 

97 raise ValueError(f"No suffix (needed to decide file format) found in {path}") 

98 

99 extension = path.suffix.lower() 

100 path.parent.mkdir(exist_ok=True, parents=True) 

101 if extension == ".npy": 

102 save_array(path, data) 

103 elif extension in (".h5", ".hdf", ".hdf5"): 

104 raise NotImplementedError("Saving to h5 with dataset path is not implemented.") 

105 else: 

106 if ( 

107 extension in (".tif", ".tiff") 

108 and tensor.tagged_shape.get(ba := AxisId("batch")) == 1 

109 ): 

110 # remove singleton batch axis for saving 

111 tensor = tensor[{ba: 0}] 

112 singleton_axes_msg = "(without singleton batch axes) " 

113 else: 

114 singleton_axes_msg = "" 

115 

116 logger.debug( 

117 "writing tensor {} {}to {}", 

118 dict(tensor.tagged_shape), 

119 singleton_axes_msg, 

120 path, 

121 ) 

122 imwrite(path, data, extension=extension) 

123 

124 

125def save_sample( 

126 path: Union[Path, str, PerMember[Union[Path, str]]], sample: Sample 

127) -> None: 

128 """Save a **sample** to a **path** pattern 

129 or all sample members in the **path** mapping. 

130 

131 If **path** is a pathlib.Path or a string and the **sample** has multiple members, 

132 **path** it must contain `{member_id}` (or `{input_id}` or `{output_id}`). 

133 

134 (Each) **path** may contain `{sample_id}` to be formatted with the **sample** object. 

135 """ 

136 if not isinstance(path, collections.abc.Mapping): 

137 if len(sample.members) < 2 or any( 

138 m in str(path) for m in ("{member_id}", "{input_id}", "{output_id}") 

139 ): 

140 path = {m: path for m in sample.members} 

141 else: 

142 raise ValueError( 

143 f"path {path} must contain '{{member_id}}' for sample with multiple members {list(sample.members)}." 

144 ) 

145 

146 for m, p in path.items(): 

147 t = sample.members[m] 

148 p_formatted = Path( 

149 str(p).format(sample_id=sample.id, member_id=m, input_id=m, output_id=m) 

150 ) 

151 save_tensor(p_formatted, t) 

152 

153 

154class _SerializedDatasetStatsEntry( 

155 BaseModel, frozen=True, arbitrary_types_allowed=True 

156): 

157 measure: DatasetMeasure 

158 value: MeasureValue 

159 

160 

161_stat_adapter = TypeAdapter( 

162 Sequence[_SerializedDatasetStatsEntry], 

163 config=ConfigDict(arbitrary_types_allowed=True), 

164) 

165 

166 

167def save_dataset_stat(stat: Mapping[DatasetMeasure, MeasureValue], path: Path): 

168 serializable = [ 

169 _SerializedDatasetStatsEntry(measure=k, value=v) for k, v in stat.items() 

170 ] 

171 _ = path.write_bytes(_stat_adapter.dump_json(serializable)) 

172 

173 

174def load_dataset_stat(path: Path): 

175 seq = _stat_adapter.validate_json(path.read_bytes()) 

176 return {e.measure: e.value for e in seq} 

177 

178 

179def ensure_unzipped( 

180 source: Union[PermissiveFileSource, ZipPath, BytesReader], folder: Path 

181): 

182 """unzip a (downloaded) **source** to a file in **folder** if source is a zip archive 

183 otherwise copy **source** to a file in **folder**.""" 

184 if isinstance(source, BytesReader): 

185 weights_reader = source 

186 else: 

187 weights_reader = get_reader(source) 

188 

189 out_path = folder / ( 

190 weights_reader.original_file_name or f"file{weights_reader.suffix}" 

191 ) 

192 

193 if zipfile.is_zipfile(weights_reader): 

194 out_path = out_path.with_name(out_path.name + ".unzipped") 

195 out_path.parent.mkdir(exist_ok=True, parents=True) 

196 # source itself is a zipfile 

197 with zipfile.ZipFile(weights_reader, "r") as f: 

198 f.extractall(out_path) 

199 

200 else: 

201 out_path.parent.mkdir(exist_ok=True, parents=True) 

202 with out_path.open("wb") as f: 

203 copyfileobj(weights_reader, f) 

204 

205 return out_path 

206 

207 

208def get_suffix(source: Union[ZipPath, FileSource]) -> Suffix: 

209 """DEPRECATED: use source.suffix instead.""" 

210 return source.suffix