Coverage for bioimageio/core/io.py: 72%
151 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-16 15:20 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-16 15:20 +0000
1import collections.abc
2import warnings
3import zipfile
4from pathlib import Path, PurePosixPath
5from shutil import copyfileobj
6from typing import (
7 Any,
8 Mapping,
9 Optional,
10 Sequence,
11 Tuple,
12 TypeVar,
13 Union,
14)
16import h5py # pyright: ignore[reportMissingTypeStubs]
17from imageio.v3 import imread, imwrite # type: ignore
18from loguru import logger
19from numpy.typing import NDArray
20from pydantic import BaseModel, ConfigDict, TypeAdapter
21from typing_extensions import assert_never
23from bioimageio.spec._internal.io import get_reader, interprete_file_source
24from bioimageio.spec._internal.type_guards import is_ndarray
25from bioimageio.spec.common import (
26 BytesReader,
27 FileSource,
28 HttpUrl,
29 PermissiveFileSource,
30 RelativeFilePath,
31 ZipPath,
32)
33from bioimageio.spec.utils import download, load_array, save_array
35from .axis import AxisLike
36from .common import PerMember
37from .sample import Sample
38from .stat_measures import DatasetMeasure, MeasureValue
39from .tensor import Tensor
41DEFAULT_H5_DATASET_PATH = "data"
44SUFFIXES_WITH_DATAPATH = (".h5", ".hdf", ".hdf5")
47def load_image(
48 source: Union[ZipPath, PermissiveFileSource], is_volume: Optional[bool] = None
49) -> NDArray[Any]:
50 """load a single image as numpy array
52 Args:
53 source: image source
54 is_volume: deprecated
55 """
56 if is_volume is not None:
57 warnings.warn("**is_volume** is deprecated and will be removed soon.")
59 if isinstance(source, ZipPath):
60 parsed_source = source
61 else:
62 parsed_source = interprete_file_source(source)
64 if isinstance(parsed_source, RelativeFilePath):
65 src = parsed_source.absolute()
66 else:
67 src = parsed_source
69 if isinstance(src, Path):
70 file_source, suffix, subpath = _split_dataset_path(src)
71 elif isinstance(src, HttpUrl):
72 file_source, suffix, subpath = _split_dataset_path(src)
73 elif isinstance(src, ZipPath):
74 file_source, suffix, subpath = _split_dataset_path(src)
75 else:
76 assert_never(src)
78 if suffix == ".npy":
79 if subpath is not None:
80 logger.warning(
81 "Unexpected subpath {} for .npy source {}", subpath, file_source
82 )
84 image = load_array(file_source)
85 elif suffix in SUFFIXES_WITH_DATAPATH:
86 if subpath is None:
87 dataset_path = DEFAULT_H5_DATASET_PATH
88 else:
89 dataset_path = str(subpath)
91 reader = download(file_source)
93 with h5py.File(reader, "r") as f:
94 h5_dataset = f.get( # pyright: ignore[reportUnknownVariableType]
95 dataset_path
96 )
97 if not isinstance(h5_dataset, h5py.Dataset):
98 raise ValueError(
99 f"{file_source} did not load as {h5py.Dataset}, but has type "
100 + str(
101 type(h5_dataset) # pyright: ignore[reportUnknownArgumentType]
102 )
103 )
104 image: NDArray[Any]
105 image = h5_dataset[:] # pyright: ignore[reportUnknownVariableType]
106 else:
107 reader = download(file_source)
108 image = imread( # pyright: ignore[reportUnknownVariableType]
109 reader.read(), extension=suffix
110 )
112 assert is_ndarray(image)
113 return image
116def load_tensor(
117 path: Union[ZipPath, Path, str], axes: Optional[Sequence[AxisLike]] = None
118) -> Tensor:
119 # TODO: load axis meta data
120 array = load_image(path)
122 return Tensor.from_numpy(array, dims=axes)
125_SourceT = TypeVar("_SourceT", Path, HttpUrl, ZipPath)
127Suffix = str
130def _split_dataset_path(
131 source: _SourceT,
132) -> Tuple[_SourceT, Suffix, Optional[PurePosixPath]]:
133 """Split off subpath (e.g. internal h5 dataset path)
134 from a file path following a file extension.
136 Examples:
137 >>> _split_dataset_path(Path("my_file.h5/dataset"))
138 (...Path('my_file.h5'), '.h5', PurePosixPath('dataset'))
140 >>> _split_dataset_path(Path("my_plain_file"))
141 (...Path('my_plain_file'), '', None)
143 """
144 if isinstance(source, RelativeFilePath):
145 src = source.absolute()
146 else:
147 src = source
149 del source
151 def separate_pure_path(path: PurePosixPath):
152 for p in path.parents:
153 if p.suffix in SUFFIXES_WITH_DATAPATH:
154 return p, p.suffix, PurePosixPath(path.relative_to(p))
156 return path, path.suffix, None
158 if isinstance(src, HttpUrl):
159 file_path, suffix, data_path = separate_pure_path(PurePosixPath(src.path or ""))
161 if data_path is None:
162 return src, suffix, None
164 return (
165 HttpUrl(str(file_path).replace(f"/{data_path}", "")),
166 suffix,
167 data_path,
168 )
170 if isinstance(src, ZipPath):
171 file_path, suffix, data_path = separate_pure_path(PurePosixPath(str(src)))
173 if data_path is None:
174 return src, suffix, None
176 return (
177 ZipPath(str(file_path).replace(f"/{data_path}", "")),
178 suffix,
179 data_path,
180 )
182 file_path, suffix, data_path = separate_pure_path(PurePosixPath(src))
183 return Path(file_path), suffix, data_path
186def save_tensor(path: Union[Path, str], tensor: Tensor) -> None:
187 # TODO: save axis meta data
189 data: NDArray[Any] = ( # pyright: ignore[reportUnknownVariableType]
190 tensor.data.to_numpy()
191 )
192 assert is_ndarray(data)
193 file_path, suffix, subpath = _split_dataset_path(Path(path))
194 if not suffix:
195 raise ValueError(f"No suffix (needed to decide file format) found in {path}")
197 file_path.parent.mkdir(exist_ok=True, parents=True)
198 if file_path.suffix == ".npy":
199 if subpath is not None:
200 raise ValueError(f"Unexpected subpath {subpath} found in .npy path {path}")
201 save_array(file_path, data)
202 elif suffix in (".h5", ".hdf", ".hdf5"):
203 if subpath is None:
204 dataset_path = DEFAULT_H5_DATASET_PATH
205 else:
206 dataset_path = str(subpath)
208 with h5py.File(file_path, "a") as f:
209 if dataset_path in f:
210 del f[dataset_path]
212 _ = f.create_dataset(dataset_path, data=data, chunks=True)
213 else:
214 # if singleton_axes := [a for a, s in tensor.tagged_shape.items() if s == 1]:
215 # tensor = tensor[{a: 0 for a in singleton_axes}]
216 # singleton_axes_msg = f"(without singleton axes {singleton_axes}) "
217 # else:
218 singleton_axes_msg = ""
220 logger.debug(
221 "writing tensor {} {}to {}",
222 dict(tensor.tagged_shape),
223 singleton_axes_msg,
224 path,
225 )
226 imwrite(path, data)
229def save_sample(
230 path: Union[Path, str, PerMember[Union[Path, str]]], sample: Sample
231) -> None:
232 """Save a **sample** to a **path** pattern
233 or all sample members in the **path** mapping.
235 If **path** is a pathlib.Path or a string and the **sample** has multiple members,
236 **path** it must contain `{member_id}` (or `{input_id}` or `{output_id}`).
238 (Each) **path** may contain `{sample_id}` to be formatted with the **sample** object.
239 """
240 if not isinstance(path, collections.abc.Mapping):
241 if len(sample.members) < 2 or any(
242 m in str(path) for m in ("{member_id}", "{input_id}", "{output_id}")
243 ):
244 path = {m: path for m in sample.members}
245 else:
246 raise ValueError(
247 f"path {path} must contain '{{member_id}}' for sample with multiple members {list(sample.members)}."
248 )
250 for m, p in path.items():
251 t = sample.members[m]
252 p_formatted = Path(
253 str(p).format(sample_id=sample.id, member_id=m, input_id=m, output_id=m)
254 )
255 save_tensor(p_formatted, t)
258class _SerializedDatasetStatsEntry(
259 BaseModel, frozen=True, arbitrary_types_allowed=True
260):
261 measure: DatasetMeasure
262 value: MeasureValue
265_stat_adapter = TypeAdapter(
266 Sequence[_SerializedDatasetStatsEntry],
267 config=ConfigDict(arbitrary_types_allowed=True),
268)
271def save_dataset_stat(stat: Mapping[DatasetMeasure, MeasureValue], path: Path):
272 serializable = [
273 _SerializedDatasetStatsEntry(measure=k, value=v) for k, v in stat.items()
274 ]
275 _ = path.write_bytes(_stat_adapter.dump_json(serializable))
278def load_dataset_stat(path: Path):
279 seq = _stat_adapter.validate_json(path.read_bytes())
280 return {e.measure: e.value for e in seq}
283def ensure_unzipped(
284 source: Union[PermissiveFileSource, ZipPath, BytesReader], folder: Path
285):
286 """unzip a (downloaded) **source** to a file in **folder** if source is a zip archive
287 otherwise copy **source** to a file in **folder**."""
288 if isinstance(source, BytesReader):
289 weights_reader = source
290 else:
291 weights_reader = get_reader(source)
293 out_path = folder / (
294 weights_reader.original_file_name or f"file{weights_reader.suffix}"
295 )
297 if zipfile.is_zipfile(weights_reader):
298 out_path = out_path.with_name(out_path.name + ".unzipped")
299 out_path.parent.mkdir(exist_ok=True, parents=True)
300 # source itself is a zipfile
301 with zipfile.ZipFile(weights_reader, "r") as f:
302 f.extractall(out_path)
304 else:
305 out_path.parent.mkdir(exist_ok=True, parents=True)
306 with out_path.open("wb") as f:
307 copyfileobj(weights_reader, f)
309 return out_path
312def get_suffix(source: Union[ZipPath, FileSource]) -> str:
313 if isinstance(source, Path):
314 return source.suffix
315 elif isinstance(source, ZipPath):
316 return source.suffix
317 if isinstance(source, RelativeFilePath):
318 return source.path.suffix
319 elif isinstance(source, ZipPath):
320 return source.suffix
321 elif isinstance(source, HttpUrl):
322 if source.path is None:
323 return ""
324 else:
325 return PurePosixPath(source.path).suffix
326 else:
327 assert_never(source)