Coverage for src / bioimageio / core / io.py: 75%
119 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-18 12:35 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-18 12:35 +0000
1import collections.abc
2import json
3import warnings
4import zipfile
5from contextlib import nullcontext
6from io import BytesIO
7from pathlib import Path
8from shutil import copyfileobj
9from typing import (
10 Any,
11 Dict,
12 List,
13 Mapping,
14 Optional,
15 Sequence,
16 TypeVar,
17 Union,
18)
20import xarray as xr
21from imageio.v3 import imread, imwrite # type: ignore
22from loguru import logger
23from numpy.typing import NDArray
24from pydantic import BaseModel, RootModel
26from bioimageio.spec._internal.io import get_reader, interprete_file_source
27from bioimageio.spec._internal.type_guards import is_ndarray
28from bioimageio.spec.common import (
29 BytesReader,
30 FileDescr,
31 FileSource,
32 HttpUrl,
33 PermissiveFileSource,
34 RelativeFilePath,
35 ZipPath,
36)
37from bioimageio.spec.utils import download, load_array, save_array
39from .axis import AxisId, AxisLike
40from .common import PerMember
41from .sample import Sample, Stat
42from .stat_measures import DatasetMeasure, MeasureValue, SampleMeasure
43from .tensor import Tensor
45JsonValue = Union[
46 bool, int, float, str, None, List["JsonValue"], Dict[str, "JsonValue"]
47]
50def load_image(
51 source: Union[PermissiveFileSource, ZipPath], is_volume: Optional[bool] = None
52) -> NDArray[Any]:
53 """load a single image as numpy array
55 Args:
56 source: image source
57 is_volume: deprecated
58 """
59 if is_volume is not None:
60 warnings.warn("**is_volume** is deprecated and will be removed soon.")
62 if isinstance(source, (FileDescr, ZipPath)):
63 parsed_source = source
64 else:
65 parsed_source = interprete_file_source(source)
67 if isinstance(parsed_source, RelativeFilePath):
68 parsed_source = parsed_source.absolute()
70 if parsed_source.suffix == ".npy":
71 image = load_array(parsed_source)
72 else:
73 reader = download(parsed_source)
74 image = imread( # pyright: ignore[reportUnknownVariableType]
75 reader.read(), extension=parsed_source.suffix
76 )
78 assert is_ndarray(image)
79 return image
82def load_tensor(
83 source: Union[PermissiveFileSource, ZipPath],
84 /,
85 axes: Optional[Sequence[AxisLike]] = None,
86) -> Tensor:
87 # TODO: load axis meta data
88 array = load_image(source)
90 return Tensor.from_numpy(array, dims=axes)
93_SourceT = TypeVar("_SourceT", Path, HttpUrl, ZipPath)
95Suffix = str
98def save_tensor(path: Union[Path, str], tensor: Tensor) -> None:
99 # TODO: save axis meta data
101 path = Path(path)
102 if not path.suffix:
103 raise ValueError(f"No suffix (needed to decide file format) found in {path}")
105 extension = path.suffix.lower()
106 path.parent.mkdir(exist_ok=True, parents=True)
107 if extension == ".npy":
108 save_array(path, tensor.to_numpy())
109 elif extension in (".h5", ".hdf", ".hdf5"):
110 raise NotImplementedError("Saving to h5 with dataset path is not implemented.")
111 else:
112 removed_singleton_axes: List[AxisId] = []
113 remove_singletons = {
114 AxisId("batch"): [
115 ".tif",
116 ".tiff",
117 ], # remove singleton batch dim for tiff files
118 **{
119 a: [".png", ".jpg", ".jpeg"] for a in tensor.dims
120 }, # remove any singleton axis for png and jpg files
121 }
122 for rm_a, rm_ext in remove_singletons.items():
123 if extension in rm_ext and tensor.tagged_shape.get(rm_a) == 1:
124 tensor = tensor[{rm_a: 0}]
125 removed_singleton_axes.append(rm_a)
127 if removed_singleton_axes:
128 singleton_axes_msg = f"(with removed singleton axes {list(map(str, removed_singleton_axes))}) "
129 else:
130 singleton_axes_msg = ""
132 logger.info(
133 "writing tensor {} {}to {}",
134 dict(tensor.tagged_shape),
135 singleton_axes_msg,
136 path,
137 )
138 if extension in (".png", ".jpg", ".jpeg") and tensor.dtype in (
139 "float32",
140 "float64",
141 ):
142 logger.warning(
143 "converting tensor of dtype {} to uint8 for saving as {}",
144 tensor.dtype,
145 extension,
146 )
147 tensor = (
148 (tensor - (t_min := tensor.data.min()))
149 / xr.ufuncs.maximum(tensor.data.max() - t_min, 1e-8)
150 * 255
151 ).astype("uint8")
153 imwrite(path, tensor, extension=extension)
156def save_sample(
157 path: Union[Path, str, PerMember[Union[Path, str]]], sample: Sample
158) -> None:
159 """Save a **sample** to a **path** pattern
160 or all sample members in the **path** mapping.
162 If **path** is a pathlib.Path or a string and the **sample** has multiple members,
163 **path** it must contain `{member_id}` (or `{input_id}` or `{output_id}`).
165 (Each) **path** may contain `{sample_id}` to be formatted with the **sample** object.
166 """
167 if not isinstance(path, collections.abc.Mapping):
168 if len(sample.members) < 2 or any(
169 m in str(path) for m in ("{member_id}", "{input_id}", "{output_id}")
170 ):
171 path = {m: path for m in sample.members}
172 else:
173 raise ValueError(
174 f"path {path} must contain '{{member_id}}' for sample with multiple members {list(sample.members)}."
175 )
177 for m, p in path.items():
178 t = sample.members[m]
179 p_formatted = Path(
180 str(p).format(sample_id=sample.id, member_id=m, input_id=m, output_id=m)
181 )
182 save_tensor(p_formatted, t)
185class _StatEntry(BaseModel, frozen=True, arbitrary_types_allowed=True):
186 """Serializable stat entry"""
188 measure: Union[DatasetMeasure, SampleMeasure]
189 value: MeasureValue
192class _StatList(RootModel[List[_StatEntry]]):
193 """Serializable stat mapping"""
195 pass
198def serialize_stat(
199 stat: Mapping[Union[DatasetMeasure, SampleMeasure], MeasureValue],
200) -> List[JsonValue]:
201 """Serialize a stat mapping to a JSON string"""
202 stat_list = _StatList([_StatEntry(measure=k, value=v) for k, v in stat.items()])
203 return stat_list.model_dump(mode="json")
206def save_stat(
207 stat: Mapping[Union[DatasetMeasure, SampleMeasure], MeasureValue],
208 output: Union[Path, BytesIO],
209) -> None:
210 """Save sample and dataset statistics as a JSON file"""
212 if isinstance(output, Path):
213 ctxt = output.open("wb")
214 else:
215 ctxt = nullcontext(output)
217 with ctxt as out:
218 _ = out.write(json.dumps(serialize_stat(stat), indent=2).encode("utf-8"))
221def load_stat(source: Union[Path, str, Sequence[JsonValue]]) -> Stat:
222 """Load sample and dataset statistics from JSON"""
223 if isinstance(source, Path):
224 source = source.read_text(encoding="utf-8")
226 if isinstance(source, str):
227 seq = _StatList.model_validate_json(source)
228 else:
229 seq = _StatList.model_validate(source)
231 return {e.measure: e.value for e in seq.root}
234def save_dataset_stat(stat: Mapping[DatasetMeasure, MeasureValue], path: Path) -> None:
235 """DEPRECATED alias for save_stat(): use `save_stats()` instead."""
236 warnings.warn("`save_dataset_stat()` is deprecated, use `save_stats()` instead.")
237 save_stat({k: v for k, v in stat.items()}, path)
240def load_dataset_stat(path: Path) -> Stat:
241 """DEPRECATED alias for `load_stat()`: use `load_stat()` instead."""
242 warnings.warn("`load_dataset_stat()` is deprecated, use `load_stats()` instead.")
243 return load_stat(path)
246def ensure_unzipped(
247 source: Union[PermissiveFileSource, ZipPath, BytesReader], folder: Path
248):
249 """unzip a (downloaded) **source** to a file in **folder** if source is a zip archive
250 otherwise copy **source** to a file in **folder**."""
251 if isinstance(source, BytesReader):
252 weights_reader = source
253 else:
254 weights_reader = get_reader(source)
256 out_path = folder / (
257 weights_reader.original_file_name or f"file{weights_reader.suffix}"
258 )
260 if zipfile.is_zipfile(weights_reader):
261 out_path = out_path.with_name(out_path.name + ".unzipped")
262 out_path.parent.mkdir(exist_ok=True, parents=True)
263 # source itself is a zipfile
264 with zipfile.ZipFile(weights_reader, "r") as f:
265 f.extractall(out_path)
267 else:
268 out_path.parent.mkdir(exist_ok=True, parents=True)
269 with out_path.open("wb") as f:
270 copyfileobj(weights_reader, f)
272 return out_path
275def get_suffix(source: Union[ZipPath, FileSource]) -> Suffix:
276 """DEPRECATED: use source.suffix instead."""
277 return source.suffix