Coverage for src / bioimageio / core / io.py: 84%

91 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-13 09:46 +0000

1import collections.abc 

2import warnings 

3import zipfile 

4from pathlib import Path 

5from shutil import copyfileobj 

6from typing import ( 

7 Any, 

8 Mapping, 

9 Optional, 

10 Sequence, 

11 TypeVar, 

12 Union, 

13) 

14 

15from imageio.v3 import imread, imwrite # type: ignore 

16from loguru import logger 

17from numpy.typing import NDArray 

18from pydantic import BaseModel, ConfigDict, TypeAdapter 

19 

20from bioimageio.spec._internal.io import get_reader, interprete_file_source 

21from bioimageio.spec._internal.type_guards import is_ndarray 

22from bioimageio.spec.common import ( 

23 BytesReader, 

24 FileSource, 

25 HttpUrl, 

26 PermissiveFileSource, 

27 RelativeFilePath, 

28 ZipPath, 

29) 

30from bioimageio.spec.utils import download, load_array, save_array 

31 

32from .axis import AxisId, AxisLike 

33from .common import PerMember 

34from .sample import Sample 

35from .stat_measures import DatasetMeasure, MeasureValue 

36from .tensor import Tensor 

37 

38 

39def load_image( 

40 source: Union[ZipPath, PermissiveFileSource], is_volume: Optional[bool] = None 

41) -> NDArray[Any]: 

42 """load a single image as numpy array 

43 

44 Args: 

45 source: image source 

46 is_volume: deprecated 

47 """ 

48 if is_volume is not None: 

49 warnings.warn("**is_volume** is deprecated and will be removed soon.") 

50 

51 if isinstance(source, ZipPath): 

52 parsed_source = source 

53 else: 

54 parsed_source = interprete_file_source(source) 

55 

56 if isinstance(parsed_source, RelativeFilePath): 

57 parsed_source = parsed_source.absolute() 

58 

59 if parsed_source.suffix == ".npy": 

60 image = load_array(parsed_source) 

61 else: 

62 reader = download(parsed_source) 

63 image = imread( # pyright: ignore[reportUnknownVariableType] 

64 reader.read(), extension=parsed_source.suffix 

65 ) 

66 

67 assert is_ndarray(image) 

68 return image 

69 

70 

71def load_tensor( 

72 path: Union[ZipPath, Path, str], axes: Optional[Sequence[AxisLike]] = None 

73) -> Tensor: 

74 # TODO: load axis meta data 

75 array = load_image(path) 

76 

77 return Tensor.from_numpy(array, dims=axes) 

78 

79 

80_SourceT = TypeVar("_SourceT", Path, HttpUrl, ZipPath) 

81 

82Suffix = str 

83 

84 

85def save_tensor(path: Union[Path, str], tensor: Tensor) -> None: 

86 # TODO: save axis meta data 

87 

88 data: NDArray[Any] = ( # pyright: ignore[reportUnknownVariableType] 

89 tensor.data.to_numpy() 

90 ) 

91 assert is_ndarray(data) 

92 path = Path(path) 

93 if not path.suffix: 

94 raise ValueError(f"No suffix (needed to decide file format) found in {path}") 

95 

96 extension = path.suffix.lower() 

97 path.parent.mkdir(exist_ok=True, parents=True) 

98 if extension == ".npy": 

99 save_array(path, data) 

100 elif extension in (".h5", ".hdf", ".hdf5"): 

101 raise NotImplementedError("Saving to h5 with dataset path is not implemented.") 

102 else: 

103 if ( 

104 extension in (".tif", ".tiff") 

105 and tensor.tagged_shape.get(ba := AxisId("batch")) == 1 

106 ): 

107 # remove singleton batch axis for saving 

108 tensor = tensor[{ba: 0}] 

109 singleton_axes_msg = f"(without singleton batch axes) " 

110 else: 

111 singleton_axes_msg = "" 

112 

113 logger.debug( 

114 "writing tensor {} {}to {}", 

115 dict(tensor.tagged_shape), 

116 singleton_axes_msg, 

117 path, 

118 ) 

119 imwrite(path, data, extension=extension) 

120 

121 

122def save_sample( 

123 path: Union[Path, str, PerMember[Union[Path, str]]], sample: Sample 

124) -> None: 

125 """Save a **sample** to a **path** pattern 

126 or all sample members in the **path** mapping. 

127 

128 If **path** is a pathlib.Path or a string and the **sample** has multiple members, 

129 **path** it must contain `{member_id}` (or `{input_id}` or `{output_id}`). 

130 

131 (Each) **path** may contain `{sample_id}` to be formatted with the **sample** object. 

132 """ 

133 if not isinstance(path, collections.abc.Mapping): 

134 if len(sample.members) < 2 or any( 

135 m in str(path) for m in ("{member_id}", "{input_id}", "{output_id}") 

136 ): 

137 path = {m: path for m in sample.members} 

138 else: 

139 raise ValueError( 

140 f"path {path} must contain '{{member_id}}' for sample with multiple members {list(sample.members)}." 

141 ) 

142 

143 for m, p in path.items(): 

144 t = sample.members[m] 

145 p_formatted = Path( 

146 str(p).format(sample_id=sample.id, member_id=m, input_id=m, output_id=m) 

147 ) 

148 save_tensor(p_formatted, t) 

149 

150 

151class _SerializedDatasetStatsEntry( 

152 BaseModel, frozen=True, arbitrary_types_allowed=True 

153): 

154 measure: DatasetMeasure 

155 value: MeasureValue 

156 

157 

158_stat_adapter = TypeAdapter( 

159 Sequence[_SerializedDatasetStatsEntry], 

160 config=ConfigDict(arbitrary_types_allowed=True), 

161) 

162 

163 

164def save_dataset_stat(stat: Mapping[DatasetMeasure, MeasureValue], path: Path): 

165 serializable = [ 

166 _SerializedDatasetStatsEntry(measure=k, value=v) for k, v in stat.items() 

167 ] 

168 _ = path.write_bytes(_stat_adapter.dump_json(serializable)) 

169 

170 

171def load_dataset_stat(path: Path): 

172 seq = _stat_adapter.validate_json(path.read_bytes()) 

173 return {e.measure: e.value for e in seq} 

174 

175 

176def ensure_unzipped( 

177 source: Union[PermissiveFileSource, ZipPath, BytesReader], folder: Path 

178): 

179 """unzip a (downloaded) **source** to a file in **folder** if source is a zip archive 

180 otherwise copy **source** to a file in **folder**.""" 

181 if isinstance(source, BytesReader): 

182 weights_reader = source 

183 else: 

184 weights_reader = get_reader(source) 

185 

186 out_path = folder / ( 

187 weights_reader.original_file_name or f"file{weights_reader.suffix}" 

188 ) 

189 

190 if zipfile.is_zipfile(weights_reader): 

191 out_path = out_path.with_name(out_path.name + ".unzipped") 

192 out_path.parent.mkdir(exist_ok=True, parents=True) 

193 # source itself is a zipfile 

194 with zipfile.ZipFile(weights_reader, "r") as f: 

195 f.extractall(out_path) 

196 

197 else: 

198 out_path.parent.mkdir(exist_ok=True, parents=True) 

199 with out_path.open("wb") as f: 

200 copyfileobj(weights_reader, f) 

201 

202 return out_path 

203 

204 

205def get_suffix(source: Union[ZipPath, FileSource]) -> Suffix: 

206 """DEPRECATED: use source.suffix instead.""" 

207 return source.suffix