Coverage for bioimageio/core/io.py: 76%
138 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 09:51 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 09:51 +0000
1import collections.abc
2import warnings
3import zipfile
4from io import TextIOWrapper
5from pathlib import Path, PurePosixPath
6from shutil import copyfileobj
7from typing import (
8 Any,
9 Mapping,
10 Optional,
11 Sequence,
12 Tuple,
13 TypeVar,
14 Union,
15)
17import h5py # pyright: ignore[reportMissingTypeStubs]
18import numpy as np
19from imageio.v3 import imread, imwrite # type: ignore
20from loguru import logger
21from numpy.typing import NDArray
22from pydantic import BaseModel, ConfigDict, TypeAdapter
23from typing_extensions import assert_never
25from bioimageio.spec._internal.io import interprete_file_source
26from bioimageio.spec.common import (
27 HttpUrl,
28 PermissiveFileSource,
29 RelativeFilePath,
30 ZipPath,
31)
32from bioimageio.spec.utils import download, load_array, save_array
34from .axis import AxisLike
35from .common import PerMember
36from .sample import Sample
37from .stat_measures import DatasetMeasure, MeasureValue
38from .tensor import Tensor
40DEFAULT_H5_DATASET_PATH = "data"
43SUFFIXES_WITH_DATAPATH = (".h5", ".hdf", ".hdf5")
46def load_image(
47 source: Union[ZipPath, PermissiveFileSource], is_volume: Optional[bool] = None
48) -> NDArray[Any]:
49 """load a single image as numpy array
51 Args:
52 source: image source
53 is_volume: deprecated
54 """
55 if is_volume is not None:
56 warnings.warn("**is_volume** is deprecated and will be removed soon.")
58 if isinstance(source, ZipPath):
59 parsed_source = source
60 else:
61 parsed_source = interprete_file_source(source)
63 if isinstance(parsed_source, RelativeFilePath):
64 src = parsed_source.absolute()
65 else:
66 src = parsed_source
68 # FIXME: why is pyright complaining about giving the union to _split_dataset_path?
69 if isinstance(src, Path):
70 file_source, subpath = _split_dataset_path(src)
71 elif isinstance(src, HttpUrl):
72 file_source, subpath = _split_dataset_path(src)
73 elif isinstance(src, ZipPath):
74 file_source, subpath = _split_dataset_path(src)
75 else:
76 assert_never(src)
78 path = download(file_source).path
80 if path.suffix == ".npy":
81 if subpath is not None:
82 raise ValueError(f"Unexpected subpath {subpath} for .npy path {path}")
83 return load_array(path)
84 elif path.suffix in SUFFIXES_WITH_DATAPATH:
85 if subpath is None:
86 dataset_path = DEFAULT_H5_DATASET_PATH
87 else:
88 dataset_path = str(subpath)
90 with h5py.File(path, "r") as f:
91 h5_dataset = f.get( # pyright: ignore[reportUnknownVariableType]
92 dataset_path
93 )
94 if not isinstance(h5_dataset, h5py.Dataset):
95 raise ValueError(
96 f"{path} is not of type {h5py.Dataset}, but has type "
97 + str(
98 type(h5_dataset) # pyright: ignore[reportUnknownArgumentType]
99 )
100 )
101 image: NDArray[Any]
102 image = h5_dataset[:] # pyright: ignore[reportUnknownVariableType]
103 assert isinstance(image, np.ndarray), type(
104 image # pyright: ignore[reportUnknownArgumentType]
105 )
106 return image # pyright: ignore[reportUnknownVariableType]
107 elif isinstance(path, ZipPath):
108 return imread(
109 path.read_bytes(), extension=path.suffix
110 ) # pyright: ignore[reportUnknownVariableType]
111 else:
112 return imread(path) # pyright: ignore[reportUnknownVariableType]
115def load_tensor(
116 path: Union[ZipPath, Path, str], axes: Optional[Sequence[AxisLike]] = None
117) -> Tensor:
118 # TODO: load axis meta data
119 array = load_image(path)
121 return Tensor.from_numpy(array, dims=axes)
124_SourceT = TypeVar("_SourceT", Path, HttpUrl, ZipPath)
127def _split_dataset_path(
128 source: _SourceT,
129) -> Tuple[_SourceT, Optional[PurePosixPath]]:
130 """Split off subpath (e.g. internal h5 dataset path)
131 from a file path following a file extension.
133 Examples:
134 >>> _split_dataset_path(Path("my_file.h5/dataset"))
135 (...Path('my_file.h5'), PurePosixPath('dataset'))
137 >>> _split_dataset_path(Path("my_plain_file"))
138 (...Path('my_plain_file'), None)
140 """
141 if isinstance(source, RelativeFilePath):
142 src = source.absolute()
143 else:
144 src = source
146 del source
148 def separate_pure_path(path: PurePosixPath):
149 for p in path.parents:
150 if p.suffix in SUFFIXES_WITH_DATAPATH:
151 return p, PurePosixPath(path.relative_to(p))
153 return path, None
155 if isinstance(src, HttpUrl):
156 file_path, data_path = separate_pure_path(PurePosixPath(src.path or ""))
158 if data_path is None:
159 return src, None
161 return (
162 HttpUrl(str(file_path).replace(f"/{data_path}", "")),
163 data_path,
164 )
166 if isinstance(src, ZipPath):
167 file_path, data_path = separate_pure_path(PurePosixPath(str(src)))
169 if data_path is None:
170 return src, None
172 return (
173 ZipPath(str(file_path).replace(f"/{data_path}", "")),
174 data_path,
175 )
177 file_path, data_path = separate_pure_path(PurePosixPath(src))
178 return Path(file_path), data_path
181def save_tensor(path: Union[Path, str], tensor: Tensor) -> None:
182 # TODO: save axis meta data
184 data: NDArray[Any] = tensor.data.to_numpy()
185 file_path, subpath = _split_dataset_path(Path(path))
186 if not file_path.suffix:
187 raise ValueError(f"No suffix (needed to decide file format) found in {path}")
189 file_path.parent.mkdir(exist_ok=True, parents=True)
190 if file_path.suffix == ".npy":
191 if subpath is not None:
192 raise ValueError(f"Unexpected subpath {subpath} found in .npy path {path}")
193 save_array(file_path, data)
194 elif file_path.suffix in (".h5", ".hdf", ".hdf5"):
195 if subpath is None:
196 dataset_path = DEFAULT_H5_DATASET_PATH
197 else:
198 dataset_path = str(subpath)
200 with h5py.File(file_path, "a") as f:
201 if dataset_path in f:
202 del f[dataset_path]
204 _ = f.create_dataset(dataset_path, data=data, chunks=True)
205 else:
206 # if singleton_axes := [a for a, s in tensor.tagged_shape.items() if s == 1]:
207 # tensor = tensor[{a: 0 for a in singleton_axes}]
208 # singleton_axes_msg = f"(without singleton axes {singleton_axes}) "
209 # else:
210 singleton_axes_msg = ""
212 logger.debug(
213 "writing tensor {} {}to {}",
214 dict(tensor.tagged_shape),
215 singleton_axes_msg,
216 path,
217 )
218 imwrite(path, data)
221def save_sample(
222 path: Union[Path, str, PerMember[Union[Path, str]]], sample: Sample
223) -> None:
224 """Save a **sample** to a **path** pattern
225 or all sample members in the **path** mapping.
227 If **path** is a pathlib.Path or a string and the **sample** has multiple members,
228 **path** it must contain `{member_id}` (or `{input_id}` or `{output_id}`).
230 (Each) **path** may contain `{sample_id}` to be formatted with the **sample** object.
231 """
232 if not isinstance(path, collections.abc.Mapping):
233 if len(sample.members) < 2 or any(
234 m in str(path) for m in ("{member_id}", "{input_id}", "{output_id}")
235 ):
236 path = {m: path for m in sample.members}
237 else:
238 raise ValueError(
239 f"path {path} must contain '{{member_id}}' for sample with multiple members {list(sample.members)}."
240 )
242 for m, p in path.items():
243 t = sample.members[m]
244 p_formatted = Path(
245 str(p).format(sample_id=sample.id, member_id=m, input_id=m, output_id=m)
246 )
247 save_tensor(p_formatted, t)
250class _SerializedDatasetStatsEntry(
251 BaseModel, frozen=True, arbitrary_types_allowed=True
252):
253 measure: DatasetMeasure
254 value: MeasureValue
257_stat_adapter = TypeAdapter(
258 Sequence[_SerializedDatasetStatsEntry],
259 config=ConfigDict(arbitrary_types_allowed=True),
260)
263def save_dataset_stat(stat: Mapping[DatasetMeasure, MeasureValue], path: Path):
264 serializable = [
265 _SerializedDatasetStatsEntry(measure=k, value=v) for k, v in stat.items()
266 ]
267 _ = path.write_bytes(_stat_adapter.dump_json(serializable))
270def load_dataset_stat(path: Path):
271 seq = _stat_adapter.validate_json(path.read_bytes())
272 return {e.measure: e.value for e in seq}
275def ensure_unzipped(source: Union[PermissiveFileSource, ZipPath], folder: Path):
276 """unzip a (downloaded) **source** to a file in **folder** if source is a zip archive.
277 Always returns the path to the unzipped source (maybe source itself)"""
278 local_weights_file = download(source).path
279 if isinstance(local_weights_file, ZipPath):
280 # source is inside a zip archive
281 out_path = folder / local_weights_file.filename
282 with local_weights_file.open("rb") as src, out_path.open("wb") as dst:
283 assert not isinstance(src, TextIOWrapper)
284 copyfileobj(src, dst)
286 local_weights_file = out_path
288 if zipfile.is_zipfile(local_weights_file):
289 # source itself is a zipfile
290 out_path = folder / local_weights_file.with_suffix(".unzipped").name
291 with zipfile.ZipFile(local_weights_file, "r") as f:
292 f.extractall(out_path)
294 return out_path
295 else:
296 return local_weights_file