bioimageio.spec.model.v0_5

   1from __future__ import annotations
   2
   3import collections.abc
   4import re
   5import string
   6import warnings
   7from abc import ABC
   8from copy import deepcopy
   9from itertools import chain
  10from math import ceil
  11from pathlib import Path, PurePosixPath
  12from tempfile import mkdtemp
  13from typing import (
  14    TYPE_CHECKING,
  15    Any,
  16    Callable,
  17    ClassVar,
  18    Dict,
  19    Generic,
  20    List,
  21    Literal,
  22    Mapping,
  23    NamedTuple,
  24    Optional,
  25    Sequence,
  26    Set,
  27    Tuple,
  28    Type,
  29    TypeVar,
  30    Union,
  31    cast,
  32)
  33
  34import numpy as np
  35from annotated_types import Ge, Gt, Interval, MaxLen, MinLen, Predicate
  36from imageio.v3 import imread, imwrite  # pyright: ignore[reportUnknownVariableType]
  37from loguru import logger
  38from numpy.typing import NDArray
  39from pydantic import (
  40    AfterValidator,
  41    Discriminator,
  42    Field,
  43    RootModel,
  44    SerializationInfo,
  45    SerializerFunctionWrapHandler,
  46    StrictInt,
  47    Tag,
  48    ValidationInfo,
  49    WrapSerializer,
  50    field_validator,
  51    model_serializer,
  52    model_validator,
  53)
  54from typing_extensions import Annotated, Self, assert_never, get_args
  55
  56from .._internal.common_nodes import (
  57    InvalidDescr,
  58    Node,
  59    NodeWithExplicitlySetFields,
  60)
  61from .._internal.constants import DTYPE_LIMITS
  62from .._internal.field_warning import issue_warning, warn
  63from .._internal.io import BioimageioYamlContent as BioimageioYamlContent
  64from .._internal.io import FileDescr as FileDescr
  65from .._internal.io import (
  66    FileSource,
  67    WithSuffix,
  68    YamlValue,
  69    get_reader,
  70    wo_special_file_name,
  71)
  72from .._internal.io_basics import Sha256 as Sha256
  73from .._internal.io_packaging import (
  74    FileDescr_,
  75    FileSource_,
  76    package_file_descr_serializer,
  77)
  78from .._internal.io_utils import load_array
  79from .._internal.node_converter import Converter
  80from .._internal.type_guards import is_dict, is_sequence
  81from .._internal.types import (
  82    FAIR,
  83    AbsoluteTolerance,
  84    LowerCaseIdentifier,
  85    LowerCaseIdentifierAnno,
  86    MismatchedElementsPerMillion,
  87    RelativeTolerance,
  88)
  89from .._internal.types import Datetime as Datetime
  90from .._internal.types import Identifier as Identifier
  91from .._internal.types import NotEmpty as NotEmpty
  92from .._internal.types import SiUnit as SiUnit
  93from .._internal.url import HttpUrl as HttpUrl
  94from .._internal.validation_context import get_validation_context
  95from .._internal.validator_annotations import RestrictCharacters
  96from .._internal.version_type import Version as Version
  97from .._internal.warning_levels import INFO
  98from ..dataset.v0_2 import DatasetDescr as DatasetDescr02
  99from ..dataset.v0_2 import LinkedDataset as LinkedDataset02
 100from ..dataset.v0_3 import DatasetDescr as DatasetDescr
 101from ..dataset.v0_3 import DatasetId as DatasetId
 102from ..dataset.v0_3 import LinkedDataset as LinkedDataset
 103from ..dataset.v0_3 import Uploader as Uploader
 104from ..generic.v0_3 import (
 105    VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS,
 106)
 107from ..generic.v0_3 import Author as Author
 108from ..generic.v0_3 import BadgeDescr as BadgeDescr
 109from ..generic.v0_3 import CiteEntry as CiteEntry
 110from ..generic.v0_3 import DeprecatedLicenseId as DeprecatedLicenseId
 111from ..generic.v0_3 import Doi as Doi
 112from ..generic.v0_3 import (
 113    FileSource_documentation,
 114    GenericModelDescrBase,
 115    LinkedResourceBase,
 116    _author_conv,  # pyright: ignore[reportPrivateUsage]
 117    _maintainer_conv,  # pyright: ignore[reportPrivateUsage]
 118)
 119from ..generic.v0_3 import LicenseId as LicenseId
 120from ..generic.v0_3 import LinkedResource as LinkedResource
 121from ..generic.v0_3 import Maintainer as Maintainer
 122from ..generic.v0_3 import OrcidId as OrcidId
 123from ..generic.v0_3 import RelativeFilePath as RelativeFilePath
 124from ..generic.v0_3 import ResourceId as ResourceId
 125from .v0_4 import Author as _Author_v0_4
 126from .v0_4 import BinarizeDescr as _BinarizeDescr_v0_4
 127from .v0_4 import CallableFromDepencency as CallableFromDepencency
 128from .v0_4 import CallableFromDepencency as _CallableFromDepencency_v0_4
 129from .v0_4 import CallableFromFile as _CallableFromFile_v0_4
 130from .v0_4 import ClipDescr as _ClipDescr_v0_4
 131from .v0_4 import ClipKwargs as ClipKwargs
 132from .v0_4 import ImplicitOutputShape as _ImplicitOutputShape_v0_4
 133from .v0_4 import InputTensorDescr as _InputTensorDescr_v0_4
 134from .v0_4 import KnownRunMode as KnownRunMode
 135from .v0_4 import ModelDescr as _ModelDescr_v0_4
 136from .v0_4 import OutputTensorDescr as _OutputTensorDescr_v0_4
 137from .v0_4 import ParameterizedInputShape as _ParameterizedInputShape_v0_4
 138from .v0_4 import PostprocessingDescr as _PostprocessingDescr_v0_4
 139from .v0_4 import PreprocessingDescr as _PreprocessingDescr_v0_4
 140from .v0_4 import ProcessingKwargs as ProcessingKwargs
 141from .v0_4 import RunMode as RunMode
 142from .v0_4 import ScaleLinearDescr as _ScaleLinearDescr_v0_4
 143from .v0_4 import ScaleMeanVarianceDescr as _ScaleMeanVarianceDescr_v0_4
 144from .v0_4 import ScaleRangeDescr as _ScaleRangeDescr_v0_4
 145from .v0_4 import SigmoidDescr as _SigmoidDescr_v0_4
 146from .v0_4 import TensorName as _TensorName_v0_4
 147from .v0_4 import WeightsFormat as WeightsFormat
 148from .v0_4 import ZeroMeanUnitVarianceDescr as _ZeroMeanUnitVarianceDescr_v0_4
 149from .v0_4 import package_weights
 150
 151SpaceUnit = Literal[
 152    "attometer",
 153    "angstrom",
 154    "centimeter",
 155    "decimeter",
 156    "exameter",
 157    "femtometer",
 158    "foot",
 159    "gigameter",
 160    "hectometer",
 161    "inch",
 162    "kilometer",
 163    "megameter",
 164    "meter",
 165    "micrometer",
 166    "mile",
 167    "millimeter",
 168    "nanometer",
 169    "parsec",
 170    "petameter",
 171    "picometer",
 172    "terameter",
 173    "yard",
 174    "yoctometer",
 175    "yottameter",
 176    "zeptometer",
 177    "zettameter",
 178]
 179"""Space unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)"""
 180
 181TimeUnit = Literal[
 182    "attosecond",
 183    "centisecond",
 184    "day",
 185    "decisecond",
 186    "exasecond",
 187    "femtosecond",
 188    "gigasecond",
 189    "hectosecond",
 190    "hour",
 191    "kilosecond",
 192    "megasecond",
 193    "microsecond",
 194    "millisecond",
 195    "minute",
 196    "nanosecond",
 197    "petasecond",
 198    "picosecond",
 199    "second",
 200    "terasecond",
 201    "yoctosecond",
 202    "yottasecond",
 203    "zeptosecond",
 204    "zettasecond",
 205]
 206"""Time unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)"""
 207
 208AxisType = Literal["batch", "channel", "index", "time", "space"]
 209
 210_AXIS_TYPE_MAP: Mapping[str, AxisType] = {
 211    "b": "batch",
 212    "t": "time",
 213    "i": "index",
 214    "c": "channel",
 215    "x": "space",
 216    "y": "space",
 217    "z": "space",
 218}
 219
 220_AXIS_ID_MAP = {
 221    "b": "batch",
 222    "t": "time",
 223    "i": "index",
 224    "c": "channel",
 225}
 226
 227
 228class TensorId(LowerCaseIdentifier):
 229    root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
 230        Annotated[LowerCaseIdentifierAnno, MaxLen(32)]
 231    ]
 232
 233
 234def _normalize_axis_id(a: str):
 235    a = str(a)
 236    normalized = _AXIS_ID_MAP.get(a, a)
 237    if a != normalized:
 238        logger.opt(depth=3).warning(
 239            "Normalized axis id from '{}' to '{}'.", a, normalized
 240        )
 241    return normalized
 242
 243
 244class AxisId(LowerCaseIdentifier):
 245    root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
 246        Annotated[
 247            LowerCaseIdentifierAnno,
 248            MaxLen(16),
 249            AfterValidator(_normalize_axis_id),
 250        ]
 251    ]
 252
 253
 254def _is_batch(a: str) -> bool:
 255    return str(a) == "batch"
 256
 257
 258def _is_not_batch(a: str) -> bool:
 259    return not _is_batch(a)
 260
 261
 262NonBatchAxisId = Annotated[AxisId, Predicate(_is_not_batch)]
 263
 264PreprocessingId = Literal[
 265    "binarize",
 266    "clip",
 267    "ensure_dtype",
 268    "fixed_zero_mean_unit_variance",
 269    "scale_linear",
 270    "scale_range",
 271    "sigmoid",
 272    "softmax",
 273]
 274PostprocessingId = Literal[
 275    "binarize",
 276    "clip",
 277    "ensure_dtype",
 278    "fixed_zero_mean_unit_variance",
 279    "scale_linear",
 280    "scale_mean_variance",
 281    "scale_range",
 282    "sigmoid",
 283    "softmax",
 284    "zero_mean_unit_variance",
 285]
 286
 287
 288SAME_AS_TYPE = "<same as type>"
 289
 290
 291ParameterizedSize_N = int
 292"""
 293Annotates an integer to calculate a concrete axis size from a `ParameterizedSize`.
 294"""
 295
 296
 297class ParameterizedSize(Node):
 298    """Describes a range of valid tensor axis sizes as `size = min + n*step`.
 299
 300    - **min** and **step** are given by the model description.
 301    - All blocksize paramters n = 0,1,2,... yield a valid `size`.
 302    - A greater blocksize paramter n = 0,1,2,... results in a greater **size**.
 303      This allows to adjust the axis size more generically.
 304    """
 305
 306    N: ClassVar[Type[int]] = ParameterizedSize_N
 307    """Positive integer to parameterize this axis"""
 308
 309    min: Annotated[int, Gt(0)]
 310    step: Annotated[int, Gt(0)]
 311
 312    def validate_size(self, size: int) -> int:
 313        if size < self.min:
 314            raise ValueError(f"size {size} < {self.min}")
 315        if (size - self.min) % self.step != 0:
 316            raise ValueError(
 317                f"axis of size {size} is not parameterized by `min + n*step` ="
 318                + f" `{self.min} + n*{self.step}`"
 319            )
 320
 321        return size
 322
 323    def get_size(self, n: ParameterizedSize_N) -> int:
 324        return self.min + self.step * n
 325
 326    def get_n(self, s: int) -> ParameterizedSize_N:
 327        """return smallest n parameterizing a size greater or equal than `s`"""
 328        return ceil((s - self.min) / self.step)
 329
 330
 331class DataDependentSize(Node):
 332    min: Annotated[int, Gt(0)] = 1
 333    max: Annotated[Optional[int], Gt(1)] = None
 334
 335    @model_validator(mode="after")
 336    def _validate_max_gt_min(self):
 337        if self.max is not None and self.min >= self.max:
 338            raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}")
 339
 340        return self
 341
 342    def validate_size(self, size: int) -> int:
 343        if size < self.min:
 344            raise ValueError(f"size {size} < {self.min}")
 345
 346        if self.max is not None and size > self.max:
 347            raise ValueError(f"size {size} > {self.max}")
 348
 349        return size
 350
 351
 352class SizeReference(Node):
 353    """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis.
 354
 355    `axis.size = reference.size * reference.scale / axis.scale + offset`
 356
 357    Note:
 358    1. The axis and the referenced axis need to have the same unit (or no unit).
 359    2. Batch axes may not be referenced.
 360    3. Fractions are rounded down.
 361    4. If the reference axis is `concatenable` the referencing axis is assumed to be
 362        `concatenable` as well with the same block order.
 363
 364    Example:
 365    An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm².
 366    Let's assume that we want to express the image height h in relation to its width w
 367    instead of only accepting input images of exactly 100*49 pixels
 368    (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`).
 369
 370    >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2)
 371    >>> h = SpaceInputAxis(
 372    ...     id=AxisId("h"),
 373    ...     size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1),
 374    ...     unit="millimeter",
 375    ...     scale=4,
 376    ... )
 377    >>> print(h.size.get_size(h, w))
 378    49
 379
 380    ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49
 381    """
 382
 383    tensor_id: TensorId
 384    """tensor id of the reference axis"""
 385
 386    axis_id: AxisId
 387    """axis id of the reference axis"""
 388
 389    offset: StrictInt = 0
 390
 391    def get_size(
 392        self,
 393        axis: Union[
 394            ChannelAxis,
 395            IndexInputAxis,
 396            IndexOutputAxis,
 397            TimeInputAxis,
 398            SpaceInputAxis,
 399            TimeOutputAxis,
 400            TimeOutputAxisWithHalo,
 401            SpaceOutputAxis,
 402            SpaceOutputAxisWithHalo,
 403        ],
 404        ref_axis: Union[
 405            ChannelAxis,
 406            IndexInputAxis,
 407            IndexOutputAxis,
 408            TimeInputAxis,
 409            SpaceInputAxis,
 410            TimeOutputAxis,
 411            TimeOutputAxisWithHalo,
 412            SpaceOutputAxis,
 413            SpaceOutputAxisWithHalo,
 414        ],
 415        n: ParameterizedSize_N = 0,
 416        ref_size: Optional[int] = None,
 417    ):
 418        """Compute the concrete size for a given axis and its reference axis.
 419
 420        Args:
 421            axis: The axis this `SizeReference` is the size of.
 422            ref_axis: The reference axis to compute the size from.
 423            n: If the **ref_axis** is parameterized (of type `ParameterizedSize`)
 424                and no fixed **ref_size** is given,
 425                **n** is used to compute the size of the parameterized **ref_axis**.
 426            ref_size: Overwrite the reference size instead of deriving it from
 427                **ref_axis**
 428                (**ref_axis.scale** is still used; any given **n** is ignored).
 429        """
 430        assert axis.size == self, (
 431            "Given `axis.size` is not defined by this `SizeReference`"
 432        )
 433
 434        assert ref_axis.id == self.axis_id, (
 435            f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}."
 436        )
 437
 438        assert axis.unit == ref_axis.unit, (
 439            "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`,"
 440            f" but {axis.unit}!={ref_axis.unit}"
 441        )
 442        if ref_size is None:
 443            if isinstance(ref_axis.size, (int, float)):
 444                ref_size = ref_axis.size
 445            elif isinstance(ref_axis.size, ParameterizedSize):
 446                ref_size = ref_axis.size.get_size(n)
 447            elif isinstance(ref_axis.size, DataDependentSize):
 448                raise ValueError(
 449                    "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`."
 450                )
 451            elif isinstance(ref_axis.size, SizeReference):
 452                raise ValueError(
 453                    "Reference axis referenced in `SizeReference` may not be sized by a"
 454                    + " `SizeReference` itself."
 455                )
 456            else:
 457                assert_never(ref_axis.size)
 458
 459        return int(ref_size * ref_axis.scale / axis.scale + self.offset)
 460
 461    @staticmethod
 462    def _get_unit(
 463        axis: Union[
 464            ChannelAxis,
 465            IndexInputAxis,
 466            IndexOutputAxis,
 467            TimeInputAxis,
 468            SpaceInputAxis,
 469            TimeOutputAxis,
 470            TimeOutputAxisWithHalo,
 471            SpaceOutputAxis,
 472            SpaceOutputAxisWithHalo,
 473        ],
 474    ):
 475        return axis.unit
 476
 477
 478class AxisBase(NodeWithExplicitlySetFields):
 479    id: AxisId
 480    """An axis id unique across all axes of one tensor."""
 481
 482    description: Annotated[str, MaxLen(128)] = ""
 483    """A short description of this axis beyond its type and id."""
 484
 485
 486class WithHalo(Node):
 487    halo: Annotated[int, Ge(1)]
 488    """The halo should be cropped from the output tensor to avoid boundary effects.
 489    It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`.
 490    To document a halo that is already cropped by the model use `size.offset` instead."""
 491
 492    size: Annotated[
 493        SizeReference,
 494        Field(
 495            examples=[
 496                10,
 497                SizeReference(
 498                    tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
 499                ).model_dump(mode="json"),
 500            ]
 501        ),
 502    ]
 503    """reference to another axis with an optional offset (see `SizeReference`)"""
 504
 505
 506BATCH_AXIS_ID = AxisId("batch")
 507
 508
 509class BatchAxis(AxisBase):
 510    implemented_type: ClassVar[Literal["batch"]] = "batch"
 511    if TYPE_CHECKING:
 512        type: Literal["batch"] = "batch"
 513    else:
 514        type: Literal["batch"]
 515
 516    id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID
 517    size: Optional[Literal[1]] = None
 518    """The batch size may be fixed to 1,
 519    otherwise (the default) it may be chosen arbitrarily depending on available memory"""
 520
 521    @property
 522    def scale(self):
 523        return 1.0
 524
 525    @property
 526    def concatenable(self):
 527        return True
 528
 529    @property
 530    def unit(self):
 531        return None
 532
 533
 534class ChannelAxis(AxisBase):
 535    implemented_type: ClassVar[Literal["channel"]] = "channel"
 536    if TYPE_CHECKING:
 537        type: Literal["channel"] = "channel"
 538    else:
 539        type: Literal["channel"]
 540
 541    id: NonBatchAxisId = AxisId("channel")
 542
 543    channel_names: NotEmpty[List[Identifier]]
 544
 545    @property
 546    def size(self) -> int:
 547        return len(self.channel_names)
 548
 549    @property
 550    def concatenable(self):
 551        return False
 552
 553    @property
 554    def scale(self) -> float:
 555        return 1.0
 556
 557    @property
 558    def unit(self):
 559        return None
 560
 561
 562class IndexAxisBase(AxisBase):
 563    implemented_type: ClassVar[Literal["index"]] = "index"
 564    if TYPE_CHECKING:
 565        type: Literal["index"] = "index"
 566    else:
 567        type: Literal["index"]
 568
 569    id: NonBatchAxisId = AxisId("index")
 570
 571    @property
 572    def scale(self) -> float:
 573        return 1.0
 574
 575    @property
 576    def unit(self):
 577        return None
 578
 579
 580class _WithInputAxisSize(Node):
 581    size: Annotated[
 582        Union[Annotated[int, Gt(0)], ParameterizedSize, SizeReference],
 583        Field(
 584            examples=[
 585                10,
 586                ParameterizedSize(min=32, step=16).model_dump(mode="json"),
 587                SizeReference(
 588                    tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
 589                ).model_dump(mode="json"),
 590            ]
 591        ),
 592    ]
 593    """The size/length of this axis can be specified as
 594    - fixed integer
 595    - parameterized series of valid sizes (`ParameterizedSize`)
 596    - reference to another axis with an optional offset (`SizeReference`)
 597    """
 598
 599
 600class IndexInputAxis(IndexAxisBase, _WithInputAxisSize):
 601    concatenable: bool = False
 602    """If a model has a `concatenable` input axis, it can be processed blockwise,
 603    splitting a longer sample axis into blocks matching its input tensor description.
 604    Output axes are concatenable if they have a `SizeReference` to a concatenable
 605    input axis.
 606    """
 607
 608
 609class IndexOutputAxis(IndexAxisBase):
 610    size: Annotated[
 611        Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize],
 612        Field(
 613            examples=[
 614                10,
 615                SizeReference(
 616                    tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
 617                ).model_dump(mode="json"),
 618            ]
 619        ),
 620    ]
 621    """The size/length of this axis can be specified as
 622    - fixed integer
 623    - reference to another axis with an optional offset (`SizeReference`)
 624    - data dependent size using `DataDependentSize` (size is only known after model inference)
 625    """
 626
 627
 628class TimeAxisBase(AxisBase):
 629    implemented_type: ClassVar[Literal["time"]] = "time"
 630    if TYPE_CHECKING:
 631        type: Literal["time"] = "time"
 632    else:
 633        type: Literal["time"]
 634
 635    id: NonBatchAxisId = AxisId("time")
 636    unit: Optional[TimeUnit] = None
 637    scale: Annotated[float, Gt(0)] = 1.0
 638
 639
 640class TimeInputAxis(TimeAxisBase, _WithInputAxisSize):
 641    concatenable: bool = False
 642    """If a model has a `concatenable` input axis, it can be processed blockwise,
 643    splitting a longer sample axis into blocks matching its input tensor description.
 644    Output axes are concatenable if they have a `SizeReference` to a concatenable
 645    input axis.
 646    """
 647
 648
 649class SpaceAxisBase(AxisBase):
 650    implemented_type: ClassVar[Literal["space"]] = "space"
 651    if TYPE_CHECKING:
 652        type: Literal["space"] = "space"
 653    else:
 654        type: Literal["space"]
 655
 656    id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x")
 657    unit: Optional[SpaceUnit] = None
 658    scale: Annotated[float, Gt(0)] = 1.0
 659
 660
 661class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize):
 662    concatenable: bool = False
 663    """If a model has a `concatenable` input axis, it can be processed blockwise,
 664    splitting a longer sample axis into blocks matching its input tensor description.
 665    Output axes are concatenable if they have a `SizeReference` to a concatenable
 666    input axis.
 667    """
 668
 669
 670INPUT_AXIS_TYPES = (
 671    BatchAxis,
 672    ChannelAxis,
 673    IndexInputAxis,
 674    TimeInputAxis,
 675    SpaceInputAxis,
 676)
 677"""intended for isinstance comparisons in py<3.10"""
 678
 679_InputAxisUnion = Union[
 680    BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis
 681]
 682InputAxis = Annotated[_InputAxisUnion, Discriminator("type")]
 683
 684
 685class _WithOutputAxisSize(Node):
 686    size: Annotated[
 687        Union[Annotated[int, Gt(0)], SizeReference],
 688        Field(
 689            examples=[
 690                10,
 691                SizeReference(
 692                    tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
 693                ).model_dump(mode="json"),
 694            ]
 695        ),
 696    ]
 697    """The size/length of this axis can be specified as
 698    - fixed integer
 699    - reference to another axis with an optional offset (see `SizeReference`)
 700    """
 701
 702
 703class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize):
 704    pass
 705
 706
 707class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo):
 708    pass
 709
 710
 711def _get_halo_axis_discriminator_value(v: Any) -> Literal["with_halo", "wo_halo"]:
 712    if isinstance(v, dict):
 713        return "with_halo" if "halo" in v else "wo_halo"
 714    else:
 715        return "with_halo" if hasattr(v, "halo") else "wo_halo"
 716
 717
 718_TimeOutputAxisUnion = Annotated[
 719    Union[
 720        Annotated[TimeOutputAxis, Tag("wo_halo")],
 721        Annotated[TimeOutputAxisWithHalo, Tag("with_halo")],
 722    ],
 723    Discriminator(_get_halo_axis_discriminator_value),
 724]
 725
 726
 727class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize):
 728    pass
 729
 730
 731class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo):
 732    pass
 733
 734
 735_SpaceOutputAxisUnion = Annotated[
 736    Union[
 737        Annotated[SpaceOutputAxis, Tag("wo_halo")],
 738        Annotated[SpaceOutputAxisWithHalo, Tag("with_halo")],
 739    ],
 740    Discriminator(_get_halo_axis_discriminator_value),
 741]
 742
 743
 744_OutputAxisUnion = Union[
 745    BatchAxis, ChannelAxis, IndexOutputAxis, _TimeOutputAxisUnion, _SpaceOutputAxisUnion
 746]
 747OutputAxis = Annotated[_OutputAxisUnion, Discriminator("type")]
 748
 749OUTPUT_AXIS_TYPES = (
 750    BatchAxis,
 751    ChannelAxis,
 752    IndexOutputAxis,
 753    TimeOutputAxis,
 754    TimeOutputAxisWithHalo,
 755    SpaceOutputAxis,
 756    SpaceOutputAxisWithHalo,
 757)
 758"""intended for isinstance comparisons in py<3.10"""
 759
 760
 761AnyAxis = Union[InputAxis, OutputAxis]
 762
 763ANY_AXIS_TYPES = INPUT_AXIS_TYPES + OUTPUT_AXIS_TYPES
 764"""intended for isinstance comparisons in py<3.10"""
 765
 766TVs = Union[
 767    NotEmpty[List[int]],
 768    NotEmpty[List[float]],
 769    NotEmpty[List[bool]],
 770    NotEmpty[List[str]],
 771]
 772
 773
 774NominalOrOrdinalDType = Literal[
 775    "float32",
 776    "float64",
 777    "uint8",
 778    "int8",
 779    "uint16",
 780    "int16",
 781    "uint32",
 782    "int32",
 783    "uint64",
 784    "int64",
 785    "bool",
 786]
 787
 788
 789class NominalOrOrdinalDataDescr(Node):
 790    values: TVs
 791    """A fixed set of nominal or an ascending sequence of ordinal values.
 792    In this case `data.type` is required to be an unsigend integer type, e.g. 'uint8'.
 793    String `values` are interpreted as labels for tensor values 0, ..., N.
 794    Note: as YAML 1.2 does not natively support a "set" datatype,
 795    nominal values should be given as a sequence (aka list/array) as well.
 796    """
 797
 798    type: Annotated[
 799        NominalOrOrdinalDType,
 800        Field(
 801            examples=[
 802                "float32",
 803                "uint8",
 804                "uint16",
 805                "int64",
 806                "bool",
 807            ],
 808        ),
 809    ] = "uint8"
 810
 811    @model_validator(mode="after")
 812    def _validate_values_match_type(
 813        self,
 814    ) -> Self:
 815        incompatible: List[Any] = []
 816        for v in self.values:
 817            if self.type == "bool":
 818                if not isinstance(v, bool):
 819                    incompatible.append(v)
 820            elif self.type in DTYPE_LIMITS:
 821                if (
 822                    isinstance(v, (int, float))
 823                    and (
 824                        v < DTYPE_LIMITS[self.type].min
 825                        or v > DTYPE_LIMITS[self.type].max
 826                    )
 827                    or (isinstance(v, str) and "uint" not in self.type)
 828                    or (isinstance(v, float) and "int" in self.type)
 829                ):
 830                    incompatible.append(v)
 831            else:
 832                incompatible.append(v)
 833
 834            if len(incompatible) == 5:
 835                incompatible.append("...")
 836                break
 837
 838        if incompatible:
 839            raise ValueError(
 840                f"data type '{self.type}' incompatible with values {incompatible}"
 841            )
 842
 843        return self
 844
 845    unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None
 846
 847    @property
 848    def range(self):
 849        if isinstance(self.values[0], str):
 850            return 0, len(self.values) - 1
 851        else:
 852            return min(self.values), max(self.values)
 853
 854
 855IntervalOrRatioDType = Literal[
 856    "float32",
 857    "float64",
 858    "uint8",
 859    "int8",
 860    "uint16",
 861    "int16",
 862    "uint32",
 863    "int32",
 864    "uint64",
 865    "int64",
 866]
 867
 868
 869class IntervalOrRatioDataDescr(Node):
 870    type: Annotated[  # TODO: rename to dtype
 871        IntervalOrRatioDType,
 872        Field(
 873            examples=["float32", "float64", "uint8", "uint16"],
 874        ),
 875    ] = "float32"
 876    range: Tuple[Optional[float], Optional[float]] = (
 877        None,
 878        None,
 879    )
 880    """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor.
 881    `None` corresponds to min/max of what can be expressed by **type**."""
 882    unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit"
 883    scale: float = 1.0
 884    """Scale for data on an interval (or ratio) scale."""
 885    offset: Optional[float] = None
 886    """Offset for data on a ratio scale."""
 887
 888    @model_validator(mode="before")
 889    def _replace_inf(cls, data: Any):
 890        if is_dict(data):
 891            if "range" in data and is_sequence(data["range"]):
 892                forbidden = (
 893                    "inf",
 894                    "-inf",
 895                    ".inf",
 896                    "-.inf",
 897                    float("inf"),
 898                    float("-inf"),
 899                )
 900                if any(v in forbidden for v in data["range"]):
 901                    issue_warning("replaced 'inf' value", value=data["range"])
 902
 903                data["range"] = tuple(
 904                    (None if v in forbidden else v) for v in data["range"]
 905                )
 906
 907        return data
 908
 909
 910TensorDataDescr = Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr]
 911
 912
 913class ProcessingDescrBase(NodeWithExplicitlySetFields, ABC):
 914    """processing base class"""
 915
 916
 917class BinarizeKwargs(ProcessingKwargs):
 918    """key word arguments for `BinarizeDescr`"""
 919
 920    threshold: float
 921    """The fixed threshold"""
 922
 923
 924class BinarizeAlongAxisKwargs(ProcessingKwargs):
 925    """key word arguments for `BinarizeDescr`"""
 926
 927    threshold: NotEmpty[List[float]]
 928    """The fixed threshold values along `axis`"""
 929
 930    axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
 931    """The `threshold` axis"""
 932
 933
 934class BinarizeDescr(ProcessingDescrBase):
 935    """Binarize the tensor with a fixed threshold.
 936
 937    Values above `BinarizeKwargs.threshold`/`BinarizeAlongAxisKwargs.threshold`
 938    will be set to one, values below the threshold to zero.
 939
 940    Examples:
 941    - in YAML
 942        ```yaml
 943        postprocessing:
 944          - id: binarize
 945            kwargs:
 946              axis: 'channel'
 947              threshold: [0.25, 0.5, 0.75]
 948        ```
 949    - in Python:
 950        >>> postprocessing = [BinarizeDescr(
 951        ...   kwargs=BinarizeAlongAxisKwargs(
 952        ...       axis=AxisId('channel'),
 953        ...       threshold=[0.25, 0.5, 0.75],
 954        ...   )
 955        ... )]
 956    """
 957
 958    implemented_id: ClassVar[Literal["binarize"]] = "binarize"
 959    if TYPE_CHECKING:
 960        id: Literal["binarize"] = "binarize"
 961    else:
 962        id: Literal["binarize"]
 963    kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs]
 964
 965
 966class ClipDescr(ProcessingDescrBase):
 967    """Set tensor values below min to min and above max to max.
 968
 969    See `ScaleRangeDescr` for examples.
 970    """
 971
 972    implemented_id: ClassVar[Literal["clip"]] = "clip"
 973    if TYPE_CHECKING:
 974        id: Literal["clip"] = "clip"
 975    else:
 976        id: Literal["clip"]
 977
 978    kwargs: ClipKwargs
 979
 980
 981class EnsureDtypeKwargs(ProcessingKwargs):
 982    """key word arguments for `EnsureDtypeDescr`"""
 983
 984    dtype: Literal[
 985        "float32",
 986        "float64",
 987        "uint8",
 988        "int8",
 989        "uint16",
 990        "int16",
 991        "uint32",
 992        "int32",
 993        "uint64",
 994        "int64",
 995        "bool",
 996    ]
 997
 998
 999class EnsureDtypeDescr(ProcessingDescrBase):
1000    """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching).
1001
1002    This can for example be used to ensure the inner neural network model gets a
1003    different input tensor data type than the fully described bioimage.io model does.
1004
1005    Examples:
1006        The described bioimage.io model (incl. preprocessing) accepts any
1007        float32-compatible tensor, normalizes it with percentiles and clipping and then
1008        casts it to uint8, which is what the neural network in this example expects.
1009        - in YAML
1010            ```yaml
1011            inputs:
1012            - data:
1013                type: float32  # described bioimage.io model is compatible with any float32 input tensor
1014              preprocessing:
1015              - id: scale_range
1016                  kwargs:
1017                  axes: ['y', 'x']
1018                  max_percentile: 99.8
1019                  min_percentile: 5.0
1020              - id: clip
1021                  kwargs:
1022                  min: 0.0
1023                  max: 1.0
1024              - id: ensure_dtype  # the neural network of the model requires uint8
1025                  kwargs:
1026                  dtype: uint8
1027            ```
1028        - in Python:
1029            >>> preprocessing = [
1030            ...     ScaleRangeDescr(
1031            ...         kwargs=ScaleRangeKwargs(
1032            ...           axes= (AxisId('y'), AxisId('x')),
1033            ...           max_percentile= 99.8,
1034            ...           min_percentile= 5.0,
1035            ...         )
1036            ...     ),
1037            ...     ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)),
1038            ...     EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")),
1039            ... ]
1040    """
1041
1042    implemented_id: ClassVar[Literal["ensure_dtype"]] = "ensure_dtype"
1043    if TYPE_CHECKING:
1044        id: Literal["ensure_dtype"] = "ensure_dtype"
1045    else:
1046        id: Literal["ensure_dtype"]
1047
1048    kwargs: EnsureDtypeKwargs
1049
1050
1051class ScaleLinearKwargs(ProcessingKwargs):
1052    """Key word arguments for `ScaleLinearDescr`"""
1053
1054    gain: float = 1.0
1055    """multiplicative factor"""
1056
1057    offset: float = 0.0
1058    """additive term"""
1059
1060    @model_validator(mode="after")
1061    def _validate(self) -> Self:
1062        if self.gain == 1.0 and self.offset == 0.0:
1063            raise ValueError(
1064                "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
1065                + " != 0.0."
1066            )
1067
1068        return self
1069
1070
1071class ScaleLinearAlongAxisKwargs(ProcessingKwargs):
1072    """Key word arguments for `ScaleLinearDescr`"""
1073
1074    axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
1075    """The axis of gain and offset values."""
1076
1077    gain: Union[float, NotEmpty[List[float]]] = 1.0
1078    """multiplicative factor"""
1079
1080    offset: Union[float, NotEmpty[List[float]]] = 0.0
1081    """additive term"""
1082
1083    @model_validator(mode="after")
1084    def _validate(self) -> Self:
1085        if isinstance(self.gain, list):
1086            if isinstance(self.offset, list):
1087                if len(self.gain) != len(self.offset):
1088                    raise ValueError(
1089                        f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match."
1090                    )
1091            else:
1092                self.offset = [float(self.offset)] * len(self.gain)
1093        elif isinstance(self.offset, list):
1094            self.gain = [float(self.gain)] * len(self.offset)
1095        else:
1096            raise ValueError(
1097                "Do not specify an `axis` for scalar gain and offset values."
1098            )
1099
1100        if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset):
1101            raise ValueError(
1102                "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
1103                + " != 0.0."
1104            )
1105
1106        return self
1107
1108
1109class ScaleLinearDescr(ProcessingDescrBase):
1110    """Fixed linear scaling.
1111
1112    Examples:
1113      1. Scale with scalar gain and offset
1114        - in YAML
1115        ```yaml
1116        preprocessing:
1117          - id: scale_linear
1118            kwargs:
1119              gain: 2.0
1120              offset: 3.0
1121        ```
1122        - in Python:
1123        >>> preprocessing = [
1124        ...     ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0))
1125        ... ]
1126
1127      2. Independent scaling along an axis
1128        - in YAML
1129        ```yaml
1130        preprocessing:
1131          - id: scale_linear
1132            kwargs:
1133              axis: 'channel'
1134              gain: [1.0, 2.0, 3.0]
1135        ```
1136        - in Python:
1137        >>> preprocessing = [
1138        ...     ScaleLinearDescr(
1139        ...         kwargs=ScaleLinearAlongAxisKwargs(
1140        ...             axis=AxisId("channel"),
1141        ...             gain=[1.0, 2.0, 3.0],
1142        ...         )
1143        ...     )
1144        ... ]
1145
1146    """
1147
1148    implemented_id: ClassVar[Literal["scale_linear"]] = "scale_linear"
1149    if TYPE_CHECKING:
1150        id: Literal["scale_linear"] = "scale_linear"
1151    else:
1152        id: Literal["scale_linear"]
1153    kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs]
1154
1155
1156class SigmoidDescr(ProcessingDescrBase):
1157    """The logistic sigmoid function, a.k.a. expit function.
1158
1159    Examples:
1160    - in YAML
1161        ```yaml
1162        postprocessing:
1163          - id: sigmoid
1164        ```
1165    - in Python:
1166        >>> postprocessing = [SigmoidDescr()]
1167    """
1168
1169    implemented_id: ClassVar[Literal["sigmoid"]] = "sigmoid"
1170    if TYPE_CHECKING:
1171        id: Literal["sigmoid"] = "sigmoid"
1172    else:
1173        id: Literal["sigmoid"]
1174
1175    @property
1176    def kwargs(self) -> ProcessingKwargs:
1177        """empty kwargs"""
1178        return ProcessingKwargs()
1179
1180
1181class SoftmaxKwargs(ProcessingKwargs):
1182    """key word arguments for `SoftmaxDescr`"""
1183
1184    axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] = AxisId("channel")
1185    """The axis to apply the softmax function along.
1186    Note:
1187        Defaults to 'channel' axis
1188        (which may not exist, in which case
1189        a different axis id has to be specified).
1190    """
1191
1192
1193class SoftmaxDescr(ProcessingDescrBase):
1194    """The softmax function.
1195
1196    Examples:
1197    - in YAML
1198        ```yaml
1199        postprocessing:
1200          - id: softmax
1201            kwargs:
1202              axis: channel
1203        ```
1204    - in Python:
1205        >>> postprocessing = [SoftmaxDescr(kwargs=SoftmaxKwargs(axis=AxisId("channel")))]
1206    """
1207
1208    implemented_id: ClassVar[Literal["softmax"]] = "softmax"
1209    if TYPE_CHECKING:
1210        id: Literal["softmax"] = "softmax"
1211    else:
1212        id: Literal["softmax"]
1213
1214    kwargs: SoftmaxKwargs = Field(default_factory=SoftmaxKwargs.model_construct)
1215
1216
1217class FixedZeroMeanUnitVarianceKwargs(ProcessingKwargs):
1218    """key word arguments for `FixedZeroMeanUnitVarianceDescr`"""
1219
1220    mean: float
1221    """The mean value to normalize with."""
1222
1223    std: Annotated[float, Ge(1e-6)]
1224    """The standard deviation value to normalize with."""
1225
1226
1227class FixedZeroMeanUnitVarianceAlongAxisKwargs(ProcessingKwargs):
1228    """key word arguments for `FixedZeroMeanUnitVarianceDescr`"""
1229
1230    mean: NotEmpty[List[float]]
1231    """The mean value(s) to normalize with."""
1232
1233    std: NotEmpty[List[Annotated[float, Ge(1e-6)]]]
1234    """The standard deviation value(s) to normalize with.
1235    Size must match `mean` values."""
1236
1237    axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])]
1238    """The axis of the mean/std values to normalize each entry along that dimension
1239    separately."""
1240
1241    @model_validator(mode="after")
1242    def _mean_and_std_match(self) -> Self:
1243        if len(self.mean) != len(self.std):
1244            raise ValueError(
1245                f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})"
1246                + " must match."
1247            )
1248
1249        return self
1250
1251
1252class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1253    """Subtract a given mean and divide by the standard deviation.
1254
1255    Normalize with fixed, precomputed values for
1256    `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std`
1257    Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given
1258    axes.
1259
1260    Examples:
1261    1. scalar value for whole tensor
1262        - in YAML
1263        ```yaml
1264        preprocessing:
1265          - id: fixed_zero_mean_unit_variance
1266            kwargs:
1267              mean: 103.5
1268              std: 13.7
1269        ```
1270        - in Python
1271        >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
1272        ...   kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7)
1273        ... )]
1274
1275    2. independently along an axis
1276        - in YAML
1277        ```yaml
1278        preprocessing:
1279          - id: fixed_zero_mean_unit_variance
1280            kwargs:
1281              axis: channel
1282              mean: [101.5, 102.5, 103.5]
1283              std: [11.7, 12.7, 13.7]
1284        ```
1285        - in Python
1286        >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
1287        ...   kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
1288        ...     axis=AxisId("channel"),
1289        ...     mean=[101.5, 102.5, 103.5],
1290        ...     std=[11.7, 12.7, 13.7],
1291        ...   )
1292        ... )]
1293    """
1294
1295    implemented_id: ClassVar[Literal["fixed_zero_mean_unit_variance"]] = (
1296        "fixed_zero_mean_unit_variance"
1297    )
1298    if TYPE_CHECKING:
1299        id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance"
1300    else:
1301        id: Literal["fixed_zero_mean_unit_variance"]
1302
1303    kwargs: Union[
1304        FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs
1305    ]
1306
1307
1308class ZeroMeanUnitVarianceKwargs(ProcessingKwargs):
1309    """key word arguments for `ZeroMeanUnitVarianceDescr`"""
1310
1311    axes: Annotated[
1312        Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1313    ] = None
1314    """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1315    For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1316    resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1317    To normalize each sample independently leave out the 'batch' axis.
1318    Default: Scale all axes jointly."""
1319
1320    eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1321    """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`."""
1322
1323
1324class ZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1325    """Subtract mean and divide by variance.
1326
1327    Examples:
1328        Subtract tensor mean and variance
1329        - in YAML
1330        ```yaml
1331        preprocessing:
1332          - id: zero_mean_unit_variance
1333        ```
1334        - in Python
1335        >>> preprocessing = [ZeroMeanUnitVarianceDescr()]
1336    """
1337
1338    implemented_id: ClassVar[Literal["zero_mean_unit_variance"]] = (
1339        "zero_mean_unit_variance"
1340    )
1341    if TYPE_CHECKING:
1342        id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance"
1343    else:
1344        id: Literal["zero_mean_unit_variance"]
1345
1346    kwargs: ZeroMeanUnitVarianceKwargs = Field(
1347        default_factory=ZeroMeanUnitVarianceKwargs.model_construct
1348    )
1349
1350
1351class ScaleRangeKwargs(ProcessingKwargs):
1352    """key word arguments for `ScaleRangeDescr`
1353
1354    For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default)
1355    this processing step normalizes data to the [0, 1] intervall.
1356    For other percentiles the normalized values will partially be outside the [0, 1]
1357    intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the
1358    normalized values to a range.
1359    """
1360
1361    axes: Annotated[
1362        Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1363    ] = None
1364    """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value.
1365    For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1366    resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1367    To normalize samples independently, leave out the "batch" axis.
1368    Default: Scale all axes jointly."""
1369
1370    min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0
1371    """The lower percentile used to determine the value to align with zero."""
1372
1373    max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0
1374    """The upper percentile used to determine the value to align with one.
1375    Has to be bigger than `min_percentile`.
1376    The range is 1 to 100 instead of 0 to 100 to avoid mistakenly
1377    accepting percentiles specified in the range 0.0 to 1.0."""
1378
1379    eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1380    """Epsilon for numeric stability.
1381    `out = (tensor - v_lower) / (v_upper - v_lower + eps)`;
1382    with `v_lower,v_upper` values at the respective percentiles."""
1383
1384    reference_tensor: Optional[TensorId] = None
1385    """Tensor ID to compute the percentiles from. Default: The tensor itself.
1386    For any tensor in `inputs` only input tensor references are allowed."""
1387
1388    @field_validator("max_percentile", mode="after")
1389    @classmethod
1390    def min_smaller_max(cls, value: float, info: ValidationInfo) -> float:
1391        if (min_p := info.data["min_percentile"]) >= value:
1392            raise ValueError(f"min_percentile {min_p} >= max_percentile {value}")
1393
1394        return value
1395
1396
1397class ScaleRangeDescr(ProcessingDescrBase):
1398    """Scale with percentiles.
1399
1400    Examples:
1401    1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0
1402        - in YAML
1403        ```yaml
1404        preprocessing:
1405          - id: scale_range
1406            kwargs:
1407              axes: ['y', 'x']
1408              max_percentile: 99.8
1409              min_percentile: 5.0
1410        ```
1411        - in Python
1412        >>> preprocessing = [
1413        ...     ScaleRangeDescr(
1414        ...         kwargs=ScaleRangeKwargs(
1415        ...           axes= (AxisId('y'), AxisId('x')),
1416        ...           max_percentile= 99.8,
1417        ...           min_percentile= 5.0,
1418        ...         )
1419        ...     ),
1420        ...     ClipDescr(
1421        ...         kwargs=ClipKwargs(
1422        ...             min=0.0,
1423        ...             max=1.0,
1424        ...         )
1425        ...     ),
1426        ... ]
1427
1428      2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles.
1429        - in YAML
1430        ```yaml
1431        preprocessing:
1432          - id: scale_range
1433            kwargs:
1434              axes: ['y', 'x']
1435              max_percentile: 99.8
1436              min_percentile: 5.0
1437                  - id: scale_range
1438           - id: clip
1439             kwargs:
1440              min: 0.0
1441              max: 1.0
1442        ```
1443        - in Python
1444        >>> preprocessing = [ScaleRangeDescr(
1445        ...   kwargs=ScaleRangeKwargs(
1446        ...       axes= (AxisId('y'), AxisId('x')),
1447        ...       max_percentile= 99.8,
1448        ...       min_percentile= 5.0,
1449        ...   )
1450        ... )]
1451
1452    """
1453
1454    implemented_id: ClassVar[Literal["scale_range"]] = "scale_range"
1455    if TYPE_CHECKING:
1456        id: Literal["scale_range"] = "scale_range"
1457    else:
1458        id: Literal["scale_range"]
1459    kwargs: ScaleRangeKwargs = Field(default_factory=ScaleRangeKwargs.model_construct)
1460
1461
1462class ScaleMeanVarianceKwargs(ProcessingKwargs):
1463    """key word arguments for `ScaleMeanVarianceKwargs`"""
1464
1465    reference_tensor: TensorId
1466    """Name of tensor to match."""
1467
1468    axes: Annotated[
1469        Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1470    ] = None
1471    """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1472    For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1473    resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1474    To normalize samples independently, leave out the 'batch' axis.
1475    Default: Scale all axes jointly."""
1476
1477    eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1478    """Epsilon for numeric stability:
1479    `out  = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`"""
1480
1481
1482class ScaleMeanVarianceDescr(ProcessingDescrBase):
1483    """Scale a tensor's data distribution to match another tensor's mean/std.
1484    `out  = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`
1485    """
1486
1487    implemented_id: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance"
1488    if TYPE_CHECKING:
1489        id: Literal["scale_mean_variance"] = "scale_mean_variance"
1490    else:
1491        id: Literal["scale_mean_variance"]
1492    kwargs: ScaleMeanVarianceKwargs
1493
1494
1495PreprocessingDescr = Annotated[
1496    Union[
1497        BinarizeDescr,
1498        ClipDescr,
1499        EnsureDtypeDescr,
1500        FixedZeroMeanUnitVarianceDescr,
1501        ScaleLinearDescr,
1502        ScaleRangeDescr,
1503        SigmoidDescr,
1504        SoftmaxDescr,
1505        ZeroMeanUnitVarianceDescr,
1506    ],
1507    Discriminator("id"),
1508]
1509PostprocessingDescr = Annotated[
1510    Union[
1511        BinarizeDescr,
1512        ClipDescr,
1513        EnsureDtypeDescr,
1514        FixedZeroMeanUnitVarianceDescr,
1515        ScaleLinearDescr,
1516        ScaleMeanVarianceDescr,
1517        ScaleRangeDescr,
1518        SigmoidDescr,
1519        SoftmaxDescr,
1520        ZeroMeanUnitVarianceDescr,
1521    ],
1522    Discriminator("id"),
1523]
1524
1525IO_AxisT = TypeVar("IO_AxisT", InputAxis, OutputAxis)
1526
1527
1528class TensorDescrBase(Node, Generic[IO_AxisT]):
1529    id: TensorId
1530    """Tensor id. No duplicates are allowed."""
1531
1532    description: Annotated[str, MaxLen(128)] = ""
1533    """free text description"""
1534
1535    axes: NotEmpty[Sequence[IO_AxisT]]
1536    """tensor axes"""
1537
1538    @property
1539    def shape(self):
1540        return tuple(a.size for a in self.axes)
1541
1542    @field_validator("axes", mode="after", check_fields=False)
1543    @classmethod
1544    def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]:
1545        batch_axes = [a for a in axes if a.type == "batch"]
1546        if len(batch_axes) > 1:
1547            raise ValueError(
1548                f"Only one batch axis (per tensor) allowed, but got {batch_axes}"
1549            )
1550
1551        seen_ids: Set[AxisId] = set()
1552        duplicate_axes_ids: Set[AxisId] = set()
1553        for a in axes:
1554            (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id)
1555
1556        if duplicate_axes_ids:
1557            raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}")
1558
1559        return axes
1560
1561    test_tensor: FAIR[Optional[FileDescr_]] = None
1562    """An example tensor to use for testing.
1563    Using the model with the test input tensors is expected to yield the test output tensors.
1564    Each test tensor has be a an ndarray in the
1565    [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format).
1566    The file extension must be '.npy'."""
1567
1568    sample_tensor: FAIR[Optional[FileDescr_]] = None
1569    """A sample tensor to illustrate a possible input/output for the model,
1570    The sample image primarily serves to inform a human user about an example use case
1571    and is typically stored as .hdf5, .png or .tiff.
1572    It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats)
1573    (numpy's `.npy` format is not supported).
1574    The image dimensionality has to match the number of axes specified in this tensor description.
1575    """
1576
1577    @model_validator(mode="after")
1578    def _validate_sample_tensor(self) -> Self:
1579        if self.sample_tensor is None or not get_validation_context().perform_io_checks:
1580            return self
1581
1582        reader = get_reader(self.sample_tensor.source, sha256=self.sample_tensor.sha256)
1583        tensor: NDArray[Any] = imread(
1584            reader.read(),
1585            extension=PurePosixPath(reader.original_file_name).suffix,
1586        )
1587        n_dims = len(tensor.squeeze().shape)
1588        n_dims_min = n_dims_max = len(self.axes)
1589
1590        for a in self.axes:
1591            if isinstance(a, BatchAxis):
1592                n_dims_min -= 1
1593            elif isinstance(a.size, int):
1594                if a.size == 1:
1595                    n_dims_min -= 1
1596            elif isinstance(a.size, (ParameterizedSize, DataDependentSize)):
1597                if a.size.min == 1:
1598                    n_dims_min -= 1
1599            elif isinstance(a.size, SizeReference):
1600                if a.size.offset < 2:
1601                    # size reference may result in singleton axis
1602                    n_dims_min -= 1
1603            else:
1604                assert_never(a.size)
1605
1606        n_dims_min = max(0, n_dims_min)
1607        if n_dims < n_dims_min or n_dims > n_dims_max:
1608            raise ValueError(
1609                f"Expected sample tensor to have {n_dims_min} to"
1610                + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})."
1611            )
1612
1613        return self
1614
1615    data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = (
1616        IntervalOrRatioDataDescr()
1617    )
1618    """Description of the tensor's data values, optionally per channel.
1619    If specified per channel, the data `type` needs to match across channels."""
1620
1621    @property
1622    def dtype(
1623        self,
1624    ) -> Literal[
1625        "float32",
1626        "float64",
1627        "uint8",
1628        "int8",
1629        "uint16",
1630        "int16",
1631        "uint32",
1632        "int32",
1633        "uint64",
1634        "int64",
1635        "bool",
1636    ]:
1637        """dtype as specified under `data.type` or `data[i].type`"""
1638        if isinstance(self.data, collections.abc.Sequence):
1639            return self.data[0].type
1640        else:
1641            return self.data.type
1642
1643    @field_validator("data", mode="after")
1644    @classmethod
1645    def _check_data_type_across_channels(
1646        cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]
1647    ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]:
1648        if not isinstance(value, list):
1649            return value
1650
1651        dtypes = {t.type for t in value}
1652        if len(dtypes) > 1:
1653            raise ValueError(
1654                "Tensor data descriptions per channel need to agree in their data"
1655                + f" `type`, but found {dtypes}."
1656            )
1657
1658        return value
1659
1660    @model_validator(mode="after")
1661    def _check_data_matches_channelaxis(self) -> Self:
1662        if not isinstance(self.data, (list, tuple)):
1663            return self
1664
1665        for a in self.axes:
1666            if isinstance(a, ChannelAxis):
1667                size = a.size
1668                assert isinstance(size, int)
1669                break
1670        else:
1671            return self
1672
1673        if len(self.data) != size:
1674            raise ValueError(
1675                f"Got tensor data descriptions for {len(self.data)} channels, but"
1676                + f" '{a.id}' axis has size {size}."
1677            )
1678
1679        return self
1680
1681    def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]:
1682        if len(array.shape) != len(self.axes):
1683            raise ValueError(
1684                f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})"
1685                + f" incompatible with {len(self.axes)} axes."
1686            )
1687        return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
1688
1689
1690class InputTensorDescr(TensorDescrBase[InputAxis]):
1691    id: TensorId = TensorId("input")
1692    """Input tensor id.
1693    No duplicates are allowed across all inputs and outputs."""
1694
1695    optional: bool = False
1696    """indicates that this tensor may be `None`"""
1697
1698    preprocessing: List[PreprocessingDescr] = Field(
1699        default_factory=cast(Callable[[], List[PreprocessingDescr]], list)
1700    )
1701
1702    """Description of how this input should be preprocessed.
1703
1704    notes:
1705    - If preprocessing does not start with an 'ensure_dtype' entry, it is added
1706      to ensure an input tensor's data type matches the input tensor's data description.
1707    - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an
1708      'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally
1709      changing the data type.
1710    """
1711
1712    @model_validator(mode="after")
1713    def _validate_preprocessing_kwargs(self) -> Self:
1714        axes_ids = [a.id for a in self.axes]
1715        for p in self.preprocessing:
1716            kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
1717            if kwargs_axes is None:
1718                continue
1719
1720            if not isinstance(kwargs_axes, collections.abc.Sequence):
1721                raise ValueError(
1722                    f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}"
1723                )
1724
1725            if any(a not in axes_ids for a in kwargs_axes):
1726                raise ValueError(
1727                    "`preprocessing.i.kwargs.axes` needs to be subset of axes ids"
1728                )
1729
1730        if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
1731            dtype = self.data.type
1732        else:
1733            dtype = self.data[0].type
1734
1735        # ensure `preprocessing` begins with `EnsureDtypeDescr`
1736        if not self.preprocessing or not isinstance(
1737            self.preprocessing[0], EnsureDtypeDescr
1738        ):
1739            self.preprocessing.insert(
1740                0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1741            )
1742
1743        # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
1744        if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)):
1745            self.preprocessing.append(
1746                EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1747            )
1748
1749        return self
1750
1751
1752def convert_axes(
1753    axes: str,
1754    *,
1755    shape: Union[
1756        Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4
1757    ],
1758    tensor_type: Literal["input", "output"],
1759    halo: Optional[Sequence[int]],
1760    size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
1761):
1762    ret: List[AnyAxis] = []
1763    for i, a in enumerate(axes):
1764        axis_type = _AXIS_TYPE_MAP.get(a, a)
1765        if axis_type == "batch":
1766            ret.append(BatchAxis())
1767            continue
1768
1769        scale = 1.0
1770        if isinstance(shape, _ParameterizedInputShape_v0_4):
1771            if shape.step[i] == 0:
1772                size = shape.min[i]
1773            else:
1774                size = ParameterizedSize(min=shape.min[i], step=shape.step[i])
1775        elif isinstance(shape, _ImplicitOutputShape_v0_4):
1776            ref_t = str(shape.reference_tensor)
1777            if ref_t.count(".") == 1:
1778                t_id, orig_a_id = ref_t.split(".")
1779            else:
1780                t_id = ref_t
1781                orig_a_id = a
1782
1783            a_id = _AXIS_ID_MAP.get(orig_a_id, a)
1784            if not (orig_scale := shape.scale[i]):
1785                # old way to insert a new axis dimension
1786                size = int(2 * shape.offset[i])
1787            else:
1788                scale = 1 / orig_scale
1789                if axis_type in ("channel", "index"):
1790                    # these axes no longer have a scale
1791                    offset_from_scale = orig_scale * size_refs.get(
1792                        _TensorName_v0_4(t_id), {}
1793                    ).get(orig_a_id, 0)
1794                else:
1795                    offset_from_scale = 0
1796                size = SizeReference(
1797                    tensor_id=TensorId(t_id),
1798                    axis_id=AxisId(a_id),
1799                    offset=int(offset_from_scale + 2 * shape.offset[i]),
1800                )
1801        else:
1802            size = shape[i]
1803
1804        if axis_type == "time":
1805            if tensor_type == "input":
1806                ret.append(TimeInputAxis(size=size, scale=scale))
1807            else:
1808                assert not isinstance(size, ParameterizedSize)
1809                if halo is None:
1810                    ret.append(TimeOutputAxis(size=size, scale=scale))
1811                else:
1812                    assert not isinstance(size, int)
1813                    ret.append(
1814                        TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i])
1815                    )
1816
1817        elif axis_type == "index":
1818            if tensor_type == "input":
1819                ret.append(IndexInputAxis(size=size))
1820            else:
1821                if isinstance(size, ParameterizedSize):
1822                    size = DataDependentSize(min=size.min)
1823
1824                ret.append(IndexOutputAxis(size=size))
1825        elif axis_type == "channel":
1826            assert not isinstance(size, ParameterizedSize)
1827            if isinstance(size, SizeReference):
1828                warnings.warn(
1829                    "Conversion of channel size from an implicit output shape may be"
1830                    + " wrong"
1831                )
1832                ret.append(
1833                    ChannelAxis(
1834                        channel_names=[
1835                            Identifier(f"channel{i}") for i in range(size.offset)
1836                        ]
1837                    )
1838                )
1839            else:
1840                ret.append(
1841                    ChannelAxis(
1842                        channel_names=[Identifier(f"channel{i}") for i in range(size)]
1843                    )
1844                )
1845        elif axis_type == "space":
1846            if tensor_type == "input":
1847                ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale))
1848            else:
1849                assert not isinstance(size, ParameterizedSize)
1850                if halo is None or halo[i] == 0:
1851                    ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale))
1852                elif isinstance(size, int):
1853                    raise NotImplementedError(
1854                        f"output axis with halo and fixed size (here {size}) not allowed"
1855                    )
1856                else:
1857                    ret.append(
1858                        SpaceOutputAxisWithHalo(
1859                            id=AxisId(a), size=size, scale=scale, halo=halo[i]
1860                        )
1861                    )
1862
1863    return ret
1864
1865
1866def _axes_letters_to_ids(
1867    axes: Optional[str],
1868) -> Optional[List[AxisId]]:
1869    if axes is None:
1870        return None
1871
1872    return [AxisId(a) for a in axes]
1873
1874
1875def _get_complement_v04_axis(
1876    tensor_axes: Sequence[str], axes: Optional[Sequence[str]]
1877) -> Optional[AxisId]:
1878    if axes is None:
1879        return None
1880
1881    non_complement_axes = set(axes) | {"b"}
1882    complement_axes = [a for a in tensor_axes if a not in non_complement_axes]
1883    if len(complement_axes) > 1:
1884        raise ValueError(
1885            f"Expected none or a single complement axis, but axes '{axes}' "
1886            + f"for tensor dims '{tensor_axes}' leave '{complement_axes}'."
1887        )
1888
1889    return None if not complement_axes else AxisId(complement_axes[0])
1890
1891
1892def _convert_proc(
1893    p: Union[_PreprocessingDescr_v0_4, _PostprocessingDescr_v0_4],
1894    tensor_axes: Sequence[str],
1895) -> Union[PreprocessingDescr, PostprocessingDescr]:
1896    if isinstance(p, _BinarizeDescr_v0_4):
1897        return BinarizeDescr(kwargs=BinarizeKwargs(threshold=p.kwargs.threshold))
1898    elif isinstance(p, _ClipDescr_v0_4):
1899        return ClipDescr(kwargs=ClipKwargs(min=p.kwargs.min, max=p.kwargs.max))
1900    elif isinstance(p, _SigmoidDescr_v0_4):
1901        return SigmoidDescr()
1902    elif isinstance(p, _ScaleLinearDescr_v0_4):
1903        axes = _axes_letters_to_ids(p.kwargs.axes)
1904        if p.kwargs.axes is None:
1905            axis = None
1906        else:
1907            axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes)
1908
1909        if axis is None:
1910            assert not isinstance(p.kwargs.gain, list)
1911            assert not isinstance(p.kwargs.offset, list)
1912            kwargs = ScaleLinearKwargs(gain=p.kwargs.gain, offset=p.kwargs.offset)
1913        else:
1914            kwargs = ScaleLinearAlongAxisKwargs(
1915                axis=axis, gain=p.kwargs.gain, offset=p.kwargs.offset
1916            )
1917        return ScaleLinearDescr(kwargs=kwargs)
1918    elif isinstance(p, _ScaleMeanVarianceDescr_v0_4):
1919        return ScaleMeanVarianceDescr(
1920            kwargs=ScaleMeanVarianceKwargs(
1921                axes=_axes_letters_to_ids(p.kwargs.axes),
1922                reference_tensor=TensorId(str(p.kwargs.reference_tensor)),
1923                eps=p.kwargs.eps,
1924            )
1925        )
1926    elif isinstance(p, _ZeroMeanUnitVarianceDescr_v0_4):
1927        if p.kwargs.mode == "fixed":
1928            mean = p.kwargs.mean
1929            std = p.kwargs.std
1930            assert mean is not None
1931            assert std is not None
1932
1933            axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes)
1934
1935            if axis is None:
1936                if isinstance(mean, list):
1937                    raise ValueError("Expected single float value for mean, not <list>")
1938                if isinstance(std, list):
1939                    raise ValueError("Expected single float value for std, not <list>")
1940                return FixedZeroMeanUnitVarianceDescr(
1941                    kwargs=FixedZeroMeanUnitVarianceKwargs.model_construct(
1942                        mean=mean,
1943                        std=std,
1944                    )
1945                )
1946            else:
1947                if not isinstance(mean, list):
1948                    mean = [float(mean)]
1949                if not isinstance(std, list):
1950                    std = [float(std)]
1951
1952                return FixedZeroMeanUnitVarianceDescr(
1953                    kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
1954                        axis=axis, mean=mean, std=std
1955                    )
1956                )
1957
1958        else:
1959            axes = _axes_letters_to_ids(p.kwargs.axes) or []
1960            if p.kwargs.mode == "per_dataset":
1961                axes = [AxisId("batch")] + axes
1962            if not axes:
1963                axes = None
1964            return ZeroMeanUnitVarianceDescr(
1965                kwargs=ZeroMeanUnitVarianceKwargs(axes=axes, eps=p.kwargs.eps)
1966            )
1967
1968    elif isinstance(p, _ScaleRangeDescr_v0_4):
1969        return ScaleRangeDescr(
1970            kwargs=ScaleRangeKwargs(
1971                axes=_axes_letters_to_ids(p.kwargs.axes),
1972                min_percentile=p.kwargs.min_percentile,
1973                max_percentile=p.kwargs.max_percentile,
1974                eps=p.kwargs.eps,
1975            )
1976        )
1977    else:
1978        assert_never(p)
1979
1980
1981class _InputTensorConv(
1982    Converter[
1983        _InputTensorDescr_v0_4,
1984        InputTensorDescr,
1985        FileSource_,
1986        Optional[FileSource_],
1987        Mapping[_TensorName_v0_4, Mapping[str, int]],
1988    ]
1989):
1990    def _convert(
1991        self,
1992        src: _InputTensorDescr_v0_4,
1993        tgt: "type[InputTensorDescr] | type[dict[str, Any]]",
1994        test_tensor: FileSource_,
1995        sample_tensor: Optional[FileSource_],
1996        size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
1997    ) -> "InputTensorDescr | dict[str, Any]":
1998        axes: List[InputAxis] = convert_axes(  # pyright: ignore[reportAssignmentType]
1999            src.axes,
2000            shape=src.shape,
2001            tensor_type="input",
2002            halo=None,
2003            size_refs=size_refs,
2004        )
2005        prep: List[PreprocessingDescr] = []
2006        for p in src.preprocessing:
2007            cp = _convert_proc(p, src.axes)
2008            assert not isinstance(cp, ScaleMeanVarianceDescr)
2009            prep.append(cp)
2010
2011        prep.append(EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="float32")))
2012
2013        return tgt(
2014            axes=axes,
2015            id=TensorId(str(src.name)),
2016            test_tensor=FileDescr(source=test_tensor),
2017            sample_tensor=(
2018                None if sample_tensor is None else FileDescr(source=sample_tensor)
2019            ),
2020            data=dict(type=src.data_type),  # pyright: ignore[reportArgumentType]
2021            preprocessing=prep,
2022        )
2023
2024
2025_input_tensor_conv = _InputTensorConv(_InputTensorDescr_v0_4, InputTensorDescr)
2026
2027
2028class OutputTensorDescr(TensorDescrBase[OutputAxis]):
2029    id: TensorId = TensorId("output")
2030    """Output tensor id.
2031    No duplicates are allowed across all inputs and outputs."""
2032
2033    postprocessing: List[PostprocessingDescr] = Field(
2034        default_factory=cast(Callable[[], List[PostprocessingDescr]], list)
2035    )
2036    """Description of how this output should be postprocessed.
2037
2038    note: `postprocessing` always ends with an 'ensure_dtype' operation.
2039          If not given this is added to cast to this tensor's `data.type`.
2040    """
2041
2042    @model_validator(mode="after")
2043    def _validate_postprocessing_kwargs(self) -> Self:
2044        axes_ids = [a.id for a in self.axes]
2045        for p in self.postprocessing:
2046            kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
2047            if kwargs_axes is None:
2048                continue
2049
2050            if not isinstance(kwargs_axes, collections.abc.Sequence):
2051                raise ValueError(
2052                    f"expected `axes` sequence, but got {type(kwargs_axes)}"
2053                )
2054
2055            if any(a not in axes_ids for a in kwargs_axes):
2056                raise ValueError("`kwargs.axes` needs to be subset of axes ids")
2057
2058        if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
2059            dtype = self.data.type
2060        else:
2061            dtype = self.data[0].type
2062
2063        # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
2064        if not self.postprocessing or not isinstance(
2065            self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)
2066        ):
2067            self.postprocessing.append(
2068                EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
2069            )
2070        return self
2071
2072
2073class _OutputTensorConv(
2074    Converter[
2075        _OutputTensorDescr_v0_4,
2076        OutputTensorDescr,
2077        FileSource_,
2078        Optional[FileSource_],
2079        Mapping[_TensorName_v0_4, Mapping[str, int]],
2080    ]
2081):
2082    def _convert(
2083        self,
2084        src: _OutputTensorDescr_v0_4,
2085        tgt: "type[OutputTensorDescr] | type[dict[str, Any]]",
2086        test_tensor: FileSource_,
2087        sample_tensor: Optional[FileSource_],
2088        size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
2089    ) -> "OutputTensorDescr | dict[str, Any]":
2090        # TODO: split convert_axes into convert_output_axes and convert_input_axes
2091        axes: List[OutputAxis] = convert_axes(  # pyright: ignore[reportAssignmentType]
2092            src.axes,
2093            shape=src.shape,
2094            tensor_type="output",
2095            halo=src.halo,
2096            size_refs=size_refs,
2097        )
2098        data_descr: Dict[str, Any] = dict(type=src.data_type)
2099        if data_descr["type"] == "bool":
2100            data_descr["values"] = [False, True]
2101
2102        return tgt(
2103            axes=axes,
2104            id=TensorId(str(src.name)),
2105            test_tensor=FileDescr(source=test_tensor),
2106            sample_tensor=(
2107                None if sample_tensor is None else FileDescr(source=sample_tensor)
2108            ),
2109            data=data_descr,  # pyright: ignore[reportArgumentType]
2110            postprocessing=[_convert_proc(p, src.axes) for p in src.postprocessing],
2111        )
2112
2113
2114_output_tensor_conv = _OutputTensorConv(_OutputTensorDescr_v0_4, OutputTensorDescr)
2115
2116
2117TensorDescr = Union[InputTensorDescr, OutputTensorDescr]
2118
2119
2120def validate_tensors(
2121    tensors: Mapping[TensorId, Tuple[TensorDescr, Optional[NDArray[Any]]]],
2122    tensor_origin: Literal[
2123        "test_tensor"
2124    ],  # for more precise error messages, e.g. 'test_tensor'
2125):
2126    all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, Optional[int]]]] = {}
2127
2128    def e_msg(d: TensorDescr):
2129        return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]"
2130
2131    for descr, array in tensors.values():
2132        if array is None:
2133            axis_sizes = {a.id: None for a in descr.axes}
2134        else:
2135            try:
2136                axis_sizes = descr.get_axis_sizes_for_array(array)
2137            except ValueError as e:
2138                raise ValueError(f"{e_msg(descr)} {e}")
2139
2140        all_tensor_axes[descr.id] = {a.id: (a, axis_sizes[a.id]) for a in descr.axes}
2141
2142    for descr, array in tensors.values():
2143        if array is None:
2144            continue
2145
2146        if descr.dtype in ("float32", "float64"):
2147            invalid_test_tensor_dtype = array.dtype.name not in (
2148                "float32",
2149                "float64",
2150                "uint8",
2151                "int8",
2152                "uint16",
2153                "int16",
2154                "uint32",
2155                "int32",
2156                "uint64",
2157                "int64",
2158            )
2159        else:
2160            invalid_test_tensor_dtype = array.dtype.name != descr.dtype
2161
2162        if invalid_test_tensor_dtype:
2163            raise ValueError(
2164                f"{e_msg(descr)}.{tensor_origin}.dtype '{array.dtype.name}' does not"
2165                + f" match described dtype '{descr.dtype}'"
2166            )
2167
2168        if array.min() > -1e-4 and array.max() < 1e-4:
2169            raise ValueError(
2170                "Output values are too small for reliable testing."
2171                + f" Values <-1e5 or >=1e5 must be present in {tensor_origin}"
2172            )
2173
2174        for a in descr.axes:
2175            actual_size = all_tensor_axes[descr.id][a.id][1]
2176            if actual_size is None:
2177                continue
2178
2179            if a.size is None:
2180                continue
2181
2182            if isinstance(a.size, int):
2183                if actual_size != a.size:
2184                    raise ValueError(
2185                        f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' "
2186                        + f"has incompatible size {actual_size}, expected {a.size}"
2187                    )
2188            elif isinstance(a.size, ParameterizedSize):
2189                _ = a.size.validate_size(actual_size)
2190            elif isinstance(a.size, DataDependentSize):
2191                _ = a.size.validate_size(actual_size)
2192            elif isinstance(a.size, SizeReference):
2193                ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id)
2194                if ref_tensor_axes is None:
2195                    raise ValueError(
2196                        f"{e_msg(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor"
2197                        + f" reference '{a.size.tensor_id}'"
2198                    )
2199
2200                ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None))
2201                if ref_axis is None or ref_size is None:
2202                    raise ValueError(
2203                        f"{e_msg(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis"
2204                        + f" reference '{a.size.tensor_id}.{a.size.axis_id}"
2205                    )
2206
2207                if a.unit != ref_axis.unit:
2208                    raise ValueError(
2209                        f"{e_msg(descr)}.axes[{a.id}].size: `SizeReference` requires"
2210                        + " axis and reference axis to have the same `unit`, but"
2211                        + f" {a.unit}!={ref_axis.unit}"
2212                    )
2213
2214                if actual_size != (
2215                    expected_size := (
2216                        ref_size * ref_axis.scale / a.scale + a.size.offset
2217                    )
2218                ):
2219                    raise ValueError(
2220                        f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' of size"
2221                        + f" {actual_size} invalid for referenced size {ref_size};"
2222                        + f" expected {expected_size}"
2223                    )
2224            else:
2225                assert_never(a.size)
2226
2227
2228FileDescr_dependencies = Annotated[
2229    FileDescr_,
2230    WithSuffix((".yaml", ".yml"), case_sensitive=True),
2231    Field(examples=[dict(source="environment.yaml")]),
2232]
2233
2234
2235class _ArchitectureCallableDescr(Node):
2236    callable: Annotated[Identifier, Field(examples=["MyNetworkClass", "get_my_model"])]
2237    """Identifier of the callable that returns a torch.nn.Module instance."""
2238
2239    kwargs: Dict[str, YamlValue] = Field(
2240        default_factory=cast(Callable[[], Dict[str, YamlValue]], dict)
2241    )
2242    """key word arguments for the `callable`"""
2243
2244
2245class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr):
2246    source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2247    """Architecture source file"""
2248
2249    @model_serializer(mode="wrap", when_used="unless-none")
2250    def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo):
2251        return package_file_descr_serializer(self, nxt, info)
2252
2253
2254class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr):
2255    import_from: str
2256    """Where to import the callable from, i.e. `from <import_from> import <callable>`"""
2257
2258
2259class _ArchFileConv(
2260    Converter[
2261        _CallableFromFile_v0_4,
2262        ArchitectureFromFileDescr,
2263        Optional[Sha256],
2264        Dict[str, Any],
2265    ]
2266):
2267    def _convert(
2268        self,
2269        src: _CallableFromFile_v0_4,
2270        tgt: "type[ArchitectureFromFileDescr | dict[str, Any]]",
2271        sha256: Optional[Sha256],
2272        kwargs: Dict[str, Any],
2273    ) -> "ArchitectureFromFileDescr | dict[str, Any]":
2274        if src.startswith("http") and src.count(":") == 2:
2275            http, source, callable_ = src.split(":")
2276            source = ":".join((http, source))
2277        elif not src.startswith("http") and src.count(":") == 1:
2278            source, callable_ = src.split(":")
2279        else:
2280            source = str(src)
2281            callable_ = str(src)
2282        return tgt(
2283            callable=Identifier(callable_),
2284            source=cast(FileSource_, source),
2285            sha256=sha256,
2286            kwargs=kwargs,
2287        )
2288
2289
2290_arch_file_conv = _ArchFileConv(_CallableFromFile_v0_4, ArchitectureFromFileDescr)
2291
2292
2293class _ArchLibConv(
2294    Converter[
2295        _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr, Dict[str, Any]
2296    ]
2297):
2298    def _convert(
2299        self,
2300        src: _CallableFromDepencency_v0_4,
2301        tgt: "type[ArchitectureFromLibraryDescr | dict[str, Any]]",
2302        kwargs: Dict[str, Any],
2303    ) -> "ArchitectureFromLibraryDescr | dict[str, Any]":
2304        *mods, callable_ = src.split(".")
2305        import_from = ".".join(mods)
2306        return tgt(
2307            import_from=import_from, callable=Identifier(callable_), kwargs=kwargs
2308        )
2309
2310
2311_arch_lib_conv = _ArchLibConv(
2312    _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr
2313)
2314
2315
2316class WeightsEntryDescrBase(FileDescr):
2317    type: ClassVar[WeightsFormat]
2318    weights_format_name: ClassVar[str]  # human readable
2319
2320    source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2321    """Source of the weights file."""
2322
2323    authors: Optional[List[Author]] = None
2324    """Authors
2325    Either the person(s) that have trained this model resulting in the original weights file.
2326        (If this is the initial weights entry, i.e. it does not have a `parent`)
2327    Or the person(s) who have converted the weights to this weights format.
2328        (If this is a child weight, i.e. it has a `parent` field)
2329    """
2330
2331    parent: Annotated[
2332        Optional[WeightsFormat], Field(examples=["pytorch_state_dict"])
2333    ] = None
2334    """The source weights these weights were converted from.
2335    For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`,
2336    The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights.
2337    All weight entries except one (the initial set of weights resulting from training the model),
2338    need to have this field."""
2339
2340    comment: str = ""
2341    """A comment about this weights entry, for example how these weights were created."""
2342
2343    @model_validator(mode="after")
2344    def _validate(self) -> Self:
2345        if self.type == self.parent:
2346            raise ValueError("Weights entry can't be it's own parent.")
2347
2348        return self
2349
2350    @model_serializer(mode="wrap", when_used="unless-none")
2351    def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo):
2352        return package_file_descr_serializer(self, nxt, info)
2353
2354
2355class KerasHdf5WeightsDescr(WeightsEntryDescrBase):
2356    type = "keras_hdf5"
2357    weights_format_name: ClassVar[str] = "Keras HDF5"
2358    tensorflow_version: Version
2359    """TensorFlow version used to create these weights."""
2360
2361
2362class OnnxWeightsDescr(WeightsEntryDescrBase):
2363    type = "onnx"
2364    weights_format_name: ClassVar[str] = "ONNX"
2365    opset_version: Annotated[int, Ge(7)]
2366    """ONNX opset version"""
2367
2368
2369class PytorchStateDictWeightsDescr(WeightsEntryDescrBase):
2370    type = "pytorch_state_dict"
2371    weights_format_name: ClassVar[str] = "Pytorch State Dict"
2372    architecture: Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr]
2373    pytorch_version: Version
2374    """Version of the PyTorch library used.
2375    If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible.
2376    """
2377    dependencies: Optional[FileDescr_dependencies] = None
2378    """Custom depencies beyond pytorch described in a Conda environment file.
2379    Allows to specify custom dependencies, see conda docs:
2380    - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms)
2381    - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually)
2382
2383    The conda environment file should include pytorch and any version pinning has to be compatible with
2384    **pytorch_version**.
2385    """
2386
2387
2388class TensorflowJsWeightsDescr(WeightsEntryDescrBase):
2389    type = "tensorflow_js"
2390    weights_format_name: ClassVar[str] = "Tensorflow.js"
2391    tensorflow_version: Version
2392    """Version of the TensorFlow library used."""
2393
2394    source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2395    """The multi-file weights.
2396    All required files/folders should be a zip archive."""
2397
2398
2399class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase):
2400    type = "tensorflow_saved_model_bundle"
2401    weights_format_name: ClassVar[str] = "Tensorflow Saved Model"
2402    tensorflow_version: Version
2403    """Version of the TensorFlow library used."""
2404
2405    dependencies: Optional[FileDescr_dependencies] = None
2406    """Custom dependencies beyond tensorflow.
2407    Should include tensorflow and any version pinning has to be compatible with **tensorflow_version**."""
2408
2409    source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2410    """The multi-file weights.
2411    All required files/folders should be a zip archive."""
2412
2413
2414class TorchscriptWeightsDescr(WeightsEntryDescrBase):
2415    type = "torchscript"
2416    weights_format_name: ClassVar[str] = "TorchScript"
2417    pytorch_version: Version
2418    """Version of the PyTorch library used."""
2419
2420
2421class WeightsDescr(Node):
2422    keras_hdf5: Optional[KerasHdf5WeightsDescr] = None
2423    onnx: Optional[OnnxWeightsDescr] = None
2424    pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None
2425    tensorflow_js: Optional[TensorflowJsWeightsDescr] = None
2426    tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = (
2427        None
2428    )
2429    torchscript: Optional[TorchscriptWeightsDescr] = None
2430
2431    @model_validator(mode="after")
2432    def check_entries(self) -> Self:
2433        entries = {wtype for wtype, entry in self if entry is not None}
2434
2435        if not entries:
2436            raise ValueError("Missing weights entry")
2437
2438        entries_wo_parent = {
2439            wtype
2440            for wtype, entry in self
2441            if entry is not None and hasattr(entry, "parent") and entry.parent is None
2442        }
2443        if len(entries_wo_parent) != 1:
2444            issue_warning(
2445                "Exactly one weights entry may not specify the `parent` field (got"
2446                + " {value}). That entry is considered the original set of model weights."
2447                + " Other weight formats are created through conversion of the orignal or"
2448                + " already converted weights. They have to reference the weights format"
2449                + " they were converted from as their `parent`.",
2450                value=len(entries_wo_parent),
2451                field="weights",
2452            )
2453
2454        for wtype, entry in self:
2455            if entry is None:
2456                continue
2457
2458            assert hasattr(entry, "type")
2459            assert hasattr(entry, "parent")
2460            assert wtype == entry.type
2461            if (
2462                entry.parent is not None and entry.parent not in entries
2463            ):  # self reference checked for `parent` field
2464                raise ValueError(
2465                    f"`weights.{wtype}.parent={entry.parent} not in specified weight"
2466                    + f" formats: {entries}"
2467                )
2468
2469        return self
2470
2471    def __getitem__(
2472        self,
2473        key: Literal[
2474            "keras_hdf5",
2475            "onnx",
2476            "pytorch_state_dict",
2477            "tensorflow_js",
2478            "tensorflow_saved_model_bundle",
2479            "torchscript",
2480        ],
2481    ):
2482        if key == "keras_hdf5":
2483            ret = self.keras_hdf5
2484        elif key == "onnx":
2485            ret = self.onnx
2486        elif key == "pytorch_state_dict":
2487            ret = self.pytorch_state_dict
2488        elif key == "tensorflow_js":
2489            ret = self.tensorflow_js
2490        elif key == "tensorflow_saved_model_bundle":
2491            ret = self.tensorflow_saved_model_bundle
2492        elif key == "torchscript":
2493            ret = self.torchscript
2494        else:
2495            raise KeyError(key)
2496
2497        if ret is None:
2498            raise KeyError(key)
2499
2500        return ret
2501
2502    @property
2503    def available_formats(self):
2504        return {
2505            **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}),
2506            **({} if self.onnx is None else {"onnx": self.onnx}),
2507            **(
2508                {}
2509                if self.pytorch_state_dict is None
2510                else {"pytorch_state_dict": self.pytorch_state_dict}
2511            ),
2512            **(
2513                {}
2514                if self.tensorflow_js is None
2515                else {"tensorflow_js": self.tensorflow_js}
2516            ),
2517            **(
2518                {}
2519                if self.tensorflow_saved_model_bundle is None
2520                else {
2521                    "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle
2522                }
2523            ),
2524            **({} if self.torchscript is None else {"torchscript": self.torchscript}),
2525        }
2526
2527    @property
2528    def missing_formats(self):
2529        return {
2530            wf for wf in get_args(WeightsFormat) if wf not in self.available_formats
2531        }
2532
2533
2534class ModelId(ResourceId):
2535    pass
2536
2537
2538class LinkedModel(LinkedResourceBase):
2539    """Reference to a bioimage.io model."""
2540
2541    id: ModelId
2542    """A valid model `id` from the bioimage.io collection."""
2543
2544
2545class _DataDepSize(NamedTuple):
2546    min: StrictInt
2547    max: Optional[StrictInt]
2548
2549
2550class _AxisSizes(NamedTuple):
2551    """the lenghts of all axes of model inputs and outputs"""
2552
2553    inputs: Dict[Tuple[TensorId, AxisId], int]
2554    outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]]
2555
2556
2557class _TensorSizes(NamedTuple):
2558    """_AxisSizes as nested dicts"""
2559
2560    inputs: Dict[TensorId, Dict[AxisId, int]]
2561    outputs: Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]
2562
2563
2564class ReproducibilityTolerance(Node, extra="allow"):
2565    """Describes what small numerical differences -- if any -- may be tolerated
2566    in the generated output when executing in different environments.
2567
2568    A tensor element *output* is considered mismatched to the **test_tensor** if
2569    abs(*output* - **test_tensor**) > **absolute_tolerance** + **relative_tolerance** * abs(**test_tensor**).
2570    (Internally we call [numpy.testing.assert_allclose](https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html).)
2571
2572    Motivation:
2573        For testing we can request the respective deep learning frameworks to be as
2574        reproducible as possible by setting seeds and chosing deterministic algorithms,
2575        but differences in operating systems, available hardware and installed drivers
2576        may still lead to numerical differences.
2577    """
2578
2579    relative_tolerance: RelativeTolerance = 1e-3
2580    """Maximum relative tolerance of reproduced test tensor."""
2581
2582    absolute_tolerance: AbsoluteTolerance = 1e-4
2583    """Maximum absolute tolerance of reproduced test tensor."""
2584
2585    mismatched_elements_per_million: MismatchedElementsPerMillion = 100
2586    """Maximum number of mismatched elements/pixels per million to tolerate."""
2587
2588    output_ids: Sequence[TensorId] = ()
2589    """Limits the output tensor IDs these reproducibility details apply to."""
2590
2591    weights_formats: Sequence[WeightsFormat] = ()
2592    """Limits the weights formats these details apply to."""
2593
2594
2595class BioimageioConfig(Node, extra="allow"):
2596    reproducibility_tolerance: Sequence[ReproducibilityTolerance] = ()
2597    """Tolerances to allow when reproducing the model's test outputs
2598    from the model's test inputs.
2599    Only the first entry matching tensor id and weights format is considered.
2600    """
2601
2602
2603class Config(Node, extra="allow"):
2604    bioimageio: BioimageioConfig = Field(
2605        default_factory=BioimageioConfig.model_construct
2606    )
2607
2608
2609class ModelDescr(GenericModelDescrBase):
2610    """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights.
2611    These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
2612    """
2613
2614    implemented_format_version: ClassVar[Literal["0.5.5"]] = "0.5.5"
2615    if TYPE_CHECKING:
2616        format_version: Literal["0.5.5"] = "0.5.5"
2617    else:
2618        format_version: Literal["0.5.5"]
2619        """Version of the bioimage.io model description specification used.
2620        When creating a new model always use the latest micro/patch version described here.
2621        The `format_version` is important for any consumer software to understand how to parse the fields.
2622        """
2623
2624    implemented_type: ClassVar[Literal["model"]] = "model"
2625    if TYPE_CHECKING:
2626        type: Literal["model"] = "model"
2627    else:
2628        type: Literal["model"]
2629        """Specialized resource type 'model'"""
2630
2631    id: Optional[ModelId] = None
2632    """bioimage.io-wide unique resource identifier
2633    assigned by bioimage.io; version **un**specific."""
2634
2635    authors: FAIR[List[Author]] = Field(
2636        default_factory=cast(Callable[[], List[Author]], list)
2637    )
2638    """The authors are the creators of the model RDF and the primary points of contact."""
2639
2640    documentation: FAIR[Optional[FileSource_documentation]] = None
2641    """URL or relative path to a markdown file with additional documentation.
2642    The recommended documentation file name is `README.md`. An `.md` suffix is mandatory.
2643    The documentation should include a '#[#] Validation' (sub)section
2644    with details on how to quantitatively validate the model on unseen data."""
2645
2646    @field_validator("documentation", mode="after")
2647    @classmethod
2648    def _validate_documentation(
2649        cls, value: Optional[FileSource_documentation]
2650    ) -> Optional[FileSource_documentation]:
2651        if not get_validation_context().perform_io_checks or value is None:
2652            return value
2653
2654        doc_reader = get_reader(value)
2655        doc_content = doc_reader.read().decode(encoding="utf-8")
2656        if not re.search("#.*[vV]alidation", doc_content):
2657            issue_warning(
2658                "No '# Validation' (sub)section found in {value}.",
2659                value=value,
2660                field="documentation",
2661            )
2662
2663        return value
2664
2665    inputs: NotEmpty[Sequence[InputTensorDescr]]
2666    """Describes the input tensors expected by this model."""
2667
2668    @field_validator("inputs", mode="after")
2669    @classmethod
2670    def _validate_input_axes(
2671        cls, inputs: Sequence[InputTensorDescr]
2672    ) -> Sequence[InputTensorDescr]:
2673        input_size_refs = cls._get_axes_with_independent_size(inputs)
2674
2675        for i, ipt in enumerate(inputs):
2676            valid_independent_refs: Dict[
2677                Tuple[TensorId, AxisId],
2678                Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2679            ] = {
2680                **{
2681                    (ipt.id, a.id): (ipt, a, a.size)
2682                    for a in ipt.axes
2683                    if not isinstance(a, BatchAxis)
2684                    and isinstance(a.size, (int, ParameterizedSize))
2685                },
2686                **input_size_refs,
2687            }
2688            for a, ax in enumerate(ipt.axes):
2689                cls._validate_axis(
2690                    "inputs",
2691                    i=i,
2692                    tensor_id=ipt.id,
2693                    a=a,
2694                    axis=ax,
2695                    valid_independent_refs=valid_independent_refs,
2696                )
2697        return inputs
2698
2699    @staticmethod
2700    def _validate_axis(
2701        field_name: str,
2702        i: int,
2703        tensor_id: TensorId,
2704        a: int,
2705        axis: AnyAxis,
2706        valid_independent_refs: Dict[
2707            Tuple[TensorId, AxisId],
2708            Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2709        ],
2710    ):
2711        if isinstance(axis, BatchAxis) or isinstance(
2712            axis.size, (int, ParameterizedSize, DataDependentSize)
2713        ):
2714            return
2715        elif not isinstance(axis.size, SizeReference):
2716            assert_never(axis.size)
2717
2718        # validate axis.size SizeReference
2719        ref = (axis.size.tensor_id, axis.size.axis_id)
2720        if ref not in valid_independent_refs:
2721            raise ValueError(
2722                "Invalid tensor axis reference at"
2723                + f" {field_name}[{i}].axes[{a}].size: {axis.size}."
2724            )
2725        if ref == (tensor_id, axis.id):
2726            raise ValueError(
2727                "Self-referencing not allowed for"
2728                + f" {field_name}[{i}].axes[{a}].size: {axis.size}"
2729            )
2730        if axis.type == "channel":
2731            if valid_independent_refs[ref][1].type != "channel":
2732                raise ValueError(
2733                    "A channel axis' size may only reference another fixed size"
2734                    + " channel axis."
2735                )
2736            if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names:
2737                ref_size = valid_independent_refs[ref][2]
2738                assert isinstance(ref_size, int), (
2739                    "channel axis ref (another channel axis) has to specify fixed"
2740                    + " size"
2741                )
2742                generated_channel_names = [
2743                    Identifier(axis.channel_names.format(i=i))
2744                    for i in range(1, ref_size + 1)
2745                ]
2746                axis.channel_names = generated_channel_names
2747
2748        if (ax_unit := getattr(axis, "unit", None)) != (
2749            ref_unit := getattr(valid_independent_refs[ref][1], "unit", None)
2750        ):
2751            raise ValueError(
2752                "The units of an axis and its reference axis need to match, but"
2753                + f" '{ax_unit}' != '{ref_unit}'."
2754            )
2755        ref_axis = valid_independent_refs[ref][1]
2756        if isinstance(ref_axis, BatchAxis):
2757            raise ValueError(
2758                f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}"
2759                + " (a batch axis is not allowed as reference)."
2760            )
2761
2762        if isinstance(axis, WithHalo):
2763            min_size = axis.size.get_size(axis, ref_axis, n=0)
2764            if (min_size - 2 * axis.halo) < 1:
2765                raise ValueError(
2766                    f"axis {axis.id} with minimum size {min_size} is too small for halo"
2767                    + f" {axis.halo}."
2768                )
2769
2770            input_halo = axis.halo * axis.scale / ref_axis.scale
2771            if input_halo != int(input_halo) or input_halo % 2 == 1:
2772                raise ValueError(
2773                    f"input_halo {input_halo} (output_halo {axis.halo} *"
2774                    + f" output_scale {axis.scale} / input_scale {ref_axis.scale})"
2775                    + f"     {tensor_id}.{axis.id}."
2776                )
2777
2778    @model_validator(mode="after")
2779    def _validate_test_tensors(self) -> Self:
2780        if not get_validation_context().perform_io_checks:
2781            return self
2782
2783        test_output_arrays = [
2784            None if descr.test_tensor is None else load_array(descr.test_tensor)
2785            for descr in self.outputs
2786        ]
2787        test_input_arrays = [
2788            None if descr.test_tensor is None else load_array(descr.test_tensor)
2789            for descr in self.inputs
2790        ]
2791
2792        tensors = {
2793            descr.id: (descr, array)
2794            for descr, array in zip(
2795                chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays
2796            )
2797        }
2798        validate_tensors(tensors, tensor_origin="test_tensor")
2799
2800        output_arrays = {
2801            descr.id: array for descr, array in zip(self.outputs, test_output_arrays)
2802        }
2803        for rep_tol in self.config.bioimageio.reproducibility_tolerance:
2804            if not rep_tol.absolute_tolerance:
2805                continue
2806
2807            if rep_tol.output_ids:
2808                out_arrays = {
2809                    oid: a
2810                    for oid, a in output_arrays.items()
2811                    if oid in rep_tol.output_ids
2812                }
2813            else:
2814                out_arrays = output_arrays
2815
2816            for out_id, array in out_arrays.items():
2817                if array is None:
2818                    continue
2819
2820                if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01:
2821                    raise ValueError(
2822                        "config.bioimageio.reproducibility_tolerance.absolute_tolerance="
2823                        + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}"
2824                        + f" (1% of the maximum value of the test tensor '{out_id}')"
2825                    )
2826
2827        return self
2828
2829    @model_validator(mode="after")
2830    def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self:
2831        ipt_refs = {t.id for t in self.inputs}
2832        out_refs = {t.id for t in self.outputs}
2833        for ipt in self.inputs:
2834            for p in ipt.preprocessing:
2835                ref = p.kwargs.get("reference_tensor")
2836                if ref is None:
2837                    continue
2838                if ref not in ipt_refs:
2839                    raise ValueError(
2840                        f"`reference_tensor` '{ref}' not found. Valid input tensor"
2841                        + f" references are: {ipt_refs}."
2842                    )
2843
2844        for out in self.outputs:
2845            for p in out.postprocessing:
2846                ref = p.kwargs.get("reference_tensor")
2847                if ref is None:
2848                    continue
2849
2850                if ref not in ipt_refs and ref not in out_refs:
2851                    raise ValueError(
2852                        f"`reference_tensor` '{ref}' not found. Valid tensor references"
2853                        + f" are: {ipt_refs | out_refs}."
2854                    )
2855
2856        return self
2857
2858    # TODO: use validate funcs in validate_test_tensors
2859    # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]:
2860
2861    name: Annotated[
2862        str,
2863        RestrictCharacters(string.ascii_letters + string.digits + "_+- ()"),
2864        MinLen(5),
2865        MaxLen(128),
2866        warn(MaxLen(64), "Name longer than 64 characters.", INFO),
2867    ]
2868    """A human-readable name of this model.
2869    It should be no longer than 64 characters
2870    and may only contain letter, number, underscore, minus, parentheses and spaces.
2871    We recommend to chose a name that refers to the model's task and image modality.
2872    """
2873
2874    outputs: NotEmpty[Sequence[OutputTensorDescr]]
2875    """Describes the output tensors."""
2876
2877    @field_validator("outputs", mode="after")
2878    @classmethod
2879    def _validate_tensor_ids(
2880        cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo
2881    ) -> Sequence[OutputTensorDescr]:
2882        tensor_ids = [
2883            t.id for t in info.data.get("inputs", []) + info.data.get("outputs", [])
2884        ]
2885        duplicate_tensor_ids: List[str] = []
2886        seen: Set[str] = set()
2887        for t in tensor_ids:
2888            if t in seen:
2889                duplicate_tensor_ids.append(t)
2890
2891            seen.add(t)
2892
2893        if duplicate_tensor_ids:
2894            raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}")
2895
2896        return outputs
2897
2898    @staticmethod
2899    def _get_axes_with_parameterized_size(
2900        io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2901    ):
2902        return {
2903            f"{t.id}.{a.id}": (t, a, a.size)
2904            for t in io
2905            for a in t.axes
2906            if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize)
2907        }
2908
2909    @staticmethod
2910    def _get_axes_with_independent_size(
2911        io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2912    ):
2913        return {
2914            (t.id, a.id): (t, a, a.size)
2915            for t in io
2916            for a in t.axes
2917            if not isinstance(a, BatchAxis)
2918            and isinstance(a.size, (int, ParameterizedSize))
2919        }
2920
2921    @field_validator("outputs", mode="after")
2922    @classmethod
2923    def _validate_output_axes(
2924        cls, outputs: List[OutputTensorDescr], info: ValidationInfo
2925    ) -> List[OutputTensorDescr]:
2926        input_size_refs = cls._get_axes_with_independent_size(
2927            info.data.get("inputs", [])
2928        )
2929        output_size_refs = cls._get_axes_with_independent_size(outputs)
2930
2931        for i, out in enumerate(outputs):
2932            valid_independent_refs: Dict[
2933                Tuple[TensorId, AxisId],
2934                Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2935            ] = {
2936                **{
2937                    (out.id, a.id): (out, a, a.size)
2938                    for a in out.axes
2939                    if not isinstance(a, BatchAxis)
2940                    and isinstance(a.size, (int, ParameterizedSize))
2941                },
2942                **input_size_refs,
2943                **output_size_refs,
2944            }
2945            for a, ax in enumerate(out.axes):
2946                cls._validate_axis(
2947                    "outputs",
2948                    i,
2949                    out.id,
2950                    a,
2951                    ax,
2952                    valid_independent_refs=valid_independent_refs,
2953                )
2954
2955        return outputs
2956
2957    packaged_by: List[Author] = Field(
2958        default_factory=cast(Callable[[], List[Author]], list)
2959    )
2960    """The persons that have packaged and uploaded this model.
2961    Only required if those persons differ from the `authors`."""
2962
2963    parent: Optional[LinkedModel] = None
2964    """The model from which this model is derived, e.g. by fine-tuning the weights."""
2965
2966    @model_validator(mode="after")
2967    def _validate_parent_is_not_self(self) -> Self:
2968        if self.parent is not None and self.parent.id == self.id:
2969            raise ValueError("A model description may not reference itself as parent.")
2970
2971        return self
2972
2973    run_mode: Annotated[
2974        Optional[RunMode],
2975        warn(None, "Run mode '{value}' has limited support across consumer softwares."),
2976    ] = None
2977    """Custom run mode for this model: for more complex prediction procedures like test time
2978    data augmentation that currently cannot be expressed in the specification.
2979    No standard run modes are defined yet."""
2980
2981    timestamp: Datetime = Field(default_factory=Datetime.now)
2982    """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format
2983    with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat).
2984    (In Python a datetime object is valid, too)."""
2985
2986    training_data: Annotated[
2987        Union[None, LinkedDataset, DatasetDescr, DatasetDescr02],
2988        Field(union_mode="left_to_right"),
2989    ] = None
2990    """The dataset used to train this model"""
2991
2992    weights: Annotated[WeightsDescr, WrapSerializer(package_weights)]
2993    """The weights for this model.
2994    Weights can be given for different formats, but should otherwise be equivalent.
2995    The available weight formats determine which consumers can use this model."""
2996
2997    config: Config = Field(default_factory=Config.model_construct)
2998
2999    @model_validator(mode="after")
3000    def _add_default_cover(self) -> Self:
3001        if not get_validation_context().perform_io_checks or self.covers:
3002            return self
3003
3004        try:
3005            generated_covers = generate_covers(
3006                [
3007                    (t, load_array(t.test_tensor))
3008                    for t in self.inputs
3009                    if t.test_tensor is not None
3010                ],
3011                [
3012                    (t, load_array(t.test_tensor))
3013                    for t in self.outputs
3014                    if t.test_tensor is not None
3015                ],
3016            )
3017        except Exception as e:
3018            issue_warning(
3019                "Failed to generate cover image(s): {e}",
3020                value=self.covers,
3021                msg_context=dict(e=e),
3022                field="covers",
3023            )
3024        else:
3025            self.covers.extend(generated_covers)
3026
3027        return self
3028
3029    def get_input_test_arrays(self) -> List[NDArray[Any]]:
3030        return self._get_test_arrays(self.inputs)
3031
3032    def get_output_test_arrays(self) -> List[NDArray[Any]]:
3033        return self._get_test_arrays(self.outputs)
3034
3035    @staticmethod
3036    def _get_test_arrays(
3037        io_descr: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
3038    ):
3039        ts: List[FileDescr] = []
3040        for d in io_descr:
3041            if d.test_tensor is None:
3042                raise ValueError(
3043                    f"Failed to get test arrays: description of '{d.id}' is missing a `test_tensor`."
3044                )
3045            ts.append(d.test_tensor)
3046
3047        data = [load_array(t) for t in ts]
3048        assert all(isinstance(d, np.ndarray) for d in data)
3049        return data
3050
3051    @staticmethod
3052    def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
3053        batch_size = 1
3054        tensor_with_batchsize: Optional[TensorId] = None
3055        for tid in tensor_sizes:
3056            for aid, s in tensor_sizes[tid].items():
3057                if aid != BATCH_AXIS_ID or s == 1 or s == batch_size:
3058                    continue
3059
3060                if batch_size != 1:
3061                    assert tensor_with_batchsize is not None
3062                    raise ValueError(
3063                        f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})"
3064                    )
3065
3066                batch_size = s
3067                tensor_with_batchsize = tid
3068
3069        return batch_size
3070
3071    def get_output_tensor_sizes(
3072        self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]
3073    ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]:
3074        """Returns the tensor output sizes for given **input_sizes**.
3075        Only if **input_sizes** has a valid input shape, the tensor output size is exact.
3076        Otherwise it might be larger than the actual (valid) output"""
3077        batch_size = self.get_batch_size(input_sizes)
3078        ns = self.get_ns(input_sizes)
3079
3080        tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size)
3081        return tensor_sizes.outputs
3082
3083    def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
3084        """get parameter `n` for each parameterized axis
3085        such that the valid input size is >= the given input size"""
3086        ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {}
3087        axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs}
3088        for tid in input_sizes:
3089            for aid, s in input_sizes[tid].items():
3090                size_descr = axes[tid][aid].size
3091                if isinstance(size_descr, ParameterizedSize):
3092                    ret[(tid, aid)] = size_descr.get_n(s)
3093                elif size_descr is None or isinstance(size_descr, (int, SizeReference)):
3094                    pass
3095                else:
3096                    assert_never(size_descr)
3097
3098        return ret
3099
3100    def get_tensor_sizes(
3101        self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int
3102    ) -> _TensorSizes:
3103        axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size)
3104        return _TensorSizes(
3105            {
3106                t: {
3107                    aa: axis_sizes.inputs[(tt, aa)]
3108                    for tt, aa in axis_sizes.inputs
3109                    if tt == t
3110                }
3111                for t in {tt for tt, _ in axis_sizes.inputs}
3112            },
3113            {
3114                t: {
3115                    aa: axis_sizes.outputs[(tt, aa)]
3116                    for tt, aa in axis_sizes.outputs
3117                    if tt == t
3118                }
3119                for t in {tt for tt, _ in axis_sizes.outputs}
3120            },
3121        )
3122
3123    def get_axis_sizes(
3124        self,
3125        ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
3126        batch_size: Optional[int] = None,
3127        *,
3128        max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
3129    ) -> _AxisSizes:
3130        """Determine input and output block shape for scale factors **ns**
3131        of parameterized input sizes.
3132
3133        Args:
3134            ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id))
3135                that is parameterized as `size = min + n * step`.
3136            batch_size: The desired size of the batch dimension.
3137                If given **batch_size** overwrites any batch size present in
3138                **max_input_shape**. Default 1.
3139            max_input_shape: Limits the derived block shapes.
3140                Each axis for which the input size, parameterized by `n`, is larger
3141                than **max_input_shape** is set to the minimal value `n_min` for which
3142                this is still true.
3143                Use this for small input samples or large values of **ns**.
3144                Or simply whenever you know the full input shape.
3145
3146        Returns:
3147            Resolved axis sizes for model inputs and outputs.
3148        """
3149        max_input_shape = max_input_shape or {}
3150        if batch_size is None:
3151            for (_t_id, a_id), s in max_input_shape.items():
3152                if a_id == BATCH_AXIS_ID:
3153                    batch_size = s
3154                    break
3155            else:
3156                batch_size = 1
3157
3158        all_axes = {
3159            t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
3160        }
3161
3162        inputs: Dict[Tuple[TensorId, AxisId], int] = {}
3163        outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {}
3164
3165        def get_axis_size(a: Union[InputAxis, OutputAxis]):
3166            if isinstance(a, BatchAxis):
3167                if (t_descr.id, a.id) in ns:
3168                    logger.warning(
3169                        "Ignoring unexpected size increment factor (n) for batch axis"
3170                        + " of tensor '{}'.",
3171                        t_descr.id,
3172                    )
3173                return batch_size
3174            elif isinstance(a.size, int):
3175                if (t_descr.id, a.id) in ns:
3176                    logger.warning(
3177                        "Ignoring unexpected size increment factor (n) for fixed size"
3178                        + " axis '{}' of tensor '{}'.",
3179                        a.id,
3180                        t_descr.id,
3181                    )
3182                return a.size
3183            elif isinstance(a.size, ParameterizedSize):
3184                if (t_descr.id, a.id) not in ns:
3185                    raise ValueError(
3186                        "Size increment factor (n) missing for parametrized axis"
3187                        + f" '{a.id}' of tensor '{t_descr.id}'."
3188                    )
3189                n = ns[(t_descr.id, a.id)]
3190                s_max = max_input_shape.get((t_descr.id, a.id))
3191                if s_max is not None:
3192                    n = min(n, a.size.get_n(s_max))
3193
3194                return a.size.get_size(n)
3195
3196            elif isinstance(a.size, SizeReference):
3197                if (t_descr.id, a.id) in ns:
3198                    logger.warning(
3199                        "Ignoring unexpected size increment factor (n) for axis '{}'"
3200                        + " of tensor '{}' with size reference.",
3201                        a.id,
3202                        t_descr.id,
3203                    )
3204                assert not isinstance(a, BatchAxis)
3205                ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
3206                assert not isinstance(ref_axis, BatchAxis)
3207                ref_key = (a.size.tensor_id, a.size.axis_id)
3208                ref_size = inputs.get(ref_key, outputs.get(ref_key))
3209                assert ref_size is not None, ref_key
3210                assert not isinstance(ref_size, _DataDepSize), ref_key
3211                return a.size.get_size(
3212                    axis=a,
3213                    ref_axis=ref_axis,
3214                    ref_size=ref_size,
3215                )
3216            elif isinstance(a.size, DataDependentSize):
3217                if (t_descr.id, a.id) in ns:
3218                    logger.warning(
3219                        "Ignoring unexpected increment factor (n) for data dependent"
3220                        + " size axis '{}' of tensor '{}'.",
3221                        a.id,
3222                        t_descr.id,
3223                    )
3224                return _DataDepSize(a.size.min, a.size.max)
3225            else:
3226                assert_never(a.size)
3227
3228        # first resolve all , but the `SizeReference` input sizes
3229        for t_descr in self.inputs:
3230            for a in t_descr.axes:
3231                if not isinstance(a.size, SizeReference):
3232                    s = get_axis_size(a)
3233                    assert not isinstance(s, _DataDepSize)
3234                    inputs[t_descr.id, a.id] = s
3235
3236        # resolve all other input axis sizes
3237        for t_descr in self.inputs:
3238            for a in t_descr.axes:
3239                if isinstance(a.size, SizeReference):
3240                    s = get_axis_size(a)
3241                    assert not isinstance(s, _DataDepSize)
3242                    inputs[t_descr.id, a.id] = s
3243
3244        # resolve all output axis sizes
3245        for t_descr in self.outputs:
3246            for a in t_descr.axes:
3247                assert not isinstance(a.size, ParameterizedSize)
3248                s = get_axis_size(a)
3249                outputs[t_descr.id, a.id] = s
3250
3251        return _AxisSizes(inputs=inputs, outputs=outputs)
3252
3253    @model_validator(mode="before")
3254    @classmethod
3255    def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]:
3256        cls.convert_from_old_format_wo_validation(data)
3257        return data
3258
3259    @classmethod
3260    def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3261        """Convert metadata following an older format version to this classes' format
3262        without validating the result.
3263        """
3264        if (
3265            data.get("type") == "model"
3266            and isinstance(fv := data.get("format_version"), str)
3267            and fv.count(".") == 2
3268        ):
3269            fv_parts = fv.split(".")
3270            if any(not p.isdigit() for p in fv_parts):
3271                return
3272
3273            fv_tuple = tuple(map(int, fv_parts))
3274
3275            assert cls.implemented_format_version_tuple[0:2] == (0, 5)
3276            if fv_tuple[:2] in ((0, 3), (0, 4)):
3277                m04 = _ModelDescr_v0_4.load(data)
3278                if isinstance(m04, InvalidDescr):
3279                    try:
3280                        updated = _model_conv.convert_as_dict(
3281                            m04  # pyright: ignore[reportArgumentType]
3282                        )
3283                    except Exception as e:
3284                        logger.error(
3285                            "Failed to convert from invalid model 0.4 description."
3286                            + f"\nerror: {e}"
3287                            + "\nProceeding with model 0.5 validation without conversion."
3288                        )
3289                        updated = None
3290                else:
3291                    updated = _model_conv.convert_as_dict(m04)
3292
3293                if updated is not None:
3294                    data.clear()
3295                    data.update(updated)
3296
3297            elif fv_tuple[:2] == (0, 5):
3298                # bump patch version
3299                data["format_version"] = cls.implemented_format_version
3300
3301
3302class _ModelConv(Converter[_ModelDescr_v0_4, ModelDescr]):
3303    def _convert(
3304        self, src: _ModelDescr_v0_4, tgt: "type[ModelDescr] | type[dict[str, Any]]"
3305    ) -> "ModelDescr | dict[str, Any]":
3306        name = "".join(
3307            c if c in string.ascii_letters + string.digits + "_+- ()" else " "
3308            for c in src.name
3309        )
3310
3311        def conv_authors(auths: Optional[Sequence[_Author_v0_4]]):
3312            conv = (
3313                _author_conv.convert if TYPE_CHECKING else _author_conv.convert_as_dict
3314            )
3315            return None if auths is None else [conv(a) for a in auths]
3316
3317        if TYPE_CHECKING:
3318            arch_file_conv = _arch_file_conv.convert
3319            arch_lib_conv = _arch_lib_conv.convert
3320        else:
3321            arch_file_conv = _arch_file_conv.convert_as_dict
3322            arch_lib_conv = _arch_lib_conv.convert_as_dict
3323
3324        input_size_refs = {
3325            ipt.name: {
3326                a: s
3327                for a, s in zip(
3328                    ipt.axes,
3329                    (
3330                        ipt.shape.min
3331                        if isinstance(ipt.shape, _ParameterizedInputShape_v0_4)
3332                        else ipt.shape
3333                    ),
3334                )
3335            }
3336            for ipt in src.inputs
3337            if ipt.shape
3338        }
3339        output_size_refs = {
3340            **{
3341                out.name: {a: s for a, s in zip(out.axes, out.shape)}
3342                for out in src.outputs
3343                if not isinstance(out.shape, _ImplicitOutputShape_v0_4)
3344            },
3345            **input_size_refs,
3346        }
3347
3348        return tgt(
3349            attachments=(
3350                []
3351                if src.attachments is None
3352                else [FileDescr(source=f) for f in src.attachments.files]
3353            ),
3354            authors=[_author_conv.convert_as_dict(a) for a in src.authors],  # pyright: ignore[reportArgumentType]
3355            cite=[{"text": c.text, "doi": c.doi, "url": c.url} for c in src.cite],  # pyright: ignore[reportArgumentType]
3356            config=src.config,  # pyright: ignore[reportArgumentType]
3357            covers=src.covers,
3358            description=src.description,
3359            documentation=src.documentation,
3360            format_version="0.5.5",
3361            git_repo=src.git_repo,  # pyright: ignore[reportArgumentType]
3362            icon=src.icon,
3363            id=None if src.id is None else ModelId(src.id),
3364            id_emoji=src.id_emoji,
3365            license=src.license,  # type: ignore
3366            links=src.links,
3367            maintainers=[_maintainer_conv.convert_as_dict(m) for m in src.maintainers],  # pyright: ignore[reportArgumentType]
3368            name=name,
3369            tags=src.tags,
3370            type=src.type,
3371            uploader=src.uploader,
3372            version=src.version,
3373            inputs=[  # pyright: ignore[reportArgumentType]
3374                _input_tensor_conv.convert_as_dict(ipt, tt, st, input_size_refs)
3375                for ipt, tt, st in zip(
3376                    src.inputs,
3377                    src.test_inputs,
3378                    src.sample_inputs or [None] * len(src.test_inputs),
3379                )
3380            ],
3381            outputs=[  # pyright: ignore[reportArgumentType]
3382                _output_tensor_conv.convert_as_dict(out, tt, st, output_size_refs)
3383                for out, tt, st in zip(
3384                    src.outputs,
3385                    src.test_outputs,
3386                    src.sample_outputs or [None] * len(src.test_outputs),
3387                )
3388            ],
3389            parent=(
3390                None
3391                if src.parent is None
3392                else LinkedModel(
3393                    id=ModelId(
3394                        str(src.parent.id)
3395                        + (
3396                            ""
3397                            if src.parent.version_number is None
3398                            else f"/{src.parent.version_number}"
3399                        )
3400                    )
3401                )
3402            ),
3403            training_data=(
3404                None
3405                if src.training_data is None
3406                else (
3407                    LinkedDataset(
3408                        id=DatasetId(
3409                            str(src.training_data.id)
3410                            + (
3411                                ""
3412                                if src.training_data.version_number is None
3413                                else f"/{src.training_data.version_number}"
3414                            )
3415                        )
3416                    )
3417                    if isinstance(src.training_data, LinkedDataset02)
3418                    else src.training_data
3419                )
3420            ),
3421            packaged_by=[_author_conv.convert_as_dict(a) for a in src.packaged_by],  # pyright: ignore[reportArgumentType]
3422            run_mode=src.run_mode,
3423            timestamp=src.timestamp,
3424            weights=(WeightsDescr if TYPE_CHECKING else dict)(
3425                keras_hdf5=(w := src.weights.keras_hdf5)
3426                and (KerasHdf5WeightsDescr if TYPE_CHECKING else dict)(
3427                    authors=conv_authors(w.authors),
3428                    source=w.source,
3429                    tensorflow_version=w.tensorflow_version or Version("1.15"),
3430                    parent=w.parent,
3431                ),
3432                onnx=(w := src.weights.onnx)
3433                and (OnnxWeightsDescr if TYPE_CHECKING else dict)(
3434                    source=w.source,
3435                    authors=conv_authors(w.authors),
3436                    parent=w.parent,
3437                    opset_version=w.opset_version or 15,
3438                ),
3439                pytorch_state_dict=(w := src.weights.pytorch_state_dict)
3440                and (PytorchStateDictWeightsDescr if TYPE_CHECKING else dict)(
3441                    source=w.source,
3442                    authors=conv_authors(w.authors),
3443                    parent=w.parent,
3444                    architecture=(
3445                        arch_file_conv(
3446                            w.architecture,
3447                            w.architecture_sha256,
3448                            w.kwargs,
3449                        )
3450                        if isinstance(w.architecture, _CallableFromFile_v0_4)
3451                        else arch_lib_conv(w.architecture, w.kwargs)
3452                    ),
3453                    pytorch_version=w.pytorch_version or Version("1.10"),
3454                    dependencies=(
3455                        None
3456                        if w.dependencies is None
3457                        else (FileDescr if TYPE_CHECKING else dict)(
3458                            source=cast(
3459                                FileSource,
3460                                str(deps := w.dependencies)[
3461                                    (
3462                                        len("conda:")
3463                                        if str(deps).startswith("conda:")
3464                                        else 0
3465                                    ) :
3466                                ],
3467                            )
3468                        )
3469                    ),
3470                ),
3471                tensorflow_js=(w := src.weights.tensorflow_js)
3472                and (TensorflowJsWeightsDescr if TYPE_CHECKING else dict)(
3473                    source=w.source,
3474                    authors=conv_authors(w.authors),
3475                    parent=w.parent,
3476                    tensorflow_version=w.tensorflow_version or Version("1.15"),
3477                ),
3478                tensorflow_saved_model_bundle=(
3479                    w := src.weights.tensorflow_saved_model_bundle
3480                )
3481                and (TensorflowSavedModelBundleWeightsDescr if TYPE_CHECKING else dict)(
3482                    authors=conv_authors(w.authors),
3483                    parent=w.parent,
3484                    source=w.source,
3485                    tensorflow_version=w.tensorflow_version or Version("1.15"),
3486                    dependencies=(
3487                        None
3488                        if w.dependencies is None
3489                        else (FileDescr if TYPE_CHECKING else dict)(
3490                            source=cast(
3491                                FileSource,
3492                                (
3493                                    str(w.dependencies)[len("conda:") :]
3494                                    if str(w.dependencies).startswith("conda:")
3495                                    else str(w.dependencies)
3496                                ),
3497                            )
3498                        )
3499                    ),
3500                ),
3501                torchscript=(w := src.weights.torchscript)
3502                and (TorchscriptWeightsDescr if TYPE_CHECKING else dict)(
3503                    source=w.source,
3504                    authors=conv_authors(w.authors),
3505                    parent=w.parent,
3506                    pytorch_version=w.pytorch_version or Version("1.10"),
3507                ),
3508            ),
3509        )
3510
3511
3512_model_conv = _ModelConv(_ModelDescr_v0_4, ModelDescr)
3513
3514
3515# create better cover images for 3d data and non-image outputs
3516def generate_covers(
3517    inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]],
3518    outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]],
3519) -> List[Path]:
3520    def squeeze(
3521        data: NDArray[Any], axes: Sequence[AnyAxis]
3522    ) -> Tuple[NDArray[Any], List[AnyAxis]]:
3523        """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining"""
3524        if data.ndim != len(axes):
3525            raise ValueError(
3526                f"tensor shape {data.shape} does not match described axes"
3527                + f" {[a.id for a in axes]}"
3528            )
3529
3530        axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1]
3531        return data.squeeze(), axes
3532
3533    def normalize(
3534        data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7
3535    ) -> NDArray[np.float32]:
3536        data = data.astype("float32")
3537        data -= data.min(axis=axis, keepdims=True)
3538        data /= data.max(axis=axis, keepdims=True) + eps
3539        return data
3540
3541    def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]):
3542        original_shape = data.shape
3543        data, axes = squeeze(data, axes)
3544
3545        # take slice fom any batch or index axis if needed
3546        # and convert the first channel axis and take a slice from any additional channel axes
3547        slices: Tuple[slice, ...] = ()
3548        ndim = data.ndim
3549        ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2
3550        has_c_axis = False
3551        for i, a in enumerate(axes):
3552            s = data.shape[i]
3553            assert s > 1
3554            if (
3555                isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis))
3556                and ndim > ndim_need
3557            ):
3558                data = data[slices + (slice(s // 2 - 1, s // 2),)]
3559                ndim -= 1
3560            elif isinstance(a, ChannelAxis):
3561                if has_c_axis:
3562                    # second channel axis
3563                    data = data[slices + (slice(0, 1),)]
3564                    ndim -= 1
3565                else:
3566                    has_c_axis = True
3567                    if s == 2:
3568                        # visualize two channels with cyan and magenta
3569                        data = np.concatenate(
3570                            [
3571                                data[slices + (slice(1, 2),)],
3572                                data[slices + (slice(0, 1),)],
3573                                (
3574                                    data[slices + (slice(0, 1),)]
3575                                    + data[slices + (slice(1, 2),)]
3576                                )
3577                                / 2,  # TODO: take maximum instead?
3578                            ],
3579                            axis=i,
3580                        )
3581                    elif data.shape[i] == 3:
3582                        pass  # visualize 3 channels as RGB
3583                    else:
3584                        # visualize first 3 channels as RGB
3585                        data = data[slices + (slice(3),)]
3586
3587                    assert data.shape[i] == 3
3588
3589            slices += (slice(None),)
3590
3591        data, axes = squeeze(data, axes)
3592        assert len(axes) == ndim
3593        # take slice from z axis if needed
3594        slices = ()
3595        if ndim > ndim_need:
3596            for i, a in enumerate(axes):
3597                s = data.shape[i]
3598                if a.id == AxisId("z"):
3599                    data = data[slices + (slice(s // 2 - 1, s // 2),)]
3600                    data, axes = squeeze(data, axes)
3601                    ndim -= 1
3602                    break
3603
3604            slices += (slice(None),)
3605
3606        # take slice from any space or time axis
3607        slices = ()
3608
3609        for i, a in enumerate(axes):
3610            if ndim <= ndim_need:
3611                break
3612
3613            s = data.shape[i]
3614            assert s > 1
3615            if isinstance(
3616                a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis)
3617            ):
3618                data = data[slices + (slice(s // 2 - 1, s // 2),)]
3619                ndim -= 1
3620
3621            slices += (slice(None),)
3622
3623        del slices
3624        data, axes = squeeze(data, axes)
3625        assert len(axes) == ndim
3626
3627        if (has_c_axis and ndim != 3) or ndim != 2:
3628            raise ValueError(
3629                f"Failed to construct cover image from shape {original_shape}"
3630            )
3631
3632        if not has_c_axis:
3633            assert ndim == 2
3634            data = np.repeat(data[:, :, None], 3, axis=2)
3635            axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB"))))
3636            ndim += 1
3637
3638        assert ndim == 3
3639
3640        # transpose axis order such that longest axis comes first...
3641        axis_order: List[int] = list(np.argsort(list(data.shape)))
3642        axis_order.reverse()
3643        # ... and channel axis is last
3644        c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0]
3645        axis_order.append(axis_order.pop(c))
3646        axes = [axes[ao] for ao in axis_order]
3647        data = data.transpose(axis_order)
3648
3649        # h, w = data.shape[:2]
3650        # if h / w  in (1.0 or 2.0):
3651        #     pass
3652        # elif h / w < 2:
3653        # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images
3654
3655        norm_along = (
3656            tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None
3657        )
3658        # normalize the data and map to 8 bit
3659        data = normalize(data, norm_along)
3660        data = (data * 255).astype("uint8")
3661
3662        return data
3663
3664    def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]):
3665        assert im0.dtype == im1.dtype == np.uint8
3666        assert im0.shape == im1.shape
3667        assert im0.ndim == 3
3668        N, M, C = im0.shape
3669        assert C == 3
3670        out = np.ones((N, M, C), dtype="uint8")
3671        for c in range(C):
3672            outc = np.tril(im0[..., c])
3673            mask = outc == 0
3674            outc[mask] = np.triu(im1[..., c])[mask]
3675            out[..., c] = outc
3676
3677        return out
3678
3679    if not inputs:
3680        raise ValueError("Missing test input tensor for cover generation.")
3681
3682    if not outputs:
3683        raise ValueError("Missing test output tensor for cover generation.")
3684
3685    ipt_descr, ipt = inputs[0]
3686    out_descr, out = outputs[0]
3687
3688    ipt_img = to_2d_image(ipt, ipt_descr.axes)
3689    out_img = to_2d_image(out, out_descr.axes)
3690
3691    cover_folder = Path(mkdtemp())
3692    if ipt_img.shape == out_img.shape:
3693        covers = [cover_folder / "cover.png"]
3694        imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img))
3695    else:
3696        covers = [cover_folder / "input.png", cover_folder / "output.png"]
3697        imwrite(covers[0], ipt_img)
3698        imwrite(covers[1], out_img)
3699
3700    return covers
SpaceUnit = typing.Literal['attometer', 'angstrom', 'centimeter', 'decimeter', 'exameter', 'femtometer', 'foot', 'gigameter', 'hectometer', 'inch', 'kilometer', 'megameter', 'meter', 'micrometer', 'mile', 'millimeter', 'nanometer', 'parsec', 'petameter', 'picometer', 'terameter', 'yard', 'yoctometer', 'yottameter', 'zeptometer', 'zettameter']

Space unit compatible to the OME-Zarr axes specification 0.5

TimeUnit = typing.Literal['attosecond', 'centisecond', 'day', 'decisecond', 'exasecond', 'femtosecond', 'gigasecond', 'hectosecond', 'hour', 'kilosecond', 'megasecond', 'microsecond', 'millisecond', 'minute', 'nanosecond', 'petasecond', 'picosecond', 'second', 'terasecond', 'yoctosecond', 'yottasecond', 'zeptosecond', 'zettasecond']

Time unit compatible to the OME-Zarr axes specification 0.5

AxisType = typing.Literal['batch', 'channel', 'index', 'time', 'space']
229class TensorId(LowerCaseIdentifier):
230    root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
231        Annotated[LowerCaseIdentifierAnno, MaxLen(32)]
232    ]

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.

root_model: ClassVar[Type[pydantic.root_model.RootModel[Any]]] = <class 'pydantic.root_model.RootModel[Annotated[str, MinLen, AfterValidator, AfterValidator, Annotated[TypeVar, Predicate], MaxLen]]'>

the pydantic root model to validate the string

245class AxisId(LowerCaseIdentifier):
246    root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
247        Annotated[
248            LowerCaseIdentifierAnno,
249            MaxLen(16),
250            AfterValidator(_normalize_axis_id),
251        ]
252    ]

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.

root_model: ClassVar[Type[pydantic.root_model.RootModel[Any]]] = <class 'pydantic.root_model.RootModel[Annotated[str, MinLen, AfterValidator, AfterValidator, Annotated[TypeVar, Predicate], MaxLen, AfterValidator]]'>

the pydantic root model to validate the string

NonBatchAxisId = typing.Annotated[AxisId, Predicate(_is_not_batch)]
PreprocessingId = typing.Literal['binarize', 'clip', 'ensure_dtype', 'fixed_zero_mean_unit_variance', 'scale_linear', 'scale_range', 'sigmoid', 'softmax']
PostprocessingId = typing.Literal['binarize', 'clip', 'ensure_dtype', 'fixed_zero_mean_unit_variance', 'scale_linear', 'scale_mean_variance', 'scale_range', 'sigmoid', 'softmax', 'zero_mean_unit_variance']
SAME_AS_TYPE = '<same as type>'
ParameterizedSize_N = <class 'int'>

Annotates an integer to calculate a concrete axis size from a ParameterizedSize.

class ParameterizedSize(bioimageio.spec._internal.node.Node):
298class ParameterizedSize(Node):
299    """Describes a range of valid tensor axis sizes as `size = min + n*step`.
300
301    - **min** and **step** are given by the model description.
302    - All blocksize paramters n = 0,1,2,... yield a valid `size`.
303    - A greater blocksize paramter n = 0,1,2,... results in a greater **size**.
304      This allows to adjust the axis size more generically.
305    """
306
307    N: ClassVar[Type[int]] = ParameterizedSize_N
308    """Positive integer to parameterize this axis"""
309
310    min: Annotated[int, Gt(0)]
311    step: Annotated[int, Gt(0)]
312
313    def validate_size(self, size: int) -> int:
314        if size < self.min:
315            raise ValueError(f"size {size} < {self.min}")
316        if (size - self.min) % self.step != 0:
317            raise ValueError(
318                f"axis of size {size} is not parameterized by `min + n*step` ="
319                + f" `{self.min} + n*{self.step}`"
320            )
321
322        return size
323
324    def get_size(self, n: ParameterizedSize_N) -> int:
325        return self.min + self.step * n
326
327    def get_n(self, s: int) -> ParameterizedSize_N:
328        """return smallest n parameterizing a size greater or equal than `s`"""
329        return ceil((s - self.min) / self.step)

Describes a range of valid tensor axis sizes as size = min + n*step.

  • min and step are given by the model description.
  • All blocksize paramters n = 0,1,2,... yield a valid size.
  • A greater blocksize paramter n = 0,1,2,... results in a greater size. This allows to adjust the axis size more generically.
N: ClassVar[Type[int]] = <class 'int'>

Positive integer to parameterize this axis

min: Annotated[int, Gt(gt=0)]
step: Annotated[int, Gt(gt=0)]
def validate_size(self, size: int) -> int:
313    def validate_size(self, size: int) -> int:
314        if size < self.min:
315            raise ValueError(f"size {size} < {self.min}")
316        if (size - self.min) % self.step != 0:
317            raise ValueError(
318                f"axis of size {size} is not parameterized by `min + n*step` ="
319                + f" `{self.min} + n*{self.step}`"
320            )
321
322        return size
def get_size(self, n: int) -> int:
324    def get_size(self, n: ParameterizedSize_N) -> int:
325        return self.min + self.step * n
def get_n(self, s: int) -> int:
327    def get_n(self, s: int) -> ParameterizedSize_N:
328        """return smallest n parameterizing a size greater or equal than `s`"""
329        return ceil((s - self.min) / self.step)

return smallest n parameterizing a size greater or equal than s

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class DataDependentSize(bioimageio.spec._internal.node.Node):
332class DataDependentSize(Node):
333    min: Annotated[int, Gt(0)] = 1
334    max: Annotated[Optional[int], Gt(1)] = None
335
336    @model_validator(mode="after")
337    def _validate_max_gt_min(self):
338        if self.max is not None and self.min >= self.max:
339            raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}")
340
341        return self
342
343    def validate_size(self, size: int) -> int:
344        if size < self.min:
345            raise ValueError(f"size {size} < {self.min}")
346
347        if self.max is not None and size > self.max:
348            raise ValueError(f"size {size} > {self.max}")
349
350        return size
min: Annotated[int, Gt(gt=0)]
max: Annotated[Optional[int], Gt(gt=1)]
def validate_size(self, size: int) -> int:
343    def validate_size(self, size: int) -> int:
344        if size < self.min:
345            raise ValueError(f"size {size} < {self.min}")
346
347        if self.max is not None and size > self.max:
348            raise ValueError(f"size {size} > {self.max}")
349
350        return size
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class SizeReference(bioimageio.spec._internal.node.Node):
353class SizeReference(Node):
354    """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis.
355
356    `axis.size = reference.size * reference.scale / axis.scale + offset`
357
358    Note:
359    1. The axis and the referenced axis need to have the same unit (or no unit).
360    2. Batch axes may not be referenced.
361    3. Fractions are rounded down.
362    4. If the reference axis is `concatenable` the referencing axis is assumed to be
363        `concatenable` as well with the same block order.
364
365    Example:
366    An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm².
367    Let's assume that we want to express the image height h in relation to its width w
368    instead of only accepting input images of exactly 100*49 pixels
369    (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`).
370
371    >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2)
372    >>> h = SpaceInputAxis(
373    ...     id=AxisId("h"),
374    ...     size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1),
375    ...     unit="millimeter",
376    ...     scale=4,
377    ... )
378    >>> print(h.size.get_size(h, w))
379    49
380
381    ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49
382    """
383
384    tensor_id: TensorId
385    """tensor id of the reference axis"""
386
387    axis_id: AxisId
388    """axis id of the reference axis"""
389
390    offset: StrictInt = 0
391
392    def get_size(
393        self,
394        axis: Union[
395            ChannelAxis,
396            IndexInputAxis,
397            IndexOutputAxis,
398            TimeInputAxis,
399            SpaceInputAxis,
400            TimeOutputAxis,
401            TimeOutputAxisWithHalo,
402            SpaceOutputAxis,
403            SpaceOutputAxisWithHalo,
404        ],
405        ref_axis: Union[
406            ChannelAxis,
407            IndexInputAxis,
408            IndexOutputAxis,
409            TimeInputAxis,
410            SpaceInputAxis,
411            TimeOutputAxis,
412            TimeOutputAxisWithHalo,
413            SpaceOutputAxis,
414            SpaceOutputAxisWithHalo,
415        ],
416        n: ParameterizedSize_N = 0,
417        ref_size: Optional[int] = None,
418    ):
419        """Compute the concrete size for a given axis and its reference axis.
420
421        Args:
422            axis: The axis this `SizeReference` is the size of.
423            ref_axis: The reference axis to compute the size from.
424            n: If the **ref_axis** is parameterized (of type `ParameterizedSize`)
425                and no fixed **ref_size** is given,
426                **n** is used to compute the size of the parameterized **ref_axis**.
427            ref_size: Overwrite the reference size instead of deriving it from
428                **ref_axis**
429                (**ref_axis.scale** is still used; any given **n** is ignored).
430        """
431        assert axis.size == self, (
432            "Given `axis.size` is not defined by this `SizeReference`"
433        )
434
435        assert ref_axis.id == self.axis_id, (
436            f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}."
437        )
438
439        assert axis.unit == ref_axis.unit, (
440            "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`,"
441            f" but {axis.unit}!={ref_axis.unit}"
442        )
443        if ref_size is None:
444            if isinstance(ref_axis.size, (int, float)):
445                ref_size = ref_axis.size
446            elif isinstance(ref_axis.size, ParameterizedSize):
447                ref_size = ref_axis.size.get_size(n)
448            elif isinstance(ref_axis.size, DataDependentSize):
449                raise ValueError(
450                    "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`."
451                )
452            elif isinstance(ref_axis.size, SizeReference):
453                raise ValueError(
454                    "Reference axis referenced in `SizeReference` may not be sized by a"
455                    + " `SizeReference` itself."
456                )
457            else:
458                assert_never(ref_axis.size)
459
460        return int(ref_size * ref_axis.scale / axis.scale + self.offset)
461
462    @staticmethod
463    def _get_unit(
464        axis: Union[
465            ChannelAxis,
466            IndexInputAxis,
467            IndexOutputAxis,
468            TimeInputAxis,
469            SpaceInputAxis,
470            TimeOutputAxis,
471            TimeOutputAxisWithHalo,
472            SpaceOutputAxis,
473            SpaceOutputAxisWithHalo,
474        ],
475    ):
476        return axis.unit

A tensor axis size (extent in pixels/frames) defined in relation to a reference axis.

axis.size = reference.size * reference.scale / axis.scale + offset

Note:

  1. The axis and the referenced axis need to have the same unit (or no unit).
  2. Batch axes may not be referenced.
  3. Fractions are rounded down.
  4. If the reference axis is concatenable the referencing axis is assumed to be concatenable as well with the same block order.

Example: An unisotropic input image of wh=10049 pixels depicts a phsical space of 200196mm². Let's assume that we want to express the image height h in relation to its width w instead of only accepting input images of exactly 10049 pixels (for example to express a range of valid image shapes by parametrizing w, see ParameterizedSize).

>>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2)
>>> h = SpaceInputAxis(
...     id=AxisId("h"),
...     size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1),
...     unit="millimeter",
...     scale=4,
... )
>>> print(h.size.get_size(h, w))
49

⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49

tensor_id: TensorId

tensor id of the reference axis

axis_id: AxisId

axis id of the reference axis

offset: Annotated[int, Strict(strict=True)]
392    def get_size(
393        self,
394        axis: Union[
395            ChannelAxis,
396            IndexInputAxis,
397            IndexOutputAxis,
398            TimeInputAxis,
399            SpaceInputAxis,
400            TimeOutputAxis,
401            TimeOutputAxisWithHalo,
402            SpaceOutputAxis,
403            SpaceOutputAxisWithHalo,
404        ],
405        ref_axis: Union[
406            ChannelAxis,
407            IndexInputAxis,
408            IndexOutputAxis,
409            TimeInputAxis,
410            SpaceInputAxis,
411            TimeOutputAxis,
412            TimeOutputAxisWithHalo,
413            SpaceOutputAxis,
414            SpaceOutputAxisWithHalo,
415        ],
416        n: ParameterizedSize_N = 0,
417        ref_size: Optional[int] = None,
418    ):
419        """Compute the concrete size for a given axis and its reference axis.
420
421        Args:
422            axis: The axis this `SizeReference` is the size of.
423            ref_axis: The reference axis to compute the size from.
424            n: If the **ref_axis** is parameterized (of type `ParameterizedSize`)
425                and no fixed **ref_size** is given,
426                **n** is used to compute the size of the parameterized **ref_axis**.
427            ref_size: Overwrite the reference size instead of deriving it from
428                **ref_axis**
429                (**ref_axis.scale** is still used; any given **n** is ignored).
430        """
431        assert axis.size == self, (
432            "Given `axis.size` is not defined by this `SizeReference`"
433        )
434
435        assert ref_axis.id == self.axis_id, (
436            f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}."
437        )
438
439        assert axis.unit == ref_axis.unit, (
440            "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`,"
441            f" but {axis.unit}!={ref_axis.unit}"
442        )
443        if ref_size is None:
444            if isinstance(ref_axis.size, (int, float)):
445                ref_size = ref_axis.size
446            elif isinstance(ref_axis.size, ParameterizedSize):
447                ref_size = ref_axis.size.get_size(n)
448            elif isinstance(ref_axis.size, DataDependentSize):
449                raise ValueError(
450                    "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`."
451                )
452            elif isinstance(ref_axis.size, SizeReference):
453                raise ValueError(
454                    "Reference axis referenced in `SizeReference` may not be sized by a"
455                    + " `SizeReference` itself."
456                )
457            else:
458                assert_never(ref_axis.size)
459
460        return int(ref_size * ref_axis.scale / axis.scale + self.offset)

Compute the concrete size for a given axis and its reference axis.

Arguments:
  • axis: The axis this SizeReference is the size of.
  • ref_axis: The reference axis to compute the size from.
  • n: If the ref_axis is parameterized (of type ParameterizedSize) and no fixed ref_size is given, n is used to compute the size of the parameterized ref_axis.
  • ref_size: Overwrite the reference size instead of deriving it from ref_axis (ref_axis.scale is still used; any given n is ignored).
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

479class AxisBase(NodeWithExplicitlySetFields):
480    id: AxisId
481    """An axis id unique across all axes of one tensor."""
482
483    description: Annotated[str, MaxLen(128)] = ""
484    """A short description of this axis beyond its type and id."""
id: AxisId

An axis id unique across all axes of one tensor.

description: Annotated[str, MaxLen(max_length=128)]

A short description of this axis beyond its type and id.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class WithHalo(bioimageio.spec._internal.node.Node):
487class WithHalo(Node):
488    halo: Annotated[int, Ge(1)]
489    """The halo should be cropped from the output tensor to avoid boundary effects.
490    It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`.
491    To document a halo that is already cropped by the model use `size.offset` instead."""
492
493    size: Annotated[
494        SizeReference,
495        Field(
496            examples=[
497                10,
498                SizeReference(
499                    tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
500                ).model_dump(mode="json"),
501            ]
502        ),
503    ]
504    """reference to another axis with an optional offset (see `SizeReference`)"""
halo: Annotated[int, Ge(ge=1)]

The halo should be cropped from the output tensor to avoid boundary effects. It is to be cropped from both sides, i.e. size_after_crop = size - 2 * halo. To document a halo that is already cropped by the model use size.offset instead.

size: Annotated[SizeReference, FieldInfo(annotation=NoneType, required=True, examples=[10, {'tensor_id': 't', 'axis_id': 'a', 'offset': 5}])]

reference to another axis with an optional offset (see SizeReference)

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

BATCH_AXIS_ID = 'batch'
class BatchAxis(AxisBase):
510class BatchAxis(AxisBase):
511    implemented_type: ClassVar[Literal["batch"]] = "batch"
512    if TYPE_CHECKING:
513        type: Literal["batch"] = "batch"
514    else:
515        type: Literal["batch"]
516
517    id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID
518    size: Optional[Literal[1]] = None
519    """The batch size may be fixed to 1,
520    otherwise (the default) it may be chosen arbitrarily depending on available memory"""
521
522    @property
523    def scale(self):
524        return 1.0
525
526    @property
527    def concatenable(self):
528        return True
529
530    @property
531    def unit(self):
532        return None
implemented_type: ClassVar[Literal['batch']] = 'batch'
id: Annotated[AxisId, Predicate(_is_batch)]

An axis id unique across all axes of one tensor.

size: Optional[Literal[1]]

The batch size may be fixed to 1, otherwise (the default) it may be chosen arbitrarily depending on available memory

scale
522    @property
523    def scale(self):
524        return 1.0
concatenable
526    @property
527    def concatenable(self):
528        return True
unit
530    @property
531    def unit(self):
532        return None
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['batch']
class ChannelAxis(AxisBase):
535class ChannelAxis(AxisBase):
536    implemented_type: ClassVar[Literal["channel"]] = "channel"
537    if TYPE_CHECKING:
538        type: Literal["channel"] = "channel"
539    else:
540        type: Literal["channel"]
541
542    id: NonBatchAxisId = AxisId("channel")
543
544    channel_names: NotEmpty[List[Identifier]]
545
546    @property
547    def size(self) -> int:
548        return len(self.channel_names)
549
550    @property
551    def concatenable(self):
552        return False
553
554    @property
555    def scale(self) -> float:
556        return 1.0
557
558    @property
559    def unit(self):
560        return None
implemented_type: ClassVar[Literal['channel']] = 'channel'
id: Annotated[AxisId, Predicate(_is_not_batch)]

An axis id unique across all axes of one tensor.

channel_names: Annotated[List[bioimageio.spec._internal.types.Identifier], MinLen(min_length=1)]
size: int
546    @property
547    def size(self) -> int:
548        return len(self.channel_names)
concatenable
550    @property
551    def concatenable(self):
552        return False
scale: float
554    @property
555    def scale(self) -> float:
556        return 1.0
unit
558    @property
559    def unit(self):
560        return None
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['channel']
class IndexAxisBase(AxisBase):
563class IndexAxisBase(AxisBase):
564    implemented_type: ClassVar[Literal["index"]] = "index"
565    if TYPE_CHECKING:
566        type: Literal["index"] = "index"
567    else:
568        type: Literal["index"]
569
570    id: NonBatchAxisId = AxisId("index")
571
572    @property
573    def scale(self) -> float:
574        return 1.0
575
576    @property
577    def unit(self):
578        return None
implemented_type: ClassVar[Literal['index']] = 'index'
id: Annotated[AxisId, Predicate(_is_not_batch)]

An axis id unique across all axes of one tensor.

scale: float
572    @property
573    def scale(self) -> float:
574        return 1.0
unit
576    @property
577    def unit(self):
578        return None
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['index']
class IndexInputAxis(IndexAxisBase, _WithInputAxisSize):
601class IndexInputAxis(IndexAxisBase, _WithInputAxisSize):
602    concatenable: bool = False
603    """If a model has a `concatenable` input axis, it can be processed blockwise,
604    splitting a longer sample axis into blocks matching its input tensor description.
605    Output axes are concatenable if they have a `SizeReference` to a concatenable
606    input axis.
607    """
concatenable: bool

If a model has a concatenable input axis, it can be processed blockwise, splitting a longer sample axis into blocks matching its input tensor description. Output axes are concatenable if they have a SizeReference to a concatenable input axis.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['index']
class IndexOutputAxis(IndexAxisBase):
610class IndexOutputAxis(IndexAxisBase):
611    size: Annotated[
612        Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize],
613        Field(
614            examples=[
615                10,
616                SizeReference(
617                    tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
618                ).model_dump(mode="json"),
619            ]
620        ),
621    ]
622    """The size/length of this axis can be specified as
623    - fixed integer
624    - reference to another axis with an optional offset (`SizeReference`)
625    - data dependent size using `DataDependentSize` (size is only known after model inference)
626    """
size: Annotated[Union[Annotated[int, Gt(gt=0)], SizeReference, DataDependentSize], FieldInfo(annotation=NoneType, required=True, examples=[10, {'tensor_id': 't', 'axis_id': 'a', 'offset': 5}])]

The size/length of this axis can be specified as

  • fixed integer
  • reference to another axis with an optional offset (SizeReference)
  • data dependent size using DataDependentSize (size is only known after model inference)
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['index']
class TimeAxisBase(AxisBase):
629class TimeAxisBase(AxisBase):
630    implemented_type: ClassVar[Literal["time"]] = "time"
631    if TYPE_CHECKING:
632        type: Literal["time"] = "time"
633    else:
634        type: Literal["time"]
635
636    id: NonBatchAxisId = AxisId("time")
637    unit: Optional[TimeUnit] = None
638    scale: Annotated[float, Gt(0)] = 1.0
implemented_type: ClassVar[Literal['time']] = 'time'
id: Annotated[AxisId, Predicate(_is_not_batch)]

An axis id unique across all axes of one tensor.

unit: Optional[Literal['attosecond', 'centisecond', 'day', 'decisecond', 'exasecond', 'femtosecond', 'gigasecond', 'hectosecond', 'hour', 'kilosecond', 'megasecond', 'microsecond', 'millisecond', 'minute', 'nanosecond', 'petasecond', 'picosecond', 'second', 'terasecond', 'yoctosecond', 'yottasecond', 'zeptosecond', 'zettasecond']]
scale: Annotated[float, Gt(gt=0)]
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['time']
class TimeInputAxis(TimeAxisBase, _WithInputAxisSize):
641class TimeInputAxis(TimeAxisBase, _WithInputAxisSize):
642    concatenable: bool = False
643    """If a model has a `concatenable` input axis, it can be processed blockwise,
644    splitting a longer sample axis into blocks matching its input tensor description.
645    Output axes are concatenable if they have a `SizeReference` to a concatenable
646    input axis.
647    """
concatenable: bool

If a model has a concatenable input axis, it can be processed blockwise, splitting a longer sample axis into blocks matching its input tensor description. Output axes are concatenable if they have a SizeReference to a concatenable input axis.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['time']
class SpaceAxisBase(AxisBase):
650class SpaceAxisBase(AxisBase):
651    implemented_type: ClassVar[Literal["space"]] = "space"
652    if TYPE_CHECKING:
653        type: Literal["space"] = "space"
654    else:
655        type: Literal["space"]
656
657    id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x")
658    unit: Optional[SpaceUnit] = None
659    scale: Annotated[float, Gt(0)] = 1.0
implemented_type: ClassVar[Literal['space']] = 'space'
id: Annotated[AxisId, Predicate(_is_not_batch), FieldInfo(annotation=NoneType, required=True, examples=['x', 'y', 'z'])]

An axis id unique across all axes of one tensor.

unit: Optional[Literal['attometer', 'angstrom', 'centimeter', 'decimeter', 'exameter', 'femtometer', 'foot', 'gigameter', 'hectometer', 'inch', 'kilometer', 'megameter', 'meter', 'micrometer', 'mile', 'millimeter', 'nanometer', 'parsec', 'petameter', 'picometer', 'terameter', 'yard', 'yoctometer', 'yottameter', 'zeptometer', 'zettameter']]
scale: Annotated[float, Gt(gt=0)]
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['space']
class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize):
662class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize):
663    concatenable: bool = False
664    """If a model has a `concatenable` input axis, it can be processed blockwise,
665    splitting a longer sample axis into blocks matching its input tensor description.
666    Output axes are concatenable if they have a `SizeReference` to a concatenable
667    input axis.
668    """
concatenable: bool

If a model has a concatenable input axis, it can be processed blockwise, splitting a longer sample axis into blocks matching its input tensor description. Output axes are concatenable if they have a SizeReference to a concatenable input axis.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['space']
INPUT_AXIS_TYPES = (<class 'BatchAxis'>, <class 'ChannelAxis'>, <class 'IndexInputAxis'>, <class 'TimeInputAxis'>, <class 'SpaceInputAxis'>)

intended for isinstance comparisons in py<3.10

InputAxis = typing.Annotated[typing.Union[BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)]
class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize):
704class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize):
705    pass
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['time']
class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo):
708class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo):
709    pass
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['time']
class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize):
728class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize):
729    pass
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['space']
class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo):
732class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo):
733    pass
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['space']
OutputAxis = typing.Annotated[typing.Union[BatchAxis, ChannelAxis, IndexOutputAxis, typing.Annotated[typing.Union[typing.Annotated[TimeOutputAxis, Tag(tag='wo_halo')], typing.Annotated[TimeOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)], typing.Annotated[typing.Union[typing.Annotated[SpaceOutputAxis, Tag(tag='wo_halo')], typing.Annotated[SpaceOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)]
OUTPUT_AXIS_TYPES = (<class 'BatchAxis'>, <class 'ChannelAxis'>, <class 'IndexOutputAxis'>, <class 'TimeOutputAxis'>, <class 'TimeOutputAxisWithHalo'>, <class 'SpaceOutputAxis'>, <class 'SpaceOutputAxisWithHalo'>)

intended for isinstance comparisons in py<3.10

AnyAxis = typing.Union[typing.Annotated[typing.Union[BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], typing.Annotated[typing.Union[BatchAxis, ChannelAxis, IndexOutputAxis, typing.Annotated[typing.Union[typing.Annotated[TimeOutputAxis, Tag(tag='wo_halo')], typing.Annotated[TimeOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)], typing.Annotated[typing.Union[typing.Annotated[SpaceOutputAxis, Tag(tag='wo_halo')], typing.Annotated[SpaceOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]
ANY_AXIS_TYPES = (<class 'BatchAxis'>, <class 'ChannelAxis'>, <class 'IndexInputAxis'>, <class 'TimeInputAxis'>, <class 'SpaceInputAxis'>, <class 'BatchAxis'>, <class 'ChannelAxis'>, <class 'IndexOutputAxis'>, <class 'TimeOutputAxis'>, <class 'TimeOutputAxisWithHalo'>, <class 'SpaceOutputAxis'>, <class 'SpaceOutputAxisWithHalo'>)

intended for isinstance comparisons in py<3.10

TVs = typing.Union[typing.Annotated[typing.List[int], MinLen(min_length=1)], typing.Annotated[typing.List[float], MinLen(min_length=1)], typing.Annotated[typing.List[bool], MinLen(min_length=1)], typing.Annotated[typing.List[str], MinLen(min_length=1)]]
NominalOrOrdinalDType = typing.Literal['float32', 'float64', 'uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64', 'bool']
class NominalOrOrdinalDataDescr(bioimageio.spec._internal.node.Node):
790class NominalOrOrdinalDataDescr(Node):
791    values: TVs
792    """A fixed set of nominal or an ascending sequence of ordinal values.
793    In this case `data.type` is required to be an unsigend integer type, e.g. 'uint8'.
794    String `values` are interpreted as labels for tensor values 0, ..., N.
795    Note: as YAML 1.2 does not natively support a "set" datatype,
796    nominal values should be given as a sequence (aka list/array) as well.
797    """
798
799    type: Annotated[
800        NominalOrOrdinalDType,
801        Field(
802            examples=[
803                "float32",
804                "uint8",
805                "uint16",
806                "int64",
807                "bool",
808            ],
809        ),
810    ] = "uint8"
811
812    @model_validator(mode="after")
813    def _validate_values_match_type(
814        self,
815    ) -> Self:
816        incompatible: List[Any] = []
817        for v in self.values:
818            if self.type == "bool":
819                if not isinstance(v, bool):
820                    incompatible.append(v)
821            elif self.type in DTYPE_LIMITS:
822                if (
823                    isinstance(v, (int, float))
824                    and (
825                        v < DTYPE_LIMITS[self.type].min
826                        or v > DTYPE_LIMITS[self.type].max
827                    )
828                    or (isinstance(v, str) and "uint" not in self.type)
829                    or (isinstance(v, float) and "int" in self.type)
830                ):
831                    incompatible.append(v)
832            else:
833                incompatible.append(v)
834
835            if len(incompatible) == 5:
836                incompatible.append("...")
837                break
838
839        if incompatible:
840            raise ValueError(
841                f"data type '{self.type}' incompatible with values {incompatible}"
842            )
843
844        return self
845
846    unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None
847
848    @property
849    def range(self):
850        if isinstance(self.values[0], str):
851            return 0, len(self.values) - 1
852        else:
853            return min(self.values), max(self.values)
values: Union[Annotated[List[int], MinLen(min_length=1)], Annotated[List[float], MinLen(min_length=1)], Annotated[List[bool], MinLen(min_length=1)], Annotated[List[str], MinLen(min_length=1)]]

A fixed set of nominal or an ascending sequence of ordinal values. In this case data.type is required to be an unsigend integer type, e.g. 'uint8'. String values are interpreted as labels for tensor values 0, ..., N. Note: as YAML 1.2 does not natively support a "set" datatype, nominal values should be given as a sequence (aka list/array) as well.

type: Annotated[Literal['float32', 'float64', 'uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64', 'bool'], FieldInfo(annotation=NoneType, required=True, examples=['float32', 'uint8', 'uint16', 'int64', 'bool'])]
unit: Union[Literal['arbitrary unit'], bioimageio.spec._internal.types.SiUnit, NoneType]
range
848    @property
849    def range(self):
850        if isinstance(self.values[0], str):
851            return 0, len(self.values) - 1
852        else:
853            return min(self.values), max(self.values)
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

IntervalOrRatioDType = typing.Literal['float32', 'float64', 'uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64']
class IntervalOrRatioDataDescr(bioimageio.spec._internal.node.Node):
870class IntervalOrRatioDataDescr(Node):
871    type: Annotated[  # TODO: rename to dtype
872        IntervalOrRatioDType,
873        Field(
874            examples=["float32", "float64", "uint8", "uint16"],
875        ),
876    ] = "float32"
877    range: Tuple[Optional[float], Optional[float]] = (
878        None,
879        None,
880    )
881    """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor.
882    `None` corresponds to min/max of what can be expressed by **type**."""
883    unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit"
884    scale: float = 1.0
885    """Scale for data on an interval (or ratio) scale."""
886    offset: Optional[float] = None
887    """Offset for data on a ratio scale."""
888
889    @model_validator(mode="before")
890    def _replace_inf(cls, data: Any):
891        if is_dict(data):
892            if "range" in data and is_sequence(data["range"]):
893                forbidden = (
894                    "inf",
895                    "-inf",
896                    ".inf",
897                    "-.inf",
898                    float("inf"),
899                    float("-inf"),
900                )
901                if any(v in forbidden for v in data["range"]):
902                    issue_warning("replaced 'inf' value", value=data["range"])
903
904                data["range"] = tuple(
905                    (None if v in forbidden else v) for v in data["range"]
906                )
907
908        return data
type: Annotated[Literal['float32', 'float64', 'uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64'], FieldInfo(annotation=NoneType, required=True, examples=['float32', 'float64', 'uint8', 'uint16'])]
range: Tuple[Optional[float], Optional[float]]

Tuple (minimum, maximum) specifying the allowed range of the data in this tensor. None corresponds to min/max of what can be expressed by type.

unit: Union[Literal['arbitrary unit'], bioimageio.spec._internal.types.SiUnit]
scale: float

Scale for data on an interval (or ratio) scale.

offset: Optional[float]

Offset for data on a ratio scale.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

TensorDataDescr = typing.Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr]
class ProcessingDescrBase(bioimageio.spec._internal.common_nodes.NodeWithExplicitlySetFields, abc.ABC):
914class ProcessingDescrBase(NodeWithExplicitlySetFields, ABC):
915    """processing base class"""

processing base class

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class BinarizeKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
918class BinarizeKwargs(ProcessingKwargs):
919    """key word arguments for `BinarizeDescr`"""
920
921    threshold: float
922    """The fixed threshold"""

key word arguments for BinarizeDescr

threshold: float

The fixed threshold

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class BinarizeAlongAxisKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
925class BinarizeAlongAxisKwargs(ProcessingKwargs):
926    """key word arguments for `BinarizeDescr`"""
927
928    threshold: NotEmpty[List[float]]
929    """The fixed threshold values along `axis`"""
930
931    axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
932    """The `threshold` axis"""

key word arguments for BinarizeDescr

threshold: Annotated[List[float], MinLen(min_length=1)]

The fixed threshold values along axis

axis: Annotated[AxisId, Predicate(_is_not_batch), FieldInfo(annotation=NoneType, required=True, examples=['channel'])]

The threshold axis

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class BinarizeDescr(ProcessingDescrBase):
935class BinarizeDescr(ProcessingDescrBase):
936    """Binarize the tensor with a fixed threshold.
937
938    Values above `BinarizeKwargs.threshold`/`BinarizeAlongAxisKwargs.threshold`
939    will be set to one, values below the threshold to zero.
940
941    Examples:
942    - in YAML
943        ```yaml
944        postprocessing:
945          - id: binarize
946            kwargs:
947              axis: 'channel'
948              threshold: [0.25, 0.5, 0.75]
949        ```
950    - in Python:
951        >>> postprocessing = [BinarizeDescr(
952        ...   kwargs=BinarizeAlongAxisKwargs(
953        ...       axis=AxisId('channel'),
954        ...       threshold=[0.25, 0.5, 0.75],
955        ...   )
956        ... )]
957    """
958
959    implemented_id: ClassVar[Literal["binarize"]] = "binarize"
960    if TYPE_CHECKING:
961        id: Literal["binarize"] = "binarize"
962    else:
963        id: Literal["binarize"]
964    kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs]

Binarize the tensor with a fixed threshold.

Values above BinarizeKwargs.threshold/BinarizeAlongAxisKwargs.threshold will be set to one, values below the threshold to zero.

Examples:

  • in YAML
postprocessing:
  - id: binarize
    kwargs:
      axis: 'channel'
      threshold: [0.25, 0.5, 0.75]
  • in Python:
    >>> postprocessing = [BinarizeDescr(
    ...   kwargs=BinarizeAlongAxisKwargs(
    ...       axis=AxisId('channel'),
    ...       threshold=[0.25, 0.5, 0.75],
    ...   )
    ... )]
    
implemented_id: ClassVar[Literal['binarize']] = 'binarize'
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

id: Literal['binarize']
class ClipDescr(ProcessingDescrBase):
967class ClipDescr(ProcessingDescrBase):
968    """Set tensor values below min to min and above max to max.
969
970    See `ScaleRangeDescr` for examples.
971    """
972
973    implemented_id: ClassVar[Literal["clip"]] = "clip"
974    if TYPE_CHECKING:
975        id: Literal["clip"] = "clip"
976    else:
977        id: Literal["clip"]
978
979    kwargs: ClipKwargs

Set tensor values below min to min and above max to max.

See ScaleRangeDescr for examples.

implemented_id: ClassVar[Literal['clip']] = 'clip'
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

id: Literal['clip']
class EnsureDtypeKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
982class EnsureDtypeKwargs(ProcessingKwargs):
983    """key word arguments for `EnsureDtypeDescr`"""
984
985    dtype: Literal[
986        "float32",
987        "float64",
988        "uint8",
989        "int8",
990        "uint16",
991        "int16",
992        "uint32",
993        "int32",
994        "uint64",
995        "int64",
996        "bool",
997    ]

key word arguments for EnsureDtypeDescr

dtype: Literal['float32', 'float64', 'uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64', 'bool']
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class EnsureDtypeDescr(ProcessingDescrBase):
1000class EnsureDtypeDescr(ProcessingDescrBase):
1001    """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching).
1002
1003    This can for example be used to ensure the inner neural network model gets a
1004    different input tensor data type than the fully described bioimage.io model does.
1005
1006    Examples:
1007        The described bioimage.io model (incl. preprocessing) accepts any
1008        float32-compatible tensor, normalizes it with percentiles and clipping and then
1009        casts it to uint8, which is what the neural network in this example expects.
1010        - in YAML
1011            ```yaml
1012            inputs:
1013            - data:
1014                type: float32  # described bioimage.io model is compatible with any float32 input tensor
1015              preprocessing:
1016              - id: scale_range
1017                  kwargs:
1018                  axes: ['y', 'x']
1019                  max_percentile: 99.8
1020                  min_percentile: 5.0
1021              - id: clip
1022                  kwargs:
1023                  min: 0.0
1024                  max: 1.0
1025              - id: ensure_dtype  # the neural network of the model requires uint8
1026                  kwargs:
1027                  dtype: uint8
1028            ```
1029        - in Python:
1030            >>> preprocessing = [
1031            ...     ScaleRangeDescr(
1032            ...         kwargs=ScaleRangeKwargs(
1033            ...           axes= (AxisId('y'), AxisId('x')),
1034            ...           max_percentile= 99.8,
1035            ...           min_percentile= 5.0,
1036            ...         )
1037            ...     ),
1038            ...     ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)),
1039            ...     EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")),
1040            ... ]
1041    """
1042
1043    implemented_id: ClassVar[Literal["ensure_dtype"]] = "ensure_dtype"
1044    if TYPE_CHECKING:
1045        id: Literal["ensure_dtype"] = "ensure_dtype"
1046    else:
1047        id: Literal["ensure_dtype"]
1048
1049    kwargs: EnsureDtypeKwargs

Cast the tensor data type to EnsureDtypeKwargs.dtype (if not matching).

This can for example be used to ensure the inner neural network model gets a different input tensor data type than the fully described bioimage.io model does.

Examples:

The described bioimage.io model (incl. preprocessing) accepts any float32-compatible tensor, normalizes it with percentiles and clipping and then casts it to uint8, which is what the neural network in this example expects.

  • in YAML

inputs:
- data:
    type: float32  # described bioimage.io model is compatible with any float32 input tensor
  preprocessing:
  - id: scale_range
      kwargs:
      axes: ['y', 'x']
      max_percentile: 99.8
      min_percentile: 5.0
  - id: clip
      kwargs:
      min: 0.0
      max: 1.0
  - id: ensure_dtype  # the neural network of the model requires uint8
      kwargs:
      dtype: uint8

  • in Python:
    >>> preprocessing = [
    ...     ScaleRangeDescr(
    ...         kwargs=ScaleRangeKwargs(
    ...           axes= (AxisId('y'), AxisId('x')),
    ...           max_percentile= 99.8,
    ...           min_percentile= 5.0,
    ...         )
    ...     ),
    ...     ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)),
    ...     EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")),
    ... ]
    
implemented_id: ClassVar[Literal['ensure_dtype']] = 'ensure_dtype'
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

id: Literal['ensure_dtype']
class ScaleLinearKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1052class ScaleLinearKwargs(ProcessingKwargs):
1053    """Key word arguments for `ScaleLinearDescr`"""
1054
1055    gain: float = 1.0
1056    """multiplicative factor"""
1057
1058    offset: float = 0.0
1059    """additive term"""
1060
1061    @model_validator(mode="after")
1062    def _validate(self) -> Self:
1063        if self.gain == 1.0 and self.offset == 0.0:
1064            raise ValueError(
1065                "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
1066                + " != 0.0."
1067            )
1068
1069        return self

Key word arguments for ScaleLinearDescr

gain: float

multiplicative factor

offset: float

additive term

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class ScaleLinearAlongAxisKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1072class ScaleLinearAlongAxisKwargs(ProcessingKwargs):
1073    """Key word arguments for `ScaleLinearDescr`"""
1074
1075    axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
1076    """The axis of gain and offset values."""
1077
1078    gain: Union[float, NotEmpty[List[float]]] = 1.0
1079    """multiplicative factor"""
1080
1081    offset: Union[float, NotEmpty[List[float]]] = 0.0
1082    """additive term"""
1083
1084    @model_validator(mode="after")
1085    def _validate(self) -> Self:
1086        if isinstance(self.gain, list):
1087            if isinstance(self.offset, list):
1088                if len(self.gain) != len(self.offset):
1089                    raise ValueError(
1090                        f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match."
1091                    )
1092            else:
1093                self.offset = [float(self.offset)] * len(self.gain)
1094        elif isinstance(self.offset, list):
1095            self.gain = [float(self.gain)] * len(self.offset)
1096        else:
1097            raise ValueError(
1098                "Do not specify an `axis` for scalar gain and offset values."
1099            )
1100
1101        if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset):
1102            raise ValueError(
1103                "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
1104                + " != 0.0."
1105            )
1106
1107        return self

Key word arguments for ScaleLinearDescr

axis: Annotated[AxisId, Predicate(_is_not_batch), FieldInfo(annotation=NoneType, required=True, examples=['channel'])]

The axis of gain and offset values.

gain: Union[float, Annotated[List[float], MinLen(min_length=1)]]

multiplicative factor

offset: Union[float, Annotated[List[float], MinLen(min_length=1)]]

additive term

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class ScaleLinearDescr(ProcessingDescrBase):
1110class ScaleLinearDescr(ProcessingDescrBase):
1111    """Fixed linear scaling.
1112
1113    Examples:
1114      1. Scale with scalar gain and offset
1115        - in YAML
1116        ```yaml
1117        preprocessing:
1118          - id: scale_linear
1119            kwargs:
1120              gain: 2.0
1121              offset: 3.0
1122        ```
1123        - in Python:
1124        >>> preprocessing = [
1125        ...     ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0))
1126        ... ]
1127
1128      2. Independent scaling along an axis
1129        - in YAML
1130        ```yaml
1131        preprocessing:
1132          - id: scale_linear
1133            kwargs:
1134              axis: 'channel'
1135              gain: [1.0, 2.0, 3.0]
1136        ```
1137        - in Python:
1138        >>> preprocessing = [
1139        ...     ScaleLinearDescr(
1140        ...         kwargs=ScaleLinearAlongAxisKwargs(
1141        ...             axis=AxisId("channel"),
1142        ...             gain=[1.0, 2.0, 3.0],
1143        ...         )
1144        ...     )
1145        ... ]
1146
1147    """
1148
1149    implemented_id: ClassVar[Literal["scale_linear"]] = "scale_linear"
1150    if TYPE_CHECKING:
1151        id: Literal["scale_linear"] = "scale_linear"
1152    else:
1153        id: Literal["scale_linear"]
1154    kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs]

Fixed linear scaling.

Examples:
  1. Scale with scalar gain and offset

    • in YAML

      preprocessing:
        - id: scale_linear
          kwargs:
            gain: 2.0
            offset: 3.0
      
    • in Python:

      >>> preprocessing = [
      ...     ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0))
      ... ]
      
  2. Independent scaling along an axis

    • in YAML

      preprocessing:
        - id: scale_linear
          kwargs:
            axis: 'channel'
            gain: [1.0, 2.0, 3.0]
      
    • in Python:

      >>> preprocessing = [
      ...     ScaleLinearDescr(
      ...         kwargs=ScaleLinearAlongAxisKwargs(
      ...             axis=AxisId("channel"),
      ...             gain=[1.0, 2.0, 3.0],
      ...         )
      ...     )
      ... ]
      
implemented_id: ClassVar[Literal['scale_linear']] = 'scale_linear'
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

id: Literal['scale_linear']
class SigmoidDescr(ProcessingDescrBase):
1157class SigmoidDescr(ProcessingDescrBase):
1158    """The logistic sigmoid function, a.k.a. expit function.
1159
1160    Examples:
1161    - in YAML
1162        ```yaml
1163        postprocessing:
1164          - id: sigmoid
1165        ```
1166    - in Python:
1167        >>> postprocessing = [SigmoidDescr()]
1168    """
1169
1170    implemented_id: ClassVar[Literal["sigmoid"]] = "sigmoid"
1171    if TYPE_CHECKING:
1172        id: Literal["sigmoid"] = "sigmoid"
1173    else:
1174        id: Literal["sigmoid"]
1175
1176    @property
1177    def kwargs(self) -> ProcessingKwargs:
1178        """empty kwargs"""
1179        return ProcessingKwargs()

The logistic sigmoid function, a.k.a. expit function.

Examples:

  • in YAML
postprocessing:
  - id: sigmoid
  • in Python:
    >>> postprocessing = [SigmoidDescr()]
    
implemented_id: ClassVar[Literal['sigmoid']] = 'sigmoid'
1176    @property
1177    def kwargs(self) -> ProcessingKwargs:
1178        """empty kwargs"""
1179        return ProcessingKwargs()

empty kwargs

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

id: Literal['sigmoid']
class SoftmaxKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1182class SoftmaxKwargs(ProcessingKwargs):
1183    """key word arguments for `SoftmaxDescr`"""
1184
1185    axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] = AxisId("channel")
1186    """The axis to apply the softmax function along.
1187    Note:
1188        Defaults to 'channel' axis
1189        (which may not exist, in which case
1190        a different axis id has to be specified).
1191    """

key word arguments for SoftmaxDescr

axis: Annotated[AxisId, Predicate(_is_not_batch), FieldInfo(annotation=NoneType, required=True, examples=['channel'])]

The axis to apply the softmax function along.

Note:

Defaults to 'channel' axis (which may not exist, in which case a different axis id has to be specified).

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class SoftmaxDescr(ProcessingDescrBase):
1194class SoftmaxDescr(ProcessingDescrBase):
1195    """The softmax function.
1196
1197    Examples:
1198    - in YAML
1199        ```yaml
1200        postprocessing:
1201          - id: softmax
1202            kwargs:
1203              axis: channel
1204        ```
1205    - in Python:
1206        >>> postprocessing = [SoftmaxDescr(kwargs=SoftmaxKwargs(axis=AxisId("channel")))]
1207    """
1208
1209    implemented_id: ClassVar[Literal["softmax"]] = "softmax"
1210    if TYPE_CHECKING:
1211        id: Literal["softmax"] = "softmax"
1212    else:
1213        id: Literal["softmax"]
1214
1215    kwargs: SoftmaxKwargs = Field(default_factory=SoftmaxKwargs.model_construct)

The softmax function.

Examples:

  • in YAML
postprocessing:
  - id: softmax
    kwargs:
      axis: channel
  • in Python:
    >>> postprocessing = [SoftmaxDescr(kwargs=SoftmaxKwargs(axis=AxisId("channel")))]
    
implemented_id: ClassVar[Literal['softmax']] = 'softmax'
kwargs: SoftmaxKwargs
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

id: Literal['softmax']
class FixedZeroMeanUnitVarianceKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1218class FixedZeroMeanUnitVarianceKwargs(ProcessingKwargs):
1219    """key word arguments for `FixedZeroMeanUnitVarianceDescr`"""
1220
1221    mean: float
1222    """The mean value to normalize with."""
1223
1224    std: Annotated[float, Ge(1e-6)]
1225    """The standard deviation value to normalize with."""

key word arguments for FixedZeroMeanUnitVarianceDescr

mean: float

The mean value to normalize with.

std: Annotated[float, Ge(ge=1e-06)]

The standard deviation value to normalize with.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class FixedZeroMeanUnitVarianceAlongAxisKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1228class FixedZeroMeanUnitVarianceAlongAxisKwargs(ProcessingKwargs):
1229    """key word arguments for `FixedZeroMeanUnitVarianceDescr`"""
1230
1231    mean: NotEmpty[List[float]]
1232    """The mean value(s) to normalize with."""
1233
1234    std: NotEmpty[List[Annotated[float, Ge(1e-6)]]]
1235    """The standard deviation value(s) to normalize with.
1236    Size must match `mean` values."""
1237
1238    axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])]
1239    """The axis of the mean/std values to normalize each entry along that dimension
1240    separately."""
1241
1242    @model_validator(mode="after")
1243    def _mean_and_std_match(self) -> Self:
1244        if len(self.mean) != len(self.std):
1245            raise ValueError(
1246                f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})"
1247                + " must match."
1248            )
1249
1250        return self

key word arguments for FixedZeroMeanUnitVarianceDescr

mean: Annotated[List[float], MinLen(min_length=1)]

The mean value(s) to normalize with.

std: Annotated[List[Annotated[float, Ge(ge=1e-06)]], MinLen(min_length=1)]

The standard deviation value(s) to normalize with. Size must match mean values.

axis: Annotated[AxisId, Predicate(_is_not_batch), FieldInfo(annotation=NoneType, required=True, examples=['channel', 'index'])]

The axis of the mean/std values to normalize each entry along that dimension separately.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1253class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1254    """Subtract a given mean and divide by the standard deviation.
1255
1256    Normalize with fixed, precomputed values for
1257    `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std`
1258    Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given
1259    axes.
1260
1261    Examples:
1262    1. scalar value for whole tensor
1263        - in YAML
1264        ```yaml
1265        preprocessing:
1266          - id: fixed_zero_mean_unit_variance
1267            kwargs:
1268              mean: 103.5
1269              std: 13.7
1270        ```
1271        - in Python
1272        >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
1273        ...   kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7)
1274        ... )]
1275
1276    2. independently along an axis
1277        - in YAML
1278        ```yaml
1279        preprocessing:
1280          - id: fixed_zero_mean_unit_variance
1281            kwargs:
1282              axis: channel
1283              mean: [101.5, 102.5, 103.5]
1284              std: [11.7, 12.7, 13.7]
1285        ```
1286        - in Python
1287        >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
1288        ...   kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
1289        ...     axis=AxisId("channel"),
1290        ...     mean=[101.5, 102.5, 103.5],
1291        ...     std=[11.7, 12.7, 13.7],
1292        ...   )
1293        ... )]
1294    """
1295
1296    implemented_id: ClassVar[Literal["fixed_zero_mean_unit_variance"]] = (
1297        "fixed_zero_mean_unit_variance"
1298    )
1299    if TYPE_CHECKING:
1300        id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance"
1301    else:
1302        id: Literal["fixed_zero_mean_unit_variance"]
1303
1304    kwargs: Union[
1305        FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs
1306    ]

Subtract a given mean and divide by the standard deviation.

Normalize with fixed, precomputed values for FixedZeroMeanUnitVarianceKwargs.mean and FixedZeroMeanUnitVarianceKwargs.std Use FixedZeroMeanUnitVarianceAlongAxisKwargs for independent scaling along given axes.

Examples:

  1. scalar value for whole tensor
    • in YAML
preprocessing:
  - id: fixed_zero_mean_unit_variance
    kwargs:
      mean: 103.5
      std: 13.7
- in Python >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( ... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7) ... )]
  1. independently along an axis

    • in YAML
    preprocessing:
      - id: fixed_zero_mean_unit_variance
        kwargs:
          axis: channel
          mean: [101.5, 102.5, 103.5]
          std: [11.7, 12.7, 13.7]
    
    • in Python
      >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
      ...   kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
      ...     axis=AxisId("channel"),
      ...     mean=[101.5, 102.5, 103.5],
      ...     std=[11.7, 12.7, 13.7],
      ...   )
      ... )]
      
implemented_id: ClassVar[Literal['fixed_zero_mean_unit_variance']] = 'fixed_zero_mean_unit_variance'
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

id: Literal['fixed_zero_mean_unit_variance']
class ZeroMeanUnitVarianceKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1309class ZeroMeanUnitVarianceKwargs(ProcessingKwargs):
1310    """key word arguments for `ZeroMeanUnitVarianceDescr`"""
1311
1312    axes: Annotated[
1313        Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1314    ] = None
1315    """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1316    For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1317    resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1318    To normalize each sample independently leave out the 'batch' axis.
1319    Default: Scale all axes jointly."""
1320
1321    eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1322    """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`."""

key word arguments for ZeroMeanUnitVarianceDescr

axes: Annotated[Optional[Sequence[AxisId]], FieldInfo(annotation=NoneType, required=True, examples=[('batch', 'x', 'y')])]

The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') resulting in a tensor of equal shape normalized per channel, specify axes=('batch', 'x', 'y'). To normalize each sample independently leave out the 'batch' axis. Default: Scale all axes jointly.

eps: Annotated[float, Interval(gt=0, ge=None, lt=None, le=0.1)]

epsilon for numeric stability: out = (tensor - mean) / (std + eps).

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class ZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1325class ZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1326    """Subtract mean and divide by variance.
1327
1328    Examples:
1329        Subtract tensor mean and variance
1330        - in YAML
1331        ```yaml
1332        preprocessing:
1333          - id: zero_mean_unit_variance
1334        ```
1335        - in Python
1336        >>> preprocessing = [ZeroMeanUnitVarianceDescr()]
1337    """
1338
1339    implemented_id: ClassVar[Literal["zero_mean_unit_variance"]] = (
1340        "zero_mean_unit_variance"
1341    )
1342    if TYPE_CHECKING:
1343        id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance"
1344    else:
1345        id: Literal["zero_mean_unit_variance"]
1346
1347    kwargs: ZeroMeanUnitVarianceKwargs = Field(
1348        default_factory=ZeroMeanUnitVarianceKwargs.model_construct
1349    )

Subtract mean and divide by variance.

Examples:

Subtract tensor mean and variance

  • in YAML
preprocessing:
  - id: zero_mean_unit_variance
  • in Python
    >>> preprocessing = [ZeroMeanUnitVarianceDescr()]
    
implemented_id: ClassVar[Literal['zero_mean_unit_variance']] = 'zero_mean_unit_variance'
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

id: Literal['zero_mean_unit_variance']
class ScaleRangeKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1352class ScaleRangeKwargs(ProcessingKwargs):
1353    """key word arguments for `ScaleRangeDescr`
1354
1355    For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default)
1356    this processing step normalizes data to the [0, 1] intervall.
1357    For other percentiles the normalized values will partially be outside the [0, 1]
1358    intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the
1359    normalized values to a range.
1360    """
1361
1362    axes: Annotated[
1363        Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1364    ] = None
1365    """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value.
1366    For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1367    resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1368    To normalize samples independently, leave out the "batch" axis.
1369    Default: Scale all axes jointly."""
1370
1371    min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0
1372    """The lower percentile used to determine the value to align with zero."""
1373
1374    max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0
1375    """The upper percentile used to determine the value to align with one.
1376    Has to be bigger than `min_percentile`.
1377    The range is 1 to 100 instead of 0 to 100 to avoid mistakenly
1378    accepting percentiles specified in the range 0.0 to 1.0."""
1379
1380    eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1381    """Epsilon for numeric stability.
1382    `out = (tensor - v_lower) / (v_upper - v_lower + eps)`;
1383    with `v_lower,v_upper` values at the respective percentiles."""
1384
1385    reference_tensor: Optional[TensorId] = None
1386    """Tensor ID to compute the percentiles from. Default: The tensor itself.
1387    For any tensor in `inputs` only input tensor references are allowed."""
1388
1389    @field_validator("max_percentile", mode="after")
1390    @classmethod
1391    def min_smaller_max(cls, value: float, info: ValidationInfo) -> float:
1392        if (min_p := info.data["min_percentile"]) >= value:
1393            raise ValueError(f"min_percentile {min_p} >= max_percentile {value}")
1394
1395        return value

key word arguments for ScaleRangeDescr

For min_percentile=0.0 (the default) and max_percentile=100 (the default) this processing step normalizes data to the [0, 1] intervall. For other percentiles the normalized values will partially be outside the [0, 1] intervall. Use ScaleRange followed by ClipDescr if you want to limit the normalized values to a range.

axes: Annotated[Optional[Sequence[AxisId]], FieldInfo(annotation=NoneType, required=True, examples=[('batch', 'x', 'y')])]

The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value. For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') resulting in a tensor of equal shape normalized per channel, specify axes=('batch', 'x', 'y'). To normalize samples independently, leave out the "batch" axis. Default: Scale all axes jointly.

min_percentile: Annotated[float, Interval(gt=None, ge=0, lt=100, le=None)]

The lower percentile used to determine the value to align with zero.

max_percentile: Annotated[float, Interval(gt=1, ge=None, lt=None, le=100)]

The upper percentile used to determine the value to align with one. Has to be bigger than min_percentile. The range is 1 to 100 instead of 0 to 100 to avoid mistakenly accepting percentiles specified in the range 0.0 to 1.0.

eps: Annotated[float, Interval(gt=0, ge=None, lt=None, le=0.1)]

Epsilon for numeric stability. out = (tensor - v_lower) / (v_upper - v_lower + eps); with v_lower,v_upper values at the respective percentiles.

reference_tensor: Optional[TensorId]

Tensor ID to compute the percentiles from. Default: The tensor itself. For any tensor in inputs only input tensor references are allowed.

@field_validator('max_percentile', mode='after')
@classmethod
def min_smaller_max( cls, value: float, info: pydantic_core.core_schema.ValidationInfo) -> float:
1389    @field_validator("max_percentile", mode="after")
1390    @classmethod
1391    def min_smaller_max(cls, value: float, info: ValidationInfo) -> float:
1392        if (min_p := info.data["min_percentile"]) >= value:
1393            raise ValueError(f"min_percentile {min_p} >= max_percentile {value}")
1394
1395        return value
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class ScaleRangeDescr(ProcessingDescrBase):
1398class ScaleRangeDescr(ProcessingDescrBase):
1399    """Scale with percentiles.
1400
1401    Examples:
1402    1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0
1403        - in YAML
1404        ```yaml
1405        preprocessing:
1406          - id: scale_range
1407            kwargs:
1408              axes: ['y', 'x']
1409              max_percentile: 99.8
1410              min_percentile: 5.0
1411        ```
1412        - in Python
1413        >>> preprocessing = [
1414        ...     ScaleRangeDescr(
1415        ...         kwargs=ScaleRangeKwargs(
1416        ...           axes= (AxisId('y'), AxisId('x')),
1417        ...           max_percentile= 99.8,
1418        ...           min_percentile= 5.0,
1419        ...         )
1420        ...     ),
1421        ...     ClipDescr(
1422        ...         kwargs=ClipKwargs(
1423        ...             min=0.0,
1424        ...             max=1.0,
1425        ...         )
1426        ...     ),
1427        ... ]
1428
1429      2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles.
1430        - in YAML
1431        ```yaml
1432        preprocessing:
1433          - id: scale_range
1434            kwargs:
1435              axes: ['y', 'x']
1436              max_percentile: 99.8
1437              min_percentile: 5.0
1438                  - id: scale_range
1439           - id: clip
1440             kwargs:
1441              min: 0.0
1442              max: 1.0
1443        ```
1444        - in Python
1445        >>> preprocessing = [ScaleRangeDescr(
1446        ...   kwargs=ScaleRangeKwargs(
1447        ...       axes= (AxisId('y'), AxisId('x')),
1448        ...       max_percentile= 99.8,
1449        ...       min_percentile= 5.0,
1450        ...   )
1451        ... )]
1452
1453    """
1454
1455    implemented_id: ClassVar[Literal["scale_range"]] = "scale_range"
1456    if TYPE_CHECKING:
1457        id: Literal["scale_range"] = "scale_range"
1458    else:
1459        id: Literal["scale_range"]
1460    kwargs: ScaleRangeKwargs = Field(default_factory=ScaleRangeKwargs.model_construct)

Scale with percentiles.

Examples:

  1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0
    • in YAML
preprocessing:
  - id: scale_range
    kwargs:
      axes: ['y', 'x']
      max_percentile: 99.8
      min_percentile: 5.0
- in Python >>> preprocessing = [ ... ScaleRangeDescr( ... kwargs=ScaleRangeKwargs( ... axes= (AxisId('y'), AxisId('x')), ... max_percentile= 99.8, ... min_percentile= 5.0, ... ) ... ), ... ClipDescr( ... kwargs=ClipKwargs( ... min=0.0, ... max=1.0, ... ) ... ), ... ]
  1. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles.

    • in YAML
    preprocessing:
      - id: scale_range
        kwargs:
          axes: ['y', 'x']
          max_percentile: 99.8
          min_percentile: 5.0
              - id: scale_range
       - id: clip
         kwargs:
          min: 0.0
          max: 1.0
    
    • in Python
      >>> preprocessing = [ScaleRangeDescr(
      ...   kwargs=ScaleRangeKwargs(
      ...       axes= (AxisId('y'), AxisId('x')),
      ...       max_percentile= 99.8,
      ...       min_percentile= 5.0,
      ...   )
      ... )]
      
implemented_id: ClassVar[Literal['scale_range']] = 'scale_range'
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

id: Literal['scale_range']
class ScaleMeanVarianceKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1463class ScaleMeanVarianceKwargs(ProcessingKwargs):
1464    """key word arguments for `ScaleMeanVarianceKwargs`"""
1465
1466    reference_tensor: TensorId
1467    """Name of tensor to match."""
1468
1469    axes: Annotated[
1470        Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1471    ] = None
1472    """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1473    For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1474    resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1475    To normalize samples independently, leave out the 'batch' axis.
1476    Default: Scale all axes jointly."""
1477
1478    eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1479    """Epsilon for numeric stability:
1480    `out  = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`"""

key word arguments for ScaleMeanVarianceKwargs

reference_tensor: TensorId

Name of tensor to match.

axes: Annotated[Optional[Sequence[AxisId]], FieldInfo(annotation=NoneType, required=True, examples=[('batch', 'x', 'y')])]

The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') resulting in a tensor of equal shape normalized per channel, specify axes=('batch', 'x', 'y'). To normalize samples independently, leave out the 'batch' axis. Default: Scale all axes jointly.

eps: Annotated[float, Interval(gt=0, ge=None, lt=None, le=0.1)]

Epsilon for numeric stability: out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class ScaleMeanVarianceDescr(ProcessingDescrBase):
1483class ScaleMeanVarianceDescr(ProcessingDescrBase):
1484    """Scale a tensor's data distribution to match another tensor's mean/std.
1485    `out  = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`
1486    """
1487
1488    implemented_id: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance"
1489    if TYPE_CHECKING:
1490        id: Literal["scale_mean_variance"] = "scale_mean_variance"
1491    else:
1492        id: Literal["scale_mean_variance"]
1493    kwargs: ScaleMeanVarianceKwargs

Scale a tensor's data distribution to match another tensor's mean/std. out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.

implemented_id: ClassVar[Literal['scale_mean_variance']] = 'scale_mean_variance'
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

id: Literal['scale_mean_variance']
PreprocessingDescr = typing.Annotated[typing.Union[BinarizeDescr, ClipDescr, EnsureDtypeDescr, FixedZeroMeanUnitVarianceDescr, ScaleLinearDescr, ScaleRangeDescr, SigmoidDescr, SoftmaxDescr, ZeroMeanUnitVarianceDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)]
PostprocessingDescr = typing.Annotated[typing.Union[BinarizeDescr, ClipDescr, EnsureDtypeDescr, FixedZeroMeanUnitVarianceDescr, ScaleLinearDescr, ScaleMeanVarianceDescr, ScaleRangeDescr, SigmoidDescr, SoftmaxDescr, ZeroMeanUnitVarianceDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)]
class TensorDescrBase(bioimageio.spec._internal.node.Node, typing.Generic[~IO_AxisT]):
1529class TensorDescrBase(Node, Generic[IO_AxisT]):
1530    id: TensorId
1531    """Tensor id. No duplicates are allowed."""
1532
1533    description: Annotated[str, MaxLen(128)] = ""
1534    """free text description"""
1535
1536    axes: NotEmpty[Sequence[IO_AxisT]]
1537    """tensor axes"""
1538
1539    @property
1540    def shape(self):
1541        return tuple(a.size for a in self.axes)
1542
1543    @field_validator("axes", mode="after", check_fields=False)
1544    @classmethod
1545    def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]:
1546        batch_axes = [a for a in axes if a.type == "batch"]
1547        if len(batch_axes) > 1:
1548            raise ValueError(
1549                f"Only one batch axis (per tensor) allowed, but got {batch_axes}"
1550            )
1551
1552        seen_ids: Set[AxisId] = set()
1553        duplicate_axes_ids: Set[AxisId] = set()
1554        for a in axes:
1555            (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id)
1556
1557        if duplicate_axes_ids:
1558            raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}")
1559
1560        return axes
1561
1562    test_tensor: FAIR[Optional[FileDescr_]] = None
1563    """An example tensor to use for testing.
1564    Using the model with the test input tensors is expected to yield the test output tensors.
1565    Each test tensor has be a an ndarray in the
1566    [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format).
1567    The file extension must be '.npy'."""
1568
1569    sample_tensor: FAIR[Optional[FileDescr_]] = None
1570    """A sample tensor to illustrate a possible input/output for the model,
1571    The sample image primarily serves to inform a human user about an example use case
1572    and is typically stored as .hdf5, .png or .tiff.
1573    It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats)
1574    (numpy's `.npy` format is not supported).
1575    The image dimensionality has to match the number of axes specified in this tensor description.
1576    """
1577
1578    @model_validator(mode="after")
1579    def _validate_sample_tensor(self) -> Self:
1580        if self.sample_tensor is None or not get_validation_context().perform_io_checks:
1581            return self
1582
1583        reader = get_reader(self.sample_tensor.source, sha256=self.sample_tensor.sha256)
1584        tensor: NDArray[Any] = imread(
1585            reader.read(),
1586            extension=PurePosixPath(reader.original_file_name).suffix,
1587        )
1588        n_dims = len(tensor.squeeze().shape)
1589        n_dims_min = n_dims_max = len(self.axes)
1590
1591        for a in self.axes:
1592            if isinstance(a, BatchAxis):
1593                n_dims_min -= 1
1594            elif isinstance(a.size, int):
1595                if a.size == 1:
1596                    n_dims_min -= 1
1597            elif isinstance(a.size, (ParameterizedSize, DataDependentSize)):
1598                if a.size.min == 1:
1599                    n_dims_min -= 1
1600            elif isinstance(a.size, SizeReference):
1601                if a.size.offset < 2:
1602                    # size reference may result in singleton axis
1603                    n_dims_min -= 1
1604            else:
1605                assert_never(a.size)
1606
1607        n_dims_min = max(0, n_dims_min)
1608        if n_dims < n_dims_min or n_dims > n_dims_max:
1609            raise ValueError(
1610                f"Expected sample tensor to have {n_dims_min} to"
1611                + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})."
1612            )
1613
1614        return self
1615
1616    data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = (
1617        IntervalOrRatioDataDescr()
1618    )
1619    """Description of the tensor's data values, optionally per channel.
1620    If specified per channel, the data `type` needs to match across channels."""
1621
1622    @property
1623    def dtype(
1624        self,
1625    ) -> Literal[
1626        "float32",
1627        "float64",
1628        "uint8",
1629        "int8",
1630        "uint16",
1631        "int16",
1632        "uint32",
1633        "int32",
1634        "uint64",
1635        "int64",
1636        "bool",
1637    ]:
1638        """dtype as specified under `data.type` or `data[i].type`"""
1639        if isinstance(self.data, collections.abc.Sequence):
1640            return self.data[0].type
1641        else:
1642            return self.data.type
1643
1644    @field_validator("data", mode="after")
1645    @classmethod
1646    def _check_data_type_across_channels(
1647        cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]
1648    ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]:
1649        if not isinstance(value, list):
1650            return value
1651
1652        dtypes = {t.type for t in value}
1653        if len(dtypes) > 1:
1654            raise ValueError(
1655                "Tensor data descriptions per channel need to agree in their data"
1656                + f" `type`, but found {dtypes}."
1657            )
1658
1659        return value
1660
1661    @model_validator(mode="after")
1662    def _check_data_matches_channelaxis(self) -> Self:
1663        if not isinstance(self.data, (list, tuple)):
1664            return self
1665
1666        for a in self.axes:
1667            if isinstance(a, ChannelAxis):
1668                size = a.size
1669                assert isinstance(size, int)
1670                break
1671        else:
1672            return self
1673
1674        if len(self.data) != size:
1675            raise ValueError(
1676                f"Got tensor data descriptions for {len(self.data)} channels, but"
1677                + f" '{a.id}' axis has size {size}."
1678            )
1679
1680        return self
1681
1682    def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]:
1683        if len(array.shape) != len(self.axes):
1684            raise ValueError(
1685                f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})"
1686                + f" incompatible with {len(self.axes)} axes."
1687            )
1688        return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
id: TensorId

Tensor id. No duplicates are allowed.

description: Annotated[str, MaxLen(max_length=128)]

free text description

axes: Annotated[Sequence[~IO_AxisT], MinLen(min_length=1)]

tensor axes

shape
1539    @property
1540    def shape(self):
1541        return tuple(a.size for a in self.axes)
test_tensor: Annotated[Optional[Annotated[bioimageio.spec._internal.io.FileDescr, AfterValidator(func=<function wo_special_file_name at 0x7fe59c528d60>), WrapSerializer(func=<function package_file_descr_serializer at 0x7fe58e0b1c60>, return_type=PydanticUndefined, when_used='unless-none')]], AfterWarner(func=<function as_warning.<locals>.wrapper at 0x7fe58e0ccfe0>, severity=35, msg=None, context=None)]

An example tensor to use for testing. Using the model with the test input tensors is expected to yield the test output tensors. Each test tensor has be a an ndarray in the numpy.lib file format. The file extension must be '.npy'.

sample_tensor: Annotated[Optional[Annotated[bioimageio.spec._internal.io.FileDescr, AfterValidator(func=<function wo_special_file_name at 0x7fe59c528d60>), WrapSerializer(func=<function package_file_descr_serializer at 0x7fe58e0b1c60>, return_type=PydanticUndefined, when_used='unless-none')]], AfterWarner(func=<function as_warning.<locals>.wrapper at 0x7fe58e0ccfe0>, severity=35, msg=None, context=None)]

A sample tensor to illustrate a possible input/output for the model, The sample image primarily serves to inform a human user about an example use case and is typically stored as .hdf5, .png or .tiff. It has to be readable by the imageio library (numpy's .npy format is not supported). The image dimensionality has to match the number of axes specified in this tensor description.

data: Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr, Annotated[Sequence[Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr]], MinLen(min_length=1)]]

Description of the tensor's data values, optionally per channel. If specified per channel, the data type needs to match across channels.

dtype: Literal['float32', 'float64', 'uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64', 'bool']
1622    @property
1623    def dtype(
1624        self,
1625    ) -> Literal[
1626        "float32",
1627        "float64",
1628        "uint8",
1629        "int8",
1630        "uint16",
1631        "int16",
1632        "uint32",
1633        "int32",
1634        "uint64",
1635        "int64",
1636        "bool",
1637    ]:
1638        """dtype as specified under `data.type` or `data[i].type`"""
1639        if isinstance(self.data, collections.abc.Sequence):
1640            return self.data[0].type
1641        else:
1642            return self.data.type

dtype as specified under data.type or data[i].type

def get_axis_sizes_for_array( self, array: numpy.ndarray[tuple[typing.Any, ...], numpy.dtype[typing.Any]]) -> Dict[AxisId, int]:
1682    def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]:
1683        if len(array.shape) != len(self.axes):
1684            raise ValueError(
1685                f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})"
1686                + f" incompatible with {len(self.axes)} axes."
1687            )
1688        return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
class InputTensorDescr(bioimageio.spec._internal.node.Node, typing.Generic[~IO_AxisT]):
1691class InputTensorDescr(TensorDescrBase[InputAxis]):
1692    id: TensorId = TensorId("input")
1693    """Input tensor id.
1694    No duplicates are allowed across all inputs and outputs."""
1695
1696    optional: bool = False
1697    """indicates that this tensor may be `None`"""
1698
1699    preprocessing: List[PreprocessingDescr] = Field(
1700        default_factory=cast(Callable[[], List[PreprocessingDescr]], list)
1701    )
1702
1703    """Description of how this input should be preprocessed.
1704
1705    notes:
1706    - If preprocessing does not start with an 'ensure_dtype' entry, it is added
1707      to ensure an input tensor's data type matches the input tensor's data description.
1708    - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an
1709      'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally
1710      changing the data type.
1711    """
1712
1713    @model_validator(mode="after")
1714    def _validate_preprocessing_kwargs(self) -> Self:
1715        axes_ids = [a.id for a in self.axes]
1716        for p in self.preprocessing:
1717            kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
1718            if kwargs_axes is None:
1719                continue
1720
1721            if not isinstance(kwargs_axes, collections.abc.Sequence):
1722                raise ValueError(
1723                    f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}"
1724                )
1725
1726            if any(a not in axes_ids for a in kwargs_axes):
1727                raise ValueError(
1728                    "`preprocessing.i.kwargs.axes` needs to be subset of axes ids"
1729                )
1730
1731        if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
1732            dtype = self.data.type
1733        else:
1734            dtype = self.data[0].type
1735
1736        # ensure `preprocessing` begins with `EnsureDtypeDescr`
1737        if not self.preprocessing or not isinstance(
1738            self.preprocessing[0], EnsureDtypeDescr
1739        ):
1740            self.preprocessing.insert(
1741                0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1742            )
1743
1744        # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
1745        if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)):
1746            self.preprocessing.append(
1747                EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1748            )
1749
1750        return self
id: TensorId

Input tensor id. No duplicates are allowed across all inputs and outputs.

optional: bool

indicates that this tensor may be None

preprocessing: List[Annotated[Union[BinarizeDescr, ClipDescr, EnsureDtypeDescr, FixedZeroMeanUnitVarianceDescr, ScaleLinearDescr, ScaleRangeDescr, SigmoidDescr, SoftmaxDescr, ZeroMeanUnitVarianceDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]

Description of how this input should be preprocessed.

notes:

  • If preprocessing does not start with an 'ensure_dtype' entry, it is added to ensure an input tensor's data type matches the input tensor's data description.
  • If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally changing the data type.
def convert_axes( axes: str, *, shape: Union[Sequence[int], bioimageio.spec.model.v0_4.ParameterizedInputShape, bioimageio.spec.model.v0_4.ImplicitOutputShape], tensor_type: Literal['input', 'output'], halo: Optional[Sequence[int]], size_refs: Mapping[bioimageio.spec.model.v0_4.TensorName, Mapping[str, int]]):
1753def convert_axes(
1754    axes: str,
1755    *,
1756    shape: Union[
1757        Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4
1758    ],
1759    tensor_type: Literal["input", "output"],
1760    halo: Optional[Sequence[int]],
1761    size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
1762):
1763    ret: List[AnyAxis] = []
1764    for i, a in enumerate(axes):
1765        axis_type = _AXIS_TYPE_MAP.get(a, a)
1766        if axis_type == "batch":
1767            ret.append(BatchAxis())
1768            continue
1769
1770        scale = 1.0
1771        if isinstance(shape, _ParameterizedInputShape_v0_4):
1772            if shape.step[i] == 0:
1773                size = shape.min[i]
1774            else:
1775                size = ParameterizedSize(min=shape.min[i], step=shape.step[i])
1776        elif isinstance(shape, _ImplicitOutputShape_v0_4):
1777            ref_t = str(shape.reference_tensor)
1778            if ref_t.count(".") == 1:
1779                t_id, orig_a_id = ref_t.split(".")
1780            else:
1781                t_id = ref_t
1782                orig_a_id = a
1783
1784            a_id = _AXIS_ID_MAP.get(orig_a_id, a)
1785            if not (orig_scale := shape.scale[i]):
1786                # old way to insert a new axis dimension
1787                size = int(2 * shape.offset[i])
1788            else:
1789                scale = 1 / orig_scale
1790                if axis_type in ("channel", "index"):
1791                    # these axes no longer have a scale
1792                    offset_from_scale = orig_scale * size_refs.get(
1793                        _TensorName_v0_4(t_id), {}
1794                    ).get(orig_a_id, 0)
1795                else:
1796                    offset_from_scale = 0
1797                size = SizeReference(
1798                    tensor_id=TensorId(t_id),
1799                    axis_id=AxisId(a_id),
1800                    offset=int(offset_from_scale + 2 * shape.offset[i]),
1801                )
1802        else:
1803            size = shape[i]
1804
1805        if axis_type == "time":
1806            if tensor_type == "input":
1807                ret.append(TimeInputAxis(size=size, scale=scale))
1808            else:
1809                assert not isinstance(size, ParameterizedSize)
1810                if halo is None:
1811                    ret.append(TimeOutputAxis(size=size, scale=scale))
1812                else:
1813                    assert not isinstance(size, int)
1814                    ret.append(
1815                        TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i])
1816                    )
1817
1818        elif axis_type == "index":
1819            if tensor_type == "input":
1820                ret.append(IndexInputAxis(size=size))
1821            else:
1822                if isinstance(size, ParameterizedSize):
1823                    size = DataDependentSize(min=size.min)
1824
1825                ret.append(IndexOutputAxis(size=size))
1826        elif axis_type == "channel":
1827            assert not isinstance(size, ParameterizedSize)
1828            if isinstance(size, SizeReference):
1829                warnings.warn(
1830                    "Conversion of channel size from an implicit output shape may be"
1831                    + " wrong"
1832                )
1833                ret.append(
1834                    ChannelAxis(
1835                        channel_names=[
1836                            Identifier(f"channel{i}") for i in range(size.offset)
1837                        ]
1838                    )
1839                )
1840            else:
1841                ret.append(
1842                    ChannelAxis(
1843                        channel_names=[Identifier(f"channel{i}") for i in range(size)]
1844                    )
1845                )
1846        elif axis_type == "space":
1847            if tensor_type == "input":
1848                ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale))
1849            else:
1850                assert not isinstance(size, ParameterizedSize)
1851                if halo is None or halo[i] == 0:
1852                    ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale))
1853                elif isinstance(size, int):
1854                    raise NotImplementedError(
1855                        f"output axis with halo and fixed size (here {size}) not allowed"
1856                    )
1857                else:
1858                    ret.append(
1859                        SpaceOutputAxisWithHalo(
1860                            id=AxisId(a), size=size, scale=scale, halo=halo[i]
1861                        )
1862                    )
1863
1864    return ret
class OutputTensorDescr(bioimageio.spec._internal.node.Node, typing.Generic[~IO_AxisT]):
2029class OutputTensorDescr(TensorDescrBase[OutputAxis]):
2030    id: TensorId = TensorId("output")
2031    """Output tensor id.
2032    No duplicates are allowed across all inputs and outputs."""
2033
2034    postprocessing: List[PostprocessingDescr] = Field(
2035        default_factory=cast(Callable[[], List[PostprocessingDescr]], list)
2036    )
2037    """Description of how this output should be postprocessed.
2038
2039    note: `postprocessing` always ends with an 'ensure_dtype' operation.
2040          If not given this is added to cast to this tensor's `data.type`.
2041    """
2042
2043    @model_validator(mode="after")
2044    def _validate_postprocessing_kwargs(self) -> Self:
2045        axes_ids = [a.id for a in self.axes]
2046        for p in self.postprocessing:
2047            kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
2048            if kwargs_axes is None:
2049                continue
2050
2051            if not isinstance(kwargs_axes, collections.abc.Sequence):
2052                raise ValueError(
2053                    f"expected `axes` sequence, but got {type(kwargs_axes)}"
2054                )
2055
2056            if any(a not in axes_ids for a in kwargs_axes):
2057                raise ValueError("`kwargs.axes` needs to be subset of axes ids")
2058
2059        if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
2060            dtype = self.data.type
2061        else:
2062            dtype = self.data[0].type
2063
2064        # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
2065        if not self.postprocessing or not isinstance(
2066            self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)
2067        ):
2068            self.postprocessing.append(
2069                EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
2070            )
2071        return self
id: TensorId

Output tensor id. No duplicates are allowed across all inputs and outputs.

postprocessing: List[Annotated[Union[BinarizeDescr, ClipDescr, EnsureDtypeDescr, FixedZeroMeanUnitVarianceDescr, ScaleLinearDescr, ScaleMeanVarianceDescr, ScaleRangeDescr, SigmoidDescr, SoftmaxDescr, ZeroMeanUnitVarianceDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]

Description of how this output should be postprocessed.

note: postprocessing always ends with an 'ensure_dtype' operation. If not given this is added to cast to this tensor's data.type.

TensorDescr = typing.Union[InputTensorDescr, OutputTensorDescr]
def validate_tensors( tensors: Mapping[TensorId, Tuple[Union[InputTensorDescr, OutputTensorDescr], Optional[numpy.ndarray[tuple[Any, ...], numpy.dtype[Any]]]]], tensor_origin: Literal['test_tensor']):
2121def validate_tensors(
2122    tensors: Mapping[TensorId, Tuple[TensorDescr, Optional[NDArray[Any]]]],
2123    tensor_origin: Literal[
2124        "test_tensor"
2125    ],  # for more precise error messages, e.g. 'test_tensor'
2126):
2127    all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, Optional[int]]]] = {}
2128
2129    def e_msg(d: TensorDescr):
2130        return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]"
2131
2132    for descr, array in tensors.values():
2133        if array is None:
2134            axis_sizes = {a.id: None for a in descr.axes}
2135        else:
2136            try:
2137                axis_sizes = descr.get_axis_sizes_for_array(array)
2138            except ValueError as e:
2139                raise ValueError(f"{e_msg(descr)} {e}")
2140
2141        all_tensor_axes[descr.id] = {a.id: (a, axis_sizes[a.id]) for a in descr.axes}
2142
2143    for descr, array in tensors.values():
2144        if array is None:
2145            continue
2146
2147        if descr.dtype in ("float32", "float64"):
2148            invalid_test_tensor_dtype = array.dtype.name not in (
2149                "float32",
2150                "float64",
2151                "uint8",
2152                "int8",
2153                "uint16",
2154                "int16",
2155                "uint32",
2156                "int32",
2157                "uint64",
2158                "int64",
2159            )
2160        else:
2161            invalid_test_tensor_dtype = array.dtype.name != descr.dtype
2162
2163        if invalid_test_tensor_dtype:
2164            raise ValueError(
2165                f"{e_msg(descr)}.{tensor_origin}.dtype '{array.dtype.name}' does not"
2166                + f" match described dtype '{descr.dtype}'"
2167            )
2168
2169        if array.min() > -1e-4 and array.max() < 1e-4:
2170            raise ValueError(
2171                "Output values are too small for reliable testing."
2172                + f" Values <-1e5 or >=1e5 must be present in {tensor_origin}"
2173            )
2174
2175        for a in descr.axes:
2176            actual_size = all_tensor_axes[descr.id][a.id][1]
2177            if actual_size is None:
2178                continue
2179
2180            if a.size is None:
2181                continue
2182
2183            if isinstance(a.size, int):
2184                if actual_size != a.size:
2185                    raise ValueError(
2186                        f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' "
2187                        + f"has incompatible size {actual_size}, expected {a.size}"
2188                    )
2189            elif isinstance(a.size, ParameterizedSize):
2190                _ = a.size.validate_size(actual_size)
2191            elif isinstance(a.size, DataDependentSize):
2192                _ = a.size.validate_size(actual_size)
2193            elif isinstance(a.size, SizeReference):
2194                ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id)
2195                if ref_tensor_axes is None:
2196                    raise ValueError(
2197                        f"{e_msg(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor"
2198                        + f" reference '{a.size.tensor_id}'"
2199                    )
2200
2201                ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None))
2202                if ref_axis is None or ref_size is None:
2203                    raise ValueError(
2204                        f"{e_msg(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis"
2205                        + f" reference '{a.size.tensor_id}.{a.size.axis_id}"
2206                    )
2207
2208                if a.unit != ref_axis.unit:
2209                    raise ValueError(
2210                        f"{e_msg(descr)}.axes[{a.id}].size: `SizeReference` requires"
2211                        + " axis and reference axis to have the same `unit`, but"
2212                        + f" {a.unit}!={ref_axis.unit}"
2213                    )
2214
2215                if actual_size != (
2216                    expected_size := (
2217                        ref_size * ref_axis.scale / a.scale + a.size.offset
2218                    )
2219                ):
2220                    raise ValueError(
2221                        f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' of size"
2222                        + f" {actual_size} invalid for referenced size {ref_size};"
2223                        + f" expected {expected_size}"
2224                    )
2225            else:
2226                assert_never(a.size)
FileDescr_dependencies = typing.Annotated[bioimageio.spec._internal.io.FileDescr, AfterValidator(func=<function wo_special_file_name>), WrapSerializer(func=<function package_file_descr_serializer>, return_type=PydanticUndefined, when_used='unless-none'), WithSuffix(suffix=('.yaml', '.yml'), case_sensitive=True), FieldInfo(annotation=NoneType, required=True, examples=[{'source': 'environment.yaml'}])]
class ArchitectureFromFileDescr(_ArchitectureCallableDescr, bioimageio.spec._internal.io.FileDescr):
2246class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr):
2247    source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2248    """Architecture source file"""
2249
2250    @model_serializer(mode="wrap", when_used="unless-none")
2251    def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo):
2252        return package_file_descr_serializer(self, nxt, info)

A file description

source: Annotated[Union[bioimageio.spec._internal.url.HttpUrl, bioimageio.spec._internal.io.RelativeFilePath, Annotated[pathlib.Path, PathType(path_type='file'), FieldInfo(annotation=NoneType, required=True, title='FilePath')]], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')]), AfterValidator(func=<function wo_special_file_name at 0x7fe59c528d60>)]

Architecture source file

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr):
2255class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr):
2256    import_from: str
2257    """Where to import the callable from, i.e. `from <import_from> import <callable>`"""
import_from: str

Where to import the callable from, i.e. from <import_from> import <callable>

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class WeightsEntryDescrBase(bioimageio.spec._internal.io.FileDescr):
2317class WeightsEntryDescrBase(FileDescr):
2318    type: ClassVar[WeightsFormat]
2319    weights_format_name: ClassVar[str]  # human readable
2320
2321    source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2322    """Source of the weights file."""
2323
2324    authors: Optional[List[Author]] = None
2325    """Authors
2326    Either the person(s) that have trained this model resulting in the original weights file.
2327        (If this is the initial weights entry, i.e. it does not have a `parent`)
2328    Or the person(s) who have converted the weights to this weights format.
2329        (If this is a child weight, i.e. it has a `parent` field)
2330    """
2331
2332    parent: Annotated[
2333        Optional[WeightsFormat], Field(examples=["pytorch_state_dict"])
2334    ] = None
2335    """The source weights these weights were converted from.
2336    For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`,
2337    The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights.
2338    All weight entries except one (the initial set of weights resulting from training the model),
2339    need to have this field."""
2340
2341    comment: str = ""
2342    """A comment about this weights entry, for example how these weights were created."""
2343
2344    @model_validator(mode="after")
2345    def _validate(self) -> Self:
2346        if self.type == self.parent:
2347            raise ValueError("Weights entry can't be it's own parent.")
2348
2349        return self
2350
2351    @model_serializer(mode="wrap", when_used="unless-none")
2352    def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo):
2353        return package_file_descr_serializer(self, nxt, info)

A file description

type: ClassVar[Literal['keras_hdf5', 'onnx', 'pytorch_state_dict', 'tensorflow_js', 'tensorflow_saved_model_bundle', 'torchscript']]
weights_format_name: ClassVar[str]
source: Annotated[Union[bioimageio.spec._internal.url.HttpUrl, bioimageio.spec._internal.io.RelativeFilePath, Annotated[pathlib.Path, PathType(path_type='file'), FieldInfo(annotation=NoneType, required=True, title='FilePath')]], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')]), AfterValidator(func=<function wo_special_file_name at 0x7fe59c528d60>)]

Source of the weights file.

authors: Optional[List[bioimageio.spec.generic.v0_3.Author]]

Authors Either the person(s) that have trained this model resulting in the original weights file. (If this is the initial weights entry, i.e. it does not have a parent) Or the person(s) who have converted the weights to this weights format. (If this is a child weight, i.e. it has a parent field)

parent: Annotated[Optional[Literal['keras_hdf5', 'onnx', 'pytorch_state_dict', 'tensorflow_js', 'tensorflow_saved_model_bundle', 'torchscript']], FieldInfo(annotation=NoneType, required=True, examples=['pytorch_state_dict'])]

The source weights these weights were converted from. For example, if a model's weights were converted from the pytorch_state_dict format to torchscript, The pytorch_state_dict weights entry has no parent and is the parent of the torchscript weights. All weight entries except one (the initial set of weights resulting from training the model), need to have this field.

comment: str

A comment about this weights entry, for example how these weights were created.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class KerasHdf5WeightsDescr(WeightsEntryDescrBase):
2356class KerasHdf5WeightsDescr(WeightsEntryDescrBase):
2357    type = "keras_hdf5"
2358    weights_format_name: ClassVar[str] = "Keras HDF5"
2359    tensorflow_version: Version
2360    """TensorFlow version used to create these weights."""

A file description

type = 'keras_hdf5'
weights_format_name: ClassVar[str] = 'Keras HDF5'

TensorFlow version used to create these weights.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class OnnxWeightsDescr(WeightsEntryDescrBase):
2363class OnnxWeightsDescr(WeightsEntryDescrBase):
2364    type = "onnx"
2365    weights_format_name: ClassVar[str] = "ONNX"
2366    opset_version: Annotated[int, Ge(7)]
2367    """ONNX opset version"""

A file description

type = 'onnx'
weights_format_name: ClassVar[str] = 'ONNX'
opset_version: Annotated[int, Ge(ge=7)]

ONNX opset version

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class PytorchStateDictWeightsDescr(WeightsEntryDescrBase):
2370class PytorchStateDictWeightsDescr(WeightsEntryDescrBase):
2371    type = "pytorch_state_dict"
2372    weights_format_name: ClassVar[str] = "Pytorch State Dict"
2373    architecture: Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr]
2374    pytorch_version: Version
2375    """Version of the PyTorch library used.
2376    If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible.
2377    """
2378    dependencies: Optional[FileDescr_dependencies] = None
2379    """Custom depencies beyond pytorch described in a Conda environment file.
2380    Allows to specify custom dependencies, see conda docs:
2381    - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms)
2382    - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually)
2383
2384    The conda environment file should include pytorch and any version pinning has to be compatible with
2385    **pytorch_version**.
2386    """

A file description

type = 'pytorch_state_dict'
weights_format_name: ClassVar[str] = 'Pytorch State Dict'

Version of the PyTorch library used. If architecture.depencencies is specified it has to include pytorch and any version pinning has to be compatible.

dependencies: Optional[Annotated[bioimageio.spec._internal.io.FileDescr, AfterValidator(func=<function wo_special_file_name at 0x7fe59c528d60>), WrapSerializer(func=<function package_file_descr_serializer at 0x7fe58e0b1c60>, return_type=PydanticUndefined, when_used='unless-none'), WithSuffix(suffix=('.yaml', '.yml'), case_sensitive=True), FieldInfo(annotation=NoneType, required=True, examples=[{'source': 'environment.yaml'}])]]

Custom depencies beyond pytorch described in a Conda environment file. Allows to specify custom dependencies, see conda docs:

The conda environment file should include pytorch and any version pinning has to be compatible with pytorch_version.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class TensorflowJsWeightsDescr(WeightsEntryDescrBase):
2389class TensorflowJsWeightsDescr(WeightsEntryDescrBase):
2390    type = "tensorflow_js"
2391    weights_format_name: ClassVar[str] = "Tensorflow.js"
2392    tensorflow_version: Version
2393    """Version of the TensorFlow library used."""
2394
2395    source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2396    """The multi-file weights.
2397    All required files/folders should be a zip archive."""

A file description

type = 'tensorflow_js'
weights_format_name: ClassVar[str] = 'Tensorflow.js'

Version of the TensorFlow library used.

source: Annotated[Union[bioimageio.spec._internal.url.HttpUrl, bioimageio.spec._internal.io.RelativeFilePath, Annotated[pathlib.Path, PathType(path_type='file'), FieldInfo(annotation=NoneType, required=True, title='FilePath')]], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')]), AfterValidator(func=<function wo_special_file_name at 0x7fe59c528d60>)]

The multi-file weights. All required files/folders should be a zip archive.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase):
2400class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase):
2401    type = "tensorflow_saved_model_bundle"
2402    weights_format_name: ClassVar[str] = "Tensorflow Saved Model"
2403    tensorflow_version: Version
2404    """Version of the TensorFlow library used."""
2405
2406    dependencies: Optional[FileDescr_dependencies] = None
2407    """Custom dependencies beyond tensorflow.
2408    Should include tensorflow and any version pinning has to be compatible with **tensorflow_version**."""
2409
2410    source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2411    """The multi-file weights.
2412    All required files/folders should be a zip archive."""

A file description

type = 'tensorflow_saved_model_bundle'
weights_format_name: ClassVar[str] = 'Tensorflow Saved Model'

Version of the TensorFlow library used.

dependencies: Optional[Annotated[bioimageio.spec._internal.io.FileDescr, AfterValidator(func=<function wo_special_file_name at 0x7fe59c528d60>), WrapSerializer(func=<function package_file_descr_serializer at 0x7fe58e0b1c60>, return_type=PydanticUndefined, when_used='unless-none'), WithSuffix(suffix=('.yaml', '.yml'), case_sensitive=True), FieldInfo(annotation=NoneType, required=True, examples=[{'source': 'environment.yaml'}])]]

Custom dependencies beyond tensorflow. Should include tensorflow and any version pinning has to be compatible with tensorflow_version.

source: Annotated[Union[bioimageio.spec._internal.url.HttpUrl, bioimageio.spec._internal.io.RelativeFilePath, Annotated[pathlib.Path, PathType(path_type='file'), FieldInfo(annotation=NoneType, required=True, title='FilePath')]], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')]), AfterValidator(func=<function wo_special_file_name at 0x7fe59c528d60>)]

The multi-file weights. All required files/folders should be a zip archive.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class TorchscriptWeightsDescr(WeightsEntryDescrBase):
2415class TorchscriptWeightsDescr(WeightsEntryDescrBase):
2416    type = "torchscript"
2417    weights_format_name: ClassVar[str] = "TorchScript"
2418    pytorch_version: Version
2419    """Version of the PyTorch library used."""

A file description

type = 'torchscript'
weights_format_name: ClassVar[str] = 'TorchScript'

Version of the PyTorch library used.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class WeightsDescr(bioimageio.spec._internal.node.Node):
2422class WeightsDescr(Node):
2423    keras_hdf5: Optional[KerasHdf5WeightsDescr] = None
2424    onnx: Optional[OnnxWeightsDescr] = None
2425    pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None
2426    tensorflow_js: Optional[TensorflowJsWeightsDescr] = None
2427    tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = (
2428        None
2429    )
2430    torchscript: Optional[TorchscriptWeightsDescr] = None
2431
2432    @model_validator(mode="after")
2433    def check_entries(self) -> Self:
2434        entries = {wtype for wtype, entry in self if entry is not None}
2435
2436        if not entries:
2437            raise ValueError("Missing weights entry")
2438
2439        entries_wo_parent = {
2440            wtype
2441            for wtype, entry in self
2442            if entry is not None and hasattr(entry, "parent") and entry.parent is None
2443        }
2444        if len(entries_wo_parent) != 1:
2445            issue_warning(
2446                "Exactly one weights entry may not specify the `parent` field (got"
2447                + " {value}). That entry is considered the original set of model weights."
2448                + " Other weight formats are created through conversion of the orignal or"
2449                + " already converted weights. They have to reference the weights format"
2450                + " they were converted from as their `parent`.",
2451                value=len(entries_wo_parent),
2452                field="weights",
2453            )
2454
2455        for wtype, entry in self:
2456            if entry is None:
2457                continue
2458
2459            assert hasattr(entry, "type")
2460            assert hasattr(entry, "parent")
2461            assert wtype == entry.type
2462            if (
2463                entry.parent is not None and entry.parent not in entries
2464            ):  # self reference checked for `parent` field
2465                raise ValueError(
2466                    f"`weights.{wtype}.parent={entry.parent} not in specified weight"
2467                    + f" formats: {entries}"
2468                )
2469
2470        return self
2471
2472    def __getitem__(
2473        self,
2474        key: Literal[
2475            "keras_hdf5",
2476            "onnx",
2477            "pytorch_state_dict",
2478            "tensorflow_js",
2479            "tensorflow_saved_model_bundle",
2480            "torchscript",
2481        ],
2482    ):
2483        if key == "keras_hdf5":
2484            ret = self.keras_hdf5
2485        elif key == "onnx":
2486            ret = self.onnx
2487        elif key == "pytorch_state_dict":
2488            ret = self.pytorch_state_dict
2489        elif key == "tensorflow_js":
2490            ret = self.tensorflow_js
2491        elif key == "tensorflow_saved_model_bundle":
2492            ret = self.tensorflow_saved_model_bundle
2493        elif key == "torchscript":
2494            ret = self.torchscript
2495        else:
2496            raise KeyError(key)
2497
2498        if ret is None:
2499            raise KeyError(key)
2500
2501        return ret
2502
2503    @property
2504    def available_formats(self):
2505        return {
2506            **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}),
2507            **({} if self.onnx is None else {"onnx": self.onnx}),
2508            **(
2509                {}
2510                if self.pytorch_state_dict is None
2511                else {"pytorch_state_dict": self.pytorch_state_dict}
2512            ),
2513            **(
2514                {}
2515                if self.tensorflow_js is None
2516                else {"tensorflow_js": self.tensorflow_js}
2517            ),
2518            **(
2519                {}
2520                if self.tensorflow_saved_model_bundle is None
2521                else {
2522                    "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle
2523                }
2524            ),
2525            **({} if self.torchscript is None else {"torchscript": self.torchscript}),
2526        }
2527
2528    @property
2529    def missing_formats(self):
2530        return {
2531            wf for wf in get_args(WeightsFormat) if wf not in self.available_formats
2532        }
keras_hdf5: Optional[KerasHdf5WeightsDescr]
onnx: Optional[OnnxWeightsDescr]
pytorch_state_dict: Optional[PytorchStateDictWeightsDescr]
tensorflow_js: Optional[TensorflowJsWeightsDescr]
tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr]
torchscript: Optional[TorchscriptWeightsDescr]
@model_validator(mode='after')
def check_entries(self) -> Self:
2432    @model_validator(mode="after")
2433    def check_entries(self) -> Self:
2434        entries = {wtype for wtype, entry in self if entry is not None}
2435
2436        if not entries:
2437            raise ValueError("Missing weights entry")
2438
2439        entries_wo_parent = {
2440            wtype
2441            for wtype, entry in self
2442            if entry is not None and hasattr(entry, "parent") and entry.parent is None
2443        }
2444        if len(entries_wo_parent) != 1:
2445            issue_warning(
2446                "Exactly one weights entry may not specify the `parent` field (got"
2447                + " {value}). That entry is considered the original set of model weights."
2448                + " Other weight formats are created through conversion of the orignal or"
2449                + " already converted weights. They have to reference the weights format"
2450                + " they were converted from as their `parent`.",
2451                value=len(entries_wo_parent),
2452                field="weights",
2453            )
2454
2455        for wtype, entry in self:
2456            if entry is None:
2457                continue
2458
2459            assert hasattr(entry, "type")
2460            assert hasattr(entry, "parent")
2461            assert wtype == entry.type
2462            if (
2463                entry.parent is not None and entry.parent not in entries
2464            ):  # self reference checked for `parent` field
2465                raise ValueError(
2466                    f"`weights.{wtype}.parent={entry.parent} not in specified weight"
2467                    + f" formats: {entries}"
2468                )
2469
2470        return self
available_formats
2503    @property
2504    def available_formats(self):
2505        return {
2506            **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}),
2507            **({} if self.onnx is None else {"onnx": self.onnx}),
2508            **(
2509                {}
2510                if self.pytorch_state_dict is None
2511                else {"pytorch_state_dict": self.pytorch_state_dict}
2512            ),
2513            **(
2514                {}
2515                if self.tensorflow_js is None
2516                else {"tensorflow_js": self.tensorflow_js}
2517            ),
2518            **(
2519                {}
2520                if self.tensorflow_saved_model_bundle is None
2521                else {
2522                    "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle
2523                }
2524            ),
2525            **({} if self.torchscript is None else {"torchscript": self.torchscript}),
2526        }
missing_formats
2528    @property
2529    def missing_formats(self):
2530        return {
2531            wf for wf in get_args(WeightsFormat) if wf not in self.available_formats
2532        }
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class ModelId(bioimageio.spec.generic.v0_3.ResourceId):
2535class ModelId(ResourceId):
2536    pass

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.

class LinkedModel(bioimageio.spec.generic.v0_3.LinkedResourceBase):
2539class LinkedModel(LinkedResourceBase):
2540    """Reference to a bioimage.io model."""
2541
2542    id: ModelId
2543    """A valid model `id` from the bioimage.io collection."""

Reference to a bioimage.io model.

id: ModelId

A valid model id from the bioimage.io collection.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class ReproducibilityTolerance(bioimageio.spec._internal.node.Node):
2565class ReproducibilityTolerance(Node, extra="allow"):
2566    """Describes what small numerical differences -- if any -- may be tolerated
2567    in the generated output when executing in different environments.
2568
2569    A tensor element *output* is considered mismatched to the **test_tensor** if
2570    abs(*output* - **test_tensor**) > **absolute_tolerance** + **relative_tolerance** * abs(**test_tensor**).
2571    (Internally we call [numpy.testing.assert_allclose](https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html).)
2572
2573    Motivation:
2574        For testing we can request the respective deep learning frameworks to be as
2575        reproducible as possible by setting seeds and chosing deterministic algorithms,
2576        but differences in operating systems, available hardware and installed drivers
2577        may still lead to numerical differences.
2578    """
2579
2580    relative_tolerance: RelativeTolerance = 1e-3
2581    """Maximum relative tolerance of reproduced test tensor."""
2582
2583    absolute_tolerance: AbsoluteTolerance = 1e-4
2584    """Maximum absolute tolerance of reproduced test tensor."""
2585
2586    mismatched_elements_per_million: MismatchedElementsPerMillion = 100
2587    """Maximum number of mismatched elements/pixels per million to tolerate."""
2588
2589    output_ids: Sequence[TensorId] = ()
2590    """Limits the output tensor IDs these reproducibility details apply to."""
2591
2592    weights_formats: Sequence[WeightsFormat] = ()
2593    """Limits the weights formats these details apply to."""

Describes what small numerical differences -- if any -- may be tolerated in the generated output when executing in different environments.

A tensor element output is considered mismatched to the test_tensor if abs(output - test_tensor) > absolute_tolerance + relative_tolerance * abs(test_tensor). (Internally we call numpy.testing.assert_allclose.)

Motivation:

For testing we can request the respective deep learning frameworks to be as reproducible as possible by setting seeds and chosing deterministic algorithms, but differences in operating systems, available hardware and installed drivers may still lead to numerical differences.

relative_tolerance: Annotated[float, Interval(gt=None, ge=0, lt=None, le=0.01)]

Maximum relative tolerance of reproduced test tensor.

absolute_tolerance: Annotated[float, Interval(gt=None, ge=0, lt=None, le=None)]

Maximum absolute tolerance of reproduced test tensor.

mismatched_elements_per_million: Annotated[int, Interval(gt=None, ge=0, lt=None, le=1000)]

Maximum number of mismatched elements/pixels per million to tolerate.

output_ids: Sequence[TensorId]

Limits the output tensor IDs these reproducibility details apply to.

weights_formats: Sequence[Literal['keras_hdf5', 'onnx', 'pytorch_state_dict', 'tensorflow_js', 'tensorflow_saved_model_bundle', 'torchscript']]

Limits the weights formats these details apply to.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'allow', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class BioimageioConfig(bioimageio.spec._internal.node.Node):
2596class BioimageioConfig(Node, extra="allow"):
2597    reproducibility_tolerance: Sequence[ReproducibilityTolerance] = ()
2598    """Tolerances to allow when reproducing the model's test outputs
2599    from the model's test inputs.
2600    Only the first entry matching tensor id and weights format is considered.
2601    """
reproducibility_tolerance: Sequence[ReproducibilityTolerance]

Tolerances to allow when reproducing the model's test outputs from the model's test inputs. Only the first entry matching tensor id and weights format is considered.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'allow', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class Config(bioimageio.spec._internal.node.Node):
2604class Config(Node, extra="allow"):
2605    bioimageio: BioimageioConfig = Field(
2606        default_factory=BioimageioConfig.model_construct
2607    )
bioimageio: BioimageioConfig
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'allow', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

2610class ModelDescr(GenericModelDescrBase):
2611    """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights.
2612    These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
2613    """
2614
2615    implemented_format_version: ClassVar[Literal["0.5.5"]] = "0.5.5"
2616    if TYPE_CHECKING:
2617        format_version: Literal["0.5.5"] = "0.5.5"
2618    else:
2619        format_version: Literal["0.5.5"]
2620        """Version of the bioimage.io model description specification used.
2621        When creating a new model always use the latest micro/patch version described here.
2622        The `format_version` is important for any consumer software to understand how to parse the fields.
2623        """
2624
2625    implemented_type: ClassVar[Literal["model"]] = "model"
2626    if TYPE_CHECKING:
2627        type: Literal["model"] = "model"
2628    else:
2629        type: Literal["model"]
2630        """Specialized resource type 'model'"""
2631
2632    id: Optional[ModelId] = None
2633    """bioimage.io-wide unique resource identifier
2634    assigned by bioimage.io; version **un**specific."""
2635
2636    authors: FAIR[List[Author]] = Field(
2637        default_factory=cast(Callable[[], List[Author]], list)
2638    )
2639    """The authors are the creators of the model RDF and the primary points of contact."""
2640
2641    documentation: FAIR[Optional[FileSource_documentation]] = None
2642    """URL or relative path to a markdown file with additional documentation.
2643    The recommended documentation file name is `README.md`. An `.md` suffix is mandatory.
2644    The documentation should include a '#[#] Validation' (sub)section
2645    with details on how to quantitatively validate the model on unseen data."""
2646
2647    @field_validator("documentation", mode="after")
2648    @classmethod
2649    def _validate_documentation(
2650        cls, value: Optional[FileSource_documentation]
2651    ) -> Optional[FileSource_documentation]:
2652        if not get_validation_context().perform_io_checks or value is None:
2653            return value
2654
2655        doc_reader = get_reader(value)
2656        doc_content = doc_reader.read().decode(encoding="utf-8")
2657        if not re.search("#.*[vV]alidation", doc_content):
2658            issue_warning(
2659                "No '# Validation' (sub)section found in {value}.",
2660                value=value,
2661                field="documentation",
2662            )
2663
2664        return value
2665
2666    inputs: NotEmpty[Sequence[InputTensorDescr]]
2667    """Describes the input tensors expected by this model."""
2668
2669    @field_validator("inputs", mode="after")
2670    @classmethod
2671    def _validate_input_axes(
2672        cls, inputs: Sequence[InputTensorDescr]
2673    ) -> Sequence[InputTensorDescr]:
2674        input_size_refs = cls._get_axes_with_independent_size(inputs)
2675
2676        for i, ipt in enumerate(inputs):
2677            valid_independent_refs: Dict[
2678                Tuple[TensorId, AxisId],
2679                Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2680            ] = {
2681                **{
2682                    (ipt.id, a.id): (ipt, a, a.size)
2683                    for a in ipt.axes
2684                    if not isinstance(a, BatchAxis)
2685                    and isinstance(a.size, (int, ParameterizedSize))
2686                },
2687                **input_size_refs,
2688            }
2689            for a, ax in enumerate(ipt.axes):
2690                cls._validate_axis(
2691                    "inputs",
2692                    i=i,
2693                    tensor_id=ipt.id,
2694                    a=a,
2695                    axis=ax,
2696                    valid_independent_refs=valid_independent_refs,
2697                )
2698        return inputs
2699
2700    @staticmethod
2701    def _validate_axis(
2702        field_name: str,
2703        i: int,
2704        tensor_id: TensorId,
2705        a: int,
2706        axis: AnyAxis,
2707        valid_independent_refs: Dict[
2708            Tuple[TensorId, AxisId],
2709            Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2710        ],
2711    ):
2712        if isinstance(axis, BatchAxis) or isinstance(
2713            axis.size, (int, ParameterizedSize, DataDependentSize)
2714        ):
2715            return
2716        elif not isinstance(axis.size, SizeReference):
2717            assert_never(axis.size)
2718
2719        # validate axis.size SizeReference
2720        ref = (axis.size.tensor_id, axis.size.axis_id)
2721        if ref not in valid_independent_refs:
2722            raise ValueError(
2723                "Invalid tensor axis reference at"
2724                + f" {field_name}[{i}].axes[{a}].size: {axis.size}."
2725            )
2726        if ref == (tensor_id, axis.id):
2727            raise ValueError(
2728                "Self-referencing not allowed for"
2729                + f" {field_name}[{i}].axes[{a}].size: {axis.size}"
2730            )
2731        if axis.type == "channel":
2732            if valid_independent_refs[ref][1].type != "channel":
2733                raise ValueError(
2734                    "A channel axis' size may only reference another fixed size"
2735                    + " channel axis."
2736                )
2737            if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names:
2738                ref_size = valid_independent_refs[ref][2]
2739                assert isinstance(ref_size, int), (
2740                    "channel axis ref (another channel axis) has to specify fixed"
2741                    + " size"
2742                )
2743                generated_channel_names = [
2744                    Identifier(axis.channel_names.format(i=i))
2745                    for i in range(1, ref_size + 1)
2746                ]
2747                axis.channel_names = generated_channel_names
2748
2749        if (ax_unit := getattr(axis, "unit", None)) != (
2750            ref_unit := getattr(valid_independent_refs[ref][1], "unit", None)
2751        ):
2752            raise ValueError(
2753                "The units of an axis and its reference axis need to match, but"
2754                + f" '{ax_unit}' != '{ref_unit}'."
2755            )
2756        ref_axis = valid_independent_refs[ref][1]
2757        if isinstance(ref_axis, BatchAxis):
2758            raise ValueError(
2759                f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}"
2760                + " (a batch axis is not allowed as reference)."
2761            )
2762
2763        if isinstance(axis, WithHalo):
2764            min_size = axis.size.get_size(axis, ref_axis, n=0)
2765            if (min_size - 2 * axis.halo) < 1:
2766                raise ValueError(
2767                    f"axis {axis.id} with minimum size {min_size} is too small for halo"
2768                    + f" {axis.halo}."
2769                )
2770
2771            input_halo = axis.halo * axis.scale / ref_axis.scale
2772            if input_halo != int(input_halo) or input_halo % 2 == 1:
2773                raise ValueError(
2774                    f"input_halo {input_halo} (output_halo {axis.halo} *"
2775                    + f" output_scale {axis.scale} / input_scale {ref_axis.scale})"
2776                    + f"     {tensor_id}.{axis.id}."
2777                )
2778
2779    @model_validator(mode="after")
2780    def _validate_test_tensors(self) -> Self:
2781        if not get_validation_context().perform_io_checks:
2782            return self
2783
2784        test_output_arrays = [
2785            None if descr.test_tensor is None else load_array(descr.test_tensor)
2786            for descr in self.outputs
2787        ]
2788        test_input_arrays = [
2789            None if descr.test_tensor is None else load_array(descr.test_tensor)
2790            for descr in self.inputs
2791        ]
2792
2793        tensors = {
2794            descr.id: (descr, array)
2795            for descr, array in zip(
2796                chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays
2797            )
2798        }
2799        validate_tensors(tensors, tensor_origin="test_tensor")
2800
2801        output_arrays = {
2802            descr.id: array for descr, array in zip(self.outputs, test_output_arrays)
2803        }
2804        for rep_tol in self.config.bioimageio.reproducibility_tolerance:
2805            if not rep_tol.absolute_tolerance:
2806                continue
2807
2808            if rep_tol.output_ids:
2809                out_arrays = {
2810                    oid: a
2811                    for oid, a in output_arrays.items()
2812                    if oid in rep_tol.output_ids
2813                }
2814            else:
2815                out_arrays = output_arrays
2816
2817            for out_id, array in out_arrays.items():
2818                if array is None:
2819                    continue
2820
2821                if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01:
2822                    raise ValueError(
2823                        "config.bioimageio.reproducibility_tolerance.absolute_tolerance="
2824                        + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}"
2825                        + f" (1% of the maximum value of the test tensor '{out_id}')"
2826                    )
2827
2828        return self
2829
2830    @model_validator(mode="after")
2831    def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self:
2832        ipt_refs = {t.id for t in self.inputs}
2833        out_refs = {t.id for t in self.outputs}
2834        for ipt in self.inputs:
2835            for p in ipt.preprocessing:
2836                ref = p.kwargs.get("reference_tensor")
2837                if ref is None:
2838                    continue
2839                if ref not in ipt_refs:
2840                    raise ValueError(
2841                        f"`reference_tensor` '{ref}' not found. Valid input tensor"
2842                        + f" references are: {ipt_refs}."
2843                    )
2844
2845        for out in self.outputs:
2846            for p in out.postprocessing:
2847                ref = p.kwargs.get("reference_tensor")
2848                if ref is None:
2849                    continue
2850
2851                if ref not in ipt_refs and ref not in out_refs:
2852                    raise ValueError(
2853                        f"`reference_tensor` '{ref}' not found. Valid tensor references"
2854                        + f" are: {ipt_refs | out_refs}."
2855                    )
2856
2857        return self
2858
2859    # TODO: use validate funcs in validate_test_tensors
2860    # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]:
2861
2862    name: Annotated[
2863        str,
2864        RestrictCharacters(string.ascii_letters + string.digits + "_+- ()"),
2865        MinLen(5),
2866        MaxLen(128),
2867        warn(MaxLen(64), "Name longer than 64 characters.", INFO),
2868    ]
2869    """A human-readable name of this model.
2870    It should be no longer than 64 characters
2871    and may only contain letter, number, underscore, minus, parentheses and spaces.
2872    We recommend to chose a name that refers to the model's task and image modality.
2873    """
2874
2875    outputs: NotEmpty[Sequence[OutputTensorDescr]]
2876    """Describes the output tensors."""
2877
2878    @field_validator("outputs", mode="after")
2879    @classmethod
2880    def _validate_tensor_ids(
2881        cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo
2882    ) -> Sequence[OutputTensorDescr]:
2883        tensor_ids = [
2884            t.id for t in info.data.get("inputs", []) + info.data.get("outputs", [])
2885        ]
2886        duplicate_tensor_ids: List[str] = []
2887        seen: Set[str] = set()
2888        for t in tensor_ids:
2889            if t in seen:
2890                duplicate_tensor_ids.append(t)
2891
2892            seen.add(t)
2893
2894        if duplicate_tensor_ids:
2895            raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}")
2896
2897        return outputs
2898
2899    @staticmethod
2900    def _get_axes_with_parameterized_size(
2901        io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2902    ):
2903        return {
2904            f"{t.id}.{a.id}": (t, a, a.size)
2905            for t in io
2906            for a in t.axes
2907            if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize)
2908        }
2909
2910    @staticmethod
2911    def _get_axes_with_independent_size(
2912        io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2913    ):
2914        return {
2915            (t.id, a.id): (t, a, a.size)
2916            for t in io
2917            for a in t.axes
2918            if not isinstance(a, BatchAxis)
2919            and isinstance(a.size, (int, ParameterizedSize))
2920        }
2921
2922    @field_validator("outputs", mode="after")
2923    @classmethod
2924    def _validate_output_axes(
2925        cls, outputs: List[OutputTensorDescr], info: ValidationInfo
2926    ) -> List[OutputTensorDescr]:
2927        input_size_refs = cls._get_axes_with_independent_size(
2928            info.data.get("inputs", [])
2929        )
2930        output_size_refs = cls._get_axes_with_independent_size(outputs)
2931
2932        for i, out in enumerate(outputs):
2933            valid_independent_refs: Dict[
2934                Tuple[TensorId, AxisId],
2935                Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2936            ] = {
2937                **{
2938                    (out.id, a.id): (out, a, a.size)
2939                    for a in out.axes
2940                    if not isinstance(a, BatchAxis)
2941                    and isinstance(a.size, (int, ParameterizedSize))
2942                },
2943                **input_size_refs,
2944                **output_size_refs,
2945            }
2946            for a, ax in enumerate(out.axes):
2947                cls._validate_axis(
2948                    "outputs",
2949                    i,
2950                    out.id,
2951                    a,
2952                    ax,
2953                    valid_independent_refs=valid_independent_refs,
2954                )
2955
2956        return outputs
2957
2958    packaged_by: List[Author] = Field(
2959        default_factory=cast(Callable[[], List[Author]], list)
2960    )
2961    """The persons that have packaged and uploaded this model.
2962    Only required if those persons differ from the `authors`."""
2963
2964    parent: Optional[LinkedModel] = None
2965    """The model from which this model is derived, e.g. by fine-tuning the weights."""
2966
2967    @model_validator(mode="after")
2968    def _validate_parent_is_not_self(self) -> Self:
2969        if self.parent is not None and self.parent.id == self.id:
2970            raise ValueError("A model description may not reference itself as parent.")
2971
2972        return self
2973
2974    run_mode: Annotated[
2975        Optional[RunMode],
2976        warn(None, "Run mode '{value}' has limited support across consumer softwares."),
2977    ] = None
2978    """Custom run mode for this model: for more complex prediction procedures like test time
2979    data augmentation that currently cannot be expressed in the specification.
2980    No standard run modes are defined yet."""
2981
2982    timestamp: Datetime = Field(default_factory=Datetime.now)
2983    """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format
2984    with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat).
2985    (In Python a datetime object is valid, too)."""
2986
2987    training_data: Annotated[
2988        Union[None, LinkedDataset, DatasetDescr, DatasetDescr02],
2989        Field(union_mode="left_to_right"),
2990    ] = None
2991    """The dataset used to train this model"""
2992
2993    weights: Annotated[WeightsDescr, WrapSerializer(package_weights)]
2994    """The weights for this model.
2995    Weights can be given for different formats, but should otherwise be equivalent.
2996    The available weight formats determine which consumers can use this model."""
2997
2998    config: Config = Field(default_factory=Config.model_construct)
2999
3000    @model_validator(mode="after")
3001    def _add_default_cover(self) -> Self:
3002        if not get_validation_context().perform_io_checks or self.covers:
3003            return self
3004
3005        try:
3006            generated_covers = generate_covers(
3007                [
3008                    (t, load_array(t.test_tensor))
3009                    for t in self.inputs
3010                    if t.test_tensor is not None
3011                ],
3012                [
3013                    (t, load_array(t.test_tensor))
3014                    for t in self.outputs
3015                    if t.test_tensor is not None
3016                ],
3017            )
3018        except Exception as e:
3019            issue_warning(
3020                "Failed to generate cover image(s): {e}",
3021                value=self.covers,
3022                msg_context=dict(e=e),
3023                field="covers",
3024            )
3025        else:
3026            self.covers.extend(generated_covers)
3027
3028        return self
3029
3030    def get_input_test_arrays(self) -> List[NDArray[Any]]:
3031        return self._get_test_arrays(self.inputs)
3032
3033    def get_output_test_arrays(self) -> List[NDArray[Any]]:
3034        return self._get_test_arrays(self.outputs)
3035
3036    @staticmethod
3037    def _get_test_arrays(
3038        io_descr: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
3039    ):
3040        ts: List[FileDescr] = []
3041        for d in io_descr:
3042            if d.test_tensor is None:
3043                raise ValueError(
3044                    f"Failed to get test arrays: description of '{d.id}' is missing a `test_tensor`."
3045                )
3046            ts.append(d.test_tensor)
3047
3048        data = [load_array(t) for t in ts]
3049        assert all(isinstance(d, np.ndarray) for d in data)
3050        return data
3051
3052    @staticmethod
3053    def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
3054        batch_size = 1
3055        tensor_with_batchsize: Optional[TensorId] = None
3056        for tid in tensor_sizes:
3057            for aid, s in tensor_sizes[tid].items():
3058                if aid != BATCH_AXIS_ID or s == 1 or s == batch_size:
3059                    continue
3060
3061                if batch_size != 1:
3062                    assert tensor_with_batchsize is not None
3063                    raise ValueError(
3064                        f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})"
3065                    )
3066
3067                batch_size = s
3068                tensor_with_batchsize = tid
3069
3070        return batch_size
3071
3072    def get_output_tensor_sizes(
3073        self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]
3074    ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]:
3075        """Returns the tensor output sizes for given **input_sizes**.
3076        Only if **input_sizes** has a valid input shape, the tensor output size is exact.
3077        Otherwise it might be larger than the actual (valid) output"""
3078        batch_size = self.get_batch_size(input_sizes)
3079        ns = self.get_ns(input_sizes)
3080
3081        tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size)
3082        return tensor_sizes.outputs
3083
3084    def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
3085        """get parameter `n` for each parameterized axis
3086        such that the valid input size is >= the given input size"""
3087        ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {}
3088        axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs}
3089        for tid in input_sizes:
3090            for aid, s in input_sizes[tid].items():
3091                size_descr = axes[tid][aid].size
3092                if isinstance(size_descr, ParameterizedSize):
3093                    ret[(tid, aid)] = size_descr.get_n(s)
3094                elif size_descr is None or isinstance(size_descr, (int, SizeReference)):
3095                    pass
3096                else:
3097                    assert_never(size_descr)
3098
3099        return ret
3100
3101    def get_tensor_sizes(
3102        self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int
3103    ) -> _TensorSizes:
3104        axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size)
3105        return _TensorSizes(
3106            {
3107                t: {
3108                    aa: axis_sizes.inputs[(tt, aa)]
3109                    for tt, aa in axis_sizes.inputs
3110                    if tt == t
3111                }
3112                for t in {tt for tt, _ in axis_sizes.inputs}
3113            },
3114            {
3115                t: {
3116                    aa: axis_sizes.outputs[(tt, aa)]
3117                    for tt, aa in axis_sizes.outputs
3118                    if tt == t
3119                }
3120                for t in {tt for tt, _ in axis_sizes.outputs}
3121            },
3122        )
3123
3124    def get_axis_sizes(
3125        self,
3126        ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
3127        batch_size: Optional[int] = None,
3128        *,
3129        max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
3130    ) -> _AxisSizes:
3131        """Determine input and output block shape for scale factors **ns**
3132        of parameterized input sizes.
3133
3134        Args:
3135            ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id))
3136                that is parameterized as `size = min + n * step`.
3137            batch_size: The desired size of the batch dimension.
3138                If given **batch_size** overwrites any batch size present in
3139                **max_input_shape**. Default 1.
3140            max_input_shape: Limits the derived block shapes.
3141                Each axis for which the input size, parameterized by `n`, is larger
3142                than **max_input_shape** is set to the minimal value `n_min` for which
3143                this is still true.
3144                Use this for small input samples or large values of **ns**.
3145                Or simply whenever you know the full input shape.
3146
3147        Returns:
3148            Resolved axis sizes for model inputs and outputs.
3149        """
3150        max_input_shape = max_input_shape or {}
3151        if batch_size is None:
3152            for (_t_id, a_id), s in max_input_shape.items():
3153                if a_id == BATCH_AXIS_ID:
3154                    batch_size = s
3155                    break
3156            else:
3157                batch_size = 1
3158
3159        all_axes = {
3160            t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
3161        }
3162
3163        inputs: Dict[Tuple[TensorId, AxisId], int] = {}
3164        outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {}
3165
3166        def get_axis_size(a: Union[InputAxis, OutputAxis]):
3167            if isinstance(a, BatchAxis):
3168                if (t_descr.id, a.id) in ns:
3169                    logger.warning(
3170                        "Ignoring unexpected size increment factor (n) for batch axis"
3171                        + " of tensor '{}'.",
3172                        t_descr.id,
3173                    )
3174                return batch_size
3175            elif isinstance(a.size, int):
3176                if (t_descr.id, a.id) in ns:
3177                    logger.warning(
3178                        "Ignoring unexpected size increment factor (n) for fixed size"
3179                        + " axis '{}' of tensor '{}'.",
3180                        a.id,
3181                        t_descr.id,
3182                    )
3183                return a.size
3184            elif isinstance(a.size, ParameterizedSize):
3185                if (t_descr.id, a.id) not in ns:
3186                    raise ValueError(
3187                        "Size increment factor (n) missing for parametrized axis"
3188                        + f" '{a.id}' of tensor '{t_descr.id}'."
3189                    )
3190                n = ns[(t_descr.id, a.id)]
3191                s_max = max_input_shape.get((t_descr.id, a.id))
3192                if s_max is not None:
3193                    n = min(n, a.size.get_n(s_max))
3194
3195                return a.size.get_size(n)
3196
3197            elif isinstance(a.size, SizeReference):
3198                if (t_descr.id, a.id) in ns:
3199                    logger.warning(
3200                        "Ignoring unexpected size increment factor (n) for axis '{}'"
3201                        + " of tensor '{}' with size reference.",
3202                        a.id,
3203                        t_descr.id,
3204                    )
3205                assert not isinstance(a, BatchAxis)
3206                ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
3207                assert not isinstance(ref_axis, BatchAxis)
3208                ref_key = (a.size.tensor_id, a.size.axis_id)
3209                ref_size = inputs.get(ref_key, outputs.get(ref_key))
3210                assert ref_size is not None, ref_key
3211                assert not isinstance(ref_size, _DataDepSize), ref_key
3212                return a.size.get_size(
3213                    axis=a,
3214                    ref_axis=ref_axis,
3215                    ref_size=ref_size,
3216                )
3217            elif isinstance(a.size, DataDependentSize):
3218                if (t_descr.id, a.id) in ns:
3219                    logger.warning(
3220                        "Ignoring unexpected increment factor (n) for data dependent"
3221                        + " size axis '{}' of tensor '{}'.",
3222                        a.id,
3223                        t_descr.id,
3224                    )
3225                return _DataDepSize(a.size.min, a.size.max)
3226            else:
3227                assert_never(a.size)
3228
3229        # first resolve all , but the `SizeReference` input sizes
3230        for t_descr in self.inputs:
3231            for a in t_descr.axes:
3232                if not isinstance(a.size, SizeReference):
3233                    s = get_axis_size(a)
3234                    assert not isinstance(s, _DataDepSize)
3235                    inputs[t_descr.id, a.id] = s
3236
3237        # resolve all other input axis sizes
3238        for t_descr in self.inputs:
3239            for a in t_descr.axes:
3240                if isinstance(a.size, SizeReference):
3241                    s = get_axis_size(a)
3242                    assert not isinstance(s, _DataDepSize)
3243                    inputs[t_descr.id, a.id] = s
3244
3245        # resolve all output axis sizes
3246        for t_descr in self.outputs:
3247            for a in t_descr.axes:
3248                assert not isinstance(a.size, ParameterizedSize)
3249                s = get_axis_size(a)
3250                outputs[t_descr.id, a.id] = s
3251
3252        return _AxisSizes(inputs=inputs, outputs=outputs)
3253
3254    @model_validator(mode="before")
3255    @classmethod
3256    def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]:
3257        cls.convert_from_old_format_wo_validation(data)
3258        return data
3259
3260    @classmethod
3261    def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3262        """Convert metadata following an older format version to this classes' format
3263        without validating the result.
3264        """
3265        if (
3266            data.get("type") == "model"
3267            and isinstance(fv := data.get("format_version"), str)
3268            and fv.count(".") == 2
3269        ):
3270            fv_parts = fv.split(".")
3271            if any(not p.isdigit() for p in fv_parts):
3272                return
3273
3274            fv_tuple = tuple(map(int, fv_parts))
3275
3276            assert cls.implemented_format_version_tuple[0:2] == (0, 5)
3277            if fv_tuple[:2] in ((0, 3), (0, 4)):
3278                m04 = _ModelDescr_v0_4.load(data)
3279                if isinstance(m04, InvalidDescr):
3280                    try:
3281                        updated = _model_conv.convert_as_dict(
3282                            m04  # pyright: ignore[reportArgumentType]
3283                        )
3284                    except Exception as e:
3285                        logger.error(
3286                            "Failed to convert from invalid model 0.4 description."
3287                            + f"\nerror: {e}"
3288                            + "\nProceeding with model 0.5 validation without conversion."
3289                        )
3290                        updated = None
3291                else:
3292                    updated = _model_conv.convert_as_dict(m04)
3293
3294                if updated is not None:
3295                    data.clear()
3296                    data.update(updated)
3297
3298            elif fv_tuple[:2] == (0, 5):
3299                # bump patch version
3300                data["format_version"] = cls.implemented_format_version

Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. These fields are typically stored in a YAML file which we call a model resource description file (model RDF).

implemented_format_version: ClassVar[Literal['0.5.5']] = '0.5.5'
implemented_type: ClassVar[Literal['model']] = 'model'
id: Optional[ModelId]

bioimage.io-wide unique resource identifier assigned by bioimage.io; version unspecific.

authors: Annotated[List[bioimageio.spec.generic.v0_3.Author], AfterWarner(func=<function as_warning.<locals>.wrapper at 0x7fe58e0ccfe0>, severity=35, msg=None, context=None)]

The authors are the creators of the model RDF and the primary points of contact.

documentation: Annotated[Optional[Annotated[Union[bioimageio.spec._internal.url.HttpUrl, bioimageio.spec._internal.io.RelativeFilePath, Annotated[pathlib.Path, PathType(path_type='file'), FieldInfo(annotation=NoneType, required=True, title='FilePath')]], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')]), AfterValidator(func=<function wo_special_file_name at 0x7fe59c528d60>), PlainSerializer(func=<function _package_serializer at 0x7fe58e0b3060>, return_type=PydanticUndefined, when_used='unless-none'), WithSuffix(suffix='.md', case_sensitive=True), FieldInfo(annotation=NoneType, required=True, examples=['https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/README.md', 'README.md'])]], AfterWarner(func=<function as_warning.<locals>.wrapper at 0x7fe58e0ccfe0>, severity=35, msg=None, context=None)]

URL or relative path to a markdown file with additional documentation. The recommended documentation file name is README.md. An .md suffix is mandatory. The documentation should include a '#[#] Validation' (sub)section with details on how to quantitatively validate the model on unseen data.

inputs: Annotated[Sequence[InputTensorDescr], MinLen(min_length=1)]

Describes the input tensors expected by this model.

name: Annotated[str, RestrictCharacters(alphabet='abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_+- ()'), MinLen(min_length=5), MaxLen(max_length=128), AfterWarner(func=<function as_warning.<locals>.wrapper at 0x7fe58b1d7600>, severity=20, msg='Name longer than 64 characters.', context={'typ': Annotated[Any, MaxLen(max_length=64)]})]

A human-readable name of this model. It should be no longer than 64 characters and may only contain letter, number, underscore, minus, parentheses and spaces. We recommend to chose a name that refers to the model's task and image modality.

outputs: Annotated[Sequence[OutputTensorDescr], MinLen(min_length=1)]

Describes the output tensors.

The persons that have packaged and uploaded this model. Only required if those persons differ from the authors.

parent: Optional[LinkedModel]

The model from which this model is derived, e.g. by fine-tuning the weights.

run_mode: Annotated[Optional[bioimageio.spec.model.v0_4.RunMode], AfterWarner(func=<function as_warning.<locals>.wrapper at 0x7fe58b1d6480>, severity=30, msg="Run mode '{value}' has limited support across consumer softwares.", context={'typ': None})]

Custom run mode for this model: for more complex prediction procedures like test time data augmentation that currently cannot be expressed in the specification. No standard run modes are defined yet.

Timestamp in ISO 8601 format with a few restrictions listed here. (In Python a datetime object is valid, too).

training_data: Annotated[Union[NoneType, bioimageio.spec.dataset.v0_3.LinkedDataset, bioimageio.spec.DatasetDescr, bioimageio.spec.dataset.v0_2.DatasetDescr], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')])]

The dataset used to train this model

weights: Annotated[WeightsDescr, WrapSerializer(func=<function package_weights at 0x7fe58dd0e480>, return_type=PydanticUndefined, when_used='always')]

The weights for this model. Weights can be given for different formats, but should otherwise be equivalent. The available weight formats determine which consumers can use this model.

config: Config
def get_input_test_arrays(self) -> List[numpy.ndarray[tuple[Any, ...], numpy.dtype[Any]]]:
3030    def get_input_test_arrays(self) -> List[NDArray[Any]]:
3031        return self._get_test_arrays(self.inputs)
def get_output_test_arrays(self) -> List[numpy.ndarray[tuple[Any, ...], numpy.dtype[Any]]]:
3033    def get_output_test_arrays(self) -> List[NDArray[Any]]:
3034        return self._get_test_arrays(self.outputs)
@staticmethod
def get_batch_size( tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
3052    @staticmethod
3053    def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
3054        batch_size = 1
3055        tensor_with_batchsize: Optional[TensorId] = None
3056        for tid in tensor_sizes:
3057            for aid, s in tensor_sizes[tid].items():
3058                if aid != BATCH_AXIS_ID or s == 1 or s == batch_size:
3059                    continue
3060
3061                if batch_size != 1:
3062                    assert tensor_with_batchsize is not None
3063                    raise ValueError(
3064                        f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})"
3065                    )
3066
3067                batch_size = s
3068                tensor_with_batchsize = tid
3069
3070        return batch_size
def get_output_tensor_sizes( self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> Dict[TensorId, Dict[AxisId, Union[int, bioimageio.spec.model.v0_5._DataDepSize]]]:
3072    def get_output_tensor_sizes(
3073        self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]
3074    ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]:
3075        """Returns the tensor output sizes for given **input_sizes**.
3076        Only if **input_sizes** has a valid input shape, the tensor output size is exact.
3077        Otherwise it might be larger than the actual (valid) output"""
3078        batch_size = self.get_batch_size(input_sizes)
3079        ns = self.get_ns(input_sizes)
3080
3081        tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size)
3082        return tensor_sizes.outputs

Returns the tensor output sizes for given input_sizes. Only if input_sizes has a valid input shape, the tensor output size is exact. Otherwise it might be larger than the actual (valid) output

def get_ns( self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
3084    def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
3085        """get parameter `n` for each parameterized axis
3086        such that the valid input size is >= the given input size"""
3087        ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {}
3088        axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs}
3089        for tid in input_sizes:
3090            for aid, s in input_sizes[tid].items():
3091                size_descr = axes[tid][aid].size
3092                if isinstance(size_descr, ParameterizedSize):
3093                    ret[(tid, aid)] = size_descr.get_n(s)
3094                elif size_descr is None or isinstance(size_descr, (int, SizeReference)):
3095                    pass
3096                else:
3097                    assert_never(size_descr)
3098
3099        return ret

get parameter n for each parameterized axis such that the valid input size is >= the given input size

def get_tensor_sizes( self, ns: Mapping[Tuple[TensorId, AxisId], int], batch_size: int) -> bioimageio.spec.model.v0_5._TensorSizes:
3101    def get_tensor_sizes(
3102        self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int
3103    ) -> _TensorSizes:
3104        axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size)
3105        return _TensorSizes(
3106            {
3107                t: {
3108                    aa: axis_sizes.inputs[(tt, aa)]
3109                    for tt, aa in axis_sizes.inputs
3110                    if tt == t
3111                }
3112                for t in {tt for tt, _ in axis_sizes.inputs}
3113            },
3114            {
3115                t: {
3116                    aa: axis_sizes.outputs[(tt, aa)]
3117                    for tt, aa in axis_sizes.outputs
3118                    if tt == t
3119                }
3120                for t in {tt for tt, _ in axis_sizes.outputs}
3121            },
3122        )
def get_axis_sizes( self, ns: Mapping[Tuple[TensorId, AxisId], int], batch_size: Optional[int] = None, *, max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None) -> bioimageio.spec.model.v0_5._AxisSizes:
3124    def get_axis_sizes(
3125        self,
3126        ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
3127        batch_size: Optional[int] = None,
3128        *,
3129        max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
3130    ) -> _AxisSizes:
3131        """Determine input and output block shape for scale factors **ns**
3132        of parameterized input sizes.
3133
3134        Args:
3135            ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id))
3136                that is parameterized as `size = min + n * step`.
3137            batch_size: The desired size of the batch dimension.
3138                If given **batch_size** overwrites any batch size present in
3139                **max_input_shape**. Default 1.
3140            max_input_shape: Limits the derived block shapes.
3141                Each axis for which the input size, parameterized by `n`, is larger
3142                than **max_input_shape** is set to the minimal value `n_min` for which
3143                this is still true.
3144                Use this for small input samples or large values of **ns**.
3145                Or simply whenever you know the full input shape.
3146
3147        Returns:
3148            Resolved axis sizes for model inputs and outputs.
3149        """
3150        max_input_shape = max_input_shape or {}
3151        if batch_size is None:
3152            for (_t_id, a_id), s in max_input_shape.items():
3153                if a_id == BATCH_AXIS_ID:
3154                    batch_size = s
3155                    break
3156            else:
3157                batch_size = 1
3158
3159        all_axes = {
3160            t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
3161        }
3162
3163        inputs: Dict[Tuple[TensorId, AxisId], int] = {}
3164        outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {}
3165
3166        def get_axis_size(a: Union[InputAxis, OutputAxis]):
3167            if isinstance(a, BatchAxis):
3168                if (t_descr.id, a.id) in ns:
3169                    logger.warning(
3170                        "Ignoring unexpected size increment factor (n) for batch axis"
3171                        + " of tensor '{}'.",
3172                        t_descr.id,
3173                    )
3174                return batch_size
3175            elif isinstance(a.size, int):
3176                if (t_descr.id, a.id) in ns:
3177                    logger.warning(
3178                        "Ignoring unexpected size increment factor (n) for fixed size"
3179                        + " axis '{}' of tensor '{}'.",
3180                        a.id,
3181                        t_descr.id,
3182                    )
3183                return a.size
3184            elif isinstance(a.size, ParameterizedSize):
3185                if (t_descr.id, a.id) not in ns:
3186                    raise ValueError(
3187                        "Size increment factor (n) missing for parametrized axis"
3188                        + f" '{a.id}' of tensor '{t_descr.id}'."
3189                    )
3190                n = ns[(t_descr.id, a.id)]
3191                s_max = max_input_shape.get((t_descr.id, a.id))
3192                if s_max is not None:
3193                    n = min(n, a.size.get_n(s_max))
3194
3195                return a.size.get_size(n)
3196
3197            elif isinstance(a.size, SizeReference):
3198                if (t_descr.id, a.id) in ns:
3199                    logger.warning(
3200                        "Ignoring unexpected size increment factor (n) for axis '{}'"
3201                        + " of tensor '{}' with size reference.",
3202                        a.id,
3203                        t_descr.id,
3204                    )
3205                assert not isinstance(a, BatchAxis)
3206                ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
3207                assert not isinstance(ref_axis, BatchAxis)
3208                ref_key = (a.size.tensor_id, a.size.axis_id)
3209                ref_size = inputs.get(ref_key, outputs.get(ref_key))
3210                assert ref_size is not None, ref_key
3211                assert not isinstance(ref_size, _DataDepSize), ref_key
3212                return a.size.get_size(
3213                    axis=a,
3214                    ref_axis=ref_axis,
3215                    ref_size=ref_size,
3216                )
3217            elif isinstance(a.size, DataDependentSize):
3218                if (t_descr.id, a.id) in ns:
3219                    logger.warning(
3220                        "Ignoring unexpected increment factor (n) for data dependent"
3221                        + " size axis '{}' of tensor '{}'.",
3222                        a.id,
3223                        t_descr.id,
3224                    )
3225                return _DataDepSize(a.size.min, a.size.max)
3226            else:
3227                assert_never(a.size)
3228
3229        # first resolve all , but the `SizeReference` input sizes
3230        for t_descr in self.inputs:
3231            for a in t_descr.axes:
3232                if not isinstance(a.size, SizeReference):
3233                    s = get_axis_size(a)
3234                    assert not isinstance(s, _DataDepSize)
3235                    inputs[t_descr.id, a.id] = s
3236
3237        # resolve all other input axis sizes
3238        for t_descr in self.inputs:
3239            for a in t_descr.axes:
3240                if isinstance(a.size, SizeReference):
3241                    s = get_axis_size(a)
3242                    assert not isinstance(s, _DataDepSize)
3243                    inputs[t_descr.id, a.id] = s
3244
3245        # resolve all output axis sizes
3246        for t_descr in self.outputs:
3247            for a in t_descr.axes:
3248                assert not isinstance(a.size, ParameterizedSize)
3249                s = get_axis_size(a)
3250                outputs[t_descr.id, a.id] = s
3251
3252        return _AxisSizes(inputs=inputs, outputs=outputs)

Determine input and output block shape for scale factors ns of parameterized input sizes.

Arguments:
  • ns: Scale factor n for each axis (keyed by (tensor_id, axis_id)) that is parameterized as size = min + n * step.
  • batch_size: The desired size of the batch dimension. If given batch_size overwrites any batch size present in max_input_shape. Default 1.
  • max_input_shape: Limits the derived block shapes. Each axis for which the input size, parameterized by n, is larger than max_input_shape is set to the minimal value n_min for which this is still true. Use this for small input samples or large values of ns. Or simply whenever you know the full input shape.
Returns:

Resolved axis sizes for model inputs and outputs.

@classmethod
def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3260    @classmethod
3261    def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3262        """Convert metadata following an older format version to this classes' format
3263        without validating the result.
3264        """
3265        if (
3266            data.get("type") == "model"
3267            and isinstance(fv := data.get("format_version"), str)
3268            and fv.count(".") == 2
3269        ):
3270            fv_parts = fv.split(".")
3271            if any(not p.isdigit() for p in fv_parts):
3272                return
3273
3274            fv_tuple = tuple(map(int, fv_parts))
3275
3276            assert cls.implemented_format_version_tuple[0:2] == (0, 5)
3277            if fv_tuple[:2] in ((0, 3), (0, 4)):
3278                m04 = _ModelDescr_v0_4.load(data)
3279                if isinstance(m04, InvalidDescr):
3280                    try:
3281                        updated = _model_conv.convert_as_dict(
3282                            m04  # pyright: ignore[reportArgumentType]
3283                        )
3284                    except Exception as e:
3285                        logger.error(
3286                            "Failed to convert from invalid model 0.4 description."
3287                            + f"\nerror: {e}"
3288                            + "\nProceeding with model 0.5 validation without conversion."
3289                        )
3290                        updated = None
3291                else:
3292                    updated = _model_conv.convert_as_dict(m04)
3293
3294                if updated is not None:
3295                    data.clear()
3296                    data.update(updated)
3297
3298            elif fv_tuple[:2] == (0, 5):
3299                # bump patch version
3300                data["format_version"] = cls.implemented_format_version

Convert metadata following an older format version to this classes' format without validating the result.

implemented_format_version_tuple: ClassVar[Tuple[int, int, int]] = (0, 5, 5)
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

def model_post_init(self: pydantic.main.BaseModel, context: Any, /) -> None:
337def init_private_attributes(self: BaseModel, context: Any, /) -> None:
338    """This function is meant to behave like a BaseModel method to initialise private attributes.
339
340    It takes context as an argument since that's what pydantic-core passes when calling it.
341
342    Args:
343        self: The BaseModel instance.
344        context: The context.
345    """
346    if getattr(self, '__pydantic_private__', None) is None:
347        pydantic_private = {}
348        for name, private_attr in self.__private_attributes__.items():
349            default = private_attr.get_default()
350            if default is not PydanticUndefined:
351                pydantic_private[name] = default
352        object_setattr(self, '__pydantic_private__', pydantic_private)

This function is meant to behave like a BaseModel method to initialise private attributes.

It takes context as an argument since that's what pydantic-core passes when calling it.

Arguments:
  • self: The BaseModel instance.
  • context: The context.
def generate_covers( inputs: Sequence[Tuple[InputTensorDescr, numpy.ndarray[tuple[Any, ...], numpy.dtype[Any]]]], outputs: Sequence[Tuple[OutputTensorDescr, numpy.ndarray[tuple[Any, ...], numpy.dtype[Any]]]]) -> List[pathlib.Path]:
3517def generate_covers(
3518    inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]],
3519    outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]],
3520) -> List[Path]:
3521    def squeeze(
3522        data: NDArray[Any], axes: Sequence[AnyAxis]
3523    ) -> Tuple[NDArray[Any], List[AnyAxis]]:
3524        """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining"""
3525        if data.ndim != len(axes):
3526            raise ValueError(
3527                f"tensor shape {data.shape} does not match described axes"
3528                + f" {[a.id for a in axes]}"
3529            )
3530
3531        axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1]
3532        return data.squeeze(), axes
3533
3534    def normalize(
3535        data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7
3536    ) -> NDArray[np.float32]:
3537        data = data.astype("float32")
3538        data -= data.min(axis=axis, keepdims=True)
3539        data /= data.max(axis=axis, keepdims=True) + eps
3540        return data
3541
3542    def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]):
3543        original_shape = data.shape
3544        data, axes = squeeze(data, axes)
3545
3546        # take slice fom any batch or index axis if needed
3547        # and convert the first channel axis and take a slice from any additional channel axes
3548        slices: Tuple[slice, ...] = ()
3549        ndim = data.ndim
3550        ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2
3551        has_c_axis = False
3552        for i, a in enumerate(axes):
3553            s = data.shape[i]
3554            assert s > 1
3555            if (
3556                isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis))
3557                and ndim > ndim_need
3558            ):
3559                data = data[slices + (slice(s // 2 - 1, s // 2),)]
3560                ndim -= 1
3561            elif isinstance(a, ChannelAxis):
3562                if has_c_axis:
3563                    # second channel axis
3564                    data = data[slices + (slice(0, 1),)]
3565                    ndim -= 1
3566                else:
3567                    has_c_axis = True
3568                    if s == 2:
3569                        # visualize two channels with cyan and magenta
3570                        data = np.concatenate(
3571                            [
3572                                data[slices + (slice(1, 2),)],
3573                                data[slices + (slice(0, 1),)],
3574                                (
3575                                    data[slices + (slice(0, 1),)]
3576                                    + data[slices + (slice(1, 2),)]
3577                                )
3578                                / 2,  # TODO: take maximum instead?
3579                            ],
3580                            axis=i,
3581                        )
3582                    elif data.shape[i] == 3:
3583                        pass  # visualize 3 channels as RGB
3584                    else:
3585                        # visualize first 3 channels as RGB
3586                        data = data[slices + (slice(3),)]
3587
3588                    assert data.shape[i] == 3
3589
3590            slices += (slice(None),)
3591
3592        data, axes = squeeze(data, axes)
3593        assert len(axes) == ndim
3594        # take slice from z axis if needed
3595        slices = ()
3596        if ndim > ndim_need:
3597            for i, a in enumerate(axes):
3598                s = data.shape[i]
3599                if a.id == AxisId("z"):
3600                    data = data[slices + (slice(s // 2 - 1, s // 2),)]
3601                    data, axes = squeeze(data, axes)
3602                    ndim -= 1
3603                    break
3604
3605            slices += (slice(None),)
3606
3607        # take slice from any space or time axis
3608        slices = ()
3609
3610        for i, a in enumerate(axes):
3611            if ndim <= ndim_need:
3612                break
3613
3614            s = data.shape[i]
3615            assert s > 1
3616            if isinstance(
3617                a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis)
3618            ):
3619                data = data[slices + (slice(s // 2 - 1, s // 2),)]
3620                ndim -= 1
3621
3622            slices += (slice(None),)
3623
3624        del slices
3625        data, axes = squeeze(data, axes)
3626        assert len(axes) == ndim
3627
3628        if (has_c_axis and ndim != 3) or ndim != 2:
3629            raise ValueError(
3630                f"Failed to construct cover image from shape {original_shape}"
3631            )
3632
3633        if not has_c_axis:
3634            assert ndim == 2
3635            data = np.repeat(data[:, :, None], 3, axis=2)
3636            axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB"))))
3637            ndim += 1
3638
3639        assert ndim == 3
3640
3641        # transpose axis order such that longest axis comes first...
3642        axis_order: List[int] = list(np.argsort(list(data.shape)))
3643        axis_order.reverse()
3644        # ... and channel axis is last
3645        c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0]
3646        axis_order.append(axis_order.pop(c))
3647        axes = [axes[ao] for ao in axis_order]
3648        data = data.transpose(axis_order)
3649
3650        # h, w = data.shape[:2]
3651        # if h / w  in (1.0 or 2.0):
3652        #     pass
3653        # elif h / w < 2:
3654        # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images
3655
3656        norm_along = (
3657            tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None
3658        )
3659        # normalize the data and map to 8 bit
3660        data = normalize(data, norm_along)
3661        data = (data * 255).astype("uint8")
3662
3663        return data
3664
3665    def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]):
3666        assert im0.dtype == im1.dtype == np.uint8
3667        assert im0.shape == im1.shape
3668        assert im0.ndim == 3
3669        N, M, C = im0.shape
3670        assert C == 3
3671        out = np.ones((N, M, C), dtype="uint8")
3672        for c in range(C):
3673            outc = np.tril(im0[..., c])
3674            mask = outc == 0
3675            outc[mask] = np.triu(im1[..., c])[mask]
3676            out[..., c] = outc
3677
3678        return out
3679
3680    if not inputs:
3681        raise ValueError("Missing test input tensor for cover generation.")
3682
3683    if not outputs:
3684        raise ValueError("Missing test output tensor for cover generation.")
3685
3686    ipt_descr, ipt = inputs[0]
3687    out_descr, out = outputs[0]
3688
3689    ipt_img = to_2d_image(ipt, ipt_descr.axes)
3690    out_img = to_2d_image(out, out_descr.axes)
3691
3692    cover_folder = Path(mkdtemp())
3693    if ipt_img.shape == out_img.shape:
3694        covers = [cover_folder / "cover.png"]
3695        imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img))
3696    else:
3697        covers = [cover_folder / "input.png", cover_folder / "output.png"]
3698        imwrite(covers[0], ipt_img)
3699        imwrite(covers[1], out_img)
3700
3701    return covers
class TensorDescrBase[Annotated[Union[BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis], Discriminator]](bioimageio.spec._internal.node.Node, typing.Generic[~IO_AxisT]):
class TensorDescrBase[Annotated[Union[BatchAxis, ChannelAxis, IndexOutputAxis, Annotated[Union[Annotated[TimeOutputAxis, Tag], Annotated[TimeOutputAxisWithHalo, Tag]], Discriminator], Annotated[Union[Annotated[SpaceOutputAxis, Tag], Annotated[SpaceOutputAxisWithHalo, Tag]], Discriminator]], Discriminator]](bioimageio.spec._internal.node.Node, typing.Generic[~IO_AxisT]):