bioimageio.core.io

  1import collections.abc
  2import warnings
  3import zipfile
  4from io import TextIOWrapper
  5from pathlib import Path, PurePosixPath
  6from shutil import copyfileobj
  7from typing import (
  8    Any,
  9    Mapping,
 10    Optional,
 11    Sequence,
 12    Tuple,
 13    TypeVar,
 14    Union,
 15)
 16
 17import h5py  # pyright: ignore[reportMissingTypeStubs]
 18import numpy as np
 19from imageio.v3 import imread, imwrite  # type: ignore
 20from loguru import logger
 21from numpy.typing import NDArray
 22from pydantic import BaseModel, ConfigDict, TypeAdapter
 23from typing_extensions import assert_never
 24
 25from bioimageio.spec._internal.io import interprete_file_source
 26from bioimageio.spec.common import (
 27    HttpUrl,
 28    PermissiveFileSource,
 29    RelativeFilePath,
 30    ZipPath,
 31)
 32from bioimageio.spec.utils import download, load_array, save_array
 33
 34from .axis import AxisLike
 35from .common import PerMember
 36from .sample import Sample
 37from .stat_measures import DatasetMeasure, MeasureValue
 38from .tensor import Tensor
 39
 40DEFAULT_H5_DATASET_PATH = "data"
 41
 42
 43SUFFIXES_WITH_DATAPATH = (".h5", ".hdf", ".hdf5")
 44
 45
 46def load_image(
 47    source: Union[ZipPath, PermissiveFileSource], is_volume: Optional[bool] = None
 48) -> NDArray[Any]:
 49    """load a single image as numpy array
 50
 51    Args:
 52        source: image source
 53        is_volume: deprecated
 54    """
 55    if is_volume is not None:
 56        warnings.warn("**is_volume** is deprecated and will be removed soon.")
 57
 58    if isinstance(source, ZipPath):
 59        parsed_source = source
 60    else:
 61        parsed_source = interprete_file_source(source)
 62
 63    if isinstance(parsed_source, RelativeFilePath):
 64        src = parsed_source.absolute()
 65    else:
 66        src = parsed_source
 67
 68    # FIXME: why is pyright complaining about giving the union to _split_dataset_path?
 69    if isinstance(src, Path):
 70        file_source, subpath = _split_dataset_path(src)
 71    elif isinstance(src, HttpUrl):
 72        file_source, subpath = _split_dataset_path(src)
 73    elif isinstance(src, ZipPath):
 74        file_source, subpath = _split_dataset_path(src)
 75    else:
 76        assert_never(src)
 77
 78    path = download(file_source).path
 79
 80    if path.suffix == ".npy":
 81        if subpath is not None:
 82            raise ValueError(f"Unexpected subpath {subpath} for .npy path {path}")
 83        return load_array(path)
 84    elif path.suffix in SUFFIXES_WITH_DATAPATH:
 85        if subpath is None:
 86            dataset_path = DEFAULT_H5_DATASET_PATH
 87        else:
 88            dataset_path = str(subpath)
 89
 90        with h5py.File(path, "r") as f:
 91            h5_dataset = f.get(  # pyright: ignore[reportUnknownVariableType]
 92                dataset_path
 93            )
 94            if not isinstance(h5_dataset, h5py.Dataset):
 95                raise ValueError(
 96                    f"{path} is not of type {h5py.Dataset}, but has type "
 97                    + str(
 98                        type(h5_dataset)  # pyright: ignore[reportUnknownArgumentType]
 99                    )
100                )
101            image: NDArray[Any]
102            image = h5_dataset[:]  # pyright: ignore[reportUnknownVariableType]
103            assert isinstance(image, np.ndarray), type(
104                image  # pyright: ignore[reportUnknownArgumentType]
105            )
106            return image  # pyright: ignore[reportUnknownVariableType]
107    elif isinstance(path, ZipPath):
108        return imread(
109            path.read_bytes(), extension=path.suffix
110        )  # pyright: ignore[reportUnknownVariableType]
111    else:
112        return imread(path)  # pyright: ignore[reportUnknownVariableType]
113
114
115def load_tensor(
116    path: Union[ZipPath, Path, str], axes: Optional[Sequence[AxisLike]] = None
117) -> Tensor:
118    # TODO: load axis meta data
119    array = load_image(path)
120
121    return Tensor.from_numpy(array, dims=axes)
122
123
124_SourceT = TypeVar("_SourceT", Path, HttpUrl, ZipPath)
125
126
127def _split_dataset_path(
128    source: _SourceT,
129) -> Tuple[_SourceT, Optional[PurePosixPath]]:
130    """Split off subpath (e.g. internal  h5 dataset path)
131    from a file path following a file extension.
132
133    Examples:
134        >>> _split_dataset_path(Path("my_file.h5/dataset"))
135        (...Path('my_file.h5'), PurePosixPath('dataset'))
136
137        >>> _split_dataset_path(Path("my_plain_file"))
138        (...Path('my_plain_file'), None)
139
140    """
141    if isinstance(source, RelativeFilePath):
142        src = source.absolute()
143    else:
144        src = source
145
146    del source
147
148    def separate_pure_path(path: PurePosixPath):
149        for p in path.parents:
150            if p.suffix in SUFFIXES_WITH_DATAPATH:
151                return p, PurePosixPath(path.relative_to(p))
152
153        return path, None
154
155    if isinstance(src, HttpUrl):
156        file_path, data_path = separate_pure_path(PurePosixPath(src.path or ""))
157
158        if data_path is None:
159            return src, None
160
161        return (
162            HttpUrl(str(file_path).replace(f"/{data_path}", "")),
163            data_path,
164        )
165
166    if isinstance(src, ZipPath):
167        file_path, data_path = separate_pure_path(PurePosixPath(str(src)))
168
169        if data_path is None:
170            return src, None
171
172        return (
173            ZipPath(str(file_path).replace(f"/{data_path}", "")),
174            data_path,
175        )
176
177    file_path, data_path = separate_pure_path(PurePosixPath(src))
178    return Path(file_path), data_path
179
180
181def save_tensor(path: Union[Path, str], tensor: Tensor) -> None:
182    # TODO: save axis meta data
183
184    data: NDArray[Any] = tensor.data.to_numpy()
185    file_path, subpath = _split_dataset_path(Path(path))
186    if not file_path.suffix:
187        raise ValueError(f"No suffix (needed to decide file format) found in {path}")
188
189    file_path.parent.mkdir(exist_ok=True, parents=True)
190    if file_path.suffix == ".npy":
191        if subpath is not None:
192            raise ValueError(f"Unexpected subpath {subpath} found in .npy path {path}")
193        save_array(file_path, data)
194    elif file_path.suffix in (".h5", ".hdf", ".hdf5"):
195        if subpath is None:
196            dataset_path = DEFAULT_H5_DATASET_PATH
197        else:
198            dataset_path = str(subpath)
199
200        with h5py.File(file_path, "a") as f:
201            if dataset_path in f:
202                del f[dataset_path]
203
204            _ = f.create_dataset(dataset_path, data=data, chunks=True)
205    else:
206        # if singleton_axes := [a for a, s in tensor.tagged_shape.items() if s == 1]:
207        #     tensor = tensor[{a: 0 for a in singleton_axes}]
208        #     singleton_axes_msg = f"(without singleton axes {singleton_axes}) "
209        # else:
210        singleton_axes_msg = ""
211
212        logger.debug(
213            "writing tensor {} {}to {}",
214            dict(tensor.tagged_shape),
215            singleton_axes_msg,
216            path,
217        )
218        imwrite(path, data)
219
220
221def save_sample(
222    path: Union[Path, str, PerMember[Union[Path, str]]], sample: Sample
223) -> None:
224    """Save a **sample** to a **path** pattern
225    or all sample members in the **path** mapping.
226
227    If **path** is a pathlib.Path or a string and the **sample** has multiple members,
228    **path** it must contain `{member_id}` (or `{input_id}` or `{output_id}`).
229
230    (Each) **path** may contain `{sample_id}` to be formatted with the **sample** object.
231    """
232    if not isinstance(path, collections.abc.Mapping):
233        if len(sample.members) < 2 or any(
234            m in str(path) for m in ("{member_id}", "{input_id}", "{output_id}")
235        ):
236            path = {m: path for m in sample.members}
237        else:
238            raise ValueError(
239                f"path {path} must contain '{{member_id}}' for sample with multiple members {list(sample.members)}."
240            )
241
242    for m, p in path.items():
243        t = sample.members[m]
244        p_formatted = Path(
245            str(p).format(sample_id=sample.id, member_id=m, input_id=m, output_id=m)
246        )
247        save_tensor(p_formatted, t)
248
249
250class _SerializedDatasetStatsEntry(
251    BaseModel, frozen=True, arbitrary_types_allowed=True
252):
253    measure: DatasetMeasure
254    value: MeasureValue
255
256
257_stat_adapter = TypeAdapter(
258    Sequence[_SerializedDatasetStatsEntry],
259    config=ConfigDict(arbitrary_types_allowed=True),
260)
261
262
263def save_dataset_stat(stat: Mapping[DatasetMeasure, MeasureValue], path: Path):
264    serializable = [
265        _SerializedDatasetStatsEntry(measure=k, value=v) for k, v in stat.items()
266    ]
267    _ = path.write_bytes(_stat_adapter.dump_json(serializable))
268
269
270def load_dataset_stat(path: Path):
271    seq = _stat_adapter.validate_json(path.read_bytes())
272    return {e.measure: e.value for e in seq}
273
274
275def ensure_unzipped(source: Union[PermissiveFileSource, ZipPath], folder: Path):
276    """unzip a (downloaded) **source** to a file in **folder** if source is a zip archive.
277    Always returns the path to the unzipped source (maybe source itself)"""
278    local_weights_file = download(source).path
279    if isinstance(local_weights_file, ZipPath):
280        # source is inside a zip archive
281        out_path = folder / local_weights_file.filename
282        with local_weights_file.open("rb") as src, out_path.open("wb") as dst:
283            assert not isinstance(src, TextIOWrapper)
284            copyfileobj(src, dst)
285
286        local_weights_file = out_path
287
288    if zipfile.is_zipfile(local_weights_file):
289        # source itself is a zipfile
290        out_path = folder / local_weights_file.with_suffix(".unzipped").name
291        with zipfile.ZipFile(local_weights_file, "r") as f:
292            f.extractall(out_path)
293
294        return out_path
295    else:
296        return local_weights_file
DEFAULT_H5_DATASET_PATH = 'data'
SUFFIXES_WITH_DATAPATH = ('.h5', '.hdf', '.hdf5')
def load_image( source: Union[zipp.Path, Annotated[Union[bioimageio.spec._internal.url.HttpUrl, bioimageio.spec._internal.io.RelativeFilePath, Annotated[pathlib.Path, PathType(path_type='file'), FieldInfo(annotation=NoneType, required=True, title='FilePath')]], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')])], str, Annotated[pydantic_core._pydantic_core.Url, UrlConstraints(max_length=2083, allowed_schemes=['http', 'https'], host_required=None, default_host=None, default_port=None, default_path=None)]], is_volume: Optional[bool] = None) -> numpy.ndarray[typing.Any, numpy.dtype[typing.Any]]:
 47def load_image(
 48    source: Union[ZipPath, PermissiveFileSource], is_volume: Optional[bool] = None
 49) -> NDArray[Any]:
 50    """load a single image as numpy array
 51
 52    Args:
 53        source: image source
 54        is_volume: deprecated
 55    """
 56    if is_volume is not None:
 57        warnings.warn("**is_volume** is deprecated and will be removed soon.")
 58
 59    if isinstance(source, ZipPath):
 60        parsed_source = source
 61    else:
 62        parsed_source = interprete_file_source(source)
 63
 64    if isinstance(parsed_source, RelativeFilePath):
 65        src = parsed_source.absolute()
 66    else:
 67        src = parsed_source
 68
 69    # FIXME: why is pyright complaining about giving the union to _split_dataset_path?
 70    if isinstance(src, Path):
 71        file_source, subpath = _split_dataset_path(src)
 72    elif isinstance(src, HttpUrl):
 73        file_source, subpath = _split_dataset_path(src)
 74    elif isinstance(src, ZipPath):
 75        file_source, subpath = _split_dataset_path(src)
 76    else:
 77        assert_never(src)
 78
 79    path = download(file_source).path
 80
 81    if path.suffix == ".npy":
 82        if subpath is not None:
 83            raise ValueError(f"Unexpected subpath {subpath} for .npy path {path}")
 84        return load_array(path)
 85    elif path.suffix in SUFFIXES_WITH_DATAPATH:
 86        if subpath is None:
 87            dataset_path = DEFAULT_H5_DATASET_PATH
 88        else:
 89            dataset_path = str(subpath)
 90
 91        with h5py.File(path, "r") as f:
 92            h5_dataset = f.get(  # pyright: ignore[reportUnknownVariableType]
 93                dataset_path
 94            )
 95            if not isinstance(h5_dataset, h5py.Dataset):
 96                raise ValueError(
 97                    f"{path} is not of type {h5py.Dataset}, but has type "
 98                    + str(
 99                        type(h5_dataset)  # pyright: ignore[reportUnknownArgumentType]
100                    )
101                )
102            image: NDArray[Any]
103            image = h5_dataset[:]  # pyright: ignore[reportUnknownVariableType]
104            assert isinstance(image, np.ndarray), type(
105                image  # pyright: ignore[reportUnknownArgumentType]
106            )
107            return image  # pyright: ignore[reportUnknownVariableType]
108    elif isinstance(path, ZipPath):
109        return imread(
110            path.read_bytes(), extension=path.suffix
111        )  # pyright: ignore[reportUnknownVariableType]
112    else:
113        return imread(path)  # pyright: ignore[reportUnknownVariableType]

load a single image as numpy array

Arguments:
  • source: image source
  • is_volume: deprecated
def load_tensor( path: Union[zipp.Path, pathlib.Path, str], axes: Optional[Sequence[Union[bioimageio.spec.model.v0_5.AxisId, Literal['b', 'i', 't', 'c', 'z', 'y', 'x'], Annotated[Union[bioimageio.spec.model.v0_5.BatchAxis, bioimageio.spec.model.v0_5.ChannelAxis, bioimageio.spec.model.v0_5.IndexInputAxis, bioimageio.spec.model.v0_5.TimeInputAxis, bioimageio.spec.model.v0_5.SpaceInputAxis], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.spec.model.v0_5.BatchAxis, bioimageio.spec.model.v0_5.ChannelAxis, bioimageio.spec.model.v0_5.IndexOutputAxis, Annotated[Union[Annotated[bioimageio.spec.model.v0_5.TimeOutputAxis, Tag(tag='wo_halo')], Annotated[bioimageio.spec.model.v0_5.TimeOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[Annotated[bioimageio.spec.model.v0_5.SpaceOutputAxis, Tag(tag='wo_halo')], Annotated[bioimageio.spec.model.v0_5.SpaceOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], bioimageio.core.Axis]]] = None) -> bioimageio.core.Tensor:
116def load_tensor(
117    path: Union[ZipPath, Path, str], axes: Optional[Sequence[AxisLike]] = None
118) -> Tensor:
119    # TODO: load axis meta data
120    array = load_image(path)
121
122    return Tensor.from_numpy(array, dims=axes)
def save_tensor( path: Union[pathlib.Path, str], tensor: bioimageio.core.Tensor) -> None:
182def save_tensor(path: Union[Path, str], tensor: Tensor) -> None:
183    # TODO: save axis meta data
184
185    data: NDArray[Any] = tensor.data.to_numpy()
186    file_path, subpath = _split_dataset_path(Path(path))
187    if not file_path.suffix:
188        raise ValueError(f"No suffix (needed to decide file format) found in {path}")
189
190    file_path.parent.mkdir(exist_ok=True, parents=True)
191    if file_path.suffix == ".npy":
192        if subpath is not None:
193            raise ValueError(f"Unexpected subpath {subpath} found in .npy path {path}")
194        save_array(file_path, data)
195    elif file_path.suffix in (".h5", ".hdf", ".hdf5"):
196        if subpath is None:
197            dataset_path = DEFAULT_H5_DATASET_PATH
198        else:
199            dataset_path = str(subpath)
200
201        with h5py.File(file_path, "a") as f:
202            if dataset_path in f:
203                del f[dataset_path]
204
205            _ = f.create_dataset(dataset_path, data=data, chunks=True)
206    else:
207        # if singleton_axes := [a for a, s in tensor.tagged_shape.items() if s == 1]:
208        #     tensor = tensor[{a: 0 for a in singleton_axes}]
209        #     singleton_axes_msg = f"(without singleton axes {singleton_axes}) "
210        # else:
211        singleton_axes_msg = ""
212
213        logger.debug(
214            "writing tensor {} {}to {}",
215            dict(tensor.tagged_shape),
216            singleton_axes_msg,
217            path,
218        )
219        imwrite(path, data)
def save_sample( path: Union[pathlib.Path, str, Mapping[bioimageio.spec.model.v0_5.TensorId, Union[pathlib.Path, str]]], sample: bioimageio.core.Sample) -> None:
222def save_sample(
223    path: Union[Path, str, PerMember[Union[Path, str]]], sample: Sample
224) -> None:
225    """Save a **sample** to a **path** pattern
226    or all sample members in the **path** mapping.
227
228    If **path** is a pathlib.Path or a string and the **sample** has multiple members,
229    **path** it must contain `{member_id}` (or `{input_id}` or `{output_id}`).
230
231    (Each) **path** may contain `{sample_id}` to be formatted with the **sample** object.
232    """
233    if not isinstance(path, collections.abc.Mapping):
234        if len(sample.members) < 2 or any(
235            m in str(path) for m in ("{member_id}", "{input_id}", "{output_id}")
236        ):
237            path = {m: path for m in sample.members}
238        else:
239            raise ValueError(
240                f"path {path} must contain '{{member_id}}' for sample with multiple members {list(sample.members)}."
241            )
242
243    for m, p in path.items():
244        t = sample.members[m]
245        p_formatted = Path(
246            str(p).format(sample_id=sample.id, member_id=m, input_id=m, output_id=m)
247        )
248        save_tensor(p_formatted, t)

Save a sample to a path pattern or all sample members in the path mapping.

If path is a pathlib.Path or a string and the sample has multiple members, path it must contain {member_id} (or {input_id} or {output_id}).

(Each) path may contain {sample_id} to be formatted with the sample object.

def save_dataset_stat( stat: Mapping[Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer>, return_type=PydanticUndefined, when_used='always')]]], path: pathlib.Path):
264def save_dataset_stat(stat: Mapping[DatasetMeasure, MeasureValue], path: Path):
265    serializable = [
266        _SerializedDatasetStatsEntry(measure=k, value=v) for k, v in stat.items()
267    ]
268    _ = path.write_bytes(_stat_adapter.dump_json(serializable))
def load_dataset_stat(path: pathlib.Path):
271def load_dataset_stat(path: Path):
272    seq = _stat_adapter.validate_json(path.read_bytes())
273    return {e.measure: e.value for e in seq}
def ensure_unzipped( source: Union[Annotated[Union[bioimageio.spec._internal.url.HttpUrl, bioimageio.spec._internal.io.RelativeFilePath, Annotated[pathlib.Path, PathType(path_type='file'), FieldInfo(annotation=NoneType, required=True, title='FilePath')]], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')])], str, Annotated[pydantic_core._pydantic_core.Url, UrlConstraints(max_length=2083, allowed_schemes=['http', 'https'], host_required=None, default_host=None, default_port=None, default_path=None)], zipp.Path], folder: pathlib.Path):
276def ensure_unzipped(source: Union[PermissiveFileSource, ZipPath], folder: Path):
277    """unzip a (downloaded) **source** to a file in **folder** if source is a zip archive.
278    Always returns the path to the unzipped source (maybe source itself)"""
279    local_weights_file = download(source).path
280    if isinstance(local_weights_file, ZipPath):
281        # source is inside a zip archive
282        out_path = folder / local_weights_file.filename
283        with local_weights_file.open("rb") as src, out_path.open("wb") as dst:
284            assert not isinstance(src, TextIOWrapper)
285            copyfileobj(src, dst)
286
287        local_weights_file = out_path
288
289    if zipfile.is_zipfile(local_weights_file):
290        # source itself is a zipfile
291        out_path = folder / local_weights_file.with_suffix(".unzipped").name
292        with zipfile.ZipFile(local_weights_file, "r") as f:
293            f.extractall(out_path)
294
295        return out_path
296    else:
297        return local_weights_file

unzip a (downloaded) source to a file in folder if source is a zip archive. Always returns the path to the unzipped source (maybe source itself)