Coverage for src / bioimageio / core / io.py: 84%
91 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-13 09:46 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-13 09:46 +0000
1import collections.abc
2import warnings
3import zipfile
4from pathlib import Path
5from shutil import copyfileobj
6from typing import (
7 Any,
8 Mapping,
9 Optional,
10 Sequence,
11 TypeVar,
12 Union,
13)
15from imageio.v3 import imread, imwrite # type: ignore
16from loguru import logger
17from numpy.typing import NDArray
18from pydantic import BaseModel, ConfigDict, TypeAdapter
20from bioimageio.spec._internal.io import get_reader, interprete_file_source
21from bioimageio.spec._internal.type_guards import is_ndarray
22from bioimageio.spec.common import (
23 BytesReader,
24 FileSource,
25 HttpUrl,
26 PermissiveFileSource,
27 RelativeFilePath,
28 ZipPath,
29)
30from bioimageio.spec.utils import download, load_array, save_array
32from .axis import AxisId, AxisLike
33from .common import PerMember
34from .sample import Sample
35from .stat_measures import DatasetMeasure, MeasureValue
36from .tensor import Tensor
39def load_image(
40 source: Union[ZipPath, PermissiveFileSource], is_volume: Optional[bool] = None
41) -> NDArray[Any]:
42 """load a single image as numpy array
44 Args:
45 source: image source
46 is_volume: deprecated
47 """
48 if is_volume is not None:
49 warnings.warn("**is_volume** is deprecated and will be removed soon.")
51 if isinstance(source, ZipPath):
52 parsed_source = source
53 else:
54 parsed_source = interprete_file_source(source)
56 if isinstance(parsed_source, RelativeFilePath):
57 parsed_source = parsed_source.absolute()
59 if parsed_source.suffix == ".npy":
60 image = load_array(parsed_source)
61 else:
62 reader = download(parsed_source)
63 image = imread( # pyright: ignore[reportUnknownVariableType]
64 reader.read(), extension=parsed_source.suffix
65 )
67 assert is_ndarray(image)
68 return image
71def load_tensor(
72 path: Union[ZipPath, Path, str], axes: Optional[Sequence[AxisLike]] = None
73) -> Tensor:
74 # TODO: load axis meta data
75 array = load_image(path)
77 return Tensor.from_numpy(array, dims=axes)
80_SourceT = TypeVar("_SourceT", Path, HttpUrl, ZipPath)
82Suffix = str
85def save_tensor(path: Union[Path, str], tensor: Tensor) -> None:
86 # TODO: save axis meta data
88 data: NDArray[Any] = ( # pyright: ignore[reportUnknownVariableType]
89 tensor.data.to_numpy()
90 )
91 assert is_ndarray(data)
92 path = Path(path)
93 if not path.suffix:
94 raise ValueError(f"No suffix (needed to decide file format) found in {path}")
96 extension = path.suffix.lower()
97 path.parent.mkdir(exist_ok=True, parents=True)
98 if extension == ".npy":
99 save_array(path, data)
100 elif extension in (".h5", ".hdf", ".hdf5"):
101 raise NotImplementedError("Saving to h5 with dataset path is not implemented.")
102 else:
103 if (
104 extension in (".tif", ".tiff")
105 and tensor.tagged_shape.get(ba := AxisId("batch")) == 1
106 ):
107 # remove singleton batch axis for saving
108 tensor = tensor[{ba: 0}]
109 singleton_axes_msg = f"(without singleton batch axes) "
110 else:
111 singleton_axes_msg = ""
113 logger.debug(
114 "writing tensor {} {}to {}",
115 dict(tensor.tagged_shape),
116 singleton_axes_msg,
117 path,
118 )
119 imwrite(path, data, extension=extension)
122def save_sample(
123 path: Union[Path, str, PerMember[Union[Path, str]]], sample: Sample
124) -> None:
125 """Save a **sample** to a **path** pattern
126 or all sample members in the **path** mapping.
128 If **path** is a pathlib.Path or a string and the **sample** has multiple members,
129 **path** it must contain `{member_id}` (or `{input_id}` or `{output_id}`).
131 (Each) **path** may contain `{sample_id}` to be formatted with the **sample** object.
132 """
133 if not isinstance(path, collections.abc.Mapping):
134 if len(sample.members) < 2 or any(
135 m in str(path) for m in ("{member_id}", "{input_id}", "{output_id}")
136 ):
137 path = {m: path for m in sample.members}
138 else:
139 raise ValueError(
140 f"path {path} must contain '{{member_id}}' for sample with multiple members {list(sample.members)}."
141 )
143 for m, p in path.items():
144 t = sample.members[m]
145 p_formatted = Path(
146 str(p).format(sample_id=sample.id, member_id=m, input_id=m, output_id=m)
147 )
148 save_tensor(p_formatted, t)
151class _SerializedDatasetStatsEntry(
152 BaseModel, frozen=True, arbitrary_types_allowed=True
153):
154 measure: DatasetMeasure
155 value: MeasureValue
158_stat_adapter = TypeAdapter(
159 Sequence[_SerializedDatasetStatsEntry],
160 config=ConfigDict(arbitrary_types_allowed=True),
161)
164def save_dataset_stat(stat: Mapping[DatasetMeasure, MeasureValue], path: Path):
165 serializable = [
166 _SerializedDatasetStatsEntry(measure=k, value=v) for k, v in stat.items()
167 ]
168 _ = path.write_bytes(_stat_adapter.dump_json(serializable))
171def load_dataset_stat(path: Path):
172 seq = _stat_adapter.validate_json(path.read_bytes())
173 return {e.measure: e.value for e in seq}
176def ensure_unzipped(
177 source: Union[PermissiveFileSource, ZipPath, BytesReader], folder: Path
178):
179 """unzip a (downloaded) **source** to a file in **folder** if source is a zip archive
180 otherwise copy **source** to a file in **folder**."""
181 if isinstance(source, BytesReader):
182 weights_reader = source
183 else:
184 weights_reader = get_reader(source)
186 out_path = folder / (
187 weights_reader.original_file_name or f"file{weights_reader.suffix}"
188 )
190 if zipfile.is_zipfile(weights_reader):
191 out_path = out_path.with_name(out_path.name + ".unzipped")
192 out_path.parent.mkdir(exist_ok=True, parents=True)
193 # source itself is a zipfile
194 with zipfile.ZipFile(weights_reader, "r") as f:
195 f.extractall(out_path)
197 else:
198 out_path.parent.mkdir(exist_ok=True, parents=True)
199 with out_path.open("wb") as f:
200 copyfileobj(weights_reader, f)
202 return out_path
205def get_suffix(source: Union[ZipPath, FileSource]) -> Suffix:
206 """DEPRECATED: use source.suffix instead."""
207 return source.suffix