Coverage for src/bioimageio/core/io.py: 80%
121 statements
« prev ^ index » next coverage.py v7.14.2, created at 2026-06-22 16:54 +0000
« prev ^ index » next coverage.py v7.14.2, created at 2026-06-22 16:54 +0000
1import collections.abc
2import json
3import warnings
4import zipfile
5from contextlib import nullcontext
6from io import BytesIO
7from pathlib import Path
8from shutil import copyfileobj
9from typing import (
10 TYPE_CHECKING,
11 Any,
12 Dict,
13 List,
14 Mapping,
15 Optional,
16 Sequence,
17 TypeVar,
18 Union,
19)
21import xarray as xr
22from imageio.v3 import imread, imwrite # type: ignore
23from loguru import logger
24from numpy.typing import NDArray
25from pydantic import BaseModel, RootModel
26from typing_extensions import TypeAlias
27from typing_extensions import TypeAliasType as _TypeAliasType
29from bioimageio.spec._internal.io import get_reader, interprete_file_source
30from bioimageio.spec._internal.type_guards import is_ndarray
31from bioimageio.spec.common import (
32 BytesReader,
33 FileDescr,
34 FileSource,
35 HttpUrl,
36 PermissiveFileSource,
37 RelativeFilePath,
38 ZipPath,
39)
40from bioimageio.spec.utils import download, load_array, save_array
42from .axis import AxisId, AxisLike
43from .common import PerMember
44from .sample import Sample
45from .stat_measures import DatasetMeasure, MeasureValue, SampleMeasure, Stat
46from .tensor import Tensor
48if TYPE_CHECKING:
49 JsonValue: TypeAlias = Union[
50 bool, int, float, str, None, List["JsonValue"], Dict[str, "JsonValue"]
51 ] # note: order relevant for deserializing
53else:
54 # for pydantic validation we need to use `TypeAliasType`,
55 # see https://docs.pydantic.dev/latest/concepts/types/#named-recursive-types
56 # however this results in a partially unknown type with the current pyright 1.1.388
57 JsonValue: TypeAlias = _TypeAliasType(
58 "JsonValue",
59 Union[bool, int, float, str, None, List["JsonValue"], Dict[str, "JsonValue"]],
60 )
63def load_image(
64 source: Union[PermissiveFileSource, ZipPath], is_volume: Optional[bool] = None
65) -> NDArray[Any]:
66 """load a single image as numpy array
68 Args:
69 source: image source
70 is_volume: deprecated
71 """
72 if is_volume is not None:
73 warnings.warn("**is_volume** is deprecated and will be removed soon.")
75 if isinstance(source, (FileDescr, ZipPath)):
76 parsed_source = source
77 else:
78 parsed_source = interprete_file_source(source)
80 if isinstance(parsed_source, RelativeFilePath):
81 parsed_source = parsed_source.absolute()
83 if parsed_source.suffix == ".npy":
84 image = load_array(parsed_source)
85 else:
86 reader = download(parsed_source)
87 image = imread( # pyright: ignore[reportUnknownVariableType]
88 reader.read(), extension=parsed_source.suffix
89 )
91 assert is_ndarray(image)
92 return image
95def load_tensor(
96 source: Union[PermissiveFileSource, ZipPath],
97 /,
98 axes: Optional[Sequence[AxisLike]] = None,
99) -> Tensor:
100 # TODO: load axis meta data
101 array = load_image(source)
103 return Tensor.from_numpy(array, dims=axes)
106_SourceT = TypeVar("_SourceT", Path, HttpUrl, ZipPath)
108Suffix = str
111def save_tensor(path: Union[Path, str], tensor: Tensor) -> None:
112 # TODO: save axis meta data
114 path = Path(path)
115 if not path.suffix:
116 raise ValueError(f"No suffix (needed to decide file format) found in {path}")
118 extension = path.suffix.lower()
119 path.parent.mkdir(exist_ok=True, parents=True)
120 if extension == ".npy":
121 save_array(path, tensor.to_numpy())
122 elif extension in (".h5", ".hdf", ".hdf5"):
123 raise NotImplementedError("Saving to h5 with dataset path is not implemented.")
124 else:
125 removed_singleton_axes: List[AxisId] = []
126 remove_singletons = {
127 AxisId("batch"): [
128 ".tif",
129 ".tiff",
130 ], # remove singleton batch dim for tiff files
131 **{
132 a: [".png", ".jpg", ".jpeg"] for a in tensor.dims
133 }, # remove any singleton axis for png and jpg files
134 }
135 for rm_a, rm_ext in remove_singletons.items():
136 if extension in rm_ext and tensor.tagged_shape.get(rm_a) == 1:
137 tensor = tensor[{rm_a: 0}]
138 removed_singleton_axes.append(rm_a)
140 if removed_singleton_axes:
141 singleton_axes_msg = f"(with removed singleton axes {list(map(str, removed_singleton_axes))}) "
142 else:
143 singleton_axes_msg = ""
145 logger.info(
146 "writing tensor {} {}to {}",
147 dict(tensor.tagged_shape),
148 singleton_axes_msg,
149 path,
150 )
151 if extension in (".png", ".jpg", ".jpeg") and tensor.dtype in (
152 "float32",
153 "float64",
154 ):
155 logger.warning(
156 "converting tensor of dtype {} to uint8 for saving as {}",
157 tensor.dtype,
158 extension,
159 )
160 tensor = (
161 (tensor - (t_min := tensor.data.min()))
162 / xr.ufuncs.maximum(tensor.data.max() - t_min, 1e-8)
163 * 255
164 ).astype("uint8")
166 imwrite(path, tensor, extension=extension)
169def save_sample(
170 path: Union[Path, str, PerMember[Union[Path, str]]], sample: Sample
171) -> None:
172 """Save a **sample** to a **path** pattern
173 or all sample members in the **path** mapping.
175 If **path** is a pathlib.Path or a string and the **sample** has multiple members,
176 **path** it must contain `{member_id}` (or `{input_id}` or `{output_id}`).
178 (Each) **path** may contain `{sample_id}` to be formatted with the **sample** object.
179 """
180 if not isinstance(path, collections.abc.Mapping):
181 if len(sample.members) < 2 or any(
182 m in str(path) for m in ("{member_id}", "{input_id}", "{output_id}")
183 ):
184 path = {m: path for m in sample.members}
185 else:
186 raise ValueError(
187 f"path {path} must contain '{{member_id}}' for sample with multiple members {list(sample.members)}."
188 )
190 for m, p in path.items():
191 t = sample.members[m]
192 p_formatted = Path(
193 str(p).format(sample_id=sample.id, member_id=m, input_id=m, output_id=m)
194 )
195 save_tensor(p_formatted, t)
198class _StatEntry(BaseModel, frozen=True, arbitrary_types_allowed=True):
199 """Serializable stat entry"""
201 measure: Union[DatasetMeasure, SampleMeasure]
202 value: MeasureValue
205class _StatList(RootModel[List[_StatEntry]]):
206 """Serializable stat mapping"""
208 pass
211def serialize_stat(
212 stat: Mapping[Union[DatasetMeasure, SampleMeasure], MeasureValue],
213) -> List[JsonValue]:
214 """Serialize a stat mapping to a JSON string"""
215 stat_list = _StatList([_StatEntry(measure=k, value=v) for k, v in stat.items()])
216 return stat_list.model_dump(mode="json")
219def save_stat(
220 stat: Mapping[Union[DatasetMeasure, SampleMeasure], MeasureValue],
221 output: Union[Path, BytesIO],
222) -> None:
223 """Save sample and dataset statistics as a JSON file"""
225 if isinstance(output, Path):
226 ctxt = output.open("wb")
227 else:
228 ctxt = nullcontext(output)
230 with ctxt as out:
231 _ = out.write(json.dumps(serialize_stat(stat), indent=2).encode("utf-8"))
234def load_stat(source: Union[Path, str, Sequence[JsonValue]]) -> Stat:
235 """Load sample and dataset statistics from JSON"""
236 if isinstance(source, Path):
237 source = source.read_text(encoding="utf-8")
239 if isinstance(source, str):
240 seq = _StatList.model_validate_json(source)
241 else:
242 seq = _StatList.model_validate(source)
244 return {e.measure: e.value for e in seq.root}
247def save_dataset_stat(stat: Mapping[DatasetMeasure, MeasureValue], path: Path) -> None:
248 """DEPRECATED alias for save_stat(): use `save_stats()` instead."""
249 warnings.warn("`save_dataset_stat()` is deprecated, use `save_stats()` instead.")
250 save_stat({k: v for k, v in stat.items()}, path)
253def load_dataset_stat(path: Path) -> Stat:
254 """DEPRECATED alias for `load_stat()`: use `load_stat()` instead."""
255 warnings.warn("`load_dataset_stat()` is deprecated, use `load_stats()` instead.")
256 return load_stat(path)
259def ensure_unzipped(
260 source: Union[PermissiveFileSource, ZipPath, BytesReader], folder: Path
261):
262 """unzip a (downloaded) **source** to a file in **folder** if source is a zip archive
263 otherwise copy **source** to a file in **folder**."""
264 if isinstance(source, BytesReader):
265 weights_reader = source
266 else:
267 weights_reader = get_reader(source)
269 out_path = folder / (
270 weights_reader.original_file_name or f"file{weights_reader.suffix}"
271 )
273 if zipfile.is_zipfile(weights_reader):
274 out_path = out_path.with_name(out_path.name + ".unzipped")
275 out_path.parent.mkdir(exist_ok=True, parents=True)
276 # source itself is a zipfile
277 with zipfile.ZipFile(weights_reader, "r") as f:
278 f.extractall(out_path)
280 else:
281 out_path.parent.mkdir(exist_ok=True, parents=True)
282 with out_path.open("wb") as f:
283 copyfileobj(weights_reader, f)
285 return out_path
288def get_suffix(source: Union[ZipPath, FileSource]) -> Suffix:
289 """DEPRECATED: use source.suffix instead."""
290 return source.suffix