Coverage for src / bioimageio / core / io.py: 84%
91 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-15 23:26 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-15 23:26 +0000
1import collections.abc
2import warnings
3import zipfile
4from pathlib import Path
5from shutil import copyfileobj
6from typing import (
7 Any,
8 Mapping,
9 Optional,
10 Sequence,
11 TypeVar,
12 Union,
13)
15from imageio.v3 import imread, imwrite # type: ignore
16from loguru import logger
17from numpy.typing import NDArray
18from pydantic import BaseModel, ConfigDict, TypeAdapter
20from bioimageio.spec._internal.io import get_reader, interprete_file_source
21from bioimageio.spec._internal.type_guards import is_ndarray
22from bioimageio.spec.common import (
23 BytesReader,
24 FileDescr,
25 FileSource,
26 HttpUrl,
27 PermissiveFileSource,
28 RelativeFilePath,
29 ZipPath,
30)
31from bioimageio.spec.utils import download, load_array, save_array
33from .axis import AxisId, AxisLike
34from .common import PerMember
35from .sample import Sample
36from .stat_measures import DatasetMeasure, MeasureValue
37from .tensor import Tensor
40def load_image(
41 source: Union[PermissiveFileSource, ZipPath], is_volume: Optional[bool] = None
42) -> NDArray[Any]:
43 """load a single image as numpy array
45 Args:
46 source: image source
47 is_volume: deprecated
48 """
49 if is_volume is not None:
50 warnings.warn("**is_volume** is deprecated and will be removed soon.")
52 if isinstance(source, (FileDescr, ZipPath)):
53 parsed_source = source
54 else:
55 parsed_source = interprete_file_source(source)
57 if isinstance(parsed_source, RelativeFilePath):
58 parsed_source = parsed_source.absolute()
60 if parsed_source.suffix == ".npy":
61 image = load_array(parsed_source)
62 else:
63 reader = download(parsed_source)
64 image = imread( # pyright: ignore[reportUnknownVariableType]
65 reader.read(), extension=parsed_source.suffix
66 )
68 assert is_ndarray(image)
69 return image
72def load_tensor(
73 source: Union[PermissiveFileSource, ZipPath],
74 /,
75 axes: Optional[Sequence[AxisLike]] = None,
76) -> Tensor:
77 # TODO: load axis meta data
78 array = load_image(source)
80 return Tensor.from_numpy(array, dims=axes)
83_SourceT = TypeVar("_SourceT", Path, HttpUrl, ZipPath)
85Suffix = str
88def save_tensor(path: Union[Path, str], tensor: Tensor) -> None:
89 # TODO: save axis meta data
91 data: NDArray[Any] = ( # pyright: ignore[reportUnknownVariableType]
92 tensor.data.to_numpy()
93 )
94 assert is_ndarray(data)
95 path = Path(path)
96 if not path.suffix:
97 raise ValueError(f"No suffix (needed to decide file format) found in {path}")
99 extension = path.suffix.lower()
100 path.parent.mkdir(exist_ok=True, parents=True)
101 if extension == ".npy":
102 save_array(path, data)
103 elif extension in (".h5", ".hdf", ".hdf5"):
104 raise NotImplementedError("Saving to h5 with dataset path is not implemented.")
105 else:
106 if (
107 extension in (".tif", ".tiff")
108 and tensor.tagged_shape.get(ba := AxisId("batch")) == 1
109 ):
110 # remove singleton batch axis for saving
111 tensor = tensor[{ba: 0}]
112 singleton_axes_msg = "(without singleton batch axes) "
113 else:
114 singleton_axes_msg = ""
116 logger.debug(
117 "writing tensor {} {}to {}",
118 dict(tensor.tagged_shape),
119 singleton_axes_msg,
120 path,
121 )
122 imwrite(path, data, extension=extension)
125def save_sample(
126 path: Union[Path, str, PerMember[Union[Path, str]]], sample: Sample
127) -> None:
128 """Save a **sample** to a **path** pattern
129 or all sample members in the **path** mapping.
131 If **path** is a pathlib.Path or a string and the **sample** has multiple members,
132 **path** it must contain `{member_id}` (or `{input_id}` or `{output_id}`).
134 (Each) **path** may contain `{sample_id}` to be formatted with the **sample** object.
135 """
136 if not isinstance(path, collections.abc.Mapping):
137 if len(sample.members) < 2 or any(
138 m in str(path) for m in ("{member_id}", "{input_id}", "{output_id}")
139 ):
140 path = {m: path for m in sample.members}
141 else:
142 raise ValueError(
143 f"path {path} must contain '{{member_id}}' for sample with multiple members {list(sample.members)}."
144 )
146 for m, p in path.items():
147 t = sample.members[m]
148 p_formatted = Path(
149 str(p).format(sample_id=sample.id, member_id=m, input_id=m, output_id=m)
150 )
151 save_tensor(p_formatted, t)
154class _SerializedDatasetStatsEntry(
155 BaseModel, frozen=True, arbitrary_types_allowed=True
156):
157 measure: DatasetMeasure
158 value: MeasureValue
161_stat_adapter = TypeAdapter(
162 Sequence[_SerializedDatasetStatsEntry],
163 config=ConfigDict(arbitrary_types_allowed=True),
164)
167def save_dataset_stat(stat: Mapping[DatasetMeasure, MeasureValue], path: Path):
168 serializable = [
169 _SerializedDatasetStatsEntry(measure=k, value=v) for k, v in stat.items()
170 ]
171 _ = path.write_bytes(_stat_adapter.dump_json(serializable))
174def load_dataset_stat(path: Path):
175 seq = _stat_adapter.validate_json(path.read_bytes())
176 return {e.measure: e.value for e in seq}
179def ensure_unzipped(
180 source: Union[PermissiveFileSource, ZipPath, BytesReader], folder: Path
181):
182 """unzip a (downloaded) **source** to a file in **folder** if source is a zip archive
183 otherwise copy **source** to a file in **folder**."""
184 if isinstance(source, BytesReader):
185 weights_reader = source
186 else:
187 weights_reader = get_reader(source)
189 out_path = folder / (
190 weights_reader.original_file_name or f"file{weights_reader.suffix}"
191 )
193 if zipfile.is_zipfile(weights_reader):
194 out_path = out_path.with_name(out_path.name + ".unzipped")
195 out_path.parent.mkdir(exist_ok=True, parents=True)
196 # source itself is a zipfile
197 with zipfile.ZipFile(weights_reader, "r") as f:
198 f.extractall(out_path)
200 else:
201 out_path.parent.mkdir(exist_ok=True, parents=True)
202 with out_path.open("wb") as f:
203 copyfileobj(weights_reader, f)
205 return out_path
208def get_suffix(source: Union[ZipPath, FileSource]) -> Suffix:
209 """DEPRECATED: use source.suffix instead."""
210 return source.suffix