bioimageio.core.digest_spec

  1from __future__ import annotations
  2
  3import collections.abc
  4import importlib.util
  5import sys
  6from itertools import chain
  7from pathlib import Path
  8from tempfile import TemporaryDirectory
  9from typing import (
 10    Any,
 11    Callable,
 12    Dict,
 13    Iterable,
 14    List,
 15    Mapping,
 16    NamedTuple,
 17    Optional,
 18    Sequence,
 19    Tuple,
 20    Union,
 21)
 22from zipfile import ZipFile, is_zipfile
 23
 24import numpy as np
 25import xarray as xr
 26from loguru import logger
 27from numpy.typing import NDArray
 28from typing_extensions import Unpack, assert_never
 29
 30from bioimageio.spec._internal.io import HashKwargs
 31from bioimageio.spec.common import FileDescr, FileSource, ZipPath
 32from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
 33from bioimageio.spec.model.v0_4 import CallableFromDepencency, CallableFromFile
 34from bioimageio.spec.model.v0_5 import (
 35    ArchitectureFromFileDescr,
 36    ArchitectureFromLibraryDescr,
 37    ParameterizedSize_N,
 38)
 39from bioimageio.spec.utils import load_array
 40
 41from .axis import Axis, AxisId, AxisInfo, AxisLike, PerAxis
 42from .block_meta import split_multiple_shapes_into_blocks
 43from .common import Halo, MemberId, PerMember, SampleId, TotalNumberOfBlocks
 44from .io import load_tensor
 45from .sample import (
 46    LinearSampleAxisTransform,
 47    Sample,
 48    SampleBlockMeta,
 49    sample_block_meta_generator,
 50)
 51from .stat_measures import Stat
 52from .tensor import Tensor
 53
 54TensorSource = Union[Tensor, xr.DataArray, NDArray[Any], Path]
 55
 56
 57def import_callable(
 58    node: Union[
 59        ArchitectureFromFileDescr,
 60        ArchitectureFromLibraryDescr,
 61        CallableFromDepencency,
 62        CallableFromFile,
 63    ],
 64    /,
 65    **kwargs: Unpack[HashKwargs],
 66) -> Callable[..., Any]:
 67    """import a callable (e.g. a torch.nn.Module) from a spec node describing it"""
 68    if isinstance(node, CallableFromDepencency):
 69        module = importlib.import_module(node.module_name)
 70        c = getattr(module, str(node.callable_name))
 71    elif isinstance(node, ArchitectureFromLibraryDescr):
 72        module = importlib.import_module(node.import_from)
 73        c = getattr(module, str(node.callable))
 74    elif isinstance(node, CallableFromFile):
 75        c = _import_from_file_impl(node.source_file, str(node.callable_name), **kwargs)
 76    elif isinstance(node, ArchitectureFromFileDescr):
 77        c = _import_from_file_impl(node.source, str(node.callable), sha256=node.sha256)
 78    else:
 79        assert_never(node)
 80
 81    if not callable(c):
 82        raise ValueError(f"{node} (imported: {c}) is not callable")
 83
 84    return c
 85
 86
 87tmp_dirs_in_use: List[TemporaryDirectory[str]] = []
 88"""keep global reference to temporary directories created during import to delay cleanup"""
 89
 90
 91def _import_from_file_impl(
 92    source: FileSource, callable_name: str, **kwargs: Unpack[HashKwargs]
 93):
 94    src_descr = FileDescr(source=source, **kwargs)
 95    # ensure sha is valid even if perform_io_checks=False
 96    # or the source has changed since last sha computation
 97    src_descr.validate_sha256(force_recompute=True)
 98    assert src_descr.sha256 is not None
 99    source_sha = src_descr.sha256
100
101    reader = src_descr.get_reader()
102    # make sure we have unique module name
103    module_name = f"{reader.original_file_name.split('.')[0]}_{source_sha}"
104
105    # make sure we have a unique and valid module name
106    if not module_name.isidentifier():
107        module_name = f"custom_module_{source_sha}"
108        assert module_name.isidentifier(), module_name
109
110    source_bytes = reader.read()
111
112    module = sys.modules.get(module_name)
113    if module is None:
114        try:
115            td_kwargs: Dict[str, Any] = (
116                dict(ignore_cleanup_errors=True) if sys.version_info >= (3, 10) else {}
117            )
118            if sys.version_info >= (3, 12):
119                td_kwargs["delete"] = False
120
121            tmp_dir = TemporaryDirectory(**td_kwargs)
122            # keep global ref to tmp_dir to delay cleanup until program exit
123            # TODO: remove for py >= 3.12, when delete=False works
124            tmp_dirs_in_use.append(tmp_dir)
125
126            module_path = Path(tmp_dir.name) / module_name
127            if reader.original_file_name.endswith(".zip") or is_zipfile(reader):
128                module_path.mkdir()
129                ZipFile(reader).extractall(path=module_path)
130            else:
131                module_path = module_path.with_suffix(".py")
132                _ = module_path.write_bytes(source_bytes)
133
134            importlib_spec = importlib.util.spec_from_file_location(
135                module_name, str(module_path)
136            )
137
138            if importlib_spec is None:
139                raise ImportError(f"Failed to import {source}")
140
141            module = importlib.util.module_from_spec(importlib_spec)
142
143            sys.modules[module_name] = module  # cache this module
144
145            assert importlib_spec.loader is not None
146            importlib_spec.loader.exec_module(module)
147
148        except Exception as e:
149            del sys.modules[module_name]
150            raise ImportError(f"Failed to import {source}") from e
151
152    try:
153        callable_attr = getattr(module, callable_name)
154    except AttributeError as e:
155        raise AttributeError(
156            f"Imported custom module from {source} has no `{callable_name}` attribute."
157        ) from e
158    except Exception as e:
159        raise AttributeError(
160            f"Failed to access `{callable_name}` attribute from custom module imported from {source} ."
161        ) from e
162
163    else:
164        return callable_attr
165
166
167def get_axes_infos(
168    io_descr: Union[
169        v0_4.InputTensorDescr,
170        v0_4.OutputTensorDescr,
171        v0_5.InputTensorDescr,
172        v0_5.OutputTensorDescr,
173    ],
174) -> List[AxisInfo]:
175    """get a unified, simplified axis representation from spec axes"""
176    ret: List[AxisInfo] = []
177    for a in io_descr.axes:
178        if isinstance(a, v0_5.AxisBase):
179            ret.append(AxisInfo.create(Axis(id=a.id, type=a.type)))
180        else:
181            assert a in ("b", "i", "t", "c", "z", "y", "x")
182            ret.append(AxisInfo.create(a))
183
184    return ret
185
186
187def get_member_id(
188    tensor_description: Union[
189        v0_4.InputTensorDescr,
190        v0_4.OutputTensorDescr,
191        v0_5.InputTensorDescr,
192        v0_5.OutputTensorDescr,
193    ],
194) -> MemberId:
195    """get the normalized tensor ID, usable as a sample member ID"""
196
197    if isinstance(tensor_description, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)):
198        return MemberId(tensor_description.name)
199    elif isinstance(
200        tensor_description, (v0_5.InputTensorDescr, v0_5.OutputTensorDescr)
201    ):
202        return tensor_description.id
203    else:
204        assert_never(tensor_description)
205
206
207def get_member_ids(
208    tensor_descriptions: Sequence[
209        Union[
210            v0_4.InputTensorDescr,
211            v0_4.OutputTensorDescr,
212            v0_5.InputTensorDescr,
213            v0_5.OutputTensorDescr,
214        ]
215    ],
216) -> List[MemberId]:
217    """get normalized tensor IDs to be used as sample member IDs"""
218    return [get_member_id(descr) for descr in tensor_descriptions]
219
220
221def get_test_input_sample(model: AnyModelDescr) -> Sample:
222    return _get_test_sample(
223        model.inputs,
224        model.test_inputs if isinstance(model, v0_4.ModelDescr) else model.inputs,
225    )
226
227
228get_test_inputs = get_test_input_sample
229"""DEPRECATED: use `get_test_input_sample` instead"""
230
231
232def get_test_output_sample(model: AnyModelDescr) -> Sample:
233    """returns a model's test output sample"""
234    return _get_test_sample(
235        model.outputs,
236        model.test_outputs if isinstance(model, v0_4.ModelDescr) else model.outputs,
237    )
238
239
240get_test_outputs = get_test_output_sample
241"""DEPRECATED: use `get_test_input_sample` instead"""
242
243
244def _get_test_sample(
245    tensor_descrs: Sequence[
246        Union[
247            v0_4.InputTensorDescr,
248            v0_4.OutputTensorDescr,
249            v0_5.InputTensorDescr,
250            v0_5.OutputTensorDescr,
251        ]
252    ],
253    test_sources: Sequence[Union[FileSource, v0_5.TensorDescr]],
254) -> Sample:
255    """returns a model's input/output test sample"""
256    member_ids = get_member_ids(tensor_descrs)
257    arrays: List[NDArray[Any]] = []
258    for src in test_sources:
259        if isinstance(src, (v0_5.InputTensorDescr, v0_5.OutputTensorDescr)):
260            if src.test_tensor is None:
261                raise ValueError(
262                    f"Model input '{src.id}' has no test tensor defined, cannot create test sample."
263                )
264            arrays.append(load_array(src.test_tensor))
265        else:
266            arrays.append(load_array(src))
267
268    axes = [get_axes_infos(t) for t in tensor_descrs]
269    return Sample(
270        members={
271            m: Tensor.from_numpy(arr, dims=ax)
272            for m, arr, ax in zip(member_ids, arrays, axes)
273        },
274        stat={},
275        id="test-sample",
276    )
277
278
279class IO_SampleBlockMeta(NamedTuple):
280    input: SampleBlockMeta
281    output: SampleBlockMeta
282
283
284def get_input_halo(model: v0_5.ModelDescr, output_halo: PerMember[PerAxis[Halo]]):
285    """returns which halo input tensors need to be divided into blocks with, such that
286    `output_halo` can be cropped from their outputs without introducing gaps."""
287    input_halo: Dict[MemberId, Dict[AxisId, Halo]] = {}
288    outputs = {t.id: t for t in model.outputs}
289    all_tensors = {**{t.id: t for t in model.inputs}, **outputs}
290
291    for t, th in output_halo.items():
292        axes = {a.id: a for a in outputs[t].axes}
293
294        for a, ah in th.items():
295            s = axes[a].size
296            if not isinstance(s, v0_5.SizeReference):
297                raise ValueError(
298                    f"Unable to map output halo for {t}.{a} to an input axis"
299                )
300
301            axis = axes[a]
302            ref_axis = {a.id: a for a in all_tensors[s.tensor_id].axes}[s.axis_id]
303
304            total_output_halo = sum(ah)
305            total_input_halo = total_output_halo * axis.scale / ref_axis.scale
306            assert (
307                total_input_halo == int(total_input_halo) and total_input_halo % 2 == 0
308            )
309            input_halo.setdefault(s.tensor_id, {})[a] = Halo(
310                int(total_input_halo // 2), int(total_input_halo // 2)
311            )
312
313    return input_halo
314
315
316def get_block_transform(
317    model: v0_5.ModelDescr,
318) -> PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]]:
319    """returns how a model's output tensor shapes relates to its input shapes"""
320    ret: Dict[MemberId, Dict[AxisId, Union[LinearSampleAxisTransform, int]]] = {}
321    batch_axis_trf = None
322    for ipt in model.inputs:
323        for a in ipt.axes:
324            if a.type == "batch":
325                batch_axis_trf = LinearSampleAxisTransform(
326                    axis=a.id, scale=1, offset=0, member=ipt.id
327                )
328                break
329        if batch_axis_trf is not None:
330            break
331    axis_scales = {
332        t.id: {a.id: a.scale for a in t.axes}
333        for t in chain(model.inputs, model.outputs)
334    }
335    for out in model.outputs:
336        new_axes: Dict[AxisId, Union[LinearSampleAxisTransform, int]] = {}
337        for a in out.axes:
338            if a.size is None:
339                assert a.type == "batch"
340                if batch_axis_trf is None:
341                    raise ValueError(
342                        "no batch axis found in any input tensor, but output tensor"
343                        + f" '{out.id}' has one."
344                    )
345                s = batch_axis_trf
346            elif isinstance(a.size, int):
347                s = a.size
348            elif isinstance(a.size, v0_5.DataDependentSize):
349                s = -1
350            elif isinstance(a.size, v0_5.SizeReference):
351                s = LinearSampleAxisTransform(
352                    axis=a.size.axis_id,
353                    scale=axis_scales[a.size.tensor_id][a.size.axis_id] / a.scale,
354                    offset=a.size.offset,
355                    member=a.size.tensor_id,
356                )
357            else:
358                assert_never(a.size)
359
360            new_axes[a.id] = s
361
362        ret[out.id] = new_axes
363
364    return ret
365
366
367def get_io_sample_block_metas(
368    model: v0_5.ModelDescr,
369    input_sample_shape: PerMember[PerAxis[int]],
370    ns: Mapping[Tuple[MemberId, AxisId], ParameterizedSize_N],
371    batch_size: int = 1,
372) -> Tuple[TotalNumberOfBlocks, Iterable[IO_SampleBlockMeta]]:
373    """returns an iterable yielding meta data for corresponding input and output samples"""
374    if not isinstance(model, v0_5.ModelDescr):
375        raise TypeError(f"get_block_meta() not implemented for {type(model)}")
376
377    block_axis_sizes = model.get_axis_sizes(ns=ns, batch_size=batch_size)
378    input_block_shape = {
379        t: {aa: s for (tt, aa), s in block_axis_sizes.inputs.items() if tt == t}
380        for t in {tt for tt, _ in block_axis_sizes.inputs}
381    }
382    output_halo = {
383        t.id: {
384            a.id: Halo(a.halo, a.halo) for a in t.axes if isinstance(a, v0_5.WithHalo)
385        }
386        for t in model.outputs
387    }
388    input_halo = get_input_halo(model, output_halo)
389
390    n_input_blocks, input_blocks = split_multiple_shapes_into_blocks(
391        input_sample_shape, input_block_shape, halo=input_halo
392    )
393    block_transform = get_block_transform(model)
394    return n_input_blocks, (
395        IO_SampleBlockMeta(ipt, ipt.get_transformed(block_transform))
396        for ipt in sample_block_meta_generator(
397            input_blocks, sample_shape=input_sample_shape, sample_id=None
398        )
399    )
400
401
402def get_tensor(
403    src: Union[ZipPath, TensorSource],
404    ipt: Union[v0_4.InputTensorDescr, v0_5.InputTensorDescr],
405):
406    """helper to cast/load various tensor sources"""
407
408    if isinstance(src, Tensor):
409        return src
410    elif isinstance(src, xr.DataArray):
411        return Tensor.from_xarray(src)
412    elif isinstance(src, np.ndarray):
413        return Tensor.from_numpy(src, dims=get_axes_infos(ipt))
414    else:
415        return load_tensor(src, axes=get_axes_infos(ipt))
416
417
418def create_sample_for_model(
419    model: AnyModelDescr,
420    *,
421    stat: Optional[Stat] = None,
422    sample_id: SampleId = None,
423    inputs: Union[PerMember[TensorSource], TensorSource],
424) -> Sample:
425    """Create a sample from a single set of input(s) for a specific bioimage.io model
426
427    Args:
428        model: a bioimage.io model description
429        stat: dictionary with sample and dataset statistics (may be updated in-place!)
430        inputs: the input(s) constituting a single sample.
431    """
432
433    model_inputs = {get_member_id(d): d for d in model.inputs}
434    if isinstance(inputs, collections.abc.Mapping):
435        inputs = {MemberId(k): v for k, v in inputs.items()}
436    elif len(model_inputs) == 1:
437        inputs = {list(model_inputs)[0]: inputs}
438    else:
439        raise TypeError(
440            f"Expected `inputs` to be a mapping with keys {tuple(model_inputs)}"
441        )
442
443    if unknown := {k for k in inputs if k not in model_inputs}:
444        raise ValueError(f"Got unexpected inputs: {unknown}")
445
446    if missing := {
447        k
448        for k, v in model_inputs.items()
449        if k not in inputs and not (isinstance(v, v0_5.InputTensorDescr) and v.optional)
450    }:
451        raise ValueError(f"Missing non-optional model inputs: {missing}")
452
453    return Sample(
454        members={
455            m: get_tensor(inputs[m], ipt)
456            for m, ipt in model_inputs.items()
457            if m in inputs
458        },
459        stat={} if stat is None else stat,
460        id=sample_id,
461    )
462
463
464def load_sample_for_model(
465    *,
466    model: AnyModelDescr,
467    paths: PerMember[Path],
468    axes: Optional[PerMember[Sequence[AxisLike]]] = None,
469    stat: Optional[Stat] = None,
470    sample_id: Optional[SampleId] = None,
471):
472    """load a single sample from `paths` that can be processed by `model`"""
473
474    if axes is None:
475        axes = {}
476
477    # make sure members are keyed by MemberId, not string
478    paths = {MemberId(k): v for k, v in paths.items()}
479    axes = {MemberId(k): v for k, v in axes.items()}
480
481    model_inputs = {get_member_id(d): d for d in model.inputs}
482
483    if unknown := {k for k in paths if k not in model_inputs}:
484        raise ValueError(f"Got unexpected paths for {unknown}")
485
486    if unknown := {k for k in axes if k not in model_inputs}:
487        raise ValueError(f"Got unexpected axes hints for: {unknown}")
488
489    members: Dict[MemberId, Tensor] = {}
490    for m, p in paths.items():
491        if m not in axes:
492            axes[m] = get_axes_infos(model_inputs[m])
493            logger.debug(
494                "loading '{}' from {} with default input axes {} ",
495                m,
496                p,
497                axes[m],
498            )
499        members[m] = load_tensor(p, axes[m])
500
501    return Sample(
502        members=members,
503        stat={} if stat is None else stat,
504        id=sample_id or tuple(sorted(paths.values())),
505    )
TensorSource = typing.Union[bioimageio.core.Tensor, xarray.core.dataarray.DataArray, numpy.ndarray[typing.Any, numpy.dtype[typing.Any]], pathlib.Path]
def import_callable( node: Union[bioimageio.spec.model.v0_5.ArchitectureFromFileDescr, bioimageio.spec.model.v0_5.ArchitectureFromLibraryDescr, bioimageio.spec.model.v0_4.CallableFromDepencency, bioimageio.spec.model.v0_4.CallableFromFile], /, **kwargs: Unpack[bioimageio.spec._internal.io.HashKwargs]) -> Callable[..., Any]:
58def import_callable(
59    node: Union[
60        ArchitectureFromFileDescr,
61        ArchitectureFromLibraryDescr,
62        CallableFromDepencency,
63        CallableFromFile,
64    ],
65    /,
66    **kwargs: Unpack[HashKwargs],
67) -> Callable[..., Any]:
68    """import a callable (e.g. a torch.nn.Module) from a spec node describing it"""
69    if isinstance(node, CallableFromDepencency):
70        module = importlib.import_module(node.module_name)
71        c = getattr(module, str(node.callable_name))
72    elif isinstance(node, ArchitectureFromLibraryDescr):
73        module = importlib.import_module(node.import_from)
74        c = getattr(module, str(node.callable))
75    elif isinstance(node, CallableFromFile):
76        c = _import_from_file_impl(node.source_file, str(node.callable_name), **kwargs)
77    elif isinstance(node, ArchitectureFromFileDescr):
78        c = _import_from_file_impl(node.source, str(node.callable), sha256=node.sha256)
79    else:
80        assert_never(node)
81
82    if not callable(c):
83        raise ValueError(f"{node} (imported: {c}) is not callable")
84
85    return c

import a callable (e.g. a torch.nn.Module) from a spec node describing it

tmp_dirs_in_use: List[tempfile.TemporaryDirectory[str]] = []

keep global reference to temporary directories created during import to delay cleanup

168def get_axes_infos(
169    io_descr: Union[
170        v0_4.InputTensorDescr,
171        v0_4.OutputTensorDescr,
172        v0_5.InputTensorDescr,
173        v0_5.OutputTensorDescr,
174    ],
175) -> List[AxisInfo]:
176    """get a unified, simplified axis representation from spec axes"""
177    ret: List[AxisInfo] = []
178    for a in io_descr.axes:
179        if isinstance(a, v0_5.AxisBase):
180            ret.append(AxisInfo.create(Axis(id=a.id, type=a.type)))
181        else:
182            assert a in ("b", "i", "t", "c", "z", "y", "x")
183            ret.append(AxisInfo.create(a))
184
185    return ret

get a unified, simplified axis representation from spec axes

188def get_member_id(
189    tensor_description: Union[
190        v0_4.InputTensorDescr,
191        v0_4.OutputTensorDescr,
192        v0_5.InputTensorDescr,
193        v0_5.OutputTensorDescr,
194    ],
195) -> MemberId:
196    """get the normalized tensor ID, usable as a sample member ID"""
197
198    if isinstance(tensor_description, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)):
199        return MemberId(tensor_description.name)
200    elif isinstance(
201        tensor_description, (v0_5.InputTensorDescr, v0_5.OutputTensorDescr)
202    ):
203        return tensor_description.id
204    else:
205        assert_never(tensor_description)

get the normalized tensor ID, usable as a sample member ID

208def get_member_ids(
209    tensor_descriptions: Sequence[
210        Union[
211            v0_4.InputTensorDescr,
212            v0_4.OutputTensorDescr,
213            v0_5.InputTensorDescr,
214            v0_5.OutputTensorDescr,
215        ]
216    ],
217) -> List[MemberId]:
218    """get normalized tensor IDs to be used as sample member IDs"""
219    return [get_member_id(descr) for descr in tensor_descriptions]

get normalized tensor IDs to be used as sample member IDs

def get_test_input_sample( model: Annotated[Union[Annotated[bioimageio.spec.model.v0_4.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.4')], Annotated[bioimageio.spec.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.5')]], Discriminator(discriminator='format_version', custom_error_type=None, custom_error_message=None, custom_error_context=None), FieldInfo(annotation=NoneType, required=True, title='model')]) -> bioimageio.core.Sample:
222def get_test_input_sample(model: AnyModelDescr) -> Sample:
223    return _get_test_sample(
224        model.inputs,
225        model.test_inputs if isinstance(model, v0_4.ModelDescr) else model.inputs,
226    )
def get_test_inputs( model: Annotated[Union[Annotated[bioimageio.spec.model.v0_4.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.4')], Annotated[bioimageio.spec.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.5')]], Discriminator(discriminator='format_version', custom_error_type=None, custom_error_message=None, custom_error_context=None), FieldInfo(annotation=NoneType, required=True, title='model')]) -> bioimageio.core.Sample:
222def get_test_input_sample(model: AnyModelDescr) -> Sample:
223    return _get_test_sample(
224        model.inputs,
225        model.test_inputs if isinstance(model, v0_4.ModelDescr) else model.inputs,
226    )

DEPRECATED: use get_test_input_sample instead

def get_test_output_sample( model: Annotated[Union[Annotated[bioimageio.spec.model.v0_4.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.4')], Annotated[bioimageio.spec.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.5')]], Discriminator(discriminator='format_version', custom_error_type=None, custom_error_message=None, custom_error_context=None), FieldInfo(annotation=NoneType, required=True, title='model')]) -> bioimageio.core.Sample:
233def get_test_output_sample(model: AnyModelDescr) -> Sample:
234    """returns a model's test output sample"""
235    return _get_test_sample(
236        model.outputs,
237        model.test_outputs if isinstance(model, v0_4.ModelDescr) else model.outputs,
238    )

returns a model's test output sample

def get_test_outputs( model: Annotated[Union[Annotated[bioimageio.spec.model.v0_4.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.4')], Annotated[bioimageio.spec.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.5')]], Discriminator(discriminator='format_version', custom_error_type=None, custom_error_message=None, custom_error_context=None), FieldInfo(annotation=NoneType, required=True, title='model')]) -> bioimageio.core.Sample:
233def get_test_output_sample(model: AnyModelDescr) -> Sample:
234    """returns a model's test output sample"""
235    return _get_test_sample(
236        model.outputs,
237        model.test_outputs if isinstance(model, v0_4.ModelDescr) else model.outputs,
238    )

DEPRECATED: use get_test_input_sample instead

class IO_SampleBlockMeta(typing.NamedTuple):
280class IO_SampleBlockMeta(NamedTuple):
281    input: SampleBlockMeta
282    output: SampleBlockMeta

IO_SampleBlockMeta(input, output)

IO_SampleBlockMeta( input: ForwardRef('SampleBlockMeta'), output: ForwardRef('SampleBlockMeta'))

Create new instance of IO_SampleBlockMeta(input, output)

Alias for field number 0

Alias for field number 1

285def get_input_halo(model: v0_5.ModelDescr, output_halo: PerMember[PerAxis[Halo]]):
286    """returns which halo input tensors need to be divided into blocks with, such that
287    `output_halo` can be cropped from their outputs without introducing gaps."""
288    input_halo: Dict[MemberId, Dict[AxisId, Halo]] = {}
289    outputs = {t.id: t for t in model.outputs}
290    all_tensors = {**{t.id: t for t in model.inputs}, **outputs}
291
292    for t, th in output_halo.items():
293        axes = {a.id: a for a in outputs[t].axes}
294
295        for a, ah in th.items():
296            s = axes[a].size
297            if not isinstance(s, v0_5.SizeReference):
298                raise ValueError(
299                    f"Unable to map output halo for {t}.{a} to an input axis"
300                )
301
302            axis = axes[a]
303            ref_axis = {a.id: a for a in all_tensors[s.tensor_id].axes}[s.axis_id]
304
305            total_output_halo = sum(ah)
306            total_input_halo = total_output_halo * axis.scale / ref_axis.scale
307            assert (
308                total_input_halo == int(total_input_halo) and total_input_halo % 2 == 0
309            )
310            input_halo.setdefault(s.tensor_id, {})[a] = Halo(
311                int(total_input_halo // 2), int(total_input_halo // 2)
312            )
313
314    return input_halo

returns which halo input tensors need to be divided into blocks with, such that output_halo can be cropped from their outputs without introducing gaps.

317def get_block_transform(
318    model: v0_5.ModelDescr,
319) -> PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]]:
320    """returns how a model's output tensor shapes relates to its input shapes"""
321    ret: Dict[MemberId, Dict[AxisId, Union[LinearSampleAxisTransform, int]]] = {}
322    batch_axis_trf = None
323    for ipt in model.inputs:
324        for a in ipt.axes:
325            if a.type == "batch":
326                batch_axis_trf = LinearSampleAxisTransform(
327                    axis=a.id, scale=1, offset=0, member=ipt.id
328                )
329                break
330        if batch_axis_trf is not None:
331            break
332    axis_scales = {
333        t.id: {a.id: a.scale for a in t.axes}
334        for t in chain(model.inputs, model.outputs)
335    }
336    for out in model.outputs:
337        new_axes: Dict[AxisId, Union[LinearSampleAxisTransform, int]] = {}
338        for a in out.axes:
339            if a.size is None:
340                assert a.type == "batch"
341                if batch_axis_trf is None:
342                    raise ValueError(
343                        "no batch axis found in any input tensor, but output tensor"
344                        + f" '{out.id}' has one."
345                    )
346                s = batch_axis_trf
347            elif isinstance(a.size, int):
348                s = a.size
349            elif isinstance(a.size, v0_5.DataDependentSize):
350                s = -1
351            elif isinstance(a.size, v0_5.SizeReference):
352                s = LinearSampleAxisTransform(
353                    axis=a.size.axis_id,
354                    scale=axis_scales[a.size.tensor_id][a.size.axis_id] / a.scale,
355                    offset=a.size.offset,
356                    member=a.size.tensor_id,
357                )
358            else:
359                assert_never(a.size)
360
361            new_axes[a.id] = s
362
363        ret[out.id] = new_axes
364
365    return ret

returns how a model's output tensor shapes relates to its input shapes

def get_io_sample_block_metas( model: bioimageio.spec.ModelDescr, input_sample_shape: Mapping[bioimageio.spec.model.v0_5.TensorId, Mapping[bioimageio.spec.model.v0_5.AxisId, int]], ns: Mapping[Tuple[bioimageio.spec.model.v0_5.TensorId, bioimageio.spec.model.v0_5.AxisId], int], batch_size: int = 1) -> Tuple[int, Iterable[IO_SampleBlockMeta]]:
368def get_io_sample_block_metas(
369    model: v0_5.ModelDescr,
370    input_sample_shape: PerMember[PerAxis[int]],
371    ns: Mapping[Tuple[MemberId, AxisId], ParameterizedSize_N],
372    batch_size: int = 1,
373) -> Tuple[TotalNumberOfBlocks, Iterable[IO_SampleBlockMeta]]:
374    """returns an iterable yielding meta data for corresponding input and output samples"""
375    if not isinstance(model, v0_5.ModelDescr):
376        raise TypeError(f"get_block_meta() not implemented for {type(model)}")
377
378    block_axis_sizes = model.get_axis_sizes(ns=ns, batch_size=batch_size)
379    input_block_shape = {
380        t: {aa: s for (tt, aa), s in block_axis_sizes.inputs.items() if tt == t}
381        for t in {tt for tt, _ in block_axis_sizes.inputs}
382    }
383    output_halo = {
384        t.id: {
385            a.id: Halo(a.halo, a.halo) for a in t.axes if isinstance(a, v0_5.WithHalo)
386        }
387        for t in model.outputs
388    }
389    input_halo = get_input_halo(model, output_halo)
390
391    n_input_blocks, input_blocks = split_multiple_shapes_into_blocks(
392        input_sample_shape, input_block_shape, halo=input_halo
393    )
394    block_transform = get_block_transform(model)
395    return n_input_blocks, (
396        IO_SampleBlockMeta(ipt, ipt.get_transformed(block_transform))
397        for ipt in sample_block_meta_generator(
398            input_blocks, sample_shape=input_sample_shape, sample_id=None
399        )
400    )

returns an iterable yielding meta data for corresponding input and output samples

def get_tensor( src: Union[zipp.Path, bioimageio.core.Tensor, xarray.core.dataarray.DataArray, numpy.ndarray[Any, numpy.dtype[Any]], pathlib.Path], ipt: Union[bioimageio.spec.model.v0_4.InputTensorDescr, bioimageio.spec.model.v0_5.InputTensorDescr]):
403def get_tensor(
404    src: Union[ZipPath, TensorSource],
405    ipt: Union[v0_4.InputTensorDescr, v0_5.InputTensorDescr],
406):
407    """helper to cast/load various tensor sources"""
408
409    if isinstance(src, Tensor):
410        return src
411    elif isinstance(src, xr.DataArray):
412        return Tensor.from_xarray(src)
413    elif isinstance(src, np.ndarray):
414        return Tensor.from_numpy(src, dims=get_axes_infos(ipt))
415    else:
416        return load_tensor(src, axes=get_axes_infos(ipt))

helper to cast/load various tensor sources

def create_sample_for_model( model: Annotated[Union[Annotated[bioimageio.spec.model.v0_4.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.4')], Annotated[bioimageio.spec.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.5')]], Discriminator(discriminator='format_version', custom_error_type=None, custom_error_message=None, custom_error_context=None), FieldInfo(annotation=NoneType, required=True, title='model')], *, stat: Optional[Dict[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer>, return_type=PydanticUndefined, when_used='always')]]]] = None, sample_id: Hashable = None, inputs: Union[Mapping[bioimageio.spec.model.v0_5.TensorId, Union[bioimageio.core.Tensor, xarray.core.dataarray.DataArray, numpy.ndarray[Any, numpy.dtype[Any]], pathlib.Path]], bioimageio.core.Tensor, xarray.core.dataarray.DataArray, numpy.ndarray[Any, numpy.dtype[Any]], pathlib.Path]) -> bioimageio.core.Sample:
419def create_sample_for_model(
420    model: AnyModelDescr,
421    *,
422    stat: Optional[Stat] = None,
423    sample_id: SampleId = None,
424    inputs: Union[PerMember[TensorSource], TensorSource],
425) -> Sample:
426    """Create a sample from a single set of input(s) for a specific bioimage.io model
427
428    Args:
429        model: a bioimage.io model description
430        stat: dictionary with sample and dataset statistics (may be updated in-place!)
431        inputs: the input(s) constituting a single sample.
432    """
433
434    model_inputs = {get_member_id(d): d for d in model.inputs}
435    if isinstance(inputs, collections.abc.Mapping):
436        inputs = {MemberId(k): v for k, v in inputs.items()}
437    elif len(model_inputs) == 1:
438        inputs = {list(model_inputs)[0]: inputs}
439    else:
440        raise TypeError(
441            f"Expected `inputs` to be a mapping with keys {tuple(model_inputs)}"
442        )
443
444    if unknown := {k for k in inputs if k not in model_inputs}:
445        raise ValueError(f"Got unexpected inputs: {unknown}")
446
447    if missing := {
448        k
449        for k, v in model_inputs.items()
450        if k not in inputs and not (isinstance(v, v0_5.InputTensorDescr) and v.optional)
451    }:
452        raise ValueError(f"Missing non-optional model inputs: {missing}")
453
454    return Sample(
455        members={
456            m: get_tensor(inputs[m], ipt)
457            for m, ipt in model_inputs.items()
458            if m in inputs
459        },
460        stat={} if stat is None else stat,
461        id=sample_id,
462    )

Create a sample from a single set of input(s) for a specific bioimage.io model

Arguments:
  • model: a bioimage.io model description
  • stat: dictionary with sample and dataset statistics (may be updated in-place!)
  • inputs: the input(s) constituting a single sample.
def load_sample_for_model( *, model: Annotated[Union[Annotated[bioimageio.spec.model.v0_4.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.4')], Annotated[bioimageio.spec.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.5')]], Discriminator(discriminator='format_version', custom_error_type=None, custom_error_message=None, custom_error_context=None), FieldInfo(annotation=NoneType, required=True, title='model')], paths: Mapping[bioimageio.spec.model.v0_5.TensorId, pathlib.Path], axes: Optional[Mapping[bioimageio.spec.model.v0_5.TensorId, Sequence[Union[bioimageio.spec.model.v0_5.AxisId, Literal['b', 'i', 't', 'c', 'z', 'y', 'x'], Annotated[Union[bioimageio.spec.model.v0_5.BatchAxis, bioimageio.spec.model.v0_5.ChannelAxis, bioimageio.spec.model.v0_5.IndexInputAxis, bioimageio.spec.model.v0_5.TimeInputAxis, bioimageio.spec.model.v0_5.SpaceInputAxis], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.spec.model.v0_5.BatchAxis, bioimageio.spec.model.v0_5.ChannelAxis, bioimageio.spec.model.v0_5.IndexOutputAxis, Annotated[Union[Annotated[bioimageio.spec.model.v0_5.TimeOutputAxis, Tag(tag='wo_halo')], Annotated[bioimageio.spec.model.v0_5.TimeOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[Annotated[bioimageio.spec.model.v0_5.SpaceOutputAxis, Tag(tag='wo_halo')], Annotated[bioimageio.spec.model.v0_5.SpaceOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], bioimageio.core.Axis]]]] = None, stat: Optional[Dict[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer>, return_type=PydanticUndefined, when_used='always')]]]] = None, sample_id: Optional[Hashable] = None):
465def load_sample_for_model(
466    *,
467    model: AnyModelDescr,
468    paths: PerMember[Path],
469    axes: Optional[PerMember[Sequence[AxisLike]]] = None,
470    stat: Optional[Stat] = None,
471    sample_id: Optional[SampleId] = None,
472):
473    """load a single sample from `paths` that can be processed by `model`"""
474
475    if axes is None:
476        axes = {}
477
478    # make sure members are keyed by MemberId, not string
479    paths = {MemberId(k): v for k, v in paths.items()}
480    axes = {MemberId(k): v for k, v in axes.items()}
481
482    model_inputs = {get_member_id(d): d for d in model.inputs}
483
484    if unknown := {k for k in paths if k not in model_inputs}:
485        raise ValueError(f"Got unexpected paths for {unknown}")
486
487    if unknown := {k for k in axes if k not in model_inputs}:
488        raise ValueError(f"Got unexpected axes hints for: {unknown}")
489
490    members: Dict[MemberId, Tensor] = {}
491    for m, p in paths.items():
492        if m not in axes:
493            axes[m] = get_axes_infos(model_inputs[m])
494            logger.debug(
495                "loading '{}' from {} with default input axes {} ",
496                m,
497                p,
498                axes[m],
499            )
500        members[m] = load_tensor(p, axes[m])
501
502    return Sample(
503        members=members,
504        stat={} if stat is None else stat,
505        id=sample_id or tuple(sorted(paths.values())),
506    )

load a single sample from paths that can be processed by model