bioimageio.core.digest_spec

  1from __future__ import annotations
  2
  3import collections.abc
  4import hashlib
  5import importlib.util
  6import sys
  7from itertools import chain
  8from pathlib import Path
  9from typing import (
 10    Any,
 11    Callable,
 12    Dict,
 13    Iterable,
 14    List,
 15    Mapping,
 16    NamedTuple,
 17    Optional,
 18    Sequence,
 19    Tuple,
 20    Union,
 21)
 22
 23import numpy as np
 24import xarray as xr
 25from loguru import logger
 26from numpy.typing import NDArray
 27from typing_extensions import Unpack, assert_never
 28
 29from bioimageio.spec._internal.io import HashKwargs
 30from bioimageio.spec.common import FileDescr, FileSource, ZipPath
 31from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
 32from bioimageio.spec.model.v0_4 import CallableFromDepencency, CallableFromFile
 33from bioimageio.spec.model.v0_5 import (
 34    ArchitectureFromFileDescr,
 35    ArchitectureFromLibraryDescr,
 36    ParameterizedSize_N,
 37)
 38from bioimageio.spec.utils import download, load_array
 39
 40from ._settings import settings
 41from .axis import Axis, AxisId, AxisInfo, AxisLike, PerAxis
 42from .block_meta import split_multiple_shapes_into_blocks
 43from .common import Halo, MemberId, PerMember, SampleId, TotalNumberOfBlocks
 44from .io import load_tensor
 45from .sample import (
 46    LinearSampleAxisTransform,
 47    Sample,
 48    SampleBlockMeta,
 49    sample_block_meta_generator,
 50)
 51from .stat_measures import Stat
 52from .tensor import Tensor
 53
 54TensorSource = Union[Tensor, xr.DataArray, NDArray[Any], Path]
 55
 56
 57def import_callable(
 58    node: Union[
 59        ArchitectureFromFileDescr,
 60        ArchitectureFromLibraryDescr,
 61        CallableFromDepencency,
 62        CallableFromFile,
 63    ],
 64    /,
 65    **kwargs: Unpack[HashKwargs],
 66) -> Callable[..., Any]:
 67    """import a callable (e.g. a torch.nn.Module) from a spec node describing it"""
 68    if isinstance(node, CallableFromDepencency):
 69        module = importlib.import_module(node.module_name)
 70        c = getattr(module, str(node.callable_name))
 71    elif isinstance(node, ArchitectureFromLibraryDescr):
 72        module = importlib.import_module(node.import_from)
 73        c = getattr(module, str(node.callable))
 74    elif isinstance(node, CallableFromFile):
 75        c = _import_from_file_impl(node.source_file, str(node.callable_name), **kwargs)
 76    elif isinstance(node, ArchitectureFromFileDescr):
 77        c = _import_from_file_impl(node.source, str(node.callable), sha256=node.sha256)
 78    else:
 79        assert_never(node)
 80
 81    if not callable(c):
 82        raise ValueError(f"{node} (imported: {c}) is not callable")
 83
 84    return c
 85
 86
 87def _import_from_file_impl(
 88    source: FileSource, callable_name: str, **kwargs: Unpack[HashKwargs]
 89):
 90    src_descr = FileDescr(source=source, **kwargs)
 91    # ensure sha is valid even if perform_io_checks=False
 92    src_descr.validate_sha256()
 93    assert src_descr.sha256 is not None
 94
 95    local_source = src_descr.download()
 96
 97    source_bytes = local_source.path.read_bytes()
 98    assert isinstance(source_bytes, bytes)
 99    source_sha = hashlib.sha256(source_bytes).hexdigest()
100
101    # make sure we have unique module name
102    module_name = f"{local_source.path.stem}_{source_sha}"
103
104    # make sure we have a valid module name
105    if not module_name.isidentifier():
106        module_name = f"custom_module_{source_sha}"
107        assert module_name.isidentifier(), module_name
108
109    module = sys.modules.get(module_name)
110    if module is None:
111        try:
112            if isinstance(local_source.path, Path):
113                module_path = local_source.path
114            elif isinstance(local_source.path, ZipPath):
115                # save extract source to cache
116                # loading from a file from disk ensure we get readable tracebacks
117                # if any errors occur
118                module_path = (
119                    settings.cache_path / f"{source_sha}-{local_source.path.name}"
120                )
121                _ = module_path.write_bytes(source_bytes)
122            else:
123                assert_never(local_source.path)
124
125            importlib_spec = importlib.util.spec_from_file_location(
126                module_name, module_path
127            )
128
129            if importlib_spec is None:
130                raise ImportError(f"Failed to import {source}")
131
132            module = importlib.util.module_from_spec(importlib_spec)
133            assert importlib_spec.loader is not None
134            importlib_spec.loader.exec_module(module)
135
136        except Exception as e:
137            raise ImportError(f"Failed to import {source}") from e
138        else:
139            sys.modules[module_name] = module  # cache this module
140
141    try:
142        callable_attr = getattr(module, callable_name)
143    except AttributeError as e:
144        raise AttributeError(
145            f"Imported custom module from {source} has no `{callable_name}` attribute."
146        ) from e
147    except Exception as e:
148        raise AttributeError(
149            f"Failed to access `{callable_name}` attribute from custom module imported from {source} ."
150        ) from e
151
152    else:
153        return callable_attr
154
155
156def get_axes_infos(
157    io_descr: Union[
158        v0_4.InputTensorDescr,
159        v0_4.OutputTensorDescr,
160        v0_5.InputTensorDescr,
161        v0_5.OutputTensorDescr,
162    ],
163) -> List[AxisInfo]:
164    """get a unified, simplified axis representation from spec axes"""
165    ret: List[AxisInfo] = []
166    for a in io_descr.axes:
167        if isinstance(a, v0_5.AxisBase):
168            ret.append(AxisInfo.create(Axis(id=a.id, type=a.type)))
169        else:
170            assert a in ("b", "i", "t", "c", "z", "y", "x")
171            ret.append(AxisInfo.create(a))
172
173    return ret
174
175
176def get_member_id(
177    tensor_description: Union[
178        v0_4.InputTensorDescr,
179        v0_4.OutputTensorDescr,
180        v0_5.InputTensorDescr,
181        v0_5.OutputTensorDescr,
182    ],
183) -> MemberId:
184    """get the normalized tensor ID, usable as a sample member ID"""
185
186    if isinstance(tensor_description, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)):
187        return MemberId(tensor_description.name)
188    elif isinstance(
189        tensor_description, (v0_5.InputTensorDescr, v0_5.OutputTensorDescr)
190    ):
191        return tensor_description.id
192    else:
193        assert_never(tensor_description)
194
195
196def get_member_ids(
197    tensor_descriptions: Sequence[
198        Union[
199            v0_4.InputTensorDescr,
200            v0_4.OutputTensorDescr,
201            v0_5.InputTensorDescr,
202            v0_5.OutputTensorDescr,
203        ]
204    ],
205) -> List[MemberId]:
206    """get normalized tensor IDs to be used as sample member IDs"""
207    return [get_member_id(descr) for descr in tensor_descriptions]
208
209
210def get_test_inputs(model: AnyModelDescr) -> Sample:
211    """returns a model's test input sample"""
212    member_ids = get_member_ids(model.inputs)
213    if isinstance(model, v0_4.ModelDescr):
214        arrays = [load_array(tt) for tt in model.test_inputs]
215    else:
216        arrays = [load_array(d.test_tensor) for d in model.inputs]
217
218    axes = [get_axes_infos(t) for t in model.inputs]
219    return Sample(
220        members={
221            m: Tensor.from_numpy(arr, dims=ax)
222            for m, arr, ax in zip(member_ids, arrays, axes)
223        },
224        stat={},
225        id="test-sample",
226    )
227
228
229def get_test_outputs(model: AnyModelDescr) -> Sample:
230    """returns a model's test output sample"""
231    member_ids = get_member_ids(model.outputs)
232
233    if isinstance(model, v0_4.ModelDescr):
234        arrays = [load_array(tt) for tt in model.test_outputs]
235    else:
236        arrays = [load_array(d.test_tensor) for d in model.outputs]
237
238    axes = [get_axes_infos(t) for t in model.outputs]
239
240    return Sample(
241        members={
242            m: Tensor.from_numpy(arr, dims=ax)
243            for m, arr, ax in zip(member_ids, arrays, axes)
244        },
245        stat={},
246        id="test-sample",
247    )
248
249
250class IO_SampleBlockMeta(NamedTuple):
251    input: SampleBlockMeta
252    output: SampleBlockMeta
253
254
255def get_input_halo(model: v0_5.ModelDescr, output_halo: PerMember[PerAxis[Halo]]):
256    """returns which halo input tensors need to be divided into blocks with, such that
257    `output_halo` can be cropped from their outputs without introducing gaps."""
258    input_halo: Dict[MemberId, Dict[AxisId, Halo]] = {}
259    outputs = {t.id: t for t in model.outputs}
260    all_tensors = {**{t.id: t for t in model.inputs}, **outputs}
261
262    for t, th in output_halo.items():
263        axes = {a.id: a for a in outputs[t].axes}
264
265        for a, ah in th.items():
266            s = axes[a].size
267            if not isinstance(s, v0_5.SizeReference):
268                raise ValueError(
269                    f"Unable to map output halo for {t}.{a} to an input axis"
270                )
271
272            axis = axes[a]
273            ref_axis = {a.id: a for a in all_tensors[s.tensor_id].axes}[s.axis_id]
274
275            total_output_halo = sum(ah)
276            total_input_halo = total_output_halo * axis.scale / ref_axis.scale
277            assert (
278                total_input_halo == int(total_input_halo) and total_input_halo % 2 == 0
279            )
280            input_halo.setdefault(s.tensor_id, {})[a] = Halo(
281                int(total_input_halo // 2), int(total_input_halo // 2)
282            )
283
284    return input_halo
285
286
287def get_block_transform(
288    model: v0_5.ModelDescr,
289) -> PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]]:
290    """returns how a model's output tensor shapes relates to its input shapes"""
291    ret: Dict[MemberId, Dict[AxisId, Union[LinearSampleAxisTransform, int]]] = {}
292    batch_axis_trf = None
293    for ipt in model.inputs:
294        for a in ipt.axes:
295            if a.type == "batch":
296                batch_axis_trf = LinearSampleAxisTransform(
297                    axis=a.id, scale=1, offset=0, member=ipt.id
298                )
299                break
300        if batch_axis_trf is not None:
301            break
302    axis_scales = {
303        t.id: {a.id: a.scale for a in t.axes}
304        for t in chain(model.inputs, model.outputs)
305    }
306    for out in model.outputs:
307        new_axes: Dict[AxisId, Union[LinearSampleAxisTransform, int]] = {}
308        for a in out.axes:
309            if a.size is None:
310                assert a.type == "batch"
311                if batch_axis_trf is None:
312                    raise ValueError(
313                        "no batch axis found in any input tensor, but output tensor"
314                        + f" '{out.id}' has one."
315                    )
316                s = batch_axis_trf
317            elif isinstance(a.size, int):
318                s = a.size
319            elif isinstance(a.size, v0_5.DataDependentSize):
320                s = -1
321            elif isinstance(a.size, v0_5.SizeReference):
322                s = LinearSampleAxisTransform(
323                    axis=a.size.axis_id,
324                    scale=axis_scales[a.size.tensor_id][a.size.axis_id] / a.scale,
325                    offset=a.size.offset,
326                    member=a.size.tensor_id,
327                )
328            else:
329                assert_never(a.size)
330
331            new_axes[a.id] = s
332
333        ret[out.id] = new_axes
334
335    return ret
336
337
338def get_io_sample_block_metas(
339    model: v0_5.ModelDescr,
340    input_sample_shape: PerMember[PerAxis[int]],
341    ns: Mapping[Tuple[MemberId, AxisId], ParameterizedSize_N],
342    batch_size: int = 1,
343) -> Tuple[TotalNumberOfBlocks, Iterable[IO_SampleBlockMeta]]:
344    """returns an iterable yielding meta data for corresponding input and output samples"""
345    if not isinstance(model, v0_5.ModelDescr):
346        raise TypeError(f"get_block_meta() not implemented for {type(model)}")
347
348    block_axis_sizes = model.get_axis_sizes(ns=ns, batch_size=batch_size)
349    input_block_shape = {
350        t: {aa: s for (tt, aa), s in block_axis_sizes.inputs.items() if tt == t}
351        for t in {tt for tt, _ in block_axis_sizes.inputs}
352    }
353    output_halo = {
354        t.id: {
355            a.id: Halo(a.halo, a.halo) for a in t.axes if isinstance(a, v0_5.WithHalo)
356        }
357        for t in model.outputs
358    }
359    input_halo = get_input_halo(model, output_halo)
360
361    n_input_blocks, input_blocks = split_multiple_shapes_into_blocks(
362        input_sample_shape, input_block_shape, halo=input_halo
363    )
364    block_transform = get_block_transform(model)
365    return n_input_blocks, (
366        IO_SampleBlockMeta(ipt, ipt.get_transformed(block_transform))
367        for ipt in sample_block_meta_generator(
368            input_blocks, sample_shape=input_sample_shape, sample_id=None
369        )
370    )
371
372
373def get_tensor(
374    src: Union[ZipPath, TensorSource],
375    ipt: Union[v0_4.InputTensorDescr, v0_5.InputTensorDescr],
376):
377    """helper to cast/load various tensor sources"""
378
379    if isinstance(src, Tensor):
380        return src
381
382    if isinstance(src, xr.DataArray):
383        return Tensor.from_xarray(src)
384
385    if isinstance(src, np.ndarray):
386        return Tensor.from_numpy(src, dims=get_axes_infos(ipt))
387
388    if isinstance(src, FileDescr):
389        src = download(src).path
390
391    if isinstance(src, (ZipPath, Path, str)):
392        return load_tensor(src, axes=get_axes_infos(ipt))
393
394    assert_never(src)
395
396
397def create_sample_for_model(
398    model: AnyModelDescr,
399    *,
400    stat: Optional[Stat] = None,
401    sample_id: SampleId = None,
402    inputs: Union[PerMember[TensorSource], TensorSource],
403) -> Sample:
404    """Create a sample from a single set of input(s) for a specific bioimage.io model
405
406    Args:
407        model: a bioimage.io model description
408        stat: dictionary with sample and dataset statistics (may be updated in-place!)
409        inputs: the input(s) constituting a single sample.
410    """
411
412    model_inputs = {get_member_id(d): d for d in model.inputs}
413    if isinstance(inputs, collections.abc.Mapping):
414        inputs = {MemberId(k): v for k, v in inputs.items()}
415    elif len(model_inputs) == 1:
416        inputs = {list(model_inputs)[0]: inputs}
417    else:
418        raise TypeError(
419            f"Expected `inputs` to be a mapping with keys {tuple(model_inputs)}"
420        )
421
422    if unknown := {k for k in inputs if k not in model_inputs}:
423        raise ValueError(f"Got unexpected inputs: {unknown}")
424
425    if missing := {
426        k
427        for k, v in model_inputs.items()
428        if k not in inputs and not (isinstance(v, v0_5.InputTensorDescr) and v.optional)
429    }:
430        raise ValueError(f"Missing non-optional model inputs: {missing}")
431
432    return Sample(
433        members={
434            m: get_tensor(inputs[m], ipt)
435            for m, ipt in model_inputs.items()
436            if m in inputs
437        },
438        stat={} if stat is None else stat,
439        id=sample_id,
440    )
441
442
443def load_sample_for_model(
444    *,
445    model: AnyModelDescr,
446    paths: PerMember[Path],
447    axes: Optional[PerMember[Sequence[AxisLike]]] = None,
448    stat: Optional[Stat] = None,
449    sample_id: Optional[SampleId] = None,
450):
451    """load a single sample from `paths` that can be processed by `model`"""
452
453    if axes is None:
454        axes = {}
455
456    # make sure members are keyed by MemberId, not string
457    paths = {MemberId(k): v for k, v in paths.items()}
458    axes = {MemberId(k): v for k, v in axes.items()}
459
460    model_inputs = {get_member_id(d): d for d in model.inputs}
461
462    if unknown := {k for k in paths if k not in model_inputs}:
463        raise ValueError(f"Got unexpected paths for {unknown}")
464
465    if unknown := {k for k in axes if k not in model_inputs}:
466        raise ValueError(f"Got unexpected axes hints for: {unknown}")
467
468    members: Dict[MemberId, Tensor] = {}
469    for m, p in paths.items():
470        if m not in axes:
471            axes[m] = get_axes_infos(model_inputs[m])
472            logger.debug(
473                "loading '{}' from {} with default input axes {} ",
474                m,
475                p,
476                axes[m],
477            )
478        members[m] = load_tensor(p, axes[m])
479
480    return Sample(
481        members=members,
482        stat={} if stat is None else stat,
483        id=sample_id or tuple(sorted(paths.values())),
484    )
TensorSource = typing.Union[bioimageio.core.Tensor, xarray.core.dataarray.DataArray, numpy.ndarray[typing.Any, numpy.dtype[typing.Any]], pathlib.Path]
def import_callable( node: Union[bioimageio.spec.model.v0_5.ArchitectureFromFileDescr, bioimageio.spec.model.v0_5.ArchitectureFromLibraryDescr, bioimageio.spec.model.v0_4.CallableFromDepencency, bioimageio.spec.model.v0_4.CallableFromFile], /, **kwargs: Unpack[bioimageio.spec._internal.io.HashKwargs]) -> Callable[..., Any]:
58def import_callable(
59    node: Union[
60        ArchitectureFromFileDescr,
61        ArchitectureFromLibraryDescr,
62        CallableFromDepencency,
63        CallableFromFile,
64    ],
65    /,
66    **kwargs: Unpack[HashKwargs],
67) -> Callable[..., Any]:
68    """import a callable (e.g. a torch.nn.Module) from a spec node describing it"""
69    if isinstance(node, CallableFromDepencency):
70        module = importlib.import_module(node.module_name)
71        c = getattr(module, str(node.callable_name))
72    elif isinstance(node, ArchitectureFromLibraryDescr):
73        module = importlib.import_module(node.import_from)
74        c = getattr(module, str(node.callable))
75    elif isinstance(node, CallableFromFile):
76        c = _import_from_file_impl(node.source_file, str(node.callable_name), **kwargs)
77    elif isinstance(node, ArchitectureFromFileDescr):
78        c = _import_from_file_impl(node.source, str(node.callable), sha256=node.sha256)
79    else:
80        assert_never(node)
81
82    if not callable(c):
83        raise ValueError(f"{node} (imported: {c}) is not callable")
84
85    return c

import a callable (e.g. a torch.nn.Module) from a spec node describing it

157def get_axes_infos(
158    io_descr: Union[
159        v0_4.InputTensorDescr,
160        v0_4.OutputTensorDescr,
161        v0_5.InputTensorDescr,
162        v0_5.OutputTensorDescr,
163    ],
164) -> List[AxisInfo]:
165    """get a unified, simplified axis representation from spec axes"""
166    ret: List[AxisInfo] = []
167    for a in io_descr.axes:
168        if isinstance(a, v0_5.AxisBase):
169            ret.append(AxisInfo.create(Axis(id=a.id, type=a.type)))
170        else:
171            assert a in ("b", "i", "t", "c", "z", "y", "x")
172            ret.append(AxisInfo.create(a))
173
174    return ret

get a unified, simplified axis representation from spec axes

177def get_member_id(
178    tensor_description: Union[
179        v0_4.InputTensorDescr,
180        v0_4.OutputTensorDescr,
181        v0_5.InputTensorDescr,
182        v0_5.OutputTensorDescr,
183    ],
184) -> MemberId:
185    """get the normalized tensor ID, usable as a sample member ID"""
186
187    if isinstance(tensor_description, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)):
188        return MemberId(tensor_description.name)
189    elif isinstance(
190        tensor_description, (v0_5.InputTensorDescr, v0_5.OutputTensorDescr)
191    ):
192        return tensor_description.id
193    else:
194        assert_never(tensor_description)

get the normalized tensor ID, usable as a sample member ID

197def get_member_ids(
198    tensor_descriptions: Sequence[
199        Union[
200            v0_4.InputTensorDescr,
201            v0_4.OutputTensorDescr,
202            v0_5.InputTensorDescr,
203            v0_5.OutputTensorDescr,
204        ]
205    ],
206) -> List[MemberId]:
207    """get normalized tensor IDs to be used as sample member IDs"""
208    return [get_member_id(descr) for descr in tensor_descriptions]

get normalized tensor IDs to be used as sample member IDs

def get_test_inputs( model: Annotated[Union[Annotated[bioimageio.spec.model.v0_4.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.4')], Annotated[bioimageio.spec.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.5')]], Discriminator(discriminator='format_version', custom_error_type=None, custom_error_message=None, custom_error_context=None), FieldInfo(annotation=NoneType, required=True, title='model')]) -> bioimageio.core.Sample:
211def get_test_inputs(model: AnyModelDescr) -> Sample:
212    """returns a model's test input sample"""
213    member_ids = get_member_ids(model.inputs)
214    if isinstance(model, v0_4.ModelDescr):
215        arrays = [load_array(tt) for tt in model.test_inputs]
216    else:
217        arrays = [load_array(d.test_tensor) for d in model.inputs]
218
219    axes = [get_axes_infos(t) for t in model.inputs]
220    return Sample(
221        members={
222            m: Tensor.from_numpy(arr, dims=ax)
223            for m, arr, ax in zip(member_ids, arrays, axes)
224        },
225        stat={},
226        id="test-sample",
227    )

returns a model's test input sample

def get_test_outputs( model: Annotated[Union[Annotated[bioimageio.spec.model.v0_4.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.4')], Annotated[bioimageio.spec.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.5')]], Discriminator(discriminator='format_version', custom_error_type=None, custom_error_message=None, custom_error_context=None), FieldInfo(annotation=NoneType, required=True, title='model')]) -> bioimageio.core.Sample:
230def get_test_outputs(model: AnyModelDescr) -> Sample:
231    """returns a model's test output sample"""
232    member_ids = get_member_ids(model.outputs)
233
234    if isinstance(model, v0_4.ModelDescr):
235        arrays = [load_array(tt) for tt in model.test_outputs]
236    else:
237        arrays = [load_array(d.test_tensor) for d in model.outputs]
238
239    axes = [get_axes_infos(t) for t in model.outputs]
240
241    return Sample(
242        members={
243            m: Tensor.from_numpy(arr, dims=ax)
244            for m, arr, ax in zip(member_ids, arrays, axes)
245        },
246        stat={},
247        id="test-sample",
248    )

returns a model's test output sample

class IO_SampleBlockMeta(typing.NamedTuple):
251class IO_SampleBlockMeta(NamedTuple):
252    input: SampleBlockMeta
253    output: SampleBlockMeta

IO_SampleBlockMeta(input, output)

IO_SampleBlockMeta( input: ForwardRef('SampleBlockMeta'), output: ForwardRef('SampleBlockMeta'))

Create new instance of IO_SampleBlockMeta(input, output)

Alias for field number 0

Alias for field number 1

256def get_input_halo(model: v0_5.ModelDescr, output_halo: PerMember[PerAxis[Halo]]):
257    """returns which halo input tensors need to be divided into blocks with, such that
258    `output_halo` can be cropped from their outputs without introducing gaps."""
259    input_halo: Dict[MemberId, Dict[AxisId, Halo]] = {}
260    outputs = {t.id: t for t in model.outputs}
261    all_tensors = {**{t.id: t for t in model.inputs}, **outputs}
262
263    for t, th in output_halo.items():
264        axes = {a.id: a for a in outputs[t].axes}
265
266        for a, ah in th.items():
267            s = axes[a].size
268            if not isinstance(s, v0_5.SizeReference):
269                raise ValueError(
270                    f"Unable to map output halo for {t}.{a} to an input axis"
271                )
272
273            axis = axes[a]
274            ref_axis = {a.id: a for a in all_tensors[s.tensor_id].axes}[s.axis_id]
275
276            total_output_halo = sum(ah)
277            total_input_halo = total_output_halo * axis.scale / ref_axis.scale
278            assert (
279                total_input_halo == int(total_input_halo) and total_input_halo % 2 == 0
280            )
281            input_halo.setdefault(s.tensor_id, {})[a] = Halo(
282                int(total_input_halo // 2), int(total_input_halo // 2)
283            )
284
285    return input_halo

returns which halo input tensors need to be divided into blocks with, such that output_halo can be cropped from their outputs without introducing gaps.

288def get_block_transform(
289    model: v0_5.ModelDescr,
290) -> PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]]:
291    """returns how a model's output tensor shapes relates to its input shapes"""
292    ret: Dict[MemberId, Dict[AxisId, Union[LinearSampleAxisTransform, int]]] = {}
293    batch_axis_trf = None
294    for ipt in model.inputs:
295        for a in ipt.axes:
296            if a.type == "batch":
297                batch_axis_trf = LinearSampleAxisTransform(
298                    axis=a.id, scale=1, offset=0, member=ipt.id
299                )
300                break
301        if batch_axis_trf is not None:
302            break
303    axis_scales = {
304        t.id: {a.id: a.scale for a in t.axes}
305        for t in chain(model.inputs, model.outputs)
306    }
307    for out in model.outputs:
308        new_axes: Dict[AxisId, Union[LinearSampleAxisTransform, int]] = {}
309        for a in out.axes:
310            if a.size is None:
311                assert a.type == "batch"
312                if batch_axis_trf is None:
313                    raise ValueError(
314                        "no batch axis found in any input tensor, but output tensor"
315                        + f" '{out.id}' has one."
316                    )
317                s = batch_axis_trf
318            elif isinstance(a.size, int):
319                s = a.size
320            elif isinstance(a.size, v0_5.DataDependentSize):
321                s = -1
322            elif isinstance(a.size, v0_5.SizeReference):
323                s = LinearSampleAxisTransform(
324                    axis=a.size.axis_id,
325                    scale=axis_scales[a.size.tensor_id][a.size.axis_id] / a.scale,
326                    offset=a.size.offset,
327                    member=a.size.tensor_id,
328                )
329            else:
330                assert_never(a.size)
331
332            new_axes[a.id] = s
333
334        ret[out.id] = new_axes
335
336    return ret

returns how a model's output tensor shapes relates to its input shapes

def get_io_sample_block_metas( model: bioimageio.spec.ModelDescr, input_sample_shape: Mapping[bioimageio.spec.model.v0_5.TensorId, Mapping[bioimageio.spec.model.v0_5.AxisId, int]], ns: Mapping[Tuple[bioimageio.spec.model.v0_5.TensorId, bioimageio.spec.model.v0_5.AxisId], int], batch_size: int = 1) -> Tuple[int, Iterable[IO_SampleBlockMeta]]:
339def get_io_sample_block_metas(
340    model: v0_5.ModelDescr,
341    input_sample_shape: PerMember[PerAxis[int]],
342    ns: Mapping[Tuple[MemberId, AxisId], ParameterizedSize_N],
343    batch_size: int = 1,
344) -> Tuple[TotalNumberOfBlocks, Iterable[IO_SampleBlockMeta]]:
345    """returns an iterable yielding meta data for corresponding input and output samples"""
346    if not isinstance(model, v0_5.ModelDescr):
347        raise TypeError(f"get_block_meta() not implemented for {type(model)}")
348
349    block_axis_sizes = model.get_axis_sizes(ns=ns, batch_size=batch_size)
350    input_block_shape = {
351        t: {aa: s for (tt, aa), s in block_axis_sizes.inputs.items() if tt == t}
352        for t in {tt for tt, _ in block_axis_sizes.inputs}
353    }
354    output_halo = {
355        t.id: {
356            a.id: Halo(a.halo, a.halo) for a in t.axes if isinstance(a, v0_5.WithHalo)
357        }
358        for t in model.outputs
359    }
360    input_halo = get_input_halo(model, output_halo)
361
362    n_input_blocks, input_blocks = split_multiple_shapes_into_blocks(
363        input_sample_shape, input_block_shape, halo=input_halo
364    )
365    block_transform = get_block_transform(model)
366    return n_input_blocks, (
367        IO_SampleBlockMeta(ipt, ipt.get_transformed(block_transform))
368        for ipt in sample_block_meta_generator(
369            input_blocks, sample_shape=input_sample_shape, sample_id=None
370        )
371    )

returns an iterable yielding meta data for corresponding input and output samples

def get_tensor( src: Union[zipp.Path, bioimageio.core.Tensor, xarray.core.dataarray.DataArray, numpy.ndarray[Any, numpy.dtype[Any]], pathlib.Path], ipt: Union[bioimageio.spec.model.v0_4.InputTensorDescr, bioimageio.spec.model.v0_5.InputTensorDescr]):
374def get_tensor(
375    src: Union[ZipPath, TensorSource],
376    ipt: Union[v0_4.InputTensorDescr, v0_5.InputTensorDescr],
377):
378    """helper to cast/load various tensor sources"""
379
380    if isinstance(src, Tensor):
381        return src
382
383    if isinstance(src, xr.DataArray):
384        return Tensor.from_xarray(src)
385
386    if isinstance(src, np.ndarray):
387        return Tensor.from_numpy(src, dims=get_axes_infos(ipt))
388
389    if isinstance(src, FileDescr):
390        src = download(src).path
391
392    if isinstance(src, (ZipPath, Path, str)):
393        return load_tensor(src, axes=get_axes_infos(ipt))
394
395    assert_never(src)

helper to cast/load various tensor sources

def create_sample_for_model( model: Annotated[Union[Annotated[bioimageio.spec.model.v0_4.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.4')], Annotated[bioimageio.spec.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.5')]], Discriminator(discriminator='format_version', custom_error_type=None, custom_error_message=None, custom_error_context=None), FieldInfo(annotation=NoneType, required=True, title='model')], *, stat: Optional[Dict[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer>, return_type=PydanticUndefined, when_used='always')]]]] = None, sample_id: Hashable = None, inputs: Union[Mapping[bioimageio.spec.model.v0_5.TensorId, Union[bioimageio.core.Tensor, xarray.core.dataarray.DataArray, numpy.ndarray[Any, numpy.dtype[Any]], pathlib.Path]], bioimageio.core.Tensor, xarray.core.dataarray.DataArray, numpy.ndarray[Any, numpy.dtype[Any]], pathlib.Path]) -> bioimageio.core.Sample:
398def create_sample_for_model(
399    model: AnyModelDescr,
400    *,
401    stat: Optional[Stat] = None,
402    sample_id: SampleId = None,
403    inputs: Union[PerMember[TensorSource], TensorSource],
404) -> Sample:
405    """Create a sample from a single set of input(s) for a specific bioimage.io model
406
407    Args:
408        model: a bioimage.io model description
409        stat: dictionary with sample and dataset statistics (may be updated in-place!)
410        inputs: the input(s) constituting a single sample.
411    """
412
413    model_inputs = {get_member_id(d): d for d in model.inputs}
414    if isinstance(inputs, collections.abc.Mapping):
415        inputs = {MemberId(k): v for k, v in inputs.items()}
416    elif len(model_inputs) == 1:
417        inputs = {list(model_inputs)[0]: inputs}
418    else:
419        raise TypeError(
420            f"Expected `inputs` to be a mapping with keys {tuple(model_inputs)}"
421        )
422
423    if unknown := {k for k in inputs if k not in model_inputs}:
424        raise ValueError(f"Got unexpected inputs: {unknown}")
425
426    if missing := {
427        k
428        for k, v in model_inputs.items()
429        if k not in inputs and not (isinstance(v, v0_5.InputTensorDescr) and v.optional)
430    }:
431        raise ValueError(f"Missing non-optional model inputs: {missing}")
432
433    return Sample(
434        members={
435            m: get_tensor(inputs[m], ipt)
436            for m, ipt in model_inputs.items()
437            if m in inputs
438        },
439        stat={} if stat is None else stat,
440        id=sample_id,
441    )

Create a sample from a single set of input(s) for a specific bioimage.io model

Arguments:
  • model: a bioimage.io model description
  • stat: dictionary with sample and dataset statistics (may be updated in-place!)
  • inputs: the input(s) constituting a single sample.
def load_sample_for_model( *, model: Annotated[Union[Annotated[bioimageio.spec.model.v0_4.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.4')], Annotated[bioimageio.spec.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.5')]], Discriminator(discriminator='format_version', custom_error_type=None, custom_error_message=None, custom_error_context=None), FieldInfo(annotation=NoneType, required=True, title='model')], paths: Mapping[bioimageio.spec.model.v0_5.TensorId, pathlib.Path], axes: Optional[Mapping[bioimageio.spec.model.v0_5.TensorId, Sequence[Union[bioimageio.spec.model.v0_5.AxisId, Literal['b', 'i', 't', 'c', 'z', 'y', 'x'], Annotated[Union[bioimageio.spec.model.v0_5.BatchAxis, bioimageio.spec.model.v0_5.ChannelAxis, bioimageio.spec.model.v0_5.IndexInputAxis, bioimageio.spec.model.v0_5.TimeInputAxis, bioimageio.spec.model.v0_5.SpaceInputAxis], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.spec.model.v0_5.BatchAxis, bioimageio.spec.model.v0_5.ChannelAxis, bioimageio.spec.model.v0_5.IndexOutputAxis, Annotated[Union[Annotated[bioimageio.spec.model.v0_5.TimeOutputAxis, Tag(tag='wo_halo')], Annotated[bioimageio.spec.model.v0_5.TimeOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[Annotated[bioimageio.spec.model.v0_5.SpaceOutputAxis, Tag(tag='wo_halo')], Annotated[bioimageio.spec.model.v0_5.SpaceOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], bioimageio.core.Axis]]]] = None, stat: Optional[Dict[Annotated[Union[Annotated[Union[bioimageio.core.stat_measures.SampleMean, bioimageio.core.stat_measures.SampleStd, bioimageio.core.stat_measures.SampleVar, bioimageio.core.stat_measures.SampleQuantile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Annotated[Union[bioimageio.core.stat_measures.DatasetMean, bioimageio.core.stat_measures.DatasetStd, bioimageio.core.stat_measures.DatasetVar, bioimageio.core.stat_measures.DatasetPercentile], Discriminator(discriminator='name', custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='scope', custom_error_type=None, custom_error_message=None, custom_error_context=None)], Union[float, Annotated[bioimageio.core.Tensor, BeforeValidator(func=<function tensor_custom_before_validator>, json_schema_input_type=PydanticUndefined), PlainSerializer(func=<function tensor_custom_serializer>, return_type=PydanticUndefined, when_used='always')]]]] = None, sample_id: Optional[Hashable] = None):
444def load_sample_for_model(
445    *,
446    model: AnyModelDescr,
447    paths: PerMember[Path],
448    axes: Optional[PerMember[Sequence[AxisLike]]] = None,
449    stat: Optional[Stat] = None,
450    sample_id: Optional[SampleId] = None,
451):
452    """load a single sample from `paths` that can be processed by `model`"""
453
454    if axes is None:
455        axes = {}
456
457    # make sure members are keyed by MemberId, not string
458    paths = {MemberId(k): v for k, v in paths.items()}
459    axes = {MemberId(k): v for k, v in axes.items()}
460
461    model_inputs = {get_member_id(d): d for d in model.inputs}
462
463    if unknown := {k for k in paths if k not in model_inputs}:
464        raise ValueError(f"Got unexpected paths for {unknown}")
465
466    if unknown := {k for k in axes if k not in model_inputs}:
467        raise ValueError(f"Got unexpected axes hints for: {unknown}")
468
469    members: Dict[MemberId, Tensor] = {}
470    for m, p in paths.items():
471        if m not in axes:
472            axes[m] = get_axes_infos(model_inputs[m])
473            logger.debug(
474                "loading '{}' from {} with default input axes {} ",
475                m,
476                p,
477                axes[m],
478            )
479        members[m] = load_tensor(p, axes[m])
480
481    return Sample(
482        members=members,
483        stat={} if stat is None else stat,
484        id=sample_id or tuple(sorted(paths.values())),
485    )

load a single sample from paths that can be processed by model