bioimageio.core.digest_spec

  1from __future__ import annotations
  2
  3import importlib.util
  4from itertools import chain
  5from pathlib import Path
  6from typing import (
  7    Any,
  8    Callable,
  9    Dict,
 10    Iterable,
 11    List,
 12    Mapping,
 13    NamedTuple,
 14    Optional,
 15    Sequence,
 16    Tuple,
 17    Union,
 18)
 19
 20import numpy as np
 21import xarray as xr
 22from loguru import logger
 23from numpy.typing import NDArray
 24from typing_extensions import Unpack, assert_never
 25
 26from bioimageio.spec._internal.io import resolve_and_extract
 27from bioimageio.spec._internal.io_utils import HashKwargs
 28from bioimageio.spec.common import FileSource
 29from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
 30from bioimageio.spec.model.v0_4 import CallableFromDepencency, CallableFromFile
 31from bioimageio.spec.model.v0_5 import (
 32    ArchitectureFromFileDescr,
 33    ArchitectureFromLibraryDescr,
 34    ParameterizedSize_N,
 35)
 36from bioimageio.spec.utils import load_array
 37
 38from .axis import AxisId, AxisInfo, AxisLike, PerAxis
 39from .block_meta import split_multiple_shapes_into_blocks
 40from .common import Halo, MemberId, PerMember, SampleId, TotalNumberOfBlocks
 41from .io import load_tensor
 42from .sample import (
 43    LinearSampleAxisTransform,
 44    Sample,
 45    SampleBlockMeta,
 46    sample_block_meta_generator,
 47)
 48from .stat_measures import Stat
 49from .tensor import Tensor
 50
 51
 52def import_callable(
 53    node: Union[CallableFromDepencency, ArchitectureFromLibraryDescr],
 54    /,
 55    **kwargs: Unpack[HashKwargs],
 56) -> Callable[..., Any]:
 57    """import a callable (e.g. a torch.nn.Module) from a spec node describing it"""
 58    if isinstance(node, CallableFromDepencency):
 59        module = importlib.import_module(node.module_name)
 60        c = getattr(module, str(node.callable_name))
 61    elif isinstance(node, ArchitectureFromLibraryDescr):
 62        module = importlib.import_module(node.import_from)
 63        c = getattr(module, str(node.callable))
 64    elif isinstance(node, CallableFromFile):
 65        c = _import_from_file_impl(node.source_file, str(node.callable_name), **kwargs)
 66    elif isinstance(node, ArchitectureFromFileDescr):
 67        c = _import_from_file_impl(node.source, str(node.callable), sha256=node.sha256)
 68
 69    else:
 70        assert_never(node)
 71
 72    if not callable(c):
 73        raise ValueError(f"{node} (imported: {c}) is not callable")
 74
 75    return c
 76
 77
 78def _import_from_file_impl(
 79    source: FileSource, callable_name: str, **kwargs: Unpack[HashKwargs]
 80):
 81    local_file = resolve_and_extract(source, **kwargs)
 82    module_name = local_file.path.stem
 83    importlib_spec = importlib.util.spec_from_file_location(
 84        module_name, local_file.path
 85    )
 86    if importlib_spec is None:
 87        raise ImportError(f"Failed to import {module_name} from {source}.")
 88
 89    dep = importlib.util.module_from_spec(importlib_spec)
 90    importlib_spec.loader.exec_module(dep)  # type: ignore  # todo: possible to use "loader.load_module"?
 91    return getattr(dep, callable_name)
 92
 93
 94def get_axes_infos(
 95    io_descr: Union[
 96        v0_4.InputTensorDescr,
 97        v0_4.OutputTensorDescr,
 98        v0_5.InputTensorDescr,
 99        v0_5.OutputTensorDescr,
100    ],
101) -> List[AxisInfo]:
102    """get a unified, simplified axis representation from spec axes"""
103    return [
104        (
105            AxisInfo.create("i")
106            if isinstance(a, str) and a not in ("b", "i", "t", "c", "z", "y", "x")
107            else AxisInfo.create(a)
108        )
109        for a in io_descr.axes
110    ]
111
112
113def get_member_id(
114    tensor_description: Union[
115        v0_4.InputTensorDescr,
116        v0_4.OutputTensorDescr,
117        v0_5.InputTensorDescr,
118        v0_5.OutputTensorDescr,
119    ],
120) -> MemberId:
121    """get the normalized tensor ID, usable as a sample member ID"""
122
123    if isinstance(tensor_description, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)):
124        return MemberId(tensor_description.name)
125    elif isinstance(
126        tensor_description, (v0_5.InputTensorDescr, v0_5.OutputTensorDescr)
127    ):
128        return tensor_description.id
129    else:
130        assert_never(tensor_description)
131
132
133def get_member_ids(
134    tensor_descriptions: Sequence[
135        Union[
136            v0_4.InputTensorDescr,
137            v0_4.OutputTensorDescr,
138            v0_5.InputTensorDescr,
139            v0_5.OutputTensorDescr,
140        ]
141    ],
142) -> List[MemberId]:
143    """get normalized tensor IDs to be used as sample member IDs"""
144    return [get_member_id(descr) for descr in tensor_descriptions]
145
146
147def get_test_inputs(model: AnyModelDescr) -> Sample:
148    """returns a model's test input sample"""
149    member_ids = get_member_ids(model.inputs)
150    if isinstance(model, v0_4.ModelDescr):
151        arrays = [load_array(tt) for tt in model.test_inputs]
152    else:
153        arrays = [load_array(d.test_tensor) for d in model.inputs]
154
155    axes = [get_axes_infos(t) for t in model.inputs]
156    return Sample(
157        members={
158            m: Tensor.from_numpy(arr, dims=ax)
159            for m, arr, ax in zip(member_ids, arrays, axes)
160        },
161        stat={},
162        id="test-sample",
163    )
164
165
166def get_test_outputs(model: AnyModelDescr) -> Sample:
167    """returns a model's test output sample"""
168    member_ids = get_member_ids(model.outputs)
169
170    if isinstance(model, v0_4.ModelDescr):
171        arrays = [load_array(tt) for tt in model.test_outputs]
172    else:
173        arrays = [load_array(d.test_tensor) for d in model.outputs]
174
175    axes = [get_axes_infos(t) for t in model.outputs]
176
177    return Sample(
178        members={
179            m: Tensor.from_numpy(arr, dims=ax)
180            for m, arr, ax in zip(member_ids, arrays, axes)
181        },
182        stat={},
183        id="test-sample",
184    )
185
186
187class IO_SampleBlockMeta(NamedTuple):
188    input: SampleBlockMeta
189    output: SampleBlockMeta
190
191
192def get_input_halo(model: v0_5.ModelDescr, output_halo: PerMember[PerAxis[Halo]]):
193    """returns which halo input tensors need to be divided into blocks with, such that
194    `output_halo` can be cropped from their outputs without introducing gaps."""
195    input_halo: Dict[MemberId, Dict[AxisId, Halo]] = {}
196    outputs = {t.id: t for t in model.outputs}
197    all_tensors = {**{t.id: t for t in model.inputs}, **outputs}
198
199    for t, th in output_halo.items():
200        axes = {a.id: a for a in outputs[t].axes}
201
202        for a, ah in th.items():
203            s = axes[a].size
204            if not isinstance(s, v0_5.SizeReference):
205                raise ValueError(
206                    f"Unable to map output halo for {t}.{a} to an input axis"
207                )
208
209            axis = axes[a]
210            ref_axis = {a.id: a for a in all_tensors[s.tensor_id].axes}[s.axis_id]
211
212            total_output_halo = sum(ah)
213            total_input_halo = total_output_halo * axis.scale / ref_axis.scale
214            assert (
215                total_input_halo == int(total_input_halo) and total_input_halo % 2 == 0
216            )
217            input_halo.setdefault(s.tensor_id, {})[a] = Halo(
218                int(total_input_halo // 2), int(total_input_halo // 2)
219            )
220
221    return input_halo
222
223
224def get_block_transform(
225    model: v0_5.ModelDescr,
226) -> PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]]:
227    """returns how a model's output tensor shapes relates to its input shapes"""
228    ret: Dict[MemberId, Dict[AxisId, Union[LinearSampleAxisTransform, int]]] = {}
229    batch_axis_trf = None
230    for ipt in model.inputs:
231        for a in ipt.axes:
232            if a.type == "batch":
233                batch_axis_trf = LinearSampleAxisTransform(
234                    axis=a.id, scale=1, offset=0, member=ipt.id
235                )
236                break
237        if batch_axis_trf is not None:
238            break
239    axis_scales = {
240        t.id: {a.id: a.scale for a in t.axes}
241        for t in chain(model.inputs, model.outputs)
242    }
243    for out in model.outputs:
244        new_axes: Dict[AxisId, Union[LinearSampleAxisTransform, int]] = {}
245        for a in out.axes:
246            if a.size is None:
247                assert a.type == "batch"
248                if batch_axis_trf is None:
249                    raise ValueError(
250                        "no batch axis found in any input tensor, but output tensor"
251                        + f" '{out.id}' has one."
252                    )
253                s = batch_axis_trf
254            elif isinstance(a.size, int):
255                s = a.size
256            elif isinstance(a.size, v0_5.DataDependentSize):
257                s = -1
258            elif isinstance(a.size, v0_5.SizeReference):
259                s = LinearSampleAxisTransform(
260                    axis=a.size.axis_id,
261                    scale=axis_scales[a.size.tensor_id][a.size.axis_id] / a.scale,
262                    offset=a.size.offset,
263                    member=a.size.tensor_id,
264                )
265            else:
266                assert_never(a.size)
267
268            new_axes[a.id] = s
269
270        ret[out.id] = new_axes
271
272    return ret
273
274
275def get_io_sample_block_metas(
276    model: v0_5.ModelDescr,
277    input_sample_shape: PerMember[PerAxis[int]],
278    ns: Mapping[Tuple[MemberId, AxisId], ParameterizedSize_N],
279    batch_size: int = 1,
280) -> Tuple[TotalNumberOfBlocks, Iterable[IO_SampleBlockMeta]]:
281    """returns an iterable yielding meta data for corresponding input and output samples"""
282    if not isinstance(model, v0_5.ModelDescr):
283        raise TypeError(f"get_block_meta() not implemented for {type(model)}")
284
285    block_axis_sizes = model.get_axis_sizes(ns=ns, batch_size=batch_size)
286    input_block_shape = {
287        t: {aa: s for (tt, aa), s in block_axis_sizes.inputs.items() if tt == t}
288        for t in {tt for tt, _ in block_axis_sizes.inputs}
289    }
290    output_halo = {
291        t.id: {
292            a.id: Halo(a.halo, a.halo) for a in t.axes if isinstance(a, v0_5.WithHalo)
293        }
294        for t in model.outputs
295    }
296    input_halo = get_input_halo(model, output_halo)
297
298    n_input_blocks, input_blocks = split_multiple_shapes_into_blocks(
299        input_sample_shape, input_block_shape, halo=input_halo
300    )
301    block_transform = get_block_transform(model)
302    return n_input_blocks, (
303        IO_SampleBlockMeta(ipt, ipt.get_transformed(block_transform))
304        for ipt in sample_block_meta_generator(
305            input_blocks, sample_shape=input_sample_shape, sample_id=None
306        )
307    )
308
309
310def get_tensor(
311    src: Union[Tensor, xr.DataArray, NDArray[Any], Path],
312    ipt: Union[v0_4.InputTensorDescr, v0_5.InputTensorDescr],
313):
314    """helper to cast/load various tensor sources"""
315
316    if isinstance(src, Tensor):
317        return src
318
319    if isinstance(src, xr.DataArray):
320        return Tensor.from_xarray(src)
321
322    if isinstance(src, np.ndarray):
323        return Tensor.from_numpy(src, dims=get_axes_infos(ipt))
324
325    if isinstance(src, Path):
326        return load_tensor(src, axes=get_axes_infos(ipt))
327
328    assert_never(src)
329
330
331def create_sample_for_model(
332    model: AnyModelDescr,
333    *,
334    stat: Optional[Stat] = None,
335    sample_id: SampleId = None,
336    inputs: Optional[
337        PerMember[Union[Tensor, xr.DataArray, NDArray[Any], Path]]
338    ] = None,  # TODO: make non-optional
339    **kwargs: NDArray[Any],  # TODO: deprecate in favor of `inputs`
340) -> Sample:
341    """Create a sample from a single set of input(s) for a specific bioimage.io model
342
343    Args:
344        model: a bioimage.io model description
345        stat: dictionary with sample and dataset statistics (may be updated in-place!)
346        inputs: the input(s) constituting a single sample.
347    """
348    inputs = {MemberId(k): v for k, v in {**kwargs, **(inputs or {})}.items()}
349
350    model_inputs = {get_member_id(d): d for d in model.inputs}
351    if unknown := {k for k in inputs if k not in model_inputs}:
352        raise ValueError(f"Got unexpected inputs: {unknown}")
353
354    if missing := {
355        k
356        for k, v in model_inputs.items()
357        if k not in inputs and not (isinstance(v, v0_5.InputTensorDescr) and v.optional)
358    }:
359        raise ValueError(f"Missing non-optional model inputs: {missing}")
360
361    return Sample(
362        members={
363            m: get_tensor(inputs[m], ipt)
364            for m, ipt in model_inputs.items()
365            if m in inputs
366        },
367        stat={} if stat is None else stat,
368        id=sample_id,
369    )
370
371
372def load_sample_for_model(
373    *,
374    model: AnyModelDescr,
375    paths: PerMember[Path],
376    axes: Optional[PerMember[Sequence[AxisLike]]] = None,
377    stat: Optional[Stat] = None,
378    sample_id: Optional[SampleId] = None,
379):
380    """load a single sample from `paths` that can be processed by `model`"""
381
382    if axes is None:
383        axes = {}
384
385    # make sure members are keyed by MemberId, not string
386    paths = {MemberId(k): v for k, v in paths.items()}
387    axes = {MemberId(k): v for k, v in axes.items()}
388
389    model_inputs = {get_member_id(d): d for d in model.inputs}
390
391    if unknown := {k for k in paths if k not in model_inputs}:
392        raise ValueError(f"Got unexpected paths for {unknown}")
393
394    if unknown := {k for k in axes if k not in model_inputs}:
395        raise ValueError(f"Got unexpected axes hints for: {unknown}")
396
397    members: Dict[MemberId, Tensor] = {}
398    for m, p in paths.items():
399        if m not in axes:
400            axes[m] = get_axes_infos(model_inputs[m])
401            logger.debug(
402                "loading '{}' from {} with default input axes {} ",
403                m,
404                p,
405                axes[m],
406            )
407        members[m] = load_tensor(p, axes[m])
408
409    return Sample(
410        members=members,
411        stat={} if stat is None else stat,
412        id=sample_id or tuple(sorted(paths.values())),
413    )
def import_callable( node: Union[bioimageio.spec.model.v0_4.CallableFromDepencency, bioimageio.spec.model.v0_5.ArchitectureFromLibraryDescr], /, **kwargs: Unpack[bioimageio.spec._internal.io.HashKwargs]) -> Callable[..., Any]:
53def import_callable(
54    node: Union[CallableFromDepencency, ArchitectureFromLibraryDescr],
55    /,
56    **kwargs: Unpack[HashKwargs],
57) -> Callable[..., Any]:
58    """import a callable (e.g. a torch.nn.Module) from a spec node describing it"""
59    if isinstance(node, CallableFromDepencency):
60        module = importlib.import_module(node.module_name)
61        c = getattr(module, str(node.callable_name))
62    elif isinstance(node, ArchitectureFromLibraryDescr):
63        module = importlib.import_module(node.import_from)
64        c = getattr(module, str(node.callable))
65    elif isinstance(node, CallableFromFile):
66        c = _import_from_file_impl(node.source_file, str(node.callable_name), **kwargs)
67    elif isinstance(node, ArchitectureFromFileDescr):
68        c = _import_from_file_impl(node.source, str(node.callable), sha256=node.sha256)
69
70    else:
71        assert_never(node)
72
73    if not callable(c):
74        raise ValueError(f"{node} (imported: {c}) is not callable")
75
76    return c

import a callable (e.g. a torch.nn.Module) from a spec node describing it

 95def get_axes_infos(
 96    io_descr: Union[
 97        v0_4.InputTensorDescr,
 98        v0_4.OutputTensorDescr,
 99        v0_5.InputTensorDescr,
100        v0_5.OutputTensorDescr,
101    ],
102) -> List[AxisInfo]:
103    """get a unified, simplified axis representation from spec axes"""
104    return [
105        (
106            AxisInfo.create("i")
107            if isinstance(a, str) and a not in ("b", "i", "t", "c", "z", "y", "x")
108            else AxisInfo.create(a)
109        )
110        for a in io_descr.axes
111    ]

get a unified, simplified axis representation from spec axes

114def get_member_id(
115    tensor_description: Union[
116        v0_4.InputTensorDescr,
117        v0_4.OutputTensorDescr,
118        v0_5.InputTensorDescr,
119        v0_5.OutputTensorDescr,
120    ],
121) -> MemberId:
122    """get the normalized tensor ID, usable as a sample member ID"""
123
124    if isinstance(tensor_description, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)):
125        return MemberId(tensor_description.name)
126    elif isinstance(
127        tensor_description, (v0_5.InputTensorDescr, v0_5.OutputTensorDescr)
128    ):
129        return tensor_description.id
130    else:
131        assert_never(tensor_description)

get the normalized tensor ID, usable as a sample member ID

134def get_member_ids(
135    tensor_descriptions: Sequence[
136        Union[
137            v0_4.InputTensorDescr,
138            v0_4.OutputTensorDescr,
139            v0_5.InputTensorDescr,
140            v0_5.OutputTensorDescr,
141        ]
142    ],
143) -> List[MemberId]:
144    """get normalized tensor IDs to be used as sample member IDs"""
145    return [get_member_id(descr) for descr in tensor_descriptions]

get normalized tensor IDs to be used as sample member IDs

def get_test_inputs( model: Annotated[Union[bioimageio.spec.model.v0_4.ModelDescr, bioimageio.spec.ModelDescr], Discriminator(discriminator='format_version', custom_error_type=None, custom_error_message=None, custom_error_context=None)]) -> bioimageio.core.Sample:
148def get_test_inputs(model: AnyModelDescr) -> Sample:
149    """returns a model's test input sample"""
150    member_ids = get_member_ids(model.inputs)
151    if isinstance(model, v0_4.ModelDescr):
152        arrays = [load_array(tt) for tt in model.test_inputs]
153    else:
154        arrays = [load_array(d.test_tensor) for d in model.inputs]
155
156    axes = [get_axes_infos(t) for t in model.inputs]
157    return Sample(
158        members={
159            m: Tensor.from_numpy(arr, dims=ax)
160            for m, arr, ax in zip(member_ids, arrays, axes)
161        },
162        stat={},
163        id="test-sample",
164    )

returns a model's test input sample

def get_test_outputs( model: Annotated[Union[bioimageio.spec.model.v0_4.ModelDescr, bioimageio.spec.ModelDescr], Discriminator(discriminator='format_version', custom_error_type=None, custom_error_message=None, custom_error_context=None)]) -> bioimageio.core.Sample:
167def get_test_outputs(model: AnyModelDescr) -> Sample:
168    """returns a model's test output sample"""
169    member_ids = get_member_ids(model.outputs)
170
171    if isinstance(model, v0_4.ModelDescr):
172        arrays = [load_array(tt) for tt in model.test_outputs]
173    else:
174        arrays = [load_array(d.test_tensor) for d in model.outputs]
175
176    axes = [get_axes_infos(t) for t in model.outputs]
177
178    return Sample(
179        members={
180            m: Tensor.from_numpy(arr, dims=ax)
181            for m, arr, ax in zip(member_ids, arrays, axes)
182        },
183        stat={},
184        id="test-sample",
185    )

returns a model's test output sample

class IO_SampleBlockMeta(typing.NamedTuple):
188class IO_SampleBlockMeta(NamedTuple):
189    input: SampleBlockMeta
190    output: SampleBlockMeta

IO_SampleBlockMeta(input, output)

IO_SampleBlockMeta( input: ForwardRef('SampleBlockMeta'), output: ForwardRef('SampleBlockMeta'))

Create new instance of IO_SampleBlockMeta(input, output)

Alias for field number 0

Alias for field number 1

193def get_input_halo(model: v0_5.ModelDescr, output_halo: PerMember[PerAxis[Halo]]):
194    """returns which halo input tensors need to be divided into blocks with, such that
195    `output_halo` can be cropped from their outputs without introducing gaps."""
196    input_halo: Dict[MemberId, Dict[AxisId, Halo]] = {}
197    outputs = {t.id: t for t in model.outputs}
198    all_tensors = {**{t.id: t for t in model.inputs}, **outputs}
199
200    for t, th in output_halo.items():
201        axes = {a.id: a for a in outputs[t].axes}
202
203        for a, ah in th.items():
204            s = axes[a].size
205            if not isinstance(s, v0_5.SizeReference):
206                raise ValueError(
207                    f"Unable to map output halo for {t}.{a} to an input axis"
208                )
209
210            axis = axes[a]
211            ref_axis = {a.id: a for a in all_tensors[s.tensor_id].axes}[s.axis_id]
212
213            total_output_halo = sum(ah)
214            total_input_halo = total_output_halo * axis.scale / ref_axis.scale
215            assert (
216                total_input_halo == int(total_input_halo) and total_input_halo % 2 == 0
217            )
218            input_halo.setdefault(s.tensor_id, {})[a] = Halo(
219                int(total_input_halo // 2), int(total_input_halo // 2)
220            )
221
222    return input_halo

returns which halo input tensors need to be divided into blocks with, such that output_halo can be cropped from their outputs without introducing gaps.

225def get_block_transform(
226    model: v0_5.ModelDescr,
227) -> PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]]:
228    """returns how a model's output tensor shapes relates to its input shapes"""
229    ret: Dict[MemberId, Dict[AxisId, Union[LinearSampleAxisTransform, int]]] = {}
230    batch_axis_trf = None
231    for ipt in model.inputs:
232        for a in ipt.axes:
233            if a.type == "batch":
234                batch_axis_trf = LinearSampleAxisTransform(
235                    axis=a.id, scale=1, offset=0, member=ipt.id
236                )
237                break
238        if batch_axis_trf is not None:
239            break
240    axis_scales = {
241        t.id: {a.id: a.scale for a in t.axes}
242        for t in chain(model.inputs, model.outputs)
243    }
244    for out in model.outputs:
245        new_axes: Dict[AxisId, Union[LinearSampleAxisTransform, int]] = {}
246        for a in out.axes:
247            if a.size is None:
248                assert a.type == "batch"
249                if batch_axis_trf is None:
250                    raise ValueError(
251                        "no batch axis found in any input tensor, but output tensor"
252                        + f" '{out.id}' has one."
253                    )
254                s = batch_axis_trf
255            elif isinstance(a.size, int):
256                s = a.size
257            elif isinstance(a.size, v0_5.DataDependentSize):
258                s = -1
259            elif isinstance(a.size, v0_5.SizeReference):
260                s = LinearSampleAxisTransform(
261                    axis=a.size.axis_id,
262                    scale=axis_scales[a.size.tensor_id][a.size.axis_id] / a.scale,
263                    offset=a.size.offset,
264                    member=a.size.tensor_id,
265                )
266            else:
267                assert_never(a.size)
268
269            new_axes[a.id] = s
270
271        ret[out.id] = new_axes
272
273    return ret

returns how a model's output tensor shapes relates to its input shapes

def get_io_sample_block_metas( model: bioimageio.spec.ModelDescr, input_sample_shape: Mapping[bioimageio.spec.model.v0_5.TensorId, Mapping[bioimageio.spec.model.v0_5.AxisId, int]], ns: Mapping[Tuple[bioimageio.spec.model.v0_5.TensorId, bioimageio.spec.model.v0_5.AxisId], int], batch_size: int = 1) -> Tuple[int, Iterable[IO_SampleBlockMeta]]:
276def get_io_sample_block_metas(
277    model: v0_5.ModelDescr,
278    input_sample_shape: PerMember[PerAxis[int]],
279    ns: Mapping[Tuple[MemberId, AxisId], ParameterizedSize_N],
280    batch_size: int = 1,
281) -> Tuple[TotalNumberOfBlocks, Iterable[IO_SampleBlockMeta]]:
282    """returns an iterable yielding meta data for corresponding input and output samples"""
283    if not isinstance(model, v0_5.ModelDescr):
284        raise TypeError(f"get_block_meta() not implemented for {type(model)}")
285
286    block_axis_sizes = model.get_axis_sizes(ns=ns, batch_size=batch_size)
287    input_block_shape = {
288        t: {aa: s for (tt, aa), s in block_axis_sizes.inputs.items() if tt == t}
289        for t in {tt for tt, _ in block_axis_sizes.inputs}
290    }
291    output_halo = {
292        t.id: {
293            a.id: Halo(a.halo, a.halo) for a in t.axes if isinstance(a, v0_5.WithHalo)
294        }
295        for t in model.outputs
296    }
297    input_halo = get_input_halo(model, output_halo)
298
299    n_input_blocks, input_blocks = split_multiple_shapes_into_blocks(
300        input_sample_shape, input_block_shape, halo=input_halo
301    )
302    block_transform = get_block_transform(model)
303    return n_input_blocks, (
304        IO_SampleBlockMeta(ipt, ipt.get_transformed(block_transform))
305        for ipt in sample_block_meta_generator(
306            input_blocks, sample_shape=input_sample_shape, sample_id=None
307        )
308    )

returns an iterable yielding meta data for corresponding input and output samples

def get_tensor( src: Union[bioimageio.core.Tensor, xarray.core.dataarray.DataArray, numpy.ndarray[Any, numpy.dtype[Any]], pathlib.Path], ipt: Union[bioimageio.spec.model.v0_4.InputTensorDescr, bioimageio.spec.model.v0_5.InputTensorDescr]):
311def get_tensor(
312    src: Union[Tensor, xr.DataArray, NDArray[Any], Path],
313    ipt: Union[v0_4.InputTensorDescr, v0_5.InputTensorDescr],
314):
315    """helper to cast/load various tensor sources"""
316
317    if isinstance(src, Tensor):
318        return src
319
320    if isinstance(src, xr.DataArray):
321        return Tensor.from_xarray(src)
322
323    if isinstance(src, np.ndarray):
324        return Tensor.from_numpy(src, dims=get_axes_infos(ipt))
325
326    if isinstance(src, Path):
327        return load_tensor(src, axes=get_axes_infos(ipt))
328
329    assert_never(src)

helper to cast/load various tensor sources

def create_sample_for_model( model: Annotated[Union[bioimageio.spec.model.v0_4.ModelDescr, bioimageio.spec.ModelDescr], Discriminator(discriminator='format_version', custom_error_type=None, custom_error_message=None, custom_error_context=None)], *, stat: Optional[Dict[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer>, return_type=PydanticUndefined, when_used='always')]]]] = None, sample_id: Hashable = None, inputs: Optional[Mapping[bioimageio.spec.model.v0_5.TensorId, Union[bioimageio.core.Tensor, xarray.core.dataarray.DataArray, numpy.ndarray[Any, numpy.dtype[Any]], pathlib.Path]]] = None, **kwargs: numpy.ndarray[typing.Any, numpy.dtype[typing.Any]]) -> bioimageio.core.Sample:
332def create_sample_for_model(
333    model: AnyModelDescr,
334    *,
335    stat: Optional[Stat] = None,
336    sample_id: SampleId = None,
337    inputs: Optional[
338        PerMember[Union[Tensor, xr.DataArray, NDArray[Any], Path]]
339    ] = None,  # TODO: make non-optional
340    **kwargs: NDArray[Any],  # TODO: deprecate in favor of `inputs`
341) -> Sample:
342    """Create a sample from a single set of input(s) for a specific bioimage.io model
343
344    Args:
345        model: a bioimage.io model description
346        stat: dictionary with sample and dataset statistics (may be updated in-place!)
347        inputs: the input(s) constituting a single sample.
348    """
349    inputs = {MemberId(k): v for k, v in {**kwargs, **(inputs or {})}.items()}
350
351    model_inputs = {get_member_id(d): d for d in model.inputs}
352    if unknown := {k for k in inputs if k not in model_inputs}:
353        raise ValueError(f"Got unexpected inputs: {unknown}")
354
355    if missing := {
356        k
357        for k, v in model_inputs.items()
358        if k not in inputs and not (isinstance(v, v0_5.InputTensorDescr) and v.optional)
359    }:
360        raise ValueError(f"Missing non-optional model inputs: {missing}")
361
362    return Sample(
363        members={
364            m: get_tensor(inputs[m], ipt)
365            for m, ipt in model_inputs.items()
366            if m in inputs
367        },
368        stat={} if stat is None else stat,
369        id=sample_id,
370    )

Create a sample from a single set of input(s) for a specific bioimage.io model

Arguments:
  • model: a bioimage.io model description
  • stat: dictionary with sample and dataset statistics (may be updated in-place!)
  • inputs: the input(s) constituting a single sample.
def load_sample_for_model( *, model: Annotated[Union[bioimageio.spec.model.v0_4.ModelDescr, bioimageio.spec.ModelDescr], Discriminator(discriminator='format_version', custom_error_type=None, custom_error_message=None, custom_error_context=None)], paths: Mapping[bioimageio.spec.model.v0_5.TensorId, pathlib.Path], axes: Optional[Mapping[bioimageio.spec.model.v0_5.TensorId, Sequence[Union[bioimageio.spec.model.v0_5.AxisId, Literal['b', 'i', 't', 'c', 'z', 'y', 'x'], Annotated[Union[bioimageio.spec.model.v0_5.BatchAxis, bioimageio.spec.model.v0_5.ChannelAxis, bioimageio.spec.model.v0_5.IndexInputAxis, bioimageio.spec.model.v0_5.TimeInputAxis, bioimageio.spec.model.v0_5.SpaceInputAxis], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.spec.model.v0_5.BatchAxis, bioimageio.spec.model.v0_5.ChannelAxis, bioimageio.spec.model.v0_5.IndexOutputAxis, Annotated[Union[Annotated[bioimageio.spec.model.v0_5.TimeOutputAxis, Tag(tag='wo_halo')], Annotated[bioimageio.spec.model.v0_5.TimeOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[Annotated[bioimageio.spec.model.v0_5.SpaceOutputAxis, Tag(tag='wo_halo')], Annotated[bioimageio.spec.model.v0_5.SpaceOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], bioimageio.core.Axis]]]] = None, stat: Optional[Dict[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer>, return_type=PydanticUndefined, when_used='always')]]]] = None, sample_id: Optional[Hashable] = None):
373def load_sample_for_model(
374    *,
375    model: AnyModelDescr,
376    paths: PerMember[Path],
377    axes: Optional[PerMember[Sequence[AxisLike]]] = None,
378    stat: Optional[Stat] = None,
379    sample_id: Optional[SampleId] = None,
380):
381    """load a single sample from `paths` that can be processed by `model`"""
382
383    if axes is None:
384        axes = {}
385
386    # make sure members are keyed by MemberId, not string
387    paths = {MemberId(k): v for k, v in paths.items()}
388    axes = {MemberId(k): v for k, v in axes.items()}
389
390    model_inputs = {get_member_id(d): d for d in model.inputs}
391
392    if unknown := {k for k in paths if k not in model_inputs}:
393        raise ValueError(f"Got unexpected paths for {unknown}")
394
395    if unknown := {k for k in axes if k not in model_inputs}:
396        raise ValueError(f"Got unexpected axes hints for: {unknown}")
397
398    members: Dict[MemberId, Tensor] = {}
399    for m, p in paths.items():
400        if m not in axes:
401            axes[m] = get_axes_infos(model_inputs[m])
402            logger.debug(
403                "loading '{}' from {} with default input axes {} ",
404                m,
405                p,
406                axes[m],
407            )
408        members[m] = load_tensor(p, axes[m])
409
410    return Sample(
411        members=members,
412        stat={} if stat is None else stat,
413        id=sample_id or tuple(sorted(paths.values())),
414    )

load a single sample from paths that can be processed by model