bioimageio.spec.model.v0_5

   1from __future__ import annotations
   2
   3import collections.abc
   4import re
   5import string
   6import warnings
   7from abc import ABC
   8from copy import deepcopy
   9from itertools import chain
  10from math import ceil
  11from pathlib import Path, PurePosixPath
  12from tempfile import mkdtemp
  13from typing import (
  14    TYPE_CHECKING,
  15    Any,
  16    Callable,
  17    ClassVar,
  18    Dict,
  19    Generic,
  20    List,
  21    Literal,
  22    Mapping,
  23    NamedTuple,
  24    Optional,
  25    Sequence,
  26    Set,
  27    Tuple,
  28    Type,
  29    TypeVar,
  30    Union,
  31    cast,
  32)
  33
  34import numpy as np
  35from annotated_types import Ge, Gt, Interval, MaxLen, MinLen, Predicate
  36from imageio.v3 import imread, imwrite  # pyright: ignore[reportUnknownVariableType]
  37from loguru import logger
  38from numpy.typing import NDArray
  39from pydantic import (
  40    AfterValidator,
  41    Discriminator,
  42    Field,
  43    RootModel,
  44    SerializationInfo,
  45    SerializerFunctionWrapHandler,
  46    StrictInt,
  47    Tag,
  48    ValidationInfo,
  49    WrapSerializer,
  50    field_validator,
  51    model_serializer,
  52    model_validator,
  53)
  54from typing_extensions import Annotated, Self, assert_never, get_args
  55
  56from .._internal.common_nodes import (
  57    InvalidDescr,
  58    Node,
  59    NodeWithExplicitlySetFields,
  60)
  61from .._internal.constants import DTYPE_LIMITS
  62from .._internal.field_warning import issue_warning, warn
  63from .._internal.io import BioimageioYamlContent as BioimageioYamlContent
  64from .._internal.io import FileDescr as FileDescr
  65from .._internal.io import (
  66    FileSource,
  67    WithSuffix,
  68    YamlValue,
  69    get_reader,
  70    wo_special_file_name,
  71)
  72from .._internal.io_basics import Sha256 as Sha256
  73from .._internal.io_packaging import (
  74    FileDescr_,
  75    FileSource_,
  76    package_file_descr_serializer,
  77)
  78from .._internal.io_utils import load_array
  79from .._internal.node_converter import Converter
  80from .._internal.type_guards import is_dict, is_sequence
  81from .._internal.types import (
  82    FAIR,
  83    AbsoluteTolerance,
  84    LowerCaseIdentifier,
  85    LowerCaseIdentifierAnno,
  86    MismatchedElementsPerMillion,
  87    RelativeTolerance,
  88)
  89from .._internal.types import Datetime as Datetime
  90from .._internal.types import Identifier as Identifier
  91from .._internal.types import NotEmpty as NotEmpty
  92from .._internal.types import SiUnit as SiUnit
  93from .._internal.url import HttpUrl as HttpUrl
  94from .._internal.validation_context import get_validation_context
  95from .._internal.validator_annotations import RestrictCharacters
  96from .._internal.version_type import Version as Version
  97from .._internal.warning_levels import INFO
  98from ..dataset.v0_2 import DatasetDescr as DatasetDescr02
  99from ..dataset.v0_2 import LinkedDataset as LinkedDataset02
 100from ..dataset.v0_3 import DatasetDescr as DatasetDescr
 101from ..dataset.v0_3 import DatasetId as DatasetId
 102from ..dataset.v0_3 import LinkedDataset as LinkedDataset
 103from ..dataset.v0_3 import Uploader as Uploader
 104from ..generic.v0_3 import (
 105    VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS,
 106)
 107from ..generic.v0_3 import Author as Author
 108from ..generic.v0_3 import BadgeDescr as BadgeDescr
 109from ..generic.v0_3 import CiteEntry as CiteEntry
 110from ..generic.v0_3 import DeprecatedLicenseId as DeprecatedLicenseId
 111from ..generic.v0_3 import Doi as Doi
 112from ..generic.v0_3 import (
 113    FileSource_documentation,
 114    GenericModelDescrBase,
 115    LinkedResourceBase,
 116    _author_conv,  # pyright: ignore[reportPrivateUsage]
 117    _maintainer_conv,  # pyright: ignore[reportPrivateUsage]
 118)
 119from ..generic.v0_3 import LicenseId as LicenseId
 120from ..generic.v0_3 import LinkedResource as LinkedResource
 121from ..generic.v0_3 import Maintainer as Maintainer
 122from ..generic.v0_3 import OrcidId as OrcidId
 123from ..generic.v0_3 import RelativeFilePath as RelativeFilePath
 124from ..generic.v0_3 import ResourceId as ResourceId
 125from .v0_4 import Author as _Author_v0_4
 126from .v0_4 import BinarizeDescr as _BinarizeDescr_v0_4
 127from .v0_4 import CallableFromDepencency as CallableFromDepencency
 128from .v0_4 import CallableFromDepencency as _CallableFromDepencency_v0_4
 129from .v0_4 import CallableFromFile as _CallableFromFile_v0_4
 130from .v0_4 import ClipDescr as _ClipDescr_v0_4
 131from .v0_4 import ClipKwargs as ClipKwargs
 132from .v0_4 import ImplicitOutputShape as _ImplicitOutputShape_v0_4
 133from .v0_4 import InputTensorDescr as _InputTensorDescr_v0_4
 134from .v0_4 import KnownRunMode as KnownRunMode
 135from .v0_4 import ModelDescr as _ModelDescr_v0_4
 136from .v0_4 import OutputTensorDescr as _OutputTensorDescr_v0_4
 137from .v0_4 import ParameterizedInputShape as _ParameterizedInputShape_v0_4
 138from .v0_4 import PostprocessingDescr as _PostprocessingDescr_v0_4
 139from .v0_4 import PreprocessingDescr as _PreprocessingDescr_v0_4
 140from .v0_4 import ProcessingKwargs as ProcessingKwargs
 141from .v0_4 import RunMode as RunMode
 142from .v0_4 import ScaleLinearDescr as _ScaleLinearDescr_v0_4
 143from .v0_4 import ScaleMeanVarianceDescr as _ScaleMeanVarianceDescr_v0_4
 144from .v0_4 import ScaleRangeDescr as _ScaleRangeDescr_v0_4
 145from .v0_4 import SigmoidDescr as _SigmoidDescr_v0_4
 146from .v0_4 import TensorName as _TensorName_v0_4
 147from .v0_4 import WeightsFormat as WeightsFormat
 148from .v0_4 import ZeroMeanUnitVarianceDescr as _ZeroMeanUnitVarianceDescr_v0_4
 149from .v0_4 import package_weights
 150
 151SpaceUnit = Literal[
 152    "attometer",
 153    "angstrom",
 154    "centimeter",
 155    "decimeter",
 156    "exameter",
 157    "femtometer",
 158    "foot",
 159    "gigameter",
 160    "hectometer",
 161    "inch",
 162    "kilometer",
 163    "megameter",
 164    "meter",
 165    "micrometer",
 166    "mile",
 167    "millimeter",
 168    "nanometer",
 169    "parsec",
 170    "petameter",
 171    "picometer",
 172    "terameter",
 173    "yard",
 174    "yoctometer",
 175    "yottameter",
 176    "zeptometer",
 177    "zettameter",
 178]
 179"""Space unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)"""
 180
 181TimeUnit = Literal[
 182    "attosecond",
 183    "centisecond",
 184    "day",
 185    "decisecond",
 186    "exasecond",
 187    "femtosecond",
 188    "gigasecond",
 189    "hectosecond",
 190    "hour",
 191    "kilosecond",
 192    "megasecond",
 193    "microsecond",
 194    "millisecond",
 195    "minute",
 196    "nanosecond",
 197    "petasecond",
 198    "picosecond",
 199    "second",
 200    "terasecond",
 201    "yoctosecond",
 202    "yottasecond",
 203    "zeptosecond",
 204    "zettasecond",
 205]
 206"""Time unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)"""
 207
 208AxisType = Literal["batch", "channel", "index", "time", "space"]
 209
 210_AXIS_TYPE_MAP: Mapping[str, AxisType] = {
 211    "b": "batch",
 212    "t": "time",
 213    "i": "index",
 214    "c": "channel",
 215    "x": "space",
 216    "y": "space",
 217    "z": "space",
 218}
 219
 220_AXIS_ID_MAP = {
 221    "b": "batch",
 222    "t": "time",
 223    "i": "index",
 224    "c": "channel",
 225}
 226
 227
 228class TensorId(LowerCaseIdentifier):
 229    root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
 230        Annotated[LowerCaseIdentifierAnno, MaxLen(32)]
 231    ]
 232
 233
 234def _normalize_axis_id(a: str):
 235    a = str(a)
 236    normalized = _AXIS_ID_MAP.get(a, a)
 237    if a != normalized:
 238        logger.opt(depth=3).warning(
 239            "Normalized axis id from '{}' to '{}'.", a, normalized
 240        )
 241    return normalized
 242
 243
 244class AxisId(LowerCaseIdentifier):
 245    root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
 246        Annotated[
 247            LowerCaseIdentifierAnno,
 248            MaxLen(16),
 249            AfterValidator(_normalize_axis_id),
 250        ]
 251    ]
 252
 253
 254def _is_batch(a: str) -> bool:
 255    return str(a) == "batch"
 256
 257
 258def _is_not_batch(a: str) -> bool:
 259    return not _is_batch(a)
 260
 261
 262NonBatchAxisId = Annotated[AxisId, Predicate(_is_not_batch)]
 263
 264PreprocessingId = Literal[
 265    "binarize",
 266    "clip",
 267    "ensure_dtype",
 268    "fixed_zero_mean_unit_variance",
 269    "scale_linear",
 270    "scale_range",
 271    "sigmoid",
 272    "softmax",
 273]
 274PostprocessingId = Literal[
 275    "binarize",
 276    "clip",
 277    "ensure_dtype",
 278    "fixed_zero_mean_unit_variance",
 279    "scale_linear",
 280    "scale_mean_variance",
 281    "scale_range",
 282    "sigmoid",
 283    "softmax",
 284    "zero_mean_unit_variance",
 285]
 286
 287
 288SAME_AS_TYPE = "<same as type>"
 289
 290
 291ParameterizedSize_N = int
 292"""
 293Annotates an integer to calculate a concrete axis size from a `ParameterizedSize`.
 294"""
 295
 296
 297class ParameterizedSize(Node):
 298    """Describes a range of valid tensor axis sizes as `size = min + n*step`.
 299
 300    - **min** and **step** are given by the model description.
 301    - All blocksize paramters n = 0,1,2,... yield a valid `size`.
 302    - A greater blocksize paramter n = 0,1,2,... results in a greater **size**.
 303      This allows to adjust the axis size more generically.
 304    """
 305
 306    N: ClassVar[Type[int]] = ParameterizedSize_N
 307    """Positive integer to parameterize this axis"""
 308
 309    min: Annotated[int, Gt(0)]
 310    step: Annotated[int, Gt(0)]
 311
 312    def validate_size(self, size: int) -> int:
 313        if size < self.min:
 314            raise ValueError(f"size {size} < {self.min}")
 315        if (size - self.min) % self.step != 0:
 316            raise ValueError(
 317                f"axis of size {size} is not parameterized by `min + n*step` ="
 318                + f" `{self.min} + n*{self.step}`"
 319            )
 320
 321        return size
 322
 323    def get_size(self, n: ParameterizedSize_N) -> int:
 324        return self.min + self.step * n
 325
 326    def get_n(self, s: int) -> ParameterizedSize_N:
 327        """return smallest n parameterizing a size greater or equal than `s`"""
 328        return ceil((s - self.min) / self.step)
 329
 330
 331class DataDependentSize(Node):
 332    min: Annotated[int, Gt(0)] = 1
 333    max: Annotated[Optional[int], Gt(1)] = None
 334
 335    @model_validator(mode="after")
 336    def _validate_max_gt_min(self):
 337        if self.max is not None and self.min >= self.max:
 338            raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}")
 339
 340        return self
 341
 342    def validate_size(self, size: int) -> int:
 343        if size < self.min:
 344            raise ValueError(f"size {size} < {self.min}")
 345
 346        if self.max is not None and size > self.max:
 347            raise ValueError(f"size {size} > {self.max}")
 348
 349        return size
 350
 351
 352class SizeReference(Node):
 353    """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis.
 354
 355    `axis.size = reference.size * reference.scale / axis.scale + offset`
 356
 357    Note:
 358    1. The axis and the referenced axis need to have the same unit (or no unit).
 359    2. Batch axes may not be referenced.
 360    3. Fractions are rounded down.
 361    4. If the reference axis is `concatenable` the referencing axis is assumed to be
 362        `concatenable` as well with the same block order.
 363
 364    Example:
 365    An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm².
 366    Let's assume that we want to express the image height h in relation to its width w
 367    instead of only accepting input images of exactly 100*49 pixels
 368    (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`).
 369
 370    >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2)
 371    >>> h = SpaceInputAxis(
 372    ...     id=AxisId("h"),
 373    ...     size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1),
 374    ...     unit="millimeter",
 375    ...     scale=4,
 376    ... )
 377    >>> print(h.size.get_size(h, w))
 378    49
 379
 380    ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49
 381    """
 382
 383    tensor_id: TensorId
 384    """tensor id of the reference axis"""
 385
 386    axis_id: AxisId
 387    """axis id of the reference axis"""
 388
 389    offset: StrictInt = 0
 390
 391    def get_size(
 392        self,
 393        axis: Union[
 394            ChannelAxis,
 395            IndexInputAxis,
 396            IndexOutputAxis,
 397            TimeInputAxis,
 398            SpaceInputAxis,
 399            TimeOutputAxis,
 400            TimeOutputAxisWithHalo,
 401            SpaceOutputAxis,
 402            SpaceOutputAxisWithHalo,
 403        ],
 404        ref_axis: Union[
 405            ChannelAxis,
 406            IndexInputAxis,
 407            IndexOutputAxis,
 408            TimeInputAxis,
 409            SpaceInputAxis,
 410            TimeOutputAxis,
 411            TimeOutputAxisWithHalo,
 412            SpaceOutputAxis,
 413            SpaceOutputAxisWithHalo,
 414        ],
 415        n: ParameterizedSize_N = 0,
 416        ref_size: Optional[int] = None,
 417    ):
 418        """Compute the concrete size for a given axis and its reference axis.
 419
 420        Args:
 421            axis: The axis this `SizeReference` is the size of.
 422            ref_axis: The reference axis to compute the size from.
 423            n: If the **ref_axis** is parameterized (of type `ParameterizedSize`)
 424                and no fixed **ref_size** is given,
 425                **n** is used to compute the size of the parameterized **ref_axis**.
 426            ref_size: Overwrite the reference size instead of deriving it from
 427                **ref_axis**
 428                (**ref_axis.scale** is still used; any given **n** is ignored).
 429        """
 430        assert (
 431            axis.size == self
 432        ), "Given `axis.size` is not defined by this `SizeReference`"
 433
 434        assert (
 435            ref_axis.id == self.axis_id
 436        ), f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}."
 437
 438        assert axis.unit == ref_axis.unit, (
 439            "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`,"
 440            f" but {axis.unit}!={ref_axis.unit}"
 441        )
 442        if ref_size is None:
 443            if isinstance(ref_axis.size, (int, float)):
 444                ref_size = ref_axis.size
 445            elif isinstance(ref_axis.size, ParameterizedSize):
 446                ref_size = ref_axis.size.get_size(n)
 447            elif isinstance(ref_axis.size, DataDependentSize):
 448                raise ValueError(
 449                    "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`."
 450                )
 451            elif isinstance(ref_axis.size, SizeReference):
 452                raise ValueError(
 453                    "Reference axis referenced in `SizeReference` may not be sized by a"
 454                    + " `SizeReference` itself."
 455                )
 456            else:
 457                assert_never(ref_axis.size)
 458
 459        return int(ref_size * ref_axis.scale / axis.scale + self.offset)
 460
 461    @staticmethod
 462    def _get_unit(
 463        axis: Union[
 464            ChannelAxis,
 465            IndexInputAxis,
 466            IndexOutputAxis,
 467            TimeInputAxis,
 468            SpaceInputAxis,
 469            TimeOutputAxis,
 470            TimeOutputAxisWithHalo,
 471            SpaceOutputAxis,
 472            SpaceOutputAxisWithHalo,
 473        ],
 474    ):
 475        return axis.unit
 476
 477
 478class AxisBase(NodeWithExplicitlySetFields):
 479    id: AxisId
 480    """An axis id unique across all axes of one tensor."""
 481
 482    description: Annotated[str, MaxLen(128)] = ""
 483    """A short description of this axis beyond its type and id."""
 484
 485
 486class WithHalo(Node):
 487    halo: Annotated[int, Ge(1)]
 488    """The halo should be cropped from the output tensor to avoid boundary effects.
 489    It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`.
 490    To document a halo that is already cropped by the model use `size.offset` instead."""
 491
 492    size: Annotated[
 493        SizeReference,
 494        Field(
 495            examples=[
 496                10,
 497                SizeReference(
 498                    tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
 499                ).model_dump(mode="json"),
 500            ]
 501        ),
 502    ]
 503    """reference to another axis with an optional offset (see `SizeReference`)"""
 504
 505
 506BATCH_AXIS_ID = AxisId("batch")
 507
 508
 509class BatchAxis(AxisBase):
 510    implemented_type: ClassVar[Literal["batch"]] = "batch"
 511    if TYPE_CHECKING:
 512        type: Literal["batch"] = "batch"
 513    else:
 514        type: Literal["batch"]
 515
 516    id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID
 517    size: Optional[Literal[1]] = None
 518    """The batch size may be fixed to 1,
 519    otherwise (the default) it may be chosen arbitrarily depending on available memory"""
 520
 521    @property
 522    def scale(self):
 523        return 1.0
 524
 525    @property
 526    def concatenable(self):
 527        return True
 528
 529    @property
 530    def unit(self):
 531        return None
 532
 533
 534class ChannelAxis(AxisBase):
 535    implemented_type: ClassVar[Literal["channel"]] = "channel"
 536    if TYPE_CHECKING:
 537        type: Literal["channel"] = "channel"
 538    else:
 539        type: Literal["channel"]
 540
 541    id: NonBatchAxisId = AxisId("channel")
 542
 543    channel_names: NotEmpty[List[Identifier]]
 544
 545    @property
 546    def size(self) -> int:
 547        return len(self.channel_names)
 548
 549    @property
 550    def concatenable(self):
 551        return False
 552
 553    @property
 554    def scale(self) -> float:
 555        return 1.0
 556
 557    @property
 558    def unit(self):
 559        return None
 560
 561
 562class IndexAxisBase(AxisBase):
 563    implemented_type: ClassVar[Literal["index"]] = "index"
 564    if TYPE_CHECKING:
 565        type: Literal["index"] = "index"
 566    else:
 567        type: Literal["index"]
 568
 569    id: NonBatchAxisId = AxisId("index")
 570
 571    @property
 572    def scale(self) -> float:
 573        return 1.0
 574
 575    @property
 576    def unit(self):
 577        return None
 578
 579
 580class _WithInputAxisSize(Node):
 581    size: Annotated[
 582        Union[Annotated[int, Gt(0)], ParameterizedSize, SizeReference],
 583        Field(
 584            examples=[
 585                10,
 586                ParameterizedSize(min=32, step=16).model_dump(mode="json"),
 587                SizeReference(
 588                    tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
 589                ).model_dump(mode="json"),
 590            ]
 591        ),
 592    ]
 593    """The size/length of this axis can be specified as
 594    - fixed integer
 595    - parameterized series of valid sizes (`ParameterizedSize`)
 596    - reference to another axis with an optional offset (`SizeReference`)
 597    """
 598
 599
 600class IndexInputAxis(IndexAxisBase, _WithInputAxisSize):
 601    concatenable: bool = False
 602    """If a model has a `concatenable` input axis, it can be processed blockwise,
 603    splitting a longer sample axis into blocks matching its input tensor description.
 604    Output axes are concatenable if they have a `SizeReference` to a concatenable
 605    input axis.
 606    """
 607
 608
 609class IndexOutputAxis(IndexAxisBase):
 610    size: Annotated[
 611        Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize],
 612        Field(
 613            examples=[
 614                10,
 615                SizeReference(
 616                    tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
 617                ).model_dump(mode="json"),
 618            ]
 619        ),
 620    ]
 621    """The size/length of this axis can be specified as
 622    - fixed integer
 623    - reference to another axis with an optional offset (`SizeReference`)
 624    - data dependent size using `DataDependentSize` (size is only known after model inference)
 625    """
 626
 627
 628class TimeAxisBase(AxisBase):
 629    implemented_type: ClassVar[Literal["time"]] = "time"
 630    if TYPE_CHECKING:
 631        type: Literal["time"] = "time"
 632    else:
 633        type: Literal["time"]
 634
 635    id: NonBatchAxisId = AxisId("time")
 636    unit: Optional[TimeUnit] = None
 637    scale: Annotated[float, Gt(0)] = 1.0
 638
 639
 640class TimeInputAxis(TimeAxisBase, _WithInputAxisSize):
 641    concatenable: bool = False
 642    """If a model has a `concatenable` input axis, it can be processed blockwise,
 643    splitting a longer sample axis into blocks matching its input tensor description.
 644    Output axes are concatenable if they have a `SizeReference` to a concatenable
 645    input axis.
 646    """
 647
 648
 649class SpaceAxisBase(AxisBase):
 650    implemented_type: ClassVar[Literal["space"]] = "space"
 651    if TYPE_CHECKING:
 652        type: Literal["space"] = "space"
 653    else:
 654        type: Literal["space"]
 655
 656    id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x")
 657    unit: Optional[SpaceUnit] = None
 658    scale: Annotated[float, Gt(0)] = 1.0
 659
 660
 661class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize):
 662    concatenable: bool = False
 663    """If a model has a `concatenable` input axis, it can be processed blockwise,
 664    splitting a longer sample axis into blocks matching its input tensor description.
 665    Output axes are concatenable if they have a `SizeReference` to a concatenable
 666    input axis.
 667    """
 668
 669
 670INPUT_AXIS_TYPES = (
 671    BatchAxis,
 672    ChannelAxis,
 673    IndexInputAxis,
 674    TimeInputAxis,
 675    SpaceInputAxis,
 676)
 677"""intended for isinstance comparisons in py<3.10"""
 678
 679_InputAxisUnion = Union[
 680    BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis
 681]
 682InputAxis = Annotated[_InputAxisUnion, Discriminator("type")]
 683
 684
 685class _WithOutputAxisSize(Node):
 686    size: Annotated[
 687        Union[Annotated[int, Gt(0)], SizeReference],
 688        Field(
 689            examples=[
 690                10,
 691                SizeReference(
 692                    tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
 693                ).model_dump(mode="json"),
 694            ]
 695        ),
 696    ]
 697    """The size/length of this axis can be specified as
 698    - fixed integer
 699    - reference to another axis with an optional offset (see `SizeReference`)
 700    """
 701
 702
 703class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize):
 704    pass
 705
 706
 707class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo):
 708    pass
 709
 710
 711def _get_halo_axis_discriminator_value(v: Any) -> Literal["with_halo", "wo_halo"]:
 712    if isinstance(v, dict):
 713        return "with_halo" if "halo" in v else "wo_halo"
 714    else:
 715        return "with_halo" if hasattr(v, "halo") else "wo_halo"
 716
 717
 718_TimeOutputAxisUnion = Annotated[
 719    Union[
 720        Annotated[TimeOutputAxis, Tag("wo_halo")],
 721        Annotated[TimeOutputAxisWithHalo, Tag("with_halo")],
 722    ],
 723    Discriminator(_get_halo_axis_discriminator_value),
 724]
 725
 726
 727class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize):
 728    pass
 729
 730
 731class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo):
 732    pass
 733
 734
 735_SpaceOutputAxisUnion = Annotated[
 736    Union[
 737        Annotated[SpaceOutputAxis, Tag("wo_halo")],
 738        Annotated[SpaceOutputAxisWithHalo, Tag("with_halo")],
 739    ],
 740    Discriminator(_get_halo_axis_discriminator_value),
 741]
 742
 743
 744_OutputAxisUnion = Union[
 745    BatchAxis, ChannelAxis, IndexOutputAxis, _TimeOutputAxisUnion, _SpaceOutputAxisUnion
 746]
 747OutputAxis = Annotated[_OutputAxisUnion, Discriminator("type")]
 748
 749OUTPUT_AXIS_TYPES = (
 750    BatchAxis,
 751    ChannelAxis,
 752    IndexOutputAxis,
 753    TimeOutputAxis,
 754    TimeOutputAxisWithHalo,
 755    SpaceOutputAxis,
 756    SpaceOutputAxisWithHalo,
 757)
 758"""intended for isinstance comparisons in py<3.10"""
 759
 760
 761AnyAxis = Union[InputAxis, OutputAxis]
 762
 763ANY_AXIS_TYPES = INPUT_AXIS_TYPES + OUTPUT_AXIS_TYPES
 764"""intended for isinstance comparisons in py<3.10"""
 765
 766TVs = Union[
 767    NotEmpty[List[int]],
 768    NotEmpty[List[float]],
 769    NotEmpty[List[bool]],
 770    NotEmpty[List[str]],
 771]
 772
 773
 774NominalOrOrdinalDType = Literal[
 775    "float32",
 776    "float64",
 777    "uint8",
 778    "int8",
 779    "uint16",
 780    "int16",
 781    "uint32",
 782    "int32",
 783    "uint64",
 784    "int64",
 785    "bool",
 786]
 787
 788
 789class NominalOrOrdinalDataDescr(Node):
 790    values: TVs
 791    """A fixed set of nominal or an ascending sequence of ordinal values.
 792    In this case `data.type` is required to be an unsigend integer type, e.g. 'uint8'.
 793    String `values` are interpreted as labels for tensor values 0, ..., N.
 794    Note: as YAML 1.2 does not natively support a "set" datatype,
 795    nominal values should be given as a sequence (aka list/array) as well.
 796    """
 797
 798    type: Annotated[
 799        NominalOrOrdinalDType,
 800        Field(
 801            examples=[
 802                "float32",
 803                "uint8",
 804                "uint16",
 805                "int64",
 806                "bool",
 807            ],
 808        ),
 809    ] = "uint8"
 810
 811    @model_validator(mode="after")
 812    def _validate_values_match_type(
 813        self,
 814    ) -> Self:
 815        incompatible: List[Any] = []
 816        for v in self.values:
 817            if self.type == "bool":
 818                if not isinstance(v, bool):
 819                    incompatible.append(v)
 820            elif self.type in DTYPE_LIMITS:
 821                if (
 822                    isinstance(v, (int, float))
 823                    and (
 824                        v < DTYPE_LIMITS[self.type].min
 825                        or v > DTYPE_LIMITS[self.type].max
 826                    )
 827                    or (isinstance(v, str) and "uint" not in self.type)
 828                    or (isinstance(v, float) and "int" in self.type)
 829                ):
 830                    incompatible.append(v)
 831            else:
 832                incompatible.append(v)
 833
 834            if len(incompatible) == 5:
 835                incompatible.append("...")
 836                break
 837
 838        if incompatible:
 839            raise ValueError(
 840                f"data type '{self.type}' incompatible with values {incompatible}"
 841            )
 842
 843        return self
 844
 845    unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None
 846
 847    @property
 848    def range(self):
 849        if isinstance(self.values[0], str):
 850            return 0, len(self.values) - 1
 851        else:
 852            return min(self.values), max(self.values)
 853
 854
 855IntervalOrRatioDType = Literal[
 856    "float32",
 857    "float64",
 858    "uint8",
 859    "int8",
 860    "uint16",
 861    "int16",
 862    "uint32",
 863    "int32",
 864    "uint64",
 865    "int64",
 866]
 867
 868
 869class IntervalOrRatioDataDescr(Node):
 870    type: Annotated[  # TODO: rename to dtype
 871        IntervalOrRatioDType,
 872        Field(
 873            examples=["float32", "float64", "uint8", "uint16"],
 874        ),
 875    ] = "float32"
 876    range: Tuple[Optional[float], Optional[float]] = (
 877        None,
 878        None,
 879    )
 880    """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor.
 881    `None` corresponds to min/max of what can be expressed by **type**."""
 882    unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit"
 883    scale: float = 1.0
 884    """Scale for data on an interval (or ratio) scale."""
 885    offset: Optional[float] = None
 886    """Offset for data on a ratio scale."""
 887
 888    @model_validator(mode="before")
 889    def _replace_inf(cls, data: Any):
 890        if is_dict(data):
 891            if "range" in data and is_sequence(data["range"]):
 892                forbidden = (
 893                    "inf",
 894                    "-inf",
 895                    ".inf",
 896                    "-.inf",
 897                    float("inf"),
 898                    float("-inf"),
 899                )
 900                if any(v in forbidden for v in data["range"]):
 901                    issue_warning("replaced 'inf' value", value=data["range"])
 902
 903                data["range"] = tuple(
 904                    (None if v in forbidden else v) for v in data["range"]
 905                )
 906
 907        return data
 908
 909
 910TensorDataDescr = Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr]
 911
 912
 913class ProcessingDescrBase(NodeWithExplicitlySetFields, ABC):
 914    """processing base class"""
 915
 916
 917class BinarizeKwargs(ProcessingKwargs):
 918    """key word arguments for `BinarizeDescr`"""
 919
 920    threshold: float
 921    """The fixed threshold"""
 922
 923
 924class BinarizeAlongAxisKwargs(ProcessingKwargs):
 925    """key word arguments for `BinarizeDescr`"""
 926
 927    threshold: NotEmpty[List[float]]
 928    """The fixed threshold values along `axis`"""
 929
 930    axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
 931    """The `threshold` axis"""
 932
 933
 934class BinarizeDescr(ProcessingDescrBase):
 935    """Binarize the tensor with a fixed threshold.
 936
 937    Values above `BinarizeKwargs.threshold`/`BinarizeAlongAxisKwargs.threshold`
 938    will be set to one, values below the threshold to zero.
 939
 940    Examples:
 941    - in YAML
 942        ```yaml
 943        postprocessing:
 944          - id: binarize
 945            kwargs:
 946              axis: 'channel'
 947              threshold: [0.25, 0.5, 0.75]
 948        ```
 949    - in Python:
 950        >>> postprocessing = [BinarizeDescr(
 951        ...   kwargs=BinarizeAlongAxisKwargs(
 952        ...       axis=AxisId('channel'),
 953        ...       threshold=[0.25, 0.5, 0.75],
 954        ...   )
 955        ... )]
 956    """
 957
 958    implemented_id: ClassVar[Literal["binarize"]] = "binarize"
 959    if TYPE_CHECKING:
 960        id: Literal["binarize"] = "binarize"
 961    else:
 962        id: Literal["binarize"]
 963    kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs]
 964
 965
 966class ClipDescr(ProcessingDescrBase):
 967    """Set tensor values below min to min and above max to max.
 968
 969    See `ScaleRangeDescr` for examples.
 970    """
 971
 972    implemented_id: ClassVar[Literal["clip"]] = "clip"
 973    if TYPE_CHECKING:
 974        id: Literal["clip"] = "clip"
 975    else:
 976        id: Literal["clip"]
 977
 978    kwargs: ClipKwargs
 979
 980
 981class EnsureDtypeKwargs(ProcessingKwargs):
 982    """key word arguments for `EnsureDtypeDescr`"""
 983
 984    dtype: Literal[
 985        "float32",
 986        "float64",
 987        "uint8",
 988        "int8",
 989        "uint16",
 990        "int16",
 991        "uint32",
 992        "int32",
 993        "uint64",
 994        "int64",
 995        "bool",
 996    ]
 997
 998
 999class EnsureDtypeDescr(ProcessingDescrBase):
1000    """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching).
1001
1002    This can for example be used to ensure the inner neural network model gets a
1003    different input tensor data type than the fully described bioimage.io model does.
1004
1005    Examples:
1006        The described bioimage.io model (incl. preprocessing) accepts any
1007        float32-compatible tensor, normalizes it with percentiles and clipping and then
1008        casts it to uint8, which is what the neural network in this example expects.
1009        - in YAML
1010            ```yaml
1011            inputs:
1012            - data:
1013                type: float32  # described bioimage.io model is compatible with any float32 input tensor
1014              preprocessing:
1015              - id: scale_range
1016                  kwargs:
1017                  axes: ['y', 'x']
1018                  max_percentile: 99.8
1019                  min_percentile: 5.0
1020              - id: clip
1021                  kwargs:
1022                  min: 0.0
1023                  max: 1.0
1024              - id: ensure_dtype  # the neural network of the model requires uint8
1025                  kwargs:
1026                  dtype: uint8
1027            ```
1028        - in Python:
1029            >>> preprocessing = [
1030            ...     ScaleRangeDescr(
1031            ...         kwargs=ScaleRangeKwargs(
1032            ...           axes= (AxisId('y'), AxisId('x')),
1033            ...           max_percentile= 99.8,
1034            ...           min_percentile= 5.0,
1035            ...         )
1036            ...     ),
1037            ...     ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)),
1038            ...     EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")),
1039            ... ]
1040    """
1041
1042    implemented_id: ClassVar[Literal["ensure_dtype"]] = "ensure_dtype"
1043    if TYPE_CHECKING:
1044        id: Literal["ensure_dtype"] = "ensure_dtype"
1045    else:
1046        id: Literal["ensure_dtype"]
1047
1048    kwargs: EnsureDtypeKwargs
1049
1050
1051class ScaleLinearKwargs(ProcessingKwargs):
1052    """Key word arguments for `ScaleLinearDescr`"""
1053
1054    gain: float = 1.0
1055    """multiplicative factor"""
1056
1057    offset: float = 0.0
1058    """additive term"""
1059
1060    @model_validator(mode="after")
1061    def _validate(self) -> Self:
1062        if self.gain == 1.0 and self.offset == 0.0:
1063            raise ValueError(
1064                "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
1065                + " != 0.0."
1066            )
1067
1068        return self
1069
1070
1071class ScaleLinearAlongAxisKwargs(ProcessingKwargs):
1072    """Key word arguments for `ScaleLinearDescr`"""
1073
1074    axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
1075    """The axis of gain and offset values."""
1076
1077    gain: Union[float, NotEmpty[List[float]]] = 1.0
1078    """multiplicative factor"""
1079
1080    offset: Union[float, NotEmpty[List[float]]] = 0.0
1081    """additive term"""
1082
1083    @model_validator(mode="after")
1084    def _validate(self) -> Self:
1085
1086        if isinstance(self.gain, list):
1087            if isinstance(self.offset, list):
1088                if len(self.gain) != len(self.offset):
1089                    raise ValueError(
1090                        f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match."
1091                    )
1092            else:
1093                self.offset = [float(self.offset)] * len(self.gain)
1094        elif isinstance(self.offset, list):
1095            self.gain = [float(self.gain)] * len(self.offset)
1096        else:
1097            raise ValueError(
1098                "Do not specify an `axis` for scalar gain and offset values."
1099            )
1100
1101        if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset):
1102            raise ValueError(
1103                "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
1104                + " != 0.0."
1105            )
1106
1107        return self
1108
1109
1110class ScaleLinearDescr(ProcessingDescrBase):
1111    """Fixed linear scaling.
1112
1113    Examples:
1114      1. Scale with scalar gain and offset
1115        - in YAML
1116        ```yaml
1117        preprocessing:
1118          - id: scale_linear
1119            kwargs:
1120              gain: 2.0
1121              offset: 3.0
1122        ```
1123        - in Python:
1124        >>> preprocessing = [
1125        ...     ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0))
1126        ... ]
1127
1128      2. Independent scaling along an axis
1129        - in YAML
1130        ```yaml
1131        preprocessing:
1132          - id: scale_linear
1133            kwargs:
1134              axis: 'channel'
1135              gain: [1.0, 2.0, 3.0]
1136        ```
1137        - in Python:
1138        >>> preprocessing = [
1139        ...     ScaleLinearDescr(
1140        ...         kwargs=ScaleLinearAlongAxisKwargs(
1141        ...             axis=AxisId("channel"),
1142        ...             gain=[1.0, 2.0, 3.0],
1143        ...         )
1144        ...     )
1145        ... ]
1146
1147    """
1148
1149    implemented_id: ClassVar[Literal["scale_linear"]] = "scale_linear"
1150    if TYPE_CHECKING:
1151        id: Literal["scale_linear"] = "scale_linear"
1152    else:
1153        id: Literal["scale_linear"]
1154    kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs]
1155
1156
1157class SigmoidDescr(ProcessingDescrBase):
1158    """The logistic sigmoid function, a.k.a. expit function.
1159
1160    Examples:
1161    - in YAML
1162        ```yaml
1163        postprocessing:
1164          - id: sigmoid
1165        ```
1166    - in Python:
1167        >>> postprocessing = [SigmoidDescr()]
1168    """
1169
1170    implemented_id: ClassVar[Literal["sigmoid"]] = "sigmoid"
1171    if TYPE_CHECKING:
1172        id: Literal["sigmoid"] = "sigmoid"
1173    else:
1174        id: Literal["sigmoid"]
1175
1176    @property
1177    def kwargs(self) -> ProcessingKwargs:
1178        """empty kwargs"""
1179        return ProcessingKwargs()
1180
1181
1182class SoftmaxKwargs(ProcessingKwargs):
1183    """key word arguments for `SoftmaxDescr`"""
1184
1185    axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] = AxisId("channel")
1186    """The axis to apply the softmax function along.
1187    Note:
1188        Defaults to 'channel' axis
1189        (which may not exist, in which case
1190        a different axis id has to be specified).
1191    """
1192
1193
1194class SoftmaxDescr(ProcessingDescrBase):
1195    """The softmax function.
1196
1197    Examples:
1198    - in YAML
1199        ```yaml
1200        postprocessing:
1201          - id: softmax
1202            kwargs:
1203              axis: channel
1204        ```
1205    - in Python:
1206        >>> postprocessing = [SoftmaxDescr(kwargs=SoftmaxKwargs(axis=AxisId("channel")))]
1207    """
1208
1209    implemented_id: ClassVar[Literal["softmax"]] = "softmax"
1210    if TYPE_CHECKING:
1211        id: Literal["softmax"] = "softmax"
1212    else:
1213        id: Literal["softmax"]
1214
1215    kwargs: SoftmaxKwargs = Field(default_factory=SoftmaxKwargs.model_construct)
1216
1217
1218class FixedZeroMeanUnitVarianceKwargs(ProcessingKwargs):
1219    """key word arguments for `FixedZeroMeanUnitVarianceDescr`"""
1220
1221    mean: float
1222    """The mean value to normalize with."""
1223
1224    std: Annotated[float, Ge(1e-6)]
1225    """The standard deviation value to normalize with."""
1226
1227
1228class FixedZeroMeanUnitVarianceAlongAxisKwargs(ProcessingKwargs):
1229    """key word arguments for `FixedZeroMeanUnitVarianceDescr`"""
1230
1231    mean: NotEmpty[List[float]]
1232    """The mean value(s) to normalize with."""
1233
1234    std: NotEmpty[List[Annotated[float, Ge(1e-6)]]]
1235    """The standard deviation value(s) to normalize with.
1236    Size must match `mean` values."""
1237
1238    axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])]
1239    """The axis of the mean/std values to normalize each entry along that dimension
1240    separately."""
1241
1242    @model_validator(mode="after")
1243    def _mean_and_std_match(self) -> Self:
1244        if len(self.mean) != len(self.std):
1245            raise ValueError(
1246                f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})"
1247                + " must match."
1248            )
1249
1250        return self
1251
1252
1253class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1254    """Subtract a given mean and divide by the standard deviation.
1255
1256    Normalize with fixed, precomputed values for
1257    `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std`
1258    Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given
1259    axes.
1260
1261    Examples:
1262    1. scalar value for whole tensor
1263        - in YAML
1264        ```yaml
1265        preprocessing:
1266          - id: fixed_zero_mean_unit_variance
1267            kwargs:
1268              mean: 103.5
1269              std: 13.7
1270        ```
1271        - in Python
1272        >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
1273        ...   kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7)
1274        ... )]
1275
1276    2. independently along an axis
1277        - in YAML
1278        ```yaml
1279        preprocessing:
1280          - id: fixed_zero_mean_unit_variance
1281            kwargs:
1282              axis: channel
1283              mean: [101.5, 102.5, 103.5]
1284              std: [11.7, 12.7, 13.7]
1285        ```
1286        - in Python
1287        >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
1288        ...   kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
1289        ...     axis=AxisId("channel"),
1290        ...     mean=[101.5, 102.5, 103.5],
1291        ...     std=[11.7, 12.7, 13.7],
1292        ...   )
1293        ... )]
1294    """
1295
1296    implemented_id: ClassVar[Literal["fixed_zero_mean_unit_variance"]] = (
1297        "fixed_zero_mean_unit_variance"
1298    )
1299    if TYPE_CHECKING:
1300        id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance"
1301    else:
1302        id: Literal["fixed_zero_mean_unit_variance"]
1303
1304    kwargs: Union[
1305        FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs
1306    ]
1307
1308
1309class ZeroMeanUnitVarianceKwargs(ProcessingKwargs):
1310    """key word arguments for `ZeroMeanUnitVarianceDescr`"""
1311
1312    axes: Annotated[
1313        Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1314    ] = None
1315    """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1316    For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1317    resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1318    To normalize each sample independently leave out the 'batch' axis.
1319    Default: Scale all axes jointly."""
1320
1321    eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1322    """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`."""
1323
1324
1325class ZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1326    """Subtract mean and divide by variance.
1327
1328    Examples:
1329        Subtract tensor mean and variance
1330        - in YAML
1331        ```yaml
1332        preprocessing:
1333          - id: zero_mean_unit_variance
1334        ```
1335        - in Python
1336        >>> preprocessing = [ZeroMeanUnitVarianceDescr()]
1337    """
1338
1339    implemented_id: ClassVar[Literal["zero_mean_unit_variance"]] = (
1340        "zero_mean_unit_variance"
1341    )
1342    if TYPE_CHECKING:
1343        id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance"
1344    else:
1345        id: Literal["zero_mean_unit_variance"]
1346
1347    kwargs: ZeroMeanUnitVarianceKwargs = Field(
1348        default_factory=ZeroMeanUnitVarianceKwargs.model_construct
1349    )
1350
1351
1352class ScaleRangeKwargs(ProcessingKwargs):
1353    """key word arguments for `ScaleRangeDescr`
1354
1355    For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default)
1356    this processing step normalizes data to the [0, 1] intervall.
1357    For other percentiles the normalized values will partially be outside the [0, 1]
1358    intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the
1359    normalized values to a range.
1360    """
1361
1362    axes: Annotated[
1363        Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1364    ] = None
1365    """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value.
1366    For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1367    resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1368    To normalize samples independently, leave out the "batch" axis.
1369    Default: Scale all axes jointly."""
1370
1371    min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0
1372    """The lower percentile used to determine the value to align with zero."""
1373
1374    max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0
1375    """The upper percentile used to determine the value to align with one.
1376    Has to be bigger than `min_percentile`.
1377    The range is 1 to 100 instead of 0 to 100 to avoid mistakenly
1378    accepting percentiles specified in the range 0.0 to 1.0."""
1379
1380    eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1381    """Epsilon for numeric stability.
1382    `out = (tensor - v_lower) / (v_upper - v_lower + eps)`;
1383    with `v_lower,v_upper` values at the respective percentiles."""
1384
1385    reference_tensor: Optional[TensorId] = None
1386    """Tensor ID to compute the percentiles from. Default: The tensor itself.
1387    For any tensor in `inputs` only input tensor references are allowed."""
1388
1389    @field_validator("max_percentile", mode="after")
1390    @classmethod
1391    def min_smaller_max(cls, value: float, info: ValidationInfo) -> float:
1392        if (min_p := info.data["min_percentile"]) >= value:
1393            raise ValueError(f"min_percentile {min_p} >= max_percentile {value}")
1394
1395        return value
1396
1397
1398class ScaleRangeDescr(ProcessingDescrBase):
1399    """Scale with percentiles.
1400
1401    Examples:
1402    1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0
1403        - in YAML
1404        ```yaml
1405        preprocessing:
1406          - id: scale_range
1407            kwargs:
1408              axes: ['y', 'x']
1409              max_percentile: 99.8
1410              min_percentile: 5.0
1411        ```
1412        - in Python
1413        >>> preprocessing = [
1414        ...     ScaleRangeDescr(
1415        ...         kwargs=ScaleRangeKwargs(
1416        ...           axes= (AxisId('y'), AxisId('x')),
1417        ...           max_percentile= 99.8,
1418        ...           min_percentile= 5.0,
1419        ...         )
1420        ...     ),
1421        ...     ClipDescr(
1422        ...         kwargs=ClipKwargs(
1423        ...             min=0.0,
1424        ...             max=1.0,
1425        ...         )
1426        ...     ),
1427        ... ]
1428
1429      2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles.
1430        - in YAML
1431        ```yaml
1432        preprocessing:
1433          - id: scale_range
1434            kwargs:
1435              axes: ['y', 'x']
1436              max_percentile: 99.8
1437              min_percentile: 5.0
1438                  - id: scale_range
1439           - id: clip
1440             kwargs:
1441              min: 0.0
1442              max: 1.0
1443        ```
1444        - in Python
1445        >>> preprocessing = [ScaleRangeDescr(
1446        ...   kwargs=ScaleRangeKwargs(
1447        ...       axes= (AxisId('y'), AxisId('x')),
1448        ...       max_percentile= 99.8,
1449        ...       min_percentile= 5.0,
1450        ...   )
1451        ... )]
1452
1453    """
1454
1455    implemented_id: ClassVar[Literal["scale_range"]] = "scale_range"
1456    if TYPE_CHECKING:
1457        id: Literal["scale_range"] = "scale_range"
1458    else:
1459        id: Literal["scale_range"]
1460    kwargs: ScaleRangeKwargs = Field(default_factory=ScaleRangeKwargs.model_construct)
1461
1462
1463class ScaleMeanVarianceKwargs(ProcessingKwargs):
1464    """key word arguments for `ScaleMeanVarianceKwargs`"""
1465
1466    reference_tensor: TensorId
1467    """Name of tensor to match."""
1468
1469    axes: Annotated[
1470        Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1471    ] = None
1472    """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1473    For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1474    resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1475    To normalize samples independently, leave out the 'batch' axis.
1476    Default: Scale all axes jointly."""
1477
1478    eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1479    """Epsilon for numeric stability:
1480    `out  = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`"""
1481
1482
1483class ScaleMeanVarianceDescr(ProcessingDescrBase):
1484    """Scale a tensor's data distribution to match another tensor's mean/std.
1485    `out  = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`
1486    """
1487
1488    implemented_id: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance"
1489    if TYPE_CHECKING:
1490        id: Literal["scale_mean_variance"] = "scale_mean_variance"
1491    else:
1492        id: Literal["scale_mean_variance"]
1493    kwargs: ScaleMeanVarianceKwargs
1494
1495
1496PreprocessingDescr = Annotated[
1497    Union[
1498        BinarizeDescr,
1499        ClipDescr,
1500        EnsureDtypeDescr,
1501        FixedZeroMeanUnitVarianceDescr,
1502        ScaleLinearDescr,
1503        ScaleRangeDescr,
1504        SigmoidDescr,
1505        SoftmaxDescr,
1506        ZeroMeanUnitVarianceDescr,
1507    ],
1508    Discriminator("id"),
1509]
1510PostprocessingDescr = Annotated[
1511    Union[
1512        BinarizeDescr,
1513        ClipDescr,
1514        EnsureDtypeDescr,
1515        FixedZeroMeanUnitVarianceDescr,
1516        ScaleLinearDescr,
1517        ScaleMeanVarianceDescr,
1518        ScaleRangeDescr,
1519        SigmoidDescr,
1520        SoftmaxDescr,
1521        ZeroMeanUnitVarianceDescr,
1522    ],
1523    Discriminator("id"),
1524]
1525
1526IO_AxisT = TypeVar("IO_AxisT", InputAxis, OutputAxis)
1527
1528
1529class TensorDescrBase(Node, Generic[IO_AxisT]):
1530    id: TensorId
1531    """Tensor id. No duplicates are allowed."""
1532
1533    description: Annotated[str, MaxLen(128)] = ""
1534    """free text description"""
1535
1536    axes: NotEmpty[Sequence[IO_AxisT]]
1537    """tensor axes"""
1538
1539    @property
1540    def shape(self):
1541        return tuple(a.size for a in self.axes)
1542
1543    @field_validator("axes", mode="after", check_fields=False)
1544    @classmethod
1545    def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]:
1546        batch_axes = [a for a in axes if a.type == "batch"]
1547        if len(batch_axes) > 1:
1548            raise ValueError(
1549                f"Only one batch axis (per tensor) allowed, but got {batch_axes}"
1550            )
1551
1552        seen_ids: Set[AxisId] = set()
1553        duplicate_axes_ids: Set[AxisId] = set()
1554        for a in axes:
1555            (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id)
1556
1557        if duplicate_axes_ids:
1558            raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}")
1559
1560        return axes
1561
1562    test_tensor: FAIR[Optional[FileDescr_]] = None
1563    """An example tensor to use for testing.
1564    Using the model with the test input tensors is expected to yield the test output tensors.
1565    Each test tensor has be a an ndarray in the
1566    [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format).
1567    The file extension must be '.npy'."""
1568
1569    sample_tensor: FAIR[Optional[FileDescr_]] = None
1570    """A sample tensor to illustrate a possible input/output for the model,
1571    The sample image primarily serves to inform a human user about an example use case
1572    and is typically stored as .hdf5, .png or .tiff.
1573    It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats)
1574    (numpy's `.npy` format is not supported).
1575    The image dimensionality has to match the number of axes specified in this tensor description.
1576    """
1577
1578    @model_validator(mode="after")
1579    def _validate_sample_tensor(self) -> Self:
1580        if self.sample_tensor is None or not get_validation_context().perform_io_checks:
1581            return self
1582
1583        reader = get_reader(self.sample_tensor.source, sha256=self.sample_tensor.sha256)
1584        tensor: NDArray[Any] = imread(
1585            reader.read(),
1586            extension=PurePosixPath(reader.original_file_name).suffix,
1587        )
1588        n_dims = len(tensor.squeeze().shape)
1589        n_dims_min = n_dims_max = len(self.axes)
1590
1591        for a in self.axes:
1592            if isinstance(a, BatchAxis):
1593                n_dims_min -= 1
1594            elif isinstance(a.size, int):
1595                if a.size == 1:
1596                    n_dims_min -= 1
1597            elif isinstance(a.size, (ParameterizedSize, DataDependentSize)):
1598                if a.size.min == 1:
1599                    n_dims_min -= 1
1600            elif isinstance(a.size, SizeReference):
1601                if a.size.offset < 2:
1602                    # size reference may result in singleton axis
1603                    n_dims_min -= 1
1604            else:
1605                assert_never(a.size)
1606
1607        n_dims_min = max(0, n_dims_min)
1608        if n_dims < n_dims_min or n_dims > n_dims_max:
1609            raise ValueError(
1610                f"Expected sample tensor to have {n_dims_min} to"
1611                + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})."
1612            )
1613
1614        return self
1615
1616    data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = (
1617        IntervalOrRatioDataDescr()
1618    )
1619    """Description of the tensor's data values, optionally per channel.
1620    If specified per channel, the data `type` needs to match across channels."""
1621
1622    @property
1623    def dtype(
1624        self,
1625    ) -> Literal[
1626        "float32",
1627        "float64",
1628        "uint8",
1629        "int8",
1630        "uint16",
1631        "int16",
1632        "uint32",
1633        "int32",
1634        "uint64",
1635        "int64",
1636        "bool",
1637    ]:
1638        """dtype as specified under `data.type` or `data[i].type`"""
1639        if isinstance(self.data, collections.abc.Sequence):
1640            return self.data[0].type
1641        else:
1642            return self.data.type
1643
1644    @field_validator("data", mode="after")
1645    @classmethod
1646    def _check_data_type_across_channels(
1647        cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]
1648    ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]:
1649        if not isinstance(value, list):
1650            return value
1651
1652        dtypes = {t.type for t in value}
1653        if len(dtypes) > 1:
1654            raise ValueError(
1655                "Tensor data descriptions per channel need to agree in their data"
1656                + f" `type`, but found {dtypes}."
1657            )
1658
1659        return value
1660
1661    @model_validator(mode="after")
1662    def _check_data_matches_channelaxis(self) -> Self:
1663        if not isinstance(self.data, (list, tuple)):
1664            return self
1665
1666        for a in self.axes:
1667            if isinstance(a, ChannelAxis):
1668                size = a.size
1669                assert isinstance(size, int)
1670                break
1671        else:
1672            return self
1673
1674        if len(self.data) != size:
1675            raise ValueError(
1676                f"Got tensor data descriptions for {len(self.data)} channels, but"
1677                + f" '{a.id}' axis has size {size}."
1678            )
1679
1680        return self
1681
1682    def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]:
1683        if len(array.shape) != len(self.axes):
1684            raise ValueError(
1685                f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})"
1686                + f" incompatible with {len(self.axes)} axes."
1687            )
1688        return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
1689
1690
1691class InputTensorDescr(TensorDescrBase[InputAxis]):
1692    id: TensorId = TensorId("input")
1693    """Input tensor id.
1694    No duplicates are allowed across all inputs and outputs."""
1695
1696    optional: bool = False
1697    """indicates that this tensor may be `None`"""
1698
1699    preprocessing: List[PreprocessingDescr] = Field(
1700        default_factory=cast(Callable[[], List[PreprocessingDescr]], list)
1701    )
1702
1703    """Description of how this input should be preprocessed.
1704
1705    notes:
1706    - If preprocessing does not start with an 'ensure_dtype' entry, it is added
1707      to ensure an input tensor's data type matches the input tensor's data description.
1708    - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an
1709      'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally
1710      changing the data type.
1711    """
1712
1713    @model_validator(mode="after")
1714    def _validate_preprocessing_kwargs(self) -> Self:
1715        axes_ids = [a.id for a in self.axes]
1716        for p in self.preprocessing:
1717            kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
1718            if kwargs_axes is None:
1719                continue
1720
1721            if not isinstance(kwargs_axes, collections.abc.Sequence):
1722                raise ValueError(
1723                    f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}"
1724                )
1725
1726            if any(a not in axes_ids for a in kwargs_axes):
1727                raise ValueError(
1728                    "`preprocessing.i.kwargs.axes` needs to be subset of axes ids"
1729                )
1730
1731        if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
1732            dtype = self.data.type
1733        else:
1734            dtype = self.data[0].type
1735
1736        # ensure `preprocessing` begins with `EnsureDtypeDescr`
1737        if not self.preprocessing or not isinstance(
1738            self.preprocessing[0], EnsureDtypeDescr
1739        ):
1740            self.preprocessing.insert(
1741                0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1742            )
1743
1744        # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
1745        if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)):
1746            self.preprocessing.append(
1747                EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1748            )
1749
1750        return self
1751
1752
1753def convert_axes(
1754    axes: str,
1755    *,
1756    shape: Union[
1757        Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4
1758    ],
1759    tensor_type: Literal["input", "output"],
1760    halo: Optional[Sequence[int]],
1761    size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
1762):
1763    ret: List[AnyAxis] = []
1764    for i, a in enumerate(axes):
1765        axis_type = _AXIS_TYPE_MAP.get(a, a)
1766        if axis_type == "batch":
1767            ret.append(BatchAxis())
1768            continue
1769
1770        scale = 1.0
1771        if isinstance(shape, _ParameterizedInputShape_v0_4):
1772            if shape.step[i] == 0:
1773                size = shape.min[i]
1774            else:
1775                size = ParameterizedSize(min=shape.min[i], step=shape.step[i])
1776        elif isinstance(shape, _ImplicitOutputShape_v0_4):
1777            ref_t = str(shape.reference_tensor)
1778            if ref_t.count(".") == 1:
1779                t_id, orig_a_id = ref_t.split(".")
1780            else:
1781                t_id = ref_t
1782                orig_a_id = a
1783
1784            a_id = _AXIS_ID_MAP.get(orig_a_id, a)
1785            if not (orig_scale := shape.scale[i]):
1786                # old way to insert a new axis dimension
1787                size = int(2 * shape.offset[i])
1788            else:
1789                scale = 1 / orig_scale
1790                if axis_type in ("channel", "index"):
1791                    # these axes no longer have a scale
1792                    offset_from_scale = orig_scale * size_refs.get(
1793                        _TensorName_v0_4(t_id), {}
1794                    ).get(orig_a_id, 0)
1795                else:
1796                    offset_from_scale = 0
1797                size = SizeReference(
1798                    tensor_id=TensorId(t_id),
1799                    axis_id=AxisId(a_id),
1800                    offset=int(offset_from_scale + 2 * shape.offset[i]),
1801                )
1802        else:
1803            size = shape[i]
1804
1805        if axis_type == "time":
1806            if tensor_type == "input":
1807                ret.append(TimeInputAxis(size=size, scale=scale))
1808            else:
1809                assert not isinstance(size, ParameterizedSize)
1810                if halo is None:
1811                    ret.append(TimeOutputAxis(size=size, scale=scale))
1812                else:
1813                    assert not isinstance(size, int)
1814                    ret.append(
1815                        TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i])
1816                    )
1817
1818        elif axis_type == "index":
1819            if tensor_type == "input":
1820                ret.append(IndexInputAxis(size=size))
1821            else:
1822                if isinstance(size, ParameterizedSize):
1823                    size = DataDependentSize(min=size.min)
1824
1825                ret.append(IndexOutputAxis(size=size))
1826        elif axis_type == "channel":
1827            assert not isinstance(size, ParameterizedSize)
1828            if isinstance(size, SizeReference):
1829                warnings.warn(
1830                    "Conversion of channel size from an implicit output shape may be"
1831                    + " wrong"
1832                )
1833                ret.append(
1834                    ChannelAxis(
1835                        channel_names=[
1836                            Identifier(f"channel{i}") for i in range(size.offset)
1837                        ]
1838                    )
1839                )
1840            else:
1841                ret.append(
1842                    ChannelAxis(
1843                        channel_names=[Identifier(f"channel{i}") for i in range(size)]
1844                    )
1845                )
1846        elif axis_type == "space":
1847            if tensor_type == "input":
1848                ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale))
1849            else:
1850                assert not isinstance(size, ParameterizedSize)
1851                if halo is None or halo[i] == 0:
1852                    ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale))
1853                elif isinstance(size, int):
1854                    raise NotImplementedError(
1855                        f"output axis with halo and fixed size (here {size}) not allowed"
1856                    )
1857                else:
1858                    ret.append(
1859                        SpaceOutputAxisWithHalo(
1860                            id=AxisId(a), size=size, scale=scale, halo=halo[i]
1861                        )
1862                    )
1863
1864    return ret
1865
1866
1867def _axes_letters_to_ids(
1868    axes: Optional[str],
1869) -> Optional[List[AxisId]]:
1870    if axes is None:
1871        return None
1872
1873    return [AxisId(a) for a in axes]
1874
1875
1876def _get_complement_v04_axis(
1877    tensor_axes: Sequence[str], axes: Optional[Sequence[str]]
1878) -> Optional[AxisId]:
1879    if axes is None:
1880        return None
1881
1882    non_complement_axes = set(axes) | {"b"}
1883    complement_axes = [a for a in tensor_axes if a not in non_complement_axes]
1884    if len(complement_axes) > 1:
1885        raise ValueError(
1886            f"Expected none or a single complement axis, but axes '{axes}' "
1887            + f"for tensor dims '{tensor_axes}' leave '{complement_axes}'."
1888        )
1889
1890    return None if not complement_axes else AxisId(complement_axes[0])
1891
1892
1893def _convert_proc(
1894    p: Union[_PreprocessingDescr_v0_4, _PostprocessingDescr_v0_4],
1895    tensor_axes: Sequence[str],
1896) -> Union[PreprocessingDescr, PostprocessingDescr]:
1897    if isinstance(p, _BinarizeDescr_v0_4):
1898        return BinarizeDescr(kwargs=BinarizeKwargs(threshold=p.kwargs.threshold))
1899    elif isinstance(p, _ClipDescr_v0_4):
1900        return ClipDescr(kwargs=ClipKwargs(min=p.kwargs.min, max=p.kwargs.max))
1901    elif isinstance(p, _SigmoidDescr_v0_4):
1902        return SigmoidDescr()
1903    elif isinstance(p, _ScaleLinearDescr_v0_4):
1904        axes = _axes_letters_to_ids(p.kwargs.axes)
1905        if p.kwargs.axes is None:
1906            axis = None
1907        else:
1908            axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes)
1909
1910        if axis is None:
1911            assert not isinstance(p.kwargs.gain, list)
1912            assert not isinstance(p.kwargs.offset, list)
1913            kwargs = ScaleLinearKwargs(gain=p.kwargs.gain, offset=p.kwargs.offset)
1914        else:
1915            kwargs = ScaleLinearAlongAxisKwargs(
1916                axis=axis, gain=p.kwargs.gain, offset=p.kwargs.offset
1917            )
1918        return ScaleLinearDescr(kwargs=kwargs)
1919    elif isinstance(p, _ScaleMeanVarianceDescr_v0_4):
1920        return ScaleMeanVarianceDescr(
1921            kwargs=ScaleMeanVarianceKwargs(
1922                axes=_axes_letters_to_ids(p.kwargs.axes),
1923                reference_tensor=TensorId(str(p.kwargs.reference_tensor)),
1924                eps=p.kwargs.eps,
1925            )
1926        )
1927    elif isinstance(p, _ZeroMeanUnitVarianceDescr_v0_4):
1928        if p.kwargs.mode == "fixed":
1929            mean = p.kwargs.mean
1930            std = p.kwargs.std
1931            assert mean is not None
1932            assert std is not None
1933
1934            axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes)
1935
1936            if axis is None:
1937                return FixedZeroMeanUnitVarianceDescr(
1938                    kwargs=FixedZeroMeanUnitVarianceKwargs(
1939                        mean=mean, std=std  # pyright: ignore[reportArgumentType]
1940                    )
1941                )
1942            else:
1943                if not isinstance(mean, list):
1944                    mean = [float(mean)]
1945                if not isinstance(std, list):
1946                    std = [float(std)]
1947
1948                return FixedZeroMeanUnitVarianceDescr(
1949                    kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
1950                        axis=axis, mean=mean, std=std
1951                    )
1952                )
1953
1954        else:
1955            axes = _axes_letters_to_ids(p.kwargs.axes) or []
1956            if p.kwargs.mode == "per_dataset":
1957                axes = [AxisId("batch")] + axes
1958            if not axes:
1959                axes = None
1960            return ZeroMeanUnitVarianceDescr(
1961                kwargs=ZeroMeanUnitVarianceKwargs(axes=axes, eps=p.kwargs.eps)
1962            )
1963
1964    elif isinstance(p, _ScaleRangeDescr_v0_4):
1965        return ScaleRangeDescr(
1966            kwargs=ScaleRangeKwargs(
1967                axes=_axes_letters_to_ids(p.kwargs.axes),
1968                min_percentile=p.kwargs.min_percentile,
1969                max_percentile=p.kwargs.max_percentile,
1970                eps=p.kwargs.eps,
1971            )
1972        )
1973    else:
1974        assert_never(p)
1975
1976
1977class _InputTensorConv(
1978    Converter[
1979        _InputTensorDescr_v0_4,
1980        InputTensorDescr,
1981        FileSource_,
1982        Optional[FileSource_],
1983        Mapping[_TensorName_v0_4, Mapping[str, int]],
1984    ]
1985):
1986    def _convert(
1987        self,
1988        src: _InputTensorDescr_v0_4,
1989        tgt: "type[InputTensorDescr] | type[dict[str, Any]]",
1990        test_tensor: FileSource_,
1991        sample_tensor: Optional[FileSource_],
1992        size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
1993    ) -> "InputTensorDescr | dict[str, Any]":
1994        axes: List[InputAxis] = convert_axes(  # pyright: ignore[reportAssignmentType]
1995            src.axes,
1996            shape=src.shape,
1997            tensor_type="input",
1998            halo=None,
1999            size_refs=size_refs,
2000        )
2001        prep: List[PreprocessingDescr] = []
2002        for p in src.preprocessing:
2003            cp = _convert_proc(p, src.axes)
2004            assert not isinstance(cp, ScaleMeanVarianceDescr)
2005            prep.append(cp)
2006
2007        prep.append(EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="float32")))
2008
2009        return tgt(
2010            axes=axes,
2011            id=TensorId(str(src.name)),
2012            test_tensor=FileDescr(source=test_tensor),
2013            sample_tensor=(
2014                None if sample_tensor is None else FileDescr(source=sample_tensor)
2015            ),
2016            data=dict(type=src.data_type),  # pyright: ignore[reportArgumentType]
2017            preprocessing=prep,
2018        )
2019
2020
2021_input_tensor_conv = _InputTensorConv(_InputTensorDescr_v0_4, InputTensorDescr)
2022
2023
2024class OutputTensorDescr(TensorDescrBase[OutputAxis]):
2025    id: TensorId = TensorId("output")
2026    """Output tensor id.
2027    No duplicates are allowed across all inputs and outputs."""
2028
2029    postprocessing: List[PostprocessingDescr] = Field(
2030        default_factory=cast(Callable[[], List[PostprocessingDescr]], list)
2031    )
2032    """Description of how this output should be postprocessed.
2033
2034    note: `postprocessing` always ends with an 'ensure_dtype' operation.
2035          If not given this is added to cast to this tensor's `data.type`.
2036    """
2037
2038    @model_validator(mode="after")
2039    def _validate_postprocessing_kwargs(self) -> Self:
2040        axes_ids = [a.id for a in self.axes]
2041        for p in self.postprocessing:
2042            kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
2043            if kwargs_axes is None:
2044                continue
2045
2046            if not isinstance(kwargs_axes, collections.abc.Sequence):
2047                raise ValueError(
2048                    f"expected `axes` sequence, but got {type(kwargs_axes)}"
2049                )
2050
2051            if any(a not in axes_ids for a in kwargs_axes):
2052                raise ValueError("`kwargs.axes` needs to be subset of axes ids")
2053
2054        if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
2055            dtype = self.data.type
2056        else:
2057            dtype = self.data[0].type
2058
2059        # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
2060        if not self.postprocessing or not isinstance(
2061            self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)
2062        ):
2063            self.postprocessing.append(
2064                EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
2065            )
2066        return self
2067
2068
2069class _OutputTensorConv(
2070    Converter[
2071        _OutputTensorDescr_v0_4,
2072        OutputTensorDescr,
2073        FileSource_,
2074        Optional[FileSource_],
2075        Mapping[_TensorName_v0_4, Mapping[str, int]],
2076    ]
2077):
2078    def _convert(
2079        self,
2080        src: _OutputTensorDescr_v0_4,
2081        tgt: "type[OutputTensorDescr] | type[dict[str, Any]]",
2082        test_tensor: FileSource_,
2083        sample_tensor: Optional[FileSource_],
2084        size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
2085    ) -> "OutputTensorDescr | dict[str, Any]":
2086        # TODO: split convert_axes into convert_output_axes and convert_input_axes
2087        axes: List[OutputAxis] = convert_axes(  # pyright: ignore[reportAssignmentType]
2088            src.axes,
2089            shape=src.shape,
2090            tensor_type="output",
2091            halo=src.halo,
2092            size_refs=size_refs,
2093        )
2094        data_descr: Dict[str, Any] = dict(type=src.data_type)
2095        if data_descr["type"] == "bool":
2096            data_descr["values"] = [False, True]
2097
2098        return tgt(
2099            axes=axes,
2100            id=TensorId(str(src.name)),
2101            test_tensor=FileDescr(source=test_tensor),
2102            sample_tensor=(
2103                None if sample_tensor is None else FileDescr(source=sample_tensor)
2104            ),
2105            data=data_descr,  # pyright: ignore[reportArgumentType]
2106            postprocessing=[_convert_proc(p, src.axes) for p in src.postprocessing],
2107        )
2108
2109
2110_output_tensor_conv = _OutputTensorConv(_OutputTensorDescr_v0_4, OutputTensorDescr)
2111
2112
2113TensorDescr = Union[InputTensorDescr, OutputTensorDescr]
2114
2115
2116def validate_tensors(
2117    tensors: Mapping[TensorId, Tuple[TensorDescr, Optional[NDArray[Any]]]],
2118    tensor_origin: Literal[
2119        "test_tensor"
2120    ],  # for more precise error messages, e.g. 'test_tensor'
2121):
2122    all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, Optional[int]]]] = {}
2123
2124    def e_msg(d: TensorDescr):
2125        return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]"
2126
2127    for descr, array in tensors.values():
2128        if array is None:
2129            axis_sizes = {a.id: None for a in descr.axes}
2130        else:
2131            try:
2132                axis_sizes = descr.get_axis_sizes_for_array(array)
2133            except ValueError as e:
2134                raise ValueError(f"{e_msg(descr)} {e}")
2135
2136        all_tensor_axes[descr.id] = {a.id: (a, axis_sizes[a.id]) for a in descr.axes}
2137
2138    for descr, array in tensors.values():
2139        if array is None:
2140            continue
2141
2142        if descr.dtype in ("float32", "float64"):
2143            invalid_test_tensor_dtype = array.dtype.name not in (
2144                "float32",
2145                "float64",
2146                "uint8",
2147                "int8",
2148                "uint16",
2149                "int16",
2150                "uint32",
2151                "int32",
2152                "uint64",
2153                "int64",
2154            )
2155        else:
2156            invalid_test_tensor_dtype = array.dtype.name != descr.dtype
2157
2158        if invalid_test_tensor_dtype:
2159            raise ValueError(
2160                f"{e_msg(descr)}.{tensor_origin}.dtype '{array.dtype.name}' does not"
2161                + f" match described dtype '{descr.dtype}'"
2162            )
2163
2164        if array.min() > -1e-4 and array.max() < 1e-4:
2165            raise ValueError(
2166                "Output values are too small for reliable testing."
2167                + f" Values <-1e5 or >=1e5 must be present in {tensor_origin}"
2168            )
2169
2170        for a in descr.axes:
2171            actual_size = all_tensor_axes[descr.id][a.id][1]
2172            if actual_size is None:
2173                continue
2174
2175            if a.size is None:
2176                continue
2177
2178            if isinstance(a.size, int):
2179                if actual_size != a.size:
2180                    raise ValueError(
2181                        f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' "
2182                        + f"has incompatible size {actual_size}, expected {a.size}"
2183                    )
2184            elif isinstance(a.size, ParameterizedSize):
2185                _ = a.size.validate_size(actual_size)
2186            elif isinstance(a.size, DataDependentSize):
2187                _ = a.size.validate_size(actual_size)
2188            elif isinstance(a.size, SizeReference):
2189                ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id)
2190                if ref_tensor_axes is None:
2191                    raise ValueError(
2192                        f"{e_msg(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor"
2193                        + f" reference '{a.size.tensor_id}'"
2194                    )
2195
2196                ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None))
2197                if ref_axis is None or ref_size is None:
2198                    raise ValueError(
2199                        f"{e_msg(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis"
2200                        + f" reference '{a.size.tensor_id}.{a.size.axis_id}"
2201                    )
2202
2203                if a.unit != ref_axis.unit:
2204                    raise ValueError(
2205                        f"{e_msg(descr)}.axes[{a.id}].size: `SizeReference` requires"
2206                        + " axis and reference axis to have the same `unit`, but"
2207                        + f" {a.unit}!={ref_axis.unit}"
2208                    )
2209
2210                if actual_size != (
2211                    expected_size := (
2212                        ref_size * ref_axis.scale / a.scale + a.size.offset
2213                    )
2214                ):
2215                    raise ValueError(
2216                        f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' of size"
2217                        + f" {actual_size} invalid for referenced size {ref_size};"
2218                        + f" expected {expected_size}"
2219                    )
2220            else:
2221                assert_never(a.size)
2222
2223
2224FileDescr_dependencies = Annotated[
2225    FileDescr_,
2226    WithSuffix((".yaml", ".yml"), case_sensitive=True),
2227    Field(examples=[dict(source="environment.yaml")]),
2228]
2229
2230
2231class _ArchitectureCallableDescr(Node):
2232    callable: Annotated[Identifier, Field(examples=["MyNetworkClass", "get_my_model"])]
2233    """Identifier of the callable that returns a torch.nn.Module instance."""
2234
2235    kwargs: Dict[str, YamlValue] = Field(
2236        default_factory=cast(Callable[[], Dict[str, YamlValue]], dict)
2237    )
2238    """key word arguments for the `callable`"""
2239
2240
2241class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr):
2242    source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2243    """Architecture source file"""
2244
2245    @model_serializer(mode="wrap", when_used="unless-none")
2246    def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo):
2247        return package_file_descr_serializer(self, nxt, info)
2248
2249
2250class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr):
2251    import_from: str
2252    """Where to import the callable from, i.e. `from <import_from> import <callable>`"""
2253
2254
2255class _ArchFileConv(
2256    Converter[
2257        _CallableFromFile_v0_4,
2258        ArchitectureFromFileDescr,
2259        Optional[Sha256],
2260        Dict[str, Any],
2261    ]
2262):
2263    def _convert(
2264        self,
2265        src: _CallableFromFile_v0_4,
2266        tgt: "type[ArchitectureFromFileDescr | dict[str, Any]]",
2267        sha256: Optional[Sha256],
2268        kwargs: Dict[str, Any],
2269    ) -> "ArchitectureFromFileDescr | dict[str, Any]":
2270        if src.startswith("http") and src.count(":") == 2:
2271            http, source, callable_ = src.split(":")
2272            source = ":".join((http, source))
2273        elif not src.startswith("http") and src.count(":") == 1:
2274            source, callable_ = src.split(":")
2275        else:
2276            source = str(src)
2277            callable_ = str(src)
2278        return tgt(
2279            callable=Identifier(callable_),
2280            source=cast(FileSource_, source),
2281            sha256=sha256,
2282            kwargs=kwargs,
2283        )
2284
2285
2286_arch_file_conv = _ArchFileConv(_CallableFromFile_v0_4, ArchitectureFromFileDescr)
2287
2288
2289class _ArchLibConv(
2290    Converter[
2291        _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr, Dict[str, Any]
2292    ]
2293):
2294    def _convert(
2295        self,
2296        src: _CallableFromDepencency_v0_4,
2297        tgt: "type[ArchitectureFromLibraryDescr | dict[str, Any]]",
2298        kwargs: Dict[str, Any],
2299    ) -> "ArchitectureFromLibraryDescr | dict[str, Any]":
2300        *mods, callable_ = src.split(".")
2301        import_from = ".".join(mods)
2302        return tgt(
2303            import_from=import_from, callable=Identifier(callable_), kwargs=kwargs
2304        )
2305
2306
2307_arch_lib_conv = _ArchLibConv(
2308    _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr
2309)
2310
2311
2312class WeightsEntryDescrBase(FileDescr):
2313    type: ClassVar[WeightsFormat]
2314    weights_format_name: ClassVar[str]  # human readable
2315
2316    source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2317    """Source of the weights file."""
2318
2319    authors: Optional[List[Author]] = None
2320    """Authors
2321    Either the person(s) that have trained this model resulting in the original weights file.
2322        (If this is the initial weights entry, i.e. it does not have a `parent`)
2323    Or the person(s) who have converted the weights to this weights format.
2324        (If this is a child weight, i.e. it has a `parent` field)
2325    """
2326
2327    parent: Annotated[
2328        Optional[WeightsFormat], Field(examples=["pytorch_state_dict"])
2329    ] = None
2330    """The source weights these weights were converted from.
2331    For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`,
2332    The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights.
2333    All weight entries except one (the initial set of weights resulting from training the model),
2334    need to have this field."""
2335
2336    comment: str = ""
2337    """A comment about this weights entry, for example how these weights were created."""
2338
2339    @model_validator(mode="after")
2340    def _validate(self) -> Self:
2341        if self.type == self.parent:
2342            raise ValueError("Weights entry can't be it's own parent.")
2343
2344        return self
2345
2346    @model_serializer(mode="wrap", when_used="unless-none")
2347    def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo):
2348        return package_file_descr_serializer(self, nxt, info)
2349
2350
2351class KerasHdf5WeightsDescr(WeightsEntryDescrBase):
2352    type = "keras_hdf5"
2353    weights_format_name: ClassVar[str] = "Keras HDF5"
2354    tensorflow_version: Version
2355    """TensorFlow version used to create these weights."""
2356
2357
2358class OnnxWeightsDescr(WeightsEntryDescrBase):
2359    type = "onnx"
2360    weights_format_name: ClassVar[str] = "ONNX"
2361    opset_version: Annotated[int, Ge(7)]
2362    """ONNX opset version"""
2363
2364
2365class PytorchStateDictWeightsDescr(WeightsEntryDescrBase):
2366    type = "pytorch_state_dict"
2367    weights_format_name: ClassVar[str] = "Pytorch State Dict"
2368    architecture: Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr]
2369    pytorch_version: Version
2370    """Version of the PyTorch library used.
2371    If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible.
2372    """
2373    dependencies: Optional[FileDescr_dependencies] = None
2374    """Custom depencies beyond pytorch described in a Conda environment file.
2375    Allows to specify custom dependencies, see conda docs:
2376    - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms)
2377    - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually)
2378
2379    The conda environment file should include pytorch and any version pinning has to be compatible with
2380    **pytorch_version**.
2381    """
2382
2383
2384class TensorflowJsWeightsDescr(WeightsEntryDescrBase):
2385    type = "tensorflow_js"
2386    weights_format_name: ClassVar[str] = "Tensorflow.js"
2387    tensorflow_version: Version
2388    """Version of the TensorFlow library used."""
2389
2390    source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2391    """The multi-file weights.
2392    All required files/folders should be a zip archive."""
2393
2394
2395class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase):
2396    type = "tensorflow_saved_model_bundle"
2397    weights_format_name: ClassVar[str] = "Tensorflow Saved Model"
2398    tensorflow_version: Version
2399    """Version of the TensorFlow library used."""
2400
2401    dependencies: Optional[FileDescr_dependencies] = None
2402    """Custom dependencies beyond tensorflow.
2403    Should include tensorflow and any version pinning has to be compatible with **tensorflow_version**."""
2404
2405    source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2406    """The multi-file weights.
2407    All required files/folders should be a zip archive."""
2408
2409
2410class TorchscriptWeightsDescr(WeightsEntryDescrBase):
2411    type = "torchscript"
2412    weights_format_name: ClassVar[str] = "TorchScript"
2413    pytorch_version: Version
2414    """Version of the PyTorch library used."""
2415
2416
2417class WeightsDescr(Node):
2418    keras_hdf5: Optional[KerasHdf5WeightsDescr] = None
2419    onnx: Optional[OnnxWeightsDescr] = None
2420    pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None
2421    tensorflow_js: Optional[TensorflowJsWeightsDescr] = None
2422    tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = (
2423        None
2424    )
2425    torchscript: Optional[TorchscriptWeightsDescr] = None
2426
2427    @model_validator(mode="after")
2428    def check_entries(self) -> Self:
2429        entries = {wtype for wtype, entry in self if entry is not None}
2430
2431        if not entries:
2432            raise ValueError("Missing weights entry")
2433
2434        entries_wo_parent = {
2435            wtype
2436            for wtype, entry in self
2437            if entry is not None and hasattr(entry, "parent") and entry.parent is None
2438        }
2439        if len(entries_wo_parent) != 1:
2440            issue_warning(
2441                "Exactly one weights entry may not specify the `parent` field (got"
2442                + " {value}). That entry is considered the original set of model weights."
2443                + " Other weight formats are created through conversion of the orignal or"
2444                + " already converted weights. They have to reference the weights format"
2445                + " they were converted from as their `parent`.",
2446                value=len(entries_wo_parent),
2447                field="weights",
2448            )
2449
2450        for wtype, entry in self:
2451            if entry is None:
2452                continue
2453
2454            assert hasattr(entry, "type")
2455            assert hasattr(entry, "parent")
2456            assert wtype == entry.type
2457            if (
2458                entry.parent is not None and entry.parent not in entries
2459            ):  # self reference checked for `parent` field
2460                raise ValueError(
2461                    f"`weights.{wtype}.parent={entry.parent} not in specified weight"
2462                    + f" formats: {entries}"
2463                )
2464
2465        return self
2466
2467    def __getitem__(
2468        self,
2469        key: Literal[
2470            "keras_hdf5",
2471            "onnx",
2472            "pytorch_state_dict",
2473            "tensorflow_js",
2474            "tensorflow_saved_model_bundle",
2475            "torchscript",
2476        ],
2477    ):
2478        if key == "keras_hdf5":
2479            ret = self.keras_hdf5
2480        elif key == "onnx":
2481            ret = self.onnx
2482        elif key == "pytorch_state_dict":
2483            ret = self.pytorch_state_dict
2484        elif key == "tensorflow_js":
2485            ret = self.tensorflow_js
2486        elif key == "tensorflow_saved_model_bundle":
2487            ret = self.tensorflow_saved_model_bundle
2488        elif key == "torchscript":
2489            ret = self.torchscript
2490        else:
2491            raise KeyError(key)
2492
2493        if ret is None:
2494            raise KeyError(key)
2495
2496        return ret
2497
2498    @property
2499    def available_formats(self):
2500        return {
2501            **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}),
2502            **({} if self.onnx is None else {"onnx": self.onnx}),
2503            **(
2504                {}
2505                if self.pytorch_state_dict is None
2506                else {"pytorch_state_dict": self.pytorch_state_dict}
2507            ),
2508            **(
2509                {}
2510                if self.tensorflow_js is None
2511                else {"tensorflow_js": self.tensorflow_js}
2512            ),
2513            **(
2514                {}
2515                if self.tensorflow_saved_model_bundle is None
2516                else {
2517                    "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle
2518                }
2519            ),
2520            **({} if self.torchscript is None else {"torchscript": self.torchscript}),
2521        }
2522
2523    @property
2524    def missing_formats(self):
2525        return {
2526            wf for wf in get_args(WeightsFormat) if wf not in self.available_formats
2527        }
2528
2529
2530class ModelId(ResourceId):
2531    pass
2532
2533
2534class LinkedModel(LinkedResourceBase):
2535    """Reference to a bioimage.io model."""
2536
2537    id: ModelId
2538    """A valid model `id` from the bioimage.io collection."""
2539
2540
2541class _DataDepSize(NamedTuple):
2542    min: StrictInt
2543    max: Optional[StrictInt]
2544
2545
2546class _AxisSizes(NamedTuple):
2547    """the lenghts of all axes of model inputs and outputs"""
2548
2549    inputs: Dict[Tuple[TensorId, AxisId], int]
2550    outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]]
2551
2552
2553class _TensorSizes(NamedTuple):
2554    """_AxisSizes as nested dicts"""
2555
2556    inputs: Dict[TensorId, Dict[AxisId, int]]
2557    outputs: Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]
2558
2559
2560class ReproducibilityTolerance(Node, extra="allow"):
2561    """Describes what small numerical differences -- if any -- may be tolerated
2562    in the generated output when executing in different environments.
2563
2564    A tensor element *output* is considered mismatched to the **test_tensor** if
2565    abs(*output* - **test_tensor**) > **absolute_tolerance** + **relative_tolerance** * abs(**test_tensor**).
2566    (Internally we call [numpy.testing.assert_allclose](https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html).)
2567
2568    Motivation:
2569        For testing we can request the respective deep learning frameworks to be as
2570        reproducible as possible by setting seeds and chosing deterministic algorithms,
2571        but differences in operating systems, available hardware and installed drivers
2572        may still lead to numerical differences.
2573    """
2574
2575    relative_tolerance: RelativeTolerance = 1e-3
2576    """Maximum relative tolerance of reproduced test tensor."""
2577
2578    absolute_tolerance: AbsoluteTolerance = 1e-4
2579    """Maximum absolute tolerance of reproduced test tensor."""
2580
2581    mismatched_elements_per_million: MismatchedElementsPerMillion = 100
2582    """Maximum number of mismatched elements/pixels per million to tolerate."""
2583
2584    output_ids: Sequence[TensorId] = ()
2585    """Limits the output tensor IDs these reproducibility details apply to."""
2586
2587    weights_formats: Sequence[WeightsFormat] = ()
2588    """Limits the weights formats these details apply to."""
2589
2590
2591class BioimageioConfig(Node, extra="allow"):
2592    reproducibility_tolerance: Sequence[ReproducibilityTolerance] = ()
2593    """Tolerances to allow when reproducing the model's test outputs
2594    from the model's test inputs.
2595    Only the first entry matching tensor id and weights format is considered.
2596    """
2597
2598
2599class Config(Node, extra="allow"):
2600    bioimageio: BioimageioConfig = Field(
2601        default_factory=BioimageioConfig.model_construct
2602    )
2603
2604
2605class ModelDescr(GenericModelDescrBase):
2606    """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights.
2607    These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
2608    """
2609
2610    implemented_format_version: ClassVar[Literal["0.5.5"]] = "0.5.5"
2611    if TYPE_CHECKING:
2612        format_version: Literal["0.5.5"] = "0.5.5"
2613    else:
2614        format_version: Literal["0.5.5"]
2615        """Version of the bioimage.io model description specification used.
2616        When creating a new model always use the latest micro/patch version described here.
2617        The `format_version` is important for any consumer software to understand how to parse the fields.
2618        """
2619
2620    implemented_type: ClassVar[Literal["model"]] = "model"
2621    if TYPE_CHECKING:
2622        type: Literal["model"] = "model"
2623    else:
2624        type: Literal["model"]
2625        """Specialized resource type 'model'"""
2626
2627    id: Optional[ModelId] = None
2628    """bioimage.io-wide unique resource identifier
2629    assigned by bioimage.io; version **un**specific."""
2630
2631    authors: FAIR[List[Author]] = Field(
2632        default_factory=cast(Callable[[], List[Author]], list)
2633    )
2634    """The authors are the creators of the model RDF and the primary points of contact."""
2635
2636    documentation: FAIR[Optional[FileSource_documentation]] = None
2637    """URL or relative path to a markdown file with additional documentation.
2638    The recommended documentation file name is `README.md`. An `.md` suffix is mandatory.
2639    The documentation should include a '#[#] Validation' (sub)section
2640    with details on how to quantitatively validate the model on unseen data."""
2641
2642    @field_validator("documentation", mode="after")
2643    @classmethod
2644    def _validate_documentation(
2645        cls, value: Optional[FileSource_documentation]
2646    ) -> Optional[FileSource_documentation]:
2647        if not get_validation_context().perform_io_checks or value is None:
2648            return value
2649
2650        doc_reader = get_reader(value)
2651        doc_content = doc_reader.read().decode(encoding="utf-8")
2652        if not re.search("#.*[vV]alidation", doc_content):
2653            issue_warning(
2654                "No '# Validation' (sub)section found in {value}.",
2655                value=value,
2656                field="documentation",
2657            )
2658
2659        return value
2660
2661    inputs: NotEmpty[Sequence[InputTensorDescr]]
2662    """Describes the input tensors expected by this model."""
2663
2664    @field_validator("inputs", mode="after")
2665    @classmethod
2666    def _validate_input_axes(
2667        cls, inputs: Sequence[InputTensorDescr]
2668    ) -> Sequence[InputTensorDescr]:
2669        input_size_refs = cls._get_axes_with_independent_size(inputs)
2670
2671        for i, ipt in enumerate(inputs):
2672            valid_independent_refs: Dict[
2673                Tuple[TensorId, AxisId],
2674                Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2675            ] = {
2676                **{
2677                    (ipt.id, a.id): (ipt, a, a.size)
2678                    for a in ipt.axes
2679                    if not isinstance(a, BatchAxis)
2680                    and isinstance(a.size, (int, ParameterizedSize))
2681                },
2682                **input_size_refs,
2683            }
2684            for a, ax in enumerate(ipt.axes):
2685                cls._validate_axis(
2686                    "inputs",
2687                    i=i,
2688                    tensor_id=ipt.id,
2689                    a=a,
2690                    axis=ax,
2691                    valid_independent_refs=valid_independent_refs,
2692                )
2693        return inputs
2694
2695    @staticmethod
2696    def _validate_axis(
2697        field_name: str,
2698        i: int,
2699        tensor_id: TensorId,
2700        a: int,
2701        axis: AnyAxis,
2702        valid_independent_refs: Dict[
2703            Tuple[TensorId, AxisId],
2704            Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2705        ],
2706    ):
2707        if isinstance(axis, BatchAxis) or isinstance(
2708            axis.size, (int, ParameterizedSize, DataDependentSize)
2709        ):
2710            return
2711        elif not isinstance(axis.size, SizeReference):
2712            assert_never(axis.size)
2713
2714        # validate axis.size SizeReference
2715        ref = (axis.size.tensor_id, axis.size.axis_id)
2716        if ref not in valid_independent_refs:
2717            raise ValueError(
2718                "Invalid tensor axis reference at"
2719                + f" {field_name}[{i}].axes[{a}].size: {axis.size}."
2720            )
2721        if ref == (tensor_id, axis.id):
2722            raise ValueError(
2723                "Self-referencing not allowed for"
2724                + f" {field_name}[{i}].axes[{a}].size: {axis.size}"
2725            )
2726        if axis.type == "channel":
2727            if valid_independent_refs[ref][1].type != "channel":
2728                raise ValueError(
2729                    "A channel axis' size may only reference another fixed size"
2730                    + " channel axis."
2731                )
2732            if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names:
2733                ref_size = valid_independent_refs[ref][2]
2734                assert isinstance(ref_size, int), (
2735                    "channel axis ref (another channel axis) has to specify fixed"
2736                    + " size"
2737                )
2738                generated_channel_names = [
2739                    Identifier(axis.channel_names.format(i=i))
2740                    for i in range(1, ref_size + 1)
2741                ]
2742                axis.channel_names = generated_channel_names
2743
2744        if (ax_unit := getattr(axis, "unit", None)) != (
2745            ref_unit := getattr(valid_independent_refs[ref][1], "unit", None)
2746        ):
2747            raise ValueError(
2748                "The units of an axis and its reference axis need to match, but"
2749                + f" '{ax_unit}' != '{ref_unit}'."
2750            )
2751        ref_axis = valid_independent_refs[ref][1]
2752        if isinstance(ref_axis, BatchAxis):
2753            raise ValueError(
2754                f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}"
2755                + " (a batch axis is not allowed as reference)."
2756            )
2757
2758        if isinstance(axis, WithHalo):
2759            min_size = axis.size.get_size(axis, ref_axis, n=0)
2760            if (min_size - 2 * axis.halo) < 1:
2761                raise ValueError(
2762                    f"axis {axis.id} with minimum size {min_size} is too small for halo"
2763                    + f" {axis.halo}."
2764                )
2765
2766            input_halo = axis.halo * axis.scale / ref_axis.scale
2767            if input_halo != int(input_halo) or input_halo % 2 == 1:
2768                raise ValueError(
2769                    f"input_halo {input_halo} (output_halo {axis.halo} *"
2770                    + f" output_scale {axis.scale} / input_scale {ref_axis.scale})"
2771                    + f"     {tensor_id}.{axis.id}."
2772                )
2773
2774    @model_validator(mode="after")
2775    def _validate_test_tensors(self) -> Self:
2776        if not get_validation_context().perform_io_checks:
2777            return self
2778
2779        test_output_arrays = [
2780            None if descr.test_tensor is None else load_array(descr.test_tensor)
2781            for descr in self.outputs
2782        ]
2783        test_input_arrays = [
2784            None if descr.test_tensor is None else load_array(descr.test_tensor)
2785            for descr in self.inputs
2786        ]
2787
2788        tensors = {
2789            descr.id: (descr, array)
2790            for descr, array in zip(
2791                chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays
2792            )
2793        }
2794        validate_tensors(tensors, tensor_origin="test_tensor")
2795
2796        output_arrays = {
2797            descr.id: array for descr, array in zip(self.outputs, test_output_arrays)
2798        }
2799        for rep_tol in self.config.bioimageio.reproducibility_tolerance:
2800            if not rep_tol.absolute_tolerance:
2801                continue
2802
2803            if rep_tol.output_ids:
2804                out_arrays = {
2805                    oid: a
2806                    for oid, a in output_arrays.items()
2807                    if oid in rep_tol.output_ids
2808                }
2809            else:
2810                out_arrays = output_arrays
2811
2812            for out_id, array in out_arrays.items():
2813                if array is None:
2814                    continue
2815
2816                if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01:
2817                    raise ValueError(
2818                        "config.bioimageio.reproducibility_tolerance.absolute_tolerance="
2819                        + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}"
2820                        + f" (1% of the maximum value of the test tensor '{out_id}')"
2821                    )
2822
2823        return self
2824
2825    @model_validator(mode="after")
2826    def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self:
2827        ipt_refs = {t.id for t in self.inputs}
2828        out_refs = {t.id for t in self.outputs}
2829        for ipt in self.inputs:
2830            for p in ipt.preprocessing:
2831                ref = p.kwargs.get("reference_tensor")
2832                if ref is None:
2833                    continue
2834                if ref not in ipt_refs:
2835                    raise ValueError(
2836                        f"`reference_tensor` '{ref}' not found. Valid input tensor"
2837                        + f" references are: {ipt_refs}."
2838                    )
2839
2840        for out in self.outputs:
2841            for p in out.postprocessing:
2842                ref = p.kwargs.get("reference_tensor")
2843                if ref is None:
2844                    continue
2845
2846                if ref not in ipt_refs and ref not in out_refs:
2847                    raise ValueError(
2848                        f"`reference_tensor` '{ref}' not found. Valid tensor references"
2849                        + f" are: {ipt_refs | out_refs}."
2850                    )
2851
2852        return self
2853
2854    # TODO: use validate funcs in validate_test_tensors
2855    # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]:
2856
2857    name: Annotated[
2858        str,
2859        RestrictCharacters(string.ascii_letters + string.digits + "_+- ()"),
2860        MinLen(5),
2861        MaxLen(128),
2862        warn(MaxLen(64), "Name longer than 64 characters.", INFO),
2863    ]
2864    """A human-readable name of this model.
2865    It should be no longer than 64 characters
2866    and may only contain letter, number, underscore, minus, parentheses and spaces.
2867    We recommend to chose a name that refers to the model's task and image modality.
2868    """
2869
2870    outputs: NotEmpty[Sequence[OutputTensorDescr]]
2871    """Describes the output tensors."""
2872
2873    @field_validator("outputs", mode="after")
2874    @classmethod
2875    def _validate_tensor_ids(
2876        cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo
2877    ) -> Sequence[OutputTensorDescr]:
2878        tensor_ids = [
2879            t.id for t in info.data.get("inputs", []) + info.data.get("outputs", [])
2880        ]
2881        duplicate_tensor_ids: List[str] = []
2882        seen: Set[str] = set()
2883        for t in tensor_ids:
2884            if t in seen:
2885                duplicate_tensor_ids.append(t)
2886
2887            seen.add(t)
2888
2889        if duplicate_tensor_ids:
2890            raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}")
2891
2892        return outputs
2893
2894    @staticmethod
2895    def _get_axes_with_parameterized_size(
2896        io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2897    ):
2898        return {
2899            f"{t.id}.{a.id}": (t, a, a.size)
2900            for t in io
2901            for a in t.axes
2902            if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize)
2903        }
2904
2905    @staticmethod
2906    def _get_axes_with_independent_size(
2907        io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2908    ):
2909        return {
2910            (t.id, a.id): (t, a, a.size)
2911            for t in io
2912            for a in t.axes
2913            if not isinstance(a, BatchAxis)
2914            and isinstance(a.size, (int, ParameterizedSize))
2915        }
2916
2917    @field_validator("outputs", mode="after")
2918    @classmethod
2919    def _validate_output_axes(
2920        cls, outputs: List[OutputTensorDescr], info: ValidationInfo
2921    ) -> List[OutputTensorDescr]:
2922        input_size_refs = cls._get_axes_with_independent_size(
2923            info.data.get("inputs", [])
2924        )
2925        output_size_refs = cls._get_axes_with_independent_size(outputs)
2926
2927        for i, out in enumerate(outputs):
2928            valid_independent_refs: Dict[
2929                Tuple[TensorId, AxisId],
2930                Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2931            ] = {
2932                **{
2933                    (out.id, a.id): (out, a, a.size)
2934                    for a in out.axes
2935                    if not isinstance(a, BatchAxis)
2936                    and isinstance(a.size, (int, ParameterizedSize))
2937                },
2938                **input_size_refs,
2939                **output_size_refs,
2940            }
2941            for a, ax in enumerate(out.axes):
2942                cls._validate_axis(
2943                    "outputs",
2944                    i,
2945                    out.id,
2946                    a,
2947                    ax,
2948                    valid_independent_refs=valid_independent_refs,
2949                )
2950
2951        return outputs
2952
2953    packaged_by: List[Author] = Field(
2954        default_factory=cast(Callable[[], List[Author]], list)
2955    )
2956    """The persons that have packaged and uploaded this model.
2957    Only required if those persons differ from the `authors`."""
2958
2959    parent: Optional[LinkedModel] = None
2960    """The model from which this model is derived, e.g. by fine-tuning the weights."""
2961
2962    @model_validator(mode="after")
2963    def _validate_parent_is_not_self(self) -> Self:
2964        if self.parent is not None and self.parent.id == self.id:
2965            raise ValueError("A model description may not reference itself as parent.")
2966
2967        return self
2968
2969    run_mode: Annotated[
2970        Optional[RunMode],
2971        warn(None, "Run mode '{value}' has limited support across consumer softwares."),
2972    ] = None
2973    """Custom run mode for this model: for more complex prediction procedures like test time
2974    data augmentation that currently cannot be expressed in the specification.
2975    No standard run modes are defined yet."""
2976
2977    timestamp: Datetime = Field(default_factory=Datetime.now)
2978    """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format
2979    with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat).
2980    (In Python a datetime object is valid, too)."""
2981
2982    training_data: Annotated[
2983        Union[None, LinkedDataset, DatasetDescr, DatasetDescr02],
2984        Field(union_mode="left_to_right"),
2985    ] = None
2986    """The dataset used to train this model"""
2987
2988    weights: Annotated[WeightsDescr, WrapSerializer(package_weights)]
2989    """The weights for this model.
2990    Weights can be given for different formats, but should otherwise be equivalent.
2991    The available weight formats determine which consumers can use this model."""
2992
2993    config: Config = Field(default_factory=Config.model_construct)
2994
2995    @model_validator(mode="after")
2996    def _add_default_cover(self) -> Self:
2997        if not get_validation_context().perform_io_checks or self.covers:
2998            return self
2999
3000        try:
3001            generated_covers = generate_covers(
3002                [
3003                    (t, load_array(t.test_tensor))
3004                    for t in self.inputs
3005                    if t.test_tensor is not None
3006                ],
3007                [
3008                    (t, load_array(t.test_tensor))
3009                    for t in self.outputs
3010                    if t.test_tensor is not None
3011                ],
3012            )
3013        except Exception as e:
3014            issue_warning(
3015                "Failed to generate cover image(s): {e}",
3016                value=self.covers,
3017                msg_context=dict(e=e),
3018                field="covers",
3019            )
3020        else:
3021            self.covers.extend(generated_covers)
3022
3023        return self
3024
3025    def get_input_test_arrays(self) -> List[NDArray[Any]]:
3026        return self._get_test_arrays(self.inputs)
3027
3028    def get_output_test_arrays(self) -> List[NDArray[Any]]:
3029        return self._get_test_arrays(self.outputs)
3030
3031    @staticmethod
3032    def _get_test_arrays(
3033        io_descr: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
3034    ):
3035        ts: List[FileDescr] = []
3036        for d in io_descr:
3037            if d.test_tensor is None:
3038                raise ValueError(
3039                    f"Failed to get test arrays: description of '{d.id}' is missing a `test_tensor`."
3040                )
3041            ts.append(d.test_tensor)
3042
3043        data = [load_array(t) for t in ts]
3044        assert all(isinstance(d, np.ndarray) for d in data)
3045        return data
3046
3047    @staticmethod
3048    def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
3049        batch_size = 1
3050        tensor_with_batchsize: Optional[TensorId] = None
3051        for tid in tensor_sizes:
3052            for aid, s in tensor_sizes[tid].items():
3053                if aid != BATCH_AXIS_ID or s == 1 or s == batch_size:
3054                    continue
3055
3056                if batch_size != 1:
3057                    assert tensor_with_batchsize is not None
3058                    raise ValueError(
3059                        f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})"
3060                    )
3061
3062                batch_size = s
3063                tensor_with_batchsize = tid
3064
3065        return batch_size
3066
3067    def get_output_tensor_sizes(
3068        self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]
3069    ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]:
3070        """Returns the tensor output sizes for given **input_sizes**.
3071        Only if **input_sizes** has a valid input shape, the tensor output size is exact.
3072        Otherwise it might be larger than the actual (valid) output"""
3073        batch_size = self.get_batch_size(input_sizes)
3074        ns = self.get_ns(input_sizes)
3075
3076        tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size)
3077        return tensor_sizes.outputs
3078
3079    def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
3080        """get parameter `n` for each parameterized axis
3081        such that the valid input size is >= the given input size"""
3082        ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {}
3083        axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs}
3084        for tid in input_sizes:
3085            for aid, s in input_sizes[tid].items():
3086                size_descr = axes[tid][aid].size
3087                if isinstance(size_descr, ParameterizedSize):
3088                    ret[(tid, aid)] = size_descr.get_n(s)
3089                elif size_descr is None or isinstance(size_descr, (int, SizeReference)):
3090                    pass
3091                else:
3092                    assert_never(size_descr)
3093
3094        return ret
3095
3096    def get_tensor_sizes(
3097        self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int
3098    ) -> _TensorSizes:
3099        axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size)
3100        return _TensorSizes(
3101            {
3102                t: {
3103                    aa: axis_sizes.inputs[(tt, aa)]
3104                    for tt, aa in axis_sizes.inputs
3105                    if tt == t
3106                }
3107                for t in {tt for tt, _ in axis_sizes.inputs}
3108            },
3109            {
3110                t: {
3111                    aa: axis_sizes.outputs[(tt, aa)]
3112                    for tt, aa in axis_sizes.outputs
3113                    if tt == t
3114                }
3115                for t in {tt for tt, _ in axis_sizes.outputs}
3116            },
3117        )
3118
3119    def get_axis_sizes(
3120        self,
3121        ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
3122        batch_size: Optional[int] = None,
3123        *,
3124        max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
3125    ) -> _AxisSizes:
3126        """Determine input and output block shape for scale factors **ns**
3127        of parameterized input sizes.
3128
3129        Args:
3130            ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id))
3131                that is parameterized as `size = min + n * step`.
3132            batch_size: The desired size of the batch dimension.
3133                If given **batch_size** overwrites any batch size present in
3134                **max_input_shape**. Default 1.
3135            max_input_shape: Limits the derived block shapes.
3136                Each axis for which the input size, parameterized by `n`, is larger
3137                than **max_input_shape** is set to the minimal value `n_min` for which
3138                this is still true.
3139                Use this for small input samples or large values of **ns**.
3140                Or simply whenever you know the full input shape.
3141
3142        Returns:
3143            Resolved axis sizes for model inputs and outputs.
3144        """
3145        max_input_shape = max_input_shape or {}
3146        if batch_size is None:
3147            for (_t_id, a_id), s in max_input_shape.items():
3148                if a_id == BATCH_AXIS_ID:
3149                    batch_size = s
3150                    break
3151            else:
3152                batch_size = 1
3153
3154        all_axes = {
3155            t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
3156        }
3157
3158        inputs: Dict[Tuple[TensorId, AxisId], int] = {}
3159        outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {}
3160
3161        def get_axis_size(a: Union[InputAxis, OutputAxis]):
3162            if isinstance(a, BatchAxis):
3163                if (t_descr.id, a.id) in ns:
3164                    logger.warning(
3165                        "Ignoring unexpected size increment factor (n) for batch axis"
3166                        + " of tensor '{}'.",
3167                        t_descr.id,
3168                    )
3169                return batch_size
3170            elif isinstance(a.size, int):
3171                if (t_descr.id, a.id) in ns:
3172                    logger.warning(
3173                        "Ignoring unexpected size increment factor (n) for fixed size"
3174                        + " axis '{}' of tensor '{}'.",
3175                        a.id,
3176                        t_descr.id,
3177                    )
3178                return a.size
3179            elif isinstance(a.size, ParameterizedSize):
3180                if (t_descr.id, a.id) not in ns:
3181                    raise ValueError(
3182                        "Size increment factor (n) missing for parametrized axis"
3183                        + f" '{a.id}' of tensor '{t_descr.id}'."
3184                    )
3185                n = ns[(t_descr.id, a.id)]
3186                s_max = max_input_shape.get((t_descr.id, a.id))
3187                if s_max is not None:
3188                    n = min(n, a.size.get_n(s_max))
3189
3190                return a.size.get_size(n)
3191
3192            elif isinstance(a.size, SizeReference):
3193                if (t_descr.id, a.id) in ns:
3194                    logger.warning(
3195                        "Ignoring unexpected size increment factor (n) for axis '{}'"
3196                        + " of tensor '{}' with size reference.",
3197                        a.id,
3198                        t_descr.id,
3199                    )
3200                assert not isinstance(a, BatchAxis)
3201                ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
3202                assert not isinstance(ref_axis, BatchAxis)
3203                ref_key = (a.size.tensor_id, a.size.axis_id)
3204                ref_size = inputs.get(ref_key, outputs.get(ref_key))
3205                assert ref_size is not None, ref_key
3206                assert not isinstance(ref_size, _DataDepSize), ref_key
3207                return a.size.get_size(
3208                    axis=a,
3209                    ref_axis=ref_axis,
3210                    ref_size=ref_size,
3211                )
3212            elif isinstance(a.size, DataDependentSize):
3213                if (t_descr.id, a.id) in ns:
3214                    logger.warning(
3215                        "Ignoring unexpected increment factor (n) for data dependent"
3216                        + " size axis '{}' of tensor '{}'.",
3217                        a.id,
3218                        t_descr.id,
3219                    )
3220                return _DataDepSize(a.size.min, a.size.max)
3221            else:
3222                assert_never(a.size)
3223
3224        # first resolve all , but the `SizeReference` input sizes
3225        for t_descr in self.inputs:
3226            for a in t_descr.axes:
3227                if not isinstance(a.size, SizeReference):
3228                    s = get_axis_size(a)
3229                    assert not isinstance(s, _DataDepSize)
3230                    inputs[t_descr.id, a.id] = s
3231
3232        # resolve all other input axis sizes
3233        for t_descr in self.inputs:
3234            for a in t_descr.axes:
3235                if isinstance(a.size, SizeReference):
3236                    s = get_axis_size(a)
3237                    assert not isinstance(s, _DataDepSize)
3238                    inputs[t_descr.id, a.id] = s
3239
3240        # resolve all output axis sizes
3241        for t_descr in self.outputs:
3242            for a in t_descr.axes:
3243                assert not isinstance(a.size, ParameterizedSize)
3244                s = get_axis_size(a)
3245                outputs[t_descr.id, a.id] = s
3246
3247        return _AxisSizes(inputs=inputs, outputs=outputs)
3248
3249    @model_validator(mode="before")
3250    @classmethod
3251    def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]:
3252        cls.convert_from_old_format_wo_validation(data)
3253        return data
3254
3255    @classmethod
3256    def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3257        """Convert metadata following an older format version to this classes' format
3258        without validating the result.
3259        """
3260        if (
3261            data.get("type") == "model"
3262            and isinstance(fv := data.get("format_version"), str)
3263            and fv.count(".") == 2
3264        ):
3265            fv_parts = fv.split(".")
3266            if any(not p.isdigit() for p in fv_parts):
3267                return
3268
3269            fv_tuple = tuple(map(int, fv_parts))
3270
3271            assert cls.implemented_format_version_tuple[0:2] == (0, 5)
3272            if fv_tuple[:2] in ((0, 3), (0, 4)):
3273                m04 = _ModelDescr_v0_4.load(data)
3274                if isinstance(m04, InvalidDescr):
3275                    try:
3276                        updated = _model_conv.convert_as_dict(
3277                            m04  # pyright: ignore[reportArgumentType]
3278                        )
3279                    except Exception as e:
3280                        logger.error(
3281                            "Failed to convert from invalid model 0.4 description."
3282                            + f"\nerror: {e}"
3283                            + "\nProceeding with model 0.5 validation without conversion."
3284                        )
3285                        updated = None
3286                else:
3287                    updated = _model_conv.convert_as_dict(m04)
3288
3289                if updated is not None:
3290                    data.clear()
3291                    data.update(updated)
3292
3293            elif fv_tuple[:2] == (0, 5):
3294                # bump patch version
3295                data["format_version"] = cls.implemented_format_version
3296
3297
3298class _ModelConv(Converter[_ModelDescr_v0_4, ModelDescr]):
3299    def _convert(
3300        self, src: _ModelDescr_v0_4, tgt: "type[ModelDescr] | type[dict[str, Any]]"
3301    ) -> "ModelDescr | dict[str, Any]":
3302        name = "".join(
3303            c if c in string.ascii_letters + string.digits + "_+- ()" else " "
3304            for c in src.name
3305        )
3306
3307        def conv_authors(auths: Optional[Sequence[_Author_v0_4]]):
3308            conv = (
3309                _author_conv.convert if TYPE_CHECKING else _author_conv.convert_as_dict
3310            )
3311            return None if auths is None else [conv(a) for a in auths]
3312
3313        if TYPE_CHECKING:
3314            arch_file_conv = _arch_file_conv.convert
3315            arch_lib_conv = _arch_lib_conv.convert
3316        else:
3317            arch_file_conv = _arch_file_conv.convert_as_dict
3318            arch_lib_conv = _arch_lib_conv.convert_as_dict
3319
3320        input_size_refs = {
3321            ipt.name: {
3322                a: s
3323                for a, s in zip(
3324                    ipt.axes,
3325                    (
3326                        ipt.shape.min
3327                        if isinstance(ipt.shape, _ParameterizedInputShape_v0_4)
3328                        else ipt.shape
3329                    ),
3330                )
3331            }
3332            for ipt in src.inputs
3333            if ipt.shape
3334        }
3335        output_size_refs = {
3336            **{
3337                out.name: {a: s for a, s in zip(out.axes, out.shape)}
3338                for out in src.outputs
3339                if not isinstance(out.shape, _ImplicitOutputShape_v0_4)
3340            },
3341            **input_size_refs,
3342        }
3343
3344        return tgt(
3345            attachments=(
3346                []
3347                if src.attachments is None
3348                else [FileDescr(source=f) for f in src.attachments.files]
3349            ),
3350            authors=[
3351                _author_conv.convert_as_dict(a) for a in src.authors
3352            ],  # pyright: ignore[reportArgumentType]
3353            cite=[
3354                {"text": c.text, "doi": c.doi, "url": c.url} for c in src.cite
3355            ],  # pyright: ignore[reportArgumentType]
3356            config=src.config,  # pyright: ignore[reportArgumentType]
3357            covers=src.covers,
3358            description=src.description,
3359            documentation=src.documentation,
3360            format_version="0.5.5",
3361            git_repo=src.git_repo,  # pyright: ignore[reportArgumentType]
3362            icon=src.icon,
3363            id=None if src.id is None else ModelId(src.id),
3364            id_emoji=src.id_emoji,
3365            license=src.license,  # type: ignore
3366            links=src.links,
3367            maintainers=[
3368                _maintainer_conv.convert_as_dict(m) for m in src.maintainers
3369            ],  # pyright: ignore[reportArgumentType]
3370            name=name,
3371            tags=src.tags,
3372            type=src.type,
3373            uploader=src.uploader,
3374            version=src.version,
3375            inputs=[  # pyright: ignore[reportArgumentType]
3376                _input_tensor_conv.convert_as_dict(ipt, tt, st, input_size_refs)
3377                for ipt, tt, st, in zip(
3378                    src.inputs,
3379                    src.test_inputs,
3380                    src.sample_inputs or [None] * len(src.test_inputs),
3381                )
3382            ],
3383            outputs=[  # pyright: ignore[reportArgumentType]
3384                _output_tensor_conv.convert_as_dict(out, tt, st, output_size_refs)
3385                for out, tt, st, in zip(
3386                    src.outputs,
3387                    src.test_outputs,
3388                    src.sample_outputs or [None] * len(src.test_outputs),
3389                )
3390            ],
3391            parent=(
3392                None
3393                if src.parent is None
3394                else LinkedModel(
3395                    id=ModelId(
3396                        str(src.parent.id)
3397                        + (
3398                            ""
3399                            if src.parent.version_number is None
3400                            else f"/{src.parent.version_number}"
3401                        )
3402                    )
3403                )
3404            ),
3405            training_data=(
3406                None
3407                if src.training_data is None
3408                else (
3409                    LinkedDataset(
3410                        id=DatasetId(
3411                            str(src.training_data.id)
3412                            + (
3413                                ""
3414                                if src.training_data.version_number is None
3415                                else f"/{src.training_data.version_number}"
3416                            )
3417                        )
3418                    )
3419                    if isinstance(src.training_data, LinkedDataset02)
3420                    else src.training_data
3421                )
3422            ),
3423            packaged_by=[
3424                _author_conv.convert_as_dict(a) for a in src.packaged_by
3425            ],  # pyright: ignore[reportArgumentType]
3426            run_mode=src.run_mode,
3427            timestamp=src.timestamp,
3428            weights=(WeightsDescr if TYPE_CHECKING else dict)(
3429                keras_hdf5=(w := src.weights.keras_hdf5)
3430                and (KerasHdf5WeightsDescr if TYPE_CHECKING else dict)(
3431                    authors=conv_authors(w.authors),
3432                    source=w.source,
3433                    tensorflow_version=w.tensorflow_version or Version("1.15"),
3434                    parent=w.parent,
3435                ),
3436                onnx=(w := src.weights.onnx)
3437                and (OnnxWeightsDescr if TYPE_CHECKING else dict)(
3438                    source=w.source,
3439                    authors=conv_authors(w.authors),
3440                    parent=w.parent,
3441                    opset_version=w.opset_version or 15,
3442                ),
3443                pytorch_state_dict=(w := src.weights.pytorch_state_dict)
3444                and (PytorchStateDictWeightsDescr if TYPE_CHECKING else dict)(
3445                    source=w.source,
3446                    authors=conv_authors(w.authors),
3447                    parent=w.parent,
3448                    architecture=(
3449                        arch_file_conv(
3450                            w.architecture,
3451                            w.architecture_sha256,
3452                            w.kwargs,
3453                        )
3454                        if isinstance(w.architecture, _CallableFromFile_v0_4)
3455                        else arch_lib_conv(w.architecture, w.kwargs)
3456                    ),
3457                    pytorch_version=w.pytorch_version or Version("1.10"),
3458                    dependencies=(
3459                        None
3460                        if w.dependencies is None
3461                        else (FileDescr if TYPE_CHECKING else dict)(
3462                            source=cast(
3463                                FileSource,
3464                                str(deps := w.dependencies)[
3465                                    (
3466                                        len("conda:")
3467                                        if str(deps).startswith("conda:")
3468                                        else 0
3469                                    ) :
3470                                ],
3471                            )
3472                        )
3473                    ),
3474                ),
3475                tensorflow_js=(w := src.weights.tensorflow_js)
3476                and (TensorflowJsWeightsDescr if TYPE_CHECKING else dict)(
3477                    source=w.source,
3478                    authors=conv_authors(w.authors),
3479                    parent=w.parent,
3480                    tensorflow_version=w.tensorflow_version or Version("1.15"),
3481                ),
3482                tensorflow_saved_model_bundle=(
3483                    w := src.weights.tensorflow_saved_model_bundle
3484                )
3485                and (TensorflowSavedModelBundleWeightsDescr if TYPE_CHECKING else dict)(
3486                    authors=conv_authors(w.authors),
3487                    parent=w.parent,
3488                    source=w.source,
3489                    tensorflow_version=w.tensorflow_version or Version("1.15"),
3490                    dependencies=(
3491                        None
3492                        if w.dependencies is None
3493                        else (FileDescr if TYPE_CHECKING else dict)(
3494                            source=cast(
3495                                FileSource,
3496                                (
3497                                    str(w.dependencies)[len("conda:") :]
3498                                    if str(w.dependencies).startswith("conda:")
3499                                    else str(w.dependencies)
3500                                ),
3501                            )
3502                        )
3503                    ),
3504                ),
3505                torchscript=(w := src.weights.torchscript)
3506                and (TorchscriptWeightsDescr if TYPE_CHECKING else dict)(
3507                    source=w.source,
3508                    authors=conv_authors(w.authors),
3509                    parent=w.parent,
3510                    pytorch_version=w.pytorch_version or Version("1.10"),
3511                ),
3512            ),
3513        )
3514
3515
3516_model_conv = _ModelConv(_ModelDescr_v0_4, ModelDescr)
3517
3518
3519# create better cover images for 3d data and non-image outputs
3520def generate_covers(
3521    inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]],
3522    outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]],
3523) -> List[Path]:
3524    def squeeze(
3525        data: NDArray[Any], axes: Sequence[AnyAxis]
3526    ) -> Tuple[NDArray[Any], List[AnyAxis]]:
3527        """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining"""
3528        if data.ndim != len(axes):
3529            raise ValueError(
3530                f"tensor shape {data.shape} does not match described axes"
3531                + f" {[a.id for a in axes]}"
3532            )
3533
3534        axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1]
3535        return data.squeeze(), axes
3536
3537    def normalize(
3538        data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7
3539    ) -> NDArray[np.float32]:
3540        data = data.astype("float32")
3541        data -= data.min(axis=axis, keepdims=True)
3542        data /= data.max(axis=axis, keepdims=True) + eps
3543        return data
3544
3545    def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]):
3546        original_shape = data.shape
3547        data, axes = squeeze(data, axes)
3548
3549        # take slice fom any batch or index axis if needed
3550        # and convert the first channel axis and take a slice from any additional channel axes
3551        slices: Tuple[slice, ...] = ()
3552        ndim = data.ndim
3553        ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2
3554        has_c_axis = False
3555        for i, a in enumerate(axes):
3556            s = data.shape[i]
3557            assert s > 1
3558            if (
3559                isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis))
3560                and ndim > ndim_need
3561            ):
3562                data = data[slices + (slice(s // 2 - 1, s // 2),)]
3563                ndim -= 1
3564            elif isinstance(a, ChannelAxis):
3565                if has_c_axis:
3566                    # second channel axis
3567                    data = data[slices + (slice(0, 1),)]
3568                    ndim -= 1
3569                else:
3570                    has_c_axis = True
3571                    if s == 2:
3572                        # visualize two channels with cyan and magenta
3573                        data = np.concatenate(
3574                            [
3575                                data[slices + (slice(1, 2),)],
3576                                data[slices + (slice(0, 1),)],
3577                                (
3578                                    data[slices + (slice(0, 1),)]
3579                                    + data[slices + (slice(1, 2),)]
3580                                )
3581                                / 2,  # TODO: take maximum instead?
3582                            ],
3583                            axis=i,
3584                        )
3585                    elif data.shape[i] == 3:
3586                        pass  # visualize 3 channels as RGB
3587                    else:
3588                        # visualize first 3 channels as RGB
3589                        data = data[slices + (slice(3),)]
3590
3591                    assert data.shape[i] == 3
3592
3593            slices += (slice(None),)
3594
3595        data, axes = squeeze(data, axes)
3596        assert len(axes) == ndim
3597        # take slice from z axis if needed
3598        slices = ()
3599        if ndim > ndim_need:
3600            for i, a in enumerate(axes):
3601                s = data.shape[i]
3602                if a.id == AxisId("z"):
3603                    data = data[slices + (slice(s // 2 - 1, s // 2),)]
3604                    data, axes = squeeze(data, axes)
3605                    ndim -= 1
3606                    break
3607
3608            slices += (slice(None),)
3609
3610        # take slice from any space or time axis
3611        slices = ()
3612
3613        for i, a in enumerate(axes):
3614            if ndim <= ndim_need:
3615                break
3616
3617            s = data.shape[i]
3618            assert s > 1
3619            if isinstance(
3620                a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis)
3621            ):
3622                data = data[slices + (slice(s // 2 - 1, s // 2),)]
3623                ndim -= 1
3624
3625            slices += (slice(None),)
3626
3627        del slices
3628        data, axes = squeeze(data, axes)
3629        assert len(axes) == ndim
3630
3631        if (has_c_axis and ndim != 3) or ndim != 2:
3632            raise ValueError(
3633                f"Failed to construct cover image from shape {original_shape}"
3634            )
3635
3636        if not has_c_axis:
3637            assert ndim == 2
3638            data = np.repeat(data[:, :, None], 3, axis=2)
3639            axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB"))))
3640            ndim += 1
3641
3642        assert ndim == 3
3643
3644        # transpose axis order such that longest axis comes first...
3645        axis_order: List[int] = list(np.argsort(list(data.shape)))
3646        axis_order.reverse()
3647        # ... and channel axis is last
3648        c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0]
3649        axis_order.append(axis_order.pop(c))
3650        axes = [axes[ao] for ao in axis_order]
3651        data = data.transpose(axis_order)
3652
3653        # h, w = data.shape[:2]
3654        # if h / w  in (1.0 or 2.0):
3655        #     pass
3656        # elif h / w < 2:
3657        # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images
3658
3659        norm_along = (
3660            tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None
3661        )
3662        # normalize the data and map to 8 bit
3663        data = normalize(data, norm_along)
3664        data = (data * 255).astype("uint8")
3665
3666        return data
3667
3668    def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]):
3669        assert im0.dtype == im1.dtype == np.uint8
3670        assert im0.shape == im1.shape
3671        assert im0.ndim == 3
3672        N, M, C = im0.shape
3673        assert C == 3
3674        out = np.ones((N, M, C), dtype="uint8")
3675        for c in range(C):
3676            outc = np.tril(im0[..., c])
3677            mask = outc == 0
3678            outc[mask] = np.triu(im1[..., c])[mask]
3679            out[..., c] = outc
3680
3681        return out
3682
3683    if not inputs:
3684        raise ValueError("Missing test input tensor for cover generation.")
3685
3686    if not outputs:
3687        raise ValueError("Missing test output tensor for cover generation.")
3688
3689    ipt_descr, ipt = inputs[0]
3690    out_descr, out = outputs[0]
3691
3692    ipt_img = to_2d_image(ipt, ipt_descr.axes)
3693    out_img = to_2d_image(out, out_descr.axes)
3694
3695    cover_folder = Path(mkdtemp())
3696    if ipt_img.shape == out_img.shape:
3697        covers = [cover_folder / "cover.png"]
3698        imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img))
3699    else:
3700        covers = [cover_folder / "input.png", cover_folder / "output.png"]
3701        imwrite(covers[0], ipt_img)
3702        imwrite(covers[1], out_img)
3703
3704    return covers
SpaceUnit = typing.Literal['attometer', 'angstrom', 'centimeter', 'decimeter', 'exameter', 'femtometer', 'foot', 'gigameter', 'hectometer', 'inch', 'kilometer', 'megameter', 'meter', 'micrometer', 'mile', 'millimeter', 'nanometer', 'parsec', 'petameter', 'picometer', 'terameter', 'yard', 'yoctometer', 'yottameter', 'zeptometer', 'zettameter']

Space unit compatible to the OME-Zarr axes specification 0.5

TimeUnit = typing.Literal['attosecond', 'centisecond', 'day', 'decisecond', 'exasecond', 'femtosecond', 'gigasecond', 'hectosecond', 'hour', 'kilosecond', 'megasecond', 'microsecond', 'millisecond', 'minute', 'nanosecond', 'petasecond', 'picosecond', 'second', 'terasecond', 'yoctosecond', 'yottasecond', 'zeptosecond', 'zettasecond']

Time unit compatible to the OME-Zarr axes specification 0.5

AxisType = typing.Literal['batch', 'channel', 'index', 'time', 'space']
229class TensorId(LowerCaseIdentifier):
230    root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
231        Annotated[LowerCaseIdentifierAnno, MaxLen(32)]
232    ]

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.

root_model: ClassVar[Type[pydantic.root_model.RootModel[Any]]] = <class 'pydantic.root_model.RootModel[Annotated[str, MinLen, AfterValidator, AfterValidator, Annotated[TypeVar, Predicate], MaxLen]]'>

the pydantic root model to validate the string

245class AxisId(LowerCaseIdentifier):
246    root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
247        Annotated[
248            LowerCaseIdentifierAnno,
249            MaxLen(16),
250            AfterValidator(_normalize_axis_id),
251        ]
252    ]

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.

root_model: ClassVar[Type[pydantic.root_model.RootModel[Any]]] = <class 'pydantic.root_model.RootModel[Annotated[str, MinLen, AfterValidator, AfterValidator, Annotated[TypeVar, Predicate], MaxLen, AfterValidator]]'>

the pydantic root model to validate the string

NonBatchAxisId = typing.Annotated[AxisId, Predicate(_is_not_batch)]
PreprocessingId = typing.Literal['binarize', 'clip', 'ensure_dtype', 'fixed_zero_mean_unit_variance', 'scale_linear', 'scale_range', 'sigmoid', 'softmax']
PostprocessingId = typing.Literal['binarize', 'clip', 'ensure_dtype', 'fixed_zero_mean_unit_variance', 'scale_linear', 'scale_mean_variance', 'scale_range', 'sigmoid', 'softmax', 'zero_mean_unit_variance']
SAME_AS_TYPE = '<same as type>'
ParameterizedSize_N = <class 'int'>

Annotates an integer to calculate a concrete axis size from a ParameterizedSize.

class ParameterizedSize(bioimageio.spec._internal.node.Node):
298class ParameterizedSize(Node):
299    """Describes a range of valid tensor axis sizes as `size = min + n*step`.
300
301    - **min** and **step** are given by the model description.
302    - All blocksize paramters n = 0,1,2,... yield a valid `size`.
303    - A greater blocksize paramter n = 0,1,2,... results in a greater **size**.
304      This allows to adjust the axis size more generically.
305    """
306
307    N: ClassVar[Type[int]] = ParameterizedSize_N
308    """Positive integer to parameterize this axis"""
309
310    min: Annotated[int, Gt(0)]
311    step: Annotated[int, Gt(0)]
312
313    def validate_size(self, size: int) -> int:
314        if size < self.min:
315            raise ValueError(f"size {size} < {self.min}")
316        if (size - self.min) % self.step != 0:
317            raise ValueError(
318                f"axis of size {size} is not parameterized by `min + n*step` ="
319                + f" `{self.min} + n*{self.step}`"
320            )
321
322        return size
323
324    def get_size(self, n: ParameterizedSize_N) -> int:
325        return self.min + self.step * n
326
327    def get_n(self, s: int) -> ParameterizedSize_N:
328        """return smallest n parameterizing a size greater or equal than `s`"""
329        return ceil((s - self.min) / self.step)

Describes a range of valid tensor axis sizes as size = min + n*step.

  • min and step are given by the model description.
  • All blocksize paramters n = 0,1,2,... yield a valid size.
  • A greater blocksize paramter n = 0,1,2,... results in a greater size. This allows to adjust the axis size more generically.
N: ClassVar[Type[int]] = <class 'int'>

Positive integer to parameterize this axis

min: Annotated[int, Gt(gt=0)]
step: Annotated[int, Gt(gt=0)]
def validate_size(self, size: int) -> int:
313    def validate_size(self, size: int) -> int:
314        if size < self.min:
315            raise ValueError(f"size {size} < {self.min}")
316        if (size - self.min) % self.step != 0:
317            raise ValueError(
318                f"axis of size {size} is not parameterized by `min + n*step` ="
319                + f" `{self.min} + n*{self.step}`"
320            )
321
322        return size
def get_size(self, n: int) -> int:
324    def get_size(self, n: ParameterizedSize_N) -> int:
325        return self.min + self.step * n
def get_n(self, s: int) -> int:
327    def get_n(self, s: int) -> ParameterizedSize_N:
328        """return smallest n parameterizing a size greater or equal than `s`"""
329        return ceil((s - self.min) / self.step)

return smallest n parameterizing a size greater or equal than s

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class DataDependentSize(bioimageio.spec._internal.node.Node):
332class DataDependentSize(Node):
333    min: Annotated[int, Gt(0)] = 1
334    max: Annotated[Optional[int], Gt(1)] = None
335
336    @model_validator(mode="after")
337    def _validate_max_gt_min(self):
338        if self.max is not None and self.min >= self.max:
339            raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}")
340
341        return self
342
343    def validate_size(self, size: int) -> int:
344        if size < self.min:
345            raise ValueError(f"size {size} < {self.min}")
346
347        if self.max is not None and size > self.max:
348            raise ValueError(f"size {size} > {self.max}")
349
350        return size
min: Annotated[int, Gt(gt=0)]
max: Annotated[Optional[int], Gt(gt=1)]
def validate_size(self, size: int) -> int:
343    def validate_size(self, size: int) -> int:
344        if size < self.min:
345            raise ValueError(f"size {size} < {self.min}")
346
347        if self.max is not None and size > self.max:
348            raise ValueError(f"size {size} > {self.max}")
349
350        return size
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class SizeReference(bioimageio.spec._internal.node.Node):
353class SizeReference(Node):
354    """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis.
355
356    `axis.size = reference.size * reference.scale / axis.scale + offset`
357
358    Note:
359    1. The axis and the referenced axis need to have the same unit (or no unit).
360    2. Batch axes may not be referenced.
361    3. Fractions are rounded down.
362    4. If the reference axis is `concatenable` the referencing axis is assumed to be
363        `concatenable` as well with the same block order.
364
365    Example:
366    An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm².
367    Let's assume that we want to express the image height h in relation to its width w
368    instead of only accepting input images of exactly 100*49 pixels
369    (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`).
370
371    >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2)
372    >>> h = SpaceInputAxis(
373    ...     id=AxisId("h"),
374    ...     size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1),
375    ...     unit="millimeter",
376    ...     scale=4,
377    ... )
378    >>> print(h.size.get_size(h, w))
379    49
380
381    ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49
382    """
383
384    tensor_id: TensorId
385    """tensor id of the reference axis"""
386
387    axis_id: AxisId
388    """axis id of the reference axis"""
389
390    offset: StrictInt = 0
391
392    def get_size(
393        self,
394        axis: Union[
395            ChannelAxis,
396            IndexInputAxis,
397            IndexOutputAxis,
398            TimeInputAxis,
399            SpaceInputAxis,
400            TimeOutputAxis,
401            TimeOutputAxisWithHalo,
402            SpaceOutputAxis,
403            SpaceOutputAxisWithHalo,
404        ],
405        ref_axis: Union[
406            ChannelAxis,
407            IndexInputAxis,
408            IndexOutputAxis,
409            TimeInputAxis,
410            SpaceInputAxis,
411            TimeOutputAxis,
412            TimeOutputAxisWithHalo,
413            SpaceOutputAxis,
414            SpaceOutputAxisWithHalo,
415        ],
416        n: ParameterizedSize_N = 0,
417        ref_size: Optional[int] = None,
418    ):
419        """Compute the concrete size for a given axis and its reference axis.
420
421        Args:
422            axis: The axis this `SizeReference` is the size of.
423            ref_axis: The reference axis to compute the size from.
424            n: If the **ref_axis** is parameterized (of type `ParameterizedSize`)
425                and no fixed **ref_size** is given,
426                **n** is used to compute the size of the parameterized **ref_axis**.
427            ref_size: Overwrite the reference size instead of deriving it from
428                **ref_axis**
429                (**ref_axis.scale** is still used; any given **n** is ignored).
430        """
431        assert (
432            axis.size == self
433        ), "Given `axis.size` is not defined by this `SizeReference`"
434
435        assert (
436            ref_axis.id == self.axis_id
437        ), f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}."
438
439        assert axis.unit == ref_axis.unit, (
440            "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`,"
441            f" but {axis.unit}!={ref_axis.unit}"
442        )
443        if ref_size is None:
444            if isinstance(ref_axis.size, (int, float)):
445                ref_size = ref_axis.size
446            elif isinstance(ref_axis.size, ParameterizedSize):
447                ref_size = ref_axis.size.get_size(n)
448            elif isinstance(ref_axis.size, DataDependentSize):
449                raise ValueError(
450                    "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`."
451                )
452            elif isinstance(ref_axis.size, SizeReference):
453                raise ValueError(
454                    "Reference axis referenced in `SizeReference` may not be sized by a"
455                    + " `SizeReference` itself."
456                )
457            else:
458                assert_never(ref_axis.size)
459
460        return int(ref_size * ref_axis.scale / axis.scale + self.offset)
461
462    @staticmethod
463    def _get_unit(
464        axis: Union[
465            ChannelAxis,
466            IndexInputAxis,
467            IndexOutputAxis,
468            TimeInputAxis,
469            SpaceInputAxis,
470            TimeOutputAxis,
471            TimeOutputAxisWithHalo,
472            SpaceOutputAxis,
473            SpaceOutputAxisWithHalo,
474        ],
475    ):
476        return axis.unit

A tensor axis size (extent in pixels/frames) defined in relation to a reference axis.

axis.size = reference.size * reference.scale / axis.scale + offset

Note:

  1. The axis and the referenced axis need to have the same unit (or no unit).
  2. Batch axes may not be referenced.
  3. Fractions are rounded down.
  4. If the reference axis is concatenable the referencing axis is assumed to be concatenable as well with the same block order.

Example: An unisotropic input image of wh=10049 pixels depicts a phsical space of 200196mm². Let's assume that we want to express the image height h in relation to its width w instead of only accepting input images of exactly 10049 pixels (for example to express a range of valid image shapes by parametrizing w, see ParameterizedSize).

>>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2)
>>> h = SpaceInputAxis(
...     id=AxisId("h"),
...     size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1),
...     unit="millimeter",
...     scale=4,
... )
>>> print(h.size.get_size(h, w))
49

⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49

tensor_id: TensorId

tensor id of the reference axis

axis_id: AxisId

axis id of the reference axis

offset: Annotated[int, Strict(strict=True)]
392    def get_size(
393        self,
394        axis: Union[
395            ChannelAxis,
396            IndexInputAxis,
397            IndexOutputAxis,
398            TimeInputAxis,
399            SpaceInputAxis,
400            TimeOutputAxis,
401            TimeOutputAxisWithHalo,
402            SpaceOutputAxis,
403            SpaceOutputAxisWithHalo,
404        ],
405        ref_axis: Union[
406            ChannelAxis,
407            IndexInputAxis,
408            IndexOutputAxis,
409            TimeInputAxis,
410            SpaceInputAxis,
411            TimeOutputAxis,
412            TimeOutputAxisWithHalo,
413            SpaceOutputAxis,
414            SpaceOutputAxisWithHalo,
415        ],
416        n: ParameterizedSize_N = 0,
417        ref_size: Optional[int] = None,
418    ):
419        """Compute the concrete size for a given axis and its reference axis.
420
421        Args:
422            axis: The axis this `SizeReference` is the size of.
423            ref_axis: The reference axis to compute the size from.
424            n: If the **ref_axis** is parameterized (of type `ParameterizedSize`)
425                and no fixed **ref_size** is given,
426                **n** is used to compute the size of the parameterized **ref_axis**.
427            ref_size: Overwrite the reference size instead of deriving it from
428                **ref_axis**
429                (**ref_axis.scale** is still used; any given **n** is ignored).
430        """
431        assert (
432            axis.size == self
433        ), "Given `axis.size` is not defined by this `SizeReference`"
434
435        assert (
436            ref_axis.id == self.axis_id
437        ), f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}."
438
439        assert axis.unit == ref_axis.unit, (
440            "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`,"
441            f" but {axis.unit}!={ref_axis.unit}"
442        )
443        if ref_size is None:
444            if isinstance(ref_axis.size, (int, float)):
445                ref_size = ref_axis.size
446            elif isinstance(ref_axis.size, ParameterizedSize):
447                ref_size = ref_axis.size.get_size(n)
448            elif isinstance(ref_axis.size, DataDependentSize):
449                raise ValueError(
450                    "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`."
451                )
452            elif isinstance(ref_axis.size, SizeReference):
453                raise ValueError(
454                    "Reference axis referenced in `SizeReference` may not be sized by a"
455                    + " `SizeReference` itself."
456                )
457            else:
458                assert_never(ref_axis.size)
459
460        return int(ref_size * ref_axis.scale / axis.scale + self.offset)

Compute the concrete size for a given axis and its reference axis.

Arguments:
  • axis: The axis this SizeReference is the size of.
  • ref_axis: The reference axis to compute the size from.
  • n: If the ref_axis is parameterized (of type ParameterizedSize) and no fixed ref_size is given, n is used to compute the size of the parameterized ref_axis.
  • ref_size: Overwrite the reference size instead of deriving it from ref_axis (ref_axis.scale is still used; any given n is ignored).
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

479class AxisBase(NodeWithExplicitlySetFields):
480    id: AxisId
481    """An axis id unique across all axes of one tensor."""
482
483    description: Annotated[str, MaxLen(128)] = ""
484    """A short description of this axis beyond its type and id."""
id: AxisId

An axis id unique across all axes of one tensor.

description: Annotated[str, MaxLen(max_length=128)]

A short description of this axis beyond its type and id.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class WithHalo(bioimageio.spec._internal.node.Node):
487class WithHalo(Node):
488    halo: Annotated[int, Ge(1)]
489    """The halo should be cropped from the output tensor to avoid boundary effects.
490    It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`.
491    To document a halo that is already cropped by the model use `size.offset` instead."""
492
493    size: Annotated[
494        SizeReference,
495        Field(
496            examples=[
497                10,
498                SizeReference(
499                    tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
500                ).model_dump(mode="json"),
501            ]
502        ),
503    ]
504    """reference to another axis with an optional offset (see `SizeReference`)"""
halo: Annotated[int, Ge(ge=1)]

The halo should be cropped from the output tensor to avoid boundary effects. It is to be cropped from both sides, i.e. size_after_crop = size - 2 * halo. To document a halo that is already cropped by the model use size.offset instead.

size: Annotated[SizeReference, FieldInfo(annotation=NoneType, required=True, examples=[10, {'tensor_id': 't', 'axis_id': 'a', 'offset': 5}])]

reference to another axis with an optional offset (see SizeReference)

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

BATCH_AXIS_ID = 'batch'
class BatchAxis(AxisBase):
510class BatchAxis(AxisBase):
511    implemented_type: ClassVar[Literal["batch"]] = "batch"
512    if TYPE_CHECKING:
513        type: Literal["batch"] = "batch"
514    else:
515        type: Literal["batch"]
516
517    id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID
518    size: Optional[Literal[1]] = None
519    """The batch size may be fixed to 1,
520    otherwise (the default) it may be chosen arbitrarily depending on available memory"""
521
522    @property
523    def scale(self):
524        return 1.0
525
526    @property
527    def concatenable(self):
528        return True
529
530    @property
531    def unit(self):
532        return None
implemented_type: ClassVar[Literal['batch']] = 'batch'
id: Annotated[AxisId, Predicate(_is_batch)]

An axis id unique across all axes of one tensor.

size: Optional[Literal[1]]

The batch size may be fixed to 1, otherwise (the default) it may be chosen arbitrarily depending on available memory

scale
522    @property
523    def scale(self):
524        return 1.0
concatenable
526    @property
527    def concatenable(self):
528        return True
unit
530    @property
531    def unit(self):
532        return None
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['batch']
class ChannelAxis(AxisBase):
535class ChannelAxis(AxisBase):
536    implemented_type: ClassVar[Literal["channel"]] = "channel"
537    if TYPE_CHECKING:
538        type: Literal["channel"] = "channel"
539    else:
540        type: Literal["channel"]
541
542    id: NonBatchAxisId = AxisId("channel")
543
544    channel_names: NotEmpty[List[Identifier]]
545
546    @property
547    def size(self) -> int:
548        return len(self.channel_names)
549
550    @property
551    def concatenable(self):
552        return False
553
554    @property
555    def scale(self) -> float:
556        return 1.0
557
558    @property
559    def unit(self):
560        return None
implemented_type: ClassVar[Literal['channel']] = 'channel'
id: Annotated[AxisId, Predicate(_is_not_batch)]

An axis id unique across all axes of one tensor.

channel_names: Annotated[List[bioimageio.spec._internal.types.Identifier], MinLen(min_length=1)]
size: int
546    @property
547    def size(self) -> int:
548        return len(self.channel_names)
concatenable
550    @property
551    def concatenable(self):
552        return False
scale: float
554    @property
555    def scale(self) -> float:
556        return 1.0
unit
558    @property
559    def unit(self):
560        return None
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['channel']
class IndexAxisBase(AxisBase):
563class IndexAxisBase(AxisBase):
564    implemented_type: ClassVar[Literal["index"]] = "index"
565    if TYPE_CHECKING:
566        type: Literal["index"] = "index"
567    else:
568        type: Literal["index"]
569
570    id: NonBatchAxisId = AxisId("index")
571
572    @property
573    def scale(self) -> float:
574        return 1.0
575
576    @property
577    def unit(self):
578        return None
implemented_type: ClassVar[Literal['index']] = 'index'
id: Annotated[AxisId, Predicate(_is_not_batch)]

An axis id unique across all axes of one tensor.

scale: float
572    @property
573    def scale(self) -> float:
574        return 1.0
unit
576    @property
577    def unit(self):
578        return None
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['index']
class IndexInputAxis(IndexAxisBase, _WithInputAxisSize):
601class IndexInputAxis(IndexAxisBase, _WithInputAxisSize):
602    concatenable: bool = False
603    """If a model has a `concatenable` input axis, it can be processed blockwise,
604    splitting a longer sample axis into blocks matching its input tensor description.
605    Output axes are concatenable if they have a `SizeReference` to a concatenable
606    input axis.
607    """
concatenable: bool

If a model has a concatenable input axis, it can be processed blockwise, splitting a longer sample axis into blocks matching its input tensor description. Output axes are concatenable if they have a SizeReference to a concatenable input axis.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['index']
class IndexOutputAxis(IndexAxisBase):
610class IndexOutputAxis(IndexAxisBase):
611    size: Annotated[
612        Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize],
613        Field(
614            examples=[
615                10,
616                SizeReference(
617                    tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
618                ).model_dump(mode="json"),
619            ]
620        ),
621    ]
622    """The size/length of this axis can be specified as
623    - fixed integer
624    - reference to another axis with an optional offset (`SizeReference`)
625    - data dependent size using `DataDependentSize` (size is only known after model inference)
626    """
size: Annotated[Union[Annotated[int, Gt(gt=0)], SizeReference, DataDependentSize], FieldInfo(annotation=NoneType, required=True, examples=[10, {'tensor_id': 't', 'axis_id': 'a', 'offset': 5}])]

The size/length of this axis can be specified as

  • fixed integer
  • reference to another axis with an optional offset (SizeReference)
  • data dependent size using DataDependentSize (size is only known after model inference)
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['index']
class TimeAxisBase(AxisBase):
629class TimeAxisBase(AxisBase):
630    implemented_type: ClassVar[Literal["time"]] = "time"
631    if TYPE_CHECKING:
632        type: Literal["time"] = "time"
633    else:
634        type: Literal["time"]
635
636    id: NonBatchAxisId = AxisId("time")
637    unit: Optional[TimeUnit] = None
638    scale: Annotated[float, Gt(0)] = 1.0
implemented_type: ClassVar[Literal['time']] = 'time'
id: Annotated[AxisId, Predicate(_is_not_batch)]

An axis id unique across all axes of one tensor.

unit: Optional[Literal['attosecond', 'centisecond', 'day', 'decisecond', 'exasecond', 'femtosecond', 'gigasecond', 'hectosecond', 'hour', 'kilosecond', 'megasecond', 'microsecond', 'millisecond', 'minute', 'nanosecond', 'petasecond', 'picosecond', 'second', 'terasecond', 'yoctosecond', 'yottasecond', 'zeptosecond', 'zettasecond']]
scale: Annotated[float, Gt(gt=0)]
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['time']
class TimeInputAxis(TimeAxisBase, _WithInputAxisSize):
641class TimeInputAxis(TimeAxisBase, _WithInputAxisSize):
642    concatenable: bool = False
643    """If a model has a `concatenable` input axis, it can be processed blockwise,
644    splitting a longer sample axis into blocks matching its input tensor description.
645    Output axes are concatenable if they have a `SizeReference` to a concatenable
646    input axis.
647    """
concatenable: bool

If a model has a concatenable input axis, it can be processed blockwise, splitting a longer sample axis into blocks matching its input tensor description. Output axes are concatenable if they have a SizeReference to a concatenable input axis.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['time']
class SpaceAxisBase(AxisBase):
650class SpaceAxisBase(AxisBase):
651    implemented_type: ClassVar[Literal["space"]] = "space"
652    if TYPE_CHECKING:
653        type: Literal["space"] = "space"
654    else:
655        type: Literal["space"]
656
657    id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x")
658    unit: Optional[SpaceUnit] = None
659    scale: Annotated[float, Gt(0)] = 1.0
implemented_type: ClassVar[Literal['space']] = 'space'
id: Annotated[AxisId, Predicate(_is_not_batch), FieldInfo(annotation=NoneType, required=True, examples=['x', 'y', 'z'])]

An axis id unique across all axes of one tensor.

unit: Optional[Literal['attometer', 'angstrom', 'centimeter', 'decimeter', 'exameter', 'femtometer', 'foot', 'gigameter', 'hectometer', 'inch', 'kilometer', 'megameter', 'meter', 'micrometer', 'mile', 'millimeter', 'nanometer', 'parsec', 'petameter', 'picometer', 'terameter', 'yard', 'yoctometer', 'yottameter', 'zeptometer', 'zettameter']]
scale: Annotated[float, Gt(gt=0)]
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['space']
class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize):
662class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize):
663    concatenable: bool = False
664    """If a model has a `concatenable` input axis, it can be processed blockwise,
665    splitting a longer sample axis into blocks matching its input tensor description.
666    Output axes are concatenable if they have a `SizeReference` to a concatenable
667    input axis.
668    """
concatenable: bool

If a model has a concatenable input axis, it can be processed blockwise, splitting a longer sample axis into blocks matching its input tensor description. Output axes are concatenable if they have a SizeReference to a concatenable input axis.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['space']
INPUT_AXIS_TYPES = (<class 'BatchAxis'>, <class 'ChannelAxis'>, <class 'IndexInputAxis'>, <class 'TimeInputAxis'>, <class 'SpaceInputAxis'>)

intended for isinstance comparisons in py<3.10

InputAxis = typing.Annotated[typing.Union[BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)]
class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize):
704class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize):
705    pass
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['time']
class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo):
708class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo):
709    pass
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['time']
class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize):
728class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize):
729    pass
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['space']
class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo):
732class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo):
733    pass
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['space']
OutputAxis = typing.Annotated[typing.Union[BatchAxis, ChannelAxis, IndexOutputAxis, typing.Annotated[typing.Union[typing.Annotated[TimeOutputAxis, Tag(tag='wo_halo')], typing.Annotated[TimeOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)], typing.Annotated[typing.Union[typing.Annotated[SpaceOutputAxis, Tag(tag='wo_halo')], typing.Annotated[SpaceOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)]
OUTPUT_AXIS_TYPES = (<class 'BatchAxis'>, <class 'ChannelAxis'>, <class 'IndexOutputAxis'>, <class 'TimeOutputAxis'>, <class 'TimeOutputAxisWithHalo'>, <class 'SpaceOutputAxis'>, <class 'SpaceOutputAxisWithHalo'>)

intended for isinstance comparisons in py<3.10

AnyAxis = typing.Union[typing.Annotated[typing.Union[BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], typing.Annotated[typing.Union[BatchAxis, ChannelAxis, IndexOutputAxis, typing.Annotated[typing.Union[typing.Annotated[TimeOutputAxis, Tag(tag='wo_halo')], typing.Annotated[TimeOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)], typing.Annotated[typing.Union[typing.Annotated[SpaceOutputAxis, Tag(tag='wo_halo')], typing.Annotated[SpaceOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]
ANY_AXIS_TYPES = (<class 'BatchAxis'>, <class 'ChannelAxis'>, <class 'IndexInputAxis'>, <class 'TimeInputAxis'>, <class 'SpaceInputAxis'>, <class 'BatchAxis'>, <class 'ChannelAxis'>, <class 'IndexOutputAxis'>, <class 'TimeOutputAxis'>, <class 'TimeOutputAxisWithHalo'>, <class 'SpaceOutputAxis'>, <class 'SpaceOutputAxisWithHalo'>)

intended for isinstance comparisons in py<3.10

TVs = typing.Union[typing.Annotated[typing.List[int], MinLen(min_length=1)], typing.Annotated[typing.List[float], MinLen(min_length=1)], typing.Annotated[typing.List[bool], MinLen(min_length=1)], typing.Annotated[typing.List[str], MinLen(min_length=1)]]
NominalOrOrdinalDType = typing.Literal['float32', 'float64', 'uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64', 'bool']
class NominalOrOrdinalDataDescr(bioimageio.spec._internal.node.Node):
790class NominalOrOrdinalDataDescr(Node):
791    values: TVs
792    """A fixed set of nominal or an ascending sequence of ordinal values.
793    In this case `data.type` is required to be an unsigend integer type, e.g. 'uint8'.
794    String `values` are interpreted as labels for tensor values 0, ..., N.
795    Note: as YAML 1.2 does not natively support a "set" datatype,
796    nominal values should be given as a sequence (aka list/array) as well.
797    """
798
799    type: Annotated[
800        NominalOrOrdinalDType,
801        Field(
802            examples=[
803                "float32",
804                "uint8",
805                "uint16",
806                "int64",
807                "bool",
808            ],
809        ),
810    ] = "uint8"
811
812    @model_validator(mode="after")
813    def _validate_values_match_type(
814        self,
815    ) -> Self:
816        incompatible: List[Any] = []
817        for v in self.values:
818            if self.type == "bool":
819                if not isinstance(v, bool):
820                    incompatible.append(v)
821            elif self.type in DTYPE_LIMITS:
822                if (
823                    isinstance(v, (int, float))
824                    and (
825                        v < DTYPE_LIMITS[self.type].min
826                        or v > DTYPE_LIMITS[self.type].max
827                    )
828                    or (isinstance(v, str) and "uint" not in self.type)
829                    or (isinstance(v, float) and "int" in self.type)
830                ):
831                    incompatible.append(v)
832            else:
833                incompatible.append(v)
834
835            if len(incompatible) == 5:
836                incompatible.append("...")
837                break
838
839        if incompatible:
840            raise ValueError(
841                f"data type '{self.type}' incompatible with values {incompatible}"
842            )
843
844        return self
845
846    unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None
847
848    @property
849    def range(self):
850        if isinstance(self.values[0], str):
851            return 0, len(self.values) - 1
852        else:
853            return min(self.values), max(self.values)
values: Union[Annotated[List[int], MinLen(min_length=1)], Annotated[List[float], MinLen(min_length=1)], Annotated[List[bool], MinLen(min_length=1)], Annotated[List[str], MinLen(min_length=1)]]

A fixed set of nominal or an ascending sequence of ordinal values. In this case data.type is required to be an unsigend integer type, e.g. 'uint8'. String values are interpreted as labels for tensor values 0, ..., N. Note: as YAML 1.2 does not natively support a "set" datatype, nominal values should be given as a sequence (aka list/array) as well.

type: Annotated[Literal['float32', 'float64', 'uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64', 'bool'], FieldInfo(annotation=NoneType, required=True, examples=['float32', 'uint8', 'uint16', 'int64', 'bool'])]
unit: Union[Literal['arbitrary unit'], bioimageio.spec._internal.types.SiUnit, NoneType]
range
848    @property
849    def range(self):
850        if isinstance(self.values[0], str):
851            return 0, len(self.values) - 1
852        else:
853            return min(self.values), max(self.values)
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

IntervalOrRatioDType = typing.Literal['float32', 'float64', 'uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64']
class IntervalOrRatioDataDescr(bioimageio.spec._internal.node.Node):
870class IntervalOrRatioDataDescr(Node):
871    type: Annotated[  # TODO: rename to dtype
872        IntervalOrRatioDType,
873        Field(
874            examples=["float32", "float64", "uint8", "uint16"],
875        ),
876    ] = "float32"
877    range: Tuple[Optional[float], Optional[float]] = (
878        None,
879        None,
880    )
881    """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor.
882    `None` corresponds to min/max of what can be expressed by **type**."""
883    unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit"
884    scale: float = 1.0
885    """Scale for data on an interval (or ratio) scale."""
886    offset: Optional[float] = None
887    """Offset for data on a ratio scale."""
888
889    @model_validator(mode="before")
890    def _replace_inf(cls, data: Any):
891        if is_dict(data):
892            if "range" in data and is_sequence(data["range"]):
893                forbidden = (
894                    "inf",
895                    "-inf",
896                    ".inf",
897                    "-.inf",
898                    float("inf"),
899                    float("-inf"),
900                )
901                if any(v in forbidden for v in data["range"]):
902                    issue_warning("replaced 'inf' value", value=data["range"])
903
904                data["range"] = tuple(
905                    (None if v in forbidden else v) for v in data["range"]
906                )
907
908        return data
type: Annotated[Literal['float32', 'float64', 'uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64'], FieldInfo(annotation=NoneType, required=True, examples=['float32', 'float64', 'uint8', 'uint16'])]
range: Tuple[Optional[float], Optional[float]]

Tuple (minimum, maximum) specifying the allowed range of the data in this tensor. None corresponds to min/max of what can be expressed by type.

unit: Union[Literal['arbitrary unit'], bioimageio.spec._internal.types.SiUnit]
scale: float

Scale for data on an interval (or ratio) scale.

offset: Optional[float]

Offset for data on a ratio scale.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

TensorDataDescr = typing.Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr]
class ProcessingDescrBase(bioimageio.spec._internal.common_nodes.NodeWithExplicitlySetFields, abc.ABC):
914class ProcessingDescrBase(NodeWithExplicitlySetFields, ABC):
915    """processing base class"""

processing base class

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class BinarizeKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
918class BinarizeKwargs(ProcessingKwargs):
919    """key word arguments for `BinarizeDescr`"""
920
921    threshold: float
922    """The fixed threshold"""

key word arguments for BinarizeDescr

threshold: float

The fixed threshold

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class BinarizeAlongAxisKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
925class BinarizeAlongAxisKwargs(ProcessingKwargs):
926    """key word arguments for `BinarizeDescr`"""
927
928    threshold: NotEmpty[List[float]]
929    """The fixed threshold values along `axis`"""
930
931    axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
932    """The `threshold` axis"""

key word arguments for BinarizeDescr

threshold: Annotated[List[float], MinLen(min_length=1)]

The fixed threshold values along axis

axis: Annotated[AxisId, Predicate(_is_not_batch), FieldInfo(annotation=NoneType, required=True, examples=['channel'])]

The threshold axis

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class BinarizeDescr(ProcessingDescrBase):
935class BinarizeDescr(ProcessingDescrBase):
936    """Binarize the tensor with a fixed threshold.
937
938    Values above `BinarizeKwargs.threshold`/`BinarizeAlongAxisKwargs.threshold`
939    will be set to one, values below the threshold to zero.
940
941    Examples:
942    - in YAML
943        ```yaml
944        postprocessing:
945          - id: binarize
946            kwargs:
947              axis: 'channel'
948              threshold: [0.25, 0.5, 0.75]
949        ```
950    - in Python:
951        >>> postprocessing = [BinarizeDescr(
952        ...   kwargs=BinarizeAlongAxisKwargs(
953        ...       axis=AxisId('channel'),
954        ...       threshold=[0.25, 0.5, 0.75],
955        ...   )
956        ... )]
957    """
958
959    implemented_id: ClassVar[Literal["binarize"]] = "binarize"
960    if TYPE_CHECKING:
961        id: Literal["binarize"] = "binarize"
962    else:
963        id: Literal["binarize"]
964    kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs]

Binarize the tensor with a fixed threshold.

Values above BinarizeKwargs.threshold/BinarizeAlongAxisKwargs.threshold will be set to one, values below the threshold to zero.

Examples:

  • in YAML
postprocessing:
  - id: binarize
    kwargs:
      axis: 'channel'
      threshold: [0.25, 0.5, 0.75]
  • in Python:
    >>> postprocessing = [BinarizeDescr(
    ...   kwargs=BinarizeAlongAxisKwargs(
    ...       axis=AxisId('channel'),
    ...       threshold=[0.25, 0.5, 0.75],
    ...   )
    ... )]
    
implemented_id: ClassVar[Literal['binarize']] = 'binarize'
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

id: Literal['binarize']
class ClipDescr(ProcessingDescrBase):
967class ClipDescr(ProcessingDescrBase):
968    """Set tensor values below min to min and above max to max.
969
970    See `ScaleRangeDescr` for examples.
971    """
972
973    implemented_id: ClassVar[Literal["clip"]] = "clip"
974    if TYPE_CHECKING:
975        id: Literal["clip"] = "clip"
976    else:
977        id: Literal["clip"]
978
979    kwargs: ClipKwargs

Set tensor values below min to min and above max to max.

See ScaleRangeDescr for examples.

implemented_id: ClassVar[Literal['clip']] = 'clip'
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

id: Literal['clip']
class EnsureDtypeKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
982class EnsureDtypeKwargs(ProcessingKwargs):
983    """key word arguments for `EnsureDtypeDescr`"""
984
985    dtype: Literal[
986        "float32",
987        "float64",
988        "uint8",
989        "int8",
990        "uint16",
991        "int16",
992        "uint32",
993        "int32",
994        "uint64",
995        "int64",
996        "bool",
997    ]

key word arguments for EnsureDtypeDescr

dtype: Literal['float32', 'float64', 'uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64', 'bool']
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class EnsureDtypeDescr(ProcessingDescrBase):
1000class EnsureDtypeDescr(ProcessingDescrBase):
1001    """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching).
1002
1003    This can for example be used to ensure the inner neural network model gets a
1004    different input tensor data type than the fully described bioimage.io model does.
1005
1006    Examples:
1007        The described bioimage.io model (incl. preprocessing) accepts any
1008        float32-compatible tensor, normalizes it with percentiles and clipping and then
1009        casts it to uint8, which is what the neural network in this example expects.
1010        - in YAML
1011            ```yaml
1012            inputs:
1013            - data:
1014                type: float32  # described bioimage.io model is compatible with any float32 input tensor
1015              preprocessing:
1016              - id: scale_range
1017                  kwargs:
1018                  axes: ['y', 'x']
1019                  max_percentile: 99.8
1020                  min_percentile: 5.0
1021              - id: clip
1022                  kwargs:
1023                  min: 0.0
1024                  max: 1.0
1025              - id: ensure_dtype  # the neural network of the model requires uint8
1026                  kwargs:
1027                  dtype: uint8
1028            ```
1029        - in Python:
1030            >>> preprocessing = [
1031            ...     ScaleRangeDescr(
1032            ...         kwargs=ScaleRangeKwargs(
1033            ...           axes= (AxisId('y'), AxisId('x')),
1034            ...           max_percentile= 99.8,
1035            ...           min_percentile= 5.0,
1036            ...         )
1037            ...     ),
1038            ...     ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)),
1039            ...     EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")),
1040            ... ]
1041    """
1042
1043    implemented_id: ClassVar[Literal["ensure_dtype"]] = "ensure_dtype"
1044    if TYPE_CHECKING:
1045        id: Literal["ensure_dtype"] = "ensure_dtype"
1046    else:
1047        id: Literal["ensure_dtype"]
1048
1049    kwargs: EnsureDtypeKwargs

Cast the tensor data type to EnsureDtypeKwargs.dtype (if not matching).

This can for example be used to ensure the inner neural network model gets a different input tensor data type than the fully described bioimage.io model does.

Examples:

The described bioimage.io model (incl. preprocessing) accepts any float32-compatible tensor, normalizes it with percentiles and clipping and then casts it to uint8, which is what the neural network in this example expects.

  • in YAML

inputs:
- data:
    type: float32  # described bioimage.io model is compatible with any float32 input tensor
  preprocessing:
  - id: scale_range
      kwargs:
      axes: ['y', 'x']
      max_percentile: 99.8
      min_percentile: 5.0
  - id: clip
      kwargs:
      min: 0.0
      max: 1.0
  - id: ensure_dtype  # the neural network of the model requires uint8
      kwargs:
      dtype: uint8

  • in Python:
    >>> preprocessing = [
    ...     ScaleRangeDescr(
    ...         kwargs=ScaleRangeKwargs(
    ...           axes= (AxisId('y'), AxisId('x')),
    ...           max_percentile= 99.8,
    ...           min_percentile= 5.0,
    ...         )
    ...     ),
    ...     ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)),
    ...     EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")),
    ... ]
    
implemented_id: ClassVar[Literal['ensure_dtype']] = 'ensure_dtype'
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

id: Literal['ensure_dtype']
class ScaleLinearKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1052class ScaleLinearKwargs(ProcessingKwargs):
1053    """Key word arguments for `ScaleLinearDescr`"""
1054
1055    gain: float = 1.0
1056    """multiplicative factor"""
1057
1058    offset: float = 0.0
1059    """additive term"""
1060
1061    @model_validator(mode="after")
1062    def _validate(self) -> Self:
1063        if self.gain == 1.0 and self.offset == 0.0:
1064            raise ValueError(
1065                "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
1066                + " != 0.0."
1067            )
1068
1069        return self

Key word arguments for ScaleLinearDescr

gain: float

multiplicative factor

offset: float

additive term

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class ScaleLinearAlongAxisKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1072class ScaleLinearAlongAxisKwargs(ProcessingKwargs):
1073    """Key word arguments for `ScaleLinearDescr`"""
1074
1075    axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
1076    """The axis of gain and offset values."""
1077
1078    gain: Union[float, NotEmpty[List[float]]] = 1.0
1079    """multiplicative factor"""
1080
1081    offset: Union[float, NotEmpty[List[float]]] = 0.0
1082    """additive term"""
1083
1084    @model_validator(mode="after")
1085    def _validate(self) -> Self:
1086
1087        if isinstance(self.gain, list):
1088            if isinstance(self.offset, list):
1089                if len(self.gain) != len(self.offset):
1090                    raise ValueError(
1091                        f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match."
1092                    )
1093            else:
1094                self.offset = [float(self.offset)] * len(self.gain)
1095        elif isinstance(self.offset, list):
1096            self.gain = [float(self.gain)] * len(self.offset)
1097        else:
1098            raise ValueError(
1099                "Do not specify an `axis` for scalar gain and offset values."
1100            )
1101
1102        if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset):
1103            raise ValueError(
1104                "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
1105                + " != 0.0."
1106            )
1107
1108        return self

Key word arguments for ScaleLinearDescr

axis: Annotated[AxisId, Predicate(_is_not_batch), FieldInfo(annotation=NoneType, required=True, examples=['channel'])]

The axis of gain and offset values.

gain: Union[float, Annotated[List[float], MinLen(min_length=1)]]

multiplicative factor

offset: Union[float, Annotated[List[float], MinLen(min_length=1)]]

additive term

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class ScaleLinearDescr(ProcessingDescrBase):
1111class ScaleLinearDescr(ProcessingDescrBase):
1112    """Fixed linear scaling.
1113
1114    Examples:
1115      1. Scale with scalar gain and offset
1116        - in YAML
1117        ```yaml
1118        preprocessing:
1119          - id: scale_linear
1120            kwargs:
1121              gain: 2.0
1122              offset: 3.0
1123        ```
1124        - in Python:
1125        >>> preprocessing = [
1126        ...     ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0))
1127        ... ]
1128
1129      2. Independent scaling along an axis
1130        - in YAML
1131        ```yaml
1132        preprocessing:
1133          - id: scale_linear
1134            kwargs:
1135              axis: 'channel'
1136              gain: [1.0, 2.0, 3.0]
1137        ```
1138        - in Python:
1139        >>> preprocessing = [
1140        ...     ScaleLinearDescr(
1141        ...         kwargs=ScaleLinearAlongAxisKwargs(
1142        ...             axis=AxisId("channel"),
1143        ...             gain=[1.0, 2.0, 3.0],
1144        ...         )
1145        ...     )
1146        ... ]
1147
1148    """
1149
1150    implemented_id: ClassVar[Literal["scale_linear"]] = "scale_linear"
1151    if TYPE_CHECKING:
1152        id: Literal["scale_linear"] = "scale_linear"
1153    else:
1154        id: Literal["scale_linear"]
1155    kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs]

Fixed linear scaling.

Examples:
  1. Scale with scalar gain and offset

    • in YAML

      preprocessing:
        - id: scale_linear
          kwargs:
            gain: 2.0
            offset: 3.0
      
    • in Python:

      >>> preprocessing = [
      ...     ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0))
      ... ]
      
  2. Independent scaling along an axis

    • in YAML

      preprocessing:
        - id: scale_linear
          kwargs:
            axis: 'channel'
            gain: [1.0, 2.0, 3.0]
      
    • in Python:

      >>> preprocessing = [
      ...     ScaleLinearDescr(
      ...         kwargs=ScaleLinearAlongAxisKwargs(
      ...             axis=AxisId("channel"),
      ...             gain=[1.0, 2.0, 3.0],
      ...         )
      ...     )
      ... ]
      
implemented_id: ClassVar[Literal['scale_linear']] = 'scale_linear'
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

id: Literal['scale_linear']
class SigmoidDescr(ProcessingDescrBase):
1158class SigmoidDescr(ProcessingDescrBase):
1159    """The logistic sigmoid function, a.k.a. expit function.
1160
1161    Examples:
1162    - in YAML
1163        ```yaml
1164        postprocessing:
1165          - id: sigmoid
1166        ```
1167    - in Python:
1168        >>> postprocessing = [SigmoidDescr()]
1169    """
1170
1171    implemented_id: ClassVar[Literal["sigmoid"]] = "sigmoid"
1172    if TYPE_CHECKING:
1173        id: Literal["sigmoid"] = "sigmoid"
1174    else:
1175        id: Literal["sigmoid"]
1176
1177    @property
1178    def kwargs(self) -> ProcessingKwargs:
1179        """empty kwargs"""
1180        return ProcessingKwargs()

The logistic sigmoid function, a.k.a. expit function.

Examples:

  • in YAML
postprocessing:
  - id: sigmoid
  • in Python:
    >>> postprocessing = [SigmoidDescr()]
    
implemented_id: ClassVar[Literal['sigmoid']] = 'sigmoid'
1177    @property
1178    def kwargs(self) -> ProcessingKwargs:
1179        """empty kwargs"""
1180        return ProcessingKwargs()

empty kwargs

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

id: Literal['sigmoid']
class SoftmaxKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1183class SoftmaxKwargs(ProcessingKwargs):
1184    """key word arguments for `SoftmaxDescr`"""
1185
1186    axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] = AxisId("channel")
1187    """The axis to apply the softmax function along.
1188    Note:
1189        Defaults to 'channel' axis
1190        (which may not exist, in which case
1191        a different axis id has to be specified).
1192    """

key word arguments for SoftmaxDescr

axis: Annotated[AxisId, Predicate(_is_not_batch), FieldInfo(annotation=NoneType, required=True, examples=['channel'])]

The axis to apply the softmax function along.

Note:

Defaults to 'channel' axis (which may not exist, in which case a different axis id has to be specified).

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class SoftmaxDescr(ProcessingDescrBase):
1195class SoftmaxDescr(ProcessingDescrBase):
1196    """The softmax function.
1197
1198    Examples:
1199    - in YAML
1200        ```yaml
1201        postprocessing:
1202          - id: softmax
1203            kwargs:
1204              axis: channel
1205        ```
1206    - in Python:
1207        >>> postprocessing = [SoftmaxDescr(kwargs=SoftmaxKwargs(axis=AxisId("channel")))]
1208    """
1209
1210    implemented_id: ClassVar[Literal["softmax"]] = "softmax"
1211    if TYPE_CHECKING:
1212        id: Literal["softmax"] = "softmax"
1213    else:
1214        id: Literal["softmax"]
1215
1216    kwargs: SoftmaxKwargs = Field(default_factory=SoftmaxKwargs.model_construct)

The softmax function.

Examples:

  • in YAML
postprocessing:
  - id: softmax
    kwargs:
      axis: channel
  • in Python:
    >>> postprocessing = [SoftmaxDescr(kwargs=SoftmaxKwargs(axis=AxisId("channel")))]
    
implemented_id: ClassVar[Literal['softmax']] = 'softmax'
kwargs: SoftmaxKwargs
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

id: Literal['softmax']
class FixedZeroMeanUnitVarianceKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1219class FixedZeroMeanUnitVarianceKwargs(ProcessingKwargs):
1220    """key word arguments for `FixedZeroMeanUnitVarianceDescr`"""
1221
1222    mean: float
1223    """The mean value to normalize with."""
1224
1225    std: Annotated[float, Ge(1e-6)]
1226    """The standard deviation value to normalize with."""

key word arguments for FixedZeroMeanUnitVarianceDescr

mean: float

The mean value to normalize with.

std: Annotated[float, Ge(ge=1e-06)]

The standard deviation value to normalize with.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class FixedZeroMeanUnitVarianceAlongAxisKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1229class FixedZeroMeanUnitVarianceAlongAxisKwargs(ProcessingKwargs):
1230    """key word arguments for `FixedZeroMeanUnitVarianceDescr`"""
1231
1232    mean: NotEmpty[List[float]]
1233    """The mean value(s) to normalize with."""
1234
1235    std: NotEmpty[List[Annotated[float, Ge(1e-6)]]]
1236    """The standard deviation value(s) to normalize with.
1237    Size must match `mean` values."""
1238
1239    axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])]
1240    """The axis of the mean/std values to normalize each entry along that dimension
1241    separately."""
1242
1243    @model_validator(mode="after")
1244    def _mean_and_std_match(self) -> Self:
1245        if len(self.mean) != len(self.std):
1246            raise ValueError(
1247                f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})"
1248                + " must match."
1249            )
1250
1251        return self

key word arguments for FixedZeroMeanUnitVarianceDescr

mean: Annotated[List[float], MinLen(min_length=1)]

The mean value(s) to normalize with.

std: Annotated[List[Annotated[float, Ge(ge=1e-06)]], MinLen(min_length=1)]

The standard deviation value(s) to normalize with. Size must match mean values.

axis: Annotated[AxisId, Predicate(_is_not_batch), FieldInfo(annotation=NoneType, required=True, examples=['channel', 'index'])]

The axis of the mean/std values to normalize each entry along that dimension separately.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1254class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1255    """Subtract a given mean and divide by the standard deviation.
1256
1257    Normalize with fixed, precomputed values for
1258    `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std`
1259    Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given
1260    axes.
1261
1262    Examples:
1263    1. scalar value for whole tensor
1264        - in YAML
1265        ```yaml
1266        preprocessing:
1267          - id: fixed_zero_mean_unit_variance
1268            kwargs:
1269              mean: 103.5
1270              std: 13.7
1271        ```
1272        - in Python
1273        >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
1274        ...   kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7)
1275        ... )]
1276
1277    2. independently along an axis
1278        - in YAML
1279        ```yaml
1280        preprocessing:
1281          - id: fixed_zero_mean_unit_variance
1282            kwargs:
1283              axis: channel
1284              mean: [101.5, 102.5, 103.5]
1285              std: [11.7, 12.7, 13.7]
1286        ```
1287        - in Python
1288        >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
1289        ...   kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
1290        ...     axis=AxisId("channel"),
1291        ...     mean=[101.5, 102.5, 103.5],
1292        ...     std=[11.7, 12.7, 13.7],
1293        ...   )
1294        ... )]
1295    """
1296
1297    implemented_id: ClassVar[Literal["fixed_zero_mean_unit_variance"]] = (
1298        "fixed_zero_mean_unit_variance"
1299    )
1300    if TYPE_CHECKING:
1301        id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance"
1302    else:
1303        id: Literal["fixed_zero_mean_unit_variance"]
1304
1305    kwargs: Union[
1306        FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs
1307    ]

Subtract a given mean and divide by the standard deviation.

Normalize with fixed, precomputed values for FixedZeroMeanUnitVarianceKwargs.mean and FixedZeroMeanUnitVarianceKwargs.std Use FixedZeroMeanUnitVarianceAlongAxisKwargs for independent scaling along given axes.

Examples:

  1. scalar value for whole tensor
    • in YAML
preprocessing:
  - id: fixed_zero_mean_unit_variance
    kwargs:
      mean: 103.5
      std: 13.7
- in Python >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( ... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7) ... )]
  1. independently along an axis

    • in YAML
    preprocessing:
      - id: fixed_zero_mean_unit_variance
        kwargs:
          axis: channel
          mean: [101.5, 102.5, 103.5]
          std: [11.7, 12.7, 13.7]
    
    • in Python
      >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
      ...   kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
      ...     axis=AxisId("channel"),
      ...     mean=[101.5, 102.5, 103.5],
      ...     std=[11.7, 12.7, 13.7],
      ...   )
      ... )]
      
implemented_id: ClassVar[Literal['fixed_zero_mean_unit_variance']] = 'fixed_zero_mean_unit_variance'
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

id: Literal['fixed_zero_mean_unit_variance']
class ZeroMeanUnitVarianceKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1310class ZeroMeanUnitVarianceKwargs(ProcessingKwargs):
1311    """key word arguments for `ZeroMeanUnitVarianceDescr`"""
1312
1313    axes: Annotated[
1314        Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1315    ] = None
1316    """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1317    For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1318    resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1319    To normalize each sample independently leave out the 'batch' axis.
1320    Default: Scale all axes jointly."""
1321
1322    eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1323    """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`."""

key word arguments for ZeroMeanUnitVarianceDescr

axes: Annotated[Optional[Sequence[AxisId]], FieldInfo(annotation=NoneType, required=True, examples=[('batch', 'x', 'y')])]

The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') resulting in a tensor of equal shape normalized per channel, specify axes=('batch', 'x', 'y'). To normalize each sample independently leave out the 'batch' axis. Default: Scale all axes jointly.

eps: Annotated[float, Interval(gt=0, ge=None, lt=None, le=0.1)]

epsilon for numeric stability: out = (tensor - mean) / (std + eps).

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class ZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1326class ZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1327    """Subtract mean and divide by variance.
1328
1329    Examples:
1330        Subtract tensor mean and variance
1331        - in YAML
1332        ```yaml
1333        preprocessing:
1334          - id: zero_mean_unit_variance
1335        ```
1336        - in Python
1337        >>> preprocessing = [ZeroMeanUnitVarianceDescr()]
1338    """
1339
1340    implemented_id: ClassVar[Literal["zero_mean_unit_variance"]] = (
1341        "zero_mean_unit_variance"
1342    )
1343    if TYPE_CHECKING:
1344        id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance"
1345    else:
1346        id: Literal["zero_mean_unit_variance"]
1347
1348    kwargs: ZeroMeanUnitVarianceKwargs = Field(
1349        default_factory=ZeroMeanUnitVarianceKwargs.model_construct
1350    )

Subtract mean and divide by variance.

Examples:

Subtract tensor mean and variance

  • in YAML
preprocessing:
  - id: zero_mean_unit_variance
  • in Python
    >>> preprocessing = [ZeroMeanUnitVarianceDescr()]
    
implemented_id: ClassVar[Literal['zero_mean_unit_variance']] = 'zero_mean_unit_variance'
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

id: Literal['zero_mean_unit_variance']
class ScaleRangeKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1353class ScaleRangeKwargs(ProcessingKwargs):
1354    """key word arguments for `ScaleRangeDescr`
1355
1356    For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default)
1357    this processing step normalizes data to the [0, 1] intervall.
1358    For other percentiles the normalized values will partially be outside the [0, 1]
1359    intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the
1360    normalized values to a range.
1361    """
1362
1363    axes: Annotated[
1364        Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1365    ] = None
1366    """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value.
1367    For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1368    resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1369    To normalize samples independently, leave out the "batch" axis.
1370    Default: Scale all axes jointly."""
1371
1372    min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0
1373    """The lower percentile used to determine the value to align with zero."""
1374
1375    max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0
1376    """The upper percentile used to determine the value to align with one.
1377    Has to be bigger than `min_percentile`.
1378    The range is 1 to 100 instead of 0 to 100 to avoid mistakenly
1379    accepting percentiles specified in the range 0.0 to 1.0."""
1380
1381    eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1382    """Epsilon for numeric stability.
1383    `out = (tensor - v_lower) / (v_upper - v_lower + eps)`;
1384    with `v_lower,v_upper` values at the respective percentiles."""
1385
1386    reference_tensor: Optional[TensorId] = None
1387    """Tensor ID to compute the percentiles from. Default: The tensor itself.
1388    For any tensor in `inputs` only input tensor references are allowed."""
1389
1390    @field_validator("max_percentile", mode="after")
1391    @classmethod
1392    def min_smaller_max(cls, value: float, info: ValidationInfo) -> float:
1393        if (min_p := info.data["min_percentile"]) >= value:
1394            raise ValueError(f"min_percentile {min_p} >= max_percentile {value}")
1395
1396        return value

key word arguments for ScaleRangeDescr

For min_percentile=0.0 (the default) and max_percentile=100 (the default) this processing step normalizes data to the [0, 1] intervall. For other percentiles the normalized values will partially be outside the [0, 1] intervall. Use ScaleRange followed by ClipDescr if you want to limit the normalized values to a range.

axes: Annotated[Optional[Sequence[AxisId]], FieldInfo(annotation=NoneType, required=True, examples=[('batch', 'x', 'y')])]

The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value. For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') resulting in a tensor of equal shape normalized per channel, specify axes=('batch', 'x', 'y'). To normalize samples independently, leave out the "batch" axis. Default: Scale all axes jointly.

min_percentile: Annotated[float, Interval(gt=None, ge=0, lt=100, le=None)]

The lower percentile used to determine the value to align with zero.

max_percentile: Annotated[float, Interval(gt=1, ge=None, lt=None, le=100)]

The upper percentile used to determine the value to align with one. Has to be bigger than min_percentile. The range is 1 to 100 instead of 0 to 100 to avoid mistakenly accepting percentiles specified in the range 0.0 to 1.0.

eps: Annotated[float, Interval(gt=0, ge=None, lt=None, le=0.1)]

Epsilon for numeric stability. out = (tensor - v_lower) / (v_upper - v_lower + eps); with v_lower,v_upper values at the respective percentiles.

reference_tensor: Optional[TensorId]

Tensor ID to compute the percentiles from. Default: The tensor itself. For any tensor in inputs only input tensor references are allowed.

@field_validator('max_percentile', mode='after')
@classmethod
def min_smaller_max( cls, value: float, info: pydantic_core.core_schema.ValidationInfo) -> float:
1390    @field_validator("max_percentile", mode="after")
1391    @classmethod
1392    def min_smaller_max(cls, value: float, info: ValidationInfo) -> float:
1393        if (min_p := info.data["min_percentile"]) >= value:
1394            raise ValueError(f"min_percentile {min_p} >= max_percentile {value}")
1395
1396        return value
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class ScaleRangeDescr(ProcessingDescrBase):
1399class ScaleRangeDescr(ProcessingDescrBase):
1400    """Scale with percentiles.
1401
1402    Examples:
1403    1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0
1404        - in YAML
1405        ```yaml
1406        preprocessing:
1407          - id: scale_range
1408            kwargs:
1409              axes: ['y', 'x']
1410              max_percentile: 99.8
1411              min_percentile: 5.0
1412        ```
1413        - in Python
1414        >>> preprocessing = [
1415        ...     ScaleRangeDescr(
1416        ...         kwargs=ScaleRangeKwargs(
1417        ...           axes= (AxisId('y'), AxisId('x')),
1418        ...           max_percentile= 99.8,
1419        ...           min_percentile= 5.0,
1420        ...         )
1421        ...     ),
1422        ...     ClipDescr(
1423        ...         kwargs=ClipKwargs(
1424        ...             min=0.0,
1425        ...             max=1.0,
1426        ...         )
1427        ...     ),
1428        ... ]
1429
1430      2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles.
1431        - in YAML
1432        ```yaml
1433        preprocessing:
1434          - id: scale_range
1435            kwargs:
1436              axes: ['y', 'x']
1437              max_percentile: 99.8
1438              min_percentile: 5.0
1439                  - id: scale_range
1440           - id: clip
1441             kwargs:
1442              min: 0.0
1443              max: 1.0
1444        ```
1445        - in Python
1446        >>> preprocessing = [ScaleRangeDescr(
1447        ...   kwargs=ScaleRangeKwargs(
1448        ...       axes= (AxisId('y'), AxisId('x')),
1449        ...       max_percentile= 99.8,
1450        ...       min_percentile= 5.0,
1451        ...   )
1452        ... )]
1453
1454    """
1455
1456    implemented_id: ClassVar[Literal["scale_range"]] = "scale_range"
1457    if TYPE_CHECKING:
1458        id: Literal["scale_range"] = "scale_range"
1459    else:
1460        id: Literal["scale_range"]
1461    kwargs: ScaleRangeKwargs = Field(default_factory=ScaleRangeKwargs.model_construct)

Scale with percentiles.

Examples:

  1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0
    • in YAML
preprocessing:
  - id: scale_range
    kwargs:
      axes: ['y', 'x']
      max_percentile: 99.8
      min_percentile: 5.0
- in Python >>> preprocessing = [ ... ScaleRangeDescr( ... kwargs=ScaleRangeKwargs( ... axes= (AxisId('y'), AxisId('x')), ... max_percentile= 99.8, ... min_percentile= 5.0, ... ) ... ), ... ClipDescr( ... kwargs=ClipKwargs( ... min=0.0, ... max=1.0, ... ) ... ), ... ]
  1. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles.

    • in YAML
    preprocessing:
      - id: scale_range
        kwargs:
          axes: ['y', 'x']
          max_percentile: 99.8
          min_percentile: 5.0
              - id: scale_range
       - id: clip
         kwargs:
          min: 0.0
          max: 1.0
    
    • in Python
      >>> preprocessing = [ScaleRangeDescr(
      ...   kwargs=ScaleRangeKwargs(
      ...       axes= (AxisId('y'), AxisId('x')),
      ...       max_percentile= 99.8,
      ...       min_percentile= 5.0,
      ...   )
      ... )]
      
implemented_id: ClassVar[Literal['scale_range']] = 'scale_range'
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

id: Literal['scale_range']
class ScaleMeanVarianceKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1464class ScaleMeanVarianceKwargs(ProcessingKwargs):
1465    """key word arguments for `ScaleMeanVarianceKwargs`"""
1466
1467    reference_tensor: TensorId
1468    """Name of tensor to match."""
1469
1470    axes: Annotated[
1471        Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1472    ] = None
1473    """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1474    For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1475    resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1476    To normalize samples independently, leave out the 'batch' axis.
1477    Default: Scale all axes jointly."""
1478
1479    eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1480    """Epsilon for numeric stability:
1481    `out  = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`"""

key word arguments for ScaleMeanVarianceKwargs

reference_tensor: TensorId

Name of tensor to match.

axes: Annotated[Optional[Sequence[AxisId]], FieldInfo(annotation=NoneType, required=True, examples=[('batch', 'x', 'y')])]

The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') resulting in a tensor of equal shape normalized per channel, specify axes=('batch', 'x', 'y'). To normalize samples independently, leave out the 'batch' axis. Default: Scale all axes jointly.

eps: Annotated[float, Interval(gt=0, ge=None, lt=None, le=0.1)]

Epsilon for numeric stability: out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class ScaleMeanVarianceDescr(ProcessingDescrBase):
1484class ScaleMeanVarianceDescr(ProcessingDescrBase):
1485    """Scale a tensor's data distribution to match another tensor's mean/std.
1486    `out  = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`
1487    """
1488
1489    implemented_id: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance"
1490    if TYPE_CHECKING:
1491        id: Literal["scale_mean_variance"] = "scale_mean_variance"
1492    else:
1493        id: Literal["scale_mean_variance"]
1494    kwargs: ScaleMeanVarianceKwargs

Scale a tensor's data distribution to match another tensor's mean/std. out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.

implemented_id: ClassVar[Literal['scale_mean_variance']] = 'scale_mean_variance'
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

id: Literal['scale_mean_variance']
PreprocessingDescr = typing.Annotated[typing.Union[BinarizeDescr, ClipDescr, EnsureDtypeDescr, FixedZeroMeanUnitVarianceDescr, ScaleLinearDescr, ScaleRangeDescr, SigmoidDescr, SoftmaxDescr, ZeroMeanUnitVarianceDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)]
PostprocessingDescr = typing.Annotated[typing.Union[BinarizeDescr, ClipDescr, EnsureDtypeDescr, FixedZeroMeanUnitVarianceDescr, ScaleLinearDescr, ScaleMeanVarianceDescr, ScaleRangeDescr, SigmoidDescr, SoftmaxDescr, ZeroMeanUnitVarianceDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)]
class TensorDescrBase(bioimageio.spec._internal.node.Node, typing.Generic[~IO_AxisT]):
1530class TensorDescrBase(Node, Generic[IO_AxisT]):
1531    id: TensorId
1532    """Tensor id. No duplicates are allowed."""
1533
1534    description: Annotated[str, MaxLen(128)] = ""
1535    """free text description"""
1536
1537    axes: NotEmpty[Sequence[IO_AxisT]]
1538    """tensor axes"""
1539
1540    @property
1541    def shape(self):
1542        return tuple(a.size for a in self.axes)
1543
1544    @field_validator("axes", mode="after", check_fields=False)
1545    @classmethod
1546    def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]:
1547        batch_axes = [a for a in axes if a.type == "batch"]
1548        if len(batch_axes) > 1:
1549            raise ValueError(
1550                f"Only one batch axis (per tensor) allowed, but got {batch_axes}"
1551            )
1552
1553        seen_ids: Set[AxisId] = set()
1554        duplicate_axes_ids: Set[AxisId] = set()
1555        for a in axes:
1556            (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id)
1557
1558        if duplicate_axes_ids:
1559            raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}")
1560
1561        return axes
1562
1563    test_tensor: FAIR[Optional[FileDescr_]] = None
1564    """An example tensor to use for testing.
1565    Using the model with the test input tensors is expected to yield the test output tensors.
1566    Each test tensor has be a an ndarray in the
1567    [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format).
1568    The file extension must be '.npy'."""
1569
1570    sample_tensor: FAIR[Optional[FileDescr_]] = None
1571    """A sample tensor to illustrate a possible input/output for the model,
1572    The sample image primarily serves to inform a human user about an example use case
1573    and is typically stored as .hdf5, .png or .tiff.
1574    It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats)
1575    (numpy's `.npy` format is not supported).
1576    The image dimensionality has to match the number of axes specified in this tensor description.
1577    """
1578
1579    @model_validator(mode="after")
1580    def _validate_sample_tensor(self) -> Self:
1581        if self.sample_tensor is None or not get_validation_context().perform_io_checks:
1582            return self
1583
1584        reader = get_reader(self.sample_tensor.source, sha256=self.sample_tensor.sha256)
1585        tensor: NDArray[Any] = imread(
1586            reader.read(),
1587            extension=PurePosixPath(reader.original_file_name).suffix,
1588        )
1589        n_dims = len(tensor.squeeze().shape)
1590        n_dims_min = n_dims_max = len(self.axes)
1591
1592        for a in self.axes:
1593            if isinstance(a, BatchAxis):
1594                n_dims_min -= 1
1595            elif isinstance(a.size, int):
1596                if a.size == 1:
1597                    n_dims_min -= 1
1598            elif isinstance(a.size, (ParameterizedSize, DataDependentSize)):
1599                if a.size.min == 1:
1600                    n_dims_min -= 1
1601            elif isinstance(a.size, SizeReference):
1602                if a.size.offset < 2:
1603                    # size reference may result in singleton axis
1604                    n_dims_min -= 1
1605            else:
1606                assert_never(a.size)
1607
1608        n_dims_min = max(0, n_dims_min)
1609        if n_dims < n_dims_min or n_dims > n_dims_max:
1610            raise ValueError(
1611                f"Expected sample tensor to have {n_dims_min} to"
1612                + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})."
1613            )
1614
1615        return self
1616
1617    data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = (
1618        IntervalOrRatioDataDescr()
1619    )
1620    """Description of the tensor's data values, optionally per channel.
1621    If specified per channel, the data `type` needs to match across channels."""
1622
1623    @property
1624    def dtype(
1625        self,
1626    ) -> Literal[
1627        "float32",
1628        "float64",
1629        "uint8",
1630        "int8",
1631        "uint16",
1632        "int16",
1633        "uint32",
1634        "int32",
1635        "uint64",
1636        "int64",
1637        "bool",
1638    ]:
1639        """dtype as specified under `data.type` or `data[i].type`"""
1640        if isinstance(self.data, collections.abc.Sequence):
1641            return self.data[0].type
1642        else:
1643            return self.data.type
1644
1645    @field_validator("data", mode="after")
1646    @classmethod
1647    def _check_data_type_across_channels(
1648        cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]
1649    ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]:
1650        if not isinstance(value, list):
1651            return value
1652
1653        dtypes = {t.type for t in value}
1654        if len(dtypes) > 1:
1655            raise ValueError(
1656                "Tensor data descriptions per channel need to agree in their data"
1657                + f" `type`, but found {dtypes}."
1658            )
1659
1660        return value
1661
1662    @model_validator(mode="after")
1663    def _check_data_matches_channelaxis(self) -> Self:
1664        if not isinstance(self.data, (list, tuple)):
1665            return self
1666
1667        for a in self.axes:
1668            if isinstance(a, ChannelAxis):
1669                size = a.size
1670                assert isinstance(size, int)
1671                break
1672        else:
1673            return self
1674
1675        if len(self.data) != size:
1676            raise ValueError(
1677                f"Got tensor data descriptions for {len(self.data)} channels, but"
1678                + f" '{a.id}' axis has size {size}."
1679            )
1680
1681        return self
1682
1683    def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]:
1684        if len(array.shape) != len(self.axes):
1685            raise ValueError(
1686                f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})"
1687                + f" incompatible with {len(self.axes)} axes."
1688            )
1689        return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
id: TensorId

Tensor id. No duplicates are allowed.

description: Annotated[str, MaxLen(max_length=128)]

free text description

axes: Annotated[Sequence[~IO_AxisT], MinLen(min_length=1)]

tensor axes

shape
1540    @property
1541    def shape(self):
1542        return tuple(a.size for a in self.axes)
test_tensor: Annotated[Optional[Annotated[bioimageio.spec._internal.io.FileDescr, AfterValidator(func=<function wo_special_file_name at 0x7f83b7c0fec0>), WrapSerializer(func=<function package_file_descr_serializer at 0x7f83b7cb71a0>, return_type=PydanticUndefined, when_used='unless-none')]], AfterWarner(func=<function as_warning.<locals>.wrapper at 0x7f83b7cd1080>, severity=35, msg=None, context=None)]

An example tensor to use for testing. Using the model with the test input tensors is expected to yield the test output tensors. Each test tensor has be a an ndarray in the numpy.lib file format. The file extension must be '.npy'.

sample_tensor: Annotated[Optional[Annotated[bioimageio.spec._internal.io.FileDescr, AfterValidator(func=<function wo_special_file_name at 0x7f83b7c0fec0>), WrapSerializer(func=<function package_file_descr_serializer at 0x7f83b7cb71a0>, return_type=PydanticUndefined, when_used='unless-none')]], AfterWarner(func=<function as_warning.<locals>.wrapper at 0x7f83b7cd1080>, severity=35, msg=None, context=None)]

A sample tensor to illustrate a possible input/output for the model, The sample image primarily serves to inform a human user about an example use case and is typically stored as .hdf5, .png or .tiff. It has to be readable by the imageio library (numpy's .npy format is not supported). The image dimensionality has to match the number of axes specified in this tensor description.

data: Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr, Annotated[Sequence[Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr]], MinLen(min_length=1)]]

Description of the tensor's data values, optionally per channel. If specified per channel, the data type needs to match across channels.

dtype: Literal['float32', 'float64', 'uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64', 'bool']
1623    @property
1624    def dtype(
1625        self,
1626    ) -> Literal[
1627        "float32",
1628        "float64",
1629        "uint8",
1630        "int8",
1631        "uint16",
1632        "int16",
1633        "uint32",
1634        "int32",
1635        "uint64",
1636        "int64",
1637        "bool",
1638    ]:
1639        """dtype as specified under `data.type` or `data[i].type`"""
1640        if isinstance(self.data, collections.abc.Sequence):
1641            return self.data[0].type
1642        else:
1643            return self.data.type

dtype as specified under data.type or data[i].type

def get_axis_sizes_for_array( self, array: numpy.ndarray[tuple[typing.Any, ...], numpy.dtype[typing.Any]]) -> Dict[AxisId, int]:
1683    def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]:
1684        if len(array.shape) != len(self.axes):
1685            raise ValueError(
1686                f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})"
1687                + f" incompatible with {len(self.axes)} axes."
1688            )
1689        return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
class InputTensorDescr(bioimageio.spec._internal.node.Node, typing.Generic[~IO_AxisT]):
1692class InputTensorDescr(TensorDescrBase[InputAxis]):
1693    id: TensorId = TensorId("input")
1694    """Input tensor id.
1695    No duplicates are allowed across all inputs and outputs."""
1696
1697    optional: bool = False
1698    """indicates that this tensor may be `None`"""
1699
1700    preprocessing: List[PreprocessingDescr] = Field(
1701        default_factory=cast(Callable[[], List[PreprocessingDescr]], list)
1702    )
1703
1704    """Description of how this input should be preprocessed.
1705
1706    notes:
1707    - If preprocessing does not start with an 'ensure_dtype' entry, it is added
1708      to ensure an input tensor's data type matches the input tensor's data description.
1709    - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an
1710      'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally
1711      changing the data type.
1712    """
1713
1714    @model_validator(mode="after")
1715    def _validate_preprocessing_kwargs(self) -> Self:
1716        axes_ids = [a.id for a in self.axes]
1717        for p in self.preprocessing:
1718            kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
1719            if kwargs_axes is None:
1720                continue
1721
1722            if not isinstance(kwargs_axes, collections.abc.Sequence):
1723                raise ValueError(
1724                    f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}"
1725                )
1726
1727            if any(a not in axes_ids for a in kwargs_axes):
1728                raise ValueError(
1729                    "`preprocessing.i.kwargs.axes` needs to be subset of axes ids"
1730                )
1731
1732        if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
1733            dtype = self.data.type
1734        else:
1735            dtype = self.data[0].type
1736
1737        # ensure `preprocessing` begins with `EnsureDtypeDescr`
1738        if not self.preprocessing or not isinstance(
1739            self.preprocessing[0], EnsureDtypeDescr
1740        ):
1741            self.preprocessing.insert(
1742                0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1743            )
1744
1745        # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
1746        if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)):
1747            self.preprocessing.append(
1748                EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1749            )
1750
1751        return self
id: TensorId

Input tensor id. No duplicates are allowed across all inputs and outputs.

optional: bool

indicates that this tensor may be None

preprocessing: List[Annotated[Union[BinarizeDescr, ClipDescr, EnsureDtypeDescr, FixedZeroMeanUnitVarianceDescr, ScaleLinearDescr, ScaleRangeDescr, SigmoidDescr, SoftmaxDescr, ZeroMeanUnitVarianceDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]

Description of how this input should be preprocessed.

notes:

  • If preprocessing does not start with an 'ensure_dtype' entry, it is added to ensure an input tensor's data type matches the input tensor's data description.
  • If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally changing the data type.
def convert_axes( axes: str, *, shape: Union[Sequence[int], bioimageio.spec.model.v0_4.ParameterizedInputShape, bioimageio.spec.model.v0_4.ImplicitOutputShape], tensor_type: Literal['input', 'output'], halo: Optional[Sequence[int]], size_refs: Mapping[bioimageio.spec.model.v0_4.TensorName, Mapping[str, int]]):
1754def convert_axes(
1755    axes: str,
1756    *,
1757    shape: Union[
1758        Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4
1759    ],
1760    tensor_type: Literal["input", "output"],
1761    halo: Optional[Sequence[int]],
1762    size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
1763):
1764    ret: List[AnyAxis] = []
1765    for i, a in enumerate(axes):
1766        axis_type = _AXIS_TYPE_MAP.get(a, a)
1767        if axis_type == "batch":
1768            ret.append(BatchAxis())
1769            continue
1770
1771        scale = 1.0
1772        if isinstance(shape, _ParameterizedInputShape_v0_4):
1773            if shape.step[i] == 0:
1774                size = shape.min[i]
1775            else:
1776                size = ParameterizedSize(min=shape.min[i], step=shape.step[i])
1777        elif isinstance(shape, _ImplicitOutputShape_v0_4):
1778            ref_t = str(shape.reference_tensor)
1779            if ref_t.count(".") == 1:
1780                t_id, orig_a_id = ref_t.split(".")
1781            else:
1782                t_id = ref_t
1783                orig_a_id = a
1784
1785            a_id = _AXIS_ID_MAP.get(orig_a_id, a)
1786            if not (orig_scale := shape.scale[i]):
1787                # old way to insert a new axis dimension
1788                size = int(2 * shape.offset[i])
1789            else:
1790                scale = 1 / orig_scale
1791                if axis_type in ("channel", "index"):
1792                    # these axes no longer have a scale
1793                    offset_from_scale = orig_scale * size_refs.get(
1794                        _TensorName_v0_4(t_id), {}
1795                    ).get(orig_a_id, 0)
1796                else:
1797                    offset_from_scale = 0
1798                size = SizeReference(
1799                    tensor_id=TensorId(t_id),
1800                    axis_id=AxisId(a_id),
1801                    offset=int(offset_from_scale + 2 * shape.offset[i]),
1802                )
1803        else:
1804            size = shape[i]
1805
1806        if axis_type == "time":
1807            if tensor_type == "input":
1808                ret.append(TimeInputAxis(size=size, scale=scale))
1809            else:
1810                assert not isinstance(size, ParameterizedSize)
1811                if halo is None:
1812                    ret.append(TimeOutputAxis(size=size, scale=scale))
1813                else:
1814                    assert not isinstance(size, int)
1815                    ret.append(
1816                        TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i])
1817                    )
1818
1819        elif axis_type == "index":
1820            if tensor_type == "input":
1821                ret.append(IndexInputAxis(size=size))
1822            else:
1823                if isinstance(size, ParameterizedSize):
1824                    size = DataDependentSize(min=size.min)
1825
1826                ret.append(IndexOutputAxis(size=size))
1827        elif axis_type == "channel":
1828            assert not isinstance(size, ParameterizedSize)
1829            if isinstance(size, SizeReference):
1830                warnings.warn(
1831                    "Conversion of channel size from an implicit output shape may be"
1832                    + " wrong"
1833                )
1834                ret.append(
1835                    ChannelAxis(
1836                        channel_names=[
1837                            Identifier(f"channel{i}") for i in range(size.offset)
1838                        ]
1839                    )
1840                )
1841            else:
1842                ret.append(
1843                    ChannelAxis(
1844                        channel_names=[Identifier(f"channel{i}") for i in range(size)]
1845                    )
1846                )
1847        elif axis_type == "space":
1848            if tensor_type == "input":
1849                ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale))
1850            else:
1851                assert not isinstance(size, ParameterizedSize)
1852                if halo is None or halo[i] == 0:
1853                    ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale))
1854                elif isinstance(size, int):
1855                    raise NotImplementedError(
1856                        f"output axis with halo and fixed size (here {size}) not allowed"
1857                    )
1858                else:
1859                    ret.append(
1860                        SpaceOutputAxisWithHalo(
1861                            id=AxisId(a), size=size, scale=scale, halo=halo[i]
1862                        )
1863                    )
1864
1865    return ret
class OutputTensorDescr(bioimageio.spec._internal.node.Node, typing.Generic[~IO_AxisT]):
2025class OutputTensorDescr(TensorDescrBase[OutputAxis]):
2026    id: TensorId = TensorId("output")
2027    """Output tensor id.
2028    No duplicates are allowed across all inputs and outputs."""
2029
2030    postprocessing: List[PostprocessingDescr] = Field(
2031        default_factory=cast(Callable[[], List[PostprocessingDescr]], list)
2032    )
2033    """Description of how this output should be postprocessed.
2034
2035    note: `postprocessing` always ends with an 'ensure_dtype' operation.
2036          If not given this is added to cast to this tensor's `data.type`.
2037    """
2038
2039    @model_validator(mode="after")
2040    def _validate_postprocessing_kwargs(self) -> Self:
2041        axes_ids = [a.id for a in self.axes]
2042        for p in self.postprocessing:
2043            kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
2044            if kwargs_axes is None:
2045                continue
2046
2047            if not isinstance(kwargs_axes, collections.abc.Sequence):
2048                raise ValueError(
2049                    f"expected `axes` sequence, but got {type(kwargs_axes)}"
2050                )
2051
2052            if any(a not in axes_ids for a in kwargs_axes):
2053                raise ValueError("`kwargs.axes` needs to be subset of axes ids")
2054
2055        if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
2056            dtype = self.data.type
2057        else:
2058            dtype = self.data[0].type
2059
2060        # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
2061        if not self.postprocessing or not isinstance(
2062            self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)
2063        ):
2064            self.postprocessing.append(
2065                EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
2066            )
2067        return self
id: TensorId

Output tensor id. No duplicates are allowed across all inputs and outputs.

postprocessing: List[Annotated[Union[BinarizeDescr, ClipDescr, EnsureDtypeDescr, FixedZeroMeanUnitVarianceDescr, ScaleLinearDescr, ScaleMeanVarianceDescr, ScaleRangeDescr, SigmoidDescr, SoftmaxDescr, ZeroMeanUnitVarianceDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]

Description of how this output should be postprocessed.

note: postprocessing always ends with an 'ensure_dtype' operation. If not given this is added to cast to this tensor's data.type.

TensorDescr = typing.Union[InputTensorDescr, OutputTensorDescr]
def validate_tensors( tensors: Mapping[TensorId, Tuple[Union[InputTensorDescr, OutputTensorDescr], Optional[numpy.ndarray[tuple[Any, ...], numpy.dtype[Any]]]]], tensor_origin: Literal['test_tensor']):
2117def validate_tensors(
2118    tensors: Mapping[TensorId, Tuple[TensorDescr, Optional[NDArray[Any]]]],
2119    tensor_origin: Literal[
2120        "test_tensor"
2121    ],  # for more precise error messages, e.g. 'test_tensor'
2122):
2123    all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, Optional[int]]]] = {}
2124
2125    def e_msg(d: TensorDescr):
2126        return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]"
2127
2128    for descr, array in tensors.values():
2129        if array is None:
2130            axis_sizes = {a.id: None for a in descr.axes}
2131        else:
2132            try:
2133                axis_sizes = descr.get_axis_sizes_for_array(array)
2134            except ValueError as e:
2135                raise ValueError(f"{e_msg(descr)} {e}")
2136
2137        all_tensor_axes[descr.id] = {a.id: (a, axis_sizes[a.id]) for a in descr.axes}
2138
2139    for descr, array in tensors.values():
2140        if array is None:
2141            continue
2142
2143        if descr.dtype in ("float32", "float64"):
2144            invalid_test_tensor_dtype = array.dtype.name not in (
2145                "float32",
2146                "float64",
2147                "uint8",
2148                "int8",
2149                "uint16",
2150                "int16",
2151                "uint32",
2152                "int32",
2153                "uint64",
2154                "int64",
2155            )
2156        else:
2157            invalid_test_tensor_dtype = array.dtype.name != descr.dtype
2158
2159        if invalid_test_tensor_dtype:
2160            raise ValueError(
2161                f"{e_msg(descr)}.{tensor_origin}.dtype '{array.dtype.name}' does not"
2162                + f" match described dtype '{descr.dtype}'"
2163            )
2164
2165        if array.min() > -1e-4 and array.max() < 1e-4:
2166            raise ValueError(
2167                "Output values are too small for reliable testing."
2168                + f" Values <-1e5 or >=1e5 must be present in {tensor_origin}"
2169            )
2170
2171        for a in descr.axes:
2172            actual_size = all_tensor_axes[descr.id][a.id][1]
2173            if actual_size is None:
2174                continue
2175
2176            if a.size is None:
2177                continue
2178
2179            if isinstance(a.size, int):
2180                if actual_size != a.size:
2181                    raise ValueError(
2182                        f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' "
2183                        + f"has incompatible size {actual_size}, expected {a.size}"
2184                    )
2185            elif isinstance(a.size, ParameterizedSize):
2186                _ = a.size.validate_size(actual_size)
2187            elif isinstance(a.size, DataDependentSize):
2188                _ = a.size.validate_size(actual_size)
2189            elif isinstance(a.size, SizeReference):
2190                ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id)
2191                if ref_tensor_axes is None:
2192                    raise ValueError(
2193                        f"{e_msg(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor"
2194                        + f" reference '{a.size.tensor_id}'"
2195                    )
2196
2197                ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None))
2198                if ref_axis is None or ref_size is None:
2199                    raise ValueError(
2200                        f"{e_msg(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis"
2201                        + f" reference '{a.size.tensor_id}.{a.size.axis_id}"
2202                    )
2203
2204                if a.unit != ref_axis.unit:
2205                    raise ValueError(
2206                        f"{e_msg(descr)}.axes[{a.id}].size: `SizeReference` requires"
2207                        + " axis and reference axis to have the same `unit`, but"
2208                        + f" {a.unit}!={ref_axis.unit}"
2209                    )
2210
2211                if actual_size != (
2212                    expected_size := (
2213                        ref_size * ref_axis.scale / a.scale + a.size.offset
2214                    )
2215                ):
2216                    raise ValueError(
2217                        f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' of size"
2218                        + f" {actual_size} invalid for referenced size {ref_size};"
2219                        + f" expected {expected_size}"
2220                    )
2221            else:
2222                assert_never(a.size)
FileDescr_dependencies = typing.Annotated[bioimageio.spec._internal.io.FileDescr, AfterValidator(func=<function wo_special_file_name>), WrapSerializer(func=<function package_file_descr_serializer>, return_type=PydanticUndefined, when_used='unless-none'), WithSuffix(suffix=('.yaml', '.yml'), case_sensitive=True), FieldInfo(annotation=NoneType, required=True, examples=[{'source': 'environment.yaml'}])]
class ArchitectureFromFileDescr(_ArchitectureCallableDescr, bioimageio.spec._internal.io.FileDescr):
2242class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr):
2243    source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2244    """Architecture source file"""
2245
2246    @model_serializer(mode="wrap", when_used="unless-none")
2247    def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo):
2248        return package_file_descr_serializer(self, nxt, info)

A file description

source: Annotated[Union[bioimageio.spec._internal.url.HttpUrl, bioimageio.spec._internal.io.RelativeFilePath, Annotated[pathlib.Path, PathType(path_type='file'), FieldInfo(annotation=NoneType, required=True, title='FilePath')]], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')]), AfterValidator(func=<function wo_special_file_name at 0x7f83b7c0fec0>)]

Architecture source file

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr):
2251class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr):
2252    import_from: str
2253    """Where to import the callable from, i.e. `from <import_from> import <callable>`"""
import_from: str

Where to import the callable from, i.e. from <import_from> import <callable>

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class WeightsEntryDescrBase(bioimageio.spec._internal.io.FileDescr):
2313class WeightsEntryDescrBase(FileDescr):
2314    type: ClassVar[WeightsFormat]
2315    weights_format_name: ClassVar[str]  # human readable
2316
2317    source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2318    """Source of the weights file."""
2319
2320    authors: Optional[List[Author]] = None
2321    """Authors
2322    Either the person(s) that have trained this model resulting in the original weights file.
2323        (If this is the initial weights entry, i.e. it does not have a `parent`)
2324    Or the person(s) who have converted the weights to this weights format.
2325        (If this is a child weight, i.e. it has a `parent` field)
2326    """
2327
2328    parent: Annotated[
2329        Optional[WeightsFormat], Field(examples=["pytorch_state_dict"])
2330    ] = None
2331    """The source weights these weights were converted from.
2332    For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`,
2333    The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights.
2334    All weight entries except one (the initial set of weights resulting from training the model),
2335    need to have this field."""
2336
2337    comment: str = ""
2338    """A comment about this weights entry, for example how these weights were created."""
2339
2340    @model_validator(mode="after")
2341    def _validate(self) -> Self:
2342        if self.type == self.parent:
2343            raise ValueError("Weights entry can't be it's own parent.")
2344
2345        return self
2346
2347    @model_serializer(mode="wrap", when_used="unless-none")
2348    def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo):
2349        return package_file_descr_serializer(self, nxt, info)

A file description

type: ClassVar[Literal['keras_hdf5', 'onnx', 'pytorch_state_dict', 'tensorflow_js', 'tensorflow_saved_model_bundle', 'torchscript']]
weights_format_name: ClassVar[str]
source: Annotated[Union[bioimageio.spec._internal.url.HttpUrl, bioimageio.spec._internal.io.RelativeFilePath, Annotated[pathlib.Path, PathType(path_type='file'), FieldInfo(annotation=NoneType, required=True, title='FilePath')]], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')]), AfterValidator(func=<function wo_special_file_name at 0x7f83b7c0fec0>)]

Source of the weights file.

authors: Optional[List[bioimageio.spec.generic.v0_3.Author]]

Authors Either the person(s) that have trained this model resulting in the original weights file. (If this is the initial weights entry, i.e. it does not have a parent) Or the person(s) who have converted the weights to this weights format. (If this is a child weight, i.e. it has a parent field)

parent: Annotated[Optional[Literal['keras_hdf5', 'onnx', 'pytorch_state_dict', 'tensorflow_js', 'tensorflow_saved_model_bundle', 'torchscript']], FieldInfo(annotation=NoneType, required=True, examples=['pytorch_state_dict'])]

The source weights these weights were converted from. For example, if a model's weights were converted from the pytorch_state_dict format to torchscript, The pytorch_state_dict weights entry has no parent and is the parent of the torchscript weights. All weight entries except one (the initial set of weights resulting from training the model), need to have this field.

comment: str

A comment about this weights entry, for example how these weights were created.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class KerasHdf5WeightsDescr(WeightsEntryDescrBase):
2352class KerasHdf5WeightsDescr(WeightsEntryDescrBase):
2353    type = "keras_hdf5"
2354    weights_format_name: ClassVar[str] = "Keras HDF5"
2355    tensorflow_version: Version
2356    """TensorFlow version used to create these weights."""

A file description

type = 'keras_hdf5'
weights_format_name: ClassVar[str] = 'Keras HDF5'

TensorFlow version used to create these weights.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class OnnxWeightsDescr(WeightsEntryDescrBase):
2359class OnnxWeightsDescr(WeightsEntryDescrBase):
2360    type = "onnx"
2361    weights_format_name: ClassVar[str] = "ONNX"
2362    opset_version: Annotated[int, Ge(7)]
2363    """ONNX opset version"""

A file description

type = 'onnx'
weights_format_name: ClassVar[str] = 'ONNX'
opset_version: Annotated[int, Ge(ge=7)]

ONNX opset version

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class PytorchStateDictWeightsDescr(WeightsEntryDescrBase):
2366class PytorchStateDictWeightsDescr(WeightsEntryDescrBase):
2367    type = "pytorch_state_dict"
2368    weights_format_name: ClassVar[str] = "Pytorch State Dict"
2369    architecture: Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr]
2370    pytorch_version: Version
2371    """Version of the PyTorch library used.
2372    If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible.
2373    """
2374    dependencies: Optional[FileDescr_dependencies] = None
2375    """Custom depencies beyond pytorch described in a Conda environment file.
2376    Allows to specify custom dependencies, see conda docs:
2377    - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms)
2378    - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually)
2379
2380    The conda environment file should include pytorch and any version pinning has to be compatible with
2381    **pytorch_version**.
2382    """

A file description

type = 'pytorch_state_dict'
weights_format_name: ClassVar[str] = 'Pytorch State Dict'

Version of the PyTorch library used. If architecture.depencencies is specified it has to include pytorch and any version pinning has to be compatible.

dependencies: Optional[Annotated[bioimageio.spec._internal.io.FileDescr, AfterValidator(func=<function wo_special_file_name at 0x7f83b7c0fec0>), WrapSerializer(func=<function package_file_descr_serializer at 0x7f83b7cb71a0>, return_type=PydanticUndefined, when_used='unless-none'), WithSuffix(suffix=('.yaml', '.yml'), case_sensitive=True), FieldInfo(annotation=NoneType, required=True, examples=[{'source': 'environment.yaml'}])]]

Custom depencies beyond pytorch described in a Conda environment file. Allows to specify custom dependencies, see conda docs:

The conda environment file should include pytorch and any version pinning has to be compatible with pytorch_version.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class TensorflowJsWeightsDescr(WeightsEntryDescrBase):
2385class TensorflowJsWeightsDescr(WeightsEntryDescrBase):
2386    type = "tensorflow_js"
2387    weights_format_name: ClassVar[str] = "Tensorflow.js"
2388    tensorflow_version: Version
2389    """Version of the TensorFlow library used."""
2390
2391    source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2392    """The multi-file weights.
2393    All required files/folders should be a zip archive."""

A file description

type = 'tensorflow_js'
weights_format_name: ClassVar[str] = 'Tensorflow.js'

Version of the TensorFlow library used.

source: Annotated[Union[bioimageio.spec._internal.url.HttpUrl, bioimageio.spec._internal.io.RelativeFilePath, Annotated[pathlib.Path, PathType(path_type='file'), FieldInfo(annotation=NoneType, required=True, title='FilePath')]], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')]), AfterValidator(func=<function wo_special_file_name at 0x7f83b7c0fec0>)]

The multi-file weights. All required files/folders should be a zip archive.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase):
2396class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase):
2397    type = "tensorflow_saved_model_bundle"
2398    weights_format_name: ClassVar[str] = "Tensorflow Saved Model"
2399    tensorflow_version: Version
2400    """Version of the TensorFlow library used."""
2401
2402    dependencies: Optional[FileDescr_dependencies] = None
2403    """Custom dependencies beyond tensorflow.
2404    Should include tensorflow and any version pinning has to be compatible with **tensorflow_version**."""
2405
2406    source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2407    """The multi-file weights.
2408    All required files/folders should be a zip archive."""

A file description

type = 'tensorflow_saved_model_bundle'
weights_format_name: ClassVar[str] = 'Tensorflow Saved Model'

Version of the TensorFlow library used.

dependencies: Optional[Annotated[bioimageio.spec._internal.io.FileDescr, AfterValidator(func=<function wo_special_file_name at 0x7f83b7c0fec0>), WrapSerializer(func=<function package_file_descr_serializer at 0x7f83b7cb71a0>, return_type=PydanticUndefined, when_used='unless-none'), WithSuffix(suffix=('.yaml', '.yml'), case_sensitive=True), FieldInfo(annotation=NoneType, required=True, examples=[{'source': 'environment.yaml'}])]]

Custom dependencies beyond tensorflow. Should include tensorflow and any version pinning has to be compatible with tensorflow_version.

source: Annotated[Union[bioimageio.spec._internal.url.HttpUrl, bioimageio.spec._internal.io.RelativeFilePath, Annotated[pathlib.Path, PathType(path_type='file'), FieldInfo(annotation=NoneType, required=True, title='FilePath')]], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')]), AfterValidator(func=<function wo_special_file_name at 0x7f83b7c0fec0>)]

The multi-file weights. All required files/folders should be a zip archive.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class TorchscriptWeightsDescr(WeightsEntryDescrBase):
2411class TorchscriptWeightsDescr(WeightsEntryDescrBase):
2412    type = "torchscript"
2413    weights_format_name: ClassVar[str] = "TorchScript"
2414    pytorch_version: Version
2415    """Version of the PyTorch library used."""

A file description

type = 'torchscript'
weights_format_name: ClassVar[str] = 'TorchScript'

Version of the PyTorch library used.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class WeightsDescr(bioimageio.spec._internal.node.Node):
2418class WeightsDescr(Node):
2419    keras_hdf5: Optional[KerasHdf5WeightsDescr] = None
2420    onnx: Optional[OnnxWeightsDescr] = None
2421    pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None
2422    tensorflow_js: Optional[TensorflowJsWeightsDescr] = None
2423    tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = (
2424        None
2425    )
2426    torchscript: Optional[TorchscriptWeightsDescr] = None
2427
2428    @model_validator(mode="after")
2429    def check_entries(self) -> Self:
2430        entries = {wtype for wtype, entry in self if entry is not None}
2431
2432        if not entries:
2433            raise ValueError("Missing weights entry")
2434
2435        entries_wo_parent = {
2436            wtype
2437            for wtype, entry in self
2438            if entry is not None and hasattr(entry, "parent") and entry.parent is None
2439        }
2440        if len(entries_wo_parent) != 1:
2441            issue_warning(
2442                "Exactly one weights entry may not specify the `parent` field (got"
2443                + " {value}). That entry is considered the original set of model weights."
2444                + " Other weight formats are created through conversion of the orignal or"
2445                + " already converted weights. They have to reference the weights format"
2446                + " they were converted from as their `parent`.",
2447                value=len(entries_wo_parent),
2448                field="weights",
2449            )
2450
2451        for wtype, entry in self:
2452            if entry is None:
2453                continue
2454
2455            assert hasattr(entry, "type")
2456            assert hasattr(entry, "parent")
2457            assert wtype == entry.type
2458            if (
2459                entry.parent is not None and entry.parent not in entries
2460            ):  # self reference checked for `parent` field
2461                raise ValueError(
2462                    f"`weights.{wtype}.parent={entry.parent} not in specified weight"
2463                    + f" formats: {entries}"
2464                )
2465
2466        return self
2467
2468    def __getitem__(
2469        self,
2470        key: Literal[
2471            "keras_hdf5",
2472            "onnx",
2473            "pytorch_state_dict",
2474            "tensorflow_js",
2475            "tensorflow_saved_model_bundle",
2476            "torchscript",
2477        ],
2478    ):
2479        if key == "keras_hdf5":
2480            ret = self.keras_hdf5
2481        elif key == "onnx":
2482            ret = self.onnx
2483        elif key == "pytorch_state_dict":
2484            ret = self.pytorch_state_dict
2485        elif key == "tensorflow_js":
2486            ret = self.tensorflow_js
2487        elif key == "tensorflow_saved_model_bundle":
2488            ret = self.tensorflow_saved_model_bundle
2489        elif key == "torchscript":
2490            ret = self.torchscript
2491        else:
2492            raise KeyError(key)
2493
2494        if ret is None:
2495            raise KeyError(key)
2496
2497        return ret
2498
2499    @property
2500    def available_formats(self):
2501        return {
2502            **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}),
2503            **({} if self.onnx is None else {"onnx": self.onnx}),
2504            **(
2505                {}
2506                if self.pytorch_state_dict is None
2507                else {"pytorch_state_dict": self.pytorch_state_dict}
2508            ),
2509            **(
2510                {}
2511                if self.tensorflow_js is None
2512                else {"tensorflow_js": self.tensorflow_js}
2513            ),
2514            **(
2515                {}
2516                if self.tensorflow_saved_model_bundle is None
2517                else {
2518                    "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle
2519                }
2520            ),
2521            **({} if self.torchscript is None else {"torchscript": self.torchscript}),
2522        }
2523
2524    @property
2525    def missing_formats(self):
2526        return {
2527            wf for wf in get_args(WeightsFormat) if wf not in self.available_formats
2528        }
keras_hdf5: Optional[KerasHdf5WeightsDescr]
onnx: Optional[OnnxWeightsDescr]
pytorch_state_dict: Optional[PytorchStateDictWeightsDescr]
tensorflow_js: Optional[TensorflowJsWeightsDescr]
tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr]
torchscript: Optional[TorchscriptWeightsDescr]
@model_validator(mode='after')
def check_entries(self) -> Self:
2428    @model_validator(mode="after")
2429    def check_entries(self) -> Self:
2430        entries = {wtype for wtype, entry in self if entry is not None}
2431
2432        if not entries:
2433            raise ValueError("Missing weights entry")
2434
2435        entries_wo_parent = {
2436            wtype
2437            for wtype, entry in self
2438            if entry is not None and hasattr(entry, "parent") and entry.parent is None
2439        }
2440        if len(entries_wo_parent) != 1:
2441            issue_warning(
2442                "Exactly one weights entry may not specify the `parent` field (got"
2443                + " {value}). That entry is considered the original set of model weights."
2444                + " Other weight formats are created through conversion of the orignal or"
2445                + " already converted weights. They have to reference the weights format"
2446                + " they were converted from as their `parent`.",
2447                value=len(entries_wo_parent),
2448                field="weights",
2449            )
2450
2451        for wtype, entry in self:
2452            if entry is None:
2453                continue
2454
2455            assert hasattr(entry, "type")
2456            assert hasattr(entry, "parent")
2457            assert wtype == entry.type
2458            if (
2459                entry.parent is not None and entry.parent not in entries
2460            ):  # self reference checked for `parent` field
2461                raise ValueError(
2462                    f"`weights.{wtype}.parent={entry.parent} not in specified weight"
2463                    + f" formats: {entries}"
2464                )
2465
2466        return self
available_formats
2499    @property
2500    def available_formats(self):
2501        return {
2502            **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}),
2503            **({} if self.onnx is None else {"onnx": self.onnx}),
2504            **(
2505                {}
2506                if self.pytorch_state_dict is None
2507                else {"pytorch_state_dict": self.pytorch_state_dict}
2508            ),
2509            **(
2510                {}
2511                if self.tensorflow_js is None
2512                else {"tensorflow_js": self.tensorflow_js}
2513            ),
2514            **(
2515                {}
2516                if self.tensorflow_saved_model_bundle is None
2517                else {
2518                    "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle
2519                }
2520            ),
2521            **({} if self.torchscript is None else {"torchscript": self.torchscript}),
2522        }
missing_formats
2524    @property
2525    def missing_formats(self):
2526        return {
2527            wf for wf in get_args(WeightsFormat) if wf not in self.available_formats
2528        }
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class ModelId(bioimageio.spec.generic.v0_3.ResourceId):
2531class ModelId(ResourceId):
2532    pass

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.

class LinkedModel(bioimageio.spec.generic.v0_3.LinkedResourceBase):
2535class LinkedModel(LinkedResourceBase):
2536    """Reference to a bioimage.io model."""
2537
2538    id: ModelId
2539    """A valid model `id` from the bioimage.io collection."""

Reference to a bioimage.io model.

id: ModelId

A valid model id from the bioimage.io collection.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class ReproducibilityTolerance(bioimageio.spec._internal.node.Node):
2561class ReproducibilityTolerance(Node, extra="allow"):
2562    """Describes what small numerical differences -- if any -- may be tolerated
2563    in the generated output when executing in different environments.
2564
2565    A tensor element *output* is considered mismatched to the **test_tensor** if
2566    abs(*output* - **test_tensor**) > **absolute_tolerance** + **relative_tolerance** * abs(**test_tensor**).
2567    (Internally we call [numpy.testing.assert_allclose](https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html).)
2568
2569    Motivation:
2570        For testing we can request the respective deep learning frameworks to be as
2571        reproducible as possible by setting seeds and chosing deterministic algorithms,
2572        but differences in operating systems, available hardware and installed drivers
2573        may still lead to numerical differences.
2574    """
2575
2576    relative_tolerance: RelativeTolerance = 1e-3
2577    """Maximum relative tolerance of reproduced test tensor."""
2578
2579    absolute_tolerance: AbsoluteTolerance = 1e-4
2580    """Maximum absolute tolerance of reproduced test tensor."""
2581
2582    mismatched_elements_per_million: MismatchedElementsPerMillion = 100
2583    """Maximum number of mismatched elements/pixels per million to tolerate."""
2584
2585    output_ids: Sequence[TensorId] = ()
2586    """Limits the output tensor IDs these reproducibility details apply to."""
2587
2588    weights_formats: Sequence[WeightsFormat] = ()
2589    """Limits the weights formats these details apply to."""

Describes what small numerical differences -- if any -- may be tolerated in the generated output when executing in different environments.

A tensor element output is considered mismatched to the test_tensor if abs(output - test_tensor) > absolute_tolerance + relative_tolerance * abs(test_tensor). (Internally we call numpy.testing.assert_allclose.)

Motivation:

For testing we can request the respective deep learning frameworks to be as reproducible as possible by setting seeds and chosing deterministic algorithms, but differences in operating systems, available hardware and installed drivers may still lead to numerical differences.

relative_tolerance: Annotated[float, Interval(gt=None, ge=0, lt=None, le=0.01)]

Maximum relative tolerance of reproduced test tensor.

absolute_tolerance: Annotated[float, Interval(gt=None, ge=0, lt=None, le=None)]

Maximum absolute tolerance of reproduced test tensor.

mismatched_elements_per_million: Annotated[int, Interval(gt=None, ge=0, lt=None, le=1000)]

Maximum number of mismatched elements/pixels per million to tolerate.

output_ids: Sequence[TensorId]

Limits the output tensor IDs these reproducibility details apply to.

weights_formats: Sequence[Literal['keras_hdf5', 'onnx', 'pytorch_state_dict', 'tensorflow_js', 'tensorflow_saved_model_bundle', 'torchscript']]

Limits the weights formats these details apply to.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'allow', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class BioimageioConfig(bioimageio.spec._internal.node.Node):
2592class BioimageioConfig(Node, extra="allow"):
2593    reproducibility_tolerance: Sequence[ReproducibilityTolerance] = ()
2594    """Tolerances to allow when reproducing the model's test outputs
2595    from the model's test inputs.
2596    Only the first entry matching tensor id and weights format is considered.
2597    """
reproducibility_tolerance: Sequence[ReproducibilityTolerance]

Tolerances to allow when reproducing the model's test outputs from the model's test inputs. Only the first entry matching tensor id and weights format is considered.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'allow', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class Config(bioimageio.spec._internal.node.Node):
2600class Config(Node, extra="allow"):
2601    bioimageio: BioimageioConfig = Field(
2602        default_factory=BioimageioConfig.model_construct
2603    )
bioimageio: BioimageioConfig
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'allow', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

2606class ModelDescr(GenericModelDescrBase):
2607    """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights.
2608    These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
2609    """
2610
2611    implemented_format_version: ClassVar[Literal["0.5.5"]] = "0.5.5"
2612    if TYPE_CHECKING:
2613        format_version: Literal["0.5.5"] = "0.5.5"
2614    else:
2615        format_version: Literal["0.5.5"]
2616        """Version of the bioimage.io model description specification used.
2617        When creating a new model always use the latest micro/patch version described here.
2618        The `format_version` is important for any consumer software to understand how to parse the fields.
2619        """
2620
2621    implemented_type: ClassVar[Literal["model"]] = "model"
2622    if TYPE_CHECKING:
2623        type: Literal["model"] = "model"
2624    else:
2625        type: Literal["model"]
2626        """Specialized resource type 'model'"""
2627
2628    id: Optional[ModelId] = None
2629    """bioimage.io-wide unique resource identifier
2630    assigned by bioimage.io; version **un**specific."""
2631
2632    authors: FAIR[List[Author]] = Field(
2633        default_factory=cast(Callable[[], List[Author]], list)
2634    )
2635    """The authors are the creators of the model RDF and the primary points of contact."""
2636
2637    documentation: FAIR[Optional[FileSource_documentation]] = None
2638    """URL or relative path to a markdown file with additional documentation.
2639    The recommended documentation file name is `README.md`. An `.md` suffix is mandatory.
2640    The documentation should include a '#[#] Validation' (sub)section
2641    with details on how to quantitatively validate the model on unseen data."""
2642
2643    @field_validator("documentation", mode="after")
2644    @classmethod
2645    def _validate_documentation(
2646        cls, value: Optional[FileSource_documentation]
2647    ) -> Optional[FileSource_documentation]:
2648        if not get_validation_context().perform_io_checks or value is None:
2649            return value
2650
2651        doc_reader = get_reader(value)
2652        doc_content = doc_reader.read().decode(encoding="utf-8")
2653        if not re.search("#.*[vV]alidation", doc_content):
2654            issue_warning(
2655                "No '# Validation' (sub)section found in {value}.",
2656                value=value,
2657                field="documentation",
2658            )
2659
2660        return value
2661
2662    inputs: NotEmpty[Sequence[InputTensorDescr]]
2663    """Describes the input tensors expected by this model."""
2664
2665    @field_validator("inputs", mode="after")
2666    @classmethod
2667    def _validate_input_axes(
2668        cls, inputs: Sequence[InputTensorDescr]
2669    ) -> Sequence[InputTensorDescr]:
2670        input_size_refs = cls._get_axes_with_independent_size(inputs)
2671
2672        for i, ipt in enumerate(inputs):
2673            valid_independent_refs: Dict[
2674                Tuple[TensorId, AxisId],
2675                Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2676            ] = {
2677                **{
2678                    (ipt.id, a.id): (ipt, a, a.size)
2679                    for a in ipt.axes
2680                    if not isinstance(a, BatchAxis)
2681                    and isinstance(a.size, (int, ParameterizedSize))
2682                },
2683                **input_size_refs,
2684            }
2685            for a, ax in enumerate(ipt.axes):
2686                cls._validate_axis(
2687                    "inputs",
2688                    i=i,
2689                    tensor_id=ipt.id,
2690                    a=a,
2691                    axis=ax,
2692                    valid_independent_refs=valid_independent_refs,
2693                )
2694        return inputs
2695
2696    @staticmethod
2697    def _validate_axis(
2698        field_name: str,
2699        i: int,
2700        tensor_id: TensorId,
2701        a: int,
2702        axis: AnyAxis,
2703        valid_independent_refs: Dict[
2704            Tuple[TensorId, AxisId],
2705            Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2706        ],
2707    ):
2708        if isinstance(axis, BatchAxis) or isinstance(
2709            axis.size, (int, ParameterizedSize, DataDependentSize)
2710        ):
2711            return
2712        elif not isinstance(axis.size, SizeReference):
2713            assert_never(axis.size)
2714
2715        # validate axis.size SizeReference
2716        ref = (axis.size.tensor_id, axis.size.axis_id)
2717        if ref not in valid_independent_refs:
2718            raise ValueError(
2719                "Invalid tensor axis reference at"
2720                + f" {field_name}[{i}].axes[{a}].size: {axis.size}."
2721            )
2722        if ref == (tensor_id, axis.id):
2723            raise ValueError(
2724                "Self-referencing not allowed for"
2725                + f" {field_name}[{i}].axes[{a}].size: {axis.size}"
2726            )
2727        if axis.type == "channel":
2728            if valid_independent_refs[ref][1].type != "channel":
2729                raise ValueError(
2730                    "A channel axis' size may only reference another fixed size"
2731                    + " channel axis."
2732                )
2733            if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names:
2734                ref_size = valid_independent_refs[ref][2]
2735                assert isinstance(ref_size, int), (
2736                    "channel axis ref (another channel axis) has to specify fixed"
2737                    + " size"
2738                )
2739                generated_channel_names = [
2740                    Identifier(axis.channel_names.format(i=i))
2741                    for i in range(1, ref_size + 1)
2742                ]
2743                axis.channel_names = generated_channel_names
2744
2745        if (ax_unit := getattr(axis, "unit", None)) != (
2746            ref_unit := getattr(valid_independent_refs[ref][1], "unit", None)
2747        ):
2748            raise ValueError(
2749                "The units of an axis and its reference axis need to match, but"
2750                + f" '{ax_unit}' != '{ref_unit}'."
2751            )
2752        ref_axis = valid_independent_refs[ref][1]
2753        if isinstance(ref_axis, BatchAxis):
2754            raise ValueError(
2755                f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}"
2756                + " (a batch axis is not allowed as reference)."
2757            )
2758
2759        if isinstance(axis, WithHalo):
2760            min_size = axis.size.get_size(axis, ref_axis, n=0)
2761            if (min_size - 2 * axis.halo) < 1:
2762                raise ValueError(
2763                    f"axis {axis.id} with minimum size {min_size} is too small for halo"
2764                    + f" {axis.halo}."
2765                )
2766
2767            input_halo = axis.halo * axis.scale / ref_axis.scale
2768            if input_halo != int(input_halo) or input_halo % 2 == 1:
2769                raise ValueError(
2770                    f"input_halo {input_halo} (output_halo {axis.halo} *"
2771                    + f" output_scale {axis.scale} / input_scale {ref_axis.scale})"
2772                    + f"     {tensor_id}.{axis.id}."
2773                )
2774
2775    @model_validator(mode="after")
2776    def _validate_test_tensors(self) -> Self:
2777        if not get_validation_context().perform_io_checks:
2778            return self
2779
2780        test_output_arrays = [
2781            None if descr.test_tensor is None else load_array(descr.test_tensor)
2782            for descr in self.outputs
2783        ]
2784        test_input_arrays = [
2785            None if descr.test_tensor is None else load_array(descr.test_tensor)
2786            for descr in self.inputs
2787        ]
2788
2789        tensors = {
2790            descr.id: (descr, array)
2791            for descr, array in zip(
2792                chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays
2793            )
2794        }
2795        validate_tensors(tensors, tensor_origin="test_tensor")
2796
2797        output_arrays = {
2798            descr.id: array for descr, array in zip(self.outputs, test_output_arrays)
2799        }
2800        for rep_tol in self.config.bioimageio.reproducibility_tolerance:
2801            if not rep_tol.absolute_tolerance:
2802                continue
2803
2804            if rep_tol.output_ids:
2805                out_arrays = {
2806                    oid: a
2807                    for oid, a in output_arrays.items()
2808                    if oid in rep_tol.output_ids
2809                }
2810            else:
2811                out_arrays = output_arrays
2812
2813            for out_id, array in out_arrays.items():
2814                if array is None:
2815                    continue
2816
2817                if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01:
2818                    raise ValueError(
2819                        "config.bioimageio.reproducibility_tolerance.absolute_tolerance="
2820                        + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}"
2821                        + f" (1% of the maximum value of the test tensor '{out_id}')"
2822                    )
2823
2824        return self
2825
2826    @model_validator(mode="after")
2827    def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self:
2828        ipt_refs = {t.id for t in self.inputs}
2829        out_refs = {t.id for t in self.outputs}
2830        for ipt in self.inputs:
2831            for p in ipt.preprocessing:
2832                ref = p.kwargs.get("reference_tensor")
2833                if ref is None:
2834                    continue
2835                if ref not in ipt_refs:
2836                    raise ValueError(
2837                        f"`reference_tensor` '{ref}' not found. Valid input tensor"
2838                        + f" references are: {ipt_refs}."
2839                    )
2840
2841        for out in self.outputs:
2842            for p in out.postprocessing:
2843                ref = p.kwargs.get("reference_tensor")
2844                if ref is None:
2845                    continue
2846
2847                if ref not in ipt_refs and ref not in out_refs:
2848                    raise ValueError(
2849                        f"`reference_tensor` '{ref}' not found. Valid tensor references"
2850                        + f" are: {ipt_refs | out_refs}."
2851                    )
2852
2853        return self
2854
2855    # TODO: use validate funcs in validate_test_tensors
2856    # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]:
2857
2858    name: Annotated[
2859        str,
2860        RestrictCharacters(string.ascii_letters + string.digits + "_+- ()"),
2861        MinLen(5),
2862        MaxLen(128),
2863        warn(MaxLen(64), "Name longer than 64 characters.", INFO),
2864    ]
2865    """A human-readable name of this model.
2866    It should be no longer than 64 characters
2867    and may only contain letter, number, underscore, minus, parentheses and spaces.
2868    We recommend to chose a name that refers to the model's task and image modality.
2869    """
2870
2871    outputs: NotEmpty[Sequence[OutputTensorDescr]]
2872    """Describes the output tensors."""
2873
2874    @field_validator("outputs", mode="after")
2875    @classmethod
2876    def _validate_tensor_ids(
2877        cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo
2878    ) -> Sequence[OutputTensorDescr]:
2879        tensor_ids = [
2880            t.id for t in info.data.get("inputs", []) + info.data.get("outputs", [])
2881        ]
2882        duplicate_tensor_ids: List[str] = []
2883        seen: Set[str] = set()
2884        for t in tensor_ids:
2885            if t in seen:
2886                duplicate_tensor_ids.append(t)
2887
2888            seen.add(t)
2889
2890        if duplicate_tensor_ids:
2891            raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}")
2892
2893        return outputs
2894
2895    @staticmethod
2896    def _get_axes_with_parameterized_size(
2897        io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2898    ):
2899        return {
2900            f"{t.id}.{a.id}": (t, a, a.size)
2901            for t in io
2902            for a in t.axes
2903            if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize)
2904        }
2905
2906    @staticmethod
2907    def _get_axes_with_independent_size(
2908        io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2909    ):
2910        return {
2911            (t.id, a.id): (t, a, a.size)
2912            for t in io
2913            for a in t.axes
2914            if not isinstance(a, BatchAxis)
2915            and isinstance(a.size, (int, ParameterizedSize))
2916        }
2917
2918    @field_validator("outputs", mode="after")
2919    @classmethod
2920    def _validate_output_axes(
2921        cls, outputs: List[OutputTensorDescr], info: ValidationInfo
2922    ) -> List[OutputTensorDescr]:
2923        input_size_refs = cls._get_axes_with_independent_size(
2924            info.data.get("inputs", [])
2925        )
2926        output_size_refs = cls._get_axes_with_independent_size(outputs)
2927
2928        for i, out in enumerate(outputs):
2929            valid_independent_refs: Dict[
2930                Tuple[TensorId, AxisId],
2931                Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2932            ] = {
2933                **{
2934                    (out.id, a.id): (out, a, a.size)
2935                    for a in out.axes
2936                    if not isinstance(a, BatchAxis)
2937                    and isinstance(a.size, (int, ParameterizedSize))
2938                },
2939                **input_size_refs,
2940                **output_size_refs,
2941            }
2942            for a, ax in enumerate(out.axes):
2943                cls._validate_axis(
2944                    "outputs",
2945                    i,
2946                    out.id,
2947                    a,
2948                    ax,
2949                    valid_independent_refs=valid_independent_refs,
2950                )
2951
2952        return outputs
2953
2954    packaged_by: List[Author] = Field(
2955        default_factory=cast(Callable[[], List[Author]], list)
2956    )
2957    """The persons that have packaged and uploaded this model.
2958    Only required if those persons differ from the `authors`."""
2959
2960    parent: Optional[LinkedModel] = None
2961    """The model from which this model is derived, e.g. by fine-tuning the weights."""
2962
2963    @model_validator(mode="after")
2964    def _validate_parent_is_not_self(self) -> Self:
2965        if self.parent is not None and self.parent.id == self.id:
2966            raise ValueError("A model description may not reference itself as parent.")
2967
2968        return self
2969
2970    run_mode: Annotated[
2971        Optional[RunMode],
2972        warn(None, "Run mode '{value}' has limited support across consumer softwares."),
2973    ] = None
2974    """Custom run mode for this model: for more complex prediction procedures like test time
2975    data augmentation that currently cannot be expressed in the specification.
2976    No standard run modes are defined yet."""
2977
2978    timestamp: Datetime = Field(default_factory=Datetime.now)
2979    """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format
2980    with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat).
2981    (In Python a datetime object is valid, too)."""
2982
2983    training_data: Annotated[
2984        Union[None, LinkedDataset, DatasetDescr, DatasetDescr02],
2985        Field(union_mode="left_to_right"),
2986    ] = None
2987    """The dataset used to train this model"""
2988
2989    weights: Annotated[WeightsDescr, WrapSerializer(package_weights)]
2990    """The weights for this model.
2991    Weights can be given for different formats, but should otherwise be equivalent.
2992    The available weight formats determine which consumers can use this model."""
2993
2994    config: Config = Field(default_factory=Config.model_construct)
2995
2996    @model_validator(mode="after")
2997    def _add_default_cover(self) -> Self:
2998        if not get_validation_context().perform_io_checks or self.covers:
2999            return self
3000
3001        try:
3002            generated_covers = generate_covers(
3003                [
3004                    (t, load_array(t.test_tensor))
3005                    for t in self.inputs
3006                    if t.test_tensor is not None
3007                ],
3008                [
3009                    (t, load_array(t.test_tensor))
3010                    for t in self.outputs
3011                    if t.test_tensor is not None
3012                ],
3013            )
3014        except Exception as e:
3015            issue_warning(
3016                "Failed to generate cover image(s): {e}",
3017                value=self.covers,
3018                msg_context=dict(e=e),
3019                field="covers",
3020            )
3021        else:
3022            self.covers.extend(generated_covers)
3023
3024        return self
3025
3026    def get_input_test_arrays(self) -> List[NDArray[Any]]:
3027        return self._get_test_arrays(self.inputs)
3028
3029    def get_output_test_arrays(self) -> List[NDArray[Any]]:
3030        return self._get_test_arrays(self.outputs)
3031
3032    @staticmethod
3033    def _get_test_arrays(
3034        io_descr: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
3035    ):
3036        ts: List[FileDescr] = []
3037        for d in io_descr:
3038            if d.test_tensor is None:
3039                raise ValueError(
3040                    f"Failed to get test arrays: description of '{d.id}' is missing a `test_tensor`."
3041                )
3042            ts.append(d.test_tensor)
3043
3044        data = [load_array(t) for t in ts]
3045        assert all(isinstance(d, np.ndarray) for d in data)
3046        return data
3047
3048    @staticmethod
3049    def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
3050        batch_size = 1
3051        tensor_with_batchsize: Optional[TensorId] = None
3052        for tid in tensor_sizes:
3053            for aid, s in tensor_sizes[tid].items():
3054                if aid != BATCH_AXIS_ID or s == 1 or s == batch_size:
3055                    continue
3056
3057                if batch_size != 1:
3058                    assert tensor_with_batchsize is not None
3059                    raise ValueError(
3060                        f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})"
3061                    )
3062
3063                batch_size = s
3064                tensor_with_batchsize = tid
3065
3066        return batch_size
3067
3068    def get_output_tensor_sizes(
3069        self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]
3070    ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]:
3071        """Returns the tensor output sizes for given **input_sizes**.
3072        Only if **input_sizes** has a valid input shape, the tensor output size is exact.
3073        Otherwise it might be larger than the actual (valid) output"""
3074        batch_size = self.get_batch_size(input_sizes)
3075        ns = self.get_ns(input_sizes)
3076
3077        tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size)
3078        return tensor_sizes.outputs
3079
3080    def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
3081        """get parameter `n` for each parameterized axis
3082        such that the valid input size is >= the given input size"""
3083        ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {}
3084        axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs}
3085        for tid in input_sizes:
3086            for aid, s in input_sizes[tid].items():
3087                size_descr = axes[tid][aid].size
3088                if isinstance(size_descr, ParameterizedSize):
3089                    ret[(tid, aid)] = size_descr.get_n(s)
3090                elif size_descr is None or isinstance(size_descr, (int, SizeReference)):
3091                    pass
3092                else:
3093                    assert_never(size_descr)
3094
3095        return ret
3096
3097    def get_tensor_sizes(
3098        self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int
3099    ) -> _TensorSizes:
3100        axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size)
3101        return _TensorSizes(
3102            {
3103                t: {
3104                    aa: axis_sizes.inputs[(tt, aa)]
3105                    for tt, aa in axis_sizes.inputs
3106                    if tt == t
3107                }
3108                for t in {tt for tt, _ in axis_sizes.inputs}
3109            },
3110            {
3111                t: {
3112                    aa: axis_sizes.outputs[(tt, aa)]
3113                    for tt, aa in axis_sizes.outputs
3114                    if tt == t
3115                }
3116                for t in {tt for tt, _ in axis_sizes.outputs}
3117            },
3118        )
3119
3120    def get_axis_sizes(
3121        self,
3122        ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
3123        batch_size: Optional[int] = None,
3124        *,
3125        max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
3126    ) -> _AxisSizes:
3127        """Determine input and output block shape for scale factors **ns**
3128        of parameterized input sizes.
3129
3130        Args:
3131            ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id))
3132                that is parameterized as `size = min + n * step`.
3133            batch_size: The desired size of the batch dimension.
3134                If given **batch_size** overwrites any batch size present in
3135                **max_input_shape**. Default 1.
3136            max_input_shape: Limits the derived block shapes.
3137                Each axis for which the input size, parameterized by `n`, is larger
3138                than **max_input_shape** is set to the minimal value `n_min` for which
3139                this is still true.
3140                Use this for small input samples or large values of **ns**.
3141                Or simply whenever you know the full input shape.
3142
3143        Returns:
3144            Resolved axis sizes for model inputs and outputs.
3145        """
3146        max_input_shape = max_input_shape or {}
3147        if batch_size is None:
3148            for (_t_id, a_id), s in max_input_shape.items():
3149                if a_id == BATCH_AXIS_ID:
3150                    batch_size = s
3151                    break
3152            else:
3153                batch_size = 1
3154
3155        all_axes = {
3156            t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
3157        }
3158
3159        inputs: Dict[Tuple[TensorId, AxisId], int] = {}
3160        outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {}
3161
3162        def get_axis_size(a: Union[InputAxis, OutputAxis]):
3163            if isinstance(a, BatchAxis):
3164                if (t_descr.id, a.id) in ns:
3165                    logger.warning(
3166                        "Ignoring unexpected size increment factor (n) for batch axis"
3167                        + " of tensor '{}'.",
3168                        t_descr.id,
3169                    )
3170                return batch_size
3171            elif isinstance(a.size, int):
3172                if (t_descr.id, a.id) in ns:
3173                    logger.warning(
3174                        "Ignoring unexpected size increment factor (n) for fixed size"
3175                        + " axis '{}' of tensor '{}'.",
3176                        a.id,
3177                        t_descr.id,
3178                    )
3179                return a.size
3180            elif isinstance(a.size, ParameterizedSize):
3181                if (t_descr.id, a.id) not in ns:
3182                    raise ValueError(
3183                        "Size increment factor (n) missing for parametrized axis"
3184                        + f" '{a.id}' of tensor '{t_descr.id}'."
3185                    )
3186                n = ns[(t_descr.id, a.id)]
3187                s_max = max_input_shape.get((t_descr.id, a.id))
3188                if s_max is not None:
3189                    n = min(n, a.size.get_n(s_max))
3190
3191                return a.size.get_size(n)
3192
3193            elif isinstance(a.size, SizeReference):
3194                if (t_descr.id, a.id) in ns:
3195                    logger.warning(
3196                        "Ignoring unexpected size increment factor (n) for axis '{}'"
3197                        + " of tensor '{}' with size reference.",
3198                        a.id,
3199                        t_descr.id,
3200                    )
3201                assert not isinstance(a, BatchAxis)
3202                ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
3203                assert not isinstance(ref_axis, BatchAxis)
3204                ref_key = (a.size.tensor_id, a.size.axis_id)
3205                ref_size = inputs.get(ref_key, outputs.get(ref_key))
3206                assert ref_size is not None, ref_key
3207                assert not isinstance(ref_size, _DataDepSize), ref_key
3208                return a.size.get_size(
3209                    axis=a,
3210                    ref_axis=ref_axis,
3211                    ref_size=ref_size,
3212                )
3213            elif isinstance(a.size, DataDependentSize):
3214                if (t_descr.id, a.id) in ns:
3215                    logger.warning(
3216                        "Ignoring unexpected increment factor (n) for data dependent"
3217                        + " size axis '{}' of tensor '{}'.",
3218                        a.id,
3219                        t_descr.id,
3220                    )
3221                return _DataDepSize(a.size.min, a.size.max)
3222            else:
3223                assert_never(a.size)
3224
3225        # first resolve all , but the `SizeReference` input sizes
3226        for t_descr in self.inputs:
3227            for a in t_descr.axes:
3228                if not isinstance(a.size, SizeReference):
3229                    s = get_axis_size(a)
3230                    assert not isinstance(s, _DataDepSize)
3231                    inputs[t_descr.id, a.id] = s
3232
3233        # resolve all other input axis sizes
3234        for t_descr in self.inputs:
3235            for a in t_descr.axes:
3236                if isinstance(a.size, SizeReference):
3237                    s = get_axis_size(a)
3238                    assert not isinstance(s, _DataDepSize)
3239                    inputs[t_descr.id, a.id] = s
3240
3241        # resolve all output axis sizes
3242        for t_descr in self.outputs:
3243            for a in t_descr.axes:
3244                assert not isinstance(a.size, ParameterizedSize)
3245                s = get_axis_size(a)
3246                outputs[t_descr.id, a.id] = s
3247
3248        return _AxisSizes(inputs=inputs, outputs=outputs)
3249
3250    @model_validator(mode="before")
3251    @classmethod
3252    def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]:
3253        cls.convert_from_old_format_wo_validation(data)
3254        return data
3255
3256    @classmethod
3257    def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3258        """Convert metadata following an older format version to this classes' format
3259        without validating the result.
3260        """
3261        if (
3262            data.get("type") == "model"
3263            and isinstance(fv := data.get("format_version"), str)
3264            and fv.count(".") == 2
3265        ):
3266            fv_parts = fv.split(".")
3267            if any(not p.isdigit() for p in fv_parts):
3268                return
3269
3270            fv_tuple = tuple(map(int, fv_parts))
3271
3272            assert cls.implemented_format_version_tuple[0:2] == (0, 5)
3273            if fv_tuple[:2] in ((0, 3), (0, 4)):
3274                m04 = _ModelDescr_v0_4.load(data)
3275                if isinstance(m04, InvalidDescr):
3276                    try:
3277                        updated = _model_conv.convert_as_dict(
3278                            m04  # pyright: ignore[reportArgumentType]
3279                        )
3280                    except Exception as e:
3281                        logger.error(
3282                            "Failed to convert from invalid model 0.4 description."
3283                            + f"\nerror: {e}"
3284                            + "\nProceeding with model 0.5 validation without conversion."
3285                        )
3286                        updated = None
3287                else:
3288                    updated = _model_conv.convert_as_dict(m04)
3289
3290                if updated is not None:
3291                    data.clear()
3292                    data.update(updated)
3293
3294            elif fv_tuple[:2] == (0, 5):
3295                # bump patch version
3296                data["format_version"] = cls.implemented_format_version

Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. These fields are typically stored in a YAML file which we call a model resource description file (model RDF).

implemented_format_version: ClassVar[Literal['0.5.5']] = '0.5.5'
implemented_type: ClassVar[Literal['model']] = 'model'
id: Optional[ModelId]

bioimage.io-wide unique resource identifier assigned by bioimage.io; version unspecific.

authors: Annotated[List[bioimageio.spec.generic.v0_3.Author], AfterWarner(func=<function as_warning.<locals>.wrapper at 0x7f83b7cd1080>, severity=35, msg=None, context=None)]

The authors are the creators of the model RDF and the primary points of contact.

documentation: Annotated[Optional[Annotated[Union[bioimageio.spec._internal.url.HttpUrl, bioimageio.spec._internal.io.RelativeFilePath, Annotated[pathlib.Path, PathType(path_type='file'), FieldInfo(annotation=NoneType, required=True, title='FilePath')]], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')]), AfterValidator(func=<function wo_special_file_name at 0x7f83b7c0fec0>), PlainSerializer(func=<function _package_serializer at 0x7f83b7cb7100>, return_type=PydanticUndefined, when_used='unless-none'), WithSuffix(suffix='.md', case_sensitive=True), FieldInfo(annotation=NoneType, required=True, examples=['https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/README.md', 'README.md'])]], AfterWarner(func=<function as_warning.<locals>.wrapper at 0x7f83b7cd1080>, severity=35, msg=None, context=None)]

URL or relative path to a markdown file with additional documentation. The recommended documentation file name is README.md. An .md suffix is mandatory. The documentation should include a '#[#] Validation' (sub)section with details on how to quantitatively validate the model on unseen data.

inputs: Annotated[Sequence[InputTensorDescr], MinLen(min_length=1)]

Describes the input tensors expected by this model.

name: Annotated[str, RestrictCharacters(alphabet='abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_+- ()'), MinLen(min_length=5), MaxLen(max_length=128), AfterWarner(func=<function as_warning.<locals>.wrapper at 0x7f83b486fd80>, severity=20, msg='Name longer than 64 characters.', context={'typ': Annotated[Any, MaxLen(max_length=64)]})]

A human-readable name of this model. It should be no longer than 64 characters and may only contain letter, number, underscore, minus, parentheses and spaces. We recommend to chose a name that refers to the model's task and image modality.

outputs: Annotated[Sequence[OutputTensorDescr], MinLen(min_length=1)]

Describes the output tensors.

The persons that have packaged and uploaded this model. Only required if those persons differ from the authors.

parent: Optional[LinkedModel]

The model from which this model is derived, e.g. by fine-tuning the weights.

run_mode: Annotated[Optional[bioimageio.spec.model.v0_4.RunMode], AfterWarner(func=<function as_warning.<locals>.wrapper at 0x7f83b486fe20>, severity=30, msg="Run mode '{value}' has limited support across consumer softwares.", context={'typ': None})]

Custom run mode for this model: for more complex prediction procedures like test time data augmentation that currently cannot be expressed in the specification. No standard run modes are defined yet.

Timestamp in ISO 8601 format with a few restrictions listed here. (In Python a datetime object is valid, too).

training_data: Annotated[Union[NoneType, bioimageio.spec.dataset.v0_3.LinkedDataset, bioimageio.spec.DatasetDescr, bioimageio.spec.dataset.v0_2.DatasetDescr], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')])]

The dataset used to train this model

weights: Annotated[WeightsDescr, WrapSerializer(func=<function package_weights at 0x7f83b790a5c0>, return_type=PydanticUndefined, when_used='always')]

The weights for this model. Weights can be given for different formats, but should otherwise be equivalent. The available weight formats determine which consumers can use this model.

config: Config
def get_input_test_arrays(self) -> List[numpy.ndarray[tuple[Any, ...], numpy.dtype[Any]]]:
3026    def get_input_test_arrays(self) -> List[NDArray[Any]]:
3027        return self._get_test_arrays(self.inputs)
def get_output_test_arrays(self) -> List[numpy.ndarray[tuple[Any, ...], numpy.dtype[Any]]]:
3029    def get_output_test_arrays(self) -> List[NDArray[Any]]:
3030        return self._get_test_arrays(self.outputs)
@staticmethod
def get_batch_size( tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
3048    @staticmethod
3049    def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
3050        batch_size = 1
3051        tensor_with_batchsize: Optional[TensorId] = None
3052        for tid in tensor_sizes:
3053            for aid, s in tensor_sizes[tid].items():
3054                if aid != BATCH_AXIS_ID or s == 1 or s == batch_size:
3055                    continue
3056
3057                if batch_size != 1:
3058                    assert tensor_with_batchsize is not None
3059                    raise ValueError(
3060                        f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})"
3061                    )
3062
3063                batch_size = s
3064                tensor_with_batchsize = tid
3065
3066        return batch_size
def get_output_tensor_sizes( self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> Dict[TensorId, Dict[AxisId, Union[int, bioimageio.spec.model.v0_5._DataDepSize]]]:
3068    def get_output_tensor_sizes(
3069        self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]
3070    ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]:
3071        """Returns the tensor output sizes for given **input_sizes**.
3072        Only if **input_sizes** has a valid input shape, the tensor output size is exact.
3073        Otherwise it might be larger than the actual (valid) output"""
3074        batch_size = self.get_batch_size(input_sizes)
3075        ns = self.get_ns(input_sizes)
3076
3077        tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size)
3078        return tensor_sizes.outputs

Returns the tensor output sizes for given input_sizes. Only if input_sizes has a valid input shape, the tensor output size is exact. Otherwise it might be larger than the actual (valid) output

def get_ns( self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
3080    def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
3081        """get parameter `n` for each parameterized axis
3082        such that the valid input size is >= the given input size"""
3083        ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {}
3084        axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs}
3085        for tid in input_sizes:
3086            for aid, s in input_sizes[tid].items():
3087                size_descr = axes[tid][aid].size
3088                if isinstance(size_descr, ParameterizedSize):
3089                    ret[(tid, aid)] = size_descr.get_n(s)
3090                elif size_descr is None or isinstance(size_descr, (int, SizeReference)):
3091                    pass
3092                else:
3093                    assert_never(size_descr)
3094
3095        return ret

get parameter n for each parameterized axis such that the valid input size is >= the given input size

def get_tensor_sizes( self, ns: Mapping[Tuple[TensorId, AxisId], int], batch_size: int) -> bioimageio.spec.model.v0_5._TensorSizes:
3097    def get_tensor_sizes(
3098        self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int
3099    ) -> _TensorSizes:
3100        axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size)
3101        return _TensorSizes(
3102            {
3103                t: {
3104                    aa: axis_sizes.inputs[(tt, aa)]
3105                    for tt, aa in axis_sizes.inputs
3106                    if tt == t
3107                }
3108                for t in {tt for tt, _ in axis_sizes.inputs}
3109            },
3110            {
3111                t: {
3112                    aa: axis_sizes.outputs[(tt, aa)]
3113                    for tt, aa in axis_sizes.outputs
3114                    if tt == t
3115                }
3116                for t in {tt for tt, _ in axis_sizes.outputs}
3117            },
3118        )
def get_axis_sizes( self, ns: Mapping[Tuple[TensorId, AxisId], int], batch_size: Optional[int] = None, *, max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None) -> bioimageio.spec.model.v0_5._AxisSizes:
3120    def get_axis_sizes(
3121        self,
3122        ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
3123        batch_size: Optional[int] = None,
3124        *,
3125        max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
3126    ) -> _AxisSizes:
3127        """Determine input and output block shape for scale factors **ns**
3128        of parameterized input sizes.
3129
3130        Args:
3131            ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id))
3132                that is parameterized as `size = min + n * step`.
3133            batch_size: The desired size of the batch dimension.
3134                If given **batch_size** overwrites any batch size present in
3135                **max_input_shape**. Default 1.
3136            max_input_shape: Limits the derived block shapes.
3137                Each axis for which the input size, parameterized by `n`, is larger
3138                than **max_input_shape** is set to the minimal value `n_min` for which
3139                this is still true.
3140                Use this for small input samples or large values of **ns**.
3141                Or simply whenever you know the full input shape.
3142
3143        Returns:
3144            Resolved axis sizes for model inputs and outputs.
3145        """
3146        max_input_shape = max_input_shape or {}
3147        if batch_size is None:
3148            for (_t_id, a_id), s in max_input_shape.items():
3149                if a_id == BATCH_AXIS_ID:
3150                    batch_size = s
3151                    break
3152            else:
3153                batch_size = 1
3154
3155        all_axes = {
3156            t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
3157        }
3158
3159        inputs: Dict[Tuple[TensorId, AxisId], int] = {}
3160        outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {}
3161
3162        def get_axis_size(a: Union[InputAxis, OutputAxis]):
3163            if isinstance(a, BatchAxis):
3164                if (t_descr.id, a.id) in ns:
3165                    logger.warning(
3166                        "Ignoring unexpected size increment factor (n) for batch axis"
3167                        + " of tensor '{}'.",
3168                        t_descr.id,
3169                    )
3170                return batch_size
3171            elif isinstance(a.size, int):
3172                if (t_descr.id, a.id) in ns:
3173                    logger.warning(
3174                        "Ignoring unexpected size increment factor (n) for fixed size"
3175                        + " axis '{}' of tensor '{}'.",
3176                        a.id,
3177                        t_descr.id,
3178                    )
3179                return a.size
3180            elif isinstance(a.size, ParameterizedSize):
3181                if (t_descr.id, a.id) not in ns:
3182                    raise ValueError(
3183                        "Size increment factor (n) missing for parametrized axis"
3184                        + f" '{a.id}' of tensor '{t_descr.id}'."
3185                    )
3186                n = ns[(t_descr.id, a.id)]
3187                s_max = max_input_shape.get((t_descr.id, a.id))
3188                if s_max is not None:
3189                    n = min(n, a.size.get_n(s_max))
3190
3191                return a.size.get_size(n)
3192
3193            elif isinstance(a.size, SizeReference):
3194                if (t_descr.id, a.id) in ns:
3195                    logger.warning(
3196                        "Ignoring unexpected size increment factor (n) for axis '{}'"
3197                        + " of tensor '{}' with size reference.",
3198                        a.id,
3199                        t_descr.id,
3200                    )
3201                assert not isinstance(a, BatchAxis)
3202                ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
3203                assert not isinstance(ref_axis, BatchAxis)
3204                ref_key = (a.size.tensor_id, a.size.axis_id)
3205                ref_size = inputs.get(ref_key, outputs.get(ref_key))
3206                assert ref_size is not None, ref_key
3207                assert not isinstance(ref_size, _DataDepSize), ref_key
3208                return a.size.get_size(
3209                    axis=a,
3210                    ref_axis=ref_axis,
3211                    ref_size=ref_size,
3212                )
3213            elif isinstance(a.size, DataDependentSize):
3214                if (t_descr.id, a.id) in ns:
3215                    logger.warning(
3216                        "Ignoring unexpected increment factor (n) for data dependent"
3217                        + " size axis '{}' of tensor '{}'.",
3218                        a.id,
3219                        t_descr.id,
3220                    )
3221                return _DataDepSize(a.size.min, a.size.max)
3222            else:
3223                assert_never(a.size)
3224
3225        # first resolve all , but the `SizeReference` input sizes
3226        for t_descr in self.inputs:
3227            for a in t_descr.axes:
3228                if not isinstance(a.size, SizeReference):
3229                    s = get_axis_size(a)
3230                    assert not isinstance(s, _DataDepSize)
3231                    inputs[t_descr.id, a.id] = s
3232
3233        # resolve all other input axis sizes
3234        for t_descr in self.inputs:
3235            for a in t_descr.axes:
3236                if isinstance(a.size, SizeReference):
3237                    s = get_axis_size(a)
3238                    assert not isinstance(s, _DataDepSize)
3239                    inputs[t_descr.id, a.id] = s
3240
3241        # resolve all output axis sizes
3242        for t_descr in self.outputs:
3243            for a in t_descr.axes:
3244                assert not isinstance(a.size, ParameterizedSize)
3245                s = get_axis_size(a)
3246                outputs[t_descr.id, a.id] = s
3247
3248        return _AxisSizes(inputs=inputs, outputs=outputs)

Determine input and output block shape for scale factors ns of parameterized input sizes.

Arguments:
  • ns: Scale factor n for each axis (keyed by (tensor_id, axis_id)) that is parameterized as size = min + n * step.
  • batch_size: The desired size of the batch dimension. If given batch_size overwrites any batch size present in max_input_shape. Default 1.
  • max_input_shape: Limits the derived block shapes. Each axis for which the input size, parameterized by n, is larger than max_input_shape is set to the minimal value n_min for which this is still true. Use this for small input samples or large values of ns. Or simply whenever you know the full input shape.
Returns:

Resolved axis sizes for model inputs and outputs.

@classmethod
def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3256    @classmethod
3257    def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3258        """Convert metadata following an older format version to this classes' format
3259        without validating the result.
3260        """
3261        if (
3262            data.get("type") == "model"
3263            and isinstance(fv := data.get("format_version"), str)
3264            and fv.count(".") == 2
3265        ):
3266            fv_parts = fv.split(".")
3267            if any(not p.isdigit() for p in fv_parts):
3268                return
3269
3270            fv_tuple = tuple(map(int, fv_parts))
3271
3272            assert cls.implemented_format_version_tuple[0:2] == (0, 5)
3273            if fv_tuple[:2] in ((0, 3), (0, 4)):
3274                m04 = _ModelDescr_v0_4.load(data)
3275                if isinstance(m04, InvalidDescr):
3276                    try:
3277                        updated = _model_conv.convert_as_dict(
3278                            m04  # pyright: ignore[reportArgumentType]
3279                        )
3280                    except Exception as e:
3281                        logger.error(
3282                            "Failed to convert from invalid model 0.4 description."
3283                            + f"\nerror: {e}"
3284                            + "\nProceeding with model 0.5 validation without conversion."
3285                        )
3286                        updated = None
3287                else:
3288                    updated = _model_conv.convert_as_dict(m04)
3289
3290                if updated is not None:
3291                    data.clear()
3292                    data.update(updated)
3293
3294            elif fv_tuple[:2] == (0, 5):
3295                # bump patch version
3296                data["format_version"] = cls.implemented_format_version

Convert metadata following an older format version to this classes' format without validating the result.

implemented_format_version_tuple: ClassVar[Tuple[int, int, int]] = (0, 5, 5)
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

def model_post_init(self: pydantic.main.BaseModel, context: Any, /) -> None:
337def init_private_attributes(self: BaseModel, context: Any, /) -> None:
338    """This function is meant to behave like a BaseModel method to initialise private attributes.
339
340    It takes context as an argument since that's what pydantic-core passes when calling it.
341
342    Args:
343        self: The BaseModel instance.
344        context: The context.
345    """
346    if getattr(self, '__pydantic_private__', None) is None:
347        pydantic_private = {}
348        for name, private_attr in self.__private_attributes__.items():
349            default = private_attr.get_default()
350            if default is not PydanticUndefined:
351                pydantic_private[name] = default
352        object_setattr(self, '__pydantic_private__', pydantic_private)

This function is meant to behave like a BaseModel method to initialise private attributes.

It takes context as an argument since that's what pydantic-core passes when calling it.

Arguments:
  • self: The BaseModel instance.
  • context: The context.
def generate_covers( inputs: Sequence[Tuple[InputTensorDescr, numpy.ndarray[tuple[Any, ...], numpy.dtype[Any]]]], outputs: Sequence[Tuple[OutputTensorDescr, numpy.ndarray[tuple[Any, ...], numpy.dtype[Any]]]]) -> List[pathlib.Path]:
3521def generate_covers(
3522    inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]],
3523    outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]],
3524) -> List[Path]:
3525    def squeeze(
3526        data: NDArray[Any], axes: Sequence[AnyAxis]
3527    ) -> Tuple[NDArray[Any], List[AnyAxis]]:
3528        """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining"""
3529        if data.ndim != len(axes):
3530            raise ValueError(
3531                f"tensor shape {data.shape} does not match described axes"
3532                + f" {[a.id for a in axes]}"
3533            )
3534
3535        axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1]
3536        return data.squeeze(), axes
3537
3538    def normalize(
3539        data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7
3540    ) -> NDArray[np.float32]:
3541        data = data.astype("float32")
3542        data -= data.min(axis=axis, keepdims=True)
3543        data /= data.max(axis=axis, keepdims=True) + eps
3544        return data
3545
3546    def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]):
3547        original_shape = data.shape
3548        data, axes = squeeze(data, axes)
3549
3550        # take slice fom any batch or index axis if needed
3551        # and convert the first channel axis and take a slice from any additional channel axes
3552        slices: Tuple[slice, ...] = ()
3553        ndim = data.ndim
3554        ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2
3555        has_c_axis = False
3556        for i, a in enumerate(axes):
3557            s = data.shape[i]
3558            assert s > 1
3559            if (
3560                isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis))
3561                and ndim > ndim_need
3562            ):
3563                data = data[slices + (slice(s // 2 - 1, s // 2),)]
3564                ndim -= 1
3565            elif isinstance(a, ChannelAxis):
3566                if has_c_axis:
3567                    # second channel axis
3568                    data = data[slices + (slice(0, 1),)]
3569                    ndim -= 1
3570                else:
3571                    has_c_axis = True
3572                    if s == 2:
3573                        # visualize two channels with cyan and magenta
3574                        data = np.concatenate(
3575                            [
3576                                data[slices + (slice(1, 2),)],
3577                                data[slices + (slice(0, 1),)],
3578                                (
3579                                    data[slices + (slice(0, 1),)]
3580                                    + data[slices + (slice(1, 2),)]
3581                                )
3582                                / 2,  # TODO: take maximum instead?
3583                            ],
3584                            axis=i,
3585                        )
3586                    elif data.shape[i] == 3:
3587                        pass  # visualize 3 channels as RGB
3588                    else:
3589                        # visualize first 3 channels as RGB
3590                        data = data[slices + (slice(3),)]
3591
3592                    assert data.shape[i] == 3
3593
3594            slices += (slice(None),)
3595
3596        data, axes = squeeze(data, axes)
3597        assert len(axes) == ndim
3598        # take slice from z axis if needed
3599        slices = ()
3600        if ndim > ndim_need:
3601            for i, a in enumerate(axes):
3602                s = data.shape[i]
3603                if a.id == AxisId("z"):
3604                    data = data[slices + (slice(s // 2 - 1, s // 2),)]
3605                    data, axes = squeeze(data, axes)
3606                    ndim -= 1
3607                    break
3608
3609            slices += (slice(None),)
3610
3611        # take slice from any space or time axis
3612        slices = ()
3613
3614        for i, a in enumerate(axes):
3615            if ndim <= ndim_need:
3616                break
3617
3618            s = data.shape[i]
3619            assert s > 1
3620            if isinstance(
3621                a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis)
3622            ):
3623                data = data[slices + (slice(s // 2 - 1, s // 2),)]
3624                ndim -= 1
3625
3626            slices += (slice(None),)
3627
3628        del slices
3629        data, axes = squeeze(data, axes)
3630        assert len(axes) == ndim
3631
3632        if (has_c_axis and ndim != 3) or ndim != 2:
3633            raise ValueError(
3634                f"Failed to construct cover image from shape {original_shape}"
3635            )
3636
3637        if not has_c_axis:
3638            assert ndim == 2
3639            data = np.repeat(data[:, :, None], 3, axis=2)
3640            axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB"))))
3641            ndim += 1
3642
3643        assert ndim == 3
3644
3645        # transpose axis order such that longest axis comes first...
3646        axis_order: List[int] = list(np.argsort(list(data.shape)))
3647        axis_order.reverse()
3648        # ... and channel axis is last
3649        c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0]
3650        axis_order.append(axis_order.pop(c))
3651        axes = [axes[ao] for ao in axis_order]
3652        data = data.transpose(axis_order)
3653
3654        # h, w = data.shape[:2]
3655        # if h / w  in (1.0 or 2.0):
3656        #     pass
3657        # elif h / w < 2:
3658        # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images
3659
3660        norm_along = (
3661            tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None
3662        )
3663        # normalize the data and map to 8 bit
3664        data = normalize(data, norm_along)
3665        data = (data * 255).astype("uint8")
3666
3667        return data
3668
3669    def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]):
3670        assert im0.dtype == im1.dtype == np.uint8
3671        assert im0.shape == im1.shape
3672        assert im0.ndim == 3
3673        N, M, C = im0.shape
3674        assert C == 3
3675        out = np.ones((N, M, C), dtype="uint8")
3676        for c in range(C):
3677            outc = np.tril(im0[..., c])
3678            mask = outc == 0
3679            outc[mask] = np.triu(im1[..., c])[mask]
3680            out[..., c] = outc
3681
3682        return out
3683
3684    if not inputs:
3685        raise ValueError("Missing test input tensor for cover generation.")
3686
3687    if not outputs:
3688        raise ValueError("Missing test output tensor for cover generation.")
3689
3690    ipt_descr, ipt = inputs[0]
3691    out_descr, out = outputs[0]
3692
3693    ipt_img = to_2d_image(ipt, ipt_descr.axes)
3694    out_img = to_2d_image(out, out_descr.axes)
3695
3696    cover_folder = Path(mkdtemp())
3697    if ipt_img.shape == out_img.shape:
3698        covers = [cover_folder / "cover.png"]
3699        imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img))
3700    else:
3701        covers = [cover_folder / "input.png", cover_folder / "output.png"]
3702        imwrite(covers[0], ipt_img)
3703        imwrite(covers[1], out_img)
3704
3705    return covers
class TensorDescrBase[Annotated[Union[BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis], Discriminator]](bioimageio.spec._internal.node.Node, typing.Generic[~IO_AxisT]):
class TensorDescrBase[Annotated[Union[BatchAxis, ChannelAxis, IndexOutputAxis, Annotated[Union[Annotated[TimeOutputAxis, Tag], Annotated[TimeOutputAxisWithHalo, Tag]], Discriminator], Annotated[Union[Annotated[SpaceOutputAxis, Tag], Annotated[SpaceOutputAxisWithHalo, Tag]], Discriminator]], Discriminator]](bioimageio.spec._internal.node.Node, typing.Generic[~IO_AxisT]):