bioimageio.core.digest_spec

  1from __future__ import annotations
  2
  3import collections.abc
  4import importlib.util
  5import sys
  6from itertools import chain
  7from pathlib import Path
  8from tempfile import TemporaryDirectory
  9from typing import (
 10    Any,
 11    Callable,
 12    Dict,
 13    Iterable,
 14    List,
 15    Mapping,
 16    NamedTuple,
 17    Optional,
 18    Sequence,
 19    Tuple,
 20    Union,
 21)
 22from zipfile import ZipFile, is_zipfile
 23
 24import numpy as np
 25import xarray as xr
 26from bioimageio.spec._internal.io import HashKwargs
 27from bioimageio.spec.common import FileDescr, FileSource, ZipPath
 28from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
 29from bioimageio.spec.model.v0_4 import CallableFromDepencency, CallableFromFile
 30from bioimageio.spec.model.v0_5 import (
 31    ArchitectureFromFileDescr,
 32    ArchitectureFromLibraryDescr,
 33    ParameterizedSize_N,
 34)
 35from bioimageio.spec.utils import load_array
 36from loguru import logger
 37from numpy.typing import NDArray
 38from typing_extensions import Unpack, assert_never
 39
 40from .axis import Axis, AxisId, AxisInfo, AxisLike, PerAxis
 41from .block_meta import split_multiple_shapes_into_blocks
 42from .common import Halo, MemberId, PerMember, SampleId, TotalNumberOfBlocks
 43from .io import load_tensor
 44from .sample import (
 45    LinearSampleAxisTransform,
 46    Sample,
 47    SampleBlockMeta,
 48    sample_block_meta_generator,
 49)
 50from .stat_measures import Stat
 51from .tensor import Tensor
 52
 53TensorSource = Union[Tensor, xr.DataArray, NDArray[Any], Path]
 54
 55
 56def import_callable(
 57    node: Union[
 58        ArchitectureFromFileDescr,
 59        ArchitectureFromLibraryDescr,
 60        CallableFromDepencency,
 61        CallableFromFile,
 62    ],
 63    /,
 64    **kwargs: Unpack[HashKwargs],
 65) -> Callable[..., Any]:
 66    """import a callable (e.g. a torch.nn.Module) from a spec node describing it"""
 67    if isinstance(node, CallableFromDepencency):
 68        module = importlib.import_module(node.module_name)
 69        c = getattr(module, str(node.callable_name))
 70    elif isinstance(node, ArchitectureFromLibraryDescr):
 71        module = importlib.import_module(node.import_from)
 72        c = getattr(module, str(node.callable))
 73    elif isinstance(node, CallableFromFile):
 74        c = _import_from_file_impl(node.source_file, str(node.callable_name), **kwargs)
 75    elif isinstance(node, ArchitectureFromFileDescr):
 76        c = _import_from_file_impl(node.source, str(node.callable), sha256=node.sha256)
 77    else:
 78        assert_never(node)
 79
 80    if not callable(c):
 81        raise ValueError(f"{node} (imported: {c}) is not callable")
 82
 83    return c
 84
 85
 86tmp_dirs_in_use: List[TemporaryDirectory[str]] = []
 87"""keep global reference to temporary directories created during import to delay cleanup"""
 88
 89
 90def _import_from_file_impl(
 91    source: FileSource, callable_name: str, **kwargs: Unpack[HashKwargs]
 92):
 93    src_descr = FileDescr(source=source, **kwargs)
 94    # ensure sha is valid even if perform_io_checks=False
 95    # or the source has changed since last sha computation
 96    src_descr.validate_sha256(force_recompute=True)
 97    assert src_descr.sha256 is not None
 98    source_sha = src_descr.sha256
 99
100    reader = src_descr.get_reader()
101    # make sure we have unique module name
102    module_name = f"{reader.original_file_name.split('.')[0]}_{source_sha}"
103
104    # make sure we have a unique and valid module name
105    if not module_name.isidentifier():
106        module_name = f"custom_module_{source_sha}"
107        assert module_name.isidentifier(), module_name
108
109    source_bytes = reader.read()
110
111    module = sys.modules.get(module_name)
112    if module is None:
113        try:
114            td_kwargs: Dict[str, Any] = (
115                dict(ignore_cleanup_errors=True) if sys.version_info >= (3, 10) else {}
116            )
117            if sys.version_info >= (3, 12):
118                td_kwargs["delete"] = False
119
120            tmp_dir = TemporaryDirectory(**td_kwargs)
121            # keep global ref to tmp_dir to delay cleanup until program exit
122            # TODO: remove for py >= 3.12, when delete=False works
123            tmp_dirs_in_use.append(tmp_dir)
124
125            module_path = Path(tmp_dir.name) / module_name
126            if reader.original_file_name.endswith(".zip") or is_zipfile(reader):
127                module_path.mkdir()
128                ZipFile(reader).extractall(path=module_path)
129            else:
130                module_path = module_path.with_suffix(".py")
131                _ = module_path.write_bytes(source_bytes)
132
133            importlib_spec = importlib.util.spec_from_file_location(
134                module_name, str(module_path)
135            )
136
137            if importlib_spec is None:
138                raise ImportError(f"Failed to import {source}")
139
140            module = importlib.util.module_from_spec(importlib_spec)
141
142            sys.modules[module_name] = module  # cache this module
143
144            assert importlib_spec.loader is not None
145            importlib_spec.loader.exec_module(module)
146
147        except Exception as e:
148            if module_name in sys.modules:
149                del sys.modules[module_name]
150
151            raise ImportError(f"Failed to import {source}") from e
152
153    try:
154        callable_attr = getattr(module, callable_name)
155    except AttributeError as e:
156        raise AttributeError(
157            f"Imported custom module from {source} has no `{callable_name}` attribute."
158        ) from e
159    except Exception as e:
160        raise AttributeError(
161            f"Failed to access `{callable_name}` attribute from custom module imported from {source} ."
162        ) from e
163
164    else:
165        return callable_attr
166
167
168def get_axes_infos(
169    io_descr: Union[
170        v0_4.InputTensorDescr,
171        v0_4.OutputTensorDescr,
172        v0_5.InputTensorDescr,
173        v0_5.OutputTensorDescr,
174    ],
175) -> List[AxisInfo]:
176    """get a unified, simplified axis representation from spec axes"""
177    ret: List[AxisInfo] = []
178    for a in io_descr.axes:
179        if isinstance(a, v0_5.AxisBase):
180            ret.append(AxisInfo.create(Axis(id=a.id, type=a.type)))
181        else:
182            assert a in ("b", "i", "t", "c", "z", "y", "x")
183            ret.append(AxisInfo.create(a))
184
185    return ret
186
187
188def get_member_id(
189    tensor_description: Union[
190        v0_4.InputTensorDescr,
191        v0_4.OutputTensorDescr,
192        v0_5.InputTensorDescr,
193        v0_5.OutputTensorDescr,
194    ],
195) -> MemberId:
196    """get the normalized tensor ID, usable as a sample member ID"""
197
198    if isinstance(tensor_description, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)):
199        return MemberId(tensor_description.name)
200    elif isinstance(
201        tensor_description, (v0_5.InputTensorDescr, v0_5.OutputTensorDescr)
202    ):
203        return tensor_description.id
204    else:
205        assert_never(tensor_description)
206
207
208def get_member_ids(
209    tensor_descriptions: Sequence[
210        Union[
211            v0_4.InputTensorDescr,
212            v0_4.OutputTensorDescr,
213            v0_5.InputTensorDescr,
214            v0_5.OutputTensorDescr,
215        ]
216    ],
217) -> List[MemberId]:
218    """get normalized tensor IDs to be used as sample member IDs"""
219    return [get_member_id(descr) for descr in tensor_descriptions]
220
221
222def get_test_input_sample(model: AnyModelDescr) -> Sample:
223    return _get_test_sample(
224        model.inputs,
225        model.test_inputs if isinstance(model, v0_4.ModelDescr) else model.inputs,
226    )
227
228
229get_test_inputs = get_test_input_sample
230"""DEPRECATED: use `get_test_input_sample` instead"""
231
232
233def get_test_output_sample(model: AnyModelDescr) -> Sample:
234    """returns a model's test output sample"""
235    return _get_test_sample(
236        model.outputs,
237        model.test_outputs if isinstance(model, v0_4.ModelDescr) else model.outputs,
238    )
239
240
241get_test_outputs = get_test_output_sample
242"""DEPRECATED: use `get_test_input_sample` instead"""
243
244
245def _get_test_sample(
246    tensor_descrs: Sequence[
247        Union[
248            v0_4.InputTensorDescr,
249            v0_4.OutputTensorDescr,
250            v0_5.InputTensorDescr,
251            v0_5.OutputTensorDescr,
252        ]
253    ],
254    test_sources: Sequence[Union[FileSource, v0_5.TensorDescr]],
255) -> Sample:
256    """returns a model's input/output test sample"""
257    member_ids = get_member_ids(tensor_descrs)
258    arrays: List[NDArray[Any]] = []
259    for src in test_sources:
260        if isinstance(src, (v0_5.InputTensorDescr, v0_5.OutputTensorDescr)):
261            if src.test_tensor is None:
262                raise ValueError(
263                    f"Model input '{src.id}' has no test tensor defined, cannot create test sample."
264                )
265            arrays.append(load_array(src.test_tensor))
266        else:
267            arrays.append(load_array(src))
268
269    axes = [get_axes_infos(t) for t in tensor_descrs]
270    return Sample(
271        members={
272            m: Tensor.from_numpy(arr, dims=ax)
273            for m, arr, ax in zip(member_ids, arrays, axes)
274        },
275        stat={},
276        id="test-sample",
277    )
278
279
280class IO_SampleBlockMeta(NamedTuple):
281    input: SampleBlockMeta
282    output: SampleBlockMeta
283
284
285def get_input_halo(model: v0_5.ModelDescr, output_halo: PerMember[PerAxis[Halo]]):
286    """returns which halo input tensors need to be divided into blocks with, such that
287    `output_halo` can be cropped from their outputs without introducing gaps."""
288    input_halo: Dict[MemberId, Dict[AxisId, Halo]] = {}
289    outputs = {t.id: t for t in model.outputs}
290    all_tensors = {**{t.id: t for t in model.inputs}, **outputs}
291
292    for t, th in output_halo.items():
293        axes = {a.id: a for a in outputs[t].axes}
294
295        for a, ah in th.items():
296            s = axes[a].size
297            if not isinstance(s, v0_5.SizeReference):
298                raise ValueError(
299                    f"Unable to map output halo for {t}.{a} to an input axis"
300                )
301
302            axis = axes[a]
303            ref_axis = {a.id: a for a in all_tensors[s.tensor_id].axes}[s.axis_id]
304
305            input_halo_left = ah.left * axis.scale / ref_axis.scale
306            input_halo_right = ah.right * axis.scale / ref_axis.scale
307            assert input_halo_left == int(input_halo_left), f"{input_halo_left} not int"
308            assert input_halo_right == int(input_halo_right), (
309                f"{input_halo_right} not int"
310            )
311
312            input_halo.setdefault(s.tensor_id, {})[a] = Halo(
313                int(input_halo_left), int(input_halo_right)
314            )
315
316    return input_halo
317
318
319def get_block_transform(
320    model: v0_5.ModelDescr,
321) -> PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]]:
322    """returns how a model's output tensor shapes relates to its input shapes"""
323    ret: Dict[MemberId, Dict[AxisId, Union[LinearSampleAxisTransform, int]]] = {}
324    batch_axis_trf = None
325    for ipt in model.inputs:
326        for a in ipt.axes:
327            if a.type == "batch":
328                batch_axis_trf = LinearSampleAxisTransform(
329                    axis=a.id, scale=1, offset=0, member=ipt.id
330                )
331                break
332        if batch_axis_trf is not None:
333            break
334    axis_scales = {
335        t.id: {a.id: a.scale for a in t.axes}
336        for t in chain(model.inputs, model.outputs)
337    }
338    for out in model.outputs:
339        new_axes: Dict[AxisId, Union[LinearSampleAxisTransform, int]] = {}
340        for a in out.axes:
341            if a.size is None:
342                assert a.type == "batch"
343                if batch_axis_trf is None:
344                    raise ValueError(
345                        "no batch axis found in any input tensor, but output tensor"
346                        + f" '{out.id}' has one."
347                    )
348                s = batch_axis_trf
349            elif isinstance(a.size, int):
350                s = a.size
351            elif isinstance(a.size, v0_5.DataDependentSize):
352                s = -1
353            elif isinstance(a.size, v0_5.SizeReference):
354                s = LinearSampleAxisTransform(
355                    axis=a.size.axis_id,
356                    scale=axis_scales[a.size.tensor_id][a.size.axis_id] / a.scale,
357                    offset=a.size.offset,
358                    member=a.size.tensor_id,
359                )
360            else:
361                assert_never(a.size)
362
363            new_axes[a.id] = s
364
365        ret[out.id] = new_axes
366
367    return ret
368
369
370def get_io_sample_block_metas(
371    model: v0_5.ModelDescr,
372    input_sample_shape: PerMember[PerAxis[int]],
373    ns: Mapping[Tuple[MemberId, AxisId], ParameterizedSize_N],
374    batch_size: int = 1,
375) -> Tuple[TotalNumberOfBlocks, Iterable[IO_SampleBlockMeta]]:
376    """returns an iterable yielding meta data for corresponding input and output samples"""
377    if not isinstance(model, v0_5.ModelDescr):
378        raise TypeError(f"get_block_meta() not implemented for {type(model)}")
379
380    block_axis_sizes = model.get_axis_sizes(ns=ns, batch_size=batch_size)
381    input_block_shape = {
382        t: {aa: s for (tt, aa), s in block_axis_sizes.inputs.items() if tt == t}
383        for t in {tt for tt, _ in block_axis_sizes.inputs}
384    }
385    output_halo = {
386        t.id: {
387            a.id: Halo(a.halo, a.halo) for a in t.axes if isinstance(a, v0_5.WithHalo)
388        }
389        for t in model.outputs
390    }
391    input_halo = get_input_halo(model, output_halo)
392
393    n_input_blocks, input_blocks = split_multiple_shapes_into_blocks(
394        input_sample_shape, input_block_shape, halo=input_halo
395    )
396    block_transform = get_block_transform(model)
397    return n_input_blocks, (
398        IO_SampleBlockMeta(ipt, ipt.get_transformed(block_transform))
399        for ipt in sample_block_meta_generator(
400            input_blocks, sample_shape=input_sample_shape, sample_id=None
401        )
402    )
403
404
405def get_tensor(
406    src: Union[ZipPath, TensorSource],
407    ipt: Union[v0_4.InputTensorDescr, v0_5.InputTensorDescr],
408):
409    """helper to cast/load various tensor sources"""
410
411    if isinstance(src, Tensor):
412        return src
413    elif isinstance(src, xr.DataArray):
414        return Tensor.from_xarray(src)
415    elif isinstance(src, np.ndarray):
416        return Tensor.from_numpy(src, dims=get_axes_infos(ipt))
417    else:
418        return load_tensor(src, axes=get_axes_infos(ipt))
419
420
421def create_sample_for_model(
422    model: AnyModelDescr,
423    *,
424    stat: Optional[Stat] = None,
425    sample_id: SampleId = None,
426    inputs: Union[PerMember[TensorSource], TensorSource],
427) -> Sample:
428    """Create a sample from a single set of input(s) for a specific bioimage.io model
429
430    Args:
431        model: a bioimage.io model description
432        stat: dictionary with sample and dataset statistics (may be updated in-place!)
433        inputs: the input(s) constituting a single sample.
434    """
435
436    model_inputs = {get_member_id(d): d for d in model.inputs}
437    if isinstance(inputs, collections.abc.Mapping):
438        inputs = {MemberId(k): v for k, v in inputs.items()}
439    elif len(model_inputs) == 1:
440        inputs = {list(model_inputs)[0]: inputs}
441    else:
442        raise TypeError(
443            f"Expected `inputs` to be a mapping with keys {tuple(model_inputs)}"
444        )
445
446    if unknown := {k for k in inputs if k not in model_inputs}:
447        raise ValueError(f"Got unexpected inputs: {unknown}")
448
449    if missing := {
450        k
451        for k, v in model_inputs.items()
452        if k not in inputs and not (isinstance(v, v0_5.InputTensorDescr) and v.optional)
453    }:
454        raise ValueError(f"Missing non-optional model inputs: {missing}")
455
456    return Sample(
457        members={
458            m: get_tensor(inputs[m], ipt)
459            for m, ipt in model_inputs.items()
460            if m in inputs
461        },
462        stat={} if stat is None else stat,
463        id=sample_id,
464    )
465
466
467def load_sample_for_model(
468    *,
469    model: AnyModelDescr,
470    paths: PerMember[Path],
471    axes: Optional[PerMember[Sequence[AxisLike]]] = None,
472    stat: Optional[Stat] = None,
473    sample_id: Optional[SampleId] = None,
474):
475    """load a single sample from `paths` that can be processed by `model`"""
476
477    if axes is None:
478        axes = {}
479
480    # make sure members are keyed by MemberId, not string
481    paths = {MemberId(k): v for k, v in paths.items()}
482    axes = {MemberId(k): v for k, v in axes.items()}
483
484    model_inputs = {get_member_id(d): d for d in model.inputs}
485
486    if unknown := {k for k in paths if k not in model_inputs}:
487        raise ValueError(f"Got unexpected paths for {unknown}")
488
489    if unknown := {k for k in axes if k not in model_inputs}:
490        raise ValueError(f"Got unexpected axes hints for: {unknown}")
491
492    members: Dict[MemberId, Tensor] = {}
493    for m, p in paths.items():
494        if m not in axes:
495            axes[m] = get_axes_infos(model_inputs[m])
496            logger.debug(
497                "loading '{}' from {} with default input axes {} ",
498                m,
499                p,
500                axes[m],
501            )
502        members[m] = load_tensor(p, axes[m])
503
504    return Sample(
505        members=members,
506        stat={} if stat is None else stat,
507        id=sample_id or tuple(sorted(paths.values())),
508    )
TensorSource = typing.Union[bioimageio.core.Tensor, xarray.core.dataarray.DataArray, numpy.ndarray[tuple[int, ...], numpy.dtype[typing.Any]], pathlib.Path]
def import_callable( node: Union[bioimageio.spec.model.v0_5.ArchitectureFromFileDescr, bioimageio.spec.model.v0_5.ArchitectureFromLibraryDescr, bioimageio.spec.model.v0_4.CallableFromDepencency, bioimageio.spec.model.v0_4.CallableFromFile], /, **kwargs: Unpack[bioimageio.spec._internal.io.HashKwargs]) -> Callable[..., Any]:
57def import_callable(
58    node: Union[
59        ArchitectureFromFileDescr,
60        ArchitectureFromLibraryDescr,
61        CallableFromDepencency,
62        CallableFromFile,
63    ],
64    /,
65    **kwargs: Unpack[HashKwargs],
66) -> Callable[..., Any]:
67    """import a callable (e.g. a torch.nn.Module) from a spec node describing it"""
68    if isinstance(node, CallableFromDepencency):
69        module = importlib.import_module(node.module_name)
70        c = getattr(module, str(node.callable_name))
71    elif isinstance(node, ArchitectureFromLibraryDescr):
72        module = importlib.import_module(node.import_from)
73        c = getattr(module, str(node.callable))
74    elif isinstance(node, CallableFromFile):
75        c = _import_from_file_impl(node.source_file, str(node.callable_name), **kwargs)
76    elif isinstance(node, ArchitectureFromFileDescr):
77        c = _import_from_file_impl(node.source, str(node.callable), sha256=node.sha256)
78    else:
79        assert_never(node)
80
81    if not callable(c):
82        raise ValueError(f"{node} (imported: {c}) is not callable")
83
84    return c

import a callable (e.g. a torch.nn.Module) from a spec node describing it

tmp_dirs_in_use: List[tempfile.TemporaryDirectory[str]] = []

keep global reference to temporary directories created during import to delay cleanup

169def get_axes_infos(
170    io_descr: Union[
171        v0_4.InputTensorDescr,
172        v0_4.OutputTensorDescr,
173        v0_5.InputTensorDescr,
174        v0_5.OutputTensorDescr,
175    ],
176) -> List[AxisInfo]:
177    """get a unified, simplified axis representation from spec axes"""
178    ret: List[AxisInfo] = []
179    for a in io_descr.axes:
180        if isinstance(a, v0_5.AxisBase):
181            ret.append(AxisInfo.create(Axis(id=a.id, type=a.type)))
182        else:
183            assert a in ("b", "i", "t", "c", "z", "y", "x")
184            ret.append(AxisInfo.create(a))
185
186    return ret

get a unified, simplified axis representation from spec axes

189def get_member_id(
190    tensor_description: Union[
191        v0_4.InputTensorDescr,
192        v0_4.OutputTensorDescr,
193        v0_5.InputTensorDescr,
194        v0_5.OutputTensorDescr,
195    ],
196) -> MemberId:
197    """get the normalized tensor ID, usable as a sample member ID"""
198
199    if isinstance(tensor_description, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)):
200        return MemberId(tensor_description.name)
201    elif isinstance(
202        tensor_description, (v0_5.InputTensorDescr, v0_5.OutputTensorDescr)
203    ):
204        return tensor_description.id
205    else:
206        assert_never(tensor_description)

get the normalized tensor ID, usable as a sample member ID

209def get_member_ids(
210    tensor_descriptions: Sequence[
211        Union[
212            v0_4.InputTensorDescr,
213            v0_4.OutputTensorDescr,
214            v0_5.InputTensorDescr,
215            v0_5.OutputTensorDescr,
216        ]
217    ],
218) -> List[MemberId]:
219    """get normalized tensor IDs to be used as sample member IDs"""
220    return [get_member_id(descr) for descr in tensor_descriptions]

get normalized tensor IDs to be used as sample member IDs

def get_test_input_sample( model: Annotated[Union[Annotated[bioimageio.spec.model.v0_4.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.4')], Annotated[bioimageio.spec.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.5')]], Discriminator(discriminator='format_version', custom_error_type=None, custom_error_message=None, custom_error_context=None), FieldInfo(annotation=NoneType, required=True, title='model')]) -> bioimageio.core.Sample:
223def get_test_input_sample(model: AnyModelDescr) -> Sample:
224    return _get_test_sample(
225        model.inputs,
226        model.test_inputs if isinstance(model, v0_4.ModelDescr) else model.inputs,
227    )
def get_test_inputs( model: Annotated[Union[Annotated[bioimageio.spec.model.v0_4.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.4')], Annotated[bioimageio.spec.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.5')]], Discriminator(discriminator='format_version', custom_error_type=None, custom_error_message=None, custom_error_context=None), FieldInfo(annotation=NoneType, required=True, title='model')]) -> bioimageio.core.Sample:
223def get_test_input_sample(model: AnyModelDescr) -> Sample:
224    return _get_test_sample(
225        model.inputs,
226        model.test_inputs if isinstance(model, v0_4.ModelDescr) else model.inputs,
227    )

DEPRECATED: use get_test_input_sample instead

def get_test_output_sample( model: Annotated[Union[Annotated[bioimageio.spec.model.v0_4.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.4')], Annotated[bioimageio.spec.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.5')]], Discriminator(discriminator='format_version', custom_error_type=None, custom_error_message=None, custom_error_context=None), FieldInfo(annotation=NoneType, required=True, title='model')]) -> bioimageio.core.Sample:
234def get_test_output_sample(model: AnyModelDescr) -> Sample:
235    """returns a model's test output sample"""
236    return _get_test_sample(
237        model.outputs,
238        model.test_outputs if isinstance(model, v0_4.ModelDescr) else model.outputs,
239    )

returns a model's test output sample

def get_test_outputs( model: Annotated[Union[Annotated[bioimageio.spec.model.v0_4.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.4')], Annotated[bioimageio.spec.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.5')]], Discriminator(discriminator='format_version', custom_error_type=None, custom_error_message=None, custom_error_context=None), FieldInfo(annotation=NoneType, required=True, title='model')]) -> bioimageio.core.Sample:
234def get_test_output_sample(model: AnyModelDescr) -> Sample:
235    """returns a model's test output sample"""
236    return _get_test_sample(
237        model.outputs,
238        model.test_outputs if isinstance(model, v0_4.ModelDescr) else model.outputs,
239    )

DEPRECATED: use get_test_input_sample instead

class IO_SampleBlockMeta(typing.NamedTuple):
281class IO_SampleBlockMeta(NamedTuple):
282    input: SampleBlockMeta
283    output: SampleBlockMeta

IO_SampleBlockMeta(input, output)

IO_SampleBlockMeta( input: ForwardRef('SampleBlockMeta'), output: ForwardRef('SampleBlockMeta'))

Create new instance of IO_SampleBlockMeta(input, output)

Alias for field number 0

Alias for field number 1

286def get_input_halo(model: v0_5.ModelDescr, output_halo: PerMember[PerAxis[Halo]]):
287    """returns which halo input tensors need to be divided into blocks with, such that
288    `output_halo` can be cropped from their outputs without introducing gaps."""
289    input_halo: Dict[MemberId, Dict[AxisId, Halo]] = {}
290    outputs = {t.id: t for t in model.outputs}
291    all_tensors = {**{t.id: t for t in model.inputs}, **outputs}
292
293    for t, th in output_halo.items():
294        axes = {a.id: a for a in outputs[t].axes}
295
296        for a, ah in th.items():
297            s = axes[a].size
298            if not isinstance(s, v0_5.SizeReference):
299                raise ValueError(
300                    f"Unable to map output halo for {t}.{a} to an input axis"
301                )
302
303            axis = axes[a]
304            ref_axis = {a.id: a for a in all_tensors[s.tensor_id].axes}[s.axis_id]
305
306            input_halo_left = ah.left * axis.scale / ref_axis.scale
307            input_halo_right = ah.right * axis.scale / ref_axis.scale
308            assert input_halo_left == int(input_halo_left), f"{input_halo_left} not int"
309            assert input_halo_right == int(input_halo_right), (
310                f"{input_halo_right} not int"
311            )
312
313            input_halo.setdefault(s.tensor_id, {})[a] = Halo(
314                int(input_halo_left), int(input_halo_right)
315            )
316
317    return input_halo

returns which halo input tensors need to be divided into blocks with, such that output_halo can be cropped from their outputs without introducing gaps.

320def get_block_transform(
321    model: v0_5.ModelDescr,
322) -> PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]]:
323    """returns how a model's output tensor shapes relates to its input shapes"""
324    ret: Dict[MemberId, Dict[AxisId, Union[LinearSampleAxisTransform, int]]] = {}
325    batch_axis_trf = None
326    for ipt in model.inputs:
327        for a in ipt.axes:
328            if a.type == "batch":
329                batch_axis_trf = LinearSampleAxisTransform(
330                    axis=a.id, scale=1, offset=0, member=ipt.id
331                )
332                break
333        if batch_axis_trf is not None:
334            break
335    axis_scales = {
336        t.id: {a.id: a.scale for a in t.axes}
337        for t in chain(model.inputs, model.outputs)
338    }
339    for out in model.outputs:
340        new_axes: Dict[AxisId, Union[LinearSampleAxisTransform, int]] = {}
341        for a in out.axes:
342            if a.size is None:
343                assert a.type == "batch"
344                if batch_axis_trf is None:
345                    raise ValueError(
346                        "no batch axis found in any input tensor, but output tensor"
347                        + f" '{out.id}' has one."
348                    )
349                s = batch_axis_trf
350            elif isinstance(a.size, int):
351                s = a.size
352            elif isinstance(a.size, v0_5.DataDependentSize):
353                s = -1
354            elif isinstance(a.size, v0_5.SizeReference):
355                s = LinearSampleAxisTransform(
356                    axis=a.size.axis_id,
357                    scale=axis_scales[a.size.tensor_id][a.size.axis_id] / a.scale,
358                    offset=a.size.offset,
359                    member=a.size.tensor_id,
360                )
361            else:
362                assert_never(a.size)
363
364            new_axes[a.id] = s
365
366        ret[out.id] = new_axes
367
368    return ret

returns how a model's output tensor shapes relates to its input shapes

def get_io_sample_block_metas( model: bioimageio.spec.ModelDescr, input_sample_shape: Mapping[bioimageio.spec.model.v0_5.TensorId, Mapping[bioimageio.spec.model.v0_5.AxisId, int]], ns: Mapping[Tuple[bioimageio.spec.model.v0_5.TensorId, bioimageio.spec.model.v0_5.AxisId], int], batch_size: int = 1) -> Tuple[int, Iterable[IO_SampleBlockMeta]]:
371def get_io_sample_block_metas(
372    model: v0_5.ModelDescr,
373    input_sample_shape: PerMember[PerAxis[int]],
374    ns: Mapping[Tuple[MemberId, AxisId], ParameterizedSize_N],
375    batch_size: int = 1,
376) -> Tuple[TotalNumberOfBlocks, Iterable[IO_SampleBlockMeta]]:
377    """returns an iterable yielding meta data for corresponding input and output samples"""
378    if not isinstance(model, v0_5.ModelDescr):
379        raise TypeError(f"get_block_meta() not implemented for {type(model)}")
380
381    block_axis_sizes = model.get_axis_sizes(ns=ns, batch_size=batch_size)
382    input_block_shape = {
383        t: {aa: s for (tt, aa), s in block_axis_sizes.inputs.items() if tt == t}
384        for t in {tt for tt, _ in block_axis_sizes.inputs}
385    }
386    output_halo = {
387        t.id: {
388            a.id: Halo(a.halo, a.halo) for a in t.axes if isinstance(a, v0_5.WithHalo)
389        }
390        for t in model.outputs
391    }
392    input_halo = get_input_halo(model, output_halo)
393
394    n_input_blocks, input_blocks = split_multiple_shapes_into_blocks(
395        input_sample_shape, input_block_shape, halo=input_halo
396    )
397    block_transform = get_block_transform(model)
398    return n_input_blocks, (
399        IO_SampleBlockMeta(ipt, ipt.get_transformed(block_transform))
400        for ipt in sample_block_meta_generator(
401            input_blocks, sample_shape=input_sample_shape, sample_id=None
402        )
403    )

returns an iterable yielding meta data for corresponding input and output samples

def get_tensor( src: Union[zipp.Path, bioimageio.core.Tensor, xarray.core.dataarray.DataArray, numpy.ndarray[tuple[int, ...], numpy.dtype[Any]], pathlib.Path], ipt: Union[bioimageio.spec.model.v0_4.InputTensorDescr, bioimageio.spec.model.v0_5.InputTensorDescr]):
406def get_tensor(
407    src: Union[ZipPath, TensorSource],
408    ipt: Union[v0_4.InputTensorDescr, v0_5.InputTensorDescr],
409):
410    """helper to cast/load various tensor sources"""
411
412    if isinstance(src, Tensor):
413        return src
414    elif isinstance(src, xr.DataArray):
415        return Tensor.from_xarray(src)
416    elif isinstance(src, np.ndarray):
417        return Tensor.from_numpy(src, dims=get_axes_infos(ipt))
418    else:
419        return load_tensor(src, axes=get_axes_infos(ipt))

helper to cast/load various tensor sources

def create_sample_for_model( model: Annotated[Union[Annotated[bioimageio.spec.model.v0_4.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.4')], Annotated[bioimageio.spec.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.5')]], Discriminator(discriminator='format_version', custom_error_type=None, custom_error_message=None, custom_error_context=None), FieldInfo(annotation=NoneType, required=True, title='model')], *, stat: Optional[Dict[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer>, return_type=PydanticUndefined, when_used='always')]]]] = None, sample_id: Hashable = None, inputs: Union[Mapping[bioimageio.spec.model.v0_5.TensorId, Union[bioimageio.core.Tensor, xarray.core.dataarray.DataArray, numpy.ndarray[tuple[int, ...], numpy.dtype[Any]], pathlib.Path]], bioimageio.core.Tensor, xarray.core.dataarray.DataArray, numpy.ndarray[tuple[int, ...], numpy.dtype[Any]], pathlib.Path]) -> bioimageio.core.Sample:
422def create_sample_for_model(
423    model: AnyModelDescr,
424    *,
425    stat: Optional[Stat] = None,
426    sample_id: SampleId = None,
427    inputs: Union[PerMember[TensorSource], TensorSource],
428) -> Sample:
429    """Create a sample from a single set of input(s) for a specific bioimage.io model
430
431    Args:
432        model: a bioimage.io model description
433        stat: dictionary with sample and dataset statistics (may be updated in-place!)
434        inputs: the input(s) constituting a single sample.
435    """
436
437    model_inputs = {get_member_id(d): d for d in model.inputs}
438    if isinstance(inputs, collections.abc.Mapping):
439        inputs = {MemberId(k): v for k, v in inputs.items()}
440    elif len(model_inputs) == 1:
441        inputs = {list(model_inputs)[0]: inputs}
442    else:
443        raise TypeError(
444            f"Expected `inputs` to be a mapping with keys {tuple(model_inputs)}"
445        )
446
447    if unknown := {k for k in inputs if k not in model_inputs}:
448        raise ValueError(f"Got unexpected inputs: {unknown}")
449
450    if missing := {
451        k
452        for k, v in model_inputs.items()
453        if k not in inputs and not (isinstance(v, v0_5.InputTensorDescr) and v.optional)
454    }:
455        raise ValueError(f"Missing non-optional model inputs: {missing}")
456
457    return Sample(
458        members={
459            m: get_tensor(inputs[m], ipt)
460            for m, ipt in model_inputs.items()
461            if m in inputs
462        },
463        stat={} if stat is None else stat,
464        id=sample_id,
465    )

Create a sample from a single set of input(s) for a specific bioimage.io model

Arguments:
  • model: a bioimage.io model description
  • stat: dictionary with sample and dataset statistics (may be updated in-place!)
  • inputs: the input(s) constituting a single sample.
def load_sample_for_model( *, model: Annotated[Union[Annotated[bioimageio.spec.model.v0_4.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.4')], Annotated[bioimageio.spec.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.5')]], Discriminator(discriminator='format_version', custom_error_type=None, custom_error_message=None, custom_error_context=None), FieldInfo(annotation=NoneType, required=True, title='model')], paths: Mapping[bioimageio.spec.model.v0_5.TensorId, pathlib.Path], axes: Optional[Mapping[bioimageio.spec.model.v0_5.TensorId, Sequence[Union[str, bioimageio.spec.model.v0_5.AxisId, Literal['b', 'i', 't', 'c', 'z', 'y', 'x'], bioimageio.core.axis.AxisDescrLike, Annotated[Union[bioimageio.spec.model.v0_5.BatchAxis, bioimageio.spec.model.v0_5.ChannelAxis, bioimageio.spec.model.v0_5.IndexInputAxis, bioimageio.spec.model.v0_5.TimeInputAxis, bioimageio.spec.model.v0_5.SpaceInputAxis], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.spec.model.v0_5.BatchAxis, bioimageio.spec.model.v0_5.ChannelAxis, bioimageio.spec.model.v0_5.IndexOutputAxis, Annotated[Union[Annotated[bioimageio.spec.model.v0_5.TimeOutputAxis, Tag(tag='wo_halo')], Annotated[bioimageio.spec.model.v0_5.TimeOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[Annotated[bioimageio.spec.model.v0_5.SpaceOutputAxis, Tag(tag='wo_halo')], Annotated[bioimageio.spec.model.v0_5.SpaceOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], bioimageio.core.Axis]]]] = None, stat: Optional[Dict[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer>, return_type=PydanticUndefined, when_used='always')]]]] = None, sample_id: Optional[Hashable] = None):
468def load_sample_for_model(
469    *,
470    model: AnyModelDescr,
471    paths: PerMember[Path],
472    axes: Optional[PerMember[Sequence[AxisLike]]] = None,
473    stat: Optional[Stat] = None,
474    sample_id: Optional[SampleId] = None,
475):
476    """load a single sample from `paths` that can be processed by `model`"""
477
478    if axes is None:
479        axes = {}
480
481    # make sure members are keyed by MemberId, not string
482    paths = {MemberId(k): v for k, v in paths.items()}
483    axes = {MemberId(k): v for k, v in axes.items()}
484
485    model_inputs = {get_member_id(d): d for d in model.inputs}
486
487    if unknown := {k for k in paths if k not in model_inputs}:
488        raise ValueError(f"Got unexpected paths for {unknown}")
489
490    if unknown := {k for k in axes if k not in model_inputs}:
491        raise ValueError(f"Got unexpected axes hints for: {unknown}")
492
493    members: Dict[MemberId, Tensor] = {}
494    for m, p in paths.items():
495        if m not in axes:
496            axes[m] = get_axes_infos(model_inputs[m])
497            logger.debug(
498                "loading '{}' from {} with default input axes {} ",
499                m,
500                p,
501                axes[m],
502            )
503        members[m] = load_tensor(p, axes[m])
504
505    return Sample(
506        members=members,
507        stat={} if stat is None else stat,
508        id=sample_id or tuple(sorted(paths.values())),
509    )

load a single sample from paths that can be processed by model