bioimageio.spec.model.v0_5

   1from __future__ import annotations
   2
   3import collections.abc
   4import re
   5import string
   6import warnings
   7from abc import ABC
   8from copy import deepcopy
   9from itertools import chain
  10from math import ceil
  11from pathlib import Path, PurePosixPath
  12from tempfile import mkdtemp
  13from typing import (
  14    TYPE_CHECKING,
  15    Any,
  16    Callable,
  17    ClassVar,
  18    Dict,
  19    Generic,
  20    List,
  21    Literal,
  22    Mapping,
  23    NamedTuple,
  24    Optional,
  25    Sequence,
  26    Set,
  27    Tuple,
  28    Type,
  29    TypeVar,
  30    Union,
  31    cast,
  32)
  33
  34import numpy as np
  35from annotated_types import Ge, Gt, Interval, MaxLen, MinLen, Predicate
  36from imageio.v3 import imread, imwrite  # pyright: ignore[reportUnknownVariableType]
  37from loguru import logger
  38from numpy.typing import NDArray
  39from pydantic import (
  40    AfterValidator,
  41    Discriminator,
  42    Field,
  43    RootModel,
  44    SerializationInfo,
  45    SerializerFunctionWrapHandler,
  46    StrictInt,
  47    Tag,
  48    ValidationInfo,
  49    WrapSerializer,
  50    field_validator,
  51    model_serializer,
  52    model_validator,
  53)
  54from typing_extensions import Annotated, Self, assert_never, get_args
  55
  56from .._internal.common_nodes import (
  57    InvalidDescr,
  58    Node,
  59    NodeWithExplicitlySetFields,
  60)
  61from .._internal.constants import DTYPE_LIMITS
  62from .._internal.field_warning import issue_warning, warn
  63from .._internal.io import BioimageioYamlContent as BioimageioYamlContent
  64from .._internal.io import FileDescr as FileDescr
  65from .._internal.io import (
  66    FileSource,
  67    WithSuffix,
  68    YamlValue,
  69    extract_file_name,
  70    get_reader,
  71    wo_special_file_name,
  72)
  73from .._internal.io_basics import Sha256 as Sha256
  74from .._internal.io_packaging import (
  75    FileDescr_,
  76    FileSource_,
  77    package_file_descr_serializer,
  78)
  79from .._internal.io_utils import load_array
  80from .._internal.node_converter import Converter
  81from .._internal.type_guards import is_dict, is_sequence
  82from .._internal.types import (
  83    FAIR,
  84    AbsoluteTolerance,
  85    LowerCaseIdentifier,
  86    LowerCaseIdentifierAnno,
  87    MismatchedElementsPerMillion,
  88    RelativeTolerance,
  89)
  90from .._internal.types import Datetime as Datetime
  91from .._internal.types import Identifier as Identifier
  92from .._internal.types import NotEmpty as NotEmpty
  93from .._internal.types import SiUnit as SiUnit
  94from .._internal.url import HttpUrl as HttpUrl
  95from .._internal.validation_context import get_validation_context
  96from .._internal.validator_annotations import RestrictCharacters
  97from .._internal.version_type import Version as Version
  98from .._internal.warning_levels import INFO
  99from ..dataset.v0_2 import DatasetDescr as DatasetDescr02
 100from ..dataset.v0_2 import LinkedDataset as LinkedDataset02
 101from ..dataset.v0_3 import DatasetDescr as DatasetDescr
 102from ..dataset.v0_3 import DatasetId as DatasetId
 103from ..dataset.v0_3 import LinkedDataset as LinkedDataset
 104from ..dataset.v0_3 import Uploader as Uploader
 105from ..generic.v0_3 import (
 106    VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS,
 107)
 108from ..generic.v0_3 import Author as Author
 109from ..generic.v0_3 import BadgeDescr as BadgeDescr
 110from ..generic.v0_3 import CiteEntry as CiteEntry
 111from ..generic.v0_3 import DeprecatedLicenseId as DeprecatedLicenseId
 112from ..generic.v0_3 import Doi as Doi
 113from ..generic.v0_3 import (
 114    FileSource_documentation,
 115    GenericModelDescrBase,
 116    LinkedResourceBase,
 117    _author_conv,  # pyright: ignore[reportPrivateUsage]
 118    _maintainer_conv,  # pyright: ignore[reportPrivateUsage]
 119)
 120from ..generic.v0_3 import LicenseId as LicenseId
 121from ..generic.v0_3 import LinkedResource as LinkedResource
 122from ..generic.v0_3 import Maintainer as Maintainer
 123from ..generic.v0_3 import OrcidId as OrcidId
 124from ..generic.v0_3 import RelativeFilePath as RelativeFilePath
 125from ..generic.v0_3 import ResourceId as ResourceId
 126from .v0_4 import Author as _Author_v0_4
 127from .v0_4 import BinarizeDescr as _BinarizeDescr_v0_4
 128from .v0_4 import CallableFromDepencency as CallableFromDepencency
 129from .v0_4 import CallableFromDepencency as _CallableFromDepencency_v0_4
 130from .v0_4 import CallableFromFile as _CallableFromFile_v0_4
 131from .v0_4 import ClipDescr as _ClipDescr_v0_4
 132from .v0_4 import ClipKwargs as ClipKwargs
 133from .v0_4 import ImplicitOutputShape as _ImplicitOutputShape_v0_4
 134from .v0_4 import InputTensorDescr as _InputTensorDescr_v0_4
 135from .v0_4 import KnownRunMode as KnownRunMode
 136from .v0_4 import ModelDescr as _ModelDescr_v0_4
 137from .v0_4 import OutputTensorDescr as _OutputTensorDescr_v0_4
 138from .v0_4 import ParameterizedInputShape as _ParameterizedInputShape_v0_4
 139from .v0_4 import PostprocessingDescr as _PostprocessingDescr_v0_4
 140from .v0_4 import PreprocessingDescr as _PreprocessingDescr_v0_4
 141from .v0_4 import ProcessingKwargs as ProcessingKwargs
 142from .v0_4 import RunMode as RunMode
 143from .v0_4 import ScaleLinearDescr as _ScaleLinearDescr_v0_4
 144from .v0_4 import ScaleMeanVarianceDescr as _ScaleMeanVarianceDescr_v0_4
 145from .v0_4 import ScaleRangeDescr as _ScaleRangeDescr_v0_4
 146from .v0_4 import SigmoidDescr as _SigmoidDescr_v0_4
 147from .v0_4 import TensorName as _TensorName_v0_4
 148from .v0_4 import WeightsFormat as WeightsFormat
 149from .v0_4 import ZeroMeanUnitVarianceDescr as _ZeroMeanUnitVarianceDescr_v0_4
 150from .v0_4 import package_weights
 151
 152SpaceUnit = Literal[
 153    "attometer",
 154    "angstrom",
 155    "centimeter",
 156    "decimeter",
 157    "exameter",
 158    "femtometer",
 159    "foot",
 160    "gigameter",
 161    "hectometer",
 162    "inch",
 163    "kilometer",
 164    "megameter",
 165    "meter",
 166    "micrometer",
 167    "mile",
 168    "millimeter",
 169    "nanometer",
 170    "parsec",
 171    "petameter",
 172    "picometer",
 173    "terameter",
 174    "yard",
 175    "yoctometer",
 176    "yottameter",
 177    "zeptometer",
 178    "zettameter",
 179]
 180"""Space unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)"""
 181
 182TimeUnit = Literal[
 183    "attosecond",
 184    "centisecond",
 185    "day",
 186    "decisecond",
 187    "exasecond",
 188    "femtosecond",
 189    "gigasecond",
 190    "hectosecond",
 191    "hour",
 192    "kilosecond",
 193    "megasecond",
 194    "microsecond",
 195    "millisecond",
 196    "minute",
 197    "nanosecond",
 198    "petasecond",
 199    "picosecond",
 200    "second",
 201    "terasecond",
 202    "yoctosecond",
 203    "yottasecond",
 204    "zeptosecond",
 205    "zettasecond",
 206]
 207"""Time unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)"""
 208
 209AxisType = Literal["batch", "channel", "index", "time", "space"]
 210
 211_AXIS_TYPE_MAP: Mapping[str, AxisType] = {
 212    "b": "batch",
 213    "t": "time",
 214    "i": "index",
 215    "c": "channel",
 216    "x": "space",
 217    "y": "space",
 218    "z": "space",
 219}
 220
 221_AXIS_ID_MAP = {
 222    "b": "batch",
 223    "t": "time",
 224    "i": "index",
 225    "c": "channel",
 226}
 227
 228
 229class TensorId(LowerCaseIdentifier):
 230    root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
 231        Annotated[LowerCaseIdentifierAnno, MaxLen(32)]
 232    ]
 233
 234
 235def _normalize_axis_id(a: str):
 236    a = str(a)
 237    normalized = _AXIS_ID_MAP.get(a, a)
 238    if a != normalized:
 239        logger.opt(depth=3).warning(
 240            "Normalized axis id from '{}' to '{}'.", a, normalized
 241        )
 242    return normalized
 243
 244
 245class AxisId(LowerCaseIdentifier):
 246    root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
 247        Annotated[
 248            LowerCaseIdentifierAnno,
 249            MaxLen(16),
 250            AfterValidator(_normalize_axis_id),
 251        ]
 252    ]
 253
 254
 255def _is_batch(a: str) -> bool:
 256    return str(a) == "batch"
 257
 258
 259def _is_not_batch(a: str) -> bool:
 260    return not _is_batch(a)
 261
 262
 263NonBatchAxisId = Annotated[AxisId, Predicate(_is_not_batch)]
 264
 265PreprocessingId = Literal[
 266    "binarize",
 267    "clip",
 268    "ensure_dtype",
 269    "fixed_zero_mean_unit_variance",
 270    "scale_linear",
 271    "scale_range",
 272    "sigmoid",
 273    "softmax",
 274]
 275PostprocessingId = Literal[
 276    "binarize",
 277    "clip",
 278    "ensure_dtype",
 279    "fixed_zero_mean_unit_variance",
 280    "scale_linear",
 281    "scale_mean_variance",
 282    "scale_range",
 283    "sigmoid",
 284    "softmax",
 285    "zero_mean_unit_variance",
 286]
 287
 288
 289SAME_AS_TYPE = "<same as type>"
 290
 291
 292ParameterizedSize_N = int
 293"""
 294Annotates an integer to calculate a concrete axis size from a `ParameterizedSize`.
 295"""
 296
 297
 298class ParameterizedSize(Node):
 299    """Describes a range of valid tensor axis sizes as `size = min + n*step`.
 300
 301    - **min** and **step** are given by the model description.
 302    - All blocksize paramters n = 0,1,2,... yield a valid `size`.
 303    - A greater blocksize paramter n = 0,1,2,... results in a greater **size**.
 304      This allows to adjust the axis size more generically.
 305    """
 306
 307    N: ClassVar[Type[int]] = ParameterizedSize_N
 308    """Positive integer to parameterize this axis"""
 309
 310    min: Annotated[int, Gt(0)]
 311    step: Annotated[int, Gt(0)]
 312
 313    def validate_size(self, size: int) -> int:
 314        if size < self.min:
 315            raise ValueError(f"size {size} < {self.min}")
 316        if (size - self.min) % self.step != 0:
 317            raise ValueError(
 318                f"axis of size {size} is not parameterized by `min + n*step` ="
 319                + f" `{self.min} + n*{self.step}`"
 320            )
 321
 322        return size
 323
 324    def get_size(self, n: ParameterizedSize_N) -> int:
 325        return self.min + self.step * n
 326
 327    def get_n(self, s: int) -> ParameterizedSize_N:
 328        """return smallest n parameterizing a size greater or equal than `s`"""
 329        return ceil((s - self.min) / self.step)
 330
 331
 332class DataDependentSize(Node):
 333    min: Annotated[int, Gt(0)] = 1
 334    max: Annotated[Optional[int], Gt(1)] = None
 335
 336    @model_validator(mode="after")
 337    def _validate_max_gt_min(self):
 338        if self.max is not None and self.min >= self.max:
 339            raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}")
 340
 341        return self
 342
 343    def validate_size(self, size: int) -> int:
 344        if size < self.min:
 345            raise ValueError(f"size {size} < {self.min}")
 346
 347        if self.max is not None and size > self.max:
 348            raise ValueError(f"size {size} > {self.max}")
 349
 350        return size
 351
 352
 353class SizeReference(Node):
 354    """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis.
 355
 356    `axis.size = reference.size * reference.scale / axis.scale + offset`
 357
 358    Note:
 359    1. The axis and the referenced axis need to have the same unit (or no unit).
 360    2. Batch axes may not be referenced.
 361    3. Fractions are rounded down.
 362    4. If the reference axis is `concatenable` the referencing axis is assumed to be
 363        `concatenable` as well with the same block order.
 364
 365    Example:
 366    An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm².
 367    Let's assume that we want to express the image height h in relation to its width w
 368    instead of only accepting input images of exactly 100*49 pixels
 369    (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`).
 370
 371    >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2)
 372    >>> h = SpaceInputAxis(
 373    ...     id=AxisId("h"),
 374    ...     size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1),
 375    ...     unit="millimeter",
 376    ...     scale=4,
 377    ... )
 378    >>> print(h.size.get_size(h, w))
 379    49
 380
 381    ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49
 382    """
 383
 384    tensor_id: TensorId
 385    """tensor id of the reference axis"""
 386
 387    axis_id: AxisId
 388    """axis id of the reference axis"""
 389
 390    offset: StrictInt = 0
 391
 392    def get_size(
 393        self,
 394        axis: Union[
 395            ChannelAxis,
 396            IndexInputAxis,
 397            IndexOutputAxis,
 398            TimeInputAxis,
 399            SpaceInputAxis,
 400            TimeOutputAxis,
 401            TimeOutputAxisWithHalo,
 402            SpaceOutputAxis,
 403            SpaceOutputAxisWithHalo,
 404        ],
 405        ref_axis: Union[
 406            ChannelAxis,
 407            IndexInputAxis,
 408            IndexOutputAxis,
 409            TimeInputAxis,
 410            SpaceInputAxis,
 411            TimeOutputAxis,
 412            TimeOutputAxisWithHalo,
 413            SpaceOutputAxis,
 414            SpaceOutputAxisWithHalo,
 415        ],
 416        n: ParameterizedSize_N = 0,
 417        ref_size: Optional[int] = None,
 418    ):
 419        """Compute the concrete size for a given axis and its reference axis.
 420
 421        Args:
 422            axis: The axis this `SizeReference` is the size of.
 423            ref_axis: The reference axis to compute the size from.
 424            n: If the **ref_axis** is parameterized (of type `ParameterizedSize`)
 425                and no fixed **ref_size** is given,
 426                **n** is used to compute the size of the parameterized **ref_axis**.
 427            ref_size: Overwrite the reference size instead of deriving it from
 428                **ref_axis**
 429                (**ref_axis.scale** is still used; any given **n** is ignored).
 430        """
 431        assert axis.size == self, (
 432            "Given `axis.size` is not defined by this `SizeReference`"
 433        )
 434
 435        assert ref_axis.id == self.axis_id, (
 436            f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}."
 437        )
 438
 439        assert axis.unit == ref_axis.unit, (
 440            "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`,"
 441            f" but {axis.unit}!={ref_axis.unit}"
 442        )
 443        if ref_size is None:
 444            if isinstance(ref_axis.size, (int, float)):
 445                ref_size = ref_axis.size
 446            elif isinstance(ref_axis.size, ParameterizedSize):
 447                ref_size = ref_axis.size.get_size(n)
 448            elif isinstance(ref_axis.size, DataDependentSize):
 449                raise ValueError(
 450                    "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`."
 451                )
 452            elif isinstance(ref_axis.size, SizeReference):
 453                raise ValueError(
 454                    "Reference axis referenced in `SizeReference` may not be sized by a"
 455                    + " `SizeReference` itself."
 456                )
 457            else:
 458                assert_never(ref_axis.size)
 459
 460        return int(ref_size * ref_axis.scale / axis.scale + self.offset)
 461
 462    @staticmethod
 463    def _get_unit(
 464        axis: Union[
 465            ChannelAxis,
 466            IndexInputAxis,
 467            IndexOutputAxis,
 468            TimeInputAxis,
 469            SpaceInputAxis,
 470            TimeOutputAxis,
 471            TimeOutputAxisWithHalo,
 472            SpaceOutputAxis,
 473            SpaceOutputAxisWithHalo,
 474        ],
 475    ):
 476        return axis.unit
 477
 478
 479class AxisBase(NodeWithExplicitlySetFields):
 480    id: AxisId
 481    """An axis id unique across all axes of one tensor."""
 482
 483    description: Annotated[str, MaxLen(128)] = ""
 484    """A short description of this axis beyond its type and id."""
 485
 486
 487class WithHalo(Node):
 488    halo: Annotated[int, Ge(1)]
 489    """The halo should be cropped from the output tensor to avoid boundary effects.
 490    It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`.
 491    To document a halo that is already cropped by the model use `size.offset` instead."""
 492
 493    size: Annotated[
 494        SizeReference,
 495        Field(
 496            examples=[
 497                10,
 498                SizeReference(
 499                    tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
 500                ).model_dump(mode="json"),
 501            ]
 502        ),
 503    ]
 504    """reference to another axis with an optional offset (see `SizeReference`)"""
 505
 506
 507BATCH_AXIS_ID = AxisId("batch")
 508
 509
 510class BatchAxis(AxisBase):
 511    implemented_type: ClassVar[Literal["batch"]] = "batch"
 512    if TYPE_CHECKING:
 513        type: Literal["batch"] = "batch"
 514    else:
 515        type: Literal["batch"]
 516
 517    id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID
 518    size: Optional[Literal[1]] = None
 519    """The batch size may be fixed to 1,
 520    otherwise (the default) it may be chosen arbitrarily depending on available memory"""
 521
 522    @property
 523    def scale(self):
 524        return 1.0
 525
 526    @property
 527    def concatenable(self):
 528        return True
 529
 530    @property
 531    def unit(self):
 532        return None
 533
 534
 535class ChannelAxis(AxisBase):
 536    implemented_type: ClassVar[Literal["channel"]] = "channel"
 537    if TYPE_CHECKING:
 538        type: Literal["channel"] = "channel"
 539    else:
 540        type: Literal["channel"]
 541
 542    id: NonBatchAxisId = AxisId("channel")
 543
 544    channel_names: NotEmpty[List[Identifier]]
 545
 546    @property
 547    def size(self) -> int:
 548        return len(self.channel_names)
 549
 550    @property
 551    def concatenable(self):
 552        return False
 553
 554    @property
 555    def scale(self) -> float:
 556        return 1.0
 557
 558    @property
 559    def unit(self):
 560        return None
 561
 562
 563class IndexAxisBase(AxisBase):
 564    implemented_type: ClassVar[Literal["index"]] = "index"
 565    if TYPE_CHECKING:
 566        type: Literal["index"] = "index"
 567    else:
 568        type: Literal["index"]
 569
 570    id: NonBatchAxisId = AxisId("index")
 571
 572    @property
 573    def scale(self) -> float:
 574        return 1.0
 575
 576    @property
 577    def unit(self):
 578        return None
 579
 580
 581class _WithInputAxisSize(Node):
 582    size: Annotated[
 583        Union[Annotated[int, Gt(0)], ParameterizedSize, SizeReference],
 584        Field(
 585            examples=[
 586                10,
 587                ParameterizedSize(min=32, step=16).model_dump(mode="json"),
 588                SizeReference(
 589                    tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
 590                ).model_dump(mode="json"),
 591            ]
 592        ),
 593    ]
 594    """The size/length of this axis can be specified as
 595    - fixed integer
 596    - parameterized series of valid sizes (`ParameterizedSize`)
 597    - reference to another axis with an optional offset (`SizeReference`)
 598    """
 599
 600
 601class IndexInputAxis(IndexAxisBase, _WithInputAxisSize):
 602    concatenable: bool = False
 603    """If a model has a `concatenable` input axis, it can be processed blockwise,
 604    splitting a longer sample axis into blocks matching its input tensor description.
 605    Output axes are concatenable if they have a `SizeReference` to a concatenable
 606    input axis.
 607    """
 608
 609
 610class IndexOutputAxis(IndexAxisBase):
 611    size: Annotated[
 612        Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize],
 613        Field(
 614            examples=[
 615                10,
 616                SizeReference(
 617                    tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
 618                ).model_dump(mode="json"),
 619            ]
 620        ),
 621    ]
 622    """The size/length of this axis can be specified as
 623    - fixed integer
 624    - reference to another axis with an optional offset (`SizeReference`)
 625    - data dependent size using `DataDependentSize` (size is only known after model inference)
 626    """
 627
 628
 629class TimeAxisBase(AxisBase):
 630    implemented_type: ClassVar[Literal["time"]] = "time"
 631    if TYPE_CHECKING:
 632        type: Literal["time"] = "time"
 633    else:
 634        type: Literal["time"]
 635
 636    id: NonBatchAxisId = AxisId("time")
 637    unit: Optional[TimeUnit] = None
 638    scale: Annotated[float, Gt(0)] = 1.0
 639
 640
 641class TimeInputAxis(TimeAxisBase, _WithInputAxisSize):
 642    concatenable: bool = False
 643    """If a model has a `concatenable` input axis, it can be processed blockwise,
 644    splitting a longer sample axis into blocks matching its input tensor description.
 645    Output axes are concatenable if they have a `SizeReference` to a concatenable
 646    input axis.
 647    """
 648
 649
 650class SpaceAxisBase(AxisBase):
 651    implemented_type: ClassVar[Literal["space"]] = "space"
 652    if TYPE_CHECKING:
 653        type: Literal["space"] = "space"
 654    else:
 655        type: Literal["space"]
 656
 657    id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x")
 658    unit: Optional[SpaceUnit] = None
 659    scale: Annotated[float, Gt(0)] = 1.0
 660
 661
 662class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize):
 663    concatenable: bool = False
 664    """If a model has a `concatenable` input axis, it can be processed blockwise,
 665    splitting a longer sample axis into blocks matching its input tensor description.
 666    Output axes are concatenable if they have a `SizeReference` to a concatenable
 667    input axis.
 668    """
 669
 670
 671INPUT_AXIS_TYPES = (
 672    BatchAxis,
 673    ChannelAxis,
 674    IndexInputAxis,
 675    TimeInputAxis,
 676    SpaceInputAxis,
 677)
 678"""intended for isinstance comparisons in py<3.10"""
 679
 680_InputAxisUnion = Union[
 681    BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis
 682]
 683InputAxis = Annotated[_InputAxisUnion, Discriminator("type")]
 684
 685
 686class _WithOutputAxisSize(Node):
 687    size: Annotated[
 688        Union[Annotated[int, Gt(0)], SizeReference],
 689        Field(
 690            examples=[
 691                10,
 692                SizeReference(
 693                    tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
 694                ).model_dump(mode="json"),
 695            ]
 696        ),
 697    ]
 698    """The size/length of this axis can be specified as
 699    - fixed integer
 700    - reference to another axis with an optional offset (see `SizeReference`)
 701    """
 702
 703
 704class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize):
 705    pass
 706
 707
 708class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo):
 709    pass
 710
 711
 712def _get_halo_axis_discriminator_value(v: Any) -> Literal["with_halo", "wo_halo"]:
 713    if isinstance(v, dict):
 714        return "with_halo" if "halo" in v else "wo_halo"
 715    else:
 716        return "with_halo" if hasattr(v, "halo") else "wo_halo"
 717
 718
 719_TimeOutputAxisUnion = Annotated[
 720    Union[
 721        Annotated[TimeOutputAxis, Tag("wo_halo")],
 722        Annotated[TimeOutputAxisWithHalo, Tag("with_halo")],
 723    ],
 724    Discriminator(_get_halo_axis_discriminator_value),
 725]
 726
 727
 728class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize):
 729    pass
 730
 731
 732class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo):
 733    pass
 734
 735
 736_SpaceOutputAxisUnion = Annotated[
 737    Union[
 738        Annotated[SpaceOutputAxis, Tag("wo_halo")],
 739        Annotated[SpaceOutputAxisWithHalo, Tag("with_halo")],
 740    ],
 741    Discriminator(_get_halo_axis_discriminator_value),
 742]
 743
 744
 745_OutputAxisUnion = Union[
 746    BatchAxis, ChannelAxis, IndexOutputAxis, _TimeOutputAxisUnion, _SpaceOutputAxisUnion
 747]
 748OutputAxis = Annotated[_OutputAxisUnion, Discriminator("type")]
 749
 750OUTPUT_AXIS_TYPES = (
 751    BatchAxis,
 752    ChannelAxis,
 753    IndexOutputAxis,
 754    TimeOutputAxis,
 755    TimeOutputAxisWithHalo,
 756    SpaceOutputAxis,
 757    SpaceOutputAxisWithHalo,
 758)
 759"""intended for isinstance comparisons in py<3.10"""
 760
 761
 762AnyAxis = Union[InputAxis, OutputAxis]
 763
 764ANY_AXIS_TYPES = INPUT_AXIS_TYPES + OUTPUT_AXIS_TYPES
 765"""intended for isinstance comparisons in py<3.10"""
 766
 767TVs = Union[
 768    NotEmpty[List[int]],
 769    NotEmpty[List[float]],
 770    NotEmpty[List[bool]],
 771    NotEmpty[List[str]],
 772]
 773
 774
 775NominalOrOrdinalDType = Literal[
 776    "float32",
 777    "float64",
 778    "uint8",
 779    "int8",
 780    "uint16",
 781    "int16",
 782    "uint32",
 783    "int32",
 784    "uint64",
 785    "int64",
 786    "bool",
 787]
 788
 789
 790class NominalOrOrdinalDataDescr(Node):
 791    values: TVs
 792    """A fixed set of nominal or an ascending sequence of ordinal values.
 793    In this case `data.type` is required to be an unsigend integer type, e.g. 'uint8'.
 794    String `values` are interpreted as labels for tensor values 0, ..., N.
 795    Note: as YAML 1.2 does not natively support a "set" datatype,
 796    nominal values should be given as a sequence (aka list/array) as well.
 797    """
 798
 799    type: Annotated[
 800        NominalOrOrdinalDType,
 801        Field(
 802            examples=[
 803                "float32",
 804                "uint8",
 805                "uint16",
 806                "int64",
 807                "bool",
 808            ],
 809        ),
 810    ] = "uint8"
 811
 812    @model_validator(mode="after")
 813    def _validate_values_match_type(
 814        self,
 815    ) -> Self:
 816        incompatible: List[Any] = []
 817        for v in self.values:
 818            if self.type == "bool":
 819                if not isinstance(v, bool):
 820                    incompatible.append(v)
 821            elif self.type in DTYPE_LIMITS:
 822                if (
 823                    isinstance(v, (int, float))
 824                    and (
 825                        v < DTYPE_LIMITS[self.type].min
 826                        or v > DTYPE_LIMITS[self.type].max
 827                    )
 828                    or (isinstance(v, str) and "uint" not in self.type)
 829                    or (isinstance(v, float) and "int" in self.type)
 830                ):
 831                    incompatible.append(v)
 832            else:
 833                incompatible.append(v)
 834
 835            if len(incompatible) == 5:
 836                incompatible.append("...")
 837                break
 838
 839        if incompatible:
 840            raise ValueError(
 841                f"data type '{self.type}' incompatible with values {incompatible}"
 842            )
 843
 844        return self
 845
 846    unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None
 847
 848    @property
 849    def range(self):
 850        if isinstance(self.values[0], str):
 851            return 0, len(self.values) - 1
 852        else:
 853            return min(self.values), max(self.values)
 854
 855
 856IntervalOrRatioDType = Literal[
 857    "float32",
 858    "float64",
 859    "uint8",
 860    "int8",
 861    "uint16",
 862    "int16",
 863    "uint32",
 864    "int32",
 865    "uint64",
 866    "int64",
 867]
 868
 869
 870class IntervalOrRatioDataDescr(Node):
 871    type: Annotated[  # TODO: rename to dtype
 872        IntervalOrRatioDType,
 873        Field(
 874            examples=["float32", "float64", "uint8", "uint16"],
 875        ),
 876    ] = "float32"
 877    range: Tuple[Optional[float], Optional[float]] = (
 878        None,
 879        None,
 880    )
 881    """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor.
 882    `None` corresponds to min/max of what can be expressed by **type**."""
 883    unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit"
 884    scale: float = 1.0
 885    """Scale for data on an interval (or ratio) scale."""
 886    offset: Optional[float] = None
 887    """Offset for data on a ratio scale."""
 888
 889    @model_validator(mode="before")
 890    def _replace_inf(cls, data: Any):
 891        if is_dict(data):
 892            if "range" in data and is_sequence(data["range"]):
 893                forbidden = (
 894                    "inf",
 895                    "-inf",
 896                    ".inf",
 897                    "-.inf",
 898                    float("inf"),
 899                    float("-inf"),
 900                )
 901                if any(v in forbidden for v in data["range"]):
 902                    issue_warning("replaced 'inf' value", value=data["range"])
 903
 904                data["range"] = tuple(
 905                    (None if v in forbidden else v) for v in data["range"]
 906                )
 907
 908        return data
 909
 910
 911TensorDataDescr = Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr]
 912
 913
 914class ProcessingDescrBase(NodeWithExplicitlySetFields, ABC):
 915    """processing base class"""
 916
 917
 918class BinarizeKwargs(ProcessingKwargs):
 919    """key word arguments for `BinarizeDescr`"""
 920
 921    threshold: float
 922    """The fixed threshold"""
 923
 924
 925class BinarizeAlongAxisKwargs(ProcessingKwargs):
 926    """key word arguments for `BinarizeDescr`"""
 927
 928    threshold: NotEmpty[List[float]]
 929    """The fixed threshold values along `axis`"""
 930
 931    axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
 932    """The `threshold` axis"""
 933
 934
 935class BinarizeDescr(ProcessingDescrBase):
 936    """Binarize the tensor with a fixed threshold.
 937
 938    Values above `BinarizeKwargs.threshold`/`BinarizeAlongAxisKwargs.threshold`
 939    will be set to one, values below the threshold to zero.
 940
 941    Examples:
 942    - in YAML
 943        ```yaml
 944        postprocessing:
 945          - id: binarize
 946            kwargs:
 947              axis: 'channel'
 948              threshold: [0.25, 0.5, 0.75]
 949        ```
 950    - in Python:
 951        >>> postprocessing = [BinarizeDescr(
 952        ...   kwargs=BinarizeAlongAxisKwargs(
 953        ...       axis=AxisId('channel'),
 954        ...       threshold=[0.25, 0.5, 0.75],
 955        ...   )
 956        ... )]
 957    """
 958
 959    implemented_id: ClassVar[Literal["binarize"]] = "binarize"
 960    if TYPE_CHECKING:
 961        id: Literal["binarize"] = "binarize"
 962    else:
 963        id: Literal["binarize"]
 964    kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs]
 965
 966
 967class ClipDescr(ProcessingDescrBase):
 968    """Set tensor values below min to min and above max to max.
 969
 970    See `ScaleRangeDescr` for examples.
 971    """
 972
 973    implemented_id: ClassVar[Literal["clip"]] = "clip"
 974    if TYPE_CHECKING:
 975        id: Literal["clip"] = "clip"
 976    else:
 977        id: Literal["clip"]
 978
 979    kwargs: ClipKwargs
 980
 981
 982class EnsureDtypeKwargs(ProcessingKwargs):
 983    """key word arguments for `EnsureDtypeDescr`"""
 984
 985    dtype: Literal[
 986        "float32",
 987        "float64",
 988        "uint8",
 989        "int8",
 990        "uint16",
 991        "int16",
 992        "uint32",
 993        "int32",
 994        "uint64",
 995        "int64",
 996        "bool",
 997    ]
 998
 999
1000class EnsureDtypeDescr(ProcessingDescrBase):
1001    """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching).
1002
1003    This can for example be used to ensure the inner neural network model gets a
1004    different input tensor data type than the fully described bioimage.io model does.
1005
1006    Examples:
1007        The described bioimage.io model (incl. preprocessing) accepts any
1008        float32-compatible tensor, normalizes it with percentiles and clipping and then
1009        casts it to uint8, which is what the neural network in this example expects.
1010        - in YAML
1011            ```yaml
1012            inputs:
1013            - data:
1014                type: float32  # described bioimage.io model is compatible with any float32 input tensor
1015              preprocessing:
1016              - id: scale_range
1017                  kwargs:
1018                  axes: ['y', 'x']
1019                  max_percentile: 99.8
1020                  min_percentile: 5.0
1021              - id: clip
1022                  kwargs:
1023                  min: 0.0
1024                  max: 1.0
1025              - id: ensure_dtype  # the neural network of the model requires uint8
1026                  kwargs:
1027                  dtype: uint8
1028            ```
1029        - in Python:
1030            >>> preprocessing = [
1031            ...     ScaleRangeDescr(
1032            ...         kwargs=ScaleRangeKwargs(
1033            ...           axes= (AxisId('y'), AxisId('x')),
1034            ...           max_percentile= 99.8,
1035            ...           min_percentile= 5.0,
1036            ...         )
1037            ...     ),
1038            ...     ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)),
1039            ...     EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")),
1040            ... ]
1041    """
1042
1043    implemented_id: ClassVar[Literal["ensure_dtype"]] = "ensure_dtype"
1044    if TYPE_CHECKING:
1045        id: Literal["ensure_dtype"] = "ensure_dtype"
1046    else:
1047        id: Literal["ensure_dtype"]
1048
1049    kwargs: EnsureDtypeKwargs
1050
1051
1052class ScaleLinearKwargs(ProcessingKwargs):
1053    """Key word arguments for `ScaleLinearDescr`"""
1054
1055    gain: float = 1.0
1056    """multiplicative factor"""
1057
1058    offset: float = 0.0
1059    """additive term"""
1060
1061    @model_validator(mode="after")
1062    def _validate(self) -> Self:
1063        if self.gain == 1.0 and self.offset == 0.0:
1064            raise ValueError(
1065                "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
1066                + " != 0.0."
1067            )
1068
1069        return self
1070
1071
1072class ScaleLinearAlongAxisKwargs(ProcessingKwargs):
1073    """Key word arguments for `ScaleLinearDescr`"""
1074
1075    axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
1076    """The axis of gain and offset values."""
1077
1078    gain: Union[float, NotEmpty[List[float]]] = 1.0
1079    """multiplicative factor"""
1080
1081    offset: Union[float, NotEmpty[List[float]]] = 0.0
1082    """additive term"""
1083
1084    @model_validator(mode="after")
1085    def _validate(self) -> Self:
1086        if isinstance(self.gain, list):
1087            if isinstance(self.offset, list):
1088                if len(self.gain) != len(self.offset):
1089                    raise ValueError(
1090                        f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match."
1091                    )
1092            else:
1093                self.offset = [float(self.offset)] * len(self.gain)
1094        elif isinstance(self.offset, list):
1095            self.gain = [float(self.gain)] * len(self.offset)
1096        else:
1097            raise ValueError(
1098                "Do not specify an `axis` for scalar gain and offset values."
1099            )
1100
1101        if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset):
1102            raise ValueError(
1103                "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
1104                + " != 0.0."
1105            )
1106
1107        return self
1108
1109
1110class ScaleLinearDescr(ProcessingDescrBase):
1111    """Fixed linear scaling.
1112
1113    Examples:
1114      1. Scale with scalar gain and offset
1115        - in YAML
1116        ```yaml
1117        preprocessing:
1118          - id: scale_linear
1119            kwargs:
1120              gain: 2.0
1121              offset: 3.0
1122        ```
1123        - in Python:
1124        >>> preprocessing = [
1125        ...     ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0))
1126        ... ]
1127
1128      2. Independent scaling along an axis
1129        - in YAML
1130        ```yaml
1131        preprocessing:
1132          - id: scale_linear
1133            kwargs:
1134              axis: 'channel'
1135              gain: [1.0, 2.0, 3.0]
1136        ```
1137        - in Python:
1138        >>> preprocessing = [
1139        ...     ScaleLinearDescr(
1140        ...         kwargs=ScaleLinearAlongAxisKwargs(
1141        ...             axis=AxisId("channel"),
1142        ...             gain=[1.0, 2.0, 3.0],
1143        ...         )
1144        ...     )
1145        ... ]
1146
1147    """
1148
1149    implemented_id: ClassVar[Literal["scale_linear"]] = "scale_linear"
1150    if TYPE_CHECKING:
1151        id: Literal["scale_linear"] = "scale_linear"
1152    else:
1153        id: Literal["scale_linear"]
1154    kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs]
1155
1156
1157class SigmoidDescr(ProcessingDescrBase):
1158    """The logistic sigmoid function, a.k.a. expit function.
1159
1160    Examples:
1161    - in YAML
1162        ```yaml
1163        postprocessing:
1164          - id: sigmoid
1165        ```
1166    - in Python:
1167        >>> postprocessing = [SigmoidDescr()]
1168    """
1169
1170    implemented_id: ClassVar[Literal["sigmoid"]] = "sigmoid"
1171    if TYPE_CHECKING:
1172        id: Literal["sigmoid"] = "sigmoid"
1173    else:
1174        id: Literal["sigmoid"]
1175
1176    @property
1177    def kwargs(self) -> ProcessingKwargs:
1178        """empty kwargs"""
1179        return ProcessingKwargs()
1180
1181
1182class SoftmaxKwargs(ProcessingKwargs):
1183    """key word arguments for `SoftmaxDescr`"""
1184
1185    axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] = AxisId("channel")
1186    """The axis to apply the softmax function along.
1187    Note:
1188        Defaults to 'channel' axis
1189        (which may not exist, in which case
1190        a different axis id has to be specified).
1191    """
1192
1193
1194class SoftmaxDescr(ProcessingDescrBase):
1195    """The softmax function.
1196
1197    Examples:
1198    - in YAML
1199        ```yaml
1200        postprocessing:
1201          - id: softmax
1202            kwargs:
1203              axis: channel
1204        ```
1205    - in Python:
1206        >>> postprocessing = [SoftmaxDescr(kwargs=SoftmaxKwargs(axis=AxisId("channel")))]
1207    """
1208
1209    implemented_id: ClassVar[Literal["softmax"]] = "softmax"
1210    if TYPE_CHECKING:
1211        id: Literal["softmax"] = "softmax"
1212    else:
1213        id: Literal["softmax"]
1214
1215    kwargs: SoftmaxKwargs = Field(default_factory=SoftmaxKwargs.model_construct)
1216
1217
1218class FixedZeroMeanUnitVarianceKwargs(ProcessingKwargs):
1219    """key word arguments for `FixedZeroMeanUnitVarianceDescr`"""
1220
1221    mean: float
1222    """The mean value to normalize with."""
1223
1224    std: Annotated[float, Ge(1e-6)]
1225    """The standard deviation value to normalize with."""
1226
1227
1228class FixedZeroMeanUnitVarianceAlongAxisKwargs(ProcessingKwargs):
1229    """key word arguments for `FixedZeroMeanUnitVarianceDescr`"""
1230
1231    mean: NotEmpty[List[float]]
1232    """The mean value(s) to normalize with."""
1233
1234    std: NotEmpty[List[Annotated[float, Ge(1e-6)]]]
1235    """The standard deviation value(s) to normalize with.
1236    Size must match `mean` values."""
1237
1238    axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])]
1239    """The axis of the mean/std values to normalize each entry along that dimension
1240    separately."""
1241
1242    @model_validator(mode="after")
1243    def _mean_and_std_match(self) -> Self:
1244        if len(self.mean) != len(self.std):
1245            raise ValueError(
1246                f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})"
1247                + " must match."
1248            )
1249
1250        return self
1251
1252
1253class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1254    """Subtract a given mean and divide by the standard deviation.
1255
1256    Normalize with fixed, precomputed values for
1257    `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std`
1258    Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given
1259    axes.
1260
1261    Examples:
1262    1. scalar value for whole tensor
1263        - in YAML
1264        ```yaml
1265        preprocessing:
1266          - id: fixed_zero_mean_unit_variance
1267            kwargs:
1268              mean: 103.5
1269              std: 13.7
1270        ```
1271        - in Python
1272        >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
1273        ...   kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7)
1274        ... )]
1275
1276    2. independently along an axis
1277        - in YAML
1278        ```yaml
1279        preprocessing:
1280          - id: fixed_zero_mean_unit_variance
1281            kwargs:
1282              axis: channel
1283              mean: [101.5, 102.5, 103.5]
1284              std: [11.7, 12.7, 13.7]
1285        ```
1286        - in Python
1287        >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
1288        ...   kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
1289        ...     axis=AxisId("channel"),
1290        ...     mean=[101.5, 102.5, 103.5],
1291        ...     std=[11.7, 12.7, 13.7],
1292        ...   )
1293        ... )]
1294    """
1295
1296    implemented_id: ClassVar[Literal["fixed_zero_mean_unit_variance"]] = (
1297        "fixed_zero_mean_unit_variance"
1298    )
1299    if TYPE_CHECKING:
1300        id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance"
1301    else:
1302        id: Literal["fixed_zero_mean_unit_variance"]
1303
1304    kwargs: Union[
1305        FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs
1306    ]
1307
1308
1309class ZeroMeanUnitVarianceKwargs(ProcessingKwargs):
1310    """key word arguments for `ZeroMeanUnitVarianceDescr`"""
1311
1312    axes: Annotated[
1313        Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1314    ] = None
1315    """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1316    For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1317    resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1318    To normalize each sample independently leave out the 'batch' axis.
1319    Default: Scale all axes jointly."""
1320
1321    eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1322    """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`."""
1323
1324
1325class ZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1326    """Subtract mean and divide by variance.
1327
1328    Examples:
1329        Subtract tensor mean and variance
1330        - in YAML
1331        ```yaml
1332        preprocessing:
1333          - id: zero_mean_unit_variance
1334        ```
1335        - in Python
1336        >>> preprocessing = [ZeroMeanUnitVarianceDescr()]
1337    """
1338
1339    implemented_id: ClassVar[Literal["zero_mean_unit_variance"]] = (
1340        "zero_mean_unit_variance"
1341    )
1342    if TYPE_CHECKING:
1343        id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance"
1344    else:
1345        id: Literal["zero_mean_unit_variance"]
1346
1347    kwargs: ZeroMeanUnitVarianceKwargs = Field(
1348        default_factory=ZeroMeanUnitVarianceKwargs.model_construct
1349    )
1350
1351
1352class ScaleRangeKwargs(ProcessingKwargs):
1353    """key word arguments for `ScaleRangeDescr`
1354
1355    For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default)
1356    this processing step normalizes data to the [0, 1] intervall.
1357    For other percentiles the normalized values will partially be outside the [0, 1]
1358    intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the
1359    normalized values to a range.
1360    """
1361
1362    axes: Annotated[
1363        Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1364    ] = None
1365    """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value.
1366    For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1367    resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1368    To normalize samples independently, leave out the "batch" axis.
1369    Default: Scale all axes jointly."""
1370
1371    min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0
1372    """The lower percentile used to determine the value to align with zero."""
1373
1374    max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0
1375    """The upper percentile used to determine the value to align with one.
1376    Has to be bigger than `min_percentile`.
1377    The range is 1 to 100 instead of 0 to 100 to avoid mistakenly
1378    accepting percentiles specified in the range 0.0 to 1.0."""
1379
1380    eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1381    """Epsilon for numeric stability.
1382    `out = (tensor - v_lower) / (v_upper - v_lower + eps)`;
1383    with `v_lower,v_upper` values at the respective percentiles."""
1384
1385    reference_tensor: Optional[TensorId] = None
1386    """Tensor ID to compute the percentiles from. Default: The tensor itself.
1387    For any tensor in `inputs` only input tensor references are allowed."""
1388
1389    @field_validator("max_percentile", mode="after")
1390    @classmethod
1391    def min_smaller_max(cls, value: float, info: ValidationInfo) -> float:
1392        if (min_p := info.data["min_percentile"]) >= value:
1393            raise ValueError(f"min_percentile {min_p} >= max_percentile {value}")
1394
1395        return value
1396
1397
1398class ScaleRangeDescr(ProcessingDescrBase):
1399    """Scale with percentiles.
1400
1401    Examples:
1402    1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0
1403        - in YAML
1404        ```yaml
1405        preprocessing:
1406          - id: scale_range
1407            kwargs:
1408              axes: ['y', 'x']
1409              max_percentile: 99.8
1410              min_percentile: 5.0
1411        ```
1412        - in Python
1413        >>> preprocessing = [
1414        ...     ScaleRangeDescr(
1415        ...         kwargs=ScaleRangeKwargs(
1416        ...           axes= (AxisId('y'), AxisId('x')),
1417        ...           max_percentile= 99.8,
1418        ...           min_percentile= 5.0,
1419        ...         )
1420        ...     ),
1421        ...     ClipDescr(
1422        ...         kwargs=ClipKwargs(
1423        ...             min=0.0,
1424        ...             max=1.0,
1425        ...         )
1426        ...     ),
1427        ... ]
1428
1429      2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles.
1430        - in YAML
1431        ```yaml
1432        preprocessing:
1433          - id: scale_range
1434            kwargs:
1435              axes: ['y', 'x']
1436              max_percentile: 99.8
1437              min_percentile: 5.0
1438                  - id: scale_range
1439           - id: clip
1440             kwargs:
1441              min: 0.0
1442              max: 1.0
1443        ```
1444        - in Python
1445        >>> preprocessing = [ScaleRangeDescr(
1446        ...   kwargs=ScaleRangeKwargs(
1447        ...       axes= (AxisId('y'), AxisId('x')),
1448        ...       max_percentile= 99.8,
1449        ...       min_percentile= 5.0,
1450        ...   )
1451        ... )]
1452
1453    """
1454
1455    implemented_id: ClassVar[Literal["scale_range"]] = "scale_range"
1456    if TYPE_CHECKING:
1457        id: Literal["scale_range"] = "scale_range"
1458    else:
1459        id: Literal["scale_range"]
1460    kwargs: ScaleRangeKwargs = Field(default_factory=ScaleRangeKwargs.model_construct)
1461
1462
1463class ScaleMeanVarianceKwargs(ProcessingKwargs):
1464    """key word arguments for `ScaleMeanVarianceKwargs`"""
1465
1466    reference_tensor: TensorId
1467    """Name of tensor to match."""
1468
1469    axes: Annotated[
1470        Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1471    ] = None
1472    """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1473    For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1474    resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1475    To normalize samples independently, leave out the 'batch' axis.
1476    Default: Scale all axes jointly."""
1477
1478    eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1479    """Epsilon for numeric stability:
1480    `out  = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`"""
1481
1482
1483class ScaleMeanVarianceDescr(ProcessingDescrBase):
1484    """Scale a tensor's data distribution to match another tensor's mean/std.
1485    `out  = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`
1486    """
1487
1488    implemented_id: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance"
1489    if TYPE_CHECKING:
1490        id: Literal["scale_mean_variance"] = "scale_mean_variance"
1491    else:
1492        id: Literal["scale_mean_variance"]
1493    kwargs: ScaleMeanVarianceKwargs
1494
1495
1496PreprocessingDescr = Annotated[
1497    Union[
1498        BinarizeDescr,
1499        ClipDescr,
1500        EnsureDtypeDescr,
1501        FixedZeroMeanUnitVarianceDescr,
1502        ScaleLinearDescr,
1503        ScaleRangeDescr,
1504        SigmoidDescr,
1505        SoftmaxDescr,
1506        ZeroMeanUnitVarianceDescr,
1507    ],
1508    Discriminator("id"),
1509]
1510PostprocessingDescr = Annotated[
1511    Union[
1512        BinarizeDescr,
1513        ClipDescr,
1514        EnsureDtypeDescr,
1515        FixedZeroMeanUnitVarianceDescr,
1516        ScaleLinearDescr,
1517        ScaleMeanVarianceDescr,
1518        ScaleRangeDescr,
1519        SigmoidDescr,
1520        SoftmaxDescr,
1521        ZeroMeanUnitVarianceDescr,
1522    ],
1523    Discriminator("id"),
1524]
1525
1526IO_AxisT = TypeVar("IO_AxisT", InputAxis, OutputAxis)
1527
1528
1529class TensorDescrBase(Node, Generic[IO_AxisT]):
1530    id: TensorId
1531    """Tensor id. No duplicates are allowed."""
1532
1533    description: Annotated[str, MaxLen(128)] = ""
1534    """free text description"""
1535
1536    axes: NotEmpty[Sequence[IO_AxisT]]
1537    """tensor axes"""
1538
1539    @property
1540    def shape(self):
1541        return tuple(a.size for a in self.axes)
1542
1543    @field_validator("axes", mode="after", check_fields=False)
1544    @classmethod
1545    def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]:
1546        batch_axes = [a for a in axes if a.type == "batch"]
1547        if len(batch_axes) > 1:
1548            raise ValueError(
1549                f"Only one batch axis (per tensor) allowed, but got {batch_axes}"
1550            )
1551
1552        seen_ids: Set[AxisId] = set()
1553        duplicate_axes_ids: Set[AxisId] = set()
1554        for a in axes:
1555            (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id)
1556
1557        if duplicate_axes_ids:
1558            raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}")
1559
1560        return axes
1561
1562    test_tensor: FAIR[Optional[FileDescr_]] = None
1563    """An example tensor to use for testing.
1564    Using the model with the test input tensors is expected to yield the test output tensors.
1565    Each test tensor has be a an ndarray in the
1566    [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format).
1567    The file extension must be '.npy'."""
1568
1569    sample_tensor: FAIR[Optional[FileDescr_]] = None
1570    """A sample tensor to illustrate a possible input/output for the model,
1571    The sample image primarily serves to inform a human user about an example use case
1572    and is typically stored as .hdf5, .png or .tiff.
1573    It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats)
1574    (numpy's `.npy` format is not supported).
1575    The image dimensionality has to match the number of axes specified in this tensor description.
1576    """
1577
1578    @model_validator(mode="after")
1579    def _validate_sample_tensor(self) -> Self:
1580        if self.sample_tensor is None or not get_validation_context().perform_io_checks:
1581            return self
1582
1583        reader = get_reader(self.sample_tensor.source, sha256=self.sample_tensor.sha256)
1584        tensor: NDArray[Any] = imread(  # pyright: ignore[reportUnknownVariableType]
1585            reader.read(),
1586            extension=PurePosixPath(reader.original_file_name).suffix,
1587        )
1588        n_dims = len(tensor.squeeze().shape)
1589        n_dims_min = n_dims_max = len(self.axes)
1590
1591        for a in self.axes:
1592            if isinstance(a, BatchAxis):
1593                n_dims_min -= 1
1594            elif isinstance(a.size, int):
1595                if a.size == 1:
1596                    n_dims_min -= 1
1597            elif isinstance(a.size, (ParameterizedSize, DataDependentSize)):
1598                if a.size.min == 1:
1599                    n_dims_min -= 1
1600            elif isinstance(a.size, SizeReference):
1601                if a.size.offset < 2:
1602                    # size reference may result in singleton axis
1603                    n_dims_min -= 1
1604            else:
1605                assert_never(a.size)
1606
1607        n_dims_min = max(0, n_dims_min)
1608        if n_dims < n_dims_min or n_dims > n_dims_max:
1609            raise ValueError(
1610                f"Expected sample tensor to have {n_dims_min} to"
1611                + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})."
1612            )
1613
1614        return self
1615
1616    data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = (
1617        IntervalOrRatioDataDescr()
1618    )
1619    """Description of the tensor's data values, optionally per channel.
1620    If specified per channel, the data `type` needs to match across channels."""
1621
1622    @property
1623    def dtype(
1624        self,
1625    ) -> Literal[
1626        "float32",
1627        "float64",
1628        "uint8",
1629        "int8",
1630        "uint16",
1631        "int16",
1632        "uint32",
1633        "int32",
1634        "uint64",
1635        "int64",
1636        "bool",
1637    ]:
1638        """dtype as specified under `data.type` or `data[i].type`"""
1639        if isinstance(self.data, collections.abc.Sequence):
1640            return self.data[0].type
1641        else:
1642            return self.data.type
1643
1644    @field_validator("data", mode="after")
1645    @classmethod
1646    def _check_data_type_across_channels(
1647        cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]
1648    ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]:
1649        if not isinstance(value, list):
1650            return value
1651
1652        dtypes = {t.type for t in value}
1653        if len(dtypes) > 1:
1654            raise ValueError(
1655                "Tensor data descriptions per channel need to agree in their data"
1656                + f" `type`, but found {dtypes}."
1657            )
1658
1659        return value
1660
1661    @model_validator(mode="after")
1662    def _check_data_matches_channelaxis(self) -> Self:
1663        if not isinstance(self.data, (list, tuple)):
1664            return self
1665
1666        for a in self.axes:
1667            if isinstance(a, ChannelAxis):
1668                size = a.size
1669                assert isinstance(size, int)
1670                break
1671        else:
1672            return self
1673
1674        if len(self.data) != size:
1675            raise ValueError(
1676                f"Got tensor data descriptions for {len(self.data)} channels, but"
1677                + f" '{a.id}' axis has size {size}."
1678            )
1679
1680        return self
1681
1682    def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]:
1683        if len(array.shape) != len(self.axes):
1684            raise ValueError(
1685                f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})"
1686                + f" incompatible with {len(self.axes)} axes."
1687            )
1688        return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
1689
1690
1691class InputTensorDescr(TensorDescrBase[InputAxis]):
1692    id: TensorId = TensorId("input")
1693    """Input tensor id.
1694    No duplicates are allowed across all inputs and outputs."""
1695
1696    optional: bool = False
1697    """indicates that this tensor may be `None`"""
1698
1699    preprocessing: List[PreprocessingDescr] = Field(
1700        default_factory=cast(Callable[[], List[PreprocessingDescr]], list)
1701    )
1702
1703    """Description of how this input should be preprocessed.
1704
1705    notes:
1706    - If preprocessing does not start with an 'ensure_dtype' entry, it is added
1707      to ensure an input tensor's data type matches the input tensor's data description.
1708    - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an
1709      'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally
1710      changing the data type.
1711    """
1712
1713    @model_validator(mode="after")
1714    def _validate_preprocessing_kwargs(self) -> Self:
1715        axes_ids = [a.id for a in self.axes]
1716        for p in self.preprocessing:
1717            kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
1718            if kwargs_axes is None:
1719                continue
1720
1721            if not isinstance(kwargs_axes, collections.abc.Sequence):
1722                raise ValueError(
1723                    f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}"
1724                )
1725
1726            if any(a not in axes_ids for a in kwargs_axes):
1727                raise ValueError(
1728                    "`preprocessing.i.kwargs.axes` needs to be subset of axes ids"
1729                )
1730
1731        if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
1732            dtype = self.data.type
1733        else:
1734            dtype = self.data[0].type
1735
1736        # ensure `preprocessing` begins with `EnsureDtypeDescr`
1737        if not self.preprocessing or not isinstance(
1738            self.preprocessing[0], EnsureDtypeDescr
1739        ):
1740            self.preprocessing.insert(
1741                0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1742            )
1743
1744        # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
1745        if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)):
1746            self.preprocessing.append(
1747                EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1748            )
1749
1750        return self
1751
1752
1753def convert_axes(
1754    axes: str,
1755    *,
1756    shape: Union[
1757        Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4
1758    ],
1759    tensor_type: Literal["input", "output"],
1760    halo: Optional[Sequence[int]],
1761    size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
1762):
1763    ret: List[AnyAxis] = []
1764    for i, a in enumerate(axes):
1765        axis_type = _AXIS_TYPE_MAP.get(a, a)
1766        if axis_type == "batch":
1767            ret.append(BatchAxis())
1768            continue
1769
1770        scale = 1.0
1771        if isinstance(shape, _ParameterizedInputShape_v0_4):
1772            if shape.step[i] == 0:
1773                size = shape.min[i]
1774            else:
1775                size = ParameterizedSize(min=shape.min[i], step=shape.step[i])
1776        elif isinstance(shape, _ImplicitOutputShape_v0_4):
1777            ref_t = str(shape.reference_tensor)
1778            if ref_t.count(".") == 1:
1779                t_id, orig_a_id = ref_t.split(".")
1780            else:
1781                t_id = ref_t
1782                orig_a_id = a
1783
1784            a_id = _AXIS_ID_MAP.get(orig_a_id, a)
1785            if not (orig_scale := shape.scale[i]):
1786                # old way to insert a new axis dimension
1787                size = int(2 * shape.offset[i])
1788            else:
1789                scale = 1 / orig_scale
1790                if axis_type in ("channel", "index"):
1791                    # these axes no longer have a scale
1792                    offset_from_scale = orig_scale * size_refs.get(
1793                        _TensorName_v0_4(t_id), {}
1794                    ).get(orig_a_id, 0)
1795                else:
1796                    offset_from_scale = 0
1797                size = SizeReference(
1798                    tensor_id=TensorId(t_id),
1799                    axis_id=AxisId(a_id),
1800                    offset=int(offset_from_scale + 2 * shape.offset[i]),
1801                )
1802        else:
1803            size = shape[i]
1804
1805        if axis_type == "time":
1806            if tensor_type == "input":
1807                ret.append(TimeInputAxis(size=size, scale=scale))
1808            else:
1809                assert not isinstance(size, ParameterizedSize)
1810                if halo is None:
1811                    ret.append(TimeOutputAxis(size=size, scale=scale))
1812                else:
1813                    assert not isinstance(size, int)
1814                    ret.append(
1815                        TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i])
1816                    )
1817
1818        elif axis_type == "index":
1819            if tensor_type == "input":
1820                ret.append(IndexInputAxis(size=size))
1821            else:
1822                if isinstance(size, ParameterizedSize):
1823                    size = DataDependentSize(min=size.min)
1824
1825                ret.append(IndexOutputAxis(size=size))
1826        elif axis_type == "channel":
1827            assert not isinstance(size, ParameterizedSize)
1828            if isinstance(size, SizeReference):
1829                warnings.warn(
1830                    "Conversion of channel size from an implicit output shape may be"
1831                    + " wrong"
1832                )
1833                ret.append(
1834                    ChannelAxis(
1835                        channel_names=[
1836                            Identifier(f"channel{i}") for i in range(size.offset)
1837                        ]
1838                    )
1839                )
1840            else:
1841                ret.append(
1842                    ChannelAxis(
1843                        channel_names=[Identifier(f"channel{i}") for i in range(size)]
1844                    )
1845                )
1846        elif axis_type == "space":
1847            if tensor_type == "input":
1848                ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale))
1849            else:
1850                assert not isinstance(size, ParameterizedSize)
1851                if halo is None or halo[i] == 0:
1852                    ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale))
1853                elif isinstance(size, int):
1854                    raise NotImplementedError(
1855                        f"output axis with halo and fixed size (here {size}) not allowed"
1856                    )
1857                else:
1858                    ret.append(
1859                        SpaceOutputAxisWithHalo(
1860                            id=AxisId(a), size=size, scale=scale, halo=halo[i]
1861                        )
1862                    )
1863
1864    return ret
1865
1866
1867def _axes_letters_to_ids(
1868    axes: Optional[str],
1869) -> Optional[List[AxisId]]:
1870    if axes is None:
1871        return None
1872
1873    return [AxisId(a) for a in axes]
1874
1875
1876def _get_complement_v04_axis(
1877    tensor_axes: Sequence[str], axes: Optional[Sequence[str]]
1878) -> Optional[AxisId]:
1879    if axes is None:
1880        return None
1881
1882    non_complement_axes = set(axes) | {"b"}
1883    complement_axes = [a for a in tensor_axes if a not in non_complement_axes]
1884    if len(complement_axes) > 1:
1885        raise ValueError(
1886            f"Expected none or a single complement axis, but axes '{axes}' "
1887            + f"for tensor dims '{tensor_axes}' leave '{complement_axes}'."
1888        )
1889
1890    return None if not complement_axes else AxisId(complement_axes[0])
1891
1892
1893def _convert_proc(
1894    p: Union[_PreprocessingDescr_v0_4, _PostprocessingDescr_v0_4],
1895    tensor_axes: Sequence[str],
1896) -> Union[PreprocessingDescr, PostprocessingDescr]:
1897    if isinstance(p, _BinarizeDescr_v0_4):
1898        return BinarizeDescr(kwargs=BinarizeKwargs(threshold=p.kwargs.threshold))
1899    elif isinstance(p, _ClipDescr_v0_4):
1900        return ClipDescr(kwargs=ClipKwargs(min=p.kwargs.min, max=p.kwargs.max))
1901    elif isinstance(p, _SigmoidDescr_v0_4):
1902        return SigmoidDescr()
1903    elif isinstance(p, _ScaleLinearDescr_v0_4):
1904        axes = _axes_letters_to_ids(p.kwargs.axes)
1905        if p.kwargs.axes is None:
1906            axis = None
1907        else:
1908            axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes)
1909
1910        if axis is None:
1911            assert not isinstance(p.kwargs.gain, list)
1912            assert not isinstance(p.kwargs.offset, list)
1913            kwargs = ScaleLinearKwargs(gain=p.kwargs.gain, offset=p.kwargs.offset)
1914        else:
1915            kwargs = ScaleLinearAlongAxisKwargs(
1916                axis=axis, gain=p.kwargs.gain, offset=p.kwargs.offset
1917            )
1918        return ScaleLinearDescr(kwargs=kwargs)
1919    elif isinstance(p, _ScaleMeanVarianceDescr_v0_4):
1920        return ScaleMeanVarianceDescr(
1921            kwargs=ScaleMeanVarianceKwargs(
1922                axes=_axes_letters_to_ids(p.kwargs.axes),
1923                reference_tensor=TensorId(str(p.kwargs.reference_tensor)),
1924                eps=p.kwargs.eps,
1925            )
1926        )
1927    elif isinstance(p, _ZeroMeanUnitVarianceDescr_v0_4):
1928        if p.kwargs.mode == "fixed":
1929            mean = p.kwargs.mean
1930            std = p.kwargs.std
1931            assert mean is not None
1932            assert std is not None
1933
1934            axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes)
1935
1936            if axis is None:
1937                if isinstance(mean, list):
1938                    raise ValueError("Expected single float value for mean, not <list>")
1939                if isinstance(std, list):
1940                    raise ValueError("Expected single float value for std, not <list>")
1941                return FixedZeroMeanUnitVarianceDescr(
1942                    kwargs=FixedZeroMeanUnitVarianceKwargs.model_construct(
1943                        mean=mean,
1944                        std=std,
1945                    )
1946                )
1947            else:
1948                if not isinstance(mean, list):
1949                    mean = [float(mean)]
1950                if not isinstance(std, list):
1951                    std = [float(std)]
1952
1953                return FixedZeroMeanUnitVarianceDescr(
1954                    kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
1955                        axis=axis, mean=mean, std=std
1956                    )
1957                )
1958
1959        else:
1960            axes = _axes_letters_to_ids(p.kwargs.axes) or []
1961            if p.kwargs.mode == "per_dataset":
1962                axes = [AxisId("batch")] + axes
1963            if not axes:
1964                axes = None
1965            return ZeroMeanUnitVarianceDescr(
1966                kwargs=ZeroMeanUnitVarianceKwargs(axes=axes, eps=p.kwargs.eps)
1967            )
1968
1969    elif isinstance(p, _ScaleRangeDescr_v0_4):
1970        return ScaleRangeDescr(
1971            kwargs=ScaleRangeKwargs(
1972                axes=_axes_letters_to_ids(p.kwargs.axes),
1973                min_percentile=p.kwargs.min_percentile,
1974                max_percentile=p.kwargs.max_percentile,
1975                eps=p.kwargs.eps,
1976            )
1977        )
1978    else:
1979        assert_never(p)
1980
1981
1982class _InputTensorConv(
1983    Converter[
1984        _InputTensorDescr_v0_4,
1985        InputTensorDescr,
1986        FileSource_,
1987        Optional[FileSource_],
1988        Mapping[_TensorName_v0_4, Mapping[str, int]],
1989    ]
1990):
1991    def _convert(
1992        self,
1993        src: _InputTensorDescr_v0_4,
1994        tgt: "type[InputTensorDescr] | type[dict[str, Any]]",
1995        test_tensor: FileSource_,
1996        sample_tensor: Optional[FileSource_],
1997        size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
1998    ) -> "InputTensorDescr | dict[str, Any]":
1999        axes: List[InputAxis] = convert_axes(  # pyright: ignore[reportAssignmentType]
2000            src.axes,
2001            shape=src.shape,
2002            tensor_type="input",
2003            halo=None,
2004            size_refs=size_refs,
2005        )
2006        prep: List[PreprocessingDescr] = []
2007        for p in src.preprocessing:
2008            cp = _convert_proc(p, src.axes)
2009            assert not isinstance(cp, ScaleMeanVarianceDescr)
2010            prep.append(cp)
2011
2012        prep.append(EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="float32")))
2013
2014        return tgt(
2015            axes=axes,
2016            id=TensorId(str(src.name)),
2017            test_tensor=FileDescr(source=test_tensor),
2018            sample_tensor=(
2019                None if sample_tensor is None else FileDescr(source=sample_tensor)
2020            ),
2021            data=dict(type=src.data_type),  # pyright: ignore[reportArgumentType]
2022            preprocessing=prep,
2023        )
2024
2025
2026_input_tensor_conv = _InputTensorConv(_InputTensorDescr_v0_4, InputTensorDescr)
2027
2028
2029class OutputTensorDescr(TensorDescrBase[OutputAxis]):
2030    id: TensorId = TensorId("output")
2031    """Output tensor id.
2032    No duplicates are allowed across all inputs and outputs."""
2033
2034    postprocessing: List[PostprocessingDescr] = Field(
2035        default_factory=cast(Callable[[], List[PostprocessingDescr]], list)
2036    )
2037    """Description of how this output should be postprocessed.
2038
2039    note: `postprocessing` always ends with an 'ensure_dtype' operation.
2040          If not given this is added to cast to this tensor's `data.type`.
2041    """
2042
2043    @model_validator(mode="after")
2044    def _validate_postprocessing_kwargs(self) -> Self:
2045        axes_ids = [a.id for a in self.axes]
2046        for p in self.postprocessing:
2047            kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
2048            if kwargs_axes is None:
2049                continue
2050
2051            if not isinstance(kwargs_axes, collections.abc.Sequence):
2052                raise ValueError(
2053                    f"expected `axes` sequence, but got {type(kwargs_axes)}"
2054                )
2055
2056            if any(a not in axes_ids for a in kwargs_axes):
2057                raise ValueError("`kwargs.axes` needs to be subset of axes ids")
2058
2059        if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
2060            dtype = self.data.type
2061        else:
2062            dtype = self.data[0].type
2063
2064        # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
2065        if not self.postprocessing or not isinstance(
2066            self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)
2067        ):
2068            self.postprocessing.append(
2069                EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
2070            )
2071        return self
2072
2073
2074class _OutputTensorConv(
2075    Converter[
2076        _OutputTensorDescr_v0_4,
2077        OutputTensorDescr,
2078        FileSource_,
2079        Optional[FileSource_],
2080        Mapping[_TensorName_v0_4, Mapping[str, int]],
2081    ]
2082):
2083    def _convert(
2084        self,
2085        src: _OutputTensorDescr_v0_4,
2086        tgt: "type[OutputTensorDescr] | type[dict[str, Any]]",
2087        test_tensor: FileSource_,
2088        sample_tensor: Optional[FileSource_],
2089        size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
2090    ) -> "OutputTensorDescr | dict[str, Any]":
2091        # TODO: split convert_axes into convert_output_axes and convert_input_axes
2092        axes: List[OutputAxis] = convert_axes(  # pyright: ignore[reportAssignmentType]
2093            src.axes,
2094            shape=src.shape,
2095            tensor_type="output",
2096            halo=src.halo,
2097            size_refs=size_refs,
2098        )
2099        data_descr: Dict[str, Any] = dict(type=src.data_type)
2100        if data_descr["type"] == "bool":
2101            data_descr["values"] = [False, True]
2102
2103        return tgt(
2104            axes=axes,
2105            id=TensorId(str(src.name)),
2106            test_tensor=FileDescr(source=test_tensor),
2107            sample_tensor=(
2108                None if sample_tensor is None else FileDescr(source=sample_tensor)
2109            ),
2110            data=data_descr,  # pyright: ignore[reportArgumentType]
2111            postprocessing=[_convert_proc(p, src.axes) for p in src.postprocessing],
2112        )
2113
2114
2115_output_tensor_conv = _OutputTensorConv(_OutputTensorDescr_v0_4, OutputTensorDescr)
2116
2117
2118TensorDescr = Union[InputTensorDescr, OutputTensorDescr]
2119
2120
2121def validate_tensors(
2122    tensors: Mapping[TensorId, Tuple[TensorDescr, Optional[NDArray[Any]]]],
2123    tensor_origin: Literal[
2124        "test_tensor"
2125    ],  # for more precise error messages, e.g. 'test_tensor'
2126):
2127    all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, Optional[int]]]] = {}
2128
2129    def e_msg(d: TensorDescr):
2130        return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]"
2131
2132    for descr, array in tensors.values():
2133        if array is None:
2134            axis_sizes = {a.id: None for a in descr.axes}
2135        else:
2136            try:
2137                axis_sizes = descr.get_axis_sizes_for_array(array)
2138            except ValueError as e:
2139                raise ValueError(f"{e_msg(descr)} {e}")
2140
2141        all_tensor_axes[descr.id] = {a.id: (a, axis_sizes[a.id]) for a in descr.axes}
2142
2143    for descr, array in tensors.values():
2144        if array is None:
2145            continue
2146
2147        if descr.dtype in ("float32", "float64"):
2148            invalid_test_tensor_dtype = array.dtype.name not in (
2149                "float32",
2150                "float64",
2151                "uint8",
2152                "int8",
2153                "uint16",
2154                "int16",
2155                "uint32",
2156                "int32",
2157                "uint64",
2158                "int64",
2159            )
2160        else:
2161            invalid_test_tensor_dtype = array.dtype.name != descr.dtype
2162
2163        if invalid_test_tensor_dtype:
2164            raise ValueError(
2165                f"{e_msg(descr)}.{tensor_origin}.dtype '{array.dtype.name}' does not"
2166                + f" match described dtype '{descr.dtype}'"
2167            )
2168
2169        if array.min() > -1e-4 and array.max() < 1e-4:
2170            raise ValueError(
2171                "Output values are too small for reliable testing."
2172                + f" Values <-1e5 or >=1e5 must be present in {tensor_origin}"
2173            )
2174
2175        for a in descr.axes:
2176            actual_size = all_tensor_axes[descr.id][a.id][1]
2177            if actual_size is None:
2178                continue
2179
2180            if a.size is None:
2181                continue
2182
2183            if isinstance(a.size, int):
2184                if actual_size != a.size:
2185                    raise ValueError(
2186                        f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' "
2187                        + f"has incompatible size {actual_size}, expected {a.size}"
2188                    )
2189            elif isinstance(a.size, ParameterizedSize):
2190                _ = a.size.validate_size(actual_size)
2191            elif isinstance(a.size, DataDependentSize):
2192                _ = a.size.validate_size(actual_size)
2193            elif isinstance(a.size, SizeReference):
2194                ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id)
2195                if ref_tensor_axes is None:
2196                    raise ValueError(
2197                        f"{e_msg(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor"
2198                        + f" reference '{a.size.tensor_id}'"
2199                    )
2200
2201                ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None))
2202                if ref_axis is None or ref_size is None:
2203                    raise ValueError(
2204                        f"{e_msg(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis"
2205                        + f" reference '{a.size.tensor_id}.{a.size.axis_id}"
2206                    )
2207
2208                if a.unit != ref_axis.unit:
2209                    raise ValueError(
2210                        f"{e_msg(descr)}.axes[{a.id}].size: `SizeReference` requires"
2211                        + " axis and reference axis to have the same `unit`, but"
2212                        + f" {a.unit}!={ref_axis.unit}"
2213                    )
2214
2215                if actual_size != (
2216                    expected_size := (
2217                        ref_size * ref_axis.scale / a.scale + a.size.offset
2218                    )
2219                ):
2220                    raise ValueError(
2221                        f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' of size"
2222                        + f" {actual_size} invalid for referenced size {ref_size};"
2223                        + f" expected {expected_size}"
2224                    )
2225            else:
2226                assert_never(a.size)
2227
2228
2229FileDescr_dependencies = Annotated[
2230    FileDescr_,
2231    WithSuffix((".yaml", ".yml"), case_sensitive=True),
2232    Field(examples=[dict(source="environment.yaml")]),
2233]
2234
2235
2236class _ArchitectureCallableDescr(Node):
2237    callable: Annotated[Identifier, Field(examples=["MyNetworkClass", "get_my_model"])]
2238    """Identifier of the callable that returns a torch.nn.Module instance."""
2239
2240    kwargs: Dict[str, YamlValue] = Field(
2241        default_factory=cast(Callable[[], Dict[str, YamlValue]], dict)
2242    )
2243    """key word arguments for the `callable`"""
2244
2245
2246class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr):
2247    source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2248    """Architecture source file"""
2249
2250    @model_serializer(mode="wrap", when_used="unless-none")
2251    def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo):
2252        return package_file_descr_serializer(self, nxt, info)
2253
2254
2255class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr):
2256    import_from: str
2257    """Where to import the callable from, i.e. `from <import_from> import <callable>`"""
2258
2259
2260class _ArchFileConv(
2261    Converter[
2262        _CallableFromFile_v0_4,
2263        ArchitectureFromFileDescr,
2264        Optional[Sha256],
2265        Dict[str, Any],
2266    ]
2267):
2268    def _convert(
2269        self,
2270        src: _CallableFromFile_v0_4,
2271        tgt: "type[ArchitectureFromFileDescr | dict[str, Any]]",
2272        sha256: Optional[Sha256],
2273        kwargs: Dict[str, Any],
2274    ) -> "ArchitectureFromFileDescr | dict[str, Any]":
2275        if src.startswith("http") and src.count(":") == 2:
2276            http, source, callable_ = src.split(":")
2277            source = ":".join((http, source))
2278        elif not src.startswith("http") and src.count(":") == 1:
2279            source, callable_ = src.split(":")
2280        else:
2281            source = str(src)
2282            callable_ = str(src)
2283        return tgt(
2284            callable=Identifier(callable_),
2285            source=cast(FileSource_, source),
2286            sha256=sha256,
2287            kwargs=kwargs,
2288        )
2289
2290
2291_arch_file_conv = _ArchFileConv(_CallableFromFile_v0_4, ArchitectureFromFileDescr)
2292
2293
2294class _ArchLibConv(
2295    Converter[
2296        _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr, Dict[str, Any]
2297    ]
2298):
2299    def _convert(
2300        self,
2301        src: _CallableFromDepencency_v0_4,
2302        tgt: "type[ArchitectureFromLibraryDescr | dict[str, Any]]",
2303        kwargs: Dict[str, Any],
2304    ) -> "ArchitectureFromLibraryDescr | dict[str, Any]":
2305        *mods, callable_ = src.split(".")
2306        import_from = ".".join(mods)
2307        return tgt(
2308            import_from=import_from, callable=Identifier(callable_), kwargs=kwargs
2309        )
2310
2311
2312_arch_lib_conv = _ArchLibConv(
2313    _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr
2314)
2315
2316
2317class WeightsEntryDescrBase(FileDescr):
2318    type: ClassVar[WeightsFormat]
2319    weights_format_name: ClassVar[str]  # human readable
2320
2321    source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2322    """Source of the weights file."""
2323
2324    authors: Optional[List[Author]] = None
2325    """Authors
2326    Either the person(s) that have trained this model resulting in the original weights file.
2327        (If this is the initial weights entry, i.e. it does not have a `parent`)
2328    Or the person(s) who have converted the weights to this weights format.
2329        (If this is a child weight, i.e. it has a `parent` field)
2330    """
2331
2332    parent: Annotated[
2333        Optional[WeightsFormat], Field(examples=["pytorch_state_dict"])
2334    ] = None
2335    """The source weights these weights were converted from.
2336    For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`,
2337    The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights.
2338    All weight entries except one (the initial set of weights resulting from training the model),
2339    need to have this field."""
2340
2341    comment: str = ""
2342    """A comment about this weights entry, for example how these weights were created."""
2343
2344    @model_validator(mode="after")
2345    def _validate(self) -> Self:
2346        if self.type == self.parent:
2347            raise ValueError("Weights entry can't be it's own parent.")
2348
2349        return self
2350
2351    @model_serializer(mode="wrap", when_used="unless-none")
2352    def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo):
2353        return package_file_descr_serializer(self, nxt, info)
2354
2355
2356class KerasHdf5WeightsDescr(WeightsEntryDescrBase):
2357    type = "keras_hdf5"
2358    weights_format_name: ClassVar[str] = "Keras HDF5"
2359    tensorflow_version: Version
2360    """TensorFlow version used to create these weights."""
2361
2362
2363FileDescr_external_data = Annotated[
2364    FileDescr_,
2365    WithSuffix(".data", case_sensitive=True),
2366    Field(examples=[dict(source="weights.onnx.data")]),
2367]
2368
2369
2370class OnnxWeightsDescr(WeightsEntryDescrBase):
2371    type = "onnx"
2372    weights_format_name: ClassVar[str] = "ONNX"
2373    opset_version: Annotated[int, Ge(7)]
2374    """ONNX opset version"""
2375
2376    external_data: Optional[FileDescr_external_data] = None
2377    """Source of the external ONNX data file holding the weights.
2378    (If present **source** holds the ONNX architecture without weights)."""
2379
2380    @model_validator(mode="after")
2381    def _validate_external_data_unique_file_name(self) -> Self:
2382        if self.external_data is not None and (
2383            extract_file_name(self.source)
2384            == extract_file_name(self.external_data.source)
2385        ):
2386            raise ValueError(
2387                f"ONNX `external_data` file name '{extract_file_name(self.external_data.source)}'"
2388                + " must be different from ONNX `source` file name."
2389            )
2390
2391        return self
2392
2393
2394class PytorchStateDictWeightsDescr(WeightsEntryDescrBase):
2395    type = "pytorch_state_dict"
2396    weights_format_name: ClassVar[str] = "Pytorch State Dict"
2397    architecture: Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr]
2398    pytorch_version: Version
2399    """Version of the PyTorch library used.
2400    If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible.
2401    """
2402    dependencies: Optional[FileDescr_dependencies] = None
2403    """Custom depencies beyond pytorch described in a Conda environment file.
2404    Allows to specify custom dependencies, see conda docs:
2405    - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms)
2406    - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually)
2407
2408    The conda environment file should include pytorch and any version pinning has to be compatible with
2409    **pytorch_version**.
2410    """
2411
2412
2413class TensorflowJsWeightsDescr(WeightsEntryDescrBase):
2414    type = "tensorflow_js"
2415    weights_format_name: ClassVar[str] = "Tensorflow.js"
2416    tensorflow_version: Version
2417    """Version of the TensorFlow library used."""
2418
2419    source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2420    """The multi-file weights.
2421    All required files/folders should be a zip archive."""
2422
2423
2424class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase):
2425    type = "tensorflow_saved_model_bundle"
2426    weights_format_name: ClassVar[str] = "Tensorflow Saved Model"
2427    tensorflow_version: Version
2428    """Version of the TensorFlow library used."""
2429
2430    dependencies: Optional[FileDescr_dependencies] = None
2431    """Custom dependencies beyond tensorflow.
2432    Should include tensorflow and any version pinning has to be compatible with **tensorflow_version**."""
2433
2434    source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2435    """The multi-file weights.
2436    All required files/folders should be a zip archive."""
2437
2438
2439class TorchscriptWeightsDescr(WeightsEntryDescrBase):
2440    type = "torchscript"
2441    weights_format_name: ClassVar[str] = "TorchScript"
2442    pytorch_version: Version
2443    """Version of the PyTorch library used."""
2444
2445
2446class WeightsDescr(Node):
2447    keras_hdf5: Optional[KerasHdf5WeightsDescr] = None
2448    onnx: Optional[OnnxWeightsDescr] = None
2449    pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None
2450    tensorflow_js: Optional[TensorflowJsWeightsDescr] = None
2451    tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = (
2452        None
2453    )
2454    torchscript: Optional[TorchscriptWeightsDescr] = None
2455
2456    @model_validator(mode="after")
2457    def check_entries(self) -> Self:
2458        entries = {wtype for wtype, entry in self if entry is not None}
2459
2460        if not entries:
2461            raise ValueError("Missing weights entry")
2462
2463        entries_wo_parent = {
2464            wtype
2465            for wtype, entry in self
2466            if entry is not None and hasattr(entry, "parent") and entry.parent is None
2467        }
2468        if len(entries_wo_parent) != 1:
2469            issue_warning(
2470                "Exactly one weights entry may not specify the `parent` field (got"
2471                + " {value}). That entry is considered the original set of model weights."
2472                + " Other weight formats are created through conversion of the orignal or"
2473                + " already converted weights. They have to reference the weights format"
2474                + " they were converted from as their `parent`.",
2475                value=len(entries_wo_parent),
2476                field="weights",
2477            )
2478
2479        for wtype, entry in self:
2480            if entry is None:
2481                continue
2482
2483            assert hasattr(entry, "type")
2484            assert hasattr(entry, "parent")
2485            assert wtype == entry.type
2486            if (
2487                entry.parent is not None and entry.parent not in entries
2488            ):  # self reference checked for `parent` field
2489                raise ValueError(
2490                    f"`weights.{wtype}.parent={entry.parent} not in specified weight"
2491                    + f" formats: {entries}"
2492                )
2493
2494        return self
2495
2496    def __getitem__(
2497        self,
2498        key: Literal[
2499            "keras_hdf5",
2500            "onnx",
2501            "pytorch_state_dict",
2502            "tensorflow_js",
2503            "tensorflow_saved_model_bundle",
2504            "torchscript",
2505        ],
2506    ):
2507        if key == "keras_hdf5":
2508            ret = self.keras_hdf5
2509        elif key == "onnx":
2510            ret = self.onnx
2511        elif key == "pytorch_state_dict":
2512            ret = self.pytorch_state_dict
2513        elif key == "tensorflow_js":
2514            ret = self.tensorflow_js
2515        elif key == "tensorflow_saved_model_bundle":
2516            ret = self.tensorflow_saved_model_bundle
2517        elif key == "torchscript":
2518            ret = self.torchscript
2519        else:
2520            raise KeyError(key)
2521
2522        if ret is None:
2523            raise KeyError(key)
2524
2525        return ret
2526
2527    @property
2528    def available_formats(self):
2529        return {
2530            **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}),
2531            **({} if self.onnx is None else {"onnx": self.onnx}),
2532            **(
2533                {}
2534                if self.pytorch_state_dict is None
2535                else {"pytorch_state_dict": self.pytorch_state_dict}
2536            ),
2537            **(
2538                {}
2539                if self.tensorflow_js is None
2540                else {"tensorflow_js": self.tensorflow_js}
2541            ),
2542            **(
2543                {}
2544                if self.tensorflow_saved_model_bundle is None
2545                else {
2546                    "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle
2547                }
2548            ),
2549            **({} if self.torchscript is None else {"torchscript": self.torchscript}),
2550        }
2551
2552    @property
2553    def missing_formats(self):
2554        return {
2555            wf for wf in get_args(WeightsFormat) if wf not in self.available_formats
2556        }
2557
2558
2559class ModelId(ResourceId):
2560    pass
2561
2562
2563class LinkedModel(LinkedResourceBase):
2564    """Reference to a bioimage.io model."""
2565
2566    id: ModelId
2567    """A valid model `id` from the bioimage.io collection."""
2568
2569
2570class _DataDepSize(NamedTuple):
2571    min: StrictInt
2572    max: Optional[StrictInt]
2573
2574
2575class _AxisSizes(NamedTuple):
2576    """the lenghts of all axes of model inputs and outputs"""
2577
2578    inputs: Dict[Tuple[TensorId, AxisId], int]
2579    outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]]
2580
2581
2582class _TensorSizes(NamedTuple):
2583    """_AxisSizes as nested dicts"""
2584
2585    inputs: Dict[TensorId, Dict[AxisId, int]]
2586    outputs: Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]
2587
2588
2589class ReproducibilityTolerance(Node, extra="allow"):
2590    """Describes what small numerical differences -- if any -- may be tolerated
2591    in the generated output when executing in different environments.
2592
2593    A tensor element *output* is considered mismatched to the **test_tensor** if
2594    abs(*output* - **test_tensor**) > **absolute_tolerance** + **relative_tolerance** * abs(**test_tensor**).
2595    (Internally we call [numpy.testing.assert_allclose](https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html).)
2596
2597    Motivation:
2598        For testing we can request the respective deep learning frameworks to be as
2599        reproducible as possible by setting seeds and chosing deterministic algorithms,
2600        but differences in operating systems, available hardware and installed drivers
2601        may still lead to numerical differences.
2602    """
2603
2604    relative_tolerance: RelativeTolerance = 1e-3
2605    """Maximum relative tolerance of reproduced test tensor."""
2606
2607    absolute_tolerance: AbsoluteTolerance = 1e-4
2608    """Maximum absolute tolerance of reproduced test tensor."""
2609
2610    mismatched_elements_per_million: MismatchedElementsPerMillion = 100
2611    """Maximum number of mismatched elements/pixels per million to tolerate."""
2612
2613    output_ids: Sequence[TensorId] = ()
2614    """Limits the output tensor IDs these reproducibility details apply to."""
2615
2616    weights_formats: Sequence[WeightsFormat] = ()
2617    """Limits the weights formats these details apply to."""
2618
2619
2620class BioimageioConfig(Node, extra="allow"):
2621    reproducibility_tolerance: Sequence[ReproducibilityTolerance] = ()
2622    """Tolerances to allow when reproducing the model's test outputs
2623    from the model's test inputs.
2624    Only the first entry matching tensor id and weights format is considered.
2625    """
2626
2627
2628class Config(Node, extra="allow"):
2629    bioimageio: BioimageioConfig = Field(
2630        default_factory=BioimageioConfig.model_construct
2631    )
2632
2633
2634class ModelDescr(GenericModelDescrBase):
2635    """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights.
2636    These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
2637    """
2638
2639    implemented_format_version: ClassVar[Literal["0.5.6"]] = "0.5.6"
2640    if TYPE_CHECKING:
2641        format_version: Literal["0.5.6"] = "0.5.6"
2642    else:
2643        format_version: Literal["0.5.6"]
2644        """Version of the bioimage.io model description specification used.
2645        When creating a new model always use the latest micro/patch version described here.
2646        The `format_version` is important for any consumer software to understand how to parse the fields.
2647        """
2648
2649    implemented_type: ClassVar[Literal["model"]] = "model"
2650    if TYPE_CHECKING:
2651        type: Literal["model"] = "model"
2652    else:
2653        type: Literal["model"]
2654        """Specialized resource type 'model'"""
2655
2656    id: Optional[ModelId] = None
2657    """bioimage.io-wide unique resource identifier
2658    assigned by bioimage.io; version **un**specific."""
2659
2660    authors: FAIR[List[Author]] = Field(
2661        default_factory=cast(Callable[[], List[Author]], list)
2662    )
2663    """The authors are the creators of the model RDF and the primary points of contact."""
2664
2665    documentation: FAIR[Optional[FileSource_documentation]] = None
2666    """URL or relative path to a markdown file with additional documentation.
2667    The recommended documentation file name is `README.md`. An `.md` suffix is mandatory.
2668    The documentation should include a '#[#] Validation' (sub)section
2669    with details on how to quantitatively validate the model on unseen data."""
2670
2671    @field_validator("documentation", mode="after")
2672    @classmethod
2673    def _validate_documentation(
2674        cls, value: Optional[FileSource_documentation]
2675    ) -> Optional[FileSource_documentation]:
2676        if not get_validation_context().perform_io_checks or value is None:
2677            return value
2678
2679        doc_reader = get_reader(value)
2680        doc_content = doc_reader.read().decode(encoding="utf-8")
2681        if not re.search("#.*[vV]alidation", doc_content):
2682            issue_warning(
2683                "No '# Validation' (sub)section found in {value}.",
2684                value=value,
2685                field="documentation",
2686            )
2687
2688        return value
2689
2690    inputs: NotEmpty[Sequence[InputTensorDescr]]
2691    """Describes the input tensors expected by this model."""
2692
2693    @field_validator("inputs", mode="after")
2694    @classmethod
2695    def _validate_input_axes(
2696        cls, inputs: Sequence[InputTensorDescr]
2697    ) -> Sequence[InputTensorDescr]:
2698        input_size_refs = cls._get_axes_with_independent_size(inputs)
2699
2700        for i, ipt in enumerate(inputs):
2701            valid_independent_refs: Dict[
2702                Tuple[TensorId, AxisId],
2703                Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2704            ] = {
2705                **{
2706                    (ipt.id, a.id): (ipt, a, a.size)
2707                    for a in ipt.axes
2708                    if not isinstance(a, BatchAxis)
2709                    and isinstance(a.size, (int, ParameterizedSize))
2710                },
2711                **input_size_refs,
2712            }
2713            for a, ax in enumerate(ipt.axes):
2714                cls._validate_axis(
2715                    "inputs",
2716                    i=i,
2717                    tensor_id=ipt.id,
2718                    a=a,
2719                    axis=ax,
2720                    valid_independent_refs=valid_independent_refs,
2721                )
2722        return inputs
2723
2724    @staticmethod
2725    def _validate_axis(
2726        field_name: str,
2727        i: int,
2728        tensor_id: TensorId,
2729        a: int,
2730        axis: AnyAxis,
2731        valid_independent_refs: Dict[
2732            Tuple[TensorId, AxisId],
2733            Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2734        ],
2735    ):
2736        if isinstance(axis, BatchAxis) or isinstance(
2737            axis.size, (int, ParameterizedSize, DataDependentSize)
2738        ):
2739            return
2740        elif not isinstance(axis.size, SizeReference):
2741            assert_never(axis.size)
2742
2743        # validate axis.size SizeReference
2744        ref = (axis.size.tensor_id, axis.size.axis_id)
2745        if ref not in valid_independent_refs:
2746            raise ValueError(
2747                "Invalid tensor axis reference at"
2748                + f" {field_name}[{i}].axes[{a}].size: {axis.size}."
2749            )
2750        if ref == (tensor_id, axis.id):
2751            raise ValueError(
2752                "Self-referencing not allowed for"
2753                + f" {field_name}[{i}].axes[{a}].size: {axis.size}"
2754            )
2755        if axis.type == "channel":
2756            if valid_independent_refs[ref][1].type != "channel":
2757                raise ValueError(
2758                    "A channel axis' size may only reference another fixed size"
2759                    + " channel axis."
2760                )
2761            if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names:
2762                ref_size = valid_independent_refs[ref][2]
2763                assert isinstance(ref_size, int), (
2764                    "channel axis ref (another channel axis) has to specify fixed"
2765                    + " size"
2766                )
2767                generated_channel_names = [
2768                    Identifier(axis.channel_names.format(i=i))
2769                    for i in range(1, ref_size + 1)
2770                ]
2771                axis.channel_names = generated_channel_names
2772
2773        if (ax_unit := getattr(axis, "unit", None)) != (
2774            ref_unit := getattr(valid_independent_refs[ref][1], "unit", None)
2775        ):
2776            raise ValueError(
2777                "The units of an axis and its reference axis need to match, but"
2778                + f" '{ax_unit}' != '{ref_unit}'."
2779            )
2780        ref_axis = valid_independent_refs[ref][1]
2781        if isinstance(ref_axis, BatchAxis):
2782            raise ValueError(
2783                f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}"
2784                + " (a batch axis is not allowed as reference)."
2785            )
2786
2787        if isinstance(axis, WithHalo):
2788            min_size = axis.size.get_size(axis, ref_axis, n=0)
2789            if (min_size - 2 * axis.halo) < 1:
2790                raise ValueError(
2791                    f"axis {axis.id} with minimum size {min_size} is too small for halo"
2792                    + f" {axis.halo}."
2793                )
2794
2795            input_halo = axis.halo * axis.scale / ref_axis.scale
2796            if input_halo != int(input_halo) or input_halo % 2 == 1:
2797                raise ValueError(
2798                    f"input_halo {input_halo} (output_halo {axis.halo} *"
2799                    + f" output_scale {axis.scale} / input_scale {ref_axis.scale})"
2800                    + f"     {tensor_id}.{axis.id}."
2801                )
2802
2803    @model_validator(mode="after")
2804    def _validate_test_tensors(self) -> Self:
2805        if not get_validation_context().perform_io_checks:
2806            return self
2807
2808        test_output_arrays = [
2809            None if descr.test_tensor is None else load_array(descr.test_tensor)
2810            for descr in self.outputs
2811        ]
2812        test_input_arrays = [
2813            None if descr.test_tensor is None else load_array(descr.test_tensor)
2814            for descr in self.inputs
2815        ]
2816
2817        tensors = {
2818            descr.id: (descr, array)
2819            for descr, array in zip(
2820                chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays
2821            )
2822        }
2823        validate_tensors(tensors, tensor_origin="test_tensor")
2824
2825        output_arrays = {
2826            descr.id: array for descr, array in zip(self.outputs, test_output_arrays)
2827        }
2828        for rep_tol in self.config.bioimageio.reproducibility_tolerance:
2829            if not rep_tol.absolute_tolerance:
2830                continue
2831
2832            if rep_tol.output_ids:
2833                out_arrays = {
2834                    oid: a
2835                    for oid, a in output_arrays.items()
2836                    if oid in rep_tol.output_ids
2837                }
2838            else:
2839                out_arrays = output_arrays
2840
2841            for out_id, array in out_arrays.items():
2842                if array is None:
2843                    continue
2844
2845                if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01:
2846                    raise ValueError(
2847                        "config.bioimageio.reproducibility_tolerance.absolute_tolerance="
2848                        + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}"
2849                        + f" (1% of the maximum value of the test tensor '{out_id}')"
2850                    )
2851
2852        return self
2853
2854    @model_validator(mode="after")
2855    def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self:
2856        ipt_refs = {t.id for t in self.inputs}
2857        out_refs = {t.id for t in self.outputs}
2858        for ipt in self.inputs:
2859            for p in ipt.preprocessing:
2860                ref = p.kwargs.get("reference_tensor")
2861                if ref is None:
2862                    continue
2863                if ref not in ipt_refs:
2864                    raise ValueError(
2865                        f"`reference_tensor` '{ref}' not found. Valid input tensor"
2866                        + f" references are: {ipt_refs}."
2867                    )
2868
2869        for out in self.outputs:
2870            for p in out.postprocessing:
2871                ref = p.kwargs.get("reference_tensor")
2872                if ref is None:
2873                    continue
2874
2875                if ref not in ipt_refs and ref not in out_refs:
2876                    raise ValueError(
2877                        f"`reference_tensor` '{ref}' not found. Valid tensor references"
2878                        + f" are: {ipt_refs | out_refs}."
2879                    )
2880
2881        return self
2882
2883    # TODO: use validate funcs in validate_test_tensors
2884    # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]:
2885
2886    name: Annotated[
2887        str,
2888        RestrictCharacters(string.ascii_letters + string.digits + "_+- ()"),
2889        MinLen(5),
2890        MaxLen(128),
2891        warn(MaxLen(64), "Name longer than 64 characters.", INFO),
2892    ]
2893    """A human-readable name of this model.
2894    It should be no longer than 64 characters
2895    and may only contain letter, number, underscore, minus, parentheses and spaces.
2896    We recommend to chose a name that refers to the model's task and image modality.
2897    """
2898
2899    outputs: NotEmpty[Sequence[OutputTensorDescr]]
2900    """Describes the output tensors."""
2901
2902    @field_validator("outputs", mode="after")
2903    @classmethod
2904    def _validate_tensor_ids(
2905        cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo
2906    ) -> Sequence[OutputTensorDescr]:
2907        tensor_ids = [
2908            t.id for t in info.data.get("inputs", []) + info.data.get("outputs", [])
2909        ]
2910        duplicate_tensor_ids: List[str] = []
2911        seen: Set[str] = set()
2912        for t in tensor_ids:
2913            if t in seen:
2914                duplicate_tensor_ids.append(t)
2915
2916            seen.add(t)
2917
2918        if duplicate_tensor_ids:
2919            raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}")
2920
2921        return outputs
2922
2923    @staticmethod
2924    def _get_axes_with_parameterized_size(
2925        io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2926    ):
2927        return {
2928            f"{t.id}.{a.id}": (t, a, a.size)
2929            for t in io
2930            for a in t.axes
2931            if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize)
2932        }
2933
2934    @staticmethod
2935    def _get_axes_with_independent_size(
2936        io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2937    ):
2938        return {
2939            (t.id, a.id): (t, a, a.size)
2940            for t in io
2941            for a in t.axes
2942            if not isinstance(a, BatchAxis)
2943            and isinstance(a.size, (int, ParameterizedSize))
2944        }
2945
2946    @field_validator("outputs", mode="after")
2947    @classmethod
2948    def _validate_output_axes(
2949        cls, outputs: List[OutputTensorDescr], info: ValidationInfo
2950    ) -> List[OutputTensorDescr]:
2951        input_size_refs = cls._get_axes_with_independent_size(
2952            info.data.get("inputs", [])
2953        )
2954        output_size_refs = cls._get_axes_with_independent_size(outputs)
2955
2956        for i, out in enumerate(outputs):
2957            valid_independent_refs: Dict[
2958                Tuple[TensorId, AxisId],
2959                Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2960            ] = {
2961                **{
2962                    (out.id, a.id): (out, a, a.size)
2963                    for a in out.axes
2964                    if not isinstance(a, BatchAxis)
2965                    and isinstance(a.size, (int, ParameterizedSize))
2966                },
2967                **input_size_refs,
2968                **output_size_refs,
2969            }
2970            for a, ax in enumerate(out.axes):
2971                cls._validate_axis(
2972                    "outputs",
2973                    i,
2974                    out.id,
2975                    a,
2976                    ax,
2977                    valid_independent_refs=valid_independent_refs,
2978                )
2979
2980        return outputs
2981
2982    packaged_by: List[Author] = Field(
2983        default_factory=cast(Callable[[], List[Author]], list)
2984    )
2985    """The persons that have packaged and uploaded this model.
2986    Only required if those persons differ from the `authors`."""
2987
2988    parent: Optional[LinkedModel] = None
2989    """The model from which this model is derived, e.g. by fine-tuning the weights."""
2990
2991    @model_validator(mode="after")
2992    def _validate_parent_is_not_self(self) -> Self:
2993        if self.parent is not None and self.parent.id == self.id:
2994            raise ValueError("A model description may not reference itself as parent.")
2995
2996        return self
2997
2998    run_mode: Annotated[
2999        Optional[RunMode],
3000        warn(None, "Run mode '{value}' has limited support across consumer softwares."),
3001    ] = None
3002    """Custom run mode for this model: for more complex prediction procedures like test time
3003    data augmentation that currently cannot be expressed in the specification.
3004    No standard run modes are defined yet."""
3005
3006    timestamp: Datetime = Field(default_factory=Datetime.now)
3007    """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format
3008    with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat).
3009    (In Python a datetime object is valid, too)."""
3010
3011    training_data: Annotated[
3012        Union[None, LinkedDataset, DatasetDescr, DatasetDescr02],
3013        Field(union_mode="left_to_right"),
3014    ] = None
3015    """The dataset used to train this model"""
3016
3017    weights: Annotated[WeightsDescr, WrapSerializer(package_weights)]
3018    """The weights for this model.
3019    Weights can be given for different formats, but should otherwise be equivalent.
3020    The available weight formats determine which consumers can use this model."""
3021
3022    config: Config = Field(default_factory=Config.model_construct)
3023
3024    @model_validator(mode="after")
3025    def _add_default_cover(self) -> Self:
3026        if not get_validation_context().perform_io_checks or self.covers:
3027            return self
3028
3029        try:
3030            generated_covers = generate_covers(
3031                [
3032                    (t, load_array(t.test_tensor))
3033                    for t in self.inputs
3034                    if t.test_tensor is not None
3035                ],
3036                [
3037                    (t, load_array(t.test_tensor))
3038                    for t in self.outputs
3039                    if t.test_tensor is not None
3040                ],
3041            )
3042        except Exception as e:
3043            issue_warning(
3044                "Failed to generate cover image(s): {e}",
3045                value=self.covers,
3046                msg_context=dict(e=e),
3047                field="covers",
3048            )
3049        else:
3050            self.covers.extend(generated_covers)
3051
3052        return self
3053
3054    def get_input_test_arrays(self) -> List[NDArray[Any]]:
3055        return self._get_test_arrays(self.inputs)
3056
3057    def get_output_test_arrays(self) -> List[NDArray[Any]]:
3058        return self._get_test_arrays(self.outputs)
3059
3060    @staticmethod
3061    def _get_test_arrays(
3062        io_descr: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
3063    ):
3064        ts: List[FileDescr] = []
3065        for d in io_descr:
3066            if d.test_tensor is None:
3067                raise ValueError(
3068                    f"Failed to get test arrays: description of '{d.id}' is missing a `test_tensor`."
3069                )
3070            ts.append(d.test_tensor)
3071
3072        data = [load_array(t) for t in ts]
3073        assert all(isinstance(d, np.ndarray) for d in data)
3074        return data
3075
3076    @staticmethod
3077    def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
3078        batch_size = 1
3079        tensor_with_batchsize: Optional[TensorId] = None
3080        for tid in tensor_sizes:
3081            for aid, s in tensor_sizes[tid].items():
3082                if aid != BATCH_AXIS_ID or s == 1 or s == batch_size:
3083                    continue
3084
3085                if batch_size != 1:
3086                    assert tensor_with_batchsize is not None
3087                    raise ValueError(
3088                        f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})"
3089                    )
3090
3091                batch_size = s
3092                tensor_with_batchsize = tid
3093
3094        return batch_size
3095
3096    def get_output_tensor_sizes(
3097        self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]
3098    ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]:
3099        """Returns the tensor output sizes for given **input_sizes**.
3100        Only if **input_sizes** has a valid input shape, the tensor output size is exact.
3101        Otherwise it might be larger than the actual (valid) output"""
3102        batch_size = self.get_batch_size(input_sizes)
3103        ns = self.get_ns(input_sizes)
3104
3105        tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size)
3106        return tensor_sizes.outputs
3107
3108    def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
3109        """get parameter `n` for each parameterized axis
3110        such that the valid input size is >= the given input size"""
3111        ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {}
3112        axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs}
3113        for tid in input_sizes:
3114            for aid, s in input_sizes[tid].items():
3115                size_descr = axes[tid][aid].size
3116                if isinstance(size_descr, ParameterizedSize):
3117                    ret[(tid, aid)] = size_descr.get_n(s)
3118                elif size_descr is None or isinstance(size_descr, (int, SizeReference)):
3119                    pass
3120                else:
3121                    assert_never(size_descr)
3122
3123        return ret
3124
3125    def get_tensor_sizes(
3126        self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int
3127    ) -> _TensorSizes:
3128        axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size)
3129        return _TensorSizes(
3130            {
3131                t: {
3132                    aa: axis_sizes.inputs[(tt, aa)]
3133                    for tt, aa in axis_sizes.inputs
3134                    if tt == t
3135                }
3136                for t in {tt for tt, _ in axis_sizes.inputs}
3137            },
3138            {
3139                t: {
3140                    aa: axis_sizes.outputs[(tt, aa)]
3141                    for tt, aa in axis_sizes.outputs
3142                    if tt == t
3143                }
3144                for t in {tt for tt, _ in axis_sizes.outputs}
3145            },
3146        )
3147
3148    def get_axis_sizes(
3149        self,
3150        ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
3151        batch_size: Optional[int] = None,
3152        *,
3153        max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
3154    ) -> _AxisSizes:
3155        """Determine input and output block shape for scale factors **ns**
3156        of parameterized input sizes.
3157
3158        Args:
3159            ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id))
3160                that is parameterized as `size = min + n * step`.
3161            batch_size: The desired size of the batch dimension.
3162                If given **batch_size** overwrites any batch size present in
3163                **max_input_shape**. Default 1.
3164            max_input_shape: Limits the derived block shapes.
3165                Each axis for which the input size, parameterized by `n`, is larger
3166                than **max_input_shape** is set to the minimal value `n_min` for which
3167                this is still true.
3168                Use this for small input samples or large values of **ns**.
3169                Or simply whenever you know the full input shape.
3170
3171        Returns:
3172            Resolved axis sizes for model inputs and outputs.
3173        """
3174        max_input_shape = max_input_shape or {}
3175        if batch_size is None:
3176            for (_t_id, a_id), s in max_input_shape.items():
3177                if a_id == BATCH_AXIS_ID:
3178                    batch_size = s
3179                    break
3180            else:
3181                batch_size = 1
3182
3183        all_axes = {
3184            t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
3185        }
3186
3187        inputs: Dict[Tuple[TensorId, AxisId], int] = {}
3188        outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {}
3189
3190        def get_axis_size(a: Union[InputAxis, OutputAxis]):
3191            if isinstance(a, BatchAxis):
3192                if (t_descr.id, a.id) in ns:
3193                    logger.warning(
3194                        "Ignoring unexpected size increment factor (n) for batch axis"
3195                        + " of tensor '{}'.",
3196                        t_descr.id,
3197                    )
3198                return batch_size
3199            elif isinstance(a.size, int):
3200                if (t_descr.id, a.id) in ns:
3201                    logger.warning(
3202                        "Ignoring unexpected size increment factor (n) for fixed size"
3203                        + " axis '{}' of tensor '{}'.",
3204                        a.id,
3205                        t_descr.id,
3206                    )
3207                return a.size
3208            elif isinstance(a.size, ParameterizedSize):
3209                if (t_descr.id, a.id) not in ns:
3210                    raise ValueError(
3211                        "Size increment factor (n) missing for parametrized axis"
3212                        + f" '{a.id}' of tensor '{t_descr.id}'."
3213                    )
3214                n = ns[(t_descr.id, a.id)]
3215                s_max = max_input_shape.get((t_descr.id, a.id))
3216                if s_max is not None:
3217                    n = min(n, a.size.get_n(s_max))
3218
3219                return a.size.get_size(n)
3220
3221            elif isinstance(a.size, SizeReference):
3222                if (t_descr.id, a.id) in ns:
3223                    logger.warning(
3224                        "Ignoring unexpected size increment factor (n) for axis '{}'"
3225                        + " of tensor '{}' with size reference.",
3226                        a.id,
3227                        t_descr.id,
3228                    )
3229                assert not isinstance(a, BatchAxis)
3230                ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
3231                assert not isinstance(ref_axis, BatchAxis)
3232                ref_key = (a.size.tensor_id, a.size.axis_id)
3233                ref_size = inputs.get(ref_key, outputs.get(ref_key))
3234                assert ref_size is not None, ref_key
3235                assert not isinstance(ref_size, _DataDepSize), ref_key
3236                return a.size.get_size(
3237                    axis=a,
3238                    ref_axis=ref_axis,
3239                    ref_size=ref_size,
3240                )
3241            elif isinstance(a.size, DataDependentSize):
3242                if (t_descr.id, a.id) in ns:
3243                    logger.warning(
3244                        "Ignoring unexpected increment factor (n) for data dependent"
3245                        + " size axis '{}' of tensor '{}'.",
3246                        a.id,
3247                        t_descr.id,
3248                    )
3249                return _DataDepSize(a.size.min, a.size.max)
3250            else:
3251                assert_never(a.size)
3252
3253        # first resolve all , but the `SizeReference` input sizes
3254        for t_descr in self.inputs:
3255            for a in t_descr.axes:
3256                if not isinstance(a.size, SizeReference):
3257                    s = get_axis_size(a)
3258                    assert not isinstance(s, _DataDepSize)
3259                    inputs[t_descr.id, a.id] = s
3260
3261        # resolve all other input axis sizes
3262        for t_descr in self.inputs:
3263            for a in t_descr.axes:
3264                if isinstance(a.size, SizeReference):
3265                    s = get_axis_size(a)
3266                    assert not isinstance(s, _DataDepSize)
3267                    inputs[t_descr.id, a.id] = s
3268
3269        # resolve all output axis sizes
3270        for t_descr in self.outputs:
3271            for a in t_descr.axes:
3272                assert not isinstance(a.size, ParameterizedSize)
3273                s = get_axis_size(a)
3274                outputs[t_descr.id, a.id] = s
3275
3276        return _AxisSizes(inputs=inputs, outputs=outputs)
3277
3278    @model_validator(mode="before")
3279    @classmethod
3280    def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]:
3281        cls.convert_from_old_format_wo_validation(data)
3282        return data
3283
3284    @classmethod
3285    def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3286        """Convert metadata following an older format version to this classes' format
3287        without validating the result.
3288        """
3289        if (
3290            data.get("type") == "model"
3291            and isinstance(fv := data.get("format_version"), str)
3292            and fv.count(".") == 2
3293        ):
3294            fv_parts = fv.split(".")
3295            if any(not p.isdigit() for p in fv_parts):
3296                return
3297
3298            fv_tuple = tuple(map(int, fv_parts))
3299
3300            assert cls.implemented_format_version_tuple[0:2] == (0, 5)
3301            if fv_tuple[:2] in ((0, 3), (0, 4)):
3302                m04 = _ModelDescr_v0_4.load(data)
3303                if isinstance(m04, InvalidDescr):
3304                    try:
3305                        updated = _model_conv.convert_as_dict(
3306                            m04  # pyright: ignore[reportArgumentType]
3307                        )
3308                    except Exception as e:
3309                        logger.error(
3310                            "Failed to convert from invalid model 0.4 description."
3311                            + f"\nerror: {e}"
3312                            + "\nProceeding with model 0.5 validation without conversion."
3313                        )
3314                        updated = None
3315                else:
3316                    updated = _model_conv.convert_as_dict(m04)
3317
3318                if updated is not None:
3319                    data.clear()
3320                    data.update(updated)
3321
3322            elif fv_tuple[:2] == (0, 5):
3323                # bump patch version
3324                data["format_version"] = cls.implemented_format_version
3325
3326
3327class _ModelConv(Converter[_ModelDescr_v0_4, ModelDescr]):
3328    def _convert(
3329        self, src: _ModelDescr_v0_4, tgt: "type[ModelDescr] | type[dict[str, Any]]"
3330    ) -> "ModelDescr | dict[str, Any]":
3331        name = "".join(
3332            c if c in string.ascii_letters + string.digits + "_+- ()" else " "
3333            for c in src.name
3334        )
3335
3336        def conv_authors(auths: Optional[Sequence[_Author_v0_4]]):
3337            conv = (
3338                _author_conv.convert if TYPE_CHECKING else _author_conv.convert_as_dict
3339            )
3340            return None if auths is None else [conv(a) for a in auths]
3341
3342        if TYPE_CHECKING:
3343            arch_file_conv = _arch_file_conv.convert
3344            arch_lib_conv = _arch_lib_conv.convert
3345        else:
3346            arch_file_conv = _arch_file_conv.convert_as_dict
3347            arch_lib_conv = _arch_lib_conv.convert_as_dict
3348
3349        input_size_refs = {
3350            ipt.name: {
3351                a: s
3352                for a, s in zip(
3353                    ipt.axes,
3354                    (
3355                        ipt.shape.min
3356                        if isinstance(ipt.shape, _ParameterizedInputShape_v0_4)
3357                        else ipt.shape
3358                    ),
3359                )
3360            }
3361            for ipt in src.inputs
3362            if ipt.shape
3363        }
3364        output_size_refs = {
3365            **{
3366                out.name: {a: s for a, s in zip(out.axes, out.shape)}
3367                for out in src.outputs
3368                if not isinstance(out.shape, _ImplicitOutputShape_v0_4)
3369            },
3370            **input_size_refs,
3371        }
3372
3373        return tgt(
3374            attachments=(
3375                []
3376                if src.attachments is None
3377                else [FileDescr(source=f) for f in src.attachments.files]
3378            ),
3379            authors=[_author_conv.convert_as_dict(a) for a in src.authors],  # pyright: ignore[reportArgumentType]
3380            cite=[{"text": c.text, "doi": c.doi, "url": c.url} for c in src.cite],  # pyright: ignore[reportArgumentType]
3381            config=src.config,  # pyright: ignore[reportArgumentType]
3382            covers=src.covers,
3383            description=src.description,
3384            documentation=src.documentation,
3385            format_version="0.5.6",
3386            git_repo=src.git_repo,  # pyright: ignore[reportArgumentType]
3387            icon=src.icon,
3388            id=None if src.id is None else ModelId(src.id),
3389            id_emoji=src.id_emoji,
3390            license=src.license,  # type: ignore
3391            links=src.links,
3392            maintainers=[_maintainer_conv.convert_as_dict(m) for m in src.maintainers],  # pyright: ignore[reportArgumentType]
3393            name=name,
3394            tags=src.tags,
3395            type=src.type,
3396            uploader=src.uploader,
3397            version=src.version,
3398            inputs=[  # pyright: ignore[reportArgumentType]
3399                _input_tensor_conv.convert_as_dict(ipt, tt, st, input_size_refs)
3400                for ipt, tt, st in zip(
3401                    src.inputs,
3402                    src.test_inputs,
3403                    src.sample_inputs or [None] * len(src.test_inputs),
3404                )
3405            ],
3406            outputs=[  # pyright: ignore[reportArgumentType]
3407                _output_tensor_conv.convert_as_dict(out, tt, st, output_size_refs)
3408                for out, tt, st in zip(
3409                    src.outputs,
3410                    src.test_outputs,
3411                    src.sample_outputs or [None] * len(src.test_outputs),
3412                )
3413            ],
3414            parent=(
3415                None
3416                if src.parent is None
3417                else LinkedModel(
3418                    id=ModelId(
3419                        str(src.parent.id)
3420                        + (
3421                            ""
3422                            if src.parent.version_number is None
3423                            else f"/{src.parent.version_number}"
3424                        )
3425                    )
3426                )
3427            ),
3428            training_data=(
3429                None
3430                if src.training_data is None
3431                else (
3432                    LinkedDataset(
3433                        id=DatasetId(
3434                            str(src.training_data.id)
3435                            + (
3436                                ""
3437                                if src.training_data.version_number is None
3438                                else f"/{src.training_data.version_number}"
3439                            )
3440                        )
3441                    )
3442                    if isinstance(src.training_data, LinkedDataset02)
3443                    else src.training_data
3444                )
3445            ),
3446            packaged_by=[_author_conv.convert_as_dict(a) for a in src.packaged_by],  # pyright: ignore[reportArgumentType]
3447            run_mode=src.run_mode,
3448            timestamp=src.timestamp,
3449            weights=(WeightsDescr if TYPE_CHECKING else dict)(
3450                keras_hdf5=(w := src.weights.keras_hdf5)
3451                and (KerasHdf5WeightsDescr if TYPE_CHECKING else dict)(
3452                    authors=conv_authors(w.authors),
3453                    source=w.source,
3454                    tensorflow_version=w.tensorflow_version or Version("1.15"),
3455                    parent=w.parent,
3456                ),
3457                onnx=(w := src.weights.onnx)
3458                and (OnnxWeightsDescr if TYPE_CHECKING else dict)(
3459                    source=w.source,
3460                    authors=conv_authors(w.authors),
3461                    parent=w.parent,
3462                    opset_version=w.opset_version or 15,
3463                ),
3464                pytorch_state_dict=(w := src.weights.pytorch_state_dict)
3465                and (PytorchStateDictWeightsDescr if TYPE_CHECKING else dict)(
3466                    source=w.source,
3467                    authors=conv_authors(w.authors),
3468                    parent=w.parent,
3469                    architecture=(
3470                        arch_file_conv(
3471                            w.architecture,
3472                            w.architecture_sha256,
3473                            w.kwargs,
3474                        )
3475                        if isinstance(w.architecture, _CallableFromFile_v0_4)
3476                        else arch_lib_conv(w.architecture, w.kwargs)
3477                    ),
3478                    pytorch_version=w.pytorch_version or Version("1.10"),
3479                    dependencies=(
3480                        None
3481                        if w.dependencies is None
3482                        else (FileDescr if TYPE_CHECKING else dict)(
3483                            source=cast(
3484                                FileSource,
3485                                str(deps := w.dependencies)[
3486                                    (
3487                                        len("conda:")
3488                                        if str(deps).startswith("conda:")
3489                                        else 0
3490                                    ) :
3491                                ],
3492                            )
3493                        )
3494                    ),
3495                ),
3496                tensorflow_js=(w := src.weights.tensorflow_js)
3497                and (TensorflowJsWeightsDescr if TYPE_CHECKING else dict)(
3498                    source=w.source,
3499                    authors=conv_authors(w.authors),
3500                    parent=w.parent,
3501                    tensorflow_version=w.tensorflow_version or Version("1.15"),
3502                ),
3503                tensorflow_saved_model_bundle=(
3504                    w := src.weights.tensorflow_saved_model_bundle
3505                )
3506                and (TensorflowSavedModelBundleWeightsDescr if TYPE_CHECKING else dict)(
3507                    authors=conv_authors(w.authors),
3508                    parent=w.parent,
3509                    source=w.source,
3510                    tensorflow_version=w.tensorflow_version or Version("1.15"),
3511                    dependencies=(
3512                        None
3513                        if w.dependencies is None
3514                        else (FileDescr if TYPE_CHECKING else dict)(
3515                            source=cast(
3516                                FileSource,
3517                                (
3518                                    str(w.dependencies)[len("conda:") :]
3519                                    if str(w.dependencies).startswith("conda:")
3520                                    else str(w.dependencies)
3521                                ),
3522                            )
3523                        )
3524                    ),
3525                ),
3526                torchscript=(w := src.weights.torchscript)
3527                and (TorchscriptWeightsDescr if TYPE_CHECKING else dict)(
3528                    source=w.source,
3529                    authors=conv_authors(w.authors),
3530                    parent=w.parent,
3531                    pytorch_version=w.pytorch_version or Version("1.10"),
3532                ),
3533            ),
3534        )
3535
3536
3537_model_conv = _ModelConv(_ModelDescr_v0_4, ModelDescr)
3538
3539
3540# create better cover images for 3d data and non-image outputs
3541def generate_covers(
3542    inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]],
3543    outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]],
3544) -> List[Path]:
3545    def squeeze(
3546        data: NDArray[Any], axes: Sequence[AnyAxis]
3547    ) -> Tuple[NDArray[Any], List[AnyAxis]]:
3548        """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining"""
3549        if data.ndim != len(axes):
3550            raise ValueError(
3551                f"tensor shape {data.shape} does not match described axes"
3552                + f" {[a.id for a in axes]}"
3553            )
3554
3555        axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1]
3556        return data.squeeze(), axes
3557
3558    def normalize(
3559        data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7
3560    ) -> NDArray[np.float32]:
3561        data = data.astype("float32")
3562        data -= data.min(axis=axis, keepdims=True)
3563        data /= data.max(axis=axis, keepdims=True) + eps
3564        return data
3565
3566    def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]):
3567        original_shape = data.shape
3568        original_axes = list(axes)
3569        data, axes = squeeze(data, axes)
3570
3571        # take slice fom any batch or index axis if needed
3572        # and convert the first channel axis and take a slice from any additional channel axes
3573        slices: Tuple[slice, ...] = ()
3574        ndim = data.ndim
3575        ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2
3576        has_c_axis = False
3577        for i, a in enumerate(axes):
3578            s = data.shape[i]
3579            assert s > 1
3580            if (
3581                isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis))
3582                and ndim > ndim_need
3583            ):
3584                data = data[slices + (slice(s // 2 - 1, s // 2),)]
3585                ndim -= 1
3586            elif isinstance(a, ChannelAxis):
3587                if has_c_axis:
3588                    # second channel axis
3589                    data = data[slices + (slice(0, 1),)]
3590                    ndim -= 1
3591                else:
3592                    has_c_axis = True
3593                    if s == 2:
3594                        # visualize two channels with cyan and magenta
3595                        data = np.concatenate(
3596                            [
3597                                data[slices + (slice(1, 2),)],
3598                                data[slices + (slice(0, 1),)],
3599                                (
3600                                    data[slices + (slice(0, 1),)]
3601                                    + data[slices + (slice(1, 2),)]
3602                                )
3603                                / 2,  # TODO: take maximum instead?
3604                            ],
3605                            axis=i,
3606                        )
3607                    elif data.shape[i] == 3:
3608                        pass  # visualize 3 channels as RGB
3609                    else:
3610                        # visualize first 3 channels as RGB
3611                        data = data[slices + (slice(3),)]
3612
3613                    assert data.shape[i] == 3
3614
3615            slices += (slice(None),)
3616
3617        data, axes = squeeze(data, axes)
3618        assert len(axes) == ndim
3619        # take slice from z axis if needed
3620        slices = ()
3621        if ndim > ndim_need:
3622            for i, a in enumerate(axes):
3623                s = data.shape[i]
3624                if a.id == AxisId("z"):
3625                    data = data[slices + (slice(s // 2 - 1, s // 2),)]
3626                    data, axes = squeeze(data, axes)
3627                    ndim -= 1
3628                    break
3629
3630            slices += (slice(None),)
3631
3632        # take slice from any space or time axis
3633        slices = ()
3634
3635        for i, a in enumerate(axes):
3636            if ndim <= ndim_need:
3637                break
3638
3639            s = data.shape[i]
3640            assert s > 1
3641            if isinstance(
3642                a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis)
3643            ):
3644                data = data[slices + (slice(s // 2 - 1, s // 2),)]
3645                ndim -= 1
3646
3647            slices += (slice(None),)
3648
3649        del slices
3650        data, axes = squeeze(data, axes)
3651        assert len(axes) == ndim
3652
3653        if (has_c_axis and ndim != 3) or (not has_c_axis and ndim != 2):
3654            raise ValueError(
3655                f"Failed to construct cover image from shape {original_shape} with axes {[a.id for a in original_axes]}."
3656            )
3657
3658        if not has_c_axis:
3659            assert ndim == 2
3660            data = np.repeat(data[:, :, None], 3, axis=2)
3661            axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB"))))
3662            ndim += 1
3663
3664        assert ndim == 3
3665
3666        # transpose axis order such that longest axis comes first...
3667        axis_order: List[int] = list(np.argsort(list(data.shape)))
3668        axis_order.reverse()
3669        # ... and channel axis is last
3670        c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0]
3671        axis_order.append(axis_order.pop(c))
3672        axes = [axes[ao] for ao in axis_order]
3673        data = data.transpose(axis_order)
3674
3675        # h, w = data.shape[:2]
3676        # if h / w  in (1.0 or 2.0):
3677        #     pass
3678        # elif h / w < 2:
3679        # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images
3680
3681        norm_along = (
3682            tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None
3683        )
3684        # normalize the data and map to 8 bit
3685        data = normalize(data, norm_along)
3686        data = (data * 255).astype("uint8")
3687
3688        return data
3689
3690    def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]):
3691        assert im0.dtype == im1.dtype == np.uint8
3692        assert im0.shape == im1.shape
3693        assert im0.ndim == 3
3694        N, M, C = im0.shape
3695        assert C == 3
3696        out = np.ones((N, M, C), dtype="uint8")
3697        for c in range(C):
3698            outc = np.tril(im0[..., c])
3699            mask = outc == 0
3700            outc[mask] = np.triu(im1[..., c])[mask]
3701            out[..., c] = outc
3702
3703        return out
3704
3705    if not inputs:
3706        raise ValueError("Missing test input tensor for cover generation.")
3707
3708    if not outputs:
3709        raise ValueError("Missing test output tensor for cover generation.")
3710
3711    ipt_descr, ipt = inputs[0]
3712    out_descr, out = outputs[0]
3713
3714    ipt_img = to_2d_image(ipt, ipt_descr.axes)
3715    out_img = to_2d_image(out, out_descr.axes)
3716
3717    cover_folder = Path(mkdtemp())
3718    if ipt_img.shape == out_img.shape:
3719        covers = [cover_folder / "cover.png"]
3720        imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img))
3721    else:
3722        covers = [cover_folder / "input.png", cover_folder / "output.png"]
3723        imwrite(covers[0], ipt_img)
3724        imwrite(covers[1], out_img)
3725
3726    return covers
SpaceUnit = typing.Literal['attometer', 'angstrom', 'centimeter', 'decimeter', 'exameter', 'femtometer', 'foot', 'gigameter', 'hectometer', 'inch', 'kilometer', 'megameter', 'meter', 'micrometer', 'mile', 'millimeter', 'nanometer', 'parsec', 'petameter', 'picometer', 'terameter', 'yard', 'yoctometer', 'yottameter', 'zeptometer', 'zettameter']

Space unit compatible to the OME-Zarr axes specification 0.5

TimeUnit = typing.Literal['attosecond', 'centisecond', 'day', 'decisecond', 'exasecond', 'femtosecond', 'gigasecond', 'hectosecond', 'hour', 'kilosecond', 'megasecond', 'microsecond', 'millisecond', 'minute', 'nanosecond', 'petasecond', 'picosecond', 'second', 'terasecond', 'yoctosecond', 'yottasecond', 'zeptosecond', 'zettasecond']

Time unit compatible to the OME-Zarr axes specification 0.5

AxisType = typing.Literal['batch', 'channel', 'index', 'time', 'space']
230class TensorId(LowerCaseIdentifier):
231    root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
232        Annotated[LowerCaseIdentifierAnno, MaxLen(32)]
233    ]

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.

root_model: ClassVar[Type[pydantic.root_model.RootModel[Any]]] = <class 'pydantic.root_model.RootModel[Annotated[str, MinLen, AfterValidator, AfterValidator, Annotated[TypeVar, Predicate], MaxLen]]'>

the pydantic root model to validate the string

246class AxisId(LowerCaseIdentifier):
247    root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
248        Annotated[
249            LowerCaseIdentifierAnno,
250            MaxLen(16),
251            AfterValidator(_normalize_axis_id),
252        ]
253    ]

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.

root_model: ClassVar[Type[pydantic.root_model.RootModel[Any]]] = <class 'pydantic.root_model.RootModel[Annotated[str, MinLen, AfterValidator, AfterValidator, Annotated[TypeVar, Predicate], MaxLen, AfterValidator]]'>

the pydantic root model to validate the string

NonBatchAxisId = typing.Annotated[AxisId, Predicate(_is_not_batch)]
PreprocessingId = typing.Literal['binarize', 'clip', 'ensure_dtype', 'fixed_zero_mean_unit_variance', 'scale_linear', 'scale_range', 'sigmoid', 'softmax']
PostprocessingId = typing.Literal['binarize', 'clip', 'ensure_dtype', 'fixed_zero_mean_unit_variance', 'scale_linear', 'scale_mean_variance', 'scale_range', 'sigmoid', 'softmax', 'zero_mean_unit_variance']
SAME_AS_TYPE = '<same as type>'
ParameterizedSize_N = <class 'int'>

Annotates an integer to calculate a concrete axis size from a ParameterizedSize.

class ParameterizedSize(bioimageio.spec._internal.node.Node):
299class ParameterizedSize(Node):
300    """Describes a range of valid tensor axis sizes as `size = min + n*step`.
301
302    - **min** and **step** are given by the model description.
303    - All blocksize paramters n = 0,1,2,... yield a valid `size`.
304    - A greater blocksize paramter n = 0,1,2,... results in a greater **size**.
305      This allows to adjust the axis size more generically.
306    """
307
308    N: ClassVar[Type[int]] = ParameterizedSize_N
309    """Positive integer to parameterize this axis"""
310
311    min: Annotated[int, Gt(0)]
312    step: Annotated[int, Gt(0)]
313
314    def validate_size(self, size: int) -> int:
315        if size < self.min:
316            raise ValueError(f"size {size} < {self.min}")
317        if (size - self.min) % self.step != 0:
318            raise ValueError(
319                f"axis of size {size} is not parameterized by `min + n*step` ="
320                + f" `{self.min} + n*{self.step}`"
321            )
322
323        return size
324
325    def get_size(self, n: ParameterizedSize_N) -> int:
326        return self.min + self.step * n
327
328    def get_n(self, s: int) -> ParameterizedSize_N:
329        """return smallest n parameterizing a size greater or equal than `s`"""
330        return ceil((s - self.min) / self.step)

Describes a range of valid tensor axis sizes as size = min + n*step.

  • min and step are given by the model description.
  • All blocksize paramters n = 0,1,2,... yield a valid size.
  • A greater blocksize paramter n = 0,1,2,... results in a greater size. This allows to adjust the axis size more generically.
N: ClassVar[Type[int]] = <class 'int'>

Positive integer to parameterize this axis

min: Annotated[int, Gt(gt=0)] = PydanticUndefined
step: Annotated[int, Gt(gt=0)] = PydanticUndefined
def validate_size(self, size: int) -> int:
314    def validate_size(self, size: int) -> int:
315        if size < self.min:
316            raise ValueError(f"size {size} < {self.min}")
317        if (size - self.min) % self.step != 0:
318            raise ValueError(
319                f"axis of size {size} is not parameterized by `min + n*step` ="
320                + f" `{self.min} + n*{self.step}`"
321            )
322
323        return size
def get_size(self, n: int) -> int:
325    def get_size(self, n: ParameterizedSize_N) -> int:
326        return self.min + self.step * n
def get_n(self, s: int) -> int:
328    def get_n(self, s: int) -> ParameterizedSize_N:
329        """return smallest n parameterizing a size greater or equal than `s`"""
330        return ceil((s - self.min) / self.step)

return smallest n parameterizing a size greater or equal than s

class DataDependentSize(bioimageio.spec._internal.node.Node):
333class DataDependentSize(Node):
334    min: Annotated[int, Gt(0)] = 1
335    max: Annotated[Optional[int], Gt(1)] = None
336
337    @model_validator(mode="after")
338    def _validate_max_gt_min(self):
339        if self.max is not None and self.min >= self.max:
340            raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}")
341
342        return self
343
344    def validate_size(self, size: int) -> int:
345        if size < self.min:
346            raise ValueError(f"size {size} < {self.min}")
347
348        if self.max is not None and size > self.max:
349            raise ValueError(f"size {size} > {self.max}")
350
351        return size
min: Annotated[int, Gt(gt=0)] = 1
max: Annotated[Optional[int], Gt(gt=1)] = None
def validate_size(self, size: int) -> int:
344    def validate_size(self, size: int) -> int:
345        if size < self.min:
346            raise ValueError(f"size {size} < {self.min}")
347
348        if self.max is not None and size > self.max:
349            raise ValueError(f"size {size} > {self.max}")
350
351        return size
class SizeReference(bioimageio.spec._internal.node.Node):
354class SizeReference(Node):
355    """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis.
356
357    `axis.size = reference.size * reference.scale / axis.scale + offset`
358
359    Note:
360    1. The axis and the referenced axis need to have the same unit (or no unit).
361    2. Batch axes may not be referenced.
362    3. Fractions are rounded down.
363    4. If the reference axis is `concatenable` the referencing axis is assumed to be
364        `concatenable` as well with the same block order.
365
366    Example:
367    An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm².
368    Let's assume that we want to express the image height h in relation to its width w
369    instead of only accepting input images of exactly 100*49 pixels
370    (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`).
371
372    >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2)
373    >>> h = SpaceInputAxis(
374    ...     id=AxisId("h"),
375    ...     size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1),
376    ...     unit="millimeter",
377    ...     scale=4,
378    ... )
379    >>> print(h.size.get_size(h, w))
380    49
381
382    ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49
383    """
384
385    tensor_id: TensorId
386    """tensor id of the reference axis"""
387
388    axis_id: AxisId
389    """axis id of the reference axis"""
390
391    offset: StrictInt = 0
392
393    def get_size(
394        self,
395        axis: Union[
396            ChannelAxis,
397            IndexInputAxis,
398            IndexOutputAxis,
399            TimeInputAxis,
400            SpaceInputAxis,
401            TimeOutputAxis,
402            TimeOutputAxisWithHalo,
403            SpaceOutputAxis,
404            SpaceOutputAxisWithHalo,
405        ],
406        ref_axis: Union[
407            ChannelAxis,
408            IndexInputAxis,
409            IndexOutputAxis,
410            TimeInputAxis,
411            SpaceInputAxis,
412            TimeOutputAxis,
413            TimeOutputAxisWithHalo,
414            SpaceOutputAxis,
415            SpaceOutputAxisWithHalo,
416        ],
417        n: ParameterizedSize_N = 0,
418        ref_size: Optional[int] = None,
419    ):
420        """Compute the concrete size for a given axis and its reference axis.
421
422        Args:
423            axis: The axis this `SizeReference` is the size of.
424            ref_axis: The reference axis to compute the size from.
425            n: If the **ref_axis** is parameterized (of type `ParameterizedSize`)
426                and no fixed **ref_size** is given,
427                **n** is used to compute the size of the parameterized **ref_axis**.
428            ref_size: Overwrite the reference size instead of deriving it from
429                **ref_axis**
430                (**ref_axis.scale** is still used; any given **n** is ignored).
431        """
432        assert axis.size == self, (
433            "Given `axis.size` is not defined by this `SizeReference`"
434        )
435
436        assert ref_axis.id == self.axis_id, (
437            f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}."
438        )
439
440        assert axis.unit == ref_axis.unit, (
441            "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`,"
442            f" but {axis.unit}!={ref_axis.unit}"
443        )
444        if ref_size is None:
445            if isinstance(ref_axis.size, (int, float)):
446                ref_size = ref_axis.size
447            elif isinstance(ref_axis.size, ParameterizedSize):
448                ref_size = ref_axis.size.get_size(n)
449            elif isinstance(ref_axis.size, DataDependentSize):
450                raise ValueError(
451                    "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`."
452                )
453            elif isinstance(ref_axis.size, SizeReference):
454                raise ValueError(
455                    "Reference axis referenced in `SizeReference` may not be sized by a"
456                    + " `SizeReference` itself."
457                )
458            else:
459                assert_never(ref_axis.size)
460
461        return int(ref_size * ref_axis.scale / axis.scale + self.offset)
462
463    @staticmethod
464    def _get_unit(
465        axis: Union[
466            ChannelAxis,
467            IndexInputAxis,
468            IndexOutputAxis,
469            TimeInputAxis,
470            SpaceInputAxis,
471            TimeOutputAxis,
472            TimeOutputAxisWithHalo,
473            SpaceOutputAxis,
474            SpaceOutputAxisWithHalo,
475        ],
476    ):
477        return axis.unit

A tensor axis size (extent in pixels/frames) defined in relation to a reference axis.

axis.size = reference.size * reference.scale / axis.scale + offset

Note:

  1. The axis and the referenced axis need to have the same unit (or no unit).
  2. Batch axes may not be referenced.
  3. Fractions are rounded down.
  4. If the reference axis is concatenable the referencing axis is assumed to be concatenable as well with the same block order.

Example: An unisotropic input image of wh=10049 pixels depicts a phsical space of 200196mm². Let's assume that we want to express the image height h in relation to its width w instead of only accepting input images of exactly 10049 pixels (for example to express a range of valid image shapes by parametrizing w, see ParameterizedSize).

>>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2)
>>> h = SpaceInputAxis(
...     id=AxisId("h"),
...     size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1),
...     unit="millimeter",
...     scale=4,
... )
>>> print(h.size.get_size(h, w))
49

⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49

tensor_id: TensorId = PydanticUndefined

tensor id of the reference axis

axis_id: AxisId = PydanticUndefined

axis id of the reference axis

offset: Annotated[int, Strict(strict=True)] = 0
393    def get_size(
394        self,
395        axis: Union[
396            ChannelAxis,
397            IndexInputAxis,
398            IndexOutputAxis,
399            TimeInputAxis,
400            SpaceInputAxis,
401            TimeOutputAxis,
402            TimeOutputAxisWithHalo,
403            SpaceOutputAxis,
404            SpaceOutputAxisWithHalo,
405        ],
406        ref_axis: Union[
407            ChannelAxis,
408            IndexInputAxis,
409            IndexOutputAxis,
410            TimeInputAxis,
411            SpaceInputAxis,
412            TimeOutputAxis,
413            TimeOutputAxisWithHalo,
414            SpaceOutputAxis,
415            SpaceOutputAxisWithHalo,
416        ],
417        n: ParameterizedSize_N = 0,
418        ref_size: Optional[int] = None,
419    ):
420        """Compute the concrete size for a given axis and its reference axis.
421
422        Args:
423            axis: The axis this `SizeReference` is the size of.
424            ref_axis: The reference axis to compute the size from.
425            n: If the **ref_axis** is parameterized (of type `ParameterizedSize`)
426                and no fixed **ref_size** is given,
427                **n** is used to compute the size of the parameterized **ref_axis**.
428            ref_size: Overwrite the reference size instead of deriving it from
429                **ref_axis**
430                (**ref_axis.scale** is still used; any given **n** is ignored).
431        """
432        assert axis.size == self, (
433            "Given `axis.size` is not defined by this `SizeReference`"
434        )
435
436        assert ref_axis.id == self.axis_id, (
437            f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}."
438        )
439
440        assert axis.unit == ref_axis.unit, (
441            "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`,"
442            f" but {axis.unit}!={ref_axis.unit}"
443        )
444        if ref_size is None:
445            if isinstance(ref_axis.size, (int, float)):
446                ref_size = ref_axis.size
447            elif isinstance(ref_axis.size, ParameterizedSize):
448                ref_size = ref_axis.size.get_size(n)
449            elif isinstance(ref_axis.size, DataDependentSize):
450                raise ValueError(
451                    "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`."
452                )
453            elif isinstance(ref_axis.size, SizeReference):
454                raise ValueError(
455                    "Reference axis referenced in `SizeReference` may not be sized by a"
456                    + " `SizeReference` itself."
457                )
458            else:
459                assert_never(ref_axis.size)
460
461        return int(ref_size * ref_axis.scale / axis.scale + self.offset)

Compute the concrete size for a given axis and its reference axis.

Arguments:
  • axis: The axis this SizeReference is the size of.
  • ref_axis: The reference axis to compute the size from.
  • n: If the ref_axis is parameterized (of type ParameterizedSize) and no fixed ref_size is given, n is used to compute the size of the parameterized ref_axis.
  • ref_size: Overwrite the reference size instead of deriving it from ref_axis (ref_axis.scale is still used; any given n is ignored).
480class AxisBase(NodeWithExplicitlySetFields):
481    id: AxisId
482    """An axis id unique across all axes of one tensor."""
483
484    description: Annotated[str, MaxLen(128)] = ""
485    """A short description of this axis beyond its type and id."""
id: AxisId = PydanticUndefined

An axis id unique across all axes of one tensor.

description: Annotated[str, MaxLen(max_length=128)] = ''

A short description of this axis beyond its type and id.

class WithHalo(bioimageio.spec._internal.node.Node):
488class WithHalo(Node):
489    halo: Annotated[int, Ge(1)]
490    """The halo should be cropped from the output tensor to avoid boundary effects.
491    It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`.
492    To document a halo that is already cropped by the model use `size.offset` instead."""
493
494    size: Annotated[
495        SizeReference,
496        Field(
497            examples=[
498                10,
499                SizeReference(
500                    tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
501                ).model_dump(mode="json"),
502            ]
503        ),
504    ]
505    """reference to another axis with an optional offset (see `SizeReference`)"""
halo: Annotated[int, Ge(ge=1)] = PydanticUndefined

The halo should be cropped from the output tensor to avoid boundary effects. It is to be cropped from both sides, i.e. size_after_crop = size - 2 * halo. To document a halo that is already cropped by the model use size.offset instead.

size: Annotated[SizeReference, FieldInfo(annotation=NoneType, required=True, examples=[10, {'tensor_id': 't', 'axis_id': 'a', 'offset': 5}])] = PydanticUndefined

reference to another axis with an optional offset (see SizeReference)

BATCH_AXIS_ID = 'batch'
class BatchAxis(AxisBase):
511class BatchAxis(AxisBase):
512    implemented_type: ClassVar[Literal["batch"]] = "batch"
513    if TYPE_CHECKING:
514        type: Literal["batch"] = "batch"
515    else:
516        type: Literal["batch"]
517
518    id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID
519    size: Optional[Literal[1]] = None
520    """The batch size may be fixed to 1,
521    otherwise (the default) it may be chosen arbitrarily depending on available memory"""
522
523    @property
524    def scale(self):
525        return 1.0
526
527    @property
528    def concatenable(self):
529        return True
530
531    @property
532    def unit(self):
533        return None
implemented_type: ClassVar[Literal['batch']] = 'batch'
id: Annotated[AxisId, Predicate(_is_batch)] = 'batch'

An axis id unique across all axes of one tensor.

size: Optional[Literal[1]] = None

The batch size may be fixed to 1, otherwise (the default) it may be chosen arbitrarily depending on available memory

scale
523    @property
524    def scale(self):
525        return 1.0
concatenable
527    @property
528    def concatenable(self):
529        return True
unit
531    @property
532    def unit(self):
533        return None
type: Literal['batch'] = PydanticUndefined
Inherited Members
AxisBase
description
class ChannelAxis(AxisBase):
536class ChannelAxis(AxisBase):
537    implemented_type: ClassVar[Literal["channel"]] = "channel"
538    if TYPE_CHECKING:
539        type: Literal["channel"] = "channel"
540    else:
541        type: Literal["channel"]
542
543    id: NonBatchAxisId = AxisId("channel")
544
545    channel_names: NotEmpty[List[Identifier]]
546
547    @property
548    def size(self) -> int:
549        return len(self.channel_names)
550
551    @property
552    def concatenable(self):
553        return False
554
555    @property
556    def scale(self) -> float:
557        return 1.0
558
559    @property
560    def unit(self):
561        return None
implemented_type: ClassVar[Literal['channel']] = 'channel'
id: Annotated[AxisId, Predicate(_is_not_batch)] = 'channel'

An axis id unique across all axes of one tensor.

channel_names: Annotated[List[bioimageio.spec._internal.types.Identifier], MinLen(min_length=1)] = PydanticUndefined
size: int
547    @property
548    def size(self) -> int:
549        return len(self.channel_names)
concatenable
551    @property
552    def concatenable(self):
553        return False
scale: float
555    @property
556    def scale(self) -> float:
557        return 1.0
unit
559    @property
560    def unit(self):
561        return None
type: Literal['channel'] = PydanticUndefined
Inherited Members
AxisBase
description
class IndexAxisBase(AxisBase):
564class IndexAxisBase(AxisBase):
565    implemented_type: ClassVar[Literal["index"]] = "index"
566    if TYPE_CHECKING:
567        type: Literal["index"] = "index"
568    else:
569        type: Literal["index"]
570
571    id: NonBatchAxisId = AxisId("index")
572
573    @property
574    def scale(self) -> float:
575        return 1.0
576
577    @property
578    def unit(self):
579        return None
implemented_type: ClassVar[Literal['index']] = 'index'
id: Annotated[AxisId, Predicate(_is_not_batch)] = 'index'

An axis id unique across all axes of one tensor.

scale: float
573    @property
574    def scale(self) -> float:
575        return 1.0
unit
577    @property
578    def unit(self):
579        return None
type: Literal['index'] = PydanticUndefined
Inherited Members
AxisBase
description
class IndexInputAxis(IndexAxisBase, _WithInputAxisSize):
602class IndexInputAxis(IndexAxisBase, _WithInputAxisSize):
603    concatenable: bool = False
604    """If a model has a `concatenable` input axis, it can be processed blockwise,
605    splitting a longer sample axis into blocks matching its input tensor description.
606    Output axes are concatenable if they have a `SizeReference` to a concatenable
607    input axis.
608    """
concatenable: bool = False

If a model has a concatenable input axis, it can be processed blockwise, splitting a longer sample axis into blocks matching its input tensor description. Output axes are concatenable if they have a SizeReference to a concatenable input axis.

type: Literal['index'] = PydanticUndefined
class IndexOutputAxis(IndexAxisBase):
611class IndexOutputAxis(IndexAxisBase):
612    size: Annotated[
613        Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize],
614        Field(
615            examples=[
616                10,
617                SizeReference(
618                    tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
619                ).model_dump(mode="json"),
620            ]
621        ),
622    ]
623    """The size/length of this axis can be specified as
624    - fixed integer
625    - reference to another axis with an optional offset (`SizeReference`)
626    - data dependent size using `DataDependentSize` (size is only known after model inference)
627    """
size: Annotated[Union[Annotated[int, Gt(gt=0)], SizeReference, DataDependentSize], FieldInfo(annotation=NoneType, required=True, examples=[10, {'tensor_id': 't', 'axis_id': 'a', 'offset': 5}])] = PydanticUndefined

The size/length of this axis can be specified as

  • fixed integer
  • reference to another axis with an optional offset (SizeReference)
  • data dependent size using DataDependentSize (size is only known after model inference)
type: Literal['index'] = PydanticUndefined
class TimeAxisBase(AxisBase):
630class TimeAxisBase(AxisBase):
631    implemented_type: ClassVar[Literal["time"]] = "time"
632    if TYPE_CHECKING:
633        type: Literal["time"] = "time"
634    else:
635        type: Literal["time"]
636
637    id: NonBatchAxisId = AxisId("time")
638    unit: Optional[TimeUnit] = None
639    scale: Annotated[float, Gt(0)] = 1.0
implemented_type: ClassVar[Literal['time']] = 'time'
id: Annotated[AxisId, Predicate(_is_not_batch)] = 'time'

An axis id unique across all axes of one tensor.

unit: Optional[Literal['attosecond', 'centisecond', 'day', 'decisecond', 'exasecond', 'femtosecond', 'gigasecond', 'hectosecond', 'hour', 'kilosecond', 'megasecond', 'microsecond', 'millisecond', 'minute', 'nanosecond', 'petasecond', 'picosecond', 'second', 'terasecond', 'yoctosecond', 'yottasecond', 'zeptosecond', 'zettasecond']] = None
scale: Annotated[float, Gt(gt=0)] = 1.0
type: Literal['time'] = PydanticUndefined
Inherited Members
AxisBase
description
class TimeInputAxis(TimeAxisBase, _WithInputAxisSize):
642class TimeInputAxis(TimeAxisBase, _WithInputAxisSize):
643    concatenable: bool = False
644    """If a model has a `concatenable` input axis, it can be processed blockwise,
645    splitting a longer sample axis into blocks matching its input tensor description.
646    Output axes are concatenable if they have a `SizeReference` to a concatenable
647    input axis.
648    """
concatenable: bool = False

If a model has a concatenable input axis, it can be processed blockwise, splitting a longer sample axis into blocks matching its input tensor description. Output axes are concatenable if they have a SizeReference to a concatenable input axis.

type: Literal['time'] = PydanticUndefined
class SpaceAxisBase(AxisBase):
651class SpaceAxisBase(AxisBase):
652    implemented_type: ClassVar[Literal["space"]] = "space"
653    if TYPE_CHECKING:
654        type: Literal["space"] = "space"
655    else:
656        type: Literal["space"]
657
658    id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x")
659    unit: Optional[SpaceUnit] = None
660    scale: Annotated[float, Gt(0)] = 1.0
implemented_type: ClassVar[Literal['space']] = 'space'
id: Annotated[AxisId, Predicate(_is_not_batch), FieldInfo(annotation=NoneType, required=True, examples=['x', 'y', 'z'])] = 'x'

An axis id unique across all axes of one tensor.

unit: Optional[Literal['attometer', 'angstrom', 'centimeter', 'decimeter', 'exameter', 'femtometer', 'foot', 'gigameter', 'hectometer', 'inch', 'kilometer', 'megameter', 'meter', 'micrometer', 'mile', 'millimeter', 'nanometer', 'parsec', 'petameter', 'picometer', 'terameter', 'yard', 'yoctometer', 'yottameter', 'zeptometer', 'zettameter']] = None
scale: Annotated[float, Gt(gt=0)] = 1.0
type: Literal['space'] = PydanticUndefined
Inherited Members
AxisBase
description
class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize):
663class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize):
664    concatenable: bool = False
665    """If a model has a `concatenable` input axis, it can be processed blockwise,
666    splitting a longer sample axis into blocks matching its input tensor description.
667    Output axes are concatenable if they have a `SizeReference` to a concatenable
668    input axis.
669    """
concatenable: bool = False

If a model has a concatenable input axis, it can be processed blockwise, splitting a longer sample axis into blocks matching its input tensor description. Output axes are concatenable if they have a SizeReference to a concatenable input axis.

type: Literal['space'] = PydanticUndefined
INPUT_AXIS_TYPES = (<class 'BatchAxis'>, <class 'ChannelAxis'>, <class 'IndexInputAxis'>, <class 'TimeInputAxis'>, <class 'SpaceInputAxis'>)

intended for isinstance comparisons in py<3.10

InputAxis = typing.Annotated[typing.Union[BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)]
class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize):
705class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize):
706    pass
type: Literal['time'] = PydanticUndefined
class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo):
709class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo):
710    pass
type: Literal['time'] = PydanticUndefined
class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize):
729class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize):
730    pass
type: Literal['space'] = PydanticUndefined
class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo):
733class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo):
734    pass
type: Literal['space'] = PydanticUndefined
OutputAxis = typing.Annotated[typing.Union[BatchAxis, ChannelAxis, IndexOutputAxis, typing.Annotated[typing.Union[typing.Annotated[TimeOutputAxis, Tag(tag='wo_halo')], typing.Annotated[TimeOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)], typing.Annotated[typing.Union[typing.Annotated[SpaceOutputAxis, Tag(tag='wo_halo')], typing.Annotated[SpaceOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)]
OUTPUT_AXIS_TYPES = (<class 'BatchAxis'>, <class 'ChannelAxis'>, <class 'IndexOutputAxis'>, <class 'TimeOutputAxis'>, <class 'TimeOutputAxisWithHalo'>, <class 'SpaceOutputAxis'>, <class 'SpaceOutputAxisWithHalo'>)

intended for isinstance comparisons in py<3.10

AnyAxis = typing.Union[typing.Annotated[typing.Union[BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], typing.Annotated[typing.Union[BatchAxis, ChannelAxis, IndexOutputAxis, typing.Annotated[typing.Union[typing.Annotated[TimeOutputAxis, Tag(tag='wo_halo')], typing.Annotated[TimeOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)], typing.Annotated[typing.Union[typing.Annotated[SpaceOutputAxis, Tag(tag='wo_halo')], typing.Annotated[SpaceOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]
ANY_AXIS_TYPES = (<class 'BatchAxis'>, <class 'ChannelAxis'>, <class 'IndexInputAxis'>, <class 'TimeInputAxis'>, <class 'SpaceInputAxis'>, <class 'BatchAxis'>, <class 'ChannelAxis'>, <class 'IndexOutputAxis'>, <class 'TimeOutputAxis'>, <class 'TimeOutputAxisWithHalo'>, <class 'SpaceOutputAxis'>, <class 'SpaceOutputAxisWithHalo'>)

intended for isinstance comparisons in py<3.10

TVs = typing.Union[typing.Annotated[typing.List[int], MinLen(min_length=1)], typing.Annotated[typing.List[float], MinLen(min_length=1)], typing.Annotated[typing.List[bool], MinLen(min_length=1)], typing.Annotated[typing.List[str], MinLen(min_length=1)]]
NominalOrOrdinalDType = typing.Literal['float32', 'float64', 'uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64', 'bool']
class NominalOrOrdinalDataDescr(bioimageio.spec._internal.node.Node):
791class NominalOrOrdinalDataDescr(Node):
792    values: TVs
793    """A fixed set of nominal or an ascending sequence of ordinal values.
794    In this case `data.type` is required to be an unsigend integer type, e.g. 'uint8'.
795    String `values` are interpreted as labels for tensor values 0, ..., N.
796    Note: as YAML 1.2 does not natively support a "set" datatype,
797    nominal values should be given as a sequence (aka list/array) as well.
798    """
799
800    type: Annotated[
801        NominalOrOrdinalDType,
802        Field(
803            examples=[
804                "float32",
805                "uint8",
806                "uint16",
807                "int64",
808                "bool",
809            ],
810        ),
811    ] = "uint8"
812
813    @model_validator(mode="after")
814    def _validate_values_match_type(
815        self,
816    ) -> Self:
817        incompatible: List[Any] = []
818        for v in self.values:
819            if self.type == "bool":
820                if not isinstance(v, bool):
821                    incompatible.append(v)
822            elif self.type in DTYPE_LIMITS:
823                if (
824                    isinstance(v, (int, float))
825                    and (
826                        v < DTYPE_LIMITS[self.type].min
827                        or v > DTYPE_LIMITS[self.type].max
828                    )
829                    or (isinstance(v, str) and "uint" not in self.type)
830                    or (isinstance(v, float) and "int" in self.type)
831                ):
832                    incompatible.append(v)
833            else:
834                incompatible.append(v)
835
836            if len(incompatible) == 5:
837                incompatible.append("...")
838                break
839
840        if incompatible:
841            raise ValueError(
842                f"data type '{self.type}' incompatible with values {incompatible}"
843            )
844
845        return self
846
847    unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None
848
849    @property
850    def range(self):
851        if isinstance(self.values[0], str):
852            return 0, len(self.values) - 1
853        else:
854            return min(self.values), max(self.values)
values: Union[Annotated[List[int], MinLen(min_length=1)], Annotated[List[float], MinLen(min_length=1)], Annotated[List[bool], MinLen(min_length=1)], Annotated[List[str], MinLen(min_length=1)]] = PydanticUndefined

A fixed set of nominal or an ascending sequence of ordinal values. In this case data.type is required to be an unsigend integer type, e.g. 'uint8'. String values are interpreted as labels for tensor values 0, ..., N. Note: as YAML 1.2 does not natively support a "set" datatype, nominal values should be given as a sequence (aka list/array) as well.

type: Annotated[Literal['float32', 'float64', 'uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64', 'bool'], FieldInfo(annotation=NoneType, required=True, examples=['float32', 'uint8', 'uint16', 'int64', 'bool'])] = 'uint8'
unit: Union[Literal['arbitrary unit'], bioimageio.spec._internal.types.SiUnit, NoneType] = None
range
849    @property
850    def range(self):
851        if isinstance(self.values[0], str):
852            return 0, len(self.values) - 1
853        else:
854            return min(self.values), max(self.values)
IntervalOrRatioDType = typing.Literal['float32', 'float64', 'uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64']
class IntervalOrRatioDataDescr(bioimageio.spec._internal.node.Node):
871class IntervalOrRatioDataDescr(Node):
872    type: Annotated[  # TODO: rename to dtype
873        IntervalOrRatioDType,
874        Field(
875            examples=["float32", "float64", "uint8", "uint16"],
876        ),
877    ] = "float32"
878    range: Tuple[Optional[float], Optional[float]] = (
879        None,
880        None,
881    )
882    """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor.
883    `None` corresponds to min/max of what can be expressed by **type**."""
884    unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit"
885    scale: float = 1.0
886    """Scale for data on an interval (or ratio) scale."""
887    offset: Optional[float] = None
888    """Offset for data on a ratio scale."""
889
890    @model_validator(mode="before")
891    def _replace_inf(cls, data: Any):
892        if is_dict(data):
893            if "range" in data and is_sequence(data["range"]):
894                forbidden = (
895                    "inf",
896                    "-inf",
897                    ".inf",
898                    "-.inf",
899                    float("inf"),
900                    float("-inf"),
901                )
902                if any(v in forbidden for v in data["range"]):
903                    issue_warning("replaced 'inf' value", value=data["range"])
904
905                data["range"] = tuple(
906                    (None if v in forbidden else v) for v in data["range"]
907                )
908
909        return data
type: Annotated[Literal['float32', 'float64', 'uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64'], FieldInfo(annotation=NoneType, required=True, examples=['float32', 'float64', 'uint8', 'uint16'])] = 'float32'
range: Tuple[Optional[float], Optional[float]] = (None, None)

Tuple (minimum, maximum) specifying the allowed range of the data in this tensor. None corresponds to min/max of what can be expressed by type.

unit: Union[Literal['arbitrary unit'], bioimageio.spec._internal.types.SiUnit] = 'arbitrary unit'
scale: float = 1.0

Scale for data on an interval (or ratio) scale.

offset: Optional[float] = None

Offset for data on a ratio scale.

TensorDataDescr = typing.Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr]
class ProcessingDescrBase(bioimageio.spec._internal.common_nodes.NodeWithExplicitlySetFields, abc.ABC):
915class ProcessingDescrBase(NodeWithExplicitlySetFields, ABC):
916    """processing base class"""

processing base class

class BinarizeKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
919class BinarizeKwargs(ProcessingKwargs):
920    """key word arguments for `BinarizeDescr`"""
921
922    threshold: float
923    """The fixed threshold"""

key word arguments for BinarizeDescr

threshold: float = PydanticUndefined

The fixed threshold

class BinarizeAlongAxisKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
926class BinarizeAlongAxisKwargs(ProcessingKwargs):
927    """key word arguments for `BinarizeDescr`"""
928
929    threshold: NotEmpty[List[float]]
930    """The fixed threshold values along `axis`"""
931
932    axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
933    """The `threshold` axis"""

key word arguments for BinarizeDescr

threshold: Annotated[List[float], MinLen(min_length=1)] = PydanticUndefined

The fixed threshold values along axis

axis: Annotated[AxisId, Predicate(_is_not_batch), FieldInfo(annotation=NoneType, required=True, examples=['channel'])] = PydanticUndefined

The threshold axis

class BinarizeDescr(ProcessingDescrBase):
936class BinarizeDescr(ProcessingDescrBase):
937    """Binarize the tensor with a fixed threshold.
938
939    Values above `BinarizeKwargs.threshold`/`BinarizeAlongAxisKwargs.threshold`
940    will be set to one, values below the threshold to zero.
941
942    Examples:
943    - in YAML
944        ```yaml
945        postprocessing:
946          - id: binarize
947            kwargs:
948              axis: 'channel'
949              threshold: [0.25, 0.5, 0.75]
950        ```
951    - in Python:
952        >>> postprocessing = [BinarizeDescr(
953        ...   kwargs=BinarizeAlongAxisKwargs(
954        ...       axis=AxisId('channel'),
955        ...       threshold=[0.25, 0.5, 0.75],
956        ...   )
957        ... )]
958    """
959
960    implemented_id: ClassVar[Literal["binarize"]] = "binarize"
961    if TYPE_CHECKING:
962        id: Literal["binarize"] = "binarize"
963    else:
964        id: Literal["binarize"]
965    kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs]

Binarize the tensor with a fixed threshold.

Values above BinarizeKwargs.threshold/BinarizeAlongAxisKwargs.threshold will be set to one, values below the threshold to zero.

Examples:

  • in YAML
postprocessing:
  - id: binarize
    kwargs:
      axis: 'channel'
      threshold: [0.25, 0.5, 0.75]
  • in Python:
    >>> postprocessing = [BinarizeDescr(
    ...   kwargs=BinarizeAlongAxisKwargs(
    ...       axis=AxisId('channel'),
    ...       threshold=[0.25, 0.5, 0.75],
    ...   )
    ... )]
    
implemented_id: ClassVar[Literal['binarize']] = 'binarize'
kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs] = PydanticUndefined
id: Literal['binarize'] = PydanticUndefined
class ClipDescr(ProcessingDescrBase):
968class ClipDescr(ProcessingDescrBase):
969    """Set tensor values below min to min and above max to max.
970
971    See `ScaleRangeDescr` for examples.
972    """
973
974    implemented_id: ClassVar[Literal["clip"]] = "clip"
975    if TYPE_CHECKING:
976        id: Literal["clip"] = "clip"
977    else:
978        id: Literal["clip"]
979
980    kwargs: ClipKwargs

Set tensor values below min to min and above max to max.

See ScaleRangeDescr for examples.

implemented_id: ClassVar[Literal['clip']] = 'clip'
kwargs: bioimageio.spec.model.v0_4.ClipKwargs = PydanticUndefined
id: Literal['clip'] = PydanticUndefined
class EnsureDtypeKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
983class EnsureDtypeKwargs(ProcessingKwargs):
984    """key word arguments for `EnsureDtypeDescr`"""
985
986    dtype: Literal[
987        "float32",
988        "float64",
989        "uint8",
990        "int8",
991        "uint16",
992        "int16",
993        "uint32",
994        "int32",
995        "uint64",
996        "int64",
997        "bool",
998    ]

key word arguments for EnsureDtypeDescr

dtype: Literal['float32', 'float64', 'uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64', 'bool'] = PydanticUndefined
class EnsureDtypeDescr(ProcessingDescrBase):
1001class EnsureDtypeDescr(ProcessingDescrBase):
1002    """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching).
1003
1004    This can for example be used to ensure the inner neural network model gets a
1005    different input tensor data type than the fully described bioimage.io model does.
1006
1007    Examples:
1008        The described bioimage.io model (incl. preprocessing) accepts any
1009        float32-compatible tensor, normalizes it with percentiles and clipping and then
1010        casts it to uint8, which is what the neural network in this example expects.
1011        - in YAML
1012            ```yaml
1013            inputs:
1014            - data:
1015                type: float32  # described bioimage.io model is compatible with any float32 input tensor
1016              preprocessing:
1017              - id: scale_range
1018                  kwargs:
1019                  axes: ['y', 'x']
1020                  max_percentile: 99.8
1021                  min_percentile: 5.0
1022              - id: clip
1023                  kwargs:
1024                  min: 0.0
1025                  max: 1.0
1026              - id: ensure_dtype  # the neural network of the model requires uint8
1027                  kwargs:
1028                  dtype: uint8
1029            ```
1030        - in Python:
1031            >>> preprocessing = [
1032            ...     ScaleRangeDescr(
1033            ...         kwargs=ScaleRangeKwargs(
1034            ...           axes= (AxisId('y'), AxisId('x')),
1035            ...           max_percentile= 99.8,
1036            ...           min_percentile= 5.0,
1037            ...         )
1038            ...     ),
1039            ...     ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)),
1040            ...     EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")),
1041            ... ]
1042    """
1043
1044    implemented_id: ClassVar[Literal["ensure_dtype"]] = "ensure_dtype"
1045    if TYPE_CHECKING:
1046        id: Literal["ensure_dtype"] = "ensure_dtype"
1047    else:
1048        id: Literal["ensure_dtype"]
1049
1050    kwargs: EnsureDtypeKwargs

Cast the tensor data type to EnsureDtypeKwargs.dtype (if not matching).

This can for example be used to ensure the inner neural network model gets a different input tensor data type than the fully described bioimage.io model does.

Examples:

The described bioimage.io model (incl. preprocessing) accepts any float32-compatible tensor, normalizes it with percentiles and clipping and then casts it to uint8, which is what the neural network in this example expects.

  • in YAML

inputs:
- data:
    type: float32  # described bioimage.io model is compatible with any float32 input tensor
  preprocessing:
  - id: scale_range
      kwargs:
      axes: ['y', 'x']
      max_percentile: 99.8
      min_percentile: 5.0
  - id: clip
      kwargs:
      min: 0.0
      max: 1.0
  - id: ensure_dtype  # the neural network of the model requires uint8
      kwargs:
      dtype: uint8

  • in Python:
    >>> preprocessing = [
    ...     ScaleRangeDescr(
    ...         kwargs=ScaleRangeKwargs(
    ...           axes= (AxisId('y'), AxisId('x')),
    ...           max_percentile= 99.8,
    ...           min_percentile= 5.0,
    ...         )
    ...     ),
    ...     ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)),
    ...     EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")),
    ... ]
    
implemented_id: ClassVar[Literal['ensure_dtype']] = 'ensure_dtype'
kwargs: EnsureDtypeKwargs = PydanticUndefined
id: Literal['ensure_dtype'] = PydanticUndefined
class ScaleLinearKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1053class ScaleLinearKwargs(ProcessingKwargs):
1054    """Key word arguments for `ScaleLinearDescr`"""
1055
1056    gain: float = 1.0
1057    """multiplicative factor"""
1058
1059    offset: float = 0.0
1060    """additive term"""
1061
1062    @model_validator(mode="after")
1063    def _validate(self) -> Self:
1064        if self.gain == 1.0 and self.offset == 0.0:
1065            raise ValueError(
1066                "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
1067                + " != 0.0."
1068            )
1069
1070        return self

Key word arguments for ScaleLinearDescr

gain: float = 1.0

multiplicative factor

offset: float = 0.0

additive term

class ScaleLinearAlongAxisKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1073class ScaleLinearAlongAxisKwargs(ProcessingKwargs):
1074    """Key word arguments for `ScaleLinearDescr`"""
1075
1076    axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
1077    """The axis of gain and offset values."""
1078
1079    gain: Union[float, NotEmpty[List[float]]] = 1.0
1080    """multiplicative factor"""
1081
1082    offset: Union[float, NotEmpty[List[float]]] = 0.0
1083    """additive term"""
1084
1085    @model_validator(mode="after")
1086    def _validate(self) -> Self:
1087        if isinstance(self.gain, list):
1088            if isinstance(self.offset, list):
1089                if len(self.gain) != len(self.offset):
1090                    raise ValueError(
1091                        f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match."
1092                    )
1093            else:
1094                self.offset = [float(self.offset)] * len(self.gain)
1095        elif isinstance(self.offset, list):
1096            self.gain = [float(self.gain)] * len(self.offset)
1097        else:
1098            raise ValueError(
1099                "Do not specify an `axis` for scalar gain and offset values."
1100            )
1101
1102        if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset):
1103            raise ValueError(
1104                "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
1105                + " != 0.0."
1106            )
1107
1108        return self

Key word arguments for ScaleLinearDescr

axis: Annotated[AxisId, Predicate(_is_not_batch), FieldInfo(annotation=NoneType, required=True, examples=['channel'])] = PydanticUndefined

The axis of gain and offset values.

gain: Union[float, Annotated[List[float], MinLen(min_length=1)]] = 1.0

multiplicative factor

offset: Union[float, Annotated[List[float], MinLen(min_length=1)]] = 0.0

additive term

class ScaleLinearDescr(ProcessingDescrBase):
1111class ScaleLinearDescr(ProcessingDescrBase):
1112    """Fixed linear scaling.
1113
1114    Examples:
1115      1. Scale with scalar gain and offset
1116        - in YAML
1117        ```yaml
1118        preprocessing:
1119          - id: scale_linear
1120            kwargs:
1121              gain: 2.0
1122              offset: 3.0
1123        ```
1124        - in Python:
1125        >>> preprocessing = [
1126        ...     ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0))
1127        ... ]
1128
1129      2. Independent scaling along an axis
1130        - in YAML
1131        ```yaml
1132        preprocessing:
1133          - id: scale_linear
1134            kwargs:
1135              axis: 'channel'
1136              gain: [1.0, 2.0, 3.0]
1137        ```
1138        - in Python:
1139        >>> preprocessing = [
1140        ...     ScaleLinearDescr(
1141        ...         kwargs=ScaleLinearAlongAxisKwargs(
1142        ...             axis=AxisId("channel"),
1143        ...             gain=[1.0, 2.0, 3.0],
1144        ...         )
1145        ...     )
1146        ... ]
1147
1148    """
1149
1150    implemented_id: ClassVar[Literal["scale_linear"]] = "scale_linear"
1151    if TYPE_CHECKING:
1152        id: Literal["scale_linear"] = "scale_linear"
1153    else:
1154        id: Literal["scale_linear"]
1155    kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs]

Fixed linear scaling.

Examples:
  1. Scale with scalar gain and offset

    • in YAML

      preprocessing:
        - id: scale_linear
          kwargs:
            gain: 2.0
            offset: 3.0
      
    • in Python:

      >>> preprocessing = [
      ...     ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0))
      ... ]
      
  2. Independent scaling along an axis

    • in YAML

      preprocessing:
        - id: scale_linear
          kwargs:
            axis: 'channel'
            gain: [1.0, 2.0, 3.0]
      
    • in Python:

      >>> preprocessing = [
      ...     ScaleLinearDescr(
      ...         kwargs=ScaleLinearAlongAxisKwargs(
      ...             axis=AxisId("channel"),
      ...             gain=[1.0, 2.0, 3.0],
      ...         )
      ...     )
      ... ]
      
implemented_id: ClassVar[Literal['scale_linear']] = 'scale_linear'
kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs] = PydanticUndefined
id: Literal['scale_linear'] = PydanticUndefined
class SigmoidDescr(ProcessingDescrBase):
1158class SigmoidDescr(ProcessingDescrBase):
1159    """The logistic sigmoid function, a.k.a. expit function.
1160
1161    Examples:
1162    - in YAML
1163        ```yaml
1164        postprocessing:
1165          - id: sigmoid
1166        ```
1167    - in Python:
1168        >>> postprocessing = [SigmoidDescr()]
1169    """
1170
1171    implemented_id: ClassVar[Literal["sigmoid"]] = "sigmoid"
1172    if TYPE_CHECKING:
1173        id: Literal["sigmoid"] = "sigmoid"
1174    else:
1175        id: Literal["sigmoid"]
1176
1177    @property
1178    def kwargs(self) -> ProcessingKwargs:
1179        """empty kwargs"""
1180        return ProcessingKwargs()

The logistic sigmoid function, a.k.a. expit function.

Examples:

  • in YAML
postprocessing:
  - id: sigmoid
  • in Python:
    >>> postprocessing = [SigmoidDescr()]
    
implemented_id: ClassVar[Literal['sigmoid']] = 'sigmoid'
1177    @property
1178    def kwargs(self) -> ProcessingKwargs:
1179        """empty kwargs"""
1180        return ProcessingKwargs()

empty kwargs

id: Literal['sigmoid'] = PydanticUndefined
class SoftmaxKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1183class SoftmaxKwargs(ProcessingKwargs):
1184    """key word arguments for `SoftmaxDescr`"""
1185
1186    axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] = AxisId("channel")
1187    """The axis to apply the softmax function along.
1188    Note:
1189        Defaults to 'channel' axis
1190        (which may not exist, in which case
1191        a different axis id has to be specified).
1192    """

key word arguments for SoftmaxDescr

axis: Annotated[AxisId, Predicate(_is_not_batch), FieldInfo(annotation=NoneType, required=True, examples=['channel'])] = 'channel'

The axis to apply the softmax function along.

Note:

Defaults to 'channel' axis (which may not exist, in which case a different axis id has to be specified).

class SoftmaxDescr(ProcessingDescrBase):
1195class SoftmaxDescr(ProcessingDescrBase):
1196    """The softmax function.
1197
1198    Examples:
1199    - in YAML
1200        ```yaml
1201        postprocessing:
1202          - id: softmax
1203            kwargs:
1204              axis: channel
1205        ```
1206    - in Python:
1207        >>> postprocessing = [SoftmaxDescr(kwargs=SoftmaxKwargs(axis=AxisId("channel")))]
1208    """
1209
1210    implemented_id: ClassVar[Literal["softmax"]] = "softmax"
1211    if TYPE_CHECKING:
1212        id: Literal["softmax"] = "softmax"
1213    else:
1214        id: Literal["softmax"]
1215
1216    kwargs: SoftmaxKwargs = Field(default_factory=SoftmaxKwargs.model_construct)

The softmax function.

Examples:

  • in YAML
postprocessing:
  - id: softmax
    kwargs:
      axis: channel
  • in Python:
    >>> postprocessing = [SoftmaxDescr(kwargs=SoftmaxKwargs(axis=AxisId("channel")))]
    
implemented_id: ClassVar[Literal['softmax']] = 'softmax'
kwargs: SoftmaxKwargs = PydanticUndefined
id: Literal['softmax'] = PydanticUndefined
class FixedZeroMeanUnitVarianceKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1219class FixedZeroMeanUnitVarianceKwargs(ProcessingKwargs):
1220    """key word arguments for `FixedZeroMeanUnitVarianceDescr`"""
1221
1222    mean: float
1223    """The mean value to normalize with."""
1224
1225    std: Annotated[float, Ge(1e-6)]
1226    """The standard deviation value to normalize with."""

key word arguments for FixedZeroMeanUnitVarianceDescr

mean: float = PydanticUndefined

The mean value to normalize with.

std: Annotated[float, Ge(ge=1e-06)] = PydanticUndefined

The standard deviation value to normalize with.

class FixedZeroMeanUnitVarianceAlongAxisKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1229class FixedZeroMeanUnitVarianceAlongAxisKwargs(ProcessingKwargs):
1230    """key word arguments for `FixedZeroMeanUnitVarianceDescr`"""
1231
1232    mean: NotEmpty[List[float]]
1233    """The mean value(s) to normalize with."""
1234
1235    std: NotEmpty[List[Annotated[float, Ge(1e-6)]]]
1236    """The standard deviation value(s) to normalize with.
1237    Size must match `mean` values."""
1238
1239    axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])]
1240    """The axis of the mean/std values to normalize each entry along that dimension
1241    separately."""
1242
1243    @model_validator(mode="after")
1244    def _mean_and_std_match(self) -> Self:
1245        if len(self.mean) != len(self.std):
1246            raise ValueError(
1247                f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})"
1248                + " must match."
1249            )
1250
1251        return self

key word arguments for FixedZeroMeanUnitVarianceDescr

mean: Annotated[List[float], MinLen(min_length=1)] = PydanticUndefined

The mean value(s) to normalize with.

std: Annotated[List[Annotated[float, Ge(ge=1e-06)]], MinLen(min_length=1)] = PydanticUndefined

The standard deviation value(s) to normalize with. Size must match mean values.

axis: Annotated[AxisId, Predicate(_is_not_batch), FieldInfo(annotation=NoneType, required=True, examples=['channel', 'index'])] = PydanticUndefined

The axis of the mean/std values to normalize each entry along that dimension separately.

class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1254class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1255    """Subtract a given mean and divide by the standard deviation.
1256
1257    Normalize with fixed, precomputed values for
1258    `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std`
1259    Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given
1260    axes.
1261
1262    Examples:
1263    1. scalar value for whole tensor
1264        - in YAML
1265        ```yaml
1266        preprocessing:
1267          - id: fixed_zero_mean_unit_variance
1268            kwargs:
1269              mean: 103.5
1270              std: 13.7
1271        ```
1272        - in Python
1273        >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
1274        ...   kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7)
1275        ... )]
1276
1277    2. independently along an axis
1278        - in YAML
1279        ```yaml
1280        preprocessing:
1281          - id: fixed_zero_mean_unit_variance
1282            kwargs:
1283              axis: channel
1284              mean: [101.5, 102.5, 103.5]
1285              std: [11.7, 12.7, 13.7]
1286        ```
1287        - in Python
1288        >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
1289        ...   kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
1290        ...     axis=AxisId("channel"),
1291        ...     mean=[101.5, 102.5, 103.5],
1292        ...     std=[11.7, 12.7, 13.7],
1293        ...   )
1294        ... )]
1295    """
1296
1297    implemented_id: ClassVar[Literal["fixed_zero_mean_unit_variance"]] = (
1298        "fixed_zero_mean_unit_variance"
1299    )
1300    if TYPE_CHECKING:
1301        id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance"
1302    else:
1303        id: Literal["fixed_zero_mean_unit_variance"]
1304
1305    kwargs: Union[
1306        FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs
1307    ]

Subtract a given mean and divide by the standard deviation.

Normalize with fixed, precomputed values for FixedZeroMeanUnitVarianceKwargs.mean and FixedZeroMeanUnitVarianceKwargs.std Use FixedZeroMeanUnitVarianceAlongAxisKwargs for independent scaling along given axes.

Examples:

  1. scalar value for whole tensor
    • in YAML
preprocessing:
  - id: fixed_zero_mean_unit_variance
    kwargs:
      mean: 103.5
      std: 13.7
- in Python >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( ... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7) ... )]
  1. independently along an axis

    • in YAML
    preprocessing:
      - id: fixed_zero_mean_unit_variance
        kwargs:
          axis: channel
          mean: [101.5, 102.5, 103.5]
          std: [11.7, 12.7, 13.7]
    
    • in Python
      >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
      ...   kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
      ...     axis=AxisId("channel"),
      ...     mean=[101.5, 102.5, 103.5],
      ...     std=[11.7, 12.7, 13.7],
      ...   )
      ... )]
      
implemented_id: ClassVar[Literal['fixed_zero_mean_unit_variance']] = 'fixed_zero_mean_unit_variance'
id: Literal['fixed_zero_mean_unit_variance'] = PydanticUndefined
class ZeroMeanUnitVarianceKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1310class ZeroMeanUnitVarianceKwargs(ProcessingKwargs):
1311    """key word arguments for `ZeroMeanUnitVarianceDescr`"""
1312
1313    axes: Annotated[
1314        Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1315    ] = None
1316    """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1317    For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1318    resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1319    To normalize each sample independently leave out the 'batch' axis.
1320    Default: Scale all axes jointly."""
1321
1322    eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1323    """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`."""

key word arguments for ZeroMeanUnitVarianceDescr

axes: Annotated[Optional[Sequence[AxisId]], FieldInfo(annotation=NoneType, required=True, examples=[('batch', 'x', 'y')])] = None

The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') resulting in a tensor of equal shape normalized per channel, specify axes=('batch', 'x', 'y'). To normalize each sample independently leave out the 'batch' axis. Default: Scale all axes jointly.

eps: Annotated[float, Interval(gt=0, ge=None, lt=None, le=0.1)] = 1e-06

epsilon for numeric stability: out = (tensor - mean) / (std + eps).

class ZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1326class ZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1327    """Subtract mean and divide by variance.
1328
1329    Examples:
1330        Subtract tensor mean and variance
1331        - in YAML
1332        ```yaml
1333        preprocessing:
1334          - id: zero_mean_unit_variance
1335        ```
1336        - in Python
1337        >>> preprocessing = [ZeroMeanUnitVarianceDescr()]
1338    """
1339
1340    implemented_id: ClassVar[Literal["zero_mean_unit_variance"]] = (
1341        "zero_mean_unit_variance"
1342    )
1343    if TYPE_CHECKING:
1344        id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance"
1345    else:
1346        id: Literal["zero_mean_unit_variance"]
1347
1348    kwargs: ZeroMeanUnitVarianceKwargs = Field(
1349        default_factory=ZeroMeanUnitVarianceKwargs.model_construct
1350    )

Subtract mean and divide by variance.

Examples:

Subtract tensor mean and variance

  • in YAML
preprocessing:
  - id: zero_mean_unit_variance
  • in Python
    >>> preprocessing = [ZeroMeanUnitVarianceDescr()]
    
implemented_id: ClassVar[Literal['zero_mean_unit_variance']] = 'zero_mean_unit_variance'
kwargs: ZeroMeanUnitVarianceKwargs = PydanticUndefined
id: Literal['zero_mean_unit_variance'] = PydanticUndefined
class ScaleRangeKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1353class ScaleRangeKwargs(ProcessingKwargs):
1354    """key word arguments for `ScaleRangeDescr`
1355
1356    For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default)
1357    this processing step normalizes data to the [0, 1] intervall.
1358    For other percentiles the normalized values will partially be outside the [0, 1]
1359    intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the
1360    normalized values to a range.
1361    """
1362
1363    axes: Annotated[
1364        Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1365    ] = None
1366    """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value.
1367    For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1368    resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1369    To normalize samples independently, leave out the "batch" axis.
1370    Default: Scale all axes jointly."""
1371
1372    min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0
1373    """The lower percentile used to determine the value to align with zero."""
1374
1375    max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0
1376    """The upper percentile used to determine the value to align with one.
1377    Has to be bigger than `min_percentile`.
1378    The range is 1 to 100 instead of 0 to 100 to avoid mistakenly
1379    accepting percentiles specified in the range 0.0 to 1.0."""
1380
1381    eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1382    """Epsilon for numeric stability.
1383    `out = (tensor - v_lower) / (v_upper - v_lower + eps)`;
1384    with `v_lower,v_upper` values at the respective percentiles."""
1385
1386    reference_tensor: Optional[TensorId] = None
1387    """Tensor ID to compute the percentiles from. Default: The tensor itself.
1388    For any tensor in `inputs` only input tensor references are allowed."""
1389
1390    @field_validator("max_percentile", mode="after")
1391    @classmethod
1392    def min_smaller_max(cls, value: float, info: ValidationInfo) -> float:
1393        if (min_p := info.data["min_percentile"]) >= value:
1394            raise ValueError(f"min_percentile {min_p} >= max_percentile {value}")
1395
1396        return value

key word arguments for ScaleRangeDescr

For min_percentile=0.0 (the default) and max_percentile=100 (the default) this processing step normalizes data to the [0, 1] intervall. For other percentiles the normalized values will partially be outside the [0, 1] intervall. Use ScaleRange followed by ClipDescr if you want to limit the normalized values to a range.

axes: Annotated[Optional[Sequence[AxisId]], FieldInfo(annotation=NoneType, required=True, examples=[('batch', 'x', 'y')])] = None

The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value. For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') resulting in a tensor of equal shape normalized per channel, specify axes=('batch', 'x', 'y'). To normalize samples independently, leave out the "batch" axis. Default: Scale all axes jointly.

min_percentile: Annotated[float, Interval(gt=None, ge=0, lt=100, le=None)] = 0.0

The lower percentile used to determine the value to align with zero.

max_percentile: Annotated[float, Interval(gt=1, ge=None, lt=None, le=100)] = 100.0

The upper percentile used to determine the value to align with one. Has to be bigger than min_percentile. The range is 1 to 100 instead of 0 to 100 to avoid mistakenly accepting percentiles specified in the range 0.0 to 1.0.

eps: Annotated[float, Interval(gt=0, ge=None, lt=None, le=0.1)] = 1e-06

Epsilon for numeric stability. out = (tensor - v_lower) / (v_upper - v_lower + eps); with v_lower,v_upper values at the respective percentiles.

reference_tensor: Optional[TensorId] = None

Tensor ID to compute the percentiles from. Default: The tensor itself. For any tensor in inputs only input tensor references are allowed.

@field_validator('max_percentile', mode='after')
@classmethod
def min_smaller_max( cls, value: float, info: pydantic_core.core_schema.ValidationInfo) -> float:
1390    @field_validator("max_percentile", mode="after")
1391    @classmethod
1392    def min_smaller_max(cls, value: float, info: ValidationInfo) -> float:
1393        if (min_p := info.data["min_percentile"]) >= value:
1394            raise ValueError(f"min_percentile {min_p} >= max_percentile {value}")
1395
1396        return value
class ScaleRangeDescr(ProcessingDescrBase):
1399class ScaleRangeDescr(ProcessingDescrBase):
1400    """Scale with percentiles.
1401
1402    Examples:
1403    1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0
1404        - in YAML
1405        ```yaml
1406        preprocessing:
1407          - id: scale_range
1408            kwargs:
1409              axes: ['y', 'x']
1410              max_percentile: 99.8
1411              min_percentile: 5.0
1412        ```
1413        - in Python
1414        >>> preprocessing = [
1415        ...     ScaleRangeDescr(
1416        ...         kwargs=ScaleRangeKwargs(
1417        ...           axes= (AxisId('y'), AxisId('x')),
1418        ...           max_percentile= 99.8,
1419        ...           min_percentile= 5.0,
1420        ...         )
1421        ...     ),
1422        ...     ClipDescr(
1423        ...         kwargs=ClipKwargs(
1424        ...             min=0.0,
1425        ...             max=1.0,
1426        ...         )
1427        ...     ),
1428        ... ]
1429
1430      2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles.
1431        - in YAML
1432        ```yaml
1433        preprocessing:
1434          - id: scale_range
1435            kwargs:
1436              axes: ['y', 'x']
1437              max_percentile: 99.8
1438              min_percentile: 5.0
1439                  - id: scale_range
1440           - id: clip
1441             kwargs:
1442              min: 0.0
1443              max: 1.0
1444        ```
1445        - in Python
1446        >>> preprocessing = [ScaleRangeDescr(
1447        ...   kwargs=ScaleRangeKwargs(
1448        ...       axes= (AxisId('y'), AxisId('x')),
1449        ...       max_percentile= 99.8,
1450        ...       min_percentile= 5.0,
1451        ...   )
1452        ... )]
1453
1454    """
1455
1456    implemented_id: ClassVar[Literal["scale_range"]] = "scale_range"
1457    if TYPE_CHECKING:
1458        id: Literal["scale_range"] = "scale_range"
1459    else:
1460        id: Literal["scale_range"]
1461    kwargs: ScaleRangeKwargs = Field(default_factory=ScaleRangeKwargs.model_construct)

Scale with percentiles.

Examples:

  1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0
    • in YAML
preprocessing:
  - id: scale_range
    kwargs:
      axes: ['y', 'x']
      max_percentile: 99.8
      min_percentile: 5.0
- in Python >>> preprocessing = [ ... ScaleRangeDescr( ... kwargs=ScaleRangeKwargs( ... axes= (AxisId('y'), AxisId('x')), ... max_percentile= 99.8, ... min_percentile= 5.0, ... ) ... ), ... ClipDescr( ... kwargs=ClipKwargs( ... min=0.0, ... max=1.0, ... ) ... ), ... ]
  1. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles.

    • in YAML
    preprocessing:
      - id: scale_range
        kwargs:
          axes: ['y', 'x']
          max_percentile: 99.8
          min_percentile: 5.0
              - id: scale_range
       - id: clip
         kwargs:
          min: 0.0
          max: 1.0
    
    • in Python
      >>> preprocessing = [ScaleRangeDescr(
      ...   kwargs=ScaleRangeKwargs(
      ...       axes= (AxisId('y'), AxisId('x')),
      ...       max_percentile= 99.8,
      ...       min_percentile= 5.0,
      ...   )
      ... )]
      
implemented_id: ClassVar[Literal['scale_range']] = 'scale_range'
kwargs: ScaleRangeKwargs = PydanticUndefined
id: Literal['scale_range'] = PydanticUndefined
class ScaleMeanVarianceKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1464class ScaleMeanVarianceKwargs(ProcessingKwargs):
1465    """key word arguments for `ScaleMeanVarianceKwargs`"""
1466
1467    reference_tensor: TensorId
1468    """Name of tensor to match."""
1469
1470    axes: Annotated[
1471        Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1472    ] = None
1473    """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1474    For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1475    resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1476    To normalize samples independently, leave out the 'batch' axis.
1477    Default: Scale all axes jointly."""
1478
1479    eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1480    """Epsilon for numeric stability:
1481    `out  = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`"""

key word arguments for ScaleMeanVarianceKwargs

reference_tensor: TensorId = PydanticUndefined

Name of tensor to match.

axes: Annotated[Optional[Sequence[AxisId]], FieldInfo(annotation=NoneType, required=True, examples=[('batch', 'x', 'y')])] = None

The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') resulting in a tensor of equal shape normalized per channel, specify axes=('batch', 'x', 'y'). To normalize samples independently, leave out the 'batch' axis. Default: Scale all axes jointly.

eps: Annotated[float, Interval(gt=0, ge=None, lt=None, le=0.1)] = 1e-06

Epsilon for numeric stability: out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.

class ScaleMeanVarianceDescr(ProcessingDescrBase):
1484class ScaleMeanVarianceDescr(ProcessingDescrBase):
1485    """Scale a tensor's data distribution to match another tensor's mean/std.
1486    `out  = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`
1487    """
1488
1489    implemented_id: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance"
1490    if TYPE_CHECKING:
1491        id: Literal["scale_mean_variance"] = "scale_mean_variance"
1492    else:
1493        id: Literal["scale_mean_variance"]
1494    kwargs: ScaleMeanVarianceKwargs

Scale a tensor's data distribution to match another tensor's mean/std. out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.

implemented_id: ClassVar[Literal['scale_mean_variance']] = 'scale_mean_variance'
kwargs: ScaleMeanVarianceKwargs = PydanticUndefined
id: Literal['scale_mean_variance'] = PydanticUndefined
PreprocessingDescr = typing.Annotated[typing.Union[BinarizeDescr, ClipDescr, EnsureDtypeDescr, FixedZeroMeanUnitVarianceDescr, ScaleLinearDescr, ScaleRangeDescr, SigmoidDescr, SoftmaxDescr, ZeroMeanUnitVarianceDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)]
PostprocessingDescr = typing.Annotated[typing.Union[BinarizeDescr, ClipDescr, EnsureDtypeDescr, FixedZeroMeanUnitVarianceDescr, ScaleLinearDescr, ScaleMeanVarianceDescr, ScaleRangeDescr, SigmoidDescr, SoftmaxDescr, ZeroMeanUnitVarianceDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)]
class TensorDescrBase(bioimageio.spec._internal.node.Node, typing.Generic[~IO_AxisT]):
1530class TensorDescrBase(Node, Generic[IO_AxisT]):
1531    id: TensorId
1532    """Tensor id. No duplicates are allowed."""
1533
1534    description: Annotated[str, MaxLen(128)] = ""
1535    """free text description"""
1536
1537    axes: NotEmpty[Sequence[IO_AxisT]]
1538    """tensor axes"""
1539
1540    @property
1541    def shape(self):
1542        return tuple(a.size for a in self.axes)
1543
1544    @field_validator("axes", mode="after", check_fields=False)
1545    @classmethod
1546    def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]:
1547        batch_axes = [a for a in axes if a.type == "batch"]
1548        if len(batch_axes) > 1:
1549            raise ValueError(
1550                f"Only one batch axis (per tensor) allowed, but got {batch_axes}"
1551            )
1552
1553        seen_ids: Set[AxisId] = set()
1554        duplicate_axes_ids: Set[AxisId] = set()
1555        for a in axes:
1556            (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id)
1557
1558        if duplicate_axes_ids:
1559            raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}")
1560
1561        return axes
1562
1563    test_tensor: FAIR[Optional[FileDescr_]] = None
1564    """An example tensor to use for testing.
1565    Using the model with the test input tensors is expected to yield the test output tensors.
1566    Each test tensor has be a an ndarray in the
1567    [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format).
1568    The file extension must be '.npy'."""
1569
1570    sample_tensor: FAIR[Optional[FileDescr_]] = None
1571    """A sample tensor to illustrate a possible input/output for the model,
1572    The sample image primarily serves to inform a human user about an example use case
1573    and is typically stored as .hdf5, .png or .tiff.
1574    It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats)
1575    (numpy's `.npy` format is not supported).
1576    The image dimensionality has to match the number of axes specified in this tensor description.
1577    """
1578
1579    @model_validator(mode="after")
1580    def _validate_sample_tensor(self) -> Self:
1581        if self.sample_tensor is None or not get_validation_context().perform_io_checks:
1582            return self
1583
1584        reader = get_reader(self.sample_tensor.source, sha256=self.sample_tensor.sha256)
1585        tensor: NDArray[Any] = imread(  # pyright: ignore[reportUnknownVariableType]
1586            reader.read(),
1587            extension=PurePosixPath(reader.original_file_name).suffix,
1588        )
1589        n_dims = len(tensor.squeeze().shape)
1590        n_dims_min = n_dims_max = len(self.axes)
1591
1592        for a in self.axes:
1593            if isinstance(a, BatchAxis):
1594                n_dims_min -= 1
1595            elif isinstance(a.size, int):
1596                if a.size == 1:
1597                    n_dims_min -= 1
1598            elif isinstance(a.size, (ParameterizedSize, DataDependentSize)):
1599                if a.size.min == 1:
1600                    n_dims_min -= 1
1601            elif isinstance(a.size, SizeReference):
1602                if a.size.offset < 2:
1603                    # size reference may result in singleton axis
1604                    n_dims_min -= 1
1605            else:
1606                assert_never(a.size)
1607
1608        n_dims_min = max(0, n_dims_min)
1609        if n_dims < n_dims_min or n_dims > n_dims_max:
1610            raise ValueError(
1611                f"Expected sample tensor to have {n_dims_min} to"
1612                + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})."
1613            )
1614
1615        return self
1616
1617    data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = (
1618        IntervalOrRatioDataDescr()
1619    )
1620    """Description of the tensor's data values, optionally per channel.
1621    If specified per channel, the data `type` needs to match across channels."""
1622
1623    @property
1624    def dtype(
1625        self,
1626    ) -> Literal[
1627        "float32",
1628        "float64",
1629        "uint8",
1630        "int8",
1631        "uint16",
1632        "int16",
1633        "uint32",
1634        "int32",
1635        "uint64",
1636        "int64",
1637        "bool",
1638    ]:
1639        """dtype as specified under `data.type` or `data[i].type`"""
1640        if isinstance(self.data, collections.abc.Sequence):
1641            return self.data[0].type
1642        else:
1643            return self.data.type
1644
1645    @field_validator("data", mode="after")
1646    @classmethod
1647    def _check_data_type_across_channels(
1648        cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]
1649    ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]:
1650        if not isinstance(value, list):
1651            return value
1652
1653        dtypes = {t.type for t in value}
1654        if len(dtypes) > 1:
1655            raise ValueError(
1656                "Tensor data descriptions per channel need to agree in their data"
1657                + f" `type`, but found {dtypes}."
1658            )
1659
1660        return value
1661
1662    @model_validator(mode="after")
1663    def _check_data_matches_channelaxis(self) -> Self:
1664        if not isinstance(self.data, (list, tuple)):
1665            return self
1666
1667        for a in self.axes:
1668            if isinstance(a, ChannelAxis):
1669                size = a.size
1670                assert isinstance(size, int)
1671                break
1672        else:
1673            return self
1674
1675        if len(self.data) != size:
1676            raise ValueError(
1677                f"Got tensor data descriptions for {len(self.data)} channels, but"
1678                + f" '{a.id}' axis has size {size}."
1679            )
1680
1681        return self
1682
1683    def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]:
1684        if len(array.shape) != len(self.axes):
1685            raise ValueError(
1686                f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})"
1687                + f" incompatible with {len(self.axes)} axes."
1688            )
1689        return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
id: TensorId = PydanticUndefined

Tensor id. No duplicates are allowed.

description: Annotated[str, MaxLen(max_length=128)] = ''

free text description

axes: Annotated[Sequence[~IO_AxisT], MinLen(min_length=1)] = PydanticUndefined

tensor axes

shape
1540    @property
1541    def shape(self):
1542        return tuple(a.size for a in self.axes)
test_tensor: Annotated[Optional[Annotated[bioimageio.spec._internal.io.FileDescr, AfterValidator(func=<function wo_special_file_name at 0x7efc0116c900>), WrapSerializer(func=<function package_file_descr_serializer at 0x7efbf1b60d60>, return_type=PydanticUndefined, when_used='unless-none')]], AfterWarner(func=<function as_warning.<locals>.wrapper at 0x7efc006a6ac0>, severity=35, msg=None, context=None)] = None

An example tensor to use for testing. Using the model with the test input tensors is expected to yield the test output tensors. Each test tensor has be a an ndarray in the numpy.lib file format. The file extension must be '.npy'.

sample_tensor: Annotated[Optional[Annotated[bioimageio.spec._internal.io.FileDescr, AfterValidator(func=<function wo_special_file_name at 0x7efc0116c900>), WrapSerializer(func=<function package_file_descr_serializer at 0x7efbf1b60d60>, return_type=PydanticUndefined, when_used='unless-none')]], AfterWarner(func=<function as_warning.<locals>.wrapper at 0x7efc006a6ac0>, severity=35, msg=None, context=None)] = None

A sample tensor to illustrate a possible input/output for the model, The sample image primarily serves to inform a human user about an example use case and is typically stored as .hdf5, .png or .tiff. It has to be readable by the imageio library (numpy's .npy format is not supported). The image dimensionality has to match the number of axes specified in this tensor description.

data: Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr, Annotated[Sequence[Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr]], MinLen(min_length=1)]] = IntervalOrRatioDataDescr(type='float32', range=(None, None), unit='arbitrary unit', scale=1.0, offset=None)

Description of the tensor's data values, optionally per channel. If specified per channel, the data type needs to match across channels.

dtype: Literal['float32', 'float64', 'uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64', 'bool']
1623    @property
1624    def dtype(
1625        self,
1626    ) -> Literal[
1627        "float32",
1628        "float64",
1629        "uint8",
1630        "int8",
1631        "uint16",
1632        "int16",
1633        "uint32",
1634        "int32",
1635        "uint64",
1636        "int64",
1637        "bool",
1638    ]:
1639        """dtype as specified under `data.type` or `data[i].type`"""
1640        if isinstance(self.data, collections.abc.Sequence):
1641            return self.data[0].type
1642        else:
1643            return self.data.type

dtype as specified under data.type or data[i].type

def get_axis_sizes_for_array( self, array: numpy.ndarray[tuple[typing.Any, ...], numpy.dtype[typing.Any]]) -> Dict[AxisId, int]:
1683    def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]:
1684        if len(array.shape) != len(self.axes):
1685            raise ValueError(
1686                f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})"
1687                + f" incompatible with {len(self.axes)} axes."
1688            )
1689        return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
class InputTensorDescr(bioimageio.spec._internal.node.Node, typing.Generic[~IO_AxisT]):
1692class InputTensorDescr(TensorDescrBase[InputAxis]):
1693    id: TensorId = TensorId("input")
1694    """Input tensor id.
1695    No duplicates are allowed across all inputs and outputs."""
1696
1697    optional: bool = False
1698    """indicates that this tensor may be `None`"""
1699
1700    preprocessing: List[PreprocessingDescr] = Field(
1701        default_factory=cast(Callable[[], List[PreprocessingDescr]], list)
1702    )
1703
1704    """Description of how this input should be preprocessed.
1705
1706    notes:
1707    - If preprocessing does not start with an 'ensure_dtype' entry, it is added
1708      to ensure an input tensor's data type matches the input tensor's data description.
1709    - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an
1710      'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally
1711      changing the data type.
1712    """
1713
1714    @model_validator(mode="after")
1715    def _validate_preprocessing_kwargs(self) -> Self:
1716        axes_ids = [a.id for a in self.axes]
1717        for p in self.preprocessing:
1718            kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
1719            if kwargs_axes is None:
1720                continue
1721
1722            if not isinstance(kwargs_axes, collections.abc.Sequence):
1723                raise ValueError(
1724                    f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}"
1725                )
1726
1727            if any(a not in axes_ids for a in kwargs_axes):
1728                raise ValueError(
1729                    "`preprocessing.i.kwargs.axes` needs to be subset of axes ids"
1730                )
1731
1732        if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
1733            dtype = self.data.type
1734        else:
1735            dtype = self.data[0].type
1736
1737        # ensure `preprocessing` begins with `EnsureDtypeDescr`
1738        if not self.preprocessing or not isinstance(
1739            self.preprocessing[0], EnsureDtypeDescr
1740        ):
1741            self.preprocessing.insert(
1742                0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1743            )
1744
1745        # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
1746        if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)):
1747            self.preprocessing.append(
1748                EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1749            )
1750
1751        return self
id: TensorId = 'input'

Input tensor id. No duplicates are allowed across all inputs and outputs.

optional: bool = False

indicates that this tensor may be None

preprocessing: List[Annotated[Union[BinarizeDescr, ClipDescr, EnsureDtypeDescr, FixedZeroMeanUnitVarianceDescr, ScaleLinearDescr, ScaleRangeDescr, SigmoidDescr, SoftmaxDescr, ZeroMeanUnitVarianceDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)]] = PydanticUndefined

Description of how this input should be preprocessed.

notes:

  • If preprocessing does not start with an 'ensure_dtype' entry, it is added to ensure an input tensor's data type matches the input tensor's data description.
  • If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally changing the data type.
def convert_axes( axes: str, *, shape: Union[Sequence[int], bioimageio.spec.model.v0_4.ParameterizedInputShape, bioimageio.spec.model.v0_4.ImplicitOutputShape], tensor_type: Literal['input', 'output'], halo: Optional[Sequence[int]], size_refs: Mapping[bioimageio.spec.model.v0_4.TensorName, Mapping[str, int]]):
1754def convert_axes(
1755    axes: str,
1756    *,
1757    shape: Union[
1758        Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4
1759    ],
1760    tensor_type: Literal["input", "output"],
1761    halo: Optional[Sequence[int]],
1762    size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
1763):
1764    ret: List[AnyAxis] = []
1765    for i, a in enumerate(axes):
1766        axis_type = _AXIS_TYPE_MAP.get(a, a)
1767        if axis_type == "batch":
1768            ret.append(BatchAxis())
1769            continue
1770
1771        scale = 1.0
1772        if isinstance(shape, _ParameterizedInputShape_v0_4):
1773            if shape.step[i] == 0:
1774                size = shape.min[i]
1775            else:
1776                size = ParameterizedSize(min=shape.min[i], step=shape.step[i])
1777        elif isinstance(shape, _ImplicitOutputShape_v0_4):
1778            ref_t = str(shape.reference_tensor)
1779            if ref_t.count(".") == 1:
1780                t_id, orig_a_id = ref_t.split(".")
1781            else:
1782                t_id = ref_t
1783                orig_a_id = a
1784
1785            a_id = _AXIS_ID_MAP.get(orig_a_id, a)
1786            if not (orig_scale := shape.scale[i]):
1787                # old way to insert a new axis dimension
1788                size = int(2 * shape.offset[i])
1789            else:
1790                scale = 1 / orig_scale
1791                if axis_type in ("channel", "index"):
1792                    # these axes no longer have a scale
1793                    offset_from_scale = orig_scale * size_refs.get(
1794                        _TensorName_v0_4(t_id), {}
1795                    ).get(orig_a_id, 0)
1796                else:
1797                    offset_from_scale = 0
1798                size = SizeReference(
1799                    tensor_id=TensorId(t_id),
1800                    axis_id=AxisId(a_id),
1801                    offset=int(offset_from_scale + 2 * shape.offset[i]),
1802                )
1803        else:
1804            size = shape[i]
1805
1806        if axis_type == "time":
1807            if tensor_type == "input":
1808                ret.append(TimeInputAxis(size=size, scale=scale))
1809            else:
1810                assert not isinstance(size, ParameterizedSize)
1811                if halo is None:
1812                    ret.append(TimeOutputAxis(size=size, scale=scale))
1813                else:
1814                    assert not isinstance(size, int)
1815                    ret.append(
1816                        TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i])
1817                    )
1818
1819        elif axis_type == "index":
1820            if tensor_type == "input":
1821                ret.append(IndexInputAxis(size=size))
1822            else:
1823                if isinstance(size, ParameterizedSize):
1824                    size = DataDependentSize(min=size.min)
1825
1826                ret.append(IndexOutputAxis(size=size))
1827        elif axis_type == "channel":
1828            assert not isinstance(size, ParameterizedSize)
1829            if isinstance(size, SizeReference):
1830                warnings.warn(
1831                    "Conversion of channel size from an implicit output shape may be"
1832                    + " wrong"
1833                )
1834                ret.append(
1835                    ChannelAxis(
1836                        channel_names=[
1837                            Identifier(f"channel{i}") for i in range(size.offset)
1838                        ]
1839                    )
1840                )
1841            else:
1842                ret.append(
1843                    ChannelAxis(
1844                        channel_names=[Identifier(f"channel{i}") for i in range(size)]
1845                    )
1846                )
1847        elif axis_type == "space":
1848            if tensor_type == "input":
1849                ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale))
1850            else:
1851                assert not isinstance(size, ParameterizedSize)
1852                if halo is None or halo[i] == 0:
1853                    ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale))
1854                elif isinstance(size, int):
1855                    raise NotImplementedError(
1856                        f"output axis with halo and fixed size (here {size}) not allowed"
1857                    )
1858                else:
1859                    ret.append(
1860                        SpaceOutputAxisWithHalo(
1861                            id=AxisId(a), size=size, scale=scale, halo=halo[i]
1862                        )
1863                    )
1864
1865    return ret
class OutputTensorDescr(bioimageio.spec._internal.node.Node, typing.Generic[~IO_AxisT]):
2030class OutputTensorDescr(TensorDescrBase[OutputAxis]):
2031    id: TensorId = TensorId("output")
2032    """Output tensor id.
2033    No duplicates are allowed across all inputs and outputs."""
2034
2035    postprocessing: List[PostprocessingDescr] = Field(
2036        default_factory=cast(Callable[[], List[PostprocessingDescr]], list)
2037    )
2038    """Description of how this output should be postprocessed.
2039
2040    note: `postprocessing` always ends with an 'ensure_dtype' operation.
2041          If not given this is added to cast to this tensor's `data.type`.
2042    """
2043
2044    @model_validator(mode="after")
2045    def _validate_postprocessing_kwargs(self) -> Self:
2046        axes_ids = [a.id for a in self.axes]
2047        for p in self.postprocessing:
2048            kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
2049            if kwargs_axes is None:
2050                continue
2051
2052            if not isinstance(kwargs_axes, collections.abc.Sequence):
2053                raise ValueError(
2054                    f"expected `axes` sequence, but got {type(kwargs_axes)}"
2055                )
2056
2057            if any(a not in axes_ids for a in kwargs_axes):
2058                raise ValueError("`kwargs.axes` needs to be subset of axes ids")
2059
2060        if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
2061            dtype = self.data.type
2062        else:
2063            dtype = self.data[0].type
2064
2065        # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
2066        if not self.postprocessing or not isinstance(
2067            self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)
2068        ):
2069            self.postprocessing.append(
2070                EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
2071            )
2072        return self
id: TensorId = 'output'

Output tensor id. No duplicates are allowed across all inputs and outputs.

postprocessing: List[Annotated[Union[BinarizeDescr, ClipDescr, EnsureDtypeDescr, FixedZeroMeanUnitVarianceDescr, ScaleLinearDescr, ScaleMeanVarianceDescr, ScaleRangeDescr, SigmoidDescr, SoftmaxDescr, ZeroMeanUnitVarianceDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)]] = PydanticUndefined

Description of how this output should be postprocessed.

note: postprocessing always ends with an 'ensure_dtype' operation. If not given this is added to cast to this tensor's data.type.

TensorDescr = typing.Union[InputTensorDescr, OutputTensorDescr]
def validate_tensors( tensors: Mapping[TensorId, Tuple[Union[InputTensorDescr, OutputTensorDescr], Optional[numpy.ndarray[tuple[Any, ...], numpy.dtype[Any]]]]], tensor_origin: Literal['test_tensor']):
2122def validate_tensors(
2123    tensors: Mapping[TensorId, Tuple[TensorDescr, Optional[NDArray[Any]]]],
2124    tensor_origin: Literal[
2125        "test_tensor"
2126    ],  # for more precise error messages, e.g. 'test_tensor'
2127):
2128    all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, Optional[int]]]] = {}
2129
2130    def e_msg(d: TensorDescr):
2131        return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]"
2132
2133    for descr, array in tensors.values():
2134        if array is None:
2135            axis_sizes = {a.id: None for a in descr.axes}
2136        else:
2137            try:
2138                axis_sizes = descr.get_axis_sizes_for_array(array)
2139            except ValueError as e:
2140                raise ValueError(f"{e_msg(descr)} {e}")
2141
2142        all_tensor_axes[descr.id] = {a.id: (a, axis_sizes[a.id]) for a in descr.axes}
2143
2144    for descr, array in tensors.values():
2145        if array is None:
2146            continue
2147
2148        if descr.dtype in ("float32", "float64"):
2149            invalid_test_tensor_dtype = array.dtype.name not in (
2150                "float32",
2151                "float64",
2152                "uint8",
2153                "int8",
2154                "uint16",
2155                "int16",
2156                "uint32",
2157                "int32",
2158                "uint64",
2159                "int64",
2160            )
2161        else:
2162            invalid_test_tensor_dtype = array.dtype.name != descr.dtype
2163
2164        if invalid_test_tensor_dtype:
2165            raise ValueError(
2166                f"{e_msg(descr)}.{tensor_origin}.dtype '{array.dtype.name}' does not"
2167                + f" match described dtype '{descr.dtype}'"
2168            )
2169
2170        if array.min() > -1e-4 and array.max() < 1e-4:
2171            raise ValueError(
2172                "Output values are too small for reliable testing."
2173                + f" Values <-1e5 or >=1e5 must be present in {tensor_origin}"
2174            )
2175
2176        for a in descr.axes:
2177            actual_size = all_tensor_axes[descr.id][a.id][1]
2178            if actual_size is None:
2179                continue
2180
2181            if a.size is None:
2182                continue
2183
2184            if isinstance(a.size, int):
2185                if actual_size != a.size:
2186                    raise ValueError(
2187                        f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' "
2188                        + f"has incompatible size {actual_size}, expected {a.size}"
2189                    )
2190            elif isinstance(a.size, ParameterizedSize):
2191                _ = a.size.validate_size(actual_size)
2192            elif isinstance(a.size, DataDependentSize):
2193                _ = a.size.validate_size(actual_size)
2194            elif isinstance(a.size, SizeReference):
2195                ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id)
2196                if ref_tensor_axes is None:
2197                    raise ValueError(
2198                        f"{e_msg(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor"
2199                        + f" reference '{a.size.tensor_id}'"
2200                    )
2201
2202                ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None))
2203                if ref_axis is None or ref_size is None:
2204                    raise ValueError(
2205                        f"{e_msg(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis"
2206                        + f" reference '{a.size.tensor_id}.{a.size.axis_id}"
2207                    )
2208
2209                if a.unit != ref_axis.unit:
2210                    raise ValueError(
2211                        f"{e_msg(descr)}.axes[{a.id}].size: `SizeReference` requires"
2212                        + " axis and reference axis to have the same `unit`, but"
2213                        + f" {a.unit}!={ref_axis.unit}"
2214                    )
2215
2216                if actual_size != (
2217                    expected_size := (
2218                        ref_size * ref_axis.scale / a.scale + a.size.offset
2219                    )
2220                ):
2221                    raise ValueError(
2222                        f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' of size"
2223                        + f" {actual_size} invalid for referenced size {ref_size};"
2224                        + f" expected {expected_size}"
2225                    )
2226            else:
2227                assert_never(a.size)
FileDescr_dependencies = typing.Annotated[bioimageio.spec._internal.io.FileDescr, AfterValidator(func=<function wo_special_file_name>), WrapSerializer(func=<function package_file_descr_serializer>, return_type=PydanticUndefined, when_used='unless-none'), WithSuffix(suffix=('.yaml', '.yml'), case_sensitive=True), FieldInfo(annotation=NoneType, required=True, examples=[{'source': 'environment.yaml'}])]
class ArchitectureFromFileDescr(_ArchitectureCallableDescr, bioimageio.spec._internal.io.FileDescr):
2247class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr):
2248    source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2249    """Architecture source file"""
2250
2251    @model_serializer(mode="wrap", when_used="unless-none")
2252    def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo):
2253        return package_file_descr_serializer(self, nxt, info)

A file description

source: Annotated[Union[bioimageio.spec._internal.url.HttpUrl, bioimageio.spec._internal.io.RelativeFilePath, Annotated[pathlib.Path, PathType(path_type='file'), FieldInfo(annotation=NoneType, required=True, title='FilePath')]], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')]), AfterValidator(func=<function wo_special_file_name at 0x7efc0116c900>)] = PydanticUndefined

Architecture source file

class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr):
2256class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr):
2257    import_from: str
2258    """Where to import the callable from, i.e. `from <import_from> import <callable>`"""
import_from: str = PydanticUndefined

Where to import the callable from, i.e. from <import_from> import <callable>

class WeightsEntryDescrBase(bioimageio.spec._internal.io.FileDescr):
2318class WeightsEntryDescrBase(FileDescr):
2319    type: ClassVar[WeightsFormat]
2320    weights_format_name: ClassVar[str]  # human readable
2321
2322    source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2323    """Source of the weights file."""
2324
2325    authors: Optional[List[Author]] = None
2326    """Authors
2327    Either the person(s) that have trained this model resulting in the original weights file.
2328        (If this is the initial weights entry, i.e. it does not have a `parent`)
2329    Or the person(s) who have converted the weights to this weights format.
2330        (If this is a child weight, i.e. it has a `parent` field)
2331    """
2332
2333    parent: Annotated[
2334        Optional[WeightsFormat], Field(examples=["pytorch_state_dict"])
2335    ] = None
2336    """The source weights these weights were converted from.
2337    For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`,
2338    The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights.
2339    All weight entries except one (the initial set of weights resulting from training the model),
2340    need to have this field."""
2341
2342    comment: str = ""
2343    """A comment about this weights entry, for example how these weights were created."""
2344
2345    @model_validator(mode="after")
2346    def _validate(self) -> Self:
2347        if self.type == self.parent:
2348            raise ValueError("Weights entry can't be it's own parent.")
2349
2350        return self
2351
2352    @model_serializer(mode="wrap", when_used="unless-none")
2353    def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo):
2354        return package_file_descr_serializer(self, nxt, info)

A file description

type: ClassVar[Literal['keras_hdf5', 'onnx', 'pytorch_state_dict', 'tensorflow_js', 'tensorflow_saved_model_bundle', 'torchscript']]
weights_format_name: ClassVar[str]
source: Annotated[Union[bioimageio.spec._internal.url.HttpUrl, bioimageio.spec._internal.io.RelativeFilePath, Annotated[pathlib.Path, PathType(path_type='file'), FieldInfo(annotation=NoneType, required=True, title='FilePath')]], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')]), AfterValidator(func=<function wo_special_file_name at 0x7efc0116c900>)] = PydanticUndefined

Source of the weights file.

authors: Optional[List[bioimageio.spec.generic.v0_3.Author]] = None

Authors Either the person(s) that have trained this model resulting in the original weights file. (If this is the initial weights entry, i.e. it does not have a parent) Or the person(s) who have converted the weights to this weights format. (If this is a child weight, i.e. it has a parent field)

parent: Annotated[Optional[Literal['keras_hdf5', 'onnx', 'pytorch_state_dict', 'tensorflow_js', 'tensorflow_saved_model_bundle', 'torchscript']], FieldInfo(annotation=NoneType, required=True, examples=['pytorch_state_dict'])] = None

The source weights these weights were converted from. For example, if a model's weights were converted from the pytorch_state_dict format to torchscript, The pytorch_state_dict weights entry has no parent and is the parent of the torchscript weights. All weight entries except one (the initial set of weights resulting from training the model), need to have this field.

comment: str = ''

A comment about this weights entry, for example how these weights were created.

class KerasHdf5WeightsDescr(WeightsEntryDescrBase):
2357class KerasHdf5WeightsDescr(WeightsEntryDescrBase):
2358    type = "keras_hdf5"
2359    weights_format_name: ClassVar[str] = "Keras HDF5"
2360    tensorflow_version: Version
2361    """TensorFlow version used to create these weights."""

A file description

type = 'keras_hdf5'
weights_format_name: ClassVar[str] = 'Keras HDF5'
tensorflow_version: bioimageio.spec._internal.version_type.Version = PydanticUndefined

TensorFlow version used to create these weights.

FileDescr_external_data = typing.Annotated[bioimageio.spec._internal.io.FileDescr, AfterValidator(func=<function wo_special_file_name>), WrapSerializer(func=<function package_file_descr_serializer>, return_type=PydanticUndefined, when_used='unless-none'), WithSuffix(suffix='.data', case_sensitive=True), FieldInfo(annotation=NoneType, required=True, examples=[{'source': 'weights.onnx.data'}])]
class OnnxWeightsDescr(WeightsEntryDescrBase):
2371class OnnxWeightsDescr(WeightsEntryDescrBase):
2372    type = "onnx"
2373    weights_format_name: ClassVar[str] = "ONNX"
2374    opset_version: Annotated[int, Ge(7)]
2375    """ONNX opset version"""
2376
2377    external_data: Optional[FileDescr_external_data] = None
2378    """Source of the external ONNX data file holding the weights.
2379    (If present **source** holds the ONNX architecture without weights)."""
2380
2381    @model_validator(mode="after")
2382    def _validate_external_data_unique_file_name(self) -> Self:
2383        if self.external_data is not None and (
2384            extract_file_name(self.source)
2385            == extract_file_name(self.external_data.source)
2386        ):
2387            raise ValueError(
2388                f"ONNX `external_data` file name '{extract_file_name(self.external_data.source)}'"
2389                + " must be different from ONNX `source` file name."
2390            )
2391
2392        return self

A file description

type = 'onnx'
weights_format_name: ClassVar[str] = 'ONNX'
opset_version: Annotated[int, Ge(ge=7)] = PydanticUndefined

ONNX opset version

external_data: Optional[Annotated[bioimageio.spec._internal.io.FileDescr, AfterValidator(func=<function wo_special_file_name at 0x7efc0116c900>), WrapSerializer(func=<function package_file_descr_serializer at 0x7efbf1b60d60>, return_type=PydanticUndefined, when_used='unless-none'), WithSuffix(suffix='.data', case_sensitive=True), FieldInfo(annotation=NoneType, required=True, examples=[{'source': 'weights.onnx.data'}])]] = None

Source of the external ONNX data file holding the weights. (If present source holds the ONNX architecture without weights).

class PytorchStateDictWeightsDescr(WeightsEntryDescrBase):
2395class PytorchStateDictWeightsDescr(WeightsEntryDescrBase):
2396    type = "pytorch_state_dict"
2397    weights_format_name: ClassVar[str] = "Pytorch State Dict"
2398    architecture: Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr]
2399    pytorch_version: Version
2400    """Version of the PyTorch library used.
2401    If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible.
2402    """
2403    dependencies: Optional[FileDescr_dependencies] = None
2404    """Custom depencies beyond pytorch described in a Conda environment file.
2405    Allows to specify custom dependencies, see conda docs:
2406    - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms)
2407    - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually)
2408
2409    The conda environment file should include pytorch and any version pinning has to be compatible with
2410    **pytorch_version**.
2411    """

A file description

type = 'pytorch_state_dict'
weights_format_name: ClassVar[str] = 'Pytorch State Dict'
architecture: Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr] = PydanticUndefined
pytorch_version: bioimageio.spec._internal.version_type.Version = PydanticUndefined

Version of the PyTorch library used. If architecture.depencencies is specified it has to include pytorch and any version pinning has to be compatible.

dependencies: Optional[Annotated[bioimageio.spec._internal.io.FileDescr, AfterValidator(func=<function wo_special_file_name at 0x7efc0116c900>), WrapSerializer(func=<function package_file_descr_serializer at 0x7efbf1b60d60>, return_type=PydanticUndefined, when_used='unless-none'), WithSuffix(suffix=('.yaml', '.yml'), case_sensitive=True), FieldInfo(annotation=NoneType, required=True, examples=[{'source': 'environment.yaml'}])]] = None

Custom depencies beyond pytorch described in a Conda environment file. Allows to specify custom dependencies, see conda docs:

The conda environment file should include pytorch and any version pinning has to be compatible with pytorch_version.

class TensorflowJsWeightsDescr(WeightsEntryDescrBase):
2414class TensorflowJsWeightsDescr(WeightsEntryDescrBase):
2415    type = "tensorflow_js"
2416    weights_format_name: ClassVar[str] = "Tensorflow.js"
2417    tensorflow_version: Version
2418    """Version of the TensorFlow library used."""
2419
2420    source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2421    """The multi-file weights.
2422    All required files/folders should be a zip archive."""

A file description

type = 'tensorflow_js'
weights_format_name: ClassVar[str] = 'Tensorflow.js'
tensorflow_version: bioimageio.spec._internal.version_type.Version = PydanticUndefined

Version of the TensorFlow library used.

source: Annotated[Union[bioimageio.spec._internal.url.HttpUrl, bioimageio.spec._internal.io.RelativeFilePath, Annotated[pathlib.Path, PathType(path_type='file'), FieldInfo(annotation=NoneType, required=True, title='FilePath')]], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')]), AfterValidator(func=<function wo_special_file_name at 0x7efc0116c900>)] = PydanticUndefined

The multi-file weights. All required files/folders should be a zip archive.

class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase):
2425class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase):
2426    type = "tensorflow_saved_model_bundle"
2427    weights_format_name: ClassVar[str] = "Tensorflow Saved Model"
2428    tensorflow_version: Version
2429    """Version of the TensorFlow library used."""
2430
2431    dependencies: Optional[FileDescr_dependencies] = None
2432    """Custom dependencies beyond tensorflow.
2433    Should include tensorflow and any version pinning has to be compatible with **tensorflow_version**."""
2434
2435    source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2436    """The multi-file weights.
2437    All required files/folders should be a zip archive."""

A file description

type = 'tensorflow_saved_model_bundle'
weights_format_name: ClassVar[str] = 'Tensorflow Saved Model'
tensorflow_version: bioimageio.spec._internal.version_type.Version = PydanticUndefined

Version of the TensorFlow library used.

dependencies: Optional[Annotated[bioimageio.spec._internal.io.FileDescr, AfterValidator(func=<function wo_special_file_name at 0x7efc0116c900>), WrapSerializer(func=<function package_file_descr_serializer at 0x7efbf1b60d60>, return_type=PydanticUndefined, when_used='unless-none'), WithSuffix(suffix=('.yaml', '.yml'), case_sensitive=True), FieldInfo(annotation=NoneType, required=True, examples=[{'source': 'environment.yaml'}])]] = None

Custom dependencies beyond tensorflow. Should include tensorflow and any version pinning has to be compatible with tensorflow_version.

source: Annotated[Union[bioimageio.spec._internal.url.HttpUrl, bioimageio.spec._internal.io.RelativeFilePath, Annotated[pathlib.Path, PathType(path_type='file'), FieldInfo(annotation=NoneType, required=True, title='FilePath')]], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')]), AfterValidator(func=<function wo_special_file_name at 0x7efc0116c900>)] = PydanticUndefined

The multi-file weights. All required files/folders should be a zip archive.

class TorchscriptWeightsDescr(WeightsEntryDescrBase):
2440class TorchscriptWeightsDescr(WeightsEntryDescrBase):
2441    type = "torchscript"
2442    weights_format_name: ClassVar[str] = "TorchScript"
2443    pytorch_version: Version
2444    """Version of the PyTorch library used."""

A file description

type = 'torchscript'
weights_format_name: ClassVar[str] = 'TorchScript'
pytorch_version: bioimageio.spec._internal.version_type.Version = PydanticUndefined

Version of the PyTorch library used.

class WeightsDescr(bioimageio.spec._internal.node.Node):
2447class WeightsDescr(Node):
2448    keras_hdf5: Optional[KerasHdf5WeightsDescr] = None
2449    onnx: Optional[OnnxWeightsDescr] = None
2450    pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None
2451    tensorflow_js: Optional[TensorflowJsWeightsDescr] = None
2452    tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = (
2453        None
2454    )
2455    torchscript: Optional[TorchscriptWeightsDescr] = None
2456
2457    @model_validator(mode="after")
2458    def check_entries(self) -> Self:
2459        entries = {wtype for wtype, entry in self if entry is not None}
2460
2461        if not entries:
2462            raise ValueError("Missing weights entry")
2463
2464        entries_wo_parent = {
2465            wtype
2466            for wtype, entry in self
2467            if entry is not None and hasattr(entry, "parent") and entry.parent is None
2468        }
2469        if len(entries_wo_parent) != 1:
2470            issue_warning(
2471                "Exactly one weights entry may not specify the `parent` field (got"
2472                + " {value}). That entry is considered the original set of model weights."
2473                + " Other weight formats are created through conversion of the orignal or"
2474                + " already converted weights. They have to reference the weights format"
2475                + " they were converted from as their `parent`.",
2476                value=len(entries_wo_parent),
2477                field="weights",
2478            )
2479
2480        for wtype, entry in self:
2481            if entry is None:
2482                continue
2483
2484            assert hasattr(entry, "type")
2485            assert hasattr(entry, "parent")
2486            assert wtype == entry.type
2487            if (
2488                entry.parent is not None and entry.parent not in entries
2489            ):  # self reference checked for `parent` field
2490                raise ValueError(
2491                    f"`weights.{wtype}.parent={entry.parent} not in specified weight"
2492                    + f" formats: {entries}"
2493                )
2494
2495        return self
2496
2497    def __getitem__(
2498        self,
2499        key: Literal[
2500            "keras_hdf5",
2501            "onnx",
2502            "pytorch_state_dict",
2503            "tensorflow_js",
2504            "tensorflow_saved_model_bundle",
2505            "torchscript",
2506        ],
2507    ):
2508        if key == "keras_hdf5":
2509            ret = self.keras_hdf5
2510        elif key == "onnx":
2511            ret = self.onnx
2512        elif key == "pytorch_state_dict":
2513            ret = self.pytorch_state_dict
2514        elif key == "tensorflow_js":
2515            ret = self.tensorflow_js
2516        elif key == "tensorflow_saved_model_bundle":
2517            ret = self.tensorflow_saved_model_bundle
2518        elif key == "torchscript":
2519            ret = self.torchscript
2520        else:
2521            raise KeyError(key)
2522
2523        if ret is None:
2524            raise KeyError(key)
2525
2526        return ret
2527
2528    @property
2529    def available_formats(self):
2530        return {
2531            **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}),
2532            **({} if self.onnx is None else {"onnx": self.onnx}),
2533            **(
2534                {}
2535                if self.pytorch_state_dict is None
2536                else {"pytorch_state_dict": self.pytorch_state_dict}
2537            ),
2538            **(
2539                {}
2540                if self.tensorflow_js is None
2541                else {"tensorflow_js": self.tensorflow_js}
2542            ),
2543            **(
2544                {}
2545                if self.tensorflow_saved_model_bundle is None
2546                else {
2547                    "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle
2548                }
2549            ),
2550            **({} if self.torchscript is None else {"torchscript": self.torchscript}),
2551        }
2552
2553    @property
2554    def missing_formats(self):
2555        return {
2556            wf for wf in get_args(WeightsFormat) if wf not in self.available_formats
2557        }
keras_hdf5: Optional[KerasHdf5WeightsDescr] = None
onnx: Optional[OnnxWeightsDescr] = None
pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None
tensorflow_js: Optional[TensorflowJsWeightsDescr] = None
tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = None
torchscript: Optional[TorchscriptWeightsDescr] = None
@model_validator(mode='after')
def check_entries(self) -> Self:
2457    @model_validator(mode="after")
2458    def check_entries(self) -> Self:
2459        entries = {wtype for wtype, entry in self if entry is not None}
2460
2461        if not entries:
2462            raise ValueError("Missing weights entry")
2463
2464        entries_wo_parent = {
2465            wtype
2466            for wtype, entry in self
2467            if entry is not None and hasattr(entry, "parent") and entry.parent is None
2468        }
2469        if len(entries_wo_parent) != 1:
2470            issue_warning(
2471                "Exactly one weights entry may not specify the `parent` field (got"
2472                + " {value}). That entry is considered the original set of model weights."
2473                + " Other weight formats are created through conversion of the orignal or"
2474                + " already converted weights. They have to reference the weights format"
2475                + " they were converted from as their `parent`.",
2476                value=len(entries_wo_parent),
2477                field="weights",
2478            )
2479
2480        for wtype, entry in self:
2481            if entry is None:
2482                continue
2483
2484            assert hasattr(entry, "type")
2485            assert hasattr(entry, "parent")
2486            assert wtype == entry.type
2487            if (
2488                entry.parent is not None and entry.parent not in entries
2489            ):  # self reference checked for `parent` field
2490                raise ValueError(
2491                    f"`weights.{wtype}.parent={entry.parent} not in specified weight"
2492                    + f" formats: {entries}"
2493                )
2494
2495        return self
available_formats
2528    @property
2529    def available_formats(self):
2530        return {
2531            **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}),
2532            **({} if self.onnx is None else {"onnx": self.onnx}),
2533            **(
2534                {}
2535                if self.pytorch_state_dict is None
2536                else {"pytorch_state_dict": self.pytorch_state_dict}
2537            ),
2538            **(
2539                {}
2540                if self.tensorflow_js is None
2541                else {"tensorflow_js": self.tensorflow_js}
2542            ),
2543            **(
2544                {}
2545                if self.tensorflow_saved_model_bundle is None
2546                else {
2547                    "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle
2548                }
2549            ),
2550            **({} if self.torchscript is None else {"torchscript": self.torchscript}),
2551        }
missing_formats
2553    @property
2554    def missing_formats(self):
2555        return {
2556            wf for wf in get_args(WeightsFormat) if wf not in self.available_formats
2557        }
class ModelId(bioimageio.spec.generic.v0_3.ResourceId):
2560class ModelId(ResourceId):
2561    pass

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.

class LinkedModel(bioimageio.spec.generic.v0_3.LinkedResourceBase):
2564class LinkedModel(LinkedResourceBase):
2565    """Reference to a bioimage.io model."""
2566
2567    id: ModelId
2568    """A valid model `id` from the bioimage.io collection."""

Reference to a bioimage.io model.

id: ModelId = PydanticUndefined

A valid model id from the bioimage.io collection.

class ReproducibilityTolerance(bioimageio.spec._internal.node.Node):
2590class ReproducibilityTolerance(Node, extra="allow"):
2591    """Describes what small numerical differences -- if any -- may be tolerated
2592    in the generated output when executing in different environments.
2593
2594    A tensor element *output* is considered mismatched to the **test_tensor** if
2595    abs(*output* - **test_tensor**) > **absolute_tolerance** + **relative_tolerance** * abs(**test_tensor**).
2596    (Internally we call [numpy.testing.assert_allclose](https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html).)
2597
2598    Motivation:
2599        For testing we can request the respective deep learning frameworks to be as
2600        reproducible as possible by setting seeds and chosing deterministic algorithms,
2601        but differences in operating systems, available hardware and installed drivers
2602        may still lead to numerical differences.
2603    """
2604
2605    relative_tolerance: RelativeTolerance = 1e-3
2606    """Maximum relative tolerance of reproduced test tensor."""
2607
2608    absolute_tolerance: AbsoluteTolerance = 1e-4
2609    """Maximum absolute tolerance of reproduced test tensor."""
2610
2611    mismatched_elements_per_million: MismatchedElementsPerMillion = 100
2612    """Maximum number of mismatched elements/pixels per million to tolerate."""
2613
2614    output_ids: Sequence[TensorId] = ()
2615    """Limits the output tensor IDs these reproducibility details apply to."""
2616
2617    weights_formats: Sequence[WeightsFormat] = ()
2618    """Limits the weights formats these details apply to."""

Describes what small numerical differences -- if any -- may be tolerated in the generated output when executing in different environments.

A tensor element output is considered mismatched to the test_tensor if abs(output - test_tensor) > absolute_tolerance + relative_tolerance * abs(test_tensor). (Internally we call numpy.testing.assert_allclose.)

Motivation:

For testing we can request the respective deep learning frameworks to be as reproducible as possible by setting seeds and chosing deterministic algorithms, but differences in operating systems, available hardware and installed drivers may still lead to numerical differences.

relative_tolerance: Annotated[float, Interval(gt=None, ge=0, lt=None, le=0.01)] = 0.001

Maximum relative tolerance of reproduced test tensor.

absolute_tolerance: Annotated[float, Interval(gt=None, ge=0, lt=None, le=None)] = 0.0001

Maximum absolute tolerance of reproduced test tensor.

mismatched_elements_per_million: Annotated[int, Interval(gt=None, ge=0, lt=None, le=1000)] = 100

Maximum number of mismatched elements/pixels per million to tolerate.

output_ids: Sequence[TensorId] = ()

Limits the output tensor IDs these reproducibility details apply to.

weights_formats: Sequence[Literal['keras_hdf5', 'onnx', 'pytorch_state_dict', 'tensorflow_js', 'tensorflow_saved_model_bundle', 'torchscript']] = ()

Limits the weights formats these details apply to.

class BioimageioConfig(bioimageio.spec._internal.node.Node):
2621class BioimageioConfig(Node, extra="allow"):
2622    reproducibility_tolerance: Sequence[ReproducibilityTolerance] = ()
2623    """Tolerances to allow when reproducing the model's test outputs
2624    from the model's test inputs.
2625    Only the first entry matching tensor id and weights format is considered.
2626    """
reproducibility_tolerance: Sequence[ReproducibilityTolerance] = ()

Tolerances to allow when reproducing the model's test outputs from the model's test inputs. Only the first entry matching tensor id and weights format is considered.

class Config(bioimageio.spec._internal.node.Node):
2629class Config(Node, extra="allow"):
2630    bioimageio: BioimageioConfig = Field(
2631        default_factory=BioimageioConfig.model_construct
2632    )
bioimageio: BioimageioConfig = PydanticUndefined
2635class ModelDescr(GenericModelDescrBase):
2636    """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights.
2637    These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
2638    """
2639
2640    implemented_format_version: ClassVar[Literal["0.5.6"]] = "0.5.6"
2641    if TYPE_CHECKING:
2642        format_version: Literal["0.5.6"] = "0.5.6"
2643    else:
2644        format_version: Literal["0.5.6"]
2645        """Version of the bioimage.io model description specification used.
2646        When creating a new model always use the latest micro/patch version described here.
2647        The `format_version` is important for any consumer software to understand how to parse the fields.
2648        """
2649
2650    implemented_type: ClassVar[Literal["model"]] = "model"
2651    if TYPE_CHECKING:
2652        type: Literal["model"] = "model"
2653    else:
2654        type: Literal["model"]
2655        """Specialized resource type 'model'"""
2656
2657    id: Optional[ModelId] = None
2658    """bioimage.io-wide unique resource identifier
2659    assigned by bioimage.io; version **un**specific."""
2660
2661    authors: FAIR[List[Author]] = Field(
2662        default_factory=cast(Callable[[], List[Author]], list)
2663    )
2664    """The authors are the creators of the model RDF and the primary points of contact."""
2665
2666    documentation: FAIR[Optional[FileSource_documentation]] = None
2667    """URL or relative path to a markdown file with additional documentation.
2668    The recommended documentation file name is `README.md`. An `.md` suffix is mandatory.
2669    The documentation should include a '#[#] Validation' (sub)section
2670    with details on how to quantitatively validate the model on unseen data."""
2671
2672    @field_validator("documentation", mode="after")
2673    @classmethod
2674    def _validate_documentation(
2675        cls, value: Optional[FileSource_documentation]
2676    ) -> Optional[FileSource_documentation]:
2677        if not get_validation_context().perform_io_checks or value is None:
2678            return value
2679
2680        doc_reader = get_reader(value)
2681        doc_content = doc_reader.read().decode(encoding="utf-8")
2682        if not re.search("#.*[vV]alidation", doc_content):
2683            issue_warning(
2684                "No '# Validation' (sub)section found in {value}.",
2685                value=value,
2686                field="documentation",
2687            )
2688
2689        return value
2690
2691    inputs: NotEmpty[Sequence[InputTensorDescr]]
2692    """Describes the input tensors expected by this model."""
2693
2694    @field_validator("inputs", mode="after")
2695    @classmethod
2696    def _validate_input_axes(
2697        cls, inputs: Sequence[InputTensorDescr]
2698    ) -> Sequence[InputTensorDescr]:
2699        input_size_refs = cls._get_axes_with_independent_size(inputs)
2700
2701        for i, ipt in enumerate(inputs):
2702            valid_independent_refs: Dict[
2703                Tuple[TensorId, AxisId],
2704                Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2705            ] = {
2706                **{
2707                    (ipt.id, a.id): (ipt, a, a.size)
2708                    for a in ipt.axes
2709                    if not isinstance(a, BatchAxis)
2710                    and isinstance(a.size, (int, ParameterizedSize))
2711                },
2712                **input_size_refs,
2713            }
2714            for a, ax in enumerate(ipt.axes):
2715                cls._validate_axis(
2716                    "inputs",
2717                    i=i,
2718                    tensor_id=ipt.id,
2719                    a=a,
2720                    axis=ax,
2721                    valid_independent_refs=valid_independent_refs,
2722                )
2723        return inputs
2724
2725    @staticmethod
2726    def _validate_axis(
2727        field_name: str,
2728        i: int,
2729        tensor_id: TensorId,
2730        a: int,
2731        axis: AnyAxis,
2732        valid_independent_refs: Dict[
2733            Tuple[TensorId, AxisId],
2734            Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2735        ],
2736    ):
2737        if isinstance(axis, BatchAxis) or isinstance(
2738            axis.size, (int, ParameterizedSize, DataDependentSize)
2739        ):
2740            return
2741        elif not isinstance(axis.size, SizeReference):
2742            assert_never(axis.size)
2743
2744        # validate axis.size SizeReference
2745        ref = (axis.size.tensor_id, axis.size.axis_id)
2746        if ref not in valid_independent_refs:
2747            raise ValueError(
2748                "Invalid tensor axis reference at"
2749                + f" {field_name}[{i}].axes[{a}].size: {axis.size}."
2750            )
2751        if ref == (tensor_id, axis.id):
2752            raise ValueError(
2753                "Self-referencing not allowed for"
2754                + f" {field_name}[{i}].axes[{a}].size: {axis.size}"
2755            )
2756        if axis.type == "channel":
2757            if valid_independent_refs[ref][1].type != "channel":
2758                raise ValueError(
2759                    "A channel axis' size may only reference another fixed size"
2760                    + " channel axis."
2761                )
2762            if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names:
2763                ref_size = valid_independent_refs[ref][2]
2764                assert isinstance(ref_size, int), (
2765                    "channel axis ref (another channel axis) has to specify fixed"
2766                    + " size"
2767                )
2768                generated_channel_names = [
2769                    Identifier(axis.channel_names.format(i=i))
2770                    for i in range(1, ref_size + 1)
2771                ]
2772                axis.channel_names = generated_channel_names
2773
2774        if (ax_unit := getattr(axis, "unit", None)) != (
2775            ref_unit := getattr(valid_independent_refs[ref][1], "unit", None)
2776        ):
2777            raise ValueError(
2778                "The units of an axis and its reference axis need to match, but"
2779                + f" '{ax_unit}' != '{ref_unit}'."
2780            )
2781        ref_axis = valid_independent_refs[ref][1]
2782        if isinstance(ref_axis, BatchAxis):
2783            raise ValueError(
2784                f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}"
2785                + " (a batch axis is not allowed as reference)."
2786            )
2787
2788        if isinstance(axis, WithHalo):
2789            min_size = axis.size.get_size(axis, ref_axis, n=0)
2790            if (min_size - 2 * axis.halo) < 1:
2791                raise ValueError(
2792                    f"axis {axis.id} with minimum size {min_size} is too small for halo"
2793                    + f" {axis.halo}."
2794                )
2795
2796            input_halo = axis.halo * axis.scale / ref_axis.scale
2797            if input_halo != int(input_halo) or input_halo % 2 == 1:
2798                raise ValueError(
2799                    f"input_halo {input_halo} (output_halo {axis.halo} *"
2800                    + f" output_scale {axis.scale} / input_scale {ref_axis.scale})"
2801                    + f"     {tensor_id}.{axis.id}."
2802                )
2803
2804    @model_validator(mode="after")
2805    def _validate_test_tensors(self) -> Self:
2806        if not get_validation_context().perform_io_checks:
2807            return self
2808
2809        test_output_arrays = [
2810            None if descr.test_tensor is None else load_array(descr.test_tensor)
2811            for descr in self.outputs
2812        ]
2813        test_input_arrays = [
2814            None if descr.test_tensor is None else load_array(descr.test_tensor)
2815            for descr in self.inputs
2816        ]
2817
2818        tensors = {
2819            descr.id: (descr, array)
2820            for descr, array in zip(
2821                chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays
2822            )
2823        }
2824        validate_tensors(tensors, tensor_origin="test_tensor")
2825
2826        output_arrays = {
2827            descr.id: array for descr, array in zip(self.outputs, test_output_arrays)
2828        }
2829        for rep_tol in self.config.bioimageio.reproducibility_tolerance:
2830            if not rep_tol.absolute_tolerance:
2831                continue
2832
2833            if rep_tol.output_ids:
2834                out_arrays = {
2835                    oid: a
2836                    for oid, a in output_arrays.items()
2837                    if oid in rep_tol.output_ids
2838                }
2839            else:
2840                out_arrays = output_arrays
2841
2842            for out_id, array in out_arrays.items():
2843                if array is None:
2844                    continue
2845
2846                if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01:
2847                    raise ValueError(
2848                        "config.bioimageio.reproducibility_tolerance.absolute_tolerance="
2849                        + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}"
2850                        + f" (1% of the maximum value of the test tensor '{out_id}')"
2851                    )
2852
2853        return self
2854
2855    @model_validator(mode="after")
2856    def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self:
2857        ipt_refs = {t.id for t in self.inputs}
2858        out_refs = {t.id for t in self.outputs}
2859        for ipt in self.inputs:
2860            for p in ipt.preprocessing:
2861                ref = p.kwargs.get("reference_tensor")
2862                if ref is None:
2863                    continue
2864                if ref not in ipt_refs:
2865                    raise ValueError(
2866                        f"`reference_tensor` '{ref}' not found. Valid input tensor"
2867                        + f" references are: {ipt_refs}."
2868                    )
2869
2870        for out in self.outputs:
2871            for p in out.postprocessing:
2872                ref = p.kwargs.get("reference_tensor")
2873                if ref is None:
2874                    continue
2875
2876                if ref not in ipt_refs and ref not in out_refs:
2877                    raise ValueError(
2878                        f"`reference_tensor` '{ref}' not found. Valid tensor references"
2879                        + f" are: {ipt_refs | out_refs}."
2880                    )
2881
2882        return self
2883
2884    # TODO: use validate funcs in validate_test_tensors
2885    # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]:
2886
2887    name: Annotated[
2888        str,
2889        RestrictCharacters(string.ascii_letters + string.digits + "_+- ()"),
2890        MinLen(5),
2891        MaxLen(128),
2892        warn(MaxLen(64), "Name longer than 64 characters.", INFO),
2893    ]
2894    """A human-readable name of this model.
2895    It should be no longer than 64 characters
2896    and may only contain letter, number, underscore, minus, parentheses and spaces.
2897    We recommend to chose a name that refers to the model's task and image modality.
2898    """
2899
2900    outputs: NotEmpty[Sequence[OutputTensorDescr]]
2901    """Describes the output tensors."""
2902
2903    @field_validator("outputs", mode="after")
2904    @classmethod
2905    def _validate_tensor_ids(
2906        cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo
2907    ) -> Sequence[OutputTensorDescr]:
2908        tensor_ids = [
2909            t.id for t in info.data.get("inputs", []) + info.data.get("outputs", [])
2910        ]
2911        duplicate_tensor_ids: List[str] = []
2912        seen: Set[str] = set()
2913        for t in tensor_ids:
2914            if t in seen:
2915                duplicate_tensor_ids.append(t)
2916
2917            seen.add(t)
2918
2919        if duplicate_tensor_ids:
2920            raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}")
2921
2922        return outputs
2923
2924    @staticmethod
2925    def _get_axes_with_parameterized_size(
2926        io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2927    ):
2928        return {
2929            f"{t.id}.{a.id}": (t, a, a.size)
2930            for t in io
2931            for a in t.axes
2932            if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize)
2933        }
2934
2935    @staticmethod
2936    def _get_axes_with_independent_size(
2937        io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2938    ):
2939        return {
2940            (t.id, a.id): (t, a, a.size)
2941            for t in io
2942            for a in t.axes
2943            if not isinstance(a, BatchAxis)
2944            and isinstance(a.size, (int, ParameterizedSize))
2945        }
2946
2947    @field_validator("outputs", mode="after")
2948    @classmethod
2949    def _validate_output_axes(
2950        cls, outputs: List[OutputTensorDescr], info: ValidationInfo
2951    ) -> List[OutputTensorDescr]:
2952        input_size_refs = cls._get_axes_with_independent_size(
2953            info.data.get("inputs", [])
2954        )
2955        output_size_refs = cls._get_axes_with_independent_size(outputs)
2956
2957        for i, out in enumerate(outputs):
2958            valid_independent_refs: Dict[
2959                Tuple[TensorId, AxisId],
2960                Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2961            ] = {
2962                **{
2963                    (out.id, a.id): (out, a, a.size)
2964                    for a in out.axes
2965                    if not isinstance(a, BatchAxis)
2966                    and isinstance(a.size, (int, ParameterizedSize))
2967                },
2968                **input_size_refs,
2969                **output_size_refs,
2970            }
2971            for a, ax in enumerate(out.axes):
2972                cls._validate_axis(
2973                    "outputs",
2974                    i,
2975                    out.id,
2976                    a,
2977                    ax,
2978                    valid_independent_refs=valid_independent_refs,
2979                )
2980
2981        return outputs
2982
2983    packaged_by: List[Author] = Field(
2984        default_factory=cast(Callable[[], List[Author]], list)
2985    )
2986    """The persons that have packaged and uploaded this model.
2987    Only required if those persons differ from the `authors`."""
2988
2989    parent: Optional[LinkedModel] = None
2990    """The model from which this model is derived, e.g. by fine-tuning the weights."""
2991
2992    @model_validator(mode="after")
2993    def _validate_parent_is_not_self(self) -> Self:
2994        if self.parent is not None and self.parent.id == self.id:
2995            raise ValueError("A model description may not reference itself as parent.")
2996
2997        return self
2998
2999    run_mode: Annotated[
3000        Optional[RunMode],
3001        warn(None, "Run mode '{value}' has limited support across consumer softwares."),
3002    ] = None
3003    """Custom run mode for this model: for more complex prediction procedures like test time
3004    data augmentation that currently cannot be expressed in the specification.
3005    No standard run modes are defined yet."""
3006
3007    timestamp: Datetime = Field(default_factory=Datetime.now)
3008    """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format
3009    with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat).
3010    (In Python a datetime object is valid, too)."""
3011
3012    training_data: Annotated[
3013        Union[None, LinkedDataset, DatasetDescr, DatasetDescr02],
3014        Field(union_mode="left_to_right"),
3015    ] = None
3016    """The dataset used to train this model"""
3017
3018    weights: Annotated[WeightsDescr, WrapSerializer(package_weights)]
3019    """The weights for this model.
3020    Weights can be given for different formats, but should otherwise be equivalent.
3021    The available weight formats determine which consumers can use this model."""
3022
3023    config: Config = Field(default_factory=Config.model_construct)
3024
3025    @model_validator(mode="after")
3026    def _add_default_cover(self) -> Self:
3027        if not get_validation_context().perform_io_checks or self.covers:
3028            return self
3029
3030        try:
3031            generated_covers = generate_covers(
3032                [
3033                    (t, load_array(t.test_tensor))
3034                    for t in self.inputs
3035                    if t.test_tensor is not None
3036                ],
3037                [
3038                    (t, load_array(t.test_tensor))
3039                    for t in self.outputs
3040                    if t.test_tensor is not None
3041                ],
3042            )
3043        except Exception as e:
3044            issue_warning(
3045                "Failed to generate cover image(s): {e}",
3046                value=self.covers,
3047                msg_context=dict(e=e),
3048                field="covers",
3049            )
3050        else:
3051            self.covers.extend(generated_covers)
3052
3053        return self
3054
3055    def get_input_test_arrays(self) -> List[NDArray[Any]]:
3056        return self._get_test_arrays(self.inputs)
3057
3058    def get_output_test_arrays(self) -> List[NDArray[Any]]:
3059        return self._get_test_arrays(self.outputs)
3060
3061    @staticmethod
3062    def _get_test_arrays(
3063        io_descr: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
3064    ):
3065        ts: List[FileDescr] = []
3066        for d in io_descr:
3067            if d.test_tensor is None:
3068                raise ValueError(
3069                    f"Failed to get test arrays: description of '{d.id}' is missing a `test_tensor`."
3070                )
3071            ts.append(d.test_tensor)
3072
3073        data = [load_array(t) for t in ts]
3074        assert all(isinstance(d, np.ndarray) for d in data)
3075        return data
3076
3077    @staticmethod
3078    def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
3079        batch_size = 1
3080        tensor_with_batchsize: Optional[TensorId] = None
3081        for tid in tensor_sizes:
3082            for aid, s in tensor_sizes[tid].items():
3083                if aid != BATCH_AXIS_ID or s == 1 or s == batch_size:
3084                    continue
3085
3086                if batch_size != 1:
3087                    assert tensor_with_batchsize is not None
3088                    raise ValueError(
3089                        f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})"
3090                    )
3091
3092                batch_size = s
3093                tensor_with_batchsize = tid
3094
3095        return batch_size
3096
3097    def get_output_tensor_sizes(
3098        self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]
3099    ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]:
3100        """Returns the tensor output sizes for given **input_sizes**.
3101        Only if **input_sizes** has a valid input shape, the tensor output size is exact.
3102        Otherwise it might be larger than the actual (valid) output"""
3103        batch_size = self.get_batch_size(input_sizes)
3104        ns = self.get_ns(input_sizes)
3105
3106        tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size)
3107        return tensor_sizes.outputs
3108
3109    def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
3110        """get parameter `n` for each parameterized axis
3111        such that the valid input size is >= the given input size"""
3112        ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {}
3113        axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs}
3114        for tid in input_sizes:
3115            for aid, s in input_sizes[tid].items():
3116                size_descr = axes[tid][aid].size
3117                if isinstance(size_descr, ParameterizedSize):
3118                    ret[(tid, aid)] = size_descr.get_n(s)
3119                elif size_descr is None or isinstance(size_descr, (int, SizeReference)):
3120                    pass
3121                else:
3122                    assert_never(size_descr)
3123
3124        return ret
3125
3126    def get_tensor_sizes(
3127        self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int
3128    ) -> _TensorSizes:
3129        axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size)
3130        return _TensorSizes(
3131            {
3132                t: {
3133                    aa: axis_sizes.inputs[(tt, aa)]
3134                    for tt, aa in axis_sizes.inputs
3135                    if tt == t
3136                }
3137                for t in {tt for tt, _ in axis_sizes.inputs}
3138            },
3139            {
3140                t: {
3141                    aa: axis_sizes.outputs[(tt, aa)]
3142                    for tt, aa in axis_sizes.outputs
3143                    if tt == t
3144                }
3145                for t in {tt for tt, _ in axis_sizes.outputs}
3146            },
3147        )
3148
3149    def get_axis_sizes(
3150        self,
3151        ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
3152        batch_size: Optional[int] = None,
3153        *,
3154        max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
3155    ) -> _AxisSizes:
3156        """Determine input and output block shape for scale factors **ns**
3157        of parameterized input sizes.
3158
3159        Args:
3160            ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id))
3161                that is parameterized as `size = min + n * step`.
3162            batch_size: The desired size of the batch dimension.
3163                If given **batch_size** overwrites any batch size present in
3164                **max_input_shape**. Default 1.
3165            max_input_shape: Limits the derived block shapes.
3166                Each axis for which the input size, parameterized by `n`, is larger
3167                than **max_input_shape** is set to the minimal value `n_min` for which
3168                this is still true.
3169                Use this for small input samples or large values of **ns**.
3170                Or simply whenever you know the full input shape.
3171
3172        Returns:
3173            Resolved axis sizes for model inputs and outputs.
3174        """
3175        max_input_shape = max_input_shape or {}
3176        if batch_size is None:
3177            for (_t_id, a_id), s in max_input_shape.items():
3178                if a_id == BATCH_AXIS_ID:
3179                    batch_size = s
3180                    break
3181            else:
3182                batch_size = 1
3183
3184        all_axes = {
3185            t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
3186        }
3187
3188        inputs: Dict[Tuple[TensorId, AxisId], int] = {}
3189        outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {}
3190
3191        def get_axis_size(a: Union[InputAxis, OutputAxis]):
3192            if isinstance(a, BatchAxis):
3193                if (t_descr.id, a.id) in ns:
3194                    logger.warning(
3195                        "Ignoring unexpected size increment factor (n) for batch axis"
3196                        + " of tensor '{}'.",
3197                        t_descr.id,
3198                    )
3199                return batch_size
3200            elif isinstance(a.size, int):
3201                if (t_descr.id, a.id) in ns:
3202                    logger.warning(
3203                        "Ignoring unexpected size increment factor (n) for fixed size"
3204                        + " axis '{}' of tensor '{}'.",
3205                        a.id,
3206                        t_descr.id,
3207                    )
3208                return a.size
3209            elif isinstance(a.size, ParameterizedSize):
3210                if (t_descr.id, a.id) not in ns:
3211                    raise ValueError(
3212                        "Size increment factor (n) missing for parametrized axis"
3213                        + f" '{a.id}' of tensor '{t_descr.id}'."
3214                    )
3215                n = ns[(t_descr.id, a.id)]
3216                s_max = max_input_shape.get((t_descr.id, a.id))
3217                if s_max is not None:
3218                    n = min(n, a.size.get_n(s_max))
3219
3220                return a.size.get_size(n)
3221
3222            elif isinstance(a.size, SizeReference):
3223                if (t_descr.id, a.id) in ns:
3224                    logger.warning(
3225                        "Ignoring unexpected size increment factor (n) for axis '{}'"
3226                        + " of tensor '{}' with size reference.",
3227                        a.id,
3228                        t_descr.id,
3229                    )
3230                assert not isinstance(a, BatchAxis)
3231                ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
3232                assert not isinstance(ref_axis, BatchAxis)
3233                ref_key = (a.size.tensor_id, a.size.axis_id)
3234                ref_size = inputs.get(ref_key, outputs.get(ref_key))
3235                assert ref_size is not None, ref_key
3236                assert not isinstance(ref_size, _DataDepSize), ref_key
3237                return a.size.get_size(
3238                    axis=a,
3239                    ref_axis=ref_axis,
3240                    ref_size=ref_size,
3241                )
3242            elif isinstance(a.size, DataDependentSize):
3243                if (t_descr.id, a.id) in ns:
3244                    logger.warning(
3245                        "Ignoring unexpected increment factor (n) for data dependent"
3246                        + " size axis '{}' of tensor '{}'.",
3247                        a.id,
3248                        t_descr.id,
3249                    )
3250                return _DataDepSize(a.size.min, a.size.max)
3251            else:
3252                assert_never(a.size)
3253
3254        # first resolve all , but the `SizeReference` input sizes
3255        for t_descr in self.inputs:
3256            for a in t_descr.axes:
3257                if not isinstance(a.size, SizeReference):
3258                    s = get_axis_size(a)
3259                    assert not isinstance(s, _DataDepSize)
3260                    inputs[t_descr.id, a.id] = s
3261
3262        # resolve all other input axis sizes
3263        for t_descr in self.inputs:
3264            for a in t_descr.axes:
3265                if isinstance(a.size, SizeReference):
3266                    s = get_axis_size(a)
3267                    assert not isinstance(s, _DataDepSize)
3268                    inputs[t_descr.id, a.id] = s
3269
3270        # resolve all output axis sizes
3271        for t_descr in self.outputs:
3272            for a in t_descr.axes:
3273                assert not isinstance(a.size, ParameterizedSize)
3274                s = get_axis_size(a)
3275                outputs[t_descr.id, a.id] = s
3276
3277        return _AxisSizes(inputs=inputs, outputs=outputs)
3278
3279    @model_validator(mode="before")
3280    @classmethod
3281    def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]:
3282        cls.convert_from_old_format_wo_validation(data)
3283        return data
3284
3285    @classmethod
3286    def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3287        """Convert metadata following an older format version to this classes' format
3288        without validating the result.
3289        """
3290        if (
3291            data.get("type") == "model"
3292            and isinstance(fv := data.get("format_version"), str)
3293            and fv.count(".") == 2
3294        ):
3295            fv_parts = fv.split(".")
3296            if any(not p.isdigit() for p in fv_parts):
3297                return
3298
3299            fv_tuple = tuple(map(int, fv_parts))
3300
3301            assert cls.implemented_format_version_tuple[0:2] == (0, 5)
3302            if fv_tuple[:2] in ((0, 3), (0, 4)):
3303                m04 = _ModelDescr_v0_4.load(data)
3304                if isinstance(m04, InvalidDescr):
3305                    try:
3306                        updated = _model_conv.convert_as_dict(
3307                            m04  # pyright: ignore[reportArgumentType]
3308                        )
3309                    except Exception as e:
3310                        logger.error(
3311                            "Failed to convert from invalid model 0.4 description."
3312                            + f"\nerror: {e}"
3313                            + "\nProceeding with model 0.5 validation without conversion."
3314                        )
3315                        updated = None
3316                else:
3317                    updated = _model_conv.convert_as_dict(m04)
3318
3319                if updated is not None:
3320                    data.clear()
3321                    data.update(updated)
3322
3323            elif fv_tuple[:2] == (0, 5):
3324                # bump patch version
3325                data["format_version"] = cls.implemented_format_version

Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. These fields are typically stored in a YAML file which we call a model resource description file (model RDF).

implemented_format_version: ClassVar[Literal['0.5.6']] = '0.5.6'
implemented_type: ClassVar[Literal['model']] = 'model'
id: Optional[ModelId] = None

bioimage.io-wide unique resource identifier assigned by bioimage.io; version unspecific.

authors: Annotated[List[bioimageio.spec.generic.v0_3.Author], AfterWarner(func=<function as_warning.<locals>.wrapper at 0x7efc006a6ac0>, severity=35, msg=None, context=None)] = PydanticUndefined

The authors are the creators of the model RDF and the primary points of contact.

documentation: Annotated[Optional[Annotated[Union[bioimageio.spec._internal.url.HttpUrl, bioimageio.spec._internal.io.RelativeFilePath, Annotated[pathlib.Path, PathType(path_type='file'), FieldInfo(annotation=NoneType, required=True, title='FilePath')]], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')]), AfterValidator(func=<function wo_special_file_name at 0x7efc0116c900>), PlainSerializer(func=<function _package_serializer at 0x7efbf1b27920>, return_type=PydanticUndefined, when_used='unless-none'), WithSuffix(suffix='.md', case_sensitive=True), FieldInfo(annotation=NoneType, required=True, examples=['https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/README.md', 'README.md'])]], AfterWarner(func=<function as_warning.<locals>.wrapper at 0x7efc006a6ac0>, severity=35, msg=None, context=None)] = None

URL or relative path to a markdown file with additional documentation. The recommended documentation file name is README.md. An .md suffix is mandatory. The documentation should include a '#[#] Validation' (sub)section with details on how to quantitatively validate the model on unseen data.

inputs: Annotated[Sequence[InputTensorDescr], MinLen(min_length=1)] = PydanticUndefined

Describes the input tensors expected by this model.

name: Annotated[str, RestrictCharacters(alphabet='abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_+- ()'), MinLen(min_length=5), MaxLen(max_length=128), AfterWarner(func=<function as_warning.<locals>.wrapper at 0x7efbeec06b60>, severity=20, msg='Name longer than 64 characters.', context={'typ': Annotated[Any, MaxLen(max_length=64)]})] = PydanticUndefined

A human-readable name of this model. It should be no longer than 64 characters and may only contain letter, number, underscore, minus, parentheses and spaces. We recommend to chose a name that refers to the model's task and image modality.

outputs: Annotated[Sequence[OutputTensorDescr], MinLen(min_length=1)] = PydanticUndefined

Describes the output tensors.

packaged_by: List[bioimageio.spec.generic.v0_3.Author] = PydanticUndefined

The persons that have packaged and uploaded this model. Only required if those persons differ from the authors.

parent: Optional[LinkedModel] = None

The model from which this model is derived, e.g. by fine-tuning the weights.

run_mode: Annotated[Optional[bioimageio.spec.model.v0_4.RunMode], AfterWarner(func=<function as_warning.<locals>.wrapper at 0x7efbeec07100>, severity=30, msg="Run mode '{value}' has limited support across consumer softwares.", context={'typ': None})] = None

Custom run mode for this model: for more complex prediction procedures like test time data augmentation that currently cannot be expressed in the specification. No standard run modes are defined yet.

timestamp: bioimageio.spec._internal.types.Datetime = PydanticUndefined

Timestamp in ISO 8601 format with a few restrictions listed here. (In Python a datetime object is valid, too).

training_data: Annotated[Union[NoneType, bioimageio.spec.dataset.v0_3.LinkedDataset, bioimageio.spec.DatasetDescr, bioimageio.spec.dataset.v0_2.DatasetDescr], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')])] = None

The dataset used to train this model

weights: Annotated[WeightsDescr, WrapSerializer(func=<function package_weights at 0x7efbf19c6200>, return_type=PydanticUndefined, when_used='always')] = PydanticUndefined

The weights for this model. Weights can be given for different formats, but should otherwise be equivalent. The available weight formats determine which consumers can use this model.

config: Config = PydanticUndefined
def get_input_test_arrays(self) -> List[numpy.ndarray[tuple[Any, ...], numpy.dtype[Any]]]:
3055    def get_input_test_arrays(self) -> List[NDArray[Any]]:
3056        return self._get_test_arrays(self.inputs)
def get_output_test_arrays(self) -> List[numpy.ndarray[tuple[Any, ...], numpy.dtype[Any]]]:
3058    def get_output_test_arrays(self) -> List[NDArray[Any]]:
3059        return self._get_test_arrays(self.outputs)
@staticmethod
def get_batch_size( tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
3077    @staticmethod
3078    def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
3079        batch_size = 1
3080        tensor_with_batchsize: Optional[TensorId] = None
3081        for tid in tensor_sizes:
3082            for aid, s in tensor_sizes[tid].items():
3083                if aid != BATCH_AXIS_ID or s == 1 or s == batch_size:
3084                    continue
3085
3086                if batch_size != 1:
3087                    assert tensor_with_batchsize is not None
3088                    raise ValueError(
3089                        f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})"
3090                    )
3091
3092                batch_size = s
3093                tensor_with_batchsize = tid
3094
3095        return batch_size
def get_output_tensor_sizes( self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> Dict[TensorId, Dict[AxisId, Union[int, bioimageio.spec.model.v0_5._DataDepSize]]]:
3097    def get_output_tensor_sizes(
3098        self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]
3099    ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]:
3100        """Returns the tensor output sizes for given **input_sizes**.
3101        Only if **input_sizes** has a valid input shape, the tensor output size is exact.
3102        Otherwise it might be larger than the actual (valid) output"""
3103        batch_size = self.get_batch_size(input_sizes)
3104        ns = self.get_ns(input_sizes)
3105
3106        tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size)
3107        return tensor_sizes.outputs

Returns the tensor output sizes for given input_sizes. Only if input_sizes has a valid input shape, the tensor output size is exact. Otherwise it might be larger than the actual (valid) output

def get_ns( self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
3109    def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
3110        """get parameter `n` for each parameterized axis
3111        such that the valid input size is >= the given input size"""
3112        ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {}
3113        axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs}
3114        for tid in input_sizes:
3115            for aid, s in input_sizes[tid].items():
3116                size_descr = axes[tid][aid].size
3117                if isinstance(size_descr, ParameterizedSize):
3118                    ret[(tid, aid)] = size_descr.get_n(s)
3119                elif size_descr is None or isinstance(size_descr, (int, SizeReference)):
3120                    pass
3121                else:
3122                    assert_never(size_descr)
3123
3124        return ret

get parameter n for each parameterized axis such that the valid input size is >= the given input size

def get_tensor_sizes( self, ns: Mapping[Tuple[TensorId, AxisId], int], batch_size: int) -> bioimageio.spec.model.v0_5._TensorSizes:
3126    def get_tensor_sizes(
3127        self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int
3128    ) -> _TensorSizes:
3129        axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size)
3130        return _TensorSizes(
3131            {
3132                t: {
3133                    aa: axis_sizes.inputs[(tt, aa)]
3134                    for tt, aa in axis_sizes.inputs
3135                    if tt == t
3136                }
3137                for t in {tt for tt, _ in axis_sizes.inputs}
3138            },
3139            {
3140                t: {
3141                    aa: axis_sizes.outputs[(tt, aa)]
3142                    for tt, aa in axis_sizes.outputs
3143                    if tt == t
3144                }
3145                for t in {tt for tt, _ in axis_sizes.outputs}
3146            },
3147        )
def get_axis_sizes( self, ns: Mapping[Tuple[TensorId, AxisId], int], batch_size: Optional[int] = None, *, max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None) -> bioimageio.spec.model.v0_5._AxisSizes:
3149    def get_axis_sizes(
3150        self,
3151        ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
3152        batch_size: Optional[int] = None,
3153        *,
3154        max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
3155    ) -> _AxisSizes:
3156        """Determine input and output block shape for scale factors **ns**
3157        of parameterized input sizes.
3158
3159        Args:
3160            ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id))
3161                that is parameterized as `size = min + n * step`.
3162            batch_size: The desired size of the batch dimension.
3163                If given **batch_size** overwrites any batch size present in
3164                **max_input_shape**. Default 1.
3165            max_input_shape: Limits the derived block shapes.
3166                Each axis for which the input size, parameterized by `n`, is larger
3167                than **max_input_shape** is set to the minimal value `n_min` for which
3168                this is still true.
3169                Use this for small input samples or large values of **ns**.
3170                Or simply whenever you know the full input shape.
3171
3172        Returns:
3173            Resolved axis sizes for model inputs and outputs.
3174        """
3175        max_input_shape = max_input_shape or {}
3176        if batch_size is None:
3177            for (_t_id, a_id), s in max_input_shape.items():
3178                if a_id == BATCH_AXIS_ID:
3179                    batch_size = s
3180                    break
3181            else:
3182                batch_size = 1
3183
3184        all_axes = {
3185            t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
3186        }
3187
3188        inputs: Dict[Tuple[TensorId, AxisId], int] = {}
3189        outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {}
3190
3191        def get_axis_size(a: Union[InputAxis, OutputAxis]):
3192            if isinstance(a, BatchAxis):
3193                if (t_descr.id, a.id) in ns:
3194                    logger.warning(
3195                        "Ignoring unexpected size increment factor (n) for batch axis"
3196                        + " of tensor '{}'.",
3197                        t_descr.id,
3198                    )
3199                return batch_size
3200            elif isinstance(a.size, int):
3201                if (t_descr.id, a.id) in ns:
3202                    logger.warning(
3203                        "Ignoring unexpected size increment factor (n) for fixed size"
3204                        + " axis '{}' of tensor '{}'.",
3205                        a.id,
3206                        t_descr.id,
3207                    )
3208                return a.size
3209            elif isinstance(a.size, ParameterizedSize):
3210                if (t_descr.id, a.id) not in ns:
3211                    raise ValueError(
3212                        "Size increment factor (n) missing for parametrized axis"
3213                        + f" '{a.id}' of tensor '{t_descr.id}'."
3214                    )
3215                n = ns[(t_descr.id, a.id)]
3216                s_max = max_input_shape.get((t_descr.id, a.id))
3217                if s_max is not None:
3218                    n = min(n, a.size.get_n(s_max))
3219
3220                return a.size.get_size(n)
3221
3222            elif isinstance(a.size, SizeReference):
3223                if (t_descr.id, a.id) in ns:
3224                    logger.warning(
3225                        "Ignoring unexpected size increment factor (n) for axis '{}'"
3226                        + " of tensor '{}' with size reference.",
3227                        a.id,
3228                        t_descr.id,
3229                    )
3230                assert not isinstance(a, BatchAxis)
3231                ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
3232                assert not isinstance(ref_axis, BatchAxis)
3233                ref_key = (a.size.tensor_id, a.size.axis_id)
3234                ref_size = inputs.get(ref_key, outputs.get(ref_key))
3235                assert ref_size is not None, ref_key
3236                assert not isinstance(ref_size, _DataDepSize), ref_key
3237                return a.size.get_size(
3238                    axis=a,
3239                    ref_axis=ref_axis,
3240                    ref_size=ref_size,
3241                )
3242            elif isinstance(a.size, DataDependentSize):
3243                if (t_descr.id, a.id) in ns:
3244                    logger.warning(
3245                        "Ignoring unexpected increment factor (n) for data dependent"
3246                        + " size axis '{}' of tensor '{}'.",
3247                        a.id,
3248                        t_descr.id,
3249                    )
3250                return _DataDepSize(a.size.min, a.size.max)
3251            else:
3252                assert_never(a.size)
3253
3254        # first resolve all , but the `SizeReference` input sizes
3255        for t_descr in self.inputs:
3256            for a in t_descr.axes:
3257                if not isinstance(a.size, SizeReference):
3258                    s = get_axis_size(a)
3259                    assert not isinstance(s, _DataDepSize)
3260                    inputs[t_descr.id, a.id] = s
3261
3262        # resolve all other input axis sizes
3263        for t_descr in self.inputs:
3264            for a in t_descr.axes:
3265                if isinstance(a.size, SizeReference):
3266                    s = get_axis_size(a)
3267                    assert not isinstance(s, _DataDepSize)
3268                    inputs[t_descr.id, a.id] = s
3269
3270        # resolve all output axis sizes
3271        for t_descr in self.outputs:
3272            for a in t_descr.axes:
3273                assert not isinstance(a.size, ParameterizedSize)
3274                s = get_axis_size(a)
3275                outputs[t_descr.id, a.id] = s
3276
3277        return _AxisSizes(inputs=inputs, outputs=outputs)

Determine input and output block shape for scale factors ns of parameterized input sizes.

Arguments:
  • ns: Scale factor n for each axis (keyed by (tensor_id, axis_id)) that is parameterized as size = min + n * step.
  • batch_size: The desired size of the batch dimension. If given batch_size overwrites any batch size present in max_input_shape. Default 1.
  • max_input_shape: Limits the derived block shapes. Each axis for which the input size, parameterized by n, is larger than max_input_shape is set to the minimal value n_min for which this is still true. Use this for small input samples or large values of ns. Or simply whenever you know the full input shape.
Returns:

Resolved axis sizes for model inputs and outputs.

@classmethod
def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3285    @classmethod
3286    def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3287        """Convert metadata following an older format version to this classes' format
3288        without validating the result.
3289        """
3290        if (
3291            data.get("type") == "model"
3292            and isinstance(fv := data.get("format_version"), str)
3293            and fv.count(".") == 2
3294        ):
3295            fv_parts = fv.split(".")
3296            if any(not p.isdigit() for p in fv_parts):
3297                return
3298
3299            fv_tuple = tuple(map(int, fv_parts))
3300
3301            assert cls.implemented_format_version_tuple[0:2] == (0, 5)
3302            if fv_tuple[:2] in ((0, 3), (0, 4)):
3303                m04 = _ModelDescr_v0_4.load(data)
3304                if isinstance(m04, InvalidDescr):
3305                    try:
3306                        updated = _model_conv.convert_as_dict(
3307                            m04  # pyright: ignore[reportArgumentType]
3308                        )
3309                    except Exception as e:
3310                        logger.error(
3311                            "Failed to convert from invalid model 0.4 description."
3312                            + f"\nerror: {e}"
3313                            + "\nProceeding with model 0.5 validation without conversion."
3314                        )
3315                        updated = None
3316                else:
3317                    updated = _model_conv.convert_as_dict(m04)
3318
3319                if updated is not None:
3320                    data.clear()
3321                    data.update(updated)
3322
3323            elif fv_tuple[:2] == (0, 5):
3324                # bump patch version
3325                data["format_version"] = cls.implemented_format_version

Convert metadata following an older format version to this classes' format without validating the result.

implemented_format_version_tuple: ClassVar[Tuple[int, int, int]] = (0, 5, 6)
def generate_covers( inputs: Sequence[Tuple[InputTensorDescr, numpy.ndarray[tuple[Any, ...], numpy.dtype[Any]]]], outputs: Sequence[Tuple[OutputTensorDescr, numpy.ndarray[tuple[Any, ...], numpy.dtype[Any]]]]) -> List[pathlib.Path]:
3542def generate_covers(
3543    inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]],
3544    outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]],
3545) -> List[Path]:
3546    def squeeze(
3547        data: NDArray[Any], axes: Sequence[AnyAxis]
3548    ) -> Tuple[NDArray[Any], List[AnyAxis]]:
3549        """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining"""
3550        if data.ndim != len(axes):
3551            raise ValueError(
3552                f"tensor shape {data.shape} does not match described axes"
3553                + f" {[a.id for a in axes]}"
3554            )
3555
3556        axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1]
3557        return data.squeeze(), axes
3558
3559    def normalize(
3560        data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7
3561    ) -> NDArray[np.float32]:
3562        data = data.astype("float32")
3563        data -= data.min(axis=axis, keepdims=True)
3564        data /= data.max(axis=axis, keepdims=True) + eps
3565        return data
3566
3567    def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]):
3568        original_shape = data.shape
3569        original_axes = list(axes)
3570        data, axes = squeeze(data, axes)
3571
3572        # take slice fom any batch or index axis if needed
3573        # and convert the first channel axis and take a slice from any additional channel axes
3574        slices: Tuple[slice, ...] = ()
3575        ndim = data.ndim
3576        ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2
3577        has_c_axis = False
3578        for i, a in enumerate(axes):
3579            s = data.shape[i]
3580            assert s > 1
3581            if (
3582                isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis))
3583                and ndim > ndim_need
3584            ):
3585                data = data[slices + (slice(s // 2 - 1, s // 2),)]
3586                ndim -= 1
3587            elif isinstance(a, ChannelAxis):
3588                if has_c_axis:
3589                    # second channel axis
3590                    data = data[slices + (slice(0, 1),)]
3591                    ndim -= 1
3592                else:
3593                    has_c_axis = True
3594                    if s == 2:
3595                        # visualize two channels with cyan and magenta
3596                        data = np.concatenate(
3597                            [
3598                                data[slices + (slice(1, 2),)],
3599                                data[slices + (slice(0, 1),)],
3600                                (
3601                                    data[slices + (slice(0, 1),)]
3602                                    + data[slices + (slice(1, 2),)]
3603                                )
3604                                / 2,  # TODO: take maximum instead?
3605                            ],
3606                            axis=i,
3607                        )
3608                    elif data.shape[i] == 3:
3609                        pass  # visualize 3 channels as RGB
3610                    else:
3611                        # visualize first 3 channels as RGB
3612                        data = data[slices + (slice(3),)]
3613
3614                    assert data.shape[i] == 3
3615
3616            slices += (slice(None),)
3617
3618        data, axes = squeeze(data, axes)
3619        assert len(axes) == ndim
3620        # take slice from z axis if needed
3621        slices = ()
3622        if ndim > ndim_need:
3623            for i, a in enumerate(axes):
3624                s = data.shape[i]
3625                if a.id == AxisId("z"):
3626                    data = data[slices + (slice(s // 2 - 1, s // 2),)]
3627                    data, axes = squeeze(data, axes)
3628                    ndim -= 1
3629                    break
3630
3631            slices += (slice(None),)
3632
3633        # take slice from any space or time axis
3634        slices = ()
3635
3636        for i, a in enumerate(axes):
3637            if ndim <= ndim_need:
3638                break
3639
3640            s = data.shape[i]
3641            assert s > 1
3642            if isinstance(
3643                a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis)
3644            ):
3645                data = data[slices + (slice(s // 2 - 1, s // 2),)]
3646                ndim -= 1
3647
3648            slices += (slice(None),)
3649
3650        del slices
3651        data, axes = squeeze(data, axes)
3652        assert len(axes) == ndim
3653
3654        if (has_c_axis and ndim != 3) or (not has_c_axis and ndim != 2):
3655            raise ValueError(
3656                f"Failed to construct cover image from shape {original_shape} with axes {[a.id for a in original_axes]}."
3657            )
3658
3659        if not has_c_axis:
3660            assert ndim == 2
3661            data = np.repeat(data[:, :, None], 3, axis=2)
3662            axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB"))))
3663            ndim += 1
3664
3665        assert ndim == 3
3666
3667        # transpose axis order such that longest axis comes first...
3668        axis_order: List[int] = list(np.argsort(list(data.shape)))
3669        axis_order.reverse()
3670        # ... and channel axis is last
3671        c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0]
3672        axis_order.append(axis_order.pop(c))
3673        axes = [axes[ao] for ao in axis_order]
3674        data = data.transpose(axis_order)
3675
3676        # h, w = data.shape[:2]
3677        # if h / w  in (1.0 or 2.0):
3678        #     pass
3679        # elif h / w < 2:
3680        # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images
3681
3682        norm_along = (
3683            tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None
3684        )
3685        # normalize the data and map to 8 bit
3686        data = normalize(data, norm_along)
3687        data = (data * 255).astype("uint8")
3688
3689        return data
3690
3691    def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]):
3692        assert im0.dtype == im1.dtype == np.uint8
3693        assert im0.shape == im1.shape
3694        assert im0.ndim == 3
3695        N, M, C = im0.shape
3696        assert C == 3
3697        out = np.ones((N, M, C), dtype="uint8")
3698        for c in range(C):
3699            outc = np.tril(im0[..., c])
3700            mask = outc == 0
3701            outc[mask] = np.triu(im1[..., c])[mask]
3702            out[..., c] = outc
3703
3704        return out
3705
3706    if not inputs:
3707        raise ValueError("Missing test input tensor for cover generation.")
3708
3709    if not outputs:
3710        raise ValueError("Missing test output tensor for cover generation.")
3711
3712    ipt_descr, ipt = inputs[0]
3713    out_descr, out = outputs[0]
3714
3715    ipt_img = to_2d_image(ipt, ipt_descr.axes)
3716    out_img = to_2d_image(out, out_descr.axes)
3717
3718    cover_folder = Path(mkdtemp())
3719    if ipt_img.shape == out_img.shape:
3720        covers = [cover_folder / "cover.png"]
3721        imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img))
3722    else:
3723        covers = [cover_folder / "input.png", cover_folder / "output.png"]
3724        imwrite(covers[0], ipt_img)
3725        imwrite(covers[1], out_img)
3726
3727    return covers
class TensorDescrBase[Annotated[Union[BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis], Discriminator]](bioimageio.spec._internal.node.Node, typing.Generic[~IO_AxisT]):
class TensorDescrBase[Annotated[Union[BatchAxis, ChannelAxis, IndexOutputAxis, Annotated[Union[Annotated[TimeOutputAxis, Tag], Annotated[TimeOutputAxisWithHalo, Tag]], Discriminator], Annotated[Union[Annotated[SpaceOutputAxis, Tag], Annotated[SpaceOutputAxisWithHalo, Tag]], Discriminator]], Discriminator]](bioimageio.spec._internal.node.Node, typing.Generic[~IO_AxisT]):