bioimageio.core.digest_spec

  1from __future__ import annotations
  2
  3import collections.abc
  4import importlib.util
  5import sys
  6from itertools import chain
  7from pathlib import Path
  8from tempfile import TemporaryDirectory
  9from typing import (
 10    Any,
 11    Callable,
 12    Dict,
 13    Iterable,
 14    List,
 15    Mapping,
 16    NamedTuple,
 17    Optional,
 18    Sequence,
 19    Tuple,
 20    Union,
 21)
 22from zipfile import ZipFile, is_zipfile
 23
 24import numpy as np
 25import xarray as xr
 26from loguru import logger
 27from numpy.typing import NDArray
 28from typing_extensions import Unpack, assert_never
 29
 30from bioimageio.spec._internal.io import HashKwargs
 31from bioimageio.spec.common import FileDescr, FileSource, ZipPath
 32from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
 33from bioimageio.spec.model.v0_4 import CallableFromDepencency, CallableFromFile
 34from bioimageio.spec.model.v0_5 import (
 35    ArchitectureFromFileDescr,
 36    ArchitectureFromLibraryDescr,
 37    ParameterizedSize_N,
 38)
 39from bioimageio.spec.utils import load_array
 40
 41from .axis import Axis, AxisId, AxisInfo, AxisLike, PerAxis
 42from .block_meta import split_multiple_shapes_into_blocks
 43from .common import Halo, MemberId, PerMember, SampleId, TotalNumberOfBlocks
 44from .io import load_tensor
 45from .sample import (
 46    LinearSampleAxisTransform,
 47    Sample,
 48    SampleBlockMeta,
 49    sample_block_meta_generator,
 50)
 51from .stat_measures import Stat
 52from .tensor import Tensor
 53
 54TensorSource = Union[Tensor, xr.DataArray, NDArray[Any], Path]
 55
 56
 57def import_callable(
 58    node: Union[
 59        ArchitectureFromFileDescr,
 60        ArchitectureFromLibraryDescr,
 61        CallableFromDepencency,
 62        CallableFromFile,
 63    ],
 64    /,
 65    **kwargs: Unpack[HashKwargs],
 66) -> Callable[..., Any]:
 67    """import a callable (e.g. a torch.nn.Module) from a spec node describing it"""
 68    if isinstance(node, CallableFromDepencency):
 69        module = importlib.import_module(node.module_name)
 70        c = getattr(module, str(node.callable_name))
 71    elif isinstance(node, ArchitectureFromLibraryDescr):
 72        module = importlib.import_module(node.import_from)
 73        c = getattr(module, str(node.callable))
 74    elif isinstance(node, CallableFromFile):
 75        c = _import_from_file_impl(node.source_file, str(node.callable_name), **kwargs)
 76    elif isinstance(node, ArchitectureFromFileDescr):
 77        c = _import_from_file_impl(node.source, str(node.callable), sha256=node.sha256)
 78    else:
 79        assert_never(node)
 80
 81    if not callable(c):
 82        raise ValueError(f"{node} (imported: {c}) is not callable")
 83
 84    return c
 85
 86
 87def _import_from_file_impl(
 88    source: FileSource, callable_name: str, **kwargs: Unpack[HashKwargs]
 89):
 90    src_descr = FileDescr(source=source, **kwargs)
 91    # ensure sha is valid even if perform_io_checks=False
 92    # or the source has changed since last sha computation
 93    src_descr.validate_sha256(force_recompute=True)
 94    assert src_descr.sha256 is not None
 95    source_sha = src_descr.sha256
 96
 97    reader = src_descr.get_reader()
 98    # make sure we have unique module name
 99    module_name = f"{reader.original_file_name.split('.')[0]}_{source_sha}"
100
101    # make sure we have a unique and valid module name
102    if not module_name.isidentifier():
103        module_name = f"custom_module_{source_sha}"
104        assert module_name.isidentifier(), module_name
105
106    source_bytes = reader.read()
107
108    module = sys.modules.get(module_name)
109    if module is None:
110        try:
111            tmp_dir = TemporaryDirectory(ignore_cleanup_errors=True)
112            module_path = Path(tmp_dir.name) / module_name
113            if reader.original_file_name.endswith(".zip") or is_zipfile(reader):
114                module_path.mkdir()
115                ZipFile(reader).extractall(path=module_path)
116            else:
117                module_path = module_path.with_suffix(".py")
118                _ = module_path.write_bytes(source_bytes)
119
120            importlib_spec = importlib.util.spec_from_file_location(
121                module_name, str(module_path)
122            )
123
124            if importlib_spec is None:
125                raise ImportError(f"Failed to import {source}")
126
127            module = importlib.util.module_from_spec(importlib_spec)
128
129            sys.modules[module_name] = module  # cache this module
130
131            assert importlib_spec.loader is not None
132            importlib_spec.loader.exec_module(module)
133
134        except Exception as e:
135            del sys.modules[module_name]
136            raise ImportError(f"Failed to import {source}") from e
137
138    try:
139        callable_attr = getattr(module, callable_name)
140    except AttributeError as e:
141        raise AttributeError(
142            f"Imported custom module from {source} has no `{callable_name}` attribute."
143        ) from e
144    except Exception as e:
145        raise AttributeError(
146            f"Failed to access `{callable_name}` attribute from custom module imported from {source} ."
147        ) from e
148
149    else:
150        return callable_attr
151
152
153def get_axes_infos(
154    io_descr: Union[
155        v0_4.InputTensorDescr,
156        v0_4.OutputTensorDescr,
157        v0_5.InputTensorDescr,
158        v0_5.OutputTensorDescr,
159    ],
160) -> List[AxisInfo]:
161    """get a unified, simplified axis representation from spec axes"""
162    ret: List[AxisInfo] = []
163    for a in io_descr.axes:
164        if isinstance(a, v0_5.AxisBase):
165            ret.append(AxisInfo.create(Axis(id=a.id, type=a.type)))
166        else:
167            assert a in ("b", "i", "t", "c", "z", "y", "x")
168            ret.append(AxisInfo.create(a))
169
170    return ret
171
172
173def get_member_id(
174    tensor_description: Union[
175        v0_4.InputTensorDescr,
176        v0_4.OutputTensorDescr,
177        v0_5.InputTensorDescr,
178        v0_5.OutputTensorDescr,
179    ],
180) -> MemberId:
181    """get the normalized tensor ID, usable as a sample member ID"""
182
183    if isinstance(tensor_description, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)):
184        return MemberId(tensor_description.name)
185    elif isinstance(
186        tensor_description, (v0_5.InputTensorDescr, v0_5.OutputTensorDescr)
187    ):
188        return tensor_description.id
189    else:
190        assert_never(tensor_description)
191
192
193def get_member_ids(
194    tensor_descriptions: Sequence[
195        Union[
196            v0_4.InputTensorDescr,
197            v0_4.OutputTensorDescr,
198            v0_5.InputTensorDescr,
199            v0_5.OutputTensorDescr,
200        ]
201    ],
202) -> List[MemberId]:
203    """get normalized tensor IDs to be used as sample member IDs"""
204    return [get_member_id(descr) for descr in tensor_descriptions]
205
206
207def get_test_inputs(model: AnyModelDescr) -> Sample:
208    """returns a model's test input sample"""
209    member_ids = get_member_ids(model.inputs)
210    if isinstance(model, v0_4.ModelDescr):
211        arrays = [load_array(tt) for tt in model.test_inputs]
212    else:
213        arrays = [load_array(d.test_tensor) for d in model.inputs]
214
215    axes = [get_axes_infos(t) for t in model.inputs]
216    return Sample(
217        members={
218            m: Tensor.from_numpy(arr, dims=ax)
219            for m, arr, ax in zip(member_ids, arrays, axes)
220        },
221        stat={},
222        id="test-sample",
223    )
224
225
226def get_test_outputs(model: AnyModelDescr) -> Sample:
227    """returns a model's test output sample"""
228    member_ids = get_member_ids(model.outputs)
229
230    if isinstance(model, v0_4.ModelDescr):
231        arrays = [load_array(tt) for tt in model.test_outputs]
232    else:
233        arrays = [load_array(d.test_tensor) for d in model.outputs]
234
235    axes = [get_axes_infos(t) for t in model.outputs]
236
237    return Sample(
238        members={
239            m: Tensor.from_numpy(arr, dims=ax)
240            for m, arr, ax in zip(member_ids, arrays, axes)
241        },
242        stat={},
243        id="test-sample",
244    )
245
246
247class IO_SampleBlockMeta(NamedTuple):
248    input: SampleBlockMeta
249    output: SampleBlockMeta
250
251
252def get_input_halo(model: v0_5.ModelDescr, output_halo: PerMember[PerAxis[Halo]]):
253    """returns which halo input tensors need to be divided into blocks with, such that
254    `output_halo` can be cropped from their outputs without introducing gaps."""
255    input_halo: Dict[MemberId, Dict[AxisId, Halo]] = {}
256    outputs = {t.id: t for t in model.outputs}
257    all_tensors = {**{t.id: t for t in model.inputs}, **outputs}
258
259    for t, th in output_halo.items():
260        axes = {a.id: a for a in outputs[t].axes}
261
262        for a, ah in th.items():
263            s = axes[a].size
264            if not isinstance(s, v0_5.SizeReference):
265                raise ValueError(
266                    f"Unable to map output halo for {t}.{a} to an input axis"
267                )
268
269            axis = axes[a]
270            ref_axis = {a.id: a for a in all_tensors[s.tensor_id].axes}[s.axis_id]
271
272            total_output_halo = sum(ah)
273            total_input_halo = total_output_halo * axis.scale / ref_axis.scale
274            assert (
275                total_input_halo == int(total_input_halo) and total_input_halo % 2 == 0
276            )
277            input_halo.setdefault(s.tensor_id, {})[a] = Halo(
278                int(total_input_halo // 2), int(total_input_halo // 2)
279            )
280
281    return input_halo
282
283
284def get_block_transform(
285    model: v0_5.ModelDescr,
286) -> PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]]:
287    """returns how a model's output tensor shapes relates to its input shapes"""
288    ret: Dict[MemberId, Dict[AxisId, Union[LinearSampleAxisTransform, int]]] = {}
289    batch_axis_trf = None
290    for ipt in model.inputs:
291        for a in ipt.axes:
292            if a.type == "batch":
293                batch_axis_trf = LinearSampleAxisTransform(
294                    axis=a.id, scale=1, offset=0, member=ipt.id
295                )
296                break
297        if batch_axis_trf is not None:
298            break
299    axis_scales = {
300        t.id: {a.id: a.scale for a in t.axes}
301        for t in chain(model.inputs, model.outputs)
302    }
303    for out in model.outputs:
304        new_axes: Dict[AxisId, Union[LinearSampleAxisTransform, int]] = {}
305        for a in out.axes:
306            if a.size is None:
307                assert a.type == "batch"
308                if batch_axis_trf is None:
309                    raise ValueError(
310                        "no batch axis found in any input tensor, but output tensor"
311                        + f" '{out.id}' has one."
312                    )
313                s = batch_axis_trf
314            elif isinstance(a.size, int):
315                s = a.size
316            elif isinstance(a.size, v0_5.DataDependentSize):
317                s = -1
318            elif isinstance(a.size, v0_5.SizeReference):
319                s = LinearSampleAxisTransform(
320                    axis=a.size.axis_id,
321                    scale=axis_scales[a.size.tensor_id][a.size.axis_id] / a.scale,
322                    offset=a.size.offset,
323                    member=a.size.tensor_id,
324                )
325            else:
326                assert_never(a.size)
327
328            new_axes[a.id] = s
329
330        ret[out.id] = new_axes
331
332    return ret
333
334
335def get_io_sample_block_metas(
336    model: v0_5.ModelDescr,
337    input_sample_shape: PerMember[PerAxis[int]],
338    ns: Mapping[Tuple[MemberId, AxisId], ParameterizedSize_N],
339    batch_size: int = 1,
340) -> Tuple[TotalNumberOfBlocks, Iterable[IO_SampleBlockMeta]]:
341    """returns an iterable yielding meta data for corresponding input and output samples"""
342    if not isinstance(model, v0_5.ModelDescr):
343        raise TypeError(f"get_block_meta() not implemented for {type(model)}")
344
345    block_axis_sizes = model.get_axis_sizes(ns=ns, batch_size=batch_size)
346    input_block_shape = {
347        t: {aa: s for (tt, aa), s in block_axis_sizes.inputs.items() if tt == t}
348        for t in {tt for tt, _ in block_axis_sizes.inputs}
349    }
350    output_halo = {
351        t.id: {
352            a.id: Halo(a.halo, a.halo) for a in t.axes if isinstance(a, v0_5.WithHalo)
353        }
354        for t in model.outputs
355    }
356    input_halo = get_input_halo(model, output_halo)
357
358    n_input_blocks, input_blocks = split_multiple_shapes_into_blocks(
359        input_sample_shape, input_block_shape, halo=input_halo
360    )
361    block_transform = get_block_transform(model)
362    return n_input_blocks, (
363        IO_SampleBlockMeta(ipt, ipt.get_transformed(block_transform))
364        for ipt in sample_block_meta_generator(
365            input_blocks, sample_shape=input_sample_shape, sample_id=None
366        )
367    )
368
369
370def get_tensor(
371    src: Union[ZipPath, TensorSource],
372    ipt: Union[v0_4.InputTensorDescr, v0_5.InputTensorDescr],
373):
374    """helper to cast/load various tensor sources"""
375
376    if isinstance(src, Tensor):
377        return src
378    elif isinstance(src, xr.DataArray):
379        return Tensor.from_xarray(src)
380    elif isinstance(src, np.ndarray):
381        return Tensor.from_numpy(src, dims=get_axes_infos(ipt))
382    else:
383        return load_tensor(src, axes=get_axes_infos(ipt))
384
385
386def create_sample_for_model(
387    model: AnyModelDescr,
388    *,
389    stat: Optional[Stat] = None,
390    sample_id: SampleId = None,
391    inputs: Union[PerMember[TensorSource], TensorSource],
392) -> Sample:
393    """Create a sample from a single set of input(s) for a specific bioimage.io model
394
395    Args:
396        model: a bioimage.io model description
397        stat: dictionary with sample and dataset statistics (may be updated in-place!)
398        inputs: the input(s) constituting a single sample.
399    """
400
401    model_inputs = {get_member_id(d): d for d in model.inputs}
402    if isinstance(inputs, collections.abc.Mapping):
403        inputs = {MemberId(k): v for k, v in inputs.items()}
404    elif len(model_inputs) == 1:
405        inputs = {list(model_inputs)[0]: inputs}
406    else:
407        raise TypeError(
408            f"Expected `inputs` to be a mapping with keys {tuple(model_inputs)}"
409        )
410
411    if unknown := {k for k in inputs if k not in model_inputs}:
412        raise ValueError(f"Got unexpected inputs: {unknown}")
413
414    if missing := {
415        k
416        for k, v in model_inputs.items()
417        if k not in inputs and not (isinstance(v, v0_5.InputTensorDescr) and v.optional)
418    }:
419        raise ValueError(f"Missing non-optional model inputs: {missing}")
420
421    return Sample(
422        members={
423            m: get_tensor(inputs[m], ipt)
424            for m, ipt in model_inputs.items()
425            if m in inputs
426        },
427        stat={} if stat is None else stat,
428        id=sample_id,
429    )
430
431
432def load_sample_for_model(
433    *,
434    model: AnyModelDescr,
435    paths: PerMember[Path],
436    axes: Optional[PerMember[Sequence[AxisLike]]] = None,
437    stat: Optional[Stat] = None,
438    sample_id: Optional[SampleId] = None,
439):
440    """load a single sample from `paths` that can be processed by `model`"""
441
442    if axes is None:
443        axes = {}
444
445    # make sure members are keyed by MemberId, not string
446    paths = {MemberId(k): v for k, v in paths.items()}
447    axes = {MemberId(k): v for k, v in axes.items()}
448
449    model_inputs = {get_member_id(d): d for d in model.inputs}
450
451    if unknown := {k for k in paths if k not in model_inputs}:
452        raise ValueError(f"Got unexpected paths for {unknown}")
453
454    if unknown := {k for k in axes if k not in model_inputs}:
455        raise ValueError(f"Got unexpected axes hints for: {unknown}")
456
457    members: Dict[MemberId, Tensor] = {}
458    for m, p in paths.items():
459        if m not in axes:
460            axes[m] = get_axes_infos(model_inputs[m])
461            logger.debug(
462                "loading '{}' from {} with default input axes {} ",
463                m,
464                p,
465                axes[m],
466            )
467        members[m] = load_tensor(p, axes[m])
468
469    return Sample(
470        members=members,
471        stat={} if stat is None else stat,
472        id=sample_id or tuple(sorted(paths.values())),
473    )
TensorSource = typing.Union[bioimageio.core.Tensor, xarray.core.dataarray.DataArray, numpy.ndarray[typing.Any, numpy.dtype[typing.Any]], pathlib.Path]
def import_callable( node: Union[bioimageio.spec.model.v0_5.ArchitectureFromFileDescr, bioimageio.spec.model.v0_5.ArchitectureFromLibraryDescr, bioimageio.spec.model.v0_4.CallableFromDepencency, bioimageio.spec.model.v0_4.CallableFromFile], /, **kwargs: Unpack[bioimageio.spec._internal.io.HashKwargs]) -> Callable[..., Any]:
58def import_callable(
59    node: Union[
60        ArchitectureFromFileDescr,
61        ArchitectureFromLibraryDescr,
62        CallableFromDepencency,
63        CallableFromFile,
64    ],
65    /,
66    **kwargs: Unpack[HashKwargs],
67) -> Callable[..., Any]:
68    """import a callable (e.g. a torch.nn.Module) from a spec node describing it"""
69    if isinstance(node, CallableFromDepencency):
70        module = importlib.import_module(node.module_name)
71        c = getattr(module, str(node.callable_name))
72    elif isinstance(node, ArchitectureFromLibraryDescr):
73        module = importlib.import_module(node.import_from)
74        c = getattr(module, str(node.callable))
75    elif isinstance(node, CallableFromFile):
76        c = _import_from_file_impl(node.source_file, str(node.callable_name), **kwargs)
77    elif isinstance(node, ArchitectureFromFileDescr):
78        c = _import_from_file_impl(node.source, str(node.callable), sha256=node.sha256)
79    else:
80        assert_never(node)
81
82    if not callable(c):
83        raise ValueError(f"{node} (imported: {c}) is not callable")
84
85    return c

import a callable (e.g. a torch.nn.Module) from a spec node describing it

154def get_axes_infos(
155    io_descr: Union[
156        v0_4.InputTensorDescr,
157        v0_4.OutputTensorDescr,
158        v0_5.InputTensorDescr,
159        v0_5.OutputTensorDescr,
160    ],
161) -> List[AxisInfo]:
162    """get a unified, simplified axis representation from spec axes"""
163    ret: List[AxisInfo] = []
164    for a in io_descr.axes:
165        if isinstance(a, v0_5.AxisBase):
166            ret.append(AxisInfo.create(Axis(id=a.id, type=a.type)))
167        else:
168            assert a in ("b", "i", "t", "c", "z", "y", "x")
169            ret.append(AxisInfo.create(a))
170
171    return ret

get a unified, simplified axis representation from spec axes

174def get_member_id(
175    tensor_description: Union[
176        v0_4.InputTensorDescr,
177        v0_4.OutputTensorDescr,
178        v0_5.InputTensorDescr,
179        v0_5.OutputTensorDescr,
180    ],
181) -> MemberId:
182    """get the normalized tensor ID, usable as a sample member ID"""
183
184    if isinstance(tensor_description, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)):
185        return MemberId(tensor_description.name)
186    elif isinstance(
187        tensor_description, (v0_5.InputTensorDescr, v0_5.OutputTensorDescr)
188    ):
189        return tensor_description.id
190    else:
191        assert_never(tensor_description)

get the normalized tensor ID, usable as a sample member ID

194def get_member_ids(
195    tensor_descriptions: Sequence[
196        Union[
197            v0_4.InputTensorDescr,
198            v0_4.OutputTensorDescr,
199            v0_5.InputTensorDescr,
200            v0_5.OutputTensorDescr,
201        ]
202    ],
203) -> List[MemberId]:
204    """get normalized tensor IDs to be used as sample member IDs"""
205    return [get_member_id(descr) for descr in tensor_descriptions]

get normalized tensor IDs to be used as sample member IDs

def get_test_inputs( model: Annotated[Union[Annotated[bioimageio.spec.model.v0_4.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.4')], Annotated[bioimageio.spec.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.5')]], Discriminator(discriminator='format_version', custom_error_type=None, custom_error_message=None, custom_error_context=None), FieldInfo(annotation=NoneType, required=True, title='model')]) -> bioimageio.core.Sample:
208def get_test_inputs(model: AnyModelDescr) -> Sample:
209    """returns a model's test input sample"""
210    member_ids = get_member_ids(model.inputs)
211    if isinstance(model, v0_4.ModelDescr):
212        arrays = [load_array(tt) for tt in model.test_inputs]
213    else:
214        arrays = [load_array(d.test_tensor) for d in model.inputs]
215
216    axes = [get_axes_infos(t) for t in model.inputs]
217    return Sample(
218        members={
219            m: Tensor.from_numpy(arr, dims=ax)
220            for m, arr, ax in zip(member_ids, arrays, axes)
221        },
222        stat={},
223        id="test-sample",
224    )

returns a model's test input sample

def get_test_outputs( model: Annotated[Union[Annotated[bioimageio.spec.model.v0_4.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.4')], Annotated[bioimageio.spec.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.5')]], Discriminator(discriminator='format_version', custom_error_type=None, custom_error_message=None, custom_error_context=None), FieldInfo(annotation=NoneType, required=True, title='model')]) -> bioimageio.core.Sample:
227def get_test_outputs(model: AnyModelDescr) -> Sample:
228    """returns a model's test output sample"""
229    member_ids = get_member_ids(model.outputs)
230
231    if isinstance(model, v0_4.ModelDescr):
232        arrays = [load_array(tt) for tt in model.test_outputs]
233    else:
234        arrays = [load_array(d.test_tensor) for d in model.outputs]
235
236    axes = [get_axes_infos(t) for t in model.outputs]
237
238    return Sample(
239        members={
240            m: Tensor.from_numpy(arr, dims=ax)
241            for m, arr, ax in zip(member_ids, arrays, axes)
242        },
243        stat={},
244        id="test-sample",
245    )

returns a model's test output sample

class IO_SampleBlockMeta(typing.NamedTuple):
248class IO_SampleBlockMeta(NamedTuple):
249    input: SampleBlockMeta
250    output: SampleBlockMeta

IO_SampleBlockMeta(input, output)

IO_SampleBlockMeta( input: ForwardRef('SampleBlockMeta'), output: ForwardRef('SampleBlockMeta'))

Create new instance of IO_SampleBlockMeta(input, output)

Alias for field number 0

Alias for field number 1

253def get_input_halo(model: v0_5.ModelDescr, output_halo: PerMember[PerAxis[Halo]]):
254    """returns which halo input tensors need to be divided into blocks with, such that
255    `output_halo` can be cropped from their outputs without introducing gaps."""
256    input_halo: Dict[MemberId, Dict[AxisId, Halo]] = {}
257    outputs = {t.id: t for t in model.outputs}
258    all_tensors = {**{t.id: t for t in model.inputs}, **outputs}
259
260    for t, th in output_halo.items():
261        axes = {a.id: a for a in outputs[t].axes}
262
263        for a, ah in th.items():
264            s = axes[a].size
265            if not isinstance(s, v0_5.SizeReference):
266                raise ValueError(
267                    f"Unable to map output halo for {t}.{a} to an input axis"
268                )
269
270            axis = axes[a]
271            ref_axis = {a.id: a for a in all_tensors[s.tensor_id].axes}[s.axis_id]
272
273            total_output_halo = sum(ah)
274            total_input_halo = total_output_halo * axis.scale / ref_axis.scale
275            assert (
276                total_input_halo == int(total_input_halo) and total_input_halo % 2 == 0
277            )
278            input_halo.setdefault(s.tensor_id, {})[a] = Halo(
279                int(total_input_halo // 2), int(total_input_halo // 2)
280            )
281
282    return input_halo

returns which halo input tensors need to be divided into blocks with, such that output_halo can be cropped from their outputs without introducing gaps.

285def get_block_transform(
286    model: v0_5.ModelDescr,
287) -> PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]]:
288    """returns how a model's output tensor shapes relates to its input shapes"""
289    ret: Dict[MemberId, Dict[AxisId, Union[LinearSampleAxisTransform, int]]] = {}
290    batch_axis_trf = None
291    for ipt in model.inputs:
292        for a in ipt.axes:
293            if a.type == "batch":
294                batch_axis_trf = LinearSampleAxisTransform(
295                    axis=a.id, scale=1, offset=0, member=ipt.id
296                )
297                break
298        if batch_axis_trf is not None:
299            break
300    axis_scales = {
301        t.id: {a.id: a.scale for a in t.axes}
302        for t in chain(model.inputs, model.outputs)
303    }
304    for out in model.outputs:
305        new_axes: Dict[AxisId, Union[LinearSampleAxisTransform, int]] = {}
306        for a in out.axes:
307            if a.size is None:
308                assert a.type == "batch"
309                if batch_axis_trf is None:
310                    raise ValueError(
311                        "no batch axis found in any input tensor, but output tensor"
312                        + f" '{out.id}' has one."
313                    )
314                s = batch_axis_trf
315            elif isinstance(a.size, int):
316                s = a.size
317            elif isinstance(a.size, v0_5.DataDependentSize):
318                s = -1
319            elif isinstance(a.size, v0_5.SizeReference):
320                s = LinearSampleAxisTransform(
321                    axis=a.size.axis_id,
322                    scale=axis_scales[a.size.tensor_id][a.size.axis_id] / a.scale,
323                    offset=a.size.offset,
324                    member=a.size.tensor_id,
325                )
326            else:
327                assert_never(a.size)
328
329            new_axes[a.id] = s
330
331        ret[out.id] = new_axes
332
333    return ret

returns how a model's output tensor shapes relates to its input shapes

def get_io_sample_block_metas( model: bioimageio.spec.ModelDescr, input_sample_shape: Mapping[bioimageio.spec.model.v0_5.TensorId, Mapping[bioimageio.spec.model.v0_5.AxisId, int]], ns: Mapping[Tuple[bioimageio.spec.model.v0_5.TensorId, bioimageio.spec.model.v0_5.AxisId], int], batch_size: int = 1) -> Tuple[int, Iterable[IO_SampleBlockMeta]]:
336def get_io_sample_block_metas(
337    model: v0_5.ModelDescr,
338    input_sample_shape: PerMember[PerAxis[int]],
339    ns: Mapping[Tuple[MemberId, AxisId], ParameterizedSize_N],
340    batch_size: int = 1,
341) -> Tuple[TotalNumberOfBlocks, Iterable[IO_SampleBlockMeta]]:
342    """returns an iterable yielding meta data for corresponding input and output samples"""
343    if not isinstance(model, v0_5.ModelDescr):
344        raise TypeError(f"get_block_meta() not implemented for {type(model)}")
345
346    block_axis_sizes = model.get_axis_sizes(ns=ns, batch_size=batch_size)
347    input_block_shape = {
348        t: {aa: s for (tt, aa), s in block_axis_sizes.inputs.items() if tt == t}
349        for t in {tt for tt, _ in block_axis_sizes.inputs}
350    }
351    output_halo = {
352        t.id: {
353            a.id: Halo(a.halo, a.halo) for a in t.axes if isinstance(a, v0_5.WithHalo)
354        }
355        for t in model.outputs
356    }
357    input_halo = get_input_halo(model, output_halo)
358
359    n_input_blocks, input_blocks = split_multiple_shapes_into_blocks(
360        input_sample_shape, input_block_shape, halo=input_halo
361    )
362    block_transform = get_block_transform(model)
363    return n_input_blocks, (
364        IO_SampleBlockMeta(ipt, ipt.get_transformed(block_transform))
365        for ipt in sample_block_meta_generator(
366            input_blocks, sample_shape=input_sample_shape, sample_id=None
367        )
368    )

returns an iterable yielding meta data for corresponding input and output samples

def get_tensor( src: Union[zipp.Path, bioimageio.core.Tensor, xarray.core.dataarray.DataArray, numpy.ndarray[Any, numpy.dtype[Any]], pathlib.Path], ipt: Union[bioimageio.spec.model.v0_4.InputTensorDescr, bioimageio.spec.model.v0_5.InputTensorDescr]):
371def get_tensor(
372    src: Union[ZipPath, TensorSource],
373    ipt: Union[v0_4.InputTensorDescr, v0_5.InputTensorDescr],
374):
375    """helper to cast/load various tensor sources"""
376
377    if isinstance(src, Tensor):
378        return src
379    elif isinstance(src, xr.DataArray):
380        return Tensor.from_xarray(src)
381    elif isinstance(src, np.ndarray):
382        return Tensor.from_numpy(src, dims=get_axes_infos(ipt))
383    else:
384        return load_tensor(src, axes=get_axes_infos(ipt))

helper to cast/load various tensor sources

def create_sample_for_model( model: Annotated[Union[Annotated[bioimageio.spec.model.v0_4.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.4')], Annotated[bioimageio.spec.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.5')]], Discriminator(discriminator='format_version', custom_error_type=None, custom_error_message=None, custom_error_context=None), FieldInfo(annotation=NoneType, required=True, title='model')], *, stat: Optional[Dict[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer>, return_type=PydanticUndefined, when_used='always')]]]] = None, sample_id: Hashable = None, inputs: Union[Mapping[bioimageio.spec.model.v0_5.TensorId, Union[bioimageio.core.Tensor, xarray.core.dataarray.DataArray, numpy.ndarray[Any, numpy.dtype[Any]], pathlib.Path]], bioimageio.core.Tensor, xarray.core.dataarray.DataArray, numpy.ndarray[Any, numpy.dtype[Any]], pathlib.Path]) -> bioimageio.core.Sample:
387def create_sample_for_model(
388    model: AnyModelDescr,
389    *,
390    stat: Optional[Stat] = None,
391    sample_id: SampleId = None,
392    inputs: Union[PerMember[TensorSource], TensorSource],
393) -> Sample:
394    """Create a sample from a single set of input(s) for a specific bioimage.io model
395
396    Args:
397        model: a bioimage.io model description
398        stat: dictionary with sample and dataset statistics (may be updated in-place!)
399        inputs: the input(s) constituting a single sample.
400    """
401
402    model_inputs = {get_member_id(d): d for d in model.inputs}
403    if isinstance(inputs, collections.abc.Mapping):
404        inputs = {MemberId(k): v for k, v in inputs.items()}
405    elif len(model_inputs) == 1:
406        inputs = {list(model_inputs)[0]: inputs}
407    else:
408        raise TypeError(
409            f"Expected `inputs` to be a mapping with keys {tuple(model_inputs)}"
410        )
411
412    if unknown := {k for k in inputs if k not in model_inputs}:
413        raise ValueError(f"Got unexpected inputs: {unknown}")
414
415    if missing := {
416        k
417        for k, v in model_inputs.items()
418        if k not in inputs and not (isinstance(v, v0_5.InputTensorDescr) and v.optional)
419    }:
420        raise ValueError(f"Missing non-optional model inputs: {missing}")
421
422    return Sample(
423        members={
424            m: get_tensor(inputs[m], ipt)
425            for m, ipt in model_inputs.items()
426            if m in inputs
427        },
428        stat={} if stat is None else stat,
429        id=sample_id,
430    )

Create a sample from a single set of input(s) for a specific bioimage.io model

Arguments:
  • model: a bioimage.io model description
  • stat: dictionary with sample and dataset statistics (may be updated in-place!)
  • inputs: the input(s) constituting a single sample.
def load_sample_for_model( *, model: Annotated[Union[Annotated[bioimageio.spec.model.v0_4.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.4')], Annotated[bioimageio.spec.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.5')]], Discriminator(discriminator='format_version', custom_error_type=None, custom_error_message=None, custom_error_context=None), FieldInfo(annotation=NoneType, required=True, title='model')], paths: Mapping[bioimageio.spec.model.v0_5.TensorId, pathlib.Path], axes: Optional[Mapping[bioimageio.spec.model.v0_5.TensorId, Sequence[Union[bioimageio.spec.model.v0_5.AxisId, Literal['b', 'i', 't', 'c', 'z', 'y', 'x'], Annotated[Union[bioimageio.spec.model.v0_5.BatchAxis, bioimageio.spec.model.v0_5.ChannelAxis, bioimageio.spec.model.v0_5.IndexInputAxis, bioimageio.spec.model.v0_5.TimeInputAxis, bioimageio.spec.model.v0_5.SpaceInputAxis], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.spec.model.v0_5.BatchAxis, bioimageio.spec.model.v0_5.ChannelAxis, bioimageio.spec.model.v0_5.IndexOutputAxis, Annotated[Union[Annotated[bioimageio.spec.model.v0_5.TimeOutputAxis, Tag(tag='wo_halo')], Annotated[bioimageio.spec.model.v0_5.TimeOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[Annotated[bioimageio.spec.model.v0_5.SpaceOutputAxis, Tag(tag='wo_halo')], Annotated[bioimageio.spec.model.v0_5.SpaceOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], bioimageio.core.Axis]]]] = None, stat: Optional[Dict[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer>, return_type=PydanticUndefined, when_used='always')]]]] = None, sample_id: Optional[Hashable] = None):
433def load_sample_for_model(
434    *,
435    model: AnyModelDescr,
436    paths: PerMember[Path],
437    axes: Optional[PerMember[Sequence[AxisLike]]] = None,
438    stat: Optional[Stat] = None,
439    sample_id: Optional[SampleId] = None,
440):
441    """load a single sample from `paths` that can be processed by `model`"""
442
443    if axes is None:
444        axes = {}
445
446    # make sure members are keyed by MemberId, not string
447    paths = {MemberId(k): v for k, v in paths.items()}
448    axes = {MemberId(k): v for k, v in axes.items()}
449
450    model_inputs = {get_member_id(d): d for d in model.inputs}
451
452    if unknown := {k for k in paths if k not in model_inputs}:
453        raise ValueError(f"Got unexpected paths for {unknown}")
454
455    if unknown := {k for k in axes if k not in model_inputs}:
456        raise ValueError(f"Got unexpected axes hints for: {unknown}")
457
458    members: Dict[MemberId, Tensor] = {}
459    for m, p in paths.items():
460        if m not in axes:
461            axes[m] = get_axes_infos(model_inputs[m])
462            logger.debug(
463                "loading '{}' from {} with default input axes {} ",
464                m,
465                p,
466                axes[m],
467            )
468        members[m] = load_tensor(p, axes[m])
469
470    return Sample(
471        members=members,
472        stat={} if stat is None else stat,
473        id=sample_id or tuple(sorted(paths.values())),
474    )

load a single sample from paths that can be processed by model