bioimageio.spec.model.v0_5

   1from __future__ import annotations
   2
   3import collections.abc
   4import re
   5import string
   6import warnings
   7from abc import ABC
   8from copy import deepcopy
   9from itertools import chain
  10from math import ceil
  11from pathlib import Path, PurePosixPath
  12from tempfile import mkdtemp
  13from typing import (
  14    TYPE_CHECKING,
  15    Any,
  16    Callable,
  17    ClassVar,
  18    Dict,
  19    Generic,
  20    List,
  21    Literal,
  22    Mapping,
  23    NamedTuple,
  24    Optional,
  25    Sequence,
  26    Set,
  27    Tuple,
  28    Type,
  29    TypeVar,
  30    Union,
  31    cast,
  32)
  33
  34import numpy as np
  35from annotated_types import Ge, Gt, Interval, MaxLen, MinLen, Predicate
  36from imageio.v3 import imread, imwrite  # pyright: ignore[reportUnknownVariableType]
  37from loguru import logger
  38from numpy.typing import NDArray
  39from pydantic import (
  40    AfterValidator,
  41    Discriminator,
  42    Field,
  43    RootModel,
  44    SerializationInfo,
  45    SerializerFunctionWrapHandler,
  46    StrictInt,
  47    Tag,
  48    ValidationInfo,
  49    WrapSerializer,
  50    field_validator,
  51    model_serializer,
  52    model_validator,
  53)
  54from typing_extensions import Annotated, Self, assert_never, get_args
  55
  56from .._internal.common_nodes import (
  57    InvalidDescr,
  58    Node,
  59    NodeWithExplicitlySetFields,
  60)
  61from .._internal.constants import DTYPE_LIMITS
  62from .._internal.field_warning import issue_warning, warn
  63from .._internal.io import BioimageioYamlContent as BioimageioYamlContent
  64from .._internal.io import FileDescr as FileDescr
  65from .._internal.io import (
  66    FileSource,
  67    WithSuffix,
  68    YamlValue,
  69    get_reader,
  70    wo_special_file_name,
  71)
  72from .._internal.io_basics import Sha256 as Sha256
  73from .._internal.io_packaging import (
  74    FileDescr_,
  75    FileSource_,
  76    package_file_descr_serializer,
  77)
  78from .._internal.io_utils import load_array
  79from .._internal.node_converter import Converter
  80from .._internal.type_guards import is_dict, is_sequence
  81from .._internal.types import (
  82    AbsoluteTolerance,
  83    LowerCaseIdentifier,
  84    LowerCaseIdentifierAnno,
  85    MismatchedElementsPerMillion,
  86    RelativeTolerance,
  87)
  88from .._internal.types import Datetime as Datetime
  89from .._internal.types import Identifier as Identifier
  90from .._internal.types import NotEmpty as NotEmpty
  91from .._internal.types import SiUnit as SiUnit
  92from .._internal.url import HttpUrl as HttpUrl
  93from .._internal.validation_context import get_validation_context
  94from .._internal.validator_annotations import RestrictCharacters
  95from .._internal.version_type import Version as Version
  96from .._internal.warning_levels import INFO
  97from ..dataset.v0_2 import DatasetDescr as DatasetDescr02
  98from ..dataset.v0_2 import LinkedDataset as LinkedDataset02
  99from ..dataset.v0_3 import DatasetDescr as DatasetDescr
 100from ..dataset.v0_3 import DatasetId as DatasetId
 101from ..dataset.v0_3 import LinkedDataset as LinkedDataset
 102from ..dataset.v0_3 import Uploader as Uploader
 103from ..generic.v0_3 import (
 104    VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS,
 105)
 106from ..generic.v0_3 import Author as Author
 107from ..generic.v0_3 import BadgeDescr as BadgeDescr
 108from ..generic.v0_3 import CiteEntry as CiteEntry
 109from ..generic.v0_3 import DeprecatedLicenseId as DeprecatedLicenseId
 110from ..generic.v0_3 import Doi as Doi
 111from ..generic.v0_3 import (
 112    FileSource_documentation,
 113    GenericModelDescrBase,
 114    LinkedResourceBase,
 115    _author_conv,  # pyright: ignore[reportPrivateUsage]
 116    _maintainer_conv,  # pyright: ignore[reportPrivateUsage]
 117)
 118from ..generic.v0_3 import LicenseId as LicenseId
 119from ..generic.v0_3 import LinkedResource as LinkedResource
 120from ..generic.v0_3 import Maintainer as Maintainer
 121from ..generic.v0_3 import OrcidId as OrcidId
 122from ..generic.v0_3 import RelativeFilePath as RelativeFilePath
 123from ..generic.v0_3 import ResourceId as ResourceId
 124from .v0_4 import Author as _Author_v0_4
 125from .v0_4 import BinarizeDescr as _BinarizeDescr_v0_4
 126from .v0_4 import CallableFromDepencency as CallableFromDepencency
 127from .v0_4 import CallableFromDepencency as _CallableFromDepencency_v0_4
 128from .v0_4 import CallableFromFile as _CallableFromFile_v0_4
 129from .v0_4 import ClipDescr as _ClipDescr_v0_4
 130from .v0_4 import ClipKwargs as ClipKwargs
 131from .v0_4 import ImplicitOutputShape as _ImplicitOutputShape_v0_4
 132from .v0_4 import InputTensorDescr as _InputTensorDescr_v0_4
 133from .v0_4 import KnownRunMode as KnownRunMode
 134from .v0_4 import ModelDescr as _ModelDescr_v0_4
 135from .v0_4 import OutputTensorDescr as _OutputTensorDescr_v0_4
 136from .v0_4 import ParameterizedInputShape as _ParameterizedInputShape_v0_4
 137from .v0_4 import PostprocessingDescr as _PostprocessingDescr_v0_4
 138from .v0_4 import PreprocessingDescr as _PreprocessingDescr_v0_4
 139from .v0_4 import ProcessingKwargs as ProcessingKwargs
 140from .v0_4 import RunMode as RunMode
 141from .v0_4 import ScaleLinearDescr as _ScaleLinearDescr_v0_4
 142from .v0_4 import ScaleMeanVarianceDescr as _ScaleMeanVarianceDescr_v0_4
 143from .v0_4 import ScaleRangeDescr as _ScaleRangeDescr_v0_4
 144from .v0_4 import SigmoidDescr as _SigmoidDescr_v0_4
 145from .v0_4 import TensorName as _TensorName_v0_4
 146from .v0_4 import WeightsFormat as WeightsFormat
 147from .v0_4 import ZeroMeanUnitVarianceDescr as _ZeroMeanUnitVarianceDescr_v0_4
 148from .v0_4 import package_weights
 149
 150SpaceUnit = Literal[
 151    "attometer",
 152    "angstrom",
 153    "centimeter",
 154    "decimeter",
 155    "exameter",
 156    "femtometer",
 157    "foot",
 158    "gigameter",
 159    "hectometer",
 160    "inch",
 161    "kilometer",
 162    "megameter",
 163    "meter",
 164    "micrometer",
 165    "mile",
 166    "millimeter",
 167    "nanometer",
 168    "parsec",
 169    "petameter",
 170    "picometer",
 171    "terameter",
 172    "yard",
 173    "yoctometer",
 174    "yottameter",
 175    "zeptometer",
 176    "zettameter",
 177]
 178"""Space unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)"""
 179
 180TimeUnit = Literal[
 181    "attosecond",
 182    "centisecond",
 183    "day",
 184    "decisecond",
 185    "exasecond",
 186    "femtosecond",
 187    "gigasecond",
 188    "hectosecond",
 189    "hour",
 190    "kilosecond",
 191    "megasecond",
 192    "microsecond",
 193    "millisecond",
 194    "minute",
 195    "nanosecond",
 196    "petasecond",
 197    "picosecond",
 198    "second",
 199    "terasecond",
 200    "yoctosecond",
 201    "yottasecond",
 202    "zeptosecond",
 203    "zettasecond",
 204]
 205"""Time unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)"""
 206
 207AxisType = Literal["batch", "channel", "index", "time", "space"]
 208
 209_AXIS_TYPE_MAP: Mapping[str, AxisType] = {
 210    "b": "batch",
 211    "t": "time",
 212    "i": "index",
 213    "c": "channel",
 214    "x": "space",
 215    "y": "space",
 216    "z": "space",
 217}
 218
 219_AXIS_ID_MAP = {
 220    "b": "batch",
 221    "t": "time",
 222    "i": "index",
 223    "c": "channel",
 224}
 225
 226
 227class TensorId(LowerCaseIdentifier):
 228    root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
 229        Annotated[LowerCaseIdentifierAnno, MaxLen(32)]
 230    ]
 231
 232
 233def _normalize_axis_id(a: str):
 234    a = str(a)
 235    normalized = _AXIS_ID_MAP.get(a, a)
 236    if a != normalized:
 237        logger.opt(depth=3).warning(
 238            "Normalized axis id from '{}' to '{}'.", a, normalized
 239        )
 240    return normalized
 241
 242
 243class AxisId(LowerCaseIdentifier):
 244    root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
 245        Annotated[
 246            LowerCaseIdentifierAnno,
 247            MaxLen(16),
 248            AfterValidator(_normalize_axis_id),
 249        ]
 250    ]
 251
 252
 253def _is_batch(a: str) -> bool:
 254    return str(a) == "batch"
 255
 256
 257def _is_not_batch(a: str) -> bool:
 258    return not _is_batch(a)
 259
 260
 261NonBatchAxisId = Annotated[AxisId, Predicate(_is_not_batch)]
 262
 263PostprocessingId = Literal[
 264    "binarize",
 265    "clip",
 266    "ensure_dtype",
 267    "fixed_zero_mean_unit_variance",
 268    "scale_linear",
 269    "scale_mean_variance",
 270    "scale_range",
 271    "sigmoid",
 272    "zero_mean_unit_variance",
 273]
 274PreprocessingId = Literal[
 275    "binarize",
 276    "clip",
 277    "ensure_dtype",
 278    "scale_linear",
 279    "sigmoid",
 280    "zero_mean_unit_variance",
 281    "scale_range",
 282]
 283
 284
 285SAME_AS_TYPE = "<same as type>"
 286
 287
 288ParameterizedSize_N = int
 289"""
 290Annotates an integer to calculate a concrete axis size from a `ParameterizedSize`.
 291"""
 292
 293
 294class ParameterizedSize(Node):
 295    """Describes a range of valid tensor axis sizes as `size = min + n*step`.
 296
 297    - **min** and **step** are given by the model description.
 298    - All blocksize paramters n = 0,1,2,... yield a valid `size`.
 299    - A greater blocksize paramter n = 0,1,2,... results in a greater **size**.
 300      This allows to adjust the axis size more generically.
 301    """
 302
 303    N: ClassVar[Type[int]] = ParameterizedSize_N
 304    """Positive integer to parameterize this axis"""
 305
 306    min: Annotated[int, Gt(0)]
 307    step: Annotated[int, Gt(0)]
 308
 309    def validate_size(self, size: int) -> int:
 310        if size < self.min:
 311            raise ValueError(f"size {size} < {self.min}")
 312        if (size - self.min) % self.step != 0:
 313            raise ValueError(
 314                f"axis of size {size} is not parameterized by `min + n*step` ="
 315                + f" `{self.min} + n*{self.step}`"
 316            )
 317
 318        return size
 319
 320    def get_size(self, n: ParameterizedSize_N) -> int:
 321        return self.min + self.step * n
 322
 323    def get_n(self, s: int) -> ParameterizedSize_N:
 324        """return smallest n parameterizing a size greater or equal than `s`"""
 325        return ceil((s - self.min) / self.step)
 326
 327
 328class DataDependentSize(Node):
 329    min: Annotated[int, Gt(0)] = 1
 330    max: Annotated[Optional[int], Gt(1)] = None
 331
 332    @model_validator(mode="after")
 333    def _validate_max_gt_min(self):
 334        if self.max is not None and self.min >= self.max:
 335            raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}")
 336
 337        return self
 338
 339    def validate_size(self, size: int) -> int:
 340        if size < self.min:
 341            raise ValueError(f"size {size} < {self.min}")
 342
 343        if self.max is not None and size > self.max:
 344            raise ValueError(f"size {size} > {self.max}")
 345
 346        return size
 347
 348
 349class SizeReference(Node):
 350    """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis.
 351
 352    `axis.size = reference.size * reference.scale / axis.scale + offset`
 353
 354    Note:
 355    1. The axis and the referenced axis need to have the same unit (or no unit).
 356    2. Batch axes may not be referenced.
 357    3. Fractions are rounded down.
 358    4. If the reference axis is `concatenable` the referencing axis is assumed to be
 359        `concatenable` as well with the same block order.
 360
 361    Example:
 362    An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm².
 363    Let's assume that we want to express the image height h in relation to its width w
 364    instead of only accepting input images of exactly 100*49 pixels
 365    (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`).
 366
 367    >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2)
 368    >>> h = SpaceInputAxis(
 369    ...     id=AxisId("h"),
 370    ...     size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1),
 371    ...     unit="millimeter",
 372    ...     scale=4,
 373    ... )
 374    >>> print(h.size.get_size(h, w))
 375    49
 376
 377    ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49
 378    """
 379
 380    tensor_id: TensorId
 381    """tensor id of the reference axis"""
 382
 383    axis_id: AxisId
 384    """axis id of the reference axis"""
 385
 386    offset: StrictInt = 0
 387
 388    def get_size(
 389        self,
 390        axis: Union[
 391            ChannelAxis,
 392            IndexInputAxis,
 393            IndexOutputAxis,
 394            TimeInputAxis,
 395            SpaceInputAxis,
 396            TimeOutputAxis,
 397            TimeOutputAxisWithHalo,
 398            SpaceOutputAxis,
 399            SpaceOutputAxisWithHalo,
 400        ],
 401        ref_axis: Union[
 402            ChannelAxis,
 403            IndexInputAxis,
 404            IndexOutputAxis,
 405            TimeInputAxis,
 406            SpaceInputAxis,
 407            TimeOutputAxis,
 408            TimeOutputAxisWithHalo,
 409            SpaceOutputAxis,
 410            SpaceOutputAxisWithHalo,
 411        ],
 412        n: ParameterizedSize_N = 0,
 413        ref_size: Optional[int] = None,
 414    ):
 415        """Compute the concrete size for a given axis and its reference axis.
 416
 417        Args:
 418            axis: The axis this `SizeReference` is the size of.
 419            ref_axis: The reference axis to compute the size from.
 420            n: If the **ref_axis** is parameterized (of type `ParameterizedSize`)
 421                and no fixed **ref_size** is given,
 422                **n** is used to compute the size of the parameterized **ref_axis**.
 423            ref_size: Overwrite the reference size instead of deriving it from
 424                **ref_axis**
 425                (**ref_axis.scale** is still used; any given **n** is ignored).
 426        """
 427        assert (
 428            axis.size == self
 429        ), "Given `axis.size` is not defined by this `SizeReference`"
 430
 431        assert (
 432            ref_axis.id == self.axis_id
 433        ), f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}."
 434
 435        assert axis.unit == ref_axis.unit, (
 436            "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`,"
 437            f" but {axis.unit}!={ref_axis.unit}"
 438        )
 439        if ref_size is None:
 440            if isinstance(ref_axis.size, (int, float)):
 441                ref_size = ref_axis.size
 442            elif isinstance(ref_axis.size, ParameterizedSize):
 443                ref_size = ref_axis.size.get_size(n)
 444            elif isinstance(ref_axis.size, DataDependentSize):
 445                raise ValueError(
 446                    "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`."
 447                )
 448            elif isinstance(ref_axis.size, SizeReference):
 449                raise ValueError(
 450                    "Reference axis referenced in `SizeReference` may not be sized by a"
 451                    + " `SizeReference` itself."
 452                )
 453            else:
 454                assert_never(ref_axis.size)
 455
 456        return int(ref_size * ref_axis.scale / axis.scale + self.offset)
 457
 458    @staticmethod
 459    def _get_unit(
 460        axis: Union[
 461            ChannelAxis,
 462            IndexInputAxis,
 463            IndexOutputAxis,
 464            TimeInputAxis,
 465            SpaceInputAxis,
 466            TimeOutputAxis,
 467            TimeOutputAxisWithHalo,
 468            SpaceOutputAxis,
 469            SpaceOutputAxisWithHalo,
 470        ],
 471    ):
 472        return axis.unit
 473
 474
 475class AxisBase(NodeWithExplicitlySetFields):
 476    id: AxisId
 477    """An axis id unique across all axes of one tensor."""
 478
 479    description: Annotated[str, MaxLen(128)] = ""
 480
 481
 482class WithHalo(Node):
 483    halo: Annotated[int, Ge(1)]
 484    """The halo should be cropped from the output tensor to avoid boundary effects.
 485    It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`.
 486    To document a halo that is already cropped by the model use `size.offset` instead."""
 487
 488    size: Annotated[
 489        SizeReference,
 490        Field(
 491            examples=[
 492                10,
 493                SizeReference(
 494                    tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
 495                ).model_dump(mode="json"),
 496            ]
 497        ),
 498    ]
 499    """reference to another axis with an optional offset (see `SizeReference`)"""
 500
 501
 502BATCH_AXIS_ID = AxisId("batch")
 503
 504
 505class BatchAxis(AxisBase):
 506    implemented_type: ClassVar[Literal["batch"]] = "batch"
 507    if TYPE_CHECKING:
 508        type: Literal["batch"] = "batch"
 509    else:
 510        type: Literal["batch"]
 511
 512    id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID
 513    size: Optional[Literal[1]] = None
 514    """The batch size may be fixed to 1,
 515    otherwise (the default) it may be chosen arbitrarily depending on available memory"""
 516
 517    @property
 518    def scale(self):
 519        return 1.0
 520
 521    @property
 522    def concatenable(self):
 523        return True
 524
 525    @property
 526    def unit(self):
 527        return None
 528
 529
 530class ChannelAxis(AxisBase):
 531    implemented_type: ClassVar[Literal["channel"]] = "channel"
 532    if TYPE_CHECKING:
 533        type: Literal["channel"] = "channel"
 534    else:
 535        type: Literal["channel"]
 536
 537    id: NonBatchAxisId = AxisId("channel")
 538    channel_names: NotEmpty[List[Identifier]]
 539
 540    @property
 541    def size(self) -> int:
 542        return len(self.channel_names)
 543
 544    @property
 545    def concatenable(self):
 546        return False
 547
 548    @property
 549    def scale(self) -> float:
 550        return 1.0
 551
 552    @property
 553    def unit(self):
 554        return None
 555
 556
 557class IndexAxisBase(AxisBase):
 558    implemented_type: ClassVar[Literal["index"]] = "index"
 559    if TYPE_CHECKING:
 560        type: Literal["index"] = "index"
 561    else:
 562        type: Literal["index"]
 563
 564    id: NonBatchAxisId = AxisId("index")
 565
 566    @property
 567    def scale(self) -> float:
 568        return 1.0
 569
 570    @property
 571    def unit(self):
 572        return None
 573
 574
 575class _WithInputAxisSize(Node):
 576    size: Annotated[
 577        Union[Annotated[int, Gt(0)], ParameterizedSize, SizeReference],
 578        Field(
 579            examples=[
 580                10,
 581                ParameterizedSize(min=32, step=16).model_dump(mode="json"),
 582                SizeReference(
 583                    tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
 584                ).model_dump(mode="json"),
 585            ]
 586        ),
 587    ]
 588    """The size/length of this axis can be specified as
 589    - fixed integer
 590    - parameterized series of valid sizes (`ParameterizedSize`)
 591    - reference to another axis with an optional offset (`SizeReference`)
 592    """
 593
 594
 595class IndexInputAxis(IndexAxisBase, _WithInputAxisSize):
 596    concatenable: bool = False
 597    """If a model has a `concatenable` input axis, it can be processed blockwise,
 598    splitting a longer sample axis into blocks matching its input tensor description.
 599    Output axes are concatenable if they have a `SizeReference` to a concatenable
 600    input axis.
 601    """
 602
 603
 604class IndexOutputAxis(IndexAxisBase):
 605    size: Annotated[
 606        Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize],
 607        Field(
 608            examples=[
 609                10,
 610                SizeReference(
 611                    tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
 612                ).model_dump(mode="json"),
 613            ]
 614        ),
 615    ]
 616    """The size/length of this axis can be specified as
 617    - fixed integer
 618    - reference to another axis with an optional offset (`SizeReference`)
 619    - data dependent size using `DataDependentSize` (size is only known after model inference)
 620    """
 621
 622
 623class TimeAxisBase(AxisBase):
 624    implemented_type: ClassVar[Literal["time"]] = "time"
 625    if TYPE_CHECKING:
 626        type: Literal["time"] = "time"
 627    else:
 628        type: Literal["time"]
 629
 630    id: NonBatchAxisId = AxisId("time")
 631    unit: Optional[TimeUnit] = None
 632    scale: Annotated[float, Gt(0)] = 1.0
 633
 634
 635class TimeInputAxis(TimeAxisBase, _WithInputAxisSize):
 636    concatenable: bool = False
 637    """If a model has a `concatenable` input axis, it can be processed blockwise,
 638    splitting a longer sample axis into blocks matching its input tensor description.
 639    Output axes are concatenable if they have a `SizeReference` to a concatenable
 640    input axis.
 641    """
 642
 643
 644class SpaceAxisBase(AxisBase):
 645    implemented_type: ClassVar[Literal["space"]] = "space"
 646    if TYPE_CHECKING:
 647        type: Literal["space"] = "space"
 648    else:
 649        type: Literal["space"]
 650
 651    id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x")
 652    unit: Optional[SpaceUnit] = None
 653    scale: Annotated[float, Gt(0)] = 1.0
 654
 655
 656class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize):
 657    concatenable: bool = False
 658    """If a model has a `concatenable` input axis, it can be processed blockwise,
 659    splitting a longer sample axis into blocks matching its input tensor description.
 660    Output axes are concatenable if they have a `SizeReference` to a concatenable
 661    input axis.
 662    """
 663
 664
 665INPUT_AXIS_TYPES = (
 666    BatchAxis,
 667    ChannelAxis,
 668    IndexInputAxis,
 669    TimeInputAxis,
 670    SpaceInputAxis,
 671)
 672"""intended for isinstance comparisons in py<3.10"""
 673
 674_InputAxisUnion = Union[
 675    BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis
 676]
 677InputAxis = Annotated[_InputAxisUnion, Discriminator("type")]
 678
 679
 680class _WithOutputAxisSize(Node):
 681    size: Annotated[
 682        Union[Annotated[int, Gt(0)], SizeReference],
 683        Field(
 684            examples=[
 685                10,
 686                SizeReference(
 687                    tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
 688                ).model_dump(mode="json"),
 689            ]
 690        ),
 691    ]
 692    """The size/length of this axis can be specified as
 693    - fixed integer
 694    - reference to another axis with an optional offset (see `SizeReference`)
 695    """
 696
 697
 698class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize):
 699    pass
 700
 701
 702class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo):
 703    pass
 704
 705
 706def _get_halo_axis_discriminator_value(v: Any) -> Literal["with_halo", "wo_halo"]:
 707    if isinstance(v, dict):
 708        return "with_halo" if "halo" in v else "wo_halo"
 709    else:
 710        return "with_halo" if hasattr(v, "halo") else "wo_halo"
 711
 712
 713_TimeOutputAxisUnion = Annotated[
 714    Union[
 715        Annotated[TimeOutputAxis, Tag("wo_halo")],
 716        Annotated[TimeOutputAxisWithHalo, Tag("with_halo")],
 717    ],
 718    Discriminator(_get_halo_axis_discriminator_value),
 719]
 720
 721
 722class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize):
 723    pass
 724
 725
 726class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo):
 727    pass
 728
 729
 730_SpaceOutputAxisUnion = Annotated[
 731    Union[
 732        Annotated[SpaceOutputAxis, Tag("wo_halo")],
 733        Annotated[SpaceOutputAxisWithHalo, Tag("with_halo")],
 734    ],
 735    Discriminator(_get_halo_axis_discriminator_value),
 736]
 737
 738
 739_OutputAxisUnion = Union[
 740    BatchAxis, ChannelAxis, IndexOutputAxis, _TimeOutputAxisUnion, _SpaceOutputAxisUnion
 741]
 742OutputAxis = Annotated[_OutputAxisUnion, Discriminator("type")]
 743
 744OUTPUT_AXIS_TYPES = (
 745    BatchAxis,
 746    ChannelAxis,
 747    IndexOutputAxis,
 748    TimeOutputAxis,
 749    TimeOutputAxisWithHalo,
 750    SpaceOutputAxis,
 751    SpaceOutputAxisWithHalo,
 752)
 753"""intended for isinstance comparisons in py<3.10"""
 754
 755
 756AnyAxis = Union[InputAxis, OutputAxis]
 757
 758ANY_AXIS_TYPES = INPUT_AXIS_TYPES + OUTPUT_AXIS_TYPES
 759"""intended for isinstance comparisons in py<3.10"""
 760
 761TVs = Union[
 762    NotEmpty[List[int]],
 763    NotEmpty[List[float]],
 764    NotEmpty[List[bool]],
 765    NotEmpty[List[str]],
 766]
 767
 768
 769NominalOrOrdinalDType = Literal[
 770    "float32",
 771    "float64",
 772    "uint8",
 773    "int8",
 774    "uint16",
 775    "int16",
 776    "uint32",
 777    "int32",
 778    "uint64",
 779    "int64",
 780    "bool",
 781]
 782
 783
 784class NominalOrOrdinalDataDescr(Node):
 785    values: TVs
 786    """A fixed set of nominal or an ascending sequence of ordinal values.
 787    In this case `data.type` is required to be an unsigend integer type, e.g. 'uint8'.
 788    String `values` are interpreted as labels for tensor values 0, ..., N.
 789    Note: as YAML 1.2 does not natively support a "set" datatype,
 790    nominal values should be given as a sequence (aka list/array) as well.
 791    """
 792
 793    type: Annotated[
 794        NominalOrOrdinalDType,
 795        Field(
 796            examples=[
 797                "float32",
 798                "uint8",
 799                "uint16",
 800                "int64",
 801                "bool",
 802            ],
 803        ),
 804    ] = "uint8"
 805
 806    @model_validator(mode="after")
 807    def _validate_values_match_type(
 808        self,
 809    ) -> Self:
 810        incompatible: List[Any] = []
 811        for v in self.values:
 812            if self.type == "bool":
 813                if not isinstance(v, bool):
 814                    incompatible.append(v)
 815            elif self.type in DTYPE_LIMITS:
 816                if (
 817                    isinstance(v, (int, float))
 818                    and (
 819                        v < DTYPE_LIMITS[self.type].min
 820                        or v > DTYPE_LIMITS[self.type].max
 821                    )
 822                    or (isinstance(v, str) and "uint" not in self.type)
 823                    or (isinstance(v, float) and "int" in self.type)
 824                ):
 825                    incompatible.append(v)
 826            else:
 827                incompatible.append(v)
 828
 829            if len(incompatible) == 5:
 830                incompatible.append("...")
 831                break
 832
 833        if incompatible:
 834            raise ValueError(
 835                f"data type '{self.type}' incompatible with values {incompatible}"
 836            )
 837
 838        return self
 839
 840    unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None
 841
 842    @property
 843    def range(self):
 844        if isinstance(self.values[0], str):
 845            return 0, len(self.values) - 1
 846        else:
 847            return min(self.values), max(self.values)
 848
 849
 850IntervalOrRatioDType = Literal[
 851    "float32",
 852    "float64",
 853    "uint8",
 854    "int8",
 855    "uint16",
 856    "int16",
 857    "uint32",
 858    "int32",
 859    "uint64",
 860    "int64",
 861]
 862
 863
 864class IntervalOrRatioDataDescr(Node):
 865    type: Annotated[  # todo: rename to dtype
 866        IntervalOrRatioDType,
 867        Field(
 868            examples=["float32", "float64", "uint8", "uint16"],
 869        ),
 870    ] = "float32"
 871    range: Tuple[Optional[float], Optional[float]] = (
 872        None,
 873        None,
 874    )
 875    """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor.
 876    `None` corresponds to min/max of what can be expressed by **type**."""
 877    unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit"
 878    scale: float = 1.0
 879    """Scale for data on an interval (or ratio) scale."""
 880    offset: Optional[float] = None
 881    """Offset for data on a ratio scale."""
 882
 883    @model_validator(mode="before")
 884    def _replace_inf(cls, data: Any):
 885        if is_dict(data):
 886            if "range" in data and is_sequence(data["range"]):
 887                forbidden = (
 888                    "inf",
 889                    "-inf",
 890                    ".inf",
 891                    "-.inf",
 892                    float("inf"),
 893                    float("-inf"),
 894                )
 895                if any(v in forbidden for v in data["range"]):
 896                    issue_warning("replaced 'inf' value", value=data["range"])
 897
 898                data["range"] = tuple(
 899                    (None if v in forbidden else v) for v in data["range"]
 900                )
 901
 902        return data
 903
 904
 905TensorDataDescr = Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr]
 906
 907
 908class ProcessingDescrBase(NodeWithExplicitlySetFields, ABC):
 909    """processing base class"""
 910
 911
 912class BinarizeKwargs(ProcessingKwargs):
 913    """key word arguments for `BinarizeDescr`"""
 914
 915    threshold: float
 916    """The fixed threshold"""
 917
 918
 919class BinarizeAlongAxisKwargs(ProcessingKwargs):
 920    """key word arguments for `BinarizeDescr`"""
 921
 922    threshold: NotEmpty[List[float]]
 923    """The fixed threshold values along `axis`"""
 924
 925    axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
 926    """The `threshold` axis"""
 927
 928
 929class BinarizeDescr(ProcessingDescrBase):
 930    """Binarize the tensor with a fixed threshold.
 931
 932    Values above `BinarizeKwargs.threshold`/`BinarizeAlongAxisKwargs.threshold`
 933    will be set to one, values below the threshold to zero.
 934
 935    Examples:
 936    - in YAML
 937        ```yaml
 938        postprocessing:
 939          - id: binarize
 940            kwargs:
 941              axis: 'channel'
 942              threshold: [0.25, 0.5, 0.75]
 943        ```
 944    - in Python:
 945        >>> postprocessing = [BinarizeDescr(
 946        ...   kwargs=BinarizeAlongAxisKwargs(
 947        ...       axis=AxisId('channel'),
 948        ...       threshold=[0.25, 0.5, 0.75],
 949        ...   )
 950        ... )]
 951    """
 952
 953    implemented_id: ClassVar[Literal["binarize"]] = "binarize"
 954    if TYPE_CHECKING:
 955        id: Literal["binarize"] = "binarize"
 956    else:
 957        id: Literal["binarize"]
 958    kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs]
 959
 960
 961class ClipDescr(ProcessingDescrBase):
 962    """Set tensor values below min to min and above max to max.
 963
 964    See `ScaleRangeDescr` for examples.
 965    """
 966
 967    implemented_id: ClassVar[Literal["clip"]] = "clip"
 968    if TYPE_CHECKING:
 969        id: Literal["clip"] = "clip"
 970    else:
 971        id: Literal["clip"]
 972
 973    kwargs: ClipKwargs
 974
 975
 976class EnsureDtypeKwargs(ProcessingKwargs):
 977    """key word arguments for `EnsureDtypeDescr`"""
 978
 979    dtype: Literal[
 980        "float32",
 981        "float64",
 982        "uint8",
 983        "int8",
 984        "uint16",
 985        "int16",
 986        "uint32",
 987        "int32",
 988        "uint64",
 989        "int64",
 990        "bool",
 991    ]
 992
 993
 994class EnsureDtypeDescr(ProcessingDescrBase):
 995    """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching).
 996
 997    This can for example be used to ensure the inner neural network model gets a
 998    different input tensor data type than the fully described bioimage.io model does.
 999
1000    Examples:
1001        The described bioimage.io model (incl. preprocessing) accepts any
1002        float32-compatible tensor, normalizes it with percentiles and clipping and then
1003        casts it to uint8, which is what the neural network in this example expects.
1004        - in YAML
1005            ```yaml
1006            inputs:
1007            - data:
1008                type: float32  # described bioimage.io model is compatible with any float32 input tensor
1009              preprocessing:
1010              - id: scale_range
1011                  kwargs:
1012                  axes: ['y', 'x']
1013                  max_percentile: 99.8
1014                  min_percentile: 5.0
1015              - id: clip
1016                  kwargs:
1017                  min: 0.0
1018                  max: 1.0
1019              - id: ensure_dtype  # the neural network of the model requires uint8
1020                  kwargs:
1021                  dtype: uint8
1022            ```
1023        - in Python:
1024            >>> preprocessing = [
1025            ...     ScaleRangeDescr(
1026            ...         kwargs=ScaleRangeKwargs(
1027            ...           axes= (AxisId('y'), AxisId('x')),
1028            ...           max_percentile= 99.8,
1029            ...           min_percentile= 5.0,
1030            ...         )
1031            ...     ),
1032            ...     ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)),
1033            ...     EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")),
1034            ... ]
1035    """
1036
1037    implemented_id: ClassVar[Literal["ensure_dtype"]] = "ensure_dtype"
1038    if TYPE_CHECKING:
1039        id: Literal["ensure_dtype"] = "ensure_dtype"
1040    else:
1041        id: Literal["ensure_dtype"]
1042
1043    kwargs: EnsureDtypeKwargs
1044
1045
1046class ScaleLinearKwargs(ProcessingKwargs):
1047    """Key word arguments for `ScaleLinearDescr`"""
1048
1049    gain: float = 1.0
1050    """multiplicative factor"""
1051
1052    offset: float = 0.0
1053    """additive term"""
1054
1055    @model_validator(mode="after")
1056    def _validate(self) -> Self:
1057        if self.gain == 1.0 and self.offset == 0.0:
1058            raise ValueError(
1059                "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
1060                + " != 0.0."
1061            )
1062
1063        return self
1064
1065
1066class ScaleLinearAlongAxisKwargs(ProcessingKwargs):
1067    """Key word arguments for `ScaleLinearDescr`"""
1068
1069    axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
1070    """The axis of gain and offset values."""
1071
1072    gain: Union[float, NotEmpty[List[float]]] = 1.0
1073    """multiplicative factor"""
1074
1075    offset: Union[float, NotEmpty[List[float]]] = 0.0
1076    """additive term"""
1077
1078    @model_validator(mode="after")
1079    def _validate(self) -> Self:
1080
1081        if isinstance(self.gain, list):
1082            if isinstance(self.offset, list):
1083                if len(self.gain) != len(self.offset):
1084                    raise ValueError(
1085                        f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match."
1086                    )
1087            else:
1088                self.offset = [float(self.offset)] * len(self.gain)
1089        elif isinstance(self.offset, list):
1090            self.gain = [float(self.gain)] * len(self.offset)
1091        else:
1092            raise ValueError(
1093                "Do not specify an `axis` for scalar gain and offset values."
1094            )
1095
1096        if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset):
1097            raise ValueError(
1098                "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
1099                + " != 0.0."
1100            )
1101
1102        return self
1103
1104
1105class ScaleLinearDescr(ProcessingDescrBase):
1106    """Fixed linear scaling.
1107
1108    Examples:
1109      1. Scale with scalar gain and offset
1110        - in YAML
1111        ```yaml
1112        preprocessing:
1113          - id: scale_linear
1114            kwargs:
1115              gain: 2.0
1116              offset: 3.0
1117        ```
1118        - in Python:
1119        >>> preprocessing = [
1120        ...     ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0))
1121        ... ]
1122
1123      2. Independent scaling along an axis
1124        - in YAML
1125        ```yaml
1126        preprocessing:
1127          - id: scale_linear
1128            kwargs:
1129              axis: 'channel'
1130              gain: [1.0, 2.0, 3.0]
1131        ```
1132        - in Python:
1133        >>> preprocessing = [
1134        ...     ScaleLinearDescr(
1135        ...         kwargs=ScaleLinearAlongAxisKwargs(
1136        ...             axis=AxisId("channel"),
1137        ...             gain=[1.0, 2.0, 3.0],
1138        ...         )
1139        ...     )
1140        ... ]
1141
1142    """
1143
1144    implemented_id: ClassVar[Literal["scale_linear"]] = "scale_linear"
1145    if TYPE_CHECKING:
1146        id: Literal["scale_linear"] = "scale_linear"
1147    else:
1148        id: Literal["scale_linear"]
1149    kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs]
1150
1151
1152class SigmoidDescr(ProcessingDescrBase):
1153    """The logistic sigmoid funciton, a.k.a. expit function.
1154
1155    Examples:
1156    - in YAML
1157        ```yaml
1158        postprocessing:
1159          - id: sigmoid
1160        ```
1161    - in Python:
1162        >>> postprocessing = [SigmoidDescr()]
1163    """
1164
1165    implemented_id: ClassVar[Literal["sigmoid"]] = "sigmoid"
1166    if TYPE_CHECKING:
1167        id: Literal["sigmoid"] = "sigmoid"
1168    else:
1169        id: Literal["sigmoid"]
1170
1171    @property
1172    def kwargs(self) -> ProcessingKwargs:
1173        """empty kwargs"""
1174        return ProcessingKwargs()
1175
1176
1177class FixedZeroMeanUnitVarianceKwargs(ProcessingKwargs):
1178    """key word arguments for `FixedZeroMeanUnitVarianceDescr`"""
1179
1180    mean: float
1181    """The mean value to normalize with."""
1182
1183    std: Annotated[float, Ge(1e-6)]
1184    """The standard deviation value to normalize with."""
1185
1186
1187class FixedZeroMeanUnitVarianceAlongAxisKwargs(ProcessingKwargs):
1188    """key word arguments for `FixedZeroMeanUnitVarianceDescr`"""
1189
1190    mean: NotEmpty[List[float]]
1191    """The mean value(s) to normalize with."""
1192
1193    std: NotEmpty[List[Annotated[float, Ge(1e-6)]]]
1194    """The standard deviation value(s) to normalize with.
1195    Size must match `mean` values."""
1196
1197    axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])]
1198    """The axis of the mean/std values to normalize each entry along that dimension
1199    separately."""
1200
1201    @model_validator(mode="after")
1202    def _mean_and_std_match(self) -> Self:
1203        if len(self.mean) != len(self.std):
1204            raise ValueError(
1205                f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})"
1206                + " must match."
1207            )
1208
1209        return self
1210
1211
1212class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1213    """Subtract a given mean and divide by the standard deviation.
1214
1215    Normalize with fixed, precomputed values for
1216    `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std`
1217    Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given
1218    axes.
1219
1220    Examples:
1221    1. scalar value for whole tensor
1222        - in YAML
1223        ```yaml
1224        preprocessing:
1225          - id: fixed_zero_mean_unit_variance
1226            kwargs:
1227              mean: 103.5
1228              std: 13.7
1229        ```
1230        - in Python
1231        >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
1232        ...   kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7)
1233        ... )]
1234
1235    2. independently along an axis
1236        - in YAML
1237        ```yaml
1238        preprocessing:
1239          - id: fixed_zero_mean_unit_variance
1240            kwargs:
1241              axis: channel
1242              mean: [101.5, 102.5, 103.5]
1243              std: [11.7, 12.7, 13.7]
1244        ```
1245        - in Python
1246        >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
1247        ...   kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
1248        ...     axis=AxisId("channel"),
1249        ...     mean=[101.5, 102.5, 103.5],
1250        ...     std=[11.7, 12.7, 13.7],
1251        ...   )
1252        ... )]
1253    """
1254
1255    implemented_id: ClassVar[Literal["fixed_zero_mean_unit_variance"]] = (
1256        "fixed_zero_mean_unit_variance"
1257    )
1258    if TYPE_CHECKING:
1259        id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance"
1260    else:
1261        id: Literal["fixed_zero_mean_unit_variance"]
1262
1263    kwargs: Union[
1264        FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs
1265    ]
1266
1267
1268class ZeroMeanUnitVarianceKwargs(ProcessingKwargs):
1269    """key word arguments for `ZeroMeanUnitVarianceDescr`"""
1270
1271    axes: Annotated[
1272        Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1273    ] = None
1274    """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1275    For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1276    resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1277    To normalize each sample independently leave out the 'batch' axis.
1278    Default: Scale all axes jointly."""
1279
1280    eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1281    """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`."""
1282
1283
1284class ZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1285    """Subtract mean and divide by variance.
1286
1287    Examples:
1288        Subtract tensor mean and variance
1289        - in YAML
1290        ```yaml
1291        preprocessing:
1292          - id: zero_mean_unit_variance
1293        ```
1294        - in Python
1295        >>> preprocessing = [ZeroMeanUnitVarianceDescr()]
1296    """
1297
1298    implemented_id: ClassVar[Literal["zero_mean_unit_variance"]] = (
1299        "zero_mean_unit_variance"
1300    )
1301    if TYPE_CHECKING:
1302        id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance"
1303    else:
1304        id: Literal["zero_mean_unit_variance"]
1305
1306    kwargs: ZeroMeanUnitVarianceKwargs = Field(
1307        default_factory=ZeroMeanUnitVarianceKwargs
1308    )
1309
1310
1311class ScaleRangeKwargs(ProcessingKwargs):
1312    """key word arguments for `ScaleRangeDescr`
1313
1314    For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default)
1315    this processing step normalizes data to the [0, 1] intervall.
1316    For other percentiles the normalized values will partially be outside the [0, 1]
1317    intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the
1318    normalized values to a range.
1319    """
1320
1321    axes: Annotated[
1322        Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1323    ] = None
1324    """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value.
1325    For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1326    resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1327    To normalize samples independently, leave out the "batch" axis.
1328    Default: Scale all axes jointly."""
1329
1330    min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0
1331    """The lower percentile used to determine the value to align with zero."""
1332
1333    max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0
1334    """The upper percentile used to determine the value to align with one.
1335    Has to be bigger than `min_percentile`.
1336    The range is 1 to 100 instead of 0 to 100 to avoid mistakenly
1337    accepting percentiles specified in the range 0.0 to 1.0."""
1338
1339    eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1340    """Epsilon for numeric stability.
1341    `out = (tensor - v_lower) / (v_upper - v_lower + eps)`;
1342    with `v_lower,v_upper` values at the respective percentiles."""
1343
1344    reference_tensor: Optional[TensorId] = None
1345    """Tensor ID to compute the percentiles from. Default: The tensor itself.
1346    For any tensor in `inputs` only input tensor references are allowed."""
1347
1348    @field_validator("max_percentile", mode="after")
1349    @classmethod
1350    def min_smaller_max(cls, value: float, info: ValidationInfo) -> float:
1351        if (min_p := info.data["min_percentile"]) >= value:
1352            raise ValueError(f"min_percentile {min_p} >= max_percentile {value}")
1353
1354        return value
1355
1356
1357class ScaleRangeDescr(ProcessingDescrBase):
1358    """Scale with percentiles.
1359
1360    Examples:
1361    1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0
1362        - in YAML
1363        ```yaml
1364        preprocessing:
1365          - id: scale_range
1366            kwargs:
1367              axes: ['y', 'x']
1368              max_percentile: 99.8
1369              min_percentile: 5.0
1370        ```
1371        - in Python
1372        >>> preprocessing = [
1373        ...     ScaleRangeDescr(
1374        ...         kwargs=ScaleRangeKwargs(
1375        ...           axes= (AxisId('y'), AxisId('x')),
1376        ...           max_percentile= 99.8,
1377        ...           min_percentile= 5.0,
1378        ...         )
1379        ...     ),
1380        ...     ClipDescr(
1381        ...         kwargs=ClipKwargs(
1382        ...             min=0.0,
1383        ...             max=1.0,
1384        ...         )
1385        ...     ),
1386        ... ]
1387
1388      2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles.
1389        - in YAML
1390        ```yaml
1391        preprocessing:
1392          - id: scale_range
1393            kwargs:
1394              axes: ['y', 'x']
1395              max_percentile: 99.8
1396              min_percentile: 5.0
1397                  - id: scale_range
1398           - id: clip
1399             kwargs:
1400              min: 0.0
1401              max: 1.0
1402        ```
1403        - in Python
1404        >>> preprocessing = [ScaleRangeDescr(
1405        ...   kwargs=ScaleRangeKwargs(
1406        ...       axes= (AxisId('y'), AxisId('x')),
1407        ...       max_percentile= 99.8,
1408        ...       min_percentile= 5.0,
1409        ...   )
1410        ... )]
1411
1412    """
1413
1414    implemented_id: ClassVar[Literal["scale_range"]] = "scale_range"
1415    if TYPE_CHECKING:
1416        id: Literal["scale_range"] = "scale_range"
1417    else:
1418        id: Literal["scale_range"]
1419    kwargs: ScaleRangeKwargs
1420
1421
1422class ScaleMeanVarianceKwargs(ProcessingKwargs):
1423    """key word arguments for `ScaleMeanVarianceKwargs`"""
1424
1425    reference_tensor: TensorId
1426    """Name of tensor to match."""
1427
1428    axes: Annotated[
1429        Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1430    ] = None
1431    """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1432    For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1433    resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1434    To normalize samples independently, leave out the 'batch' axis.
1435    Default: Scale all axes jointly."""
1436
1437    eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1438    """Epsilon for numeric stability:
1439    `out  = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`"""
1440
1441
1442class ScaleMeanVarianceDescr(ProcessingDescrBase):
1443    """Scale a tensor's data distribution to match another tensor's mean/std.
1444    `out  = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`
1445    """
1446
1447    implemented_id: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance"
1448    if TYPE_CHECKING:
1449        id: Literal["scale_mean_variance"] = "scale_mean_variance"
1450    else:
1451        id: Literal["scale_mean_variance"]
1452    kwargs: ScaleMeanVarianceKwargs
1453
1454
1455PreprocessingDescr = Annotated[
1456    Union[
1457        BinarizeDescr,
1458        ClipDescr,
1459        EnsureDtypeDescr,
1460        ScaleLinearDescr,
1461        SigmoidDescr,
1462        FixedZeroMeanUnitVarianceDescr,
1463        ZeroMeanUnitVarianceDescr,
1464        ScaleRangeDescr,
1465    ],
1466    Discriminator("id"),
1467]
1468PostprocessingDescr = Annotated[
1469    Union[
1470        BinarizeDescr,
1471        ClipDescr,
1472        EnsureDtypeDescr,
1473        ScaleLinearDescr,
1474        SigmoidDescr,
1475        FixedZeroMeanUnitVarianceDescr,
1476        ZeroMeanUnitVarianceDescr,
1477        ScaleRangeDescr,
1478        ScaleMeanVarianceDescr,
1479    ],
1480    Discriminator("id"),
1481]
1482
1483IO_AxisT = TypeVar("IO_AxisT", InputAxis, OutputAxis)
1484
1485
1486class TensorDescrBase(Node, Generic[IO_AxisT]):
1487    id: TensorId
1488    """Tensor id. No duplicates are allowed."""
1489
1490    description: Annotated[str, MaxLen(128)] = ""
1491    """free text description"""
1492
1493    axes: NotEmpty[Sequence[IO_AxisT]]
1494    """tensor axes"""
1495
1496    @property
1497    def shape(self):
1498        return tuple(a.size for a in self.axes)
1499
1500    @field_validator("axes", mode="after", check_fields=False)
1501    @classmethod
1502    def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]:
1503        batch_axes = [a for a in axes if a.type == "batch"]
1504        if len(batch_axes) > 1:
1505            raise ValueError(
1506                f"Only one batch axis (per tensor) allowed, but got {batch_axes}"
1507            )
1508
1509        seen_ids: Set[AxisId] = set()
1510        duplicate_axes_ids: Set[AxisId] = set()
1511        for a in axes:
1512            (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id)
1513
1514        if duplicate_axes_ids:
1515            raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}")
1516
1517        return axes
1518
1519    test_tensor: FileDescr_
1520    """An example tensor to use for testing.
1521    Using the model with the test input tensors is expected to yield the test output tensors.
1522    Each test tensor has be a an ndarray in the
1523    [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format).
1524    The file extension must be '.npy'."""
1525
1526    sample_tensor: Optional[FileDescr_] = None
1527    """A sample tensor to illustrate a possible input/output for the model,
1528    The sample image primarily serves to inform a human user about an example use case
1529    and is typically stored as .hdf5, .png or .tiff.
1530    It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats)
1531    (numpy's `.npy` format is not supported).
1532    The image dimensionality has to match the number of axes specified in this tensor description.
1533    """
1534
1535    @model_validator(mode="after")
1536    def _validate_sample_tensor(self) -> Self:
1537        if self.sample_tensor is None or not get_validation_context().perform_io_checks:
1538            return self
1539
1540        reader = get_reader(self.sample_tensor.source, sha256=self.sample_tensor.sha256)
1541        tensor: NDArray[Any] = imread(
1542            reader.read(),
1543            extension=PurePosixPath(reader.original_file_name).suffix,
1544        )
1545        n_dims = len(tensor.squeeze().shape)
1546        n_dims_min = n_dims_max = len(self.axes)
1547
1548        for a in self.axes:
1549            if isinstance(a, BatchAxis):
1550                n_dims_min -= 1
1551            elif isinstance(a.size, int):
1552                if a.size == 1:
1553                    n_dims_min -= 1
1554            elif isinstance(a.size, (ParameterizedSize, DataDependentSize)):
1555                if a.size.min == 1:
1556                    n_dims_min -= 1
1557            elif isinstance(a.size, SizeReference):
1558                if a.size.offset < 2:
1559                    # size reference may result in singleton axis
1560                    n_dims_min -= 1
1561            else:
1562                assert_never(a.size)
1563
1564        n_dims_min = max(0, n_dims_min)
1565        if n_dims < n_dims_min or n_dims > n_dims_max:
1566            raise ValueError(
1567                f"Expected sample tensor to have {n_dims_min} to"
1568                + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})."
1569            )
1570
1571        return self
1572
1573    data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = (
1574        IntervalOrRatioDataDescr()
1575    )
1576    """Description of the tensor's data values, optionally per channel.
1577    If specified per channel, the data `type` needs to match across channels."""
1578
1579    @property
1580    def dtype(
1581        self,
1582    ) -> Literal[
1583        "float32",
1584        "float64",
1585        "uint8",
1586        "int8",
1587        "uint16",
1588        "int16",
1589        "uint32",
1590        "int32",
1591        "uint64",
1592        "int64",
1593        "bool",
1594    ]:
1595        """dtype as specified under `data.type` or `data[i].type`"""
1596        if isinstance(self.data, collections.abc.Sequence):
1597            return self.data[0].type
1598        else:
1599            return self.data.type
1600
1601    @field_validator("data", mode="after")
1602    @classmethod
1603    def _check_data_type_across_channels(
1604        cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]
1605    ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]:
1606        if not isinstance(value, list):
1607            return value
1608
1609        dtypes = {t.type for t in value}
1610        if len(dtypes) > 1:
1611            raise ValueError(
1612                "Tensor data descriptions per channel need to agree in their data"
1613                + f" `type`, but found {dtypes}."
1614            )
1615
1616        return value
1617
1618    @model_validator(mode="after")
1619    def _check_data_matches_channelaxis(self) -> Self:
1620        if not isinstance(self.data, (list, tuple)):
1621            return self
1622
1623        for a in self.axes:
1624            if isinstance(a, ChannelAxis):
1625                size = a.size
1626                assert isinstance(size, int)
1627                break
1628        else:
1629            return self
1630
1631        if len(self.data) != size:
1632            raise ValueError(
1633                f"Got tensor data descriptions for {len(self.data)} channels, but"
1634                + f" '{a.id}' axis has size {size}."
1635            )
1636
1637        return self
1638
1639    def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]:
1640        if len(array.shape) != len(self.axes):
1641            raise ValueError(
1642                f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})"
1643                + f" incompatible with {len(self.axes)} axes."
1644            )
1645        return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
1646
1647
1648class InputTensorDescr(TensorDescrBase[InputAxis]):
1649    id: TensorId = TensorId("input")
1650    """Input tensor id.
1651    No duplicates are allowed across all inputs and outputs."""
1652
1653    optional: bool = False
1654    """indicates that this tensor may be `None`"""
1655
1656    preprocessing: List[PreprocessingDescr] = Field(
1657        default_factory=cast(Callable[[], List[PreprocessingDescr]], list)
1658    )
1659
1660    """Description of how this input should be preprocessed.
1661
1662    notes:
1663    - If preprocessing does not start with an 'ensure_dtype' entry, it is added
1664      to ensure an input tensor's data type matches the input tensor's data description.
1665    - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an
1666      'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally
1667      changing the data type.
1668    """
1669
1670    @model_validator(mode="after")
1671    def _validate_preprocessing_kwargs(self) -> Self:
1672        axes_ids = [a.id for a in self.axes]
1673        for p in self.preprocessing:
1674            kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
1675            if kwargs_axes is None:
1676                continue
1677
1678            if not isinstance(kwargs_axes, collections.abc.Sequence):
1679                raise ValueError(
1680                    f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}"
1681                )
1682
1683            if any(a not in axes_ids for a in kwargs_axes):
1684                raise ValueError(
1685                    "`preprocessing.i.kwargs.axes` needs to be subset of axes ids"
1686                )
1687
1688        if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
1689            dtype = self.data.type
1690        else:
1691            dtype = self.data[0].type
1692
1693        # ensure `preprocessing` begins with `EnsureDtypeDescr`
1694        if not self.preprocessing or not isinstance(
1695            self.preprocessing[0], EnsureDtypeDescr
1696        ):
1697            self.preprocessing.insert(
1698                0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1699            )
1700
1701        # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
1702        if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)):
1703            self.preprocessing.append(
1704                EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1705            )
1706
1707        return self
1708
1709
1710def convert_axes(
1711    axes: str,
1712    *,
1713    shape: Union[
1714        Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4
1715    ],
1716    tensor_type: Literal["input", "output"],
1717    halo: Optional[Sequence[int]],
1718    size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
1719):
1720    ret: List[AnyAxis] = []
1721    for i, a in enumerate(axes):
1722        axis_type = _AXIS_TYPE_MAP.get(a, a)
1723        if axis_type == "batch":
1724            ret.append(BatchAxis())
1725            continue
1726
1727        scale = 1.0
1728        if isinstance(shape, _ParameterizedInputShape_v0_4):
1729            if shape.step[i] == 0:
1730                size = shape.min[i]
1731            else:
1732                size = ParameterizedSize(min=shape.min[i], step=shape.step[i])
1733        elif isinstance(shape, _ImplicitOutputShape_v0_4):
1734            ref_t = str(shape.reference_tensor)
1735            if ref_t.count(".") == 1:
1736                t_id, orig_a_id = ref_t.split(".")
1737            else:
1738                t_id = ref_t
1739                orig_a_id = a
1740
1741            a_id = _AXIS_ID_MAP.get(orig_a_id, a)
1742            if not (orig_scale := shape.scale[i]):
1743                # old way to insert a new axis dimension
1744                size = int(2 * shape.offset[i])
1745            else:
1746                scale = 1 / orig_scale
1747                if axis_type in ("channel", "index"):
1748                    # these axes no longer have a scale
1749                    offset_from_scale = orig_scale * size_refs.get(
1750                        _TensorName_v0_4(t_id), {}
1751                    ).get(orig_a_id, 0)
1752                else:
1753                    offset_from_scale = 0
1754                size = SizeReference(
1755                    tensor_id=TensorId(t_id),
1756                    axis_id=AxisId(a_id),
1757                    offset=int(offset_from_scale + 2 * shape.offset[i]),
1758                )
1759        else:
1760            size = shape[i]
1761
1762        if axis_type == "time":
1763            if tensor_type == "input":
1764                ret.append(TimeInputAxis(size=size, scale=scale))
1765            else:
1766                assert not isinstance(size, ParameterizedSize)
1767                if halo is None:
1768                    ret.append(TimeOutputAxis(size=size, scale=scale))
1769                else:
1770                    assert not isinstance(size, int)
1771                    ret.append(
1772                        TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i])
1773                    )
1774
1775        elif axis_type == "index":
1776            if tensor_type == "input":
1777                ret.append(IndexInputAxis(size=size))
1778            else:
1779                if isinstance(size, ParameterizedSize):
1780                    size = DataDependentSize(min=size.min)
1781
1782                ret.append(IndexOutputAxis(size=size))
1783        elif axis_type == "channel":
1784            assert not isinstance(size, ParameterizedSize)
1785            if isinstance(size, SizeReference):
1786                warnings.warn(
1787                    "Conversion of channel size from an implicit output shape may be"
1788                    + " wrong"
1789                )
1790                ret.append(
1791                    ChannelAxis(
1792                        channel_names=[
1793                            Identifier(f"channel{i}") for i in range(size.offset)
1794                        ]
1795                    )
1796                )
1797            else:
1798                ret.append(
1799                    ChannelAxis(
1800                        channel_names=[Identifier(f"channel{i}") for i in range(size)]
1801                    )
1802                )
1803        elif axis_type == "space":
1804            if tensor_type == "input":
1805                ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale))
1806            else:
1807                assert not isinstance(size, ParameterizedSize)
1808                if halo is None or halo[i] == 0:
1809                    ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale))
1810                elif isinstance(size, int):
1811                    raise NotImplementedError(
1812                        f"output axis with halo and fixed size (here {size}) not allowed"
1813                    )
1814                else:
1815                    ret.append(
1816                        SpaceOutputAxisWithHalo(
1817                            id=AxisId(a), size=size, scale=scale, halo=halo[i]
1818                        )
1819                    )
1820
1821    return ret
1822
1823
1824def _axes_letters_to_ids(
1825    axes: Optional[str],
1826) -> Optional[List[AxisId]]:
1827    if axes is None:
1828        return None
1829
1830    return [AxisId(a) for a in axes]
1831
1832
1833def _get_complement_v04_axis(
1834    tensor_axes: Sequence[str], axes: Optional[Sequence[str]]
1835) -> Optional[AxisId]:
1836    if axes is None:
1837        return None
1838
1839    non_complement_axes = set(axes) | {"b"}
1840    complement_axes = [a for a in tensor_axes if a not in non_complement_axes]
1841    if len(complement_axes) > 1:
1842        raise ValueError(
1843            f"Expected none or a single complement axis, but axes '{axes}' "
1844            + f"for tensor dims '{tensor_axes}' leave '{complement_axes}'."
1845        )
1846
1847    return None if not complement_axes else AxisId(complement_axes[0])
1848
1849
1850def _convert_proc(
1851    p: Union[_PreprocessingDescr_v0_4, _PostprocessingDescr_v0_4],
1852    tensor_axes: Sequence[str],
1853) -> Union[PreprocessingDescr, PostprocessingDescr]:
1854    if isinstance(p, _BinarizeDescr_v0_4):
1855        return BinarizeDescr(kwargs=BinarizeKwargs(threshold=p.kwargs.threshold))
1856    elif isinstance(p, _ClipDescr_v0_4):
1857        return ClipDescr(kwargs=ClipKwargs(min=p.kwargs.min, max=p.kwargs.max))
1858    elif isinstance(p, _SigmoidDescr_v0_4):
1859        return SigmoidDescr()
1860    elif isinstance(p, _ScaleLinearDescr_v0_4):
1861        axes = _axes_letters_to_ids(p.kwargs.axes)
1862        if p.kwargs.axes is None:
1863            axis = None
1864        else:
1865            axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes)
1866
1867        if axis is None:
1868            assert not isinstance(p.kwargs.gain, list)
1869            assert not isinstance(p.kwargs.offset, list)
1870            kwargs = ScaleLinearKwargs(gain=p.kwargs.gain, offset=p.kwargs.offset)
1871        else:
1872            kwargs = ScaleLinearAlongAxisKwargs(
1873                axis=axis, gain=p.kwargs.gain, offset=p.kwargs.offset
1874            )
1875        return ScaleLinearDescr(kwargs=kwargs)
1876    elif isinstance(p, _ScaleMeanVarianceDescr_v0_4):
1877        return ScaleMeanVarianceDescr(
1878            kwargs=ScaleMeanVarianceKwargs(
1879                axes=_axes_letters_to_ids(p.kwargs.axes),
1880                reference_tensor=TensorId(str(p.kwargs.reference_tensor)),
1881                eps=p.kwargs.eps,
1882            )
1883        )
1884    elif isinstance(p, _ZeroMeanUnitVarianceDescr_v0_4):
1885        if p.kwargs.mode == "fixed":
1886            mean = p.kwargs.mean
1887            std = p.kwargs.std
1888            assert mean is not None
1889            assert std is not None
1890
1891            axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes)
1892
1893            if axis is None:
1894                return FixedZeroMeanUnitVarianceDescr(
1895                    kwargs=FixedZeroMeanUnitVarianceKwargs(
1896                        mean=mean, std=std  # pyright: ignore[reportArgumentType]
1897                    )
1898                )
1899            else:
1900                if not isinstance(mean, list):
1901                    mean = [float(mean)]
1902                if not isinstance(std, list):
1903                    std = [float(std)]
1904
1905                return FixedZeroMeanUnitVarianceDescr(
1906                    kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
1907                        axis=axis, mean=mean, std=std
1908                    )
1909                )
1910
1911        else:
1912            axes = _axes_letters_to_ids(p.kwargs.axes) or []
1913            if p.kwargs.mode == "per_dataset":
1914                axes = [AxisId("batch")] + axes
1915            if not axes:
1916                axes = None
1917            return ZeroMeanUnitVarianceDescr(
1918                kwargs=ZeroMeanUnitVarianceKwargs(axes=axes, eps=p.kwargs.eps)
1919            )
1920
1921    elif isinstance(p, _ScaleRangeDescr_v0_4):
1922        return ScaleRangeDescr(
1923            kwargs=ScaleRangeKwargs(
1924                axes=_axes_letters_to_ids(p.kwargs.axes),
1925                min_percentile=p.kwargs.min_percentile,
1926                max_percentile=p.kwargs.max_percentile,
1927                eps=p.kwargs.eps,
1928            )
1929        )
1930    else:
1931        assert_never(p)
1932
1933
1934class _InputTensorConv(
1935    Converter[
1936        _InputTensorDescr_v0_4,
1937        InputTensorDescr,
1938        FileSource_,
1939        Optional[FileSource_],
1940        Mapping[_TensorName_v0_4, Mapping[str, int]],
1941    ]
1942):
1943    def _convert(
1944        self,
1945        src: _InputTensorDescr_v0_4,
1946        tgt: "type[InputTensorDescr] | type[dict[str, Any]]",
1947        test_tensor: FileSource_,
1948        sample_tensor: Optional[FileSource_],
1949        size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
1950    ) -> "InputTensorDescr | dict[str, Any]":
1951        axes: List[InputAxis] = convert_axes(  # pyright: ignore[reportAssignmentType]
1952            src.axes,
1953            shape=src.shape,
1954            tensor_type="input",
1955            halo=None,
1956            size_refs=size_refs,
1957        )
1958        prep: List[PreprocessingDescr] = []
1959        for p in src.preprocessing:
1960            cp = _convert_proc(p, src.axes)
1961            assert not isinstance(cp, ScaleMeanVarianceDescr)
1962            prep.append(cp)
1963
1964        prep.append(EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="float32")))
1965
1966        return tgt(
1967            axes=axes,
1968            id=TensorId(str(src.name)),
1969            test_tensor=FileDescr(source=test_tensor),
1970            sample_tensor=(
1971                None if sample_tensor is None else FileDescr(source=sample_tensor)
1972            ),
1973            data=dict(type=src.data_type),  # pyright: ignore[reportArgumentType]
1974            preprocessing=prep,
1975        )
1976
1977
1978_input_tensor_conv = _InputTensorConv(_InputTensorDescr_v0_4, InputTensorDescr)
1979
1980
1981class OutputTensorDescr(TensorDescrBase[OutputAxis]):
1982    id: TensorId = TensorId("output")
1983    """Output tensor id.
1984    No duplicates are allowed across all inputs and outputs."""
1985
1986    postprocessing: List[PostprocessingDescr] = Field(
1987        default_factory=cast(Callable[[], List[PostprocessingDescr]], list)
1988    )
1989    """Description of how this output should be postprocessed.
1990
1991    note: `postprocessing` always ends with an 'ensure_dtype' operation.
1992          If not given this is added to cast to this tensor's `data.type`.
1993    """
1994
1995    @model_validator(mode="after")
1996    def _validate_postprocessing_kwargs(self) -> Self:
1997        axes_ids = [a.id for a in self.axes]
1998        for p in self.postprocessing:
1999            kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
2000            if kwargs_axes is None:
2001                continue
2002
2003            if not isinstance(kwargs_axes, collections.abc.Sequence):
2004                raise ValueError(
2005                    f"expected `axes` sequence, but got {type(kwargs_axes)}"
2006                )
2007
2008            if any(a not in axes_ids for a in kwargs_axes):
2009                raise ValueError("`kwargs.axes` needs to be subset of axes ids")
2010
2011        if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
2012            dtype = self.data.type
2013        else:
2014            dtype = self.data[0].type
2015
2016        # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
2017        if not self.postprocessing or not isinstance(
2018            self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)
2019        ):
2020            self.postprocessing.append(
2021                EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
2022            )
2023        return self
2024
2025
2026class _OutputTensorConv(
2027    Converter[
2028        _OutputTensorDescr_v0_4,
2029        OutputTensorDescr,
2030        FileSource_,
2031        Optional[FileSource_],
2032        Mapping[_TensorName_v0_4, Mapping[str, int]],
2033    ]
2034):
2035    def _convert(
2036        self,
2037        src: _OutputTensorDescr_v0_4,
2038        tgt: "type[OutputTensorDescr] | type[dict[str, Any]]",
2039        test_tensor: FileSource_,
2040        sample_tensor: Optional[FileSource_],
2041        size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
2042    ) -> "OutputTensorDescr | dict[str, Any]":
2043        # TODO: split convert_axes into convert_output_axes and convert_input_axes
2044        axes: List[OutputAxis] = convert_axes(  # pyright: ignore[reportAssignmentType]
2045            src.axes,
2046            shape=src.shape,
2047            tensor_type="output",
2048            halo=src.halo,
2049            size_refs=size_refs,
2050        )
2051        data_descr: Dict[str, Any] = dict(type=src.data_type)
2052        if data_descr["type"] == "bool":
2053            data_descr["values"] = [False, True]
2054
2055        return tgt(
2056            axes=axes,
2057            id=TensorId(str(src.name)),
2058            test_tensor=FileDescr(source=test_tensor),
2059            sample_tensor=(
2060                None if sample_tensor is None else FileDescr(source=sample_tensor)
2061            ),
2062            data=data_descr,  # pyright: ignore[reportArgumentType]
2063            postprocessing=[_convert_proc(p, src.axes) for p in src.postprocessing],
2064        )
2065
2066
2067_output_tensor_conv = _OutputTensorConv(_OutputTensorDescr_v0_4, OutputTensorDescr)
2068
2069
2070TensorDescr = Union[InputTensorDescr, OutputTensorDescr]
2071
2072
2073def validate_tensors(
2074    tensors: Mapping[TensorId, Tuple[TensorDescr, NDArray[Any]]],
2075    tensor_origin: Literal[
2076        "test_tensor"
2077    ],  # for more precise error messages, e.g. 'test_tensor'
2078):
2079    all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, int]]] = {}
2080
2081    def e_msg(d: TensorDescr):
2082        return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]"
2083
2084    for descr, array in tensors.values():
2085        try:
2086            axis_sizes = descr.get_axis_sizes_for_array(array)
2087        except ValueError as e:
2088            raise ValueError(f"{e_msg(descr)} {e}")
2089        else:
2090            all_tensor_axes[descr.id] = {
2091                a.id: (a, axis_sizes[a.id]) for a in descr.axes
2092            }
2093
2094    for descr, array in tensors.values():
2095        if descr.dtype in ("float32", "float64"):
2096            invalid_test_tensor_dtype = array.dtype.name not in (
2097                "float32",
2098                "float64",
2099                "uint8",
2100                "int8",
2101                "uint16",
2102                "int16",
2103                "uint32",
2104                "int32",
2105                "uint64",
2106                "int64",
2107            )
2108        else:
2109            invalid_test_tensor_dtype = array.dtype.name != descr.dtype
2110
2111        if invalid_test_tensor_dtype:
2112            raise ValueError(
2113                f"{e_msg(descr)}.{tensor_origin}.dtype '{array.dtype.name}' does not"
2114                + f" match described dtype '{descr.dtype}'"
2115            )
2116
2117        if array.min() > -1e-4 and array.max() < 1e-4:
2118            raise ValueError(
2119                "Output values are too small for reliable testing."
2120                + f" Values <-1e5 or >=1e5 must be present in {tensor_origin}"
2121            )
2122
2123        for a in descr.axes:
2124            actual_size = all_tensor_axes[descr.id][a.id][1]
2125            if a.size is None:
2126                continue
2127
2128            if isinstance(a.size, int):
2129                if actual_size != a.size:
2130                    raise ValueError(
2131                        f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' "
2132                        + f"has incompatible size {actual_size}, expected {a.size}"
2133                    )
2134            elif isinstance(a.size, ParameterizedSize):
2135                _ = a.size.validate_size(actual_size)
2136            elif isinstance(a.size, DataDependentSize):
2137                _ = a.size.validate_size(actual_size)
2138            elif isinstance(a.size, SizeReference):
2139                ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id)
2140                if ref_tensor_axes is None:
2141                    raise ValueError(
2142                        f"{e_msg(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor"
2143                        + f" reference '{a.size.tensor_id}'"
2144                    )
2145
2146                ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None))
2147                if ref_axis is None or ref_size is None:
2148                    raise ValueError(
2149                        f"{e_msg(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis"
2150                        + f" reference '{a.size.tensor_id}.{a.size.axis_id}"
2151                    )
2152
2153                if a.unit != ref_axis.unit:
2154                    raise ValueError(
2155                        f"{e_msg(descr)}.axes[{a.id}].size: `SizeReference` requires"
2156                        + " axis and reference axis to have the same `unit`, but"
2157                        + f" {a.unit}!={ref_axis.unit}"
2158                    )
2159
2160                if actual_size != (
2161                    expected_size := (
2162                        ref_size * ref_axis.scale / a.scale + a.size.offset
2163                    )
2164                ):
2165                    raise ValueError(
2166                        f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' of size"
2167                        + f" {actual_size} invalid for referenced size {ref_size};"
2168                        + f" expected {expected_size}"
2169                    )
2170            else:
2171                assert_never(a.size)
2172
2173
2174FileDescr_dependencies = Annotated[
2175    FileDescr_,
2176    WithSuffix((".yaml", ".yml"), case_sensitive=True),
2177    Field(examples=[dict(source="environment.yaml")]),
2178]
2179
2180
2181class _ArchitectureCallableDescr(Node):
2182    callable: Annotated[Identifier, Field(examples=["MyNetworkClass", "get_my_model"])]
2183    """Identifier of the callable that returns a torch.nn.Module instance."""
2184
2185    kwargs: Dict[str, YamlValue] = Field(
2186        default_factory=cast(Callable[[], Dict[str, YamlValue]], dict)
2187    )
2188    """key word arguments for the `callable`"""
2189
2190
2191class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr):
2192    source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2193    """Architecture source file"""
2194
2195    @model_serializer(mode="wrap", when_used="unless-none")
2196    def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo):
2197        return package_file_descr_serializer(self, nxt, info)
2198
2199
2200class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr):
2201    import_from: str
2202    """Where to import the callable from, i.e. `from <import_from> import <callable>`"""
2203
2204
2205class _ArchFileConv(
2206    Converter[
2207        _CallableFromFile_v0_4,
2208        ArchitectureFromFileDescr,
2209        Optional[Sha256],
2210        Dict[str, Any],
2211    ]
2212):
2213    def _convert(
2214        self,
2215        src: _CallableFromFile_v0_4,
2216        tgt: "type[ArchitectureFromFileDescr | dict[str, Any]]",
2217        sha256: Optional[Sha256],
2218        kwargs: Dict[str, Any],
2219    ) -> "ArchitectureFromFileDescr | dict[str, Any]":
2220        if src.startswith("http") and src.count(":") == 2:
2221            http, source, callable_ = src.split(":")
2222            source = ":".join((http, source))
2223        elif not src.startswith("http") and src.count(":") == 1:
2224            source, callable_ = src.split(":")
2225        else:
2226            source = str(src)
2227            callable_ = str(src)
2228        return tgt(
2229            callable=Identifier(callable_),
2230            source=cast(FileSource_, source),
2231            sha256=sha256,
2232            kwargs=kwargs,
2233        )
2234
2235
2236_arch_file_conv = _ArchFileConv(_CallableFromFile_v0_4, ArchitectureFromFileDescr)
2237
2238
2239class _ArchLibConv(
2240    Converter[
2241        _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr, Dict[str, Any]
2242    ]
2243):
2244    def _convert(
2245        self,
2246        src: _CallableFromDepencency_v0_4,
2247        tgt: "type[ArchitectureFromLibraryDescr | dict[str, Any]]",
2248        kwargs: Dict[str, Any],
2249    ) -> "ArchitectureFromLibraryDescr | dict[str, Any]":
2250        *mods, callable_ = src.split(".")
2251        import_from = ".".join(mods)
2252        return tgt(
2253            import_from=import_from, callable=Identifier(callable_), kwargs=kwargs
2254        )
2255
2256
2257_arch_lib_conv = _ArchLibConv(
2258    _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr
2259)
2260
2261
2262class WeightsEntryDescrBase(FileDescr):
2263    type: ClassVar[WeightsFormat]
2264    weights_format_name: ClassVar[str]  # human readable
2265
2266    source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2267    """Source of the weights file."""
2268
2269    authors: Optional[List[Author]] = None
2270    """Authors
2271    Either the person(s) that have trained this model resulting in the original weights file.
2272        (If this is the initial weights entry, i.e. it does not have a `parent`)
2273    Or the person(s) who have converted the weights to this weights format.
2274        (If this is a child weight, i.e. it has a `parent` field)
2275    """
2276
2277    parent: Annotated[
2278        Optional[WeightsFormat], Field(examples=["pytorch_state_dict"])
2279    ] = None
2280    """The source weights these weights were converted from.
2281    For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`,
2282    The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights.
2283    All weight entries except one (the initial set of weights resulting from training the model),
2284    need to have this field."""
2285
2286    comment: str = ""
2287    """A comment about this weights entry, for example how these weights were created."""
2288
2289    @model_validator(mode="after")
2290    def _validate(self) -> Self:
2291        if self.type == self.parent:
2292            raise ValueError("Weights entry can't be it's own parent.")
2293
2294        return self
2295
2296    @model_serializer(mode="wrap", when_used="unless-none")
2297    def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo):
2298        return package_file_descr_serializer(self, nxt, info)
2299
2300
2301class KerasHdf5WeightsDescr(WeightsEntryDescrBase):
2302    type = "keras_hdf5"
2303    weights_format_name: ClassVar[str] = "Keras HDF5"
2304    tensorflow_version: Version
2305    """TensorFlow version used to create these weights."""
2306
2307
2308class OnnxWeightsDescr(WeightsEntryDescrBase):
2309    type = "onnx"
2310    weights_format_name: ClassVar[str] = "ONNX"
2311    opset_version: Annotated[int, Ge(7)]
2312    """ONNX opset version"""
2313
2314
2315class PytorchStateDictWeightsDescr(WeightsEntryDescrBase):
2316    type = "pytorch_state_dict"
2317    weights_format_name: ClassVar[str] = "Pytorch State Dict"
2318    architecture: Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr]
2319    pytorch_version: Version
2320    """Version of the PyTorch library used.
2321    If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible.
2322    """
2323    dependencies: Optional[FileDescr_dependencies] = None
2324    """Custom depencies beyond pytorch described in a Conda environment file.
2325    Allows to specify custom dependencies, see conda docs:
2326    - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms)
2327    - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually)
2328
2329    The conda environment file should include pytorch and any version pinning has to be compatible with
2330    **pytorch_version**.
2331    """
2332
2333
2334class TensorflowJsWeightsDescr(WeightsEntryDescrBase):
2335    type = "tensorflow_js"
2336    weights_format_name: ClassVar[str] = "Tensorflow.js"
2337    tensorflow_version: Version
2338    """Version of the TensorFlow library used."""
2339
2340    source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2341    """The multi-file weights.
2342    All required files/folders should be a zip archive."""
2343
2344
2345class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase):
2346    type = "tensorflow_saved_model_bundle"
2347    weights_format_name: ClassVar[str] = "Tensorflow Saved Model"
2348    tensorflow_version: Version
2349    """Version of the TensorFlow library used."""
2350
2351    dependencies: Optional[FileDescr_dependencies] = None
2352    """Custom dependencies beyond tensorflow.
2353    Should include tensorflow and any version pinning has to be compatible with **tensorflow_version**."""
2354
2355    source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2356    """The multi-file weights.
2357    All required files/folders should be a zip archive."""
2358
2359
2360class TorchscriptWeightsDescr(WeightsEntryDescrBase):
2361    type = "torchscript"
2362    weights_format_name: ClassVar[str] = "TorchScript"
2363    pytorch_version: Version
2364    """Version of the PyTorch library used."""
2365
2366
2367class WeightsDescr(Node):
2368    keras_hdf5: Optional[KerasHdf5WeightsDescr] = None
2369    onnx: Optional[OnnxWeightsDescr] = None
2370    pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None
2371    tensorflow_js: Optional[TensorflowJsWeightsDescr] = None
2372    tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = (
2373        None
2374    )
2375    torchscript: Optional[TorchscriptWeightsDescr] = None
2376
2377    @model_validator(mode="after")
2378    def check_entries(self) -> Self:
2379        entries = {wtype for wtype, entry in self if entry is not None}
2380
2381        if not entries:
2382            raise ValueError("Missing weights entry")
2383
2384        entries_wo_parent = {
2385            wtype
2386            for wtype, entry in self
2387            if entry is not None and hasattr(entry, "parent") and entry.parent is None
2388        }
2389        if len(entries_wo_parent) != 1:
2390            issue_warning(
2391                "Exactly one weights entry may not specify the `parent` field (got"
2392                + " {value}). That entry is considered the original set of model weights."
2393                + " Other weight formats are created through conversion of the orignal or"
2394                + " already converted weights. They have to reference the weights format"
2395                + " they were converted from as their `parent`.",
2396                value=len(entries_wo_parent),
2397                field="weights",
2398            )
2399
2400        for wtype, entry in self:
2401            if entry is None:
2402                continue
2403
2404            assert hasattr(entry, "type")
2405            assert hasattr(entry, "parent")
2406            assert wtype == entry.type
2407            if (
2408                entry.parent is not None and entry.parent not in entries
2409            ):  # self reference checked for `parent` field
2410                raise ValueError(
2411                    f"`weights.{wtype}.parent={entry.parent} not in specified weight"
2412                    + f" formats: {entries}"
2413                )
2414
2415        return self
2416
2417    def __getitem__(
2418        self,
2419        key: Literal[
2420            "keras_hdf5",
2421            "onnx",
2422            "pytorch_state_dict",
2423            "tensorflow_js",
2424            "tensorflow_saved_model_bundle",
2425            "torchscript",
2426        ],
2427    ):
2428        if key == "keras_hdf5":
2429            ret = self.keras_hdf5
2430        elif key == "onnx":
2431            ret = self.onnx
2432        elif key == "pytorch_state_dict":
2433            ret = self.pytorch_state_dict
2434        elif key == "tensorflow_js":
2435            ret = self.tensorflow_js
2436        elif key == "tensorflow_saved_model_bundle":
2437            ret = self.tensorflow_saved_model_bundle
2438        elif key == "torchscript":
2439            ret = self.torchscript
2440        else:
2441            raise KeyError(key)
2442
2443        if ret is None:
2444            raise KeyError(key)
2445
2446        return ret
2447
2448    @property
2449    def available_formats(self):
2450        return {
2451            **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}),
2452            **({} if self.onnx is None else {"onnx": self.onnx}),
2453            **(
2454                {}
2455                if self.pytorch_state_dict is None
2456                else {"pytorch_state_dict": self.pytorch_state_dict}
2457            ),
2458            **(
2459                {}
2460                if self.tensorflow_js is None
2461                else {"tensorflow_js": self.tensorflow_js}
2462            ),
2463            **(
2464                {}
2465                if self.tensorflow_saved_model_bundle is None
2466                else {
2467                    "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle
2468                }
2469            ),
2470            **({} if self.torchscript is None else {"torchscript": self.torchscript}),
2471        }
2472
2473    @property
2474    def missing_formats(self):
2475        return {
2476            wf for wf in get_args(WeightsFormat) if wf not in self.available_formats
2477        }
2478
2479
2480class ModelId(ResourceId):
2481    pass
2482
2483
2484class LinkedModel(LinkedResourceBase):
2485    """Reference to a bioimage.io model."""
2486
2487    id: ModelId
2488    """A valid model `id` from the bioimage.io collection."""
2489
2490
2491class _DataDepSize(NamedTuple):
2492    min: StrictInt
2493    max: Optional[StrictInt]
2494
2495
2496class _AxisSizes(NamedTuple):
2497    """the lenghts of all axes of model inputs and outputs"""
2498
2499    inputs: Dict[Tuple[TensorId, AxisId], int]
2500    outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]]
2501
2502
2503class _TensorSizes(NamedTuple):
2504    """_AxisSizes as nested dicts"""
2505
2506    inputs: Dict[TensorId, Dict[AxisId, int]]
2507    outputs: Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]
2508
2509
2510class ReproducibilityTolerance(Node, extra="allow"):
2511    """Describes what small numerical differences -- if any -- may be tolerated
2512    in the generated output when executing in different environments.
2513
2514    A tensor element *output* is considered mismatched to the **test_tensor** if
2515    abs(*output* - **test_tensor**) > **absolute_tolerance** + **relative_tolerance** * abs(**test_tensor**).
2516    (Internally we call [numpy.testing.assert_allclose](https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html).)
2517
2518    Motivation:
2519        For testing we can request the respective deep learning frameworks to be as
2520        reproducible as possible by setting seeds and chosing deterministic algorithms,
2521        but differences in operating systems, available hardware and installed drivers
2522        may still lead to numerical differences.
2523    """
2524
2525    relative_tolerance: RelativeTolerance = 1e-3
2526    """Maximum relative tolerance of reproduced test tensor."""
2527
2528    absolute_tolerance: AbsoluteTolerance = 1e-4
2529    """Maximum absolute tolerance of reproduced test tensor."""
2530
2531    mismatched_elements_per_million: MismatchedElementsPerMillion = 100
2532    """Maximum number of mismatched elements/pixels per million to tolerate."""
2533
2534    output_ids: Sequence[TensorId] = ()
2535    """Limits the output tensor IDs these reproducibility details apply to."""
2536
2537    weights_formats: Sequence[WeightsFormat] = ()
2538    """Limits the weights formats these details apply to."""
2539
2540
2541class BioimageioConfig(Node, extra="allow"):
2542    reproducibility_tolerance: Sequence[ReproducibilityTolerance] = ()
2543    """Tolerances to allow when reproducing the model's test outputs
2544    from the model's test inputs.
2545    Only the first entry matching tensor id and weights format is considered.
2546    """
2547
2548
2549class Config(Node, extra="allow"):
2550    bioimageio: BioimageioConfig = Field(default_factory=BioimageioConfig)
2551
2552
2553class ModelDescr(GenericModelDescrBase):
2554    """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights.
2555    These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
2556    """
2557
2558    implemented_format_version: ClassVar[Literal["0.5.4"]] = "0.5.4"
2559    if TYPE_CHECKING:
2560        format_version: Literal["0.5.4"] = "0.5.4"
2561    else:
2562        format_version: Literal["0.5.4"]
2563        """Version of the bioimage.io model description specification used.
2564        When creating a new model always use the latest micro/patch version described here.
2565        The `format_version` is important for any consumer software to understand how to parse the fields.
2566        """
2567
2568    implemented_type: ClassVar[Literal["model"]] = "model"
2569    if TYPE_CHECKING:
2570        type: Literal["model"] = "model"
2571    else:
2572        type: Literal["model"]
2573        """Specialized resource type 'model'"""
2574
2575    id: Optional[ModelId] = None
2576    """bioimage.io-wide unique resource identifier
2577    assigned by bioimage.io; version **un**specific."""
2578
2579    authors: NotEmpty[List[Author]]
2580    """The authors are the creators of the model RDF and the primary points of contact."""
2581
2582    documentation: FileSource_documentation
2583    """URL or relative path to a markdown file with additional documentation.
2584    The recommended documentation file name is `README.md`. An `.md` suffix is mandatory.
2585    The documentation should include a '#[#] Validation' (sub)section
2586    with details on how to quantitatively validate the model on unseen data."""
2587
2588    @field_validator("documentation", mode="after")
2589    @classmethod
2590    def _validate_documentation(
2591        cls, value: FileSource_documentation
2592    ) -> FileSource_documentation:
2593        if not get_validation_context().perform_io_checks:
2594            return value
2595
2596        doc_reader = get_reader(value)
2597        doc_content = doc_reader.read().decode(encoding="utf-8")
2598        if not re.search("#.*[vV]alidation", doc_content):
2599            issue_warning(
2600                "No '# Validation' (sub)section found in {value}.",
2601                value=value,
2602                field="documentation",
2603            )
2604
2605        return value
2606
2607    inputs: NotEmpty[Sequence[InputTensorDescr]]
2608    """Describes the input tensors expected by this model."""
2609
2610    @field_validator("inputs", mode="after")
2611    @classmethod
2612    def _validate_input_axes(
2613        cls, inputs: Sequence[InputTensorDescr]
2614    ) -> Sequence[InputTensorDescr]:
2615        input_size_refs = cls._get_axes_with_independent_size(inputs)
2616
2617        for i, ipt in enumerate(inputs):
2618            valid_independent_refs: Dict[
2619                Tuple[TensorId, AxisId],
2620                Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2621            ] = {
2622                **{
2623                    (ipt.id, a.id): (ipt, a, a.size)
2624                    for a in ipt.axes
2625                    if not isinstance(a, BatchAxis)
2626                    and isinstance(a.size, (int, ParameterizedSize))
2627                },
2628                **input_size_refs,
2629            }
2630            for a, ax in enumerate(ipt.axes):
2631                cls._validate_axis(
2632                    "inputs",
2633                    i=i,
2634                    tensor_id=ipt.id,
2635                    a=a,
2636                    axis=ax,
2637                    valid_independent_refs=valid_independent_refs,
2638                )
2639        return inputs
2640
2641    @staticmethod
2642    def _validate_axis(
2643        field_name: str,
2644        i: int,
2645        tensor_id: TensorId,
2646        a: int,
2647        axis: AnyAxis,
2648        valid_independent_refs: Dict[
2649            Tuple[TensorId, AxisId],
2650            Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2651        ],
2652    ):
2653        if isinstance(axis, BatchAxis) or isinstance(
2654            axis.size, (int, ParameterizedSize, DataDependentSize)
2655        ):
2656            return
2657        elif not isinstance(axis.size, SizeReference):
2658            assert_never(axis.size)
2659
2660        # validate axis.size SizeReference
2661        ref = (axis.size.tensor_id, axis.size.axis_id)
2662        if ref not in valid_independent_refs:
2663            raise ValueError(
2664                "Invalid tensor axis reference at"
2665                + f" {field_name}[{i}].axes[{a}].size: {axis.size}."
2666            )
2667        if ref == (tensor_id, axis.id):
2668            raise ValueError(
2669                "Self-referencing not allowed for"
2670                + f" {field_name}[{i}].axes[{a}].size: {axis.size}"
2671            )
2672        if axis.type == "channel":
2673            if valid_independent_refs[ref][1].type != "channel":
2674                raise ValueError(
2675                    "A channel axis' size may only reference another fixed size"
2676                    + " channel axis."
2677                )
2678            if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names:
2679                ref_size = valid_independent_refs[ref][2]
2680                assert isinstance(ref_size, int), (
2681                    "channel axis ref (another channel axis) has to specify fixed"
2682                    + " size"
2683                )
2684                generated_channel_names = [
2685                    Identifier(axis.channel_names.format(i=i))
2686                    for i in range(1, ref_size + 1)
2687                ]
2688                axis.channel_names = generated_channel_names
2689
2690        if (ax_unit := getattr(axis, "unit", None)) != (
2691            ref_unit := getattr(valid_independent_refs[ref][1], "unit", None)
2692        ):
2693            raise ValueError(
2694                "The units of an axis and its reference axis need to match, but"
2695                + f" '{ax_unit}' != '{ref_unit}'."
2696            )
2697        ref_axis = valid_independent_refs[ref][1]
2698        if isinstance(ref_axis, BatchAxis):
2699            raise ValueError(
2700                f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}"
2701                + " (a batch axis is not allowed as reference)."
2702            )
2703
2704        if isinstance(axis, WithHalo):
2705            min_size = axis.size.get_size(axis, ref_axis, n=0)
2706            if (min_size - 2 * axis.halo) < 1:
2707                raise ValueError(
2708                    f"axis {axis.id} with minimum size {min_size} is too small for halo"
2709                    + f" {axis.halo}."
2710                )
2711
2712            input_halo = axis.halo * axis.scale / ref_axis.scale
2713            if input_halo != int(input_halo) or input_halo % 2 == 1:
2714                raise ValueError(
2715                    f"input_halo {input_halo} (output_halo {axis.halo} *"
2716                    + f" output_scale {axis.scale} / input_scale {ref_axis.scale})"
2717                    + f"     {tensor_id}.{axis.id}."
2718                )
2719
2720    @model_validator(mode="after")
2721    def _validate_test_tensors(self) -> Self:
2722        if not get_validation_context().perform_io_checks:
2723            return self
2724
2725        test_output_arrays = [load_array(descr.test_tensor) for descr in self.outputs]
2726        test_input_arrays = [load_array(descr.test_tensor) for descr in self.inputs]
2727
2728        tensors = {
2729            descr.id: (descr, array)
2730            for descr, array in zip(
2731                chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays
2732            )
2733        }
2734        validate_tensors(tensors, tensor_origin="test_tensor")
2735
2736        output_arrays = {
2737            descr.id: array for descr, array in zip(self.outputs, test_output_arrays)
2738        }
2739        for rep_tol in self.config.bioimageio.reproducibility_tolerance:
2740            if not rep_tol.absolute_tolerance:
2741                continue
2742
2743            if rep_tol.output_ids:
2744                out_arrays = {
2745                    oid: a
2746                    for oid, a in output_arrays.items()
2747                    if oid in rep_tol.output_ids
2748                }
2749            else:
2750                out_arrays = output_arrays
2751
2752            for out_id, array in out_arrays.items():
2753                if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01:
2754                    raise ValueError(
2755                        "config.bioimageio.reproducibility_tolerance.absolute_tolerance="
2756                        + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}"
2757                        + f" (1% of the maximum value of the test tensor '{out_id}')"
2758                    )
2759
2760        return self
2761
2762    @model_validator(mode="after")
2763    def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self:
2764        ipt_refs = {t.id for t in self.inputs}
2765        out_refs = {t.id for t in self.outputs}
2766        for ipt in self.inputs:
2767            for p in ipt.preprocessing:
2768                ref = p.kwargs.get("reference_tensor")
2769                if ref is None:
2770                    continue
2771                if ref not in ipt_refs:
2772                    raise ValueError(
2773                        f"`reference_tensor` '{ref}' not found. Valid input tensor"
2774                        + f" references are: {ipt_refs}."
2775                    )
2776
2777        for out in self.outputs:
2778            for p in out.postprocessing:
2779                ref = p.kwargs.get("reference_tensor")
2780                if ref is None:
2781                    continue
2782
2783                if ref not in ipt_refs and ref not in out_refs:
2784                    raise ValueError(
2785                        f"`reference_tensor` '{ref}' not found. Valid tensor references"
2786                        + f" are: {ipt_refs | out_refs}."
2787                    )
2788
2789        return self
2790
2791    # TODO: use validate funcs in validate_test_tensors
2792    # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]:
2793
2794    name: Annotated[
2795        Annotated[
2796            str, RestrictCharacters(string.ascii_letters + string.digits + "_+- ()")
2797        ],
2798        MinLen(5),
2799        MaxLen(128),
2800        warn(MaxLen(64), "Name longer than 64 characters.", INFO),
2801    ]
2802    """A human-readable name of this model.
2803    It should be no longer than 64 characters
2804    and may only contain letter, number, underscore, minus, parentheses and spaces.
2805    We recommend to chose a name that refers to the model's task and image modality.
2806    """
2807
2808    outputs: NotEmpty[Sequence[OutputTensorDescr]]
2809    """Describes the output tensors."""
2810
2811    @field_validator("outputs", mode="after")
2812    @classmethod
2813    def _validate_tensor_ids(
2814        cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo
2815    ) -> Sequence[OutputTensorDescr]:
2816        tensor_ids = [
2817            t.id for t in info.data.get("inputs", []) + info.data.get("outputs", [])
2818        ]
2819        duplicate_tensor_ids: List[str] = []
2820        seen: Set[str] = set()
2821        for t in tensor_ids:
2822            if t in seen:
2823                duplicate_tensor_ids.append(t)
2824
2825            seen.add(t)
2826
2827        if duplicate_tensor_ids:
2828            raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}")
2829
2830        return outputs
2831
2832    @staticmethod
2833    def _get_axes_with_parameterized_size(
2834        io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2835    ):
2836        return {
2837            f"{t.id}.{a.id}": (t, a, a.size)
2838            for t in io
2839            for a in t.axes
2840            if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize)
2841        }
2842
2843    @staticmethod
2844    def _get_axes_with_independent_size(
2845        io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2846    ):
2847        return {
2848            (t.id, a.id): (t, a, a.size)
2849            for t in io
2850            for a in t.axes
2851            if not isinstance(a, BatchAxis)
2852            and isinstance(a.size, (int, ParameterizedSize))
2853        }
2854
2855    @field_validator("outputs", mode="after")
2856    @classmethod
2857    def _validate_output_axes(
2858        cls, outputs: List[OutputTensorDescr], info: ValidationInfo
2859    ) -> List[OutputTensorDescr]:
2860        input_size_refs = cls._get_axes_with_independent_size(
2861            info.data.get("inputs", [])
2862        )
2863        output_size_refs = cls._get_axes_with_independent_size(outputs)
2864
2865        for i, out in enumerate(outputs):
2866            valid_independent_refs: Dict[
2867                Tuple[TensorId, AxisId],
2868                Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2869            ] = {
2870                **{
2871                    (out.id, a.id): (out, a, a.size)
2872                    for a in out.axes
2873                    if not isinstance(a, BatchAxis)
2874                    and isinstance(a.size, (int, ParameterizedSize))
2875                },
2876                **input_size_refs,
2877                **output_size_refs,
2878            }
2879            for a, ax in enumerate(out.axes):
2880                cls._validate_axis(
2881                    "outputs",
2882                    i,
2883                    out.id,
2884                    a,
2885                    ax,
2886                    valid_independent_refs=valid_independent_refs,
2887                )
2888
2889        return outputs
2890
2891    packaged_by: List[Author] = Field(
2892        default_factory=cast(Callable[[], List[Author]], list)
2893    )
2894    """The persons that have packaged and uploaded this model.
2895    Only required if those persons differ from the `authors`."""
2896
2897    parent: Optional[LinkedModel] = None
2898    """The model from which this model is derived, e.g. by fine-tuning the weights."""
2899
2900    @model_validator(mode="after")
2901    def _validate_parent_is_not_self(self) -> Self:
2902        if self.parent is not None and self.parent.id == self.id:
2903            raise ValueError("A model description may not reference itself as parent.")
2904
2905        return self
2906
2907    run_mode: Annotated[
2908        Optional[RunMode],
2909        warn(None, "Run mode '{value}' has limited support across consumer softwares."),
2910    ] = None
2911    """Custom run mode for this model: for more complex prediction procedures like test time
2912    data augmentation that currently cannot be expressed in the specification.
2913    No standard run modes are defined yet."""
2914
2915    timestamp: Datetime = Field(default_factory=Datetime.now)
2916    """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format
2917    with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat).
2918    (In Python a datetime object is valid, too)."""
2919
2920    training_data: Annotated[
2921        Union[None, LinkedDataset, DatasetDescr, DatasetDescr02],
2922        Field(union_mode="left_to_right"),
2923    ] = None
2924    """The dataset used to train this model"""
2925
2926    weights: Annotated[WeightsDescr, WrapSerializer(package_weights)]
2927    """The weights for this model.
2928    Weights can be given for different formats, but should otherwise be equivalent.
2929    The available weight formats determine which consumers can use this model."""
2930
2931    config: Config = Field(default_factory=Config)
2932
2933    @model_validator(mode="after")
2934    def _add_default_cover(self) -> Self:
2935        if not get_validation_context().perform_io_checks or self.covers:
2936            return self
2937
2938        try:
2939            generated_covers = generate_covers(
2940                [(t, load_array(t.test_tensor)) for t in self.inputs],
2941                [(t, load_array(t.test_tensor)) for t in self.outputs],
2942            )
2943        except Exception as e:
2944            issue_warning(
2945                "Failed to generate cover image(s): {e}",
2946                value=self.covers,
2947                msg_context=dict(e=e),
2948                field="covers",
2949            )
2950        else:
2951            self.covers.extend(generated_covers)
2952
2953        return self
2954
2955    def get_input_test_arrays(self) -> List[NDArray[Any]]:
2956        data = [load_array(ipt.test_tensor) for ipt in self.inputs]
2957        assert all(isinstance(d, np.ndarray) for d in data)
2958        return data
2959
2960    def get_output_test_arrays(self) -> List[NDArray[Any]]:
2961        data = [load_array(out.test_tensor) for out in self.outputs]
2962        assert all(isinstance(d, np.ndarray) for d in data)
2963        return data
2964
2965    @staticmethod
2966    def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
2967        batch_size = 1
2968        tensor_with_batchsize: Optional[TensorId] = None
2969        for tid in tensor_sizes:
2970            for aid, s in tensor_sizes[tid].items():
2971                if aid != BATCH_AXIS_ID or s == 1 or s == batch_size:
2972                    continue
2973
2974                if batch_size != 1:
2975                    assert tensor_with_batchsize is not None
2976                    raise ValueError(
2977                        f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})"
2978                    )
2979
2980                batch_size = s
2981                tensor_with_batchsize = tid
2982
2983        return batch_size
2984
2985    def get_output_tensor_sizes(
2986        self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]
2987    ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]:
2988        """Returns the tensor output sizes for given **input_sizes**.
2989        Only if **input_sizes** has a valid input shape, the tensor output size is exact.
2990        Otherwise it might be larger than the actual (valid) output"""
2991        batch_size = self.get_batch_size(input_sizes)
2992        ns = self.get_ns(input_sizes)
2993
2994        tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size)
2995        return tensor_sizes.outputs
2996
2997    def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
2998        """get parameter `n` for each parameterized axis
2999        such that the valid input size is >= the given input size"""
3000        ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {}
3001        axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs}
3002        for tid in input_sizes:
3003            for aid, s in input_sizes[tid].items():
3004                size_descr = axes[tid][aid].size
3005                if isinstance(size_descr, ParameterizedSize):
3006                    ret[(tid, aid)] = size_descr.get_n(s)
3007                elif size_descr is None or isinstance(size_descr, (int, SizeReference)):
3008                    pass
3009                else:
3010                    assert_never(size_descr)
3011
3012        return ret
3013
3014    def get_tensor_sizes(
3015        self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int
3016    ) -> _TensorSizes:
3017        axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size)
3018        return _TensorSizes(
3019            {
3020                t: {
3021                    aa: axis_sizes.inputs[(tt, aa)]
3022                    for tt, aa in axis_sizes.inputs
3023                    if tt == t
3024                }
3025                for t in {tt for tt, _ in axis_sizes.inputs}
3026            },
3027            {
3028                t: {
3029                    aa: axis_sizes.outputs[(tt, aa)]
3030                    for tt, aa in axis_sizes.outputs
3031                    if tt == t
3032                }
3033                for t in {tt for tt, _ in axis_sizes.outputs}
3034            },
3035        )
3036
3037    def get_axis_sizes(
3038        self,
3039        ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
3040        batch_size: Optional[int] = None,
3041        *,
3042        max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
3043    ) -> _AxisSizes:
3044        """Determine input and output block shape for scale factors **ns**
3045        of parameterized input sizes.
3046
3047        Args:
3048            ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id))
3049                that is parameterized as `size = min + n * step`.
3050            batch_size: The desired size of the batch dimension.
3051                If given **batch_size** overwrites any batch size present in
3052                **max_input_shape**. Default 1.
3053            max_input_shape: Limits the derived block shapes.
3054                Each axis for which the input size, parameterized by `n`, is larger
3055                than **max_input_shape** is set to the minimal value `n_min` for which
3056                this is still true.
3057                Use this for small input samples or large values of **ns**.
3058                Or simply whenever you know the full input shape.
3059
3060        Returns:
3061            Resolved axis sizes for model inputs and outputs.
3062        """
3063        max_input_shape = max_input_shape or {}
3064        if batch_size is None:
3065            for (_t_id, a_id), s in max_input_shape.items():
3066                if a_id == BATCH_AXIS_ID:
3067                    batch_size = s
3068                    break
3069            else:
3070                batch_size = 1
3071
3072        all_axes = {
3073            t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
3074        }
3075
3076        inputs: Dict[Tuple[TensorId, AxisId], int] = {}
3077        outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {}
3078
3079        def get_axis_size(a: Union[InputAxis, OutputAxis]):
3080            if isinstance(a, BatchAxis):
3081                if (t_descr.id, a.id) in ns:
3082                    logger.warning(
3083                        "Ignoring unexpected size increment factor (n) for batch axis"
3084                        + " of tensor '{}'.",
3085                        t_descr.id,
3086                    )
3087                return batch_size
3088            elif isinstance(a.size, int):
3089                if (t_descr.id, a.id) in ns:
3090                    logger.warning(
3091                        "Ignoring unexpected size increment factor (n) for fixed size"
3092                        + " axis '{}' of tensor '{}'.",
3093                        a.id,
3094                        t_descr.id,
3095                    )
3096                return a.size
3097            elif isinstance(a.size, ParameterizedSize):
3098                if (t_descr.id, a.id) not in ns:
3099                    raise ValueError(
3100                        "Size increment factor (n) missing for parametrized axis"
3101                        + f" '{a.id}' of tensor '{t_descr.id}'."
3102                    )
3103                n = ns[(t_descr.id, a.id)]
3104                s_max = max_input_shape.get((t_descr.id, a.id))
3105                if s_max is not None:
3106                    n = min(n, a.size.get_n(s_max))
3107
3108                return a.size.get_size(n)
3109
3110            elif isinstance(a.size, SizeReference):
3111                if (t_descr.id, a.id) in ns:
3112                    logger.warning(
3113                        "Ignoring unexpected size increment factor (n) for axis '{}'"
3114                        + " of tensor '{}' with size reference.",
3115                        a.id,
3116                        t_descr.id,
3117                    )
3118                assert not isinstance(a, BatchAxis)
3119                ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
3120                assert not isinstance(ref_axis, BatchAxis)
3121                ref_key = (a.size.tensor_id, a.size.axis_id)
3122                ref_size = inputs.get(ref_key, outputs.get(ref_key))
3123                assert ref_size is not None, ref_key
3124                assert not isinstance(ref_size, _DataDepSize), ref_key
3125                return a.size.get_size(
3126                    axis=a,
3127                    ref_axis=ref_axis,
3128                    ref_size=ref_size,
3129                )
3130            elif isinstance(a.size, DataDependentSize):
3131                if (t_descr.id, a.id) in ns:
3132                    logger.warning(
3133                        "Ignoring unexpected increment factor (n) for data dependent"
3134                        + " size axis '{}' of tensor '{}'.",
3135                        a.id,
3136                        t_descr.id,
3137                    )
3138                return _DataDepSize(a.size.min, a.size.max)
3139            else:
3140                assert_never(a.size)
3141
3142        # first resolve all , but the `SizeReference` input sizes
3143        for t_descr in self.inputs:
3144            for a in t_descr.axes:
3145                if not isinstance(a.size, SizeReference):
3146                    s = get_axis_size(a)
3147                    assert not isinstance(s, _DataDepSize)
3148                    inputs[t_descr.id, a.id] = s
3149
3150        # resolve all other input axis sizes
3151        for t_descr in self.inputs:
3152            for a in t_descr.axes:
3153                if isinstance(a.size, SizeReference):
3154                    s = get_axis_size(a)
3155                    assert not isinstance(s, _DataDepSize)
3156                    inputs[t_descr.id, a.id] = s
3157
3158        # resolve all output axis sizes
3159        for t_descr in self.outputs:
3160            for a in t_descr.axes:
3161                assert not isinstance(a.size, ParameterizedSize)
3162                s = get_axis_size(a)
3163                outputs[t_descr.id, a.id] = s
3164
3165        return _AxisSizes(inputs=inputs, outputs=outputs)
3166
3167    @model_validator(mode="before")
3168    @classmethod
3169    def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]:
3170        cls.convert_from_old_format_wo_validation(data)
3171        return data
3172
3173    @classmethod
3174    def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3175        """Convert metadata following an older format version to this classes' format
3176        without validating the result.
3177        """
3178        if (
3179            data.get("type") == "model"
3180            and isinstance(fv := data.get("format_version"), str)
3181            and fv.count(".") == 2
3182        ):
3183            fv_parts = fv.split(".")
3184            if any(not p.isdigit() for p in fv_parts):
3185                return
3186
3187            fv_tuple = tuple(map(int, fv_parts))
3188
3189            assert cls.implemented_format_version_tuple[0:2] == (0, 5)
3190            if fv_tuple[:2] in ((0, 3), (0, 4)):
3191                m04 = _ModelDescr_v0_4.load(data)
3192                if isinstance(m04, InvalidDescr):
3193                    try:
3194                        updated = _model_conv.convert_as_dict(
3195                            m04  # pyright: ignore[reportArgumentType]
3196                        )
3197                    except Exception as e:
3198                        logger.error(
3199                            "Failed to convert from invalid model 0.4 description."
3200                            + f"\nerror: {e}"
3201                            + "\nProceeding with model 0.5 validation without conversion."
3202                        )
3203                        updated = None
3204                else:
3205                    updated = _model_conv.convert_as_dict(m04)
3206
3207                if updated is not None:
3208                    data.clear()
3209                    data.update(updated)
3210
3211            elif fv_tuple[:2] == (0, 5):
3212                # bump patch version
3213                data["format_version"] = cls.implemented_format_version
3214
3215
3216class _ModelConv(Converter[_ModelDescr_v0_4, ModelDescr]):
3217    def _convert(
3218        self, src: _ModelDescr_v0_4, tgt: "type[ModelDescr] | type[dict[str, Any]]"
3219    ) -> "ModelDescr | dict[str, Any]":
3220        name = "".join(
3221            c if c in string.ascii_letters + string.digits + "_+- ()" else " "
3222            for c in src.name
3223        )
3224
3225        def conv_authors(auths: Optional[Sequence[_Author_v0_4]]):
3226            conv = (
3227                _author_conv.convert if TYPE_CHECKING else _author_conv.convert_as_dict
3228            )
3229            return None if auths is None else [conv(a) for a in auths]
3230
3231        if TYPE_CHECKING:
3232            arch_file_conv = _arch_file_conv.convert
3233            arch_lib_conv = _arch_lib_conv.convert
3234        else:
3235            arch_file_conv = _arch_file_conv.convert_as_dict
3236            arch_lib_conv = _arch_lib_conv.convert_as_dict
3237
3238        input_size_refs = {
3239            ipt.name: {
3240                a: s
3241                for a, s in zip(
3242                    ipt.axes,
3243                    (
3244                        ipt.shape.min
3245                        if isinstance(ipt.shape, _ParameterizedInputShape_v0_4)
3246                        else ipt.shape
3247                    ),
3248                )
3249            }
3250            for ipt in src.inputs
3251            if ipt.shape
3252        }
3253        output_size_refs = {
3254            **{
3255                out.name: {a: s for a, s in zip(out.axes, out.shape)}
3256                for out in src.outputs
3257                if not isinstance(out.shape, _ImplicitOutputShape_v0_4)
3258            },
3259            **input_size_refs,
3260        }
3261
3262        return tgt(
3263            attachments=(
3264                []
3265                if src.attachments is None
3266                else [FileDescr(source=f) for f in src.attachments.files]
3267            ),
3268            authors=[
3269                _author_conv.convert_as_dict(a) for a in src.authors
3270            ],  # pyright: ignore[reportArgumentType]
3271            cite=[
3272                {"text": c.text, "doi": c.doi, "url": c.url} for c in src.cite
3273            ],  # pyright: ignore[reportArgumentType]
3274            config=src.config,  # pyright: ignore[reportArgumentType]
3275            covers=src.covers,
3276            description=src.description,
3277            documentation=src.documentation,
3278            format_version="0.5.4",
3279            git_repo=src.git_repo,  # pyright: ignore[reportArgumentType]
3280            icon=src.icon,
3281            id=None if src.id is None else ModelId(src.id),
3282            id_emoji=src.id_emoji,
3283            license=src.license,  # type: ignore
3284            links=src.links,
3285            maintainers=[
3286                _maintainer_conv.convert_as_dict(m) for m in src.maintainers
3287            ],  # pyright: ignore[reportArgumentType]
3288            name=name,
3289            tags=src.tags,
3290            type=src.type,
3291            uploader=src.uploader,
3292            version=src.version,
3293            inputs=[  # pyright: ignore[reportArgumentType]
3294                _input_tensor_conv.convert_as_dict(ipt, tt, st, input_size_refs)
3295                for ipt, tt, st, in zip(
3296                    src.inputs,
3297                    src.test_inputs,
3298                    src.sample_inputs or [None] * len(src.test_inputs),
3299                )
3300            ],
3301            outputs=[  # pyright: ignore[reportArgumentType]
3302                _output_tensor_conv.convert_as_dict(out, tt, st, output_size_refs)
3303                for out, tt, st, in zip(
3304                    src.outputs,
3305                    src.test_outputs,
3306                    src.sample_outputs or [None] * len(src.test_outputs),
3307                )
3308            ],
3309            parent=(
3310                None
3311                if src.parent is None
3312                else LinkedModel(
3313                    id=ModelId(
3314                        str(src.parent.id)
3315                        + (
3316                            ""
3317                            if src.parent.version_number is None
3318                            else f"/{src.parent.version_number}"
3319                        )
3320                    )
3321                )
3322            ),
3323            training_data=(
3324                None
3325                if src.training_data is None
3326                else (
3327                    LinkedDataset(
3328                        id=DatasetId(
3329                            str(src.training_data.id)
3330                            + (
3331                                ""
3332                                if src.training_data.version_number is None
3333                                else f"/{src.training_data.version_number}"
3334                            )
3335                        )
3336                    )
3337                    if isinstance(src.training_data, LinkedDataset02)
3338                    else src.training_data
3339                )
3340            ),
3341            packaged_by=[
3342                _author_conv.convert_as_dict(a) for a in src.packaged_by
3343            ],  # pyright: ignore[reportArgumentType]
3344            run_mode=src.run_mode,
3345            timestamp=src.timestamp,
3346            weights=(WeightsDescr if TYPE_CHECKING else dict)(
3347                keras_hdf5=(w := src.weights.keras_hdf5)
3348                and (KerasHdf5WeightsDescr if TYPE_CHECKING else dict)(
3349                    authors=conv_authors(w.authors),
3350                    source=w.source,
3351                    tensorflow_version=w.tensorflow_version or Version("1.15"),
3352                    parent=w.parent,
3353                ),
3354                onnx=(w := src.weights.onnx)
3355                and (OnnxWeightsDescr if TYPE_CHECKING else dict)(
3356                    source=w.source,
3357                    authors=conv_authors(w.authors),
3358                    parent=w.parent,
3359                    opset_version=w.opset_version or 15,
3360                ),
3361                pytorch_state_dict=(w := src.weights.pytorch_state_dict)
3362                and (PytorchStateDictWeightsDescr if TYPE_CHECKING else dict)(
3363                    source=w.source,
3364                    authors=conv_authors(w.authors),
3365                    parent=w.parent,
3366                    architecture=(
3367                        arch_file_conv(
3368                            w.architecture,
3369                            w.architecture_sha256,
3370                            w.kwargs,
3371                        )
3372                        if isinstance(w.architecture, _CallableFromFile_v0_4)
3373                        else arch_lib_conv(w.architecture, w.kwargs)
3374                    ),
3375                    pytorch_version=w.pytorch_version or Version("1.10"),
3376                    dependencies=(
3377                        None
3378                        if w.dependencies is None
3379                        else (FileDescr if TYPE_CHECKING else dict)(
3380                            source=cast(
3381                                FileSource,
3382                                str(deps := w.dependencies)[
3383                                    (
3384                                        len("conda:")
3385                                        if str(deps).startswith("conda:")
3386                                        else 0
3387                                    ) :
3388                                ],
3389                            )
3390                        )
3391                    ),
3392                ),
3393                tensorflow_js=(w := src.weights.tensorflow_js)
3394                and (TensorflowJsWeightsDescr if TYPE_CHECKING else dict)(
3395                    source=w.source,
3396                    authors=conv_authors(w.authors),
3397                    parent=w.parent,
3398                    tensorflow_version=w.tensorflow_version or Version("1.15"),
3399                ),
3400                tensorflow_saved_model_bundle=(
3401                    w := src.weights.tensorflow_saved_model_bundle
3402                )
3403                and (TensorflowSavedModelBundleWeightsDescr if TYPE_CHECKING else dict)(
3404                    authors=conv_authors(w.authors),
3405                    parent=w.parent,
3406                    source=w.source,
3407                    tensorflow_version=w.tensorflow_version or Version("1.15"),
3408                    dependencies=(
3409                        None
3410                        if w.dependencies is None
3411                        else (FileDescr if TYPE_CHECKING else dict)(
3412                            source=cast(
3413                                FileSource,
3414                                (
3415                                    str(w.dependencies)[len("conda:") :]
3416                                    if str(w.dependencies).startswith("conda:")
3417                                    else str(w.dependencies)
3418                                ),
3419                            )
3420                        )
3421                    ),
3422                ),
3423                torchscript=(w := src.weights.torchscript)
3424                and (TorchscriptWeightsDescr if TYPE_CHECKING else dict)(
3425                    source=w.source,
3426                    authors=conv_authors(w.authors),
3427                    parent=w.parent,
3428                    pytorch_version=w.pytorch_version or Version("1.10"),
3429                ),
3430            ),
3431        )
3432
3433
3434_model_conv = _ModelConv(_ModelDescr_v0_4, ModelDescr)
3435
3436
3437# create better cover images for 3d data and non-image outputs
3438def generate_covers(
3439    inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]],
3440    outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]],
3441) -> List[Path]:
3442    def squeeze(
3443        data: NDArray[Any], axes: Sequence[AnyAxis]
3444    ) -> Tuple[NDArray[Any], List[AnyAxis]]:
3445        """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining"""
3446        if data.ndim != len(axes):
3447            raise ValueError(
3448                f"tensor shape {data.shape} does not match described axes"
3449                + f" {[a.id for a in axes]}"
3450            )
3451
3452        axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1]
3453        return data.squeeze(), axes
3454
3455    def normalize(
3456        data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7
3457    ) -> NDArray[np.float32]:
3458        data = data.astype("float32")
3459        data -= data.min(axis=axis, keepdims=True)
3460        data /= data.max(axis=axis, keepdims=True) + eps
3461        return data
3462
3463    def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]):
3464        original_shape = data.shape
3465        data, axes = squeeze(data, axes)
3466
3467        # take slice fom any batch or index axis if needed
3468        # and convert the first channel axis and take a slice from any additional channel axes
3469        slices: Tuple[slice, ...] = ()
3470        ndim = data.ndim
3471        ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2
3472        has_c_axis = False
3473        for i, a in enumerate(axes):
3474            s = data.shape[i]
3475            assert s > 1
3476            if (
3477                isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis))
3478                and ndim > ndim_need
3479            ):
3480                data = data[slices + (slice(s // 2 - 1, s // 2),)]
3481                ndim -= 1
3482            elif isinstance(a, ChannelAxis):
3483                if has_c_axis:
3484                    # second channel axis
3485                    data = data[slices + (slice(0, 1),)]
3486                    ndim -= 1
3487                else:
3488                    has_c_axis = True
3489                    if s == 2:
3490                        # visualize two channels with cyan and magenta
3491                        data = np.concatenate(
3492                            [
3493                                data[slices + (slice(1, 2),)],
3494                                data[slices + (slice(0, 1),)],
3495                                (
3496                                    data[slices + (slice(0, 1),)]
3497                                    + data[slices + (slice(1, 2),)]
3498                                )
3499                                / 2,  # TODO: take maximum instead?
3500                            ],
3501                            axis=i,
3502                        )
3503                    elif data.shape[i] == 3:
3504                        pass  # visualize 3 channels as RGB
3505                    else:
3506                        # visualize first 3 channels as RGB
3507                        data = data[slices + (slice(3),)]
3508
3509                    assert data.shape[i] == 3
3510
3511            slices += (slice(None),)
3512
3513        data, axes = squeeze(data, axes)
3514        assert len(axes) == ndim
3515        # take slice from z axis if needed
3516        slices = ()
3517        if ndim > ndim_need:
3518            for i, a in enumerate(axes):
3519                s = data.shape[i]
3520                if a.id == AxisId("z"):
3521                    data = data[slices + (slice(s // 2 - 1, s // 2),)]
3522                    data, axes = squeeze(data, axes)
3523                    ndim -= 1
3524                    break
3525
3526            slices += (slice(None),)
3527
3528        # take slice from any space or time axis
3529        slices = ()
3530
3531        for i, a in enumerate(axes):
3532            if ndim <= ndim_need:
3533                break
3534
3535            s = data.shape[i]
3536            assert s > 1
3537            if isinstance(
3538                a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis)
3539            ):
3540                data = data[slices + (slice(s // 2 - 1, s // 2),)]
3541                ndim -= 1
3542
3543            slices += (slice(None),)
3544
3545        del slices
3546        data, axes = squeeze(data, axes)
3547        assert len(axes) == ndim
3548
3549        if (has_c_axis and ndim != 3) or ndim != 2:
3550            raise ValueError(
3551                f"Failed to construct cover image from shape {original_shape}"
3552            )
3553
3554        if not has_c_axis:
3555            assert ndim == 2
3556            data = np.repeat(data[:, :, None], 3, axis=2)
3557            axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB"))))
3558            ndim += 1
3559
3560        assert ndim == 3
3561
3562        # transpose axis order such that longest axis comes first...
3563        axis_order: List[int] = list(np.argsort(list(data.shape)))
3564        axis_order.reverse()
3565        # ... and channel axis is last
3566        c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0]
3567        axis_order.append(axis_order.pop(c))
3568        axes = [axes[ao] for ao in axis_order]
3569        data = data.transpose(axis_order)
3570
3571        # h, w = data.shape[:2]
3572        # if h / w  in (1.0 or 2.0):
3573        #     pass
3574        # elif h / w < 2:
3575        # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images
3576
3577        norm_along = (
3578            tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None
3579        )
3580        # normalize the data and map to 8 bit
3581        data = normalize(data, norm_along)
3582        data = (data * 255).astype("uint8")
3583
3584        return data
3585
3586    def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]):
3587        assert im0.dtype == im1.dtype == np.uint8
3588        assert im0.shape == im1.shape
3589        assert im0.ndim == 3
3590        N, M, C = im0.shape
3591        assert C == 3
3592        out = np.ones((N, M, C), dtype="uint8")
3593        for c in range(C):
3594            outc = np.tril(im0[..., c])
3595            mask = outc == 0
3596            outc[mask] = np.triu(im1[..., c])[mask]
3597            out[..., c] = outc
3598
3599        return out
3600
3601    ipt_descr, ipt = inputs[0]
3602    out_descr, out = outputs[0]
3603
3604    ipt_img = to_2d_image(ipt, ipt_descr.axes)
3605    out_img = to_2d_image(out, out_descr.axes)
3606
3607    cover_folder = Path(mkdtemp())
3608    if ipt_img.shape == out_img.shape:
3609        covers = [cover_folder / "cover.png"]
3610        imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img))
3611    else:
3612        covers = [cover_folder / "input.png", cover_folder / "output.png"]
3613        imwrite(covers[0], ipt_img)
3614        imwrite(covers[1], out_img)
3615
3616    return covers
SpaceUnit = typing.Literal['attometer', 'angstrom', 'centimeter', 'decimeter', 'exameter', 'femtometer', 'foot', 'gigameter', 'hectometer', 'inch', 'kilometer', 'megameter', 'meter', 'micrometer', 'mile', 'millimeter', 'nanometer', 'parsec', 'petameter', 'picometer', 'terameter', 'yard', 'yoctometer', 'yottameter', 'zeptometer', 'zettameter']

Space unit compatible to the OME-Zarr axes specification 0.5

TimeUnit = typing.Literal['attosecond', 'centisecond', 'day', 'decisecond', 'exasecond', 'femtosecond', 'gigasecond', 'hectosecond', 'hour', 'kilosecond', 'megasecond', 'microsecond', 'millisecond', 'minute', 'nanosecond', 'petasecond', 'picosecond', 'second', 'terasecond', 'yoctosecond', 'yottasecond', 'zeptosecond', 'zettasecond']

Time unit compatible to the OME-Zarr axes specification 0.5

AxisType = typing.Literal['batch', 'channel', 'index', 'time', 'space']
228class TensorId(LowerCaseIdentifier):
229    root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
230        Annotated[LowerCaseIdentifierAnno, MaxLen(32)]
231    ]

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.

root_model: ClassVar[Type[pydantic.root_model.RootModel[Any]]] = <class 'pydantic.root_model.RootModel[Annotated[str, MinLen, AfterValidator, AfterValidator, Annotated[TypeVar, Predicate], MaxLen]]'>

the pydantic root model to validate the string

244class AxisId(LowerCaseIdentifier):
245    root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
246        Annotated[
247            LowerCaseIdentifierAnno,
248            MaxLen(16),
249            AfterValidator(_normalize_axis_id),
250        ]
251    ]

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.

root_model: ClassVar[Type[pydantic.root_model.RootModel[Any]]] = <class 'pydantic.root_model.RootModel[Annotated[str, MinLen, AfterValidator, AfterValidator, Annotated[TypeVar, Predicate], MaxLen, AfterValidator]]'>

the pydantic root model to validate the string

NonBatchAxisId = typing.Annotated[AxisId, Predicate(_is_not_batch)]
PostprocessingId = typing.Literal['binarize', 'clip', 'ensure_dtype', 'fixed_zero_mean_unit_variance', 'scale_linear', 'scale_mean_variance', 'scale_range', 'sigmoid', 'zero_mean_unit_variance']
PreprocessingId = typing.Literal['binarize', 'clip', 'ensure_dtype', 'scale_linear', 'sigmoid', 'zero_mean_unit_variance', 'scale_range']
SAME_AS_TYPE = '<same as type>'
ParameterizedSize_N = <class 'int'>

Annotates an integer to calculate a concrete axis size from a ParameterizedSize.

class ParameterizedSize(bioimageio.spec._internal.node.Node):
295class ParameterizedSize(Node):
296    """Describes a range of valid tensor axis sizes as `size = min + n*step`.
297
298    - **min** and **step** are given by the model description.
299    - All blocksize paramters n = 0,1,2,... yield a valid `size`.
300    - A greater blocksize paramter n = 0,1,2,... results in a greater **size**.
301      This allows to adjust the axis size more generically.
302    """
303
304    N: ClassVar[Type[int]] = ParameterizedSize_N
305    """Positive integer to parameterize this axis"""
306
307    min: Annotated[int, Gt(0)]
308    step: Annotated[int, Gt(0)]
309
310    def validate_size(self, size: int) -> int:
311        if size < self.min:
312            raise ValueError(f"size {size} < {self.min}")
313        if (size - self.min) % self.step != 0:
314            raise ValueError(
315                f"axis of size {size} is not parameterized by `min + n*step` ="
316                + f" `{self.min} + n*{self.step}`"
317            )
318
319        return size
320
321    def get_size(self, n: ParameterizedSize_N) -> int:
322        return self.min + self.step * n
323
324    def get_n(self, s: int) -> ParameterizedSize_N:
325        """return smallest n parameterizing a size greater or equal than `s`"""
326        return ceil((s - self.min) / self.step)

Describes a range of valid tensor axis sizes as size = min + n*step.

  • min and step are given by the model description.
  • All blocksize paramters n = 0,1,2,... yield a valid size.
  • A greater blocksize paramter n = 0,1,2,... results in a greater size. This allows to adjust the axis size more generically.
N: ClassVar[Type[int]] = <class 'int'>

Positive integer to parameterize this axis

min: Annotated[int, Gt(gt=0)]
step: Annotated[int, Gt(gt=0)]
def validate_size(self, size: int) -> int:
310    def validate_size(self, size: int) -> int:
311        if size < self.min:
312            raise ValueError(f"size {size} < {self.min}")
313        if (size - self.min) % self.step != 0:
314            raise ValueError(
315                f"axis of size {size} is not parameterized by `min + n*step` ="
316                + f" `{self.min} + n*{self.step}`"
317            )
318
319        return size
def get_size(self, n: int) -> int:
321    def get_size(self, n: ParameterizedSize_N) -> int:
322        return self.min + self.step * n
def get_n(self, s: int) -> int:
324    def get_n(self, s: int) -> ParameterizedSize_N:
325        """return smallest n parameterizing a size greater or equal than `s`"""
326        return ceil((s - self.min) / self.step)

return smallest n parameterizing a size greater or equal than s

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class DataDependentSize(bioimageio.spec._internal.node.Node):
329class DataDependentSize(Node):
330    min: Annotated[int, Gt(0)] = 1
331    max: Annotated[Optional[int], Gt(1)] = None
332
333    @model_validator(mode="after")
334    def _validate_max_gt_min(self):
335        if self.max is not None and self.min >= self.max:
336            raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}")
337
338        return self
339
340    def validate_size(self, size: int) -> int:
341        if size < self.min:
342            raise ValueError(f"size {size} < {self.min}")
343
344        if self.max is not None and size > self.max:
345            raise ValueError(f"size {size} > {self.max}")
346
347        return size
min: Annotated[int, Gt(gt=0)]
max: Annotated[Optional[int], Gt(gt=1)]
def validate_size(self, size: int) -> int:
340    def validate_size(self, size: int) -> int:
341        if size < self.min:
342            raise ValueError(f"size {size} < {self.min}")
343
344        if self.max is not None and size > self.max:
345            raise ValueError(f"size {size} > {self.max}")
346
347        return size
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class SizeReference(bioimageio.spec._internal.node.Node):
350class SizeReference(Node):
351    """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis.
352
353    `axis.size = reference.size * reference.scale / axis.scale + offset`
354
355    Note:
356    1. The axis and the referenced axis need to have the same unit (or no unit).
357    2. Batch axes may not be referenced.
358    3. Fractions are rounded down.
359    4. If the reference axis is `concatenable` the referencing axis is assumed to be
360        `concatenable` as well with the same block order.
361
362    Example:
363    An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm².
364    Let's assume that we want to express the image height h in relation to its width w
365    instead of only accepting input images of exactly 100*49 pixels
366    (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`).
367
368    >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2)
369    >>> h = SpaceInputAxis(
370    ...     id=AxisId("h"),
371    ...     size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1),
372    ...     unit="millimeter",
373    ...     scale=4,
374    ... )
375    >>> print(h.size.get_size(h, w))
376    49
377
378    ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49
379    """
380
381    tensor_id: TensorId
382    """tensor id of the reference axis"""
383
384    axis_id: AxisId
385    """axis id of the reference axis"""
386
387    offset: StrictInt = 0
388
389    def get_size(
390        self,
391        axis: Union[
392            ChannelAxis,
393            IndexInputAxis,
394            IndexOutputAxis,
395            TimeInputAxis,
396            SpaceInputAxis,
397            TimeOutputAxis,
398            TimeOutputAxisWithHalo,
399            SpaceOutputAxis,
400            SpaceOutputAxisWithHalo,
401        ],
402        ref_axis: Union[
403            ChannelAxis,
404            IndexInputAxis,
405            IndexOutputAxis,
406            TimeInputAxis,
407            SpaceInputAxis,
408            TimeOutputAxis,
409            TimeOutputAxisWithHalo,
410            SpaceOutputAxis,
411            SpaceOutputAxisWithHalo,
412        ],
413        n: ParameterizedSize_N = 0,
414        ref_size: Optional[int] = None,
415    ):
416        """Compute the concrete size for a given axis and its reference axis.
417
418        Args:
419            axis: The axis this `SizeReference` is the size of.
420            ref_axis: The reference axis to compute the size from.
421            n: If the **ref_axis** is parameterized (of type `ParameterizedSize`)
422                and no fixed **ref_size** is given,
423                **n** is used to compute the size of the parameterized **ref_axis**.
424            ref_size: Overwrite the reference size instead of deriving it from
425                **ref_axis**
426                (**ref_axis.scale** is still used; any given **n** is ignored).
427        """
428        assert (
429            axis.size == self
430        ), "Given `axis.size` is not defined by this `SizeReference`"
431
432        assert (
433            ref_axis.id == self.axis_id
434        ), f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}."
435
436        assert axis.unit == ref_axis.unit, (
437            "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`,"
438            f" but {axis.unit}!={ref_axis.unit}"
439        )
440        if ref_size is None:
441            if isinstance(ref_axis.size, (int, float)):
442                ref_size = ref_axis.size
443            elif isinstance(ref_axis.size, ParameterizedSize):
444                ref_size = ref_axis.size.get_size(n)
445            elif isinstance(ref_axis.size, DataDependentSize):
446                raise ValueError(
447                    "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`."
448                )
449            elif isinstance(ref_axis.size, SizeReference):
450                raise ValueError(
451                    "Reference axis referenced in `SizeReference` may not be sized by a"
452                    + " `SizeReference` itself."
453                )
454            else:
455                assert_never(ref_axis.size)
456
457        return int(ref_size * ref_axis.scale / axis.scale + self.offset)
458
459    @staticmethod
460    def _get_unit(
461        axis: Union[
462            ChannelAxis,
463            IndexInputAxis,
464            IndexOutputAxis,
465            TimeInputAxis,
466            SpaceInputAxis,
467            TimeOutputAxis,
468            TimeOutputAxisWithHalo,
469            SpaceOutputAxis,
470            SpaceOutputAxisWithHalo,
471        ],
472    ):
473        return axis.unit

A tensor axis size (extent in pixels/frames) defined in relation to a reference axis.

axis.size = reference.size * reference.scale / axis.scale + offset

Note:

  1. The axis and the referenced axis need to have the same unit (or no unit).
  2. Batch axes may not be referenced.
  3. Fractions are rounded down.
  4. If the reference axis is concatenable the referencing axis is assumed to be concatenable as well with the same block order.

Example: An unisotropic input image of wh=10049 pixels depicts a phsical space of 200196mm². Let's assume that we want to express the image height h in relation to its width w instead of only accepting input images of exactly 10049 pixels (for example to express a range of valid image shapes by parametrizing w, see ParameterizedSize).

>>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2)
>>> h = SpaceInputAxis(
...     id=AxisId("h"),
...     size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1),
...     unit="millimeter",
...     scale=4,
... )
>>> print(h.size.get_size(h, w))
49

⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49

tensor_id: TensorId

tensor id of the reference axis

axis_id: AxisId

axis id of the reference axis

offset: Annotated[int, Strict(strict=True)]
389    def get_size(
390        self,
391        axis: Union[
392            ChannelAxis,
393            IndexInputAxis,
394            IndexOutputAxis,
395            TimeInputAxis,
396            SpaceInputAxis,
397            TimeOutputAxis,
398            TimeOutputAxisWithHalo,
399            SpaceOutputAxis,
400            SpaceOutputAxisWithHalo,
401        ],
402        ref_axis: Union[
403            ChannelAxis,
404            IndexInputAxis,
405            IndexOutputAxis,
406            TimeInputAxis,
407            SpaceInputAxis,
408            TimeOutputAxis,
409            TimeOutputAxisWithHalo,
410            SpaceOutputAxis,
411            SpaceOutputAxisWithHalo,
412        ],
413        n: ParameterizedSize_N = 0,
414        ref_size: Optional[int] = None,
415    ):
416        """Compute the concrete size for a given axis and its reference axis.
417
418        Args:
419            axis: The axis this `SizeReference` is the size of.
420            ref_axis: The reference axis to compute the size from.
421            n: If the **ref_axis** is parameterized (of type `ParameterizedSize`)
422                and no fixed **ref_size** is given,
423                **n** is used to compute the size of the parameterized **ref_axis**.
424            ref_size: Overwrite the reference size instead of deriving it from
425                **ref_axis**
426                (**ref_axis.scale** is still used; any given **n** is ignored).
427        """
428        assert (
429            axis.size == self
430        ), "Given `axis.size` is not defined by this `SizeReference`"
431
432        assert (
433            ref_axis.id == self.axis_id
434        ), f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}."
435
436        assert axis.unit == ref_axis.unit, (
437            "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`,"
438            f" but {axis.unit}!={ref_axis.unit}"
439        )
440        if ref_size is None:
441            if isinstance(ref_axis.size, (int, float)):
442                ref_size = ref_axis.size
443            elif isinstance(ref_axis.size, ParameterizedSize):
444                ref_size = ref_axis.size.get_size(n)
445            elif isinstance(ref_axis.size, DataDependentSize):
446                raise ValueError(
447                    "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`."
448                )
449            elif isinstance(ref_axis.size, SizeReference):
450                raise ValueError(
451                    "Reference axis referenced in `SizeReference` may not be sized by a"
452                    + " `SizeReference` itself."
453                )
454            else:
455                assert_never(ref_axis.size)
456
457        return int(ref_size * ref_axis.scale / axis.scale + self.offset)

Compute the concrete size for a given axis and its reference axis.

Arguments:
  • axis: The axis this SizeReference is the size of.
  • ref_axis: The reference axis to compute the size from.
  • n: If the ref_axis is parameterized (of type ParameterizedSize) and no fixed ref_size is given, n is used to compute the size of the parameterized ref_axis.
  • ref_size: Overwrite the reference size instead of deriving it from ref_axis (ref_axis.scale is still used; any given n is ignored).
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

476class AxisBase(NodeWithExplicitlySetFields):
477    id: AxisId
478    """An axis id unique across all axes of one tensor."""
479
480    description: Annotated[str, MaxLen(128)] = ""
id: AxisId

An axis id unique across all axes of one tensor.

description: Annotated[str, MaxLen(max_length=128)]
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class WithHalo(bioimageio.spec._internal.node.Node):
483class WithHalo(Node):
484    halo: Annotated[int, Ge(1)]
485    """The halo should be cropped from the output tensor to avoid boundary effects.
486    It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`.
487    To document a halo that is already cropped by the model use `size.offset` instead."""
488
489    size: Annotated[
490        SizeReference,
491        Field(
492            examples=[
493                10,
494                SizeReference(
495                    tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
496                ).model_dump(mode="json"),
497            ]
498        ),
499    ]
500    """reference to another axis with an optional offset (see `SizeReference`)"""
halo: Annotated[int, Ge(ge=1)]

The halo should be cropped from the output tensor to avoid boundary effects. It is to be cropped from both sides, i.e. size_after_crop = size - 2 * halo. To document a halo that is already cropped by the model use size.offset instead.

size: Annotated[SizeReference, FieldInfo(annotation=NoneType, required=True, examples=[10, {'tensor_id': 't', 'axis_id': 'a', 'offset': 5}])]

reference to another axis with an optional offset (see SizeReference)

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

BATCH_AXIS_ID = 'batch'
class BatchAxis(AxisBase):
506class BatchAxis(AxisBase):
507    implemented_type: ClassVar[Literal["batch"]] = "batch"
508    if TYPE_CHECKING:
509        type: Literal["batch"] = "batch"
510    else:
511        type: Literal["batch"]
512
513    id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID
514    size: Optional[Literal[1]] = None
515    """The batch size may be fixed to 1,
516    otherwise (the default) it may be chosen arbitrarily depending on available memory"""
517
518    @property
519    def scale(self):
520        return 1.0
521
522    @property
523    def concatenable(self):
524        return True
525
526    @property
527    def unit(self):
528        return None
implemented_type: ClassVar[Literal['batch']] = 'batch'
id: Annotated[AxisId, Predicate(_is_batch)]

An axis id unique across all axes of one tensor.

size: Optional[Literal[1]]

The batch size may be fixed to 1, otherwise (the default) it may be chosen arbitrarily depending on available memory

scale
518    @property
519    def scale(self):
520        return 1.0
concatenable
522    @property
523    def concatenable(self):
524        return True
unit
526    @property
527    def unit(self):
528        return None
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['batch']
class ChannelAxis(AxisBase):
531class ChannelAxis(AxisBase):
532    implemented_type: ClassVar[Literal["channel"]] = "channel"
533    if TYPE_CHECKING:
534        type: Literal["channel"] = "channel"
535    else:
536        type: Literal["channel"]
537
538    id: NonBatchAxisId = AxisId("channel")
539    channel_names: NotEmpty[List[Identifier]]
540
541    @property
542    def size(self) -> int:
543        return len(self.channel_names)
544
545    @property
546    def concatenable(self):
547        return False
548
549    @property
550    def scale(self) -> float:
551        return 1.0
552
553    @property
554    def unit(self):
555        return None
implemented_type: ClassVar[Literal['channel']] = 'channel'
id: Annotated[AxisId, Predicate(_is_not_batch)]

An axis id unique across all axes of one tensor.

channel_names: Annotated[List[bioimageio.spec._internal.types.Identifier], MinLen(min_length=1)]
size: int
541    @property
542    def size(self) -> int:
543        return len(self.channel_names)
concatenable
545    @property
546    def concatenable(self):
547        return False
scale: float
549    @property
550    def scale(self) -> float:
551        return 1.0
unit
553    @property
554    def unit(self):
555        return None
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['channel']
class IndexAxisBase(AxisBase):
558class IndexAxisBase(AxisBase):
559    implemented_type: ClassVar[Literal["index"]] = "index"
560    if TYPE_CHECKING:
561        type: Literal["index"] = "index"
562    else:
563        type: Literal["index"]
564
565    id: NonBatchAxisId = AxisId("index")
566
567    @property
568    def scale(self) -> float:
569        return 1.0
570
571    @property
572    def unit(self):
573        return None
implemented_type: ClassVar[Literal['index']] = 'index'
id: Annotated[AxisId, Predicate(_is_not_batch)]

An axis id unique across all axes of one tensor.

scale: float
567    @property
568    def scale(self) -> float:
569        return 1.0
unit
571    @property
572    def unit(self):
573        return None
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['index']
class IndexInputAxis(IndexAxisBase, _WithInputAxisSize):
596class IndexInputAxis(IndexAxisBase, _WithInputAxisSize):
597    concatenable: bool = False
598    """If a model has a `concatenable` input axis, it can be processed blockwise,
599    splitting a longer sample axis into blocks matching its input tensor description.
600    Output axes are concatenable if they have a `SizeReference` to a concatenable
601    input axis.
602    """
concatenable: bool

If a model has a concatenable input axis, it can be processed blockwise, splitting a longer sample axis into blocks matching its input tensor description. Output axes are concatenable if they have a SizeReference to a concatenable input axis.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['index']
class IndexOutputAxis(IndexAxisBase):
605class IndexOutputAxis(IndexAxisBase):
606    size: Annotated[
607        Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize],
608        Field(
609            examples=[
610                10,
611                SizeReference(
612                    tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
613                ).model_dump(mode="json"),
614            ]
615        ),
616    ]
617    """The size/length of this axis can be specified as
618    - fixed integer
619    - reference to another axis with an optional offset (`SizeReference`)
620    - data dependent size using `DataDependentSize` (size is only known after model inference)
621    """
size: Annotated[Union[Annotated[int, Gt(gt=0)], SizeReference, DataDependentSize], FieldInfo(annotation=NoneType, required=True, examples=[10, {'tensor_id': 't', 'axis_id': 'a', 'offset': 5}])]

The size/length of this axis can be specified as

  • fixed integer
  • reference to another axis with an optional offset (SizeReference)
  • data dependent size using DataDependentSize (size is only known after model inference)
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['index']
class TimeAxisBase(AxisBase):
624class TimeAxisBase(AxisBase):
625    implemented_type: ClassVar[Literal["time"]] = "time"
626    if TYPE_CHECKING:
627        type: Literal["time"] = "time"
628    else:
629        type: Literal["time"]
630
631    id: NonBatchAxisId = AxisId("time")
632    unit: Optional[TimeUnit] = None
633    scale: Annotated[float, Gt(0)] = 1.0
implemented_type: ClassVar[Literal['time']] = 'time'
id: Annotated[AxisId, Predicate(_is_not_batch)]

An axis id unique across all axes of one tensor.

unit: Optional[Literal['attosecond', 'centisecond', 'day', 'decisecond', 'exasecond', 'femtosecond', 'gigasecond', 'hectosecond', 'hour', 'kilosecond', 'megasecond', 'microsecond', 'millisecond', 'minute', 'nanosecond', 'petasecond', 'picosecond', 'second', 'terasecond', 'yoctosecond', 'yottasecond', 'zeptosecond', 'zettasecond']]
scale: Annotated[float, Gt(gt=0)]
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['time']
class TimeInputAxis(TimeAxisBase, _WithInputAxisSize):
636class TimeInputAxis(TimeAxisBase, _WithInputAxisSize):
637    concatenable: bool = False
638    """If a model has a `concatenable` input axis, it can be processed blockwise,
639    splitting a longer sample axis into blocks matching its input tensor description.
640    Output axes are concatenable if they have a `SizeReference` to a concatenable
641    input axis.
642    """
concatenable: bool

If a model has a concatenable input axis, it can be processed blockwise, splitting a longer sample axis into blocks matching its input tensor description. Output axes are concatenable if they have a SizeReference to a concatenable input axis.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['time']
class SpaceAxisBase(AxisBase):
645class SpaceAxisBase(AxisBase):
646    implemented_type: ClassVar[Literal["space"]] = "space"
647    if TYPE_CHECKING:
648        type: Literal["space"] = "space"
649    else:
650        type: Literal["space"]
651
652    id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x")
653    unit: Optional[SpaceUnit] = None
654    scale: Annotated[float, Gt(0)] = 1.0
implemented_type: ClassVar[Literal['space']] = 'space'
id: Annotated[AxisId, Predicate(_is_not_batch), FieldInfo(annotation=NoneType, required=True, examples=['x', 'y', 'z'])]

An axis id unique across all axes of one tensor.

unit: Optional[Literal['attometer', 'angstrom', 'centimeter', 'decimeter', 'exameter', 'femtometer', 'foot', 'gigameter', 'hectometer', 'inch', 'kilometer', 'megameter', 'meter', 'micrometer', 'mile', 'millimeter', 'nanometer', 'parsec', 'petameter', 'picometer', 'terameter', 'yard', 'yoctometer', 'yottameter', 'zeptometer', 'zettameter']]
scale: Annotated[float, Gt(gt=0)]
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['space']
class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize):
657class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize):
658    concatenable: bool = False
659    """If a model has a `concatenable` input axis, it can be processed blockwise,
660    splitting a longer sample axis into blocks matching its input tensor description.
661    Output axes are concatenable if they have a `SizeReference` to a concatenable
662    input axis.
663    """
concatenable: bool

If a model has a concatenable input axis, it can be processed blockwise, splitting a longer sample axis into blocks matching its input tensor description. Output axes are concatenable if they have a SizeReference to a concatenable input axis.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['space']
INPUT_AXIS_TYPES = (<class 'BatchAxis'>, <class 'ChannelAxis'>, <class 'IndexInputAxis'>, <class 'TimeInputAxis'>, <class 'SpaceInputAxis'>)

intended for isinstance comparisons in py<3.10

InputAxis = typing.Annotated[typing.Union[BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)]
class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize):
699class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize):
700    pass
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['time']
class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo):
703class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo):
704    pass
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['time']
class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize):
723class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize):
724    pass
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['space']
class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo):
727class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo):
728    pass
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['space']
OutputAxis = typing.Annotated[typing.Union[BatchAxis, ChannelAxis, IndexOutputAxis, typing.Annotated[typing.Union[typing.Annotated[TimeOutputAxis, Tag(tag='wo_halo')], typing.Annotated[TimeOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)], typing.Annotated[typing.Union[typing.Annotated[SpaceOutputAxis, Tag(tag='wo_halo')], typing.Annotated[SpaceOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)]
OUTPUT_AXIS_TYPES = (<class 'BatchAxis'>, <class 'ChannelAxis'>, <class 'IndexOutputAxis'>, <class 'TimeOutputAxis'>, <class 'TimeOutputAxisWithHalo'>, <class 'SpaceOutputAxis'>, <class 'SpaceOutputAxisWithHalo'>)

intended for isinstance comparisons in py<3.10

AnyAxis = typing.Union[typing.Annotated[typing.Union[BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], typing.Annotated[typing.Union[BatchAxis, ChannelAxis, IndexOutputAxis, typing.Annotated[typing.Union[typing.Annotated[TimeOutputAxis, Tag(tag='wo_halo')], typing.Annotated[TimeOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)], typing.Annotated[typing.Union[typing.Annotated[SpaceOutputAxis, Tag(tag='wo_halo')], typing.Annotated[SpaceOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]
ANY_AXIS_TYPES = (<class 'BatchAxis'>, <class 'ChannelAxis'>, <class 'IndexInputAxis'>, <class 'TimeInputAxis'>, <class 'SpaceInputAxis'>, <class 'BatchAxis'>, <class 'ChannelAxis'>, <class 'IndexOutputAxis'>, <class 'TimeOutputAxis'>, <class 'TimeOutputAxisWithHalo'>, <class 'SpaceOutputAxis'>, <class 'SpaceOutputAxisWithHalo'>)

intended for isinstance comparisons in py<3.10

TVs = typing.Union[typing.Annotated[typing.List[int], MinLen(min_length=1)], typing.Annotated[typing.List[float], MinLen(min_length=1)], typing.Annotated[typing.List[bool], MinLen(min_length=1)], typing.Annotated[typing.List[str], MinLen(min_length=1)]]
NominalOrOrdinalDType = typing.Literal['float32', 'float64', 'uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64', 'bool']
class NominalOrOrdinalDataDescr(bioimageio.spec._internal.node.Node):
785class NominalOrOrdinalDataDescr(Node):
786    values: TVs
787    """A fixed set of nominal or an ascending sequence of ordinal values.
788    In this case `data.type` is required to be an unsigend integer type, e.g. 'uint8'.
789    String `values` are interpreted as labels for tensor values 0, ..., N.
790    Note: as YAML 1.2 does not natively support a "set" datatype,
791    nominal values should be given as a sequence (aka list/array) as well.
792    """
793
794    type: Annotated[
795        NominalOrOrdinalDType,
796        Field(
797            examples=[
798                "float32",
799                "uint8",
800                "uint16",
801                "int64",
802                "bool",
803            ],
804        ),
805    ] = "uint8"
806
807    @model_validator(mode="after")
808    def _validate_values_match_type(
809        self,
810    ) -> Self:
811        incompatible: List[Any] = []
812        for v in self.values:
813            if self.type == "bool":
814                if not isinstance(v, bool):
815                    incompatible.append(v)
816            elif self.type in DTYPE_LIMITS:
817                if (
818                    isinstance(v, (int, float))
819                    and (
820                        v < DTYPE_LIMITS[self.type].min
821                        or v > DTYPE_LIMITS[self.type].max
822                    )
823                    or (isinstance(v, str) and "uint" not in self.type)
824                    or (isinstance(v, float) and "int" in self.type)
825                ):
826                    incompatible.append(v)
827            else:
828                incompatible.append(v)
829
830            if len(incompatible) == 5:
831                incompatible.append("...")
832                break
833
834        if incompatible:
835            raise ValueError(
836                f"data type '{self.type}' incompatible with values {incompatible}"
837            )
838
839        return self
840
841    unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None
842
843    @property
844    def range(self):
845        if isinstance(self.values[0], str):
846            return 0, len(self.values) - 1
847        else:
848            return min(self.values), max(self.values)
values: Union[Annotated[List[int], MinLen(min_length=1)], Annotated[List[float], MinLen(min_length=1)], Annotated[List[bool], MinLen(min_length=1)], Annotated[List[str], MinLen(min_length=1)]]

A fixed set of nominal or an ascending sequence of ordinal values. In this case data.type is required to be an unsigend integer type, e.g. 'uint8'. String values are interpreted as labels for tensor values 0, ..., N. Note: as YAML 1.2 does not natively support a "set" datatype, nominal values should be given as a sequence (aka list/array) as well.

type: Annotated[Literal['float32', 'float64', 'uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64', 'bool'], FieldInfo(annotation=NoneType, required=True, examples=['float32', 'uint8', 'uint16', 'int64', 'bool'])]
unit: Union[Literal['arbitrary unit'], bioimageio.spec._internal.types.SiUnit, NoneType]
range
843    @property
844    def range(self):
845        if isinstance(self.values[0], str):
846            return 0, len(self.values) - 1
847        else:
848            return min(self.values), max(self.values)
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

IntervalOrRatioDType = typing.Literal['float32', 'float64', 'uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64']
class IntervalOrRatioDataDescr(bioimageio.spec._internal.node.Node):
865class IntervalOrRatioDataDescr(Node):
866    type: Annotated[  # todo: rename to dtype
867        IntervalOrRatioDType,
868        Field(
869            examples=["float32", "float64", "uint8", "uint16"],
870        ),
871    ] = "float32"
872    range: Tuple[Optional[float], Optional[float]] = (
873        None,
874        None,
875    )
876    """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor.
877    `None` corresponds to min/max of what can be expressed by **type**."""
878    unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit"
879    scale: float = 1.0
880    """Scale for data on an interval (or ratio) scale."""
881    offset: Optional[float] = None
882    """Offset for data on a ratio scale."""
883
884    @model_validator(mode="before")
885    def _replace_inf(cls, data: Any):
886        if is_dict(data):
887            if "range" in data and is_sequence(data["range"]):
888                forbidden = (
889                    "inf",
890                    "-inf",
891                    ".inf",
892                    "-.inf",
893                    float("inf"),
894                    float("-inf"),
895                )
896                if any(v in forbidden for v in data["range"]):
897                    issue_warning("replaced 'inf' value", value=data["range"])
898
899                data["range"] = tuple(
900                    (None if v in forbidden else v) for v in data["range"]
901                )
902
903        return data
type: Annotated[Literal['float32', 'float64', 'uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64'], FieldInfo(annotation=NoneType, required=True, examples=['float32', 'float64', 'uint8', 'uint16'])]
range: Tuple[Optional[float], Optional[float]]

Tuple (minimum, maximum) specifying the allowed range of the data in this tensor. None corresponds to min/max of what can be expressed by type.

unit: Union[Literal['arbitrary unit'], bioimageio.spec._internal.types.SiUnit]
scale: float

Scale for data on an interval (or ratio) scale.

offset: Optional[float]

Offset for data on a ratio scale.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

TensorDataDescr = typing.Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr]
class ProcessingDescrBase(bioimageio.spec._internal.common_nodes.NodeWithExplicitlySetFields, abc.ABC):
909class ProcessingDescrBase(NodeWithExplicitlySetFields, ABC):
910    """processing base class"""

processing base class

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class BinarizeKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
913class BinarizeKwargs(ProcessingKwargs):
914    """key word arguments for `BinarizeDescr`"""
915
916    threshold: float
917    """The fixed threshold"""

key word arguments for BinarizeDescr

threshold: float

The fixed threshold

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class BinarizeAlongAxisKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
920class BinarizeAlongAxisKwargs(ProcessingKwargs):
921    """key word arguments for `BinarizeDescr`"""
922
923    threshold: NotEmpty[List[float]]
924    """The fixed threshold values along `axis`"""
925
926    axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
927    """The `threshold` axis"""

key word arguments for BinarizeDescr

threshold: Annotated[List[float], MinLen(min_length=1)]

The fixed threshold values along axis

axis: Annotated[AxisId, Predicate(_is_not_batch), FieldInfo(annotation=NoneType, required=True, examples=['channel'])]

The threshold axis

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class BinarizeDescr(ProcessingDescrBase):
930class BinarizeDescr(ProcessingDescrBase):
931    """Binarize the tensor with a fixed threshold.
932
933    Values above `BinarizeKwargs.threshold`/`BinarizeAlongAxisKwargs.threshold`
934    will be set to one, values below the threshold to zero.
935
936    Examples:
937    - in YAML
938        ```yaml
939        postprocessing:
940          - id: binarize
941            kwargs:
942              axis: 'channel'
943              threshold: [0.25, 0.5, 0.75]
944        ```
945    - in Python:
946        >>> postprocessing = [BinarizeDescr(
947        ...   kwargs=BinarizeAlongAxisKwargs(
948        ...       axis=AxisId('channel'),
949        ...       threshold=[0.25, 0.5, 0.75],
950        ...   )
951        ... )]
952    """
953
954    implemented_id: ClassVar[Literal["binarize"]] = "binarize"
955    if TYPE_CHECKING:
956        id: Literal["binarize"] = "binarize"
957    else:
958        id: Literal["binarize"]
959    kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs]

Binarize the tensor with a fixed threshold.

Values above BinarizeKwargs.threshold/BinarizeAlongAxisKwargs.threshold will be set to one, values below the threshold to zero.

Examples:

  • in YAML
postprocessing:
  - id: binarize
    kwargs:
      axis: 'channel'
      threshold: [0.25, 0.5, 0.75]
  • in Python:
    >>> postprocessing = [BinarizeDescr(
    ...   kwargs=BinarizeAlongAxisKwargs(
    ...       axis=AxisId('channel'),
    ...       threshold=[0.25, 0.5, 0.75],
    ...   )
    ... )]
    
implemented_id: ClassVar[Literal['binarize']] = 'binarize'
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

id: Literal['binarize']
class ClipDescr(ProcessingDescrBase):
962class ClipDescr(ProcessingDescrBase):
963    """Set tensor values below min to min and above max to max.
964
965    See `ScaleRangeDescr` for examples.
966    """
967
968    implemented_id: ClassVar[Literal["clip"]] = "clip"
969    if TYPE_CHECKING:
970        id: Literal["clip"] = "clip"
971    else:
972        id: Literal["clip"]
973
974    kwargs: ClipKwargs

Set tensor values below min to min and above max to max.

See ScaleRangeDescr for examples.

implemented_id: ClassVar[Literal['clip']] = 'clip'
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

id: Literal['clip']
class EnsureDtypeKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
977class EnsureDtypeKwargs(ProcessingKwargs):
978    """key word arguments for `EnsureDtypeDescr`"""
979
980    dtype: Literal[
981        "float32",
982        "float64",
983        "uint8",
984        "int8",
985        "uint16",
986        "int16",
987        "uint32",
988        "int32",
989        "uint64",
990        "int64",
991        "bool",
992    ]

key word arguments for EnsureDtypeDescr

dtype: Literal['float32', 'float64', 'uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64', 'bool']
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class EnsureDtypeDescr(ProcessingDescrBase):
 995class EnsureDtypeDescr(ProcessingDescrBase):
 996    """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching).
 997
 998    This can for example be used to ensure the inner neural network model gets a
 999    different input tensor data type than the fully described bioimage.io model does.
1000
1001    Examples:
1002        The described bioimage.io model (incl. preprocessing) accepts any
1003        float32-compatible tensor, normalizes it with percentiles and clipping and then
1004        casts it to uint8, which is what the neural network in this example expects.
1005        - in YAML
1006            ```yaml
1007            inputs:
1008            - data:
1009                type: float32  # described bioimage.io model is compatible with any float32 input tensor
1010              preprocessing:
1011              - id: scale_range
1012                  kwargs:
1013                  axes: ['y', 'x']
1014                  max_percentile: 99.8
1015                  min_percentile: 5.0
1016              - id: clip
1017                  kwargs:
1018                  min: 0.0
1019                  max: 1.0
1020              - id: ensure_dtype  # the neural network of the model requires uint8
1021                  kwargs:
1022                  dtype: uint8
1023            ```
1024        - in Python:
1025            >>> preprocessing = [
1026            ...     ScaleRangeDescr(
1027            ...         kwargs=ScaleRangeKwargs(
1028            ...           axes= (AxisId('y'), AxisId('x')),
1029            ...           max_percentile= 99.8,
1030            ...           min_percentile= 5.0,
1031            ...         )
1032            ...     ),
1033            ...     ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)),
1034            ...     EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")),
1035            ... ]
1036    """
1037
1038    implemented_id: ClassVar[Literal["ensure_dtype"]] = "ensure_dtype"
1039    if TYPE_CHECKING:
1040        id: Literal["ensure_dtype"] = "ensure_dtype"
1041    else:
1042        id: Literal["ensure_dtype"]
1043
1044    kwargs: EnsureDtypeKwargs

Cast the tensor data type to EnsureDtypeKwargs.dtype (if not matching).

This can for example be used to ensure the inner neural network model gets a different input tensor data type than the fully described bioimage.io model does.

Examples:

The described bioimage.io model (incl. preprocessing) accepts any float32-compatible tensor, normalizes it with percentiles and clipping and then casts it to uint8, which is what the neural network in this example expects.

  • in YAML

inputs:
- data:
    type: float32  # described bioimage.io model is compatible with any float32 input tensor
  preprocessing:
  - id: scale_range
      kwargs:
      axes: ['y', 'x']
      max_percentile: 99.8
      min_percentile: 5.0
  - id: clip
      kwargs:
      min: 0.0
      max: 1.0
  - id: ensure_dtype  # the neural network of the model requires uint8
      kwargs:
      dtype: uint8

  • in Python:
    >>> preprocessing = [
    ...     ScaleRangeDescr(
    ...         kwargs=ScaleRangeKwargs(
    ...           axes= (AxisId('y'), AxisId('x')),
    ...           max_percentile= 99.8,
    ...           min_percentile= 5.0,
    ...         )
    ...     ),
    ...     ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)),
    ...     EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")),
    ... ]
    
implemented_id: ClassVar[Literal['ensure_dtype']] = 'ensure_dtype'
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

id: Literal['ensure_dtype']
class ScaleLinearKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1047class ScaleLinearKwargs(ProcessingKwargs):
1048    """Key word arguments for `ScaleLinearDescr`"""
1049
1050    gain: float = 1.0
1051    """multiplicative factor"""
1052
1053    offset: float = 0.0
1054    """additive term"""
1055
1056    @model_validator(mode="after")
1057    def _validate(self) -> Self:
1058        if self.gain == 1.0 and self.offset == 0.0:
1059            raise ValueError(
1060                "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
1061                + " != 0.0."
1062            )
1063
1064        return self

Key word arguments for ScaleLinearDescr

gain: float

multiplicative factor

offset: float

additive term

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class ScaleLinearAlongAxisKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1067class ScaleLinearAlongAxisKwargs(ProcessingKwargs):
1068    """Key word arguments for `ScaleLinearDescr`"""
1069
1070    axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
1071    """The axis of gain and offset values."""
1072
1073    gain: Union[float, NotEmpty[List[float]]] = 1.0
1074    """multiplicative factor"""
1075
1076    offset: Union[float, NotEmpty[List[float]]] = 0.0
1077    """additive term"""
1078
1079    @model_validator(mode="after")
1080    def _validate(self) -> Self:
1081
1082        if isinstance(self.gain, list):
1083            if isinstance(self.offset, list):
1084                if len(self.gain) != len(self.offset):
1085                    raise ValueError(
1086                        f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match."
1087                    )
1088            else:
1089                self.offset = [float(self.offset)] * len(self.gain)
1090        elif isinstance(self.offset, list):
1091            self.gain = [float(self.gain)] * len(self.offset)
1092        else:
1093            raise ValueError(
1094                "Do not specify an `axis` for scalar gain and offset values."
1095            )
1096
1097        if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset):
1098            raise ValueError(
1099                "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
1100                + " != 0.0."
1101            )
1102
1103        return self

Key word arguments for ScaleLinearDescr

axis: Annotated[AxisId, Predicate(_is_not_batch), FieldInfo(annotation=NoneType, required=True, examples=['channel'])]

The axis of gain and offset values.

gain: Union[float, Annotated[List[float], MinLen(min_length=1)]]

multiplicative factor

offset: Union[float, Annotated[List[float], MinLen(min_length=1)]]

additive term

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class ScaleLinearDescr(ProcessingDescrBase):
1106class ScaleLinearDescr(ProcessingDescrBase):
1107    """Fixed linear scaling.
1108
1109    Examples:
1110      1. Scale with scalar gain and offset
1111        - in YAML
1112        ```yaml
1113        preprocessing:
1114          - id: scale_linear
1115            kwargs:
1116              gain: 2.0
1117              offset: 3.0
1118        ```
1119        - in Python:
1120        >>> preprocessing = [
1121        ...     ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0))
1122        ... ]
1123
1124      2. Independent scaling along an axis
1125        - in YAML
1126        ```yaml
1127        preprocessing:
1128          - id: scale_linear
1129            kwargs:
1130              axis: 'channel'
1131              gain: [1.0, 2.0, 3.0]
1132        ```
1133        - in Python:
1134        >>> preprocessing = [
1135        ...     ScaleLinearDescr(
1136        ...         kwargs=ScaleLinearAlongAxisKwargs(
1137        ...             axis=AxisId("channel"),
1138        ...             gain=[1.0, 2.0, 3.0],
1139        ...         )
1140        ...     )
1141        ... ]
1142
1143    """
1144
1145    implemented_id: ClassVar[Literal["scale_linear"]] = "scale_linear"
1146    if TYPE_CHECKING:
1147        id: Literal["scale_linear"] = "scale_linear"
1148    else:
1149        id: Literal["scale_linear"]
1150    kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs]

Fixed linear scaling.

Examples:
  1. Scale with scalar gain and offset

    • in YAML

      preprocessing:
        - id: scale_linear
          kwargs:
            gain: 2.0
            offset: 3.0
      
    • in Python:

      >>> preprocessing = [
      ...     ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0))
      ... ]
      
  2. Independent scaling along an axis

    • in YAML

      preprocessing:
        - id: scale_linear
          kwargs:
            axis: 'channel'
            gain: [1.0, 2.0, 3.0]
      
    • in Python:

      >>> preprocessing = [
      ...     ScaleLinearDescr(
      ...         kwargs=ScaleLinearAlongAxisKwargs(
      ...             axis=AxisId("channel"),
      ...             gain=[1.0, 2.0, 3.0],
      ...         )
      ...     )
      ... ]
      
implemented_id: ClassVar[Literal['scale_linear']] = 'scale_linear'
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

id: Literal['scale_linear']
class SigmoidDescr(ProcessingDescrBase):
1153class SigmoidDescr(ProcessingDescrBase):
1154    """The logistic sigmoid funciton, a.k.a. expit function.
1155
1156    Examples:
1157    - in YAML
1158        ```yaml
1159        postprocessing:
1160          - id: sigmoid
1161        ```
1162    - in Python:
1163        >>> postprocessing = [SigmoidDescr()]
1164    """
1165
1166    implemented_id: ClassVar[Literal["sigmoid"]] = "sigmoid"
1167    if TYPE_CHECKING:
1168        id: Literal["sigmoid"] = "sigmoid"
1169    else:
1170        id: Literal["sigmoid"]
1171
1172    @property
1173    def kwargs(self) -> ProcessingKwargs:
1174        """empty kwargs"""
1175        return ProcessingKwargs()

The logistic sigmoid funciton, a.k.a. expit function.

Examples:

  • in YAML
postprocessing:
  - id: sigmoid
  • in Python:
    >>> postprocessing = [SigmoidDescr()]
    
implemented_id: ClassVar[Literal['sigmoid']] = 'sigmoid'
1172    @property
1173    def kwargs(self) -> ProcessingKwargs:
1174        """empty kwargs"""
1175        return ProcessingKwargs()

empty kwargs

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

id: Literal['sigmoid']
class FixedZeroMeanUnitVarianceKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1178class FixedZeroMeanUnitVarianceKwargs(ProcessingKwargs):
1179    """key word arguments for `FixedZeroMeanUnitVarianceDescr`"""
1180
1181    mean: float
1182    """The mean value to normalize with."""
1183
1184    std: Annotated[float, Ge(1e-6)]
1185    """The standard deviation value to normalize with."""

key word arguments for FixedZeroMeanUnitVarianceDescr

mean: float

The mean value to normalize with.

std: Annotated[float, Ge(ge=1e-06)]

The standard deviation value to normalize with.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class FixedZeroMeanUnitVarianceAlongAxisKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1188class FixedZeroMeanUnitVarianceAlongAxisKwargs(ProcessingKwargs):
1189    """key word arguments for `FixedZeroMeanUnitVarianceDescr`"""
1190
1191    mean: NotEmpty[List[float]]
1192    """The mean value(s) to normalize with."""
1193
1194    std: NotEmpty[List[Annotated[float, Ge(1e-6)]]]
1195    """The standard deviation value(s) to normalize with.
1196    Size must match `mean` values."""
1197
1198    axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])]
1199    """The axis of the mean/std values to normalize each entry along that dimension
1200    separately."""
1201
1202    @model_validator(mode="after")
1203    def _mean_and_std_match(self) -> Self:
1204        if len(self.mean) != len(self.std):
1205            raise ValueError(
1206                f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})"
1207                + " must match."
1208            )
1209
1210        return self

key word arguments for FixedZeroMeanUnitVarianceDescr

mean: Annotated[List[float], MinLen(min_length=1)]

The mean value(s) to normalize with.

std: Annotated[List[Annotated[float, Ge(ge=1e-06)]], MinLen(min_length=1)]

The standard deviation value(s) to normalize with. Size must match mean values.

axis: Annotated[AxisId, Predicate(_is_not_batch), FieldInfo(annotation=NoneType, required=True, examples=['channel', 'index'])]

The axis of the mean/std values to normalize each entry along that dimension separately.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1213class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1214    """Subtract a given mean and divide by the standard deviation.
1215
1216    Normalize with fixed, precomputed values for
1217    `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std`
1218    Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given
1219    axes.
1220
1221    Examples:
1222    1. scalar value for whole tensor
1223        - in YAML
1224        ```yaml
1225        preprocessing:
1226          - id: fixed_zero_mean_unit_variance
1227            kwargs:
1228              mean: 103.5
1229              std: 13.7
1230        ```
1231        - in Python
1232        >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
1233        ...   kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7)
1234        ... )]
1235
1236    2. independently along an axis
1237        - in YAML
1238        ```yaml
1239        preprocessing:
1240          - id: fixed_zero_mean_unit_variance
1241            kwargs:
1242              axis: channel
1243              mean: [101.5, 102.5, 103.5]
1244              std: [11.7, 12.7, 13.7]
1245        ```
1246        - in Python
1247        >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
1248        ...   kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
1249        ...     axis=AxisId("channel"),
1250        ...     mean=[101.5, 102.5, 103.5],
1251        ...     std=[11.7, 12.7, 13.7],
1252        ...   )
1253        ... )]
1254    """
1255
1256    implemented_id: ClassVar[Literal["fixed_zero_mean_unit_variance"]] = (
1257        "fixed_zero_mean_unit_variance"
1258    )
1259    if TYPE_CHECKING:
1260        id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance"
1261    else:
1262        id: Literal["fixed_zero_mean_unit_variance"]
1263
1264    kwargs: Union[
1265        FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs
1266    ]

Subtract a given mean and divide by the standard deviation.

Normalize with fixed, precomputed values for FixedZeroMeanUnitVarianceKwargs.mean and FixedZeroMeanUnitVarianceKwargs.std Use FixedZeroMeanUnitVarianceAlongAxisKwargs for independent scaling along given axes.

Examples:

  1. scalar value for whole tensor
    • in YAML
preprocessing:
  - id: fixed_zero_mean_unit_variance
    kwargs:
      mean: 103.5
      std: 13.7
- in Python >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( ... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7) ... )]
  1. independently along an axis

    • in YAML
    preprocessing:
      - id: fixed_zero_mean_unit_variance
        kwargs:
          axis: channel
          mean: [101.5, 102.5, 103.5]
          std: [11.7, 12.7, 13.7]
    
    • in Python
      >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
      ...   kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
      ...     axis=AxisId("channel"),
      ...     mean=[101.5, 102.5, 103.5],
      ...     std=[11.7, 12.7, 13.7],
      ...   )
      ... )]
      
implemented_id: ClassVar[Literal['fixed_zero_mean_unit_variance']] = 'fixed_zero_mean_unit_variance'
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

id: Literal['fixed_zero_mean_unit_variance']
class ZeroMeanUnitVarianceKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1269class ZeroMeanUnitVarianceKwargs(ProcessingKwargs):
1270    """key word arguments for `ZeroMeanUnitVarianceDescr`"""
1271
1272    axes: Annotated[
1273        Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1274    ] = None
1275    """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1276    For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1277    resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1278    To normalize each sample independently leave out the 'batch' axis.
1279    Default: Scale all axes jointly."""
1280
1281    eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1282    """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`."""

key word arguments for ZeroMeanUnitVarianceDescr

axes: Annotated[Optional[Sequence[AxisId]], FieldInfo(annotation=NoneType, required=True, examples=[('batch', 'x', 'y')])]

The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') resulting in a tensor of equal shape normalized per channel, specify axes=('batch', 'x', 'y'). To normalize each sample independently leave out the 'batch' axis. Default: Scale all axes jointly.

eps: Annotated[float, Interval(gt=0, ge=None, lt=None, le=0.1)]

epsilon for numeric stability: out = (tensor - mean) / (std + eps).

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class ZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1285class ZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1286    """Subtract mean and divide by variance.
1287
1288    Examples:
1289        Subtract tensor mean and variance
1290        - in YAML
1291        ```yaml
1292        preprocessing:
1293          - id: zero_mean_unit_variance
1294        ```
1295        - in Python
1296        >>> preprocessing = [ZeroMeanUnitVarianceDescr()]
1297    """
1298
1299    implemented_id: ClassVar[Literal["zero_mean_unit_variance"]] = (
1300        "zero_mean_unit_variance"
1301    )
1302    if TYPE_CHECKING:
1303        id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance"
1304    else:
1305        id: Literal["zero_mean_unit_variance"]
1306
1307    kwargs: ZeroMeanUnitVarianceKwargs = Field(
1308        default_factory=ZeroMeanUnitVarianceKwargs
1309    )

Subtract mean and divide by variance.

Examples:

Subtract tensor mean and variance

  • in YAML
preprocessing:
  - id: zero_mean_unit_variance
  • in Python
    >>> preprocessing = [ZeroMeanUnitVarianceDescr()]
    
implemented_id: ClassVar[Literal['zero_mean_unit_variance']] = 'zero_mean_unit_variance'
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

id: Literal['zero_mean_unit_variance']
class ScaleRangeKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1312class ScaleRangeKwargs(ProcessingKwargs):
1313    """key word arguments for `ScaleRangeDescr`
1314
1315    For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default)
1316    this processing step normalizes data to the [0, 1] intervall.
1317    For other percentiles the normalized values will partially be outside the [0, 1]
1318    intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the
1319    normalized values to a range.
1320    """
1321
1322    axes: Annotated[
1323        Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1324    ] = None
1325    """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value.
1326    For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1327    resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1328    To normalize samples independently, leave out the "batch" axis.
1329    Default: Scale all axes jointly."""
1330
1331    min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0
1332    """The lower percentile used to determine the value to align with zero."""
1333
1334    max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0
1335    """The upper percentile used to determine the value to align with one.
1336    Has to be bigger than `min_percentile`.
1337    The range is 1 to 100 instead of 0 to 100 to avoid mistakenly
1338    accepting percentiles specified in the range 0.0 to 1.0."""
1339
1340    eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1341    """Epsilon for numeric stability.
1342    `out = (tensor - v_lower) / (v_upper - v_lower + eps)`;
1343    with `v_lower,v_upper` values at the respective percentiles."""
1344
1345    reference_tensor: Optional[TensorId] = None
1346    """Tensor ID to compute the percentiles from. Default: The tensor itself.
1347    For any tensor in `inputs` only input tensor references are allowed."""
1348
1349    @field_validator("max_percentile", mode="after")
1350    @classmethod
1351    def min_smaller_max(cls, value: float, info: ValidationInfo) -> float:
1352        if (min_p := info.data["min_percentile"]) >= value:
1353            raise ValueError(f"min_percentile {min_p} >= max_percentile {value}")
1354
1355        return value

key word arguments for ScaleRangeDescr

For min_percentile=0.0 (the default) and max_percentile=100 (the default) this processing step normalizes data to the [0, 1] intervall. For other percentiles the normalized values will partially be outside the [0, 1] intervall. Use ScaleRange followed by ClipDescr if you want to limit the normalized values to a range.

axes: Annotated[Optional[Sequence[AxisId]], FieldInfo(annotation=NoneType, required=True, examples=[('batch', 'x', 'y')])]

The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value. For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') resulting in a tensor of equal shape normalized per channel, specify axes=('batch', 'x', 'y'). To normalize samples independently, leave out the "batch" axis. Default: Scale all axes jointly.

min_percentile: Annotated[float, Interval(gt=None, ge=0, lt=100, le=None)]

The lower percentile used to determine the value to align with zero.

max_percentile: Annotated[float, Interval(gt=1, ge=None, lt=None, le=100)]

The upper percentile used to determine the value to align with one. Has to be bigger than min_percentile. The range is 1 to 100 instead of 0 to 100 to avoid mistakenly accepting percentiles specified in the range 0.0 to 1.0.

eps: Annotated[float, Interval(gt=0, ge=None, lt=None, le=0.1)]

Epsilon for numeric stability. out = (tensor - v_lower) / (v_upper - v_lower + eps); with v_lower,v_upper values at the respective percentiles.

reference_tensor: Optional[TensorId]

Tensor ID to compute the percentiles from. Default: The tensor itself. For any tensor in inputs only input tensor references are allowed.

@field_validator('max_percentile', mode='after')
@classmethod
def min_smaller_max( cls, value: float, info: pydantic_core.core_schema.ValidationInfo) -> float:
1349    @field_validator("max_percentile", mode="after")
1350    @classmethod
1351    def min_smaller_max(cls, value: float, info: ValidationInfo) -> float:
1352        if (min_p := info.data["min_percentile"]) >= value:
1353            raise ValueError(f"min_percentile {min_p} >= max_percentile {value}")
1354
1355        return value
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class ScaleRangeDescr(ProcessingDescrBase):
1358class ScaleRangeDescr(ProcessingDescrBase):
1359    """Scale with percentiles.
1360
1361    Examples:
1362    1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0
1363        - in YAML
1364        ```yaml
1365        preprocessing:
1366          - id: scale_range
1367            kwargs:
1368              axes: ['y', 'x']
1369              max_percentile: 99.8
1370              min_percentile: 5.0
1371        ```
1372        - in Python
1373        >>> preprocessing = [
1374        ...     ScaleRangeDescr(
1375        ...         kwargs=ScaleRangeKwargs(
1376        ...           axes= (AxisId('y'), AxisId('x')),
1377        ...           max_percentile= 99.8,
1378        ...           min_percentile= 5.0,
1379        ...         )
1380        ...     ),
1381        ...     ClipDescr(
1382        ...         kwargs=ClipKwargs(
1383        ...             min=0.0,
1384        ...             max=1.0,
1385        ...         )
1386        ...     ),
1387        ... ]
1388
1389      2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles.
1390        - in YAML
1391        ```yaml
1392        preprocessing:
1393          - id: scale_range
1394            kwargs:
1395              axes: ['y', 'x']
1396              max_percentile: 99.8
1397              min_percentile: 5.0
1398                  - id: scale_range
1399           - id: clip
1400             kwargs:
1401              min: 0.0
1402              max: 1.0
1403        ```
1404        - in Python
1405        >>> preprocessing = [ScaleRangeDescr(
1406        ...   kwargs=ScaleRangeKwargs(
1407        ...       axes= (AxisId('y'), AxisId('x')),
1408        ...       max_percentile= 99.8,
1409        ...       min_percentile= 5.0,
1410        ...   )
1411        ... )]
1412
1413    """
1414
1415    implemented_id: ClassVar[Literal["scale_range"]] = "scale_range"
1416    if TYPE_CHECKING:
1417        id: Literal["scale_range"] = "scale_range"
1418    else:
1419        id: Literal["scale_range"]
1420    kwargs: ScaleRangeKwargs

Scale with percentiles.

Examples:

  1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0
    • in YAML
preprocessing:
  - id: scale_range
    kwargs:
      axes: ['y', 'x']
      max_percentile: 99.8
      min_percentile: 5.0
- in Python >>> preprocessing = [ ... ScaleRangeDescr( ... kwargs=ScaleRangeKwargs( ... axes= (AxisId('y'), AxisId('x')), ... max_percentile= 99.8, ... min_percentile= 5.0, ... ) ... ), ... ClipDescr( ... kwargs=ClipKwargs( ... min=0.0, ... max=1.0, ... ) ... ), ... ]
  1. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles.

    • in YAML
    preprocessing:
      - id: scale_range
        kwargs:
          axes: ['y', 'x']
          max_percentile: 99.8
          min_percentile: 5.0
              - id: scale_range
       - id: clip
         kwargs:
          min: 0.0
          max: 1.0
    
    • in Python
      >>> preprocessing = [ScaleRangeDescr(
      ...   kwargs=ScaleRangeKwargs(
      ...       axes= (AxisId('y'), AxisId('x')),
      ...       max_percentile= 99.8,
      ...       min_percentile= 5.0,
      ...   )
      ... )]
      
implemented_id: ClassVar[Literal['scale_range']] = 'scale_range'
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

id: Literal['scale_range']
class ScaleMeanVarianceKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1423class ScaleMeanVarianceKwargs(ProcessingKwargs):
1424    """key word arguments for `ScaleMeanVarianceKwargs`"""
1425
1426    reference_tensor: TensorId
1427    """Name of tensor to match."""
1428
1429    axes: Annotated[
1430        Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1431    ] = None
1432    """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1433    For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1434    resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1435    To normalize samples independently, leave out the 'batch' axis.
1436    Default: Scale all axes jointly."""
1437
1438    eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1439    """Epsilon for numeric stability:
1440    `out  = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`"""

key word arguments for ScaleMeanVarianceKwargs

reference_tensor: TensorId

Name of tensor to match.

axes: Annotated[Optional[Sequence[AxisId]], FieldInfo(annotation=NoneType, required=True, examples=[('batch', 'x', 'y')])]

The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') resulting in a tensor of equal shape normalized per channel, specify axes=('batch', 'x', 'y'). To normalize samples independently, leave out the 'batch' axis. Default: Scale all axes jointly.

eps: Annotated[float, Interval(gt=0, ge=None, lt=None, le=0.1)]

Epsilon for numeric stability: out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class ScaleMeanVarianceDescr(ProcessingDescrBase):
1443class ScaleMeanVarianceDescr(ProcessingDescrBase):
1444    """Scale a tensor's data distribution to match another tensor's mean/std.
1445    `out  = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`
1446    """
1447
1448    implemented_id: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance"
1449    if TYPE_CHECKING:
1450        id: Literal["scale_mean_variance"] = "scale_mean_variance"
1451    else:
1452        id: Literal["scale_mean_variance"]
1453    kwargs: ScaleMeanVarianceKwargs

Scale a tensor's data distribution to match another tensor's mean/std. out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.

implemented_id: ClassVar[Literal['scale_mean_variance']] = 'scale_mean_variance'
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

id: Literal['scale_mean_variance']
PreprocessingDescr = typing.Annotated[typing.Union[BinarizeDescr, ClipDescr, EnsureDtypeDescr, ScaleLinearDescr, SigmoidDescr, FixedZeroMeanUnitVarianceDescr, ZeroMeanUnitVarianceDescr, ScaleRangeDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)]
PostprocessingDescr = typing.Annotated[typing.Union[BinarizeDescr, ClipDescr, EnsureDtypeDescr, ScaleLinearDescr, SigmoidDescr, FixedZeroMeanUnitVarianceDescr, ZeroMeanUnitVarianceDescr, ScaleRangeDescr, ScaleMeanVarianceDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)]
class TensorDescrBase(bioimageio.spec._internal.node.Node, typing.Generic[~IO_AxisT]):
1487class TensorDescrBase(Node, Generic[IO_AxisT]):
1488    id: TensorId
1489    """Tensor id. No duplicates are allowed."""
1490
1491    description: Annotated[str, MaxLen(128)] = ""
1492    """free text description"""
1493
1494    axes: NotEmpty[Sequence[IO_AxisT]]
1495    """tensor axes"""
1496
1497    @property
1498    def shape(self):
1499        return tuple(a.size for a in self.axes)
1500
1501    @field_validator("axes", mode="after", check_fields=False)
1502    @classmethod
1503    def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]:
1504        batch_axes = [a for a in axes if a.type == "batch"]
1505        if len(batch_axes) > 1:
1506            raise ValueError(
1507                f"Only one batch axis (per tensor) allowed, but got {batch_axes}"
1508            )
1509
1510        seen_ids: Set[AxisId] = set()
1511        duplicate_axes_ids: Set[AxisId] = set()
1512        for a in axes:
1513            (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id)
1514
1515        if duplicate_axes_ids:
1516            raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}")
1517
1518        return axes
1519
1520    test_tensor: FileDescr_
1521    """An example tensor to use for testing.
1522    Using the model with the test input tensors is expected to yield the test output tensors.
1523    Each test tensor has be a an ndarray in the
1524    [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format).
1525    The file extension must be '.npy'."""
1526
1527    sample_tensor: Optional[FileDescr_] = None
1528    """A sample tensor to illustrate a possible input/output for the model,
1529    The sample image primarily serves to inform a human user about an example use case
1530    and is typically stored as .hdf5, .png or .tiff.
1531    It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats)
1532    (numpy's `.npy` format is not supported).
1533    The image dimensionality has to match the number of axes specified in this tensor description.
1534    """
1535
1536    @model_validator(mode="after")
1537    def _validate_sample_tensor(self) -> Self:
1538        if self.sample_tensor is None or not get_validation_context().perform_io_checks:
1539            return self
1540
1541        reader = get_reader(self.sample_tensor.source, sha256=self.sample_tensor.sha256)
1542        tensor: NDArray[Any] = imread(
1543            reader.read(),
1544            extension=PurePosixPath(reader.original_file_name).suffix,
1545        )
1546        n_dims = len(tensor.squeeze().shape)
1547        n_dims_min = n_dims_max = len(self.axes)
1548
1549        for a in self.axes:
1550            if isinstance(a, BatchAxis):
1551                n_dims_min -= 1
1552            elif isinstance(a.size, int):
1553                if a.size == 1:
1554                    n_dims_min -= 1
1555            elif isinstance(a.size, (ParameterizedSize, DataDependentSize)):
1556                if a.size.min == 1:
1557                    n_dims_min -= 1
1558            elif isinstance(a.size, SizeReference):
1559                if a.size.offset < 2:
1560                    # size reference may result in singleton axis
1561                    n_dims_min -= 1
1562            else:
1563                assert_never(a.size)
1564
1565        n_dims_min = max(0, n_dims_min)
1566        if n_dims < n_dims_min or n_dims > n_dims_max:
1567            raise ValueError(
1568                f"Expected sample tensor to have {n_dims_min} to"
1569                + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})."
1570            )
1571
1572        return self
1573
1574    data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = (
1575        IntervalOrRatioDataDescr()
1576    )
1577    """Description of the tensor's data values, optionally per channel.
1578    If specified per channel, the data `type` needs to match across channels."""
1579
1580    @property
1581    def dtype(
1582        self,
1583    ) -> Literal[
1584        "float32",
1585        "float64",
1586        "uint8",
1587        "int8",
1588        "uint16",
1589        "int16",
1590        "uint32",
1591        "int32",
1592        "uint64",
1593        "int64",
1594        "bool",
1595    ]:
1596        """dtype as specified under `data.type` or `data[i].type`"""
1597        if isinstance(self.data, collections.abc.Sequence):
1598            return self.data[0].type
1599        else:
1600            return self.data.type
1601
1602    @field_validator("data", mode="after")
1603    @classmethod
1604    def _check_data_type_across_channels(
1605        cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]
1606    ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]:
1607        if not isinstance(value, list):
1608            return value
1609
1610        dtypes = {t.type for t in value}
1611        if len(dtypes) > 1:
1612            raise ValueError(
1613                "Tensor data descriptions per channel need to agree in their data"
1614                + f" `type`, but found {dtypes}."
1615            )
1616
1617        return value
1618
1619    @model_validator(mode="after")
1620    def _check_data_matches_channelaxis(self) -> Self:
1621        if not isinstance(self.data, (list, tuple)):
1622            return self
1623
1624        for a in self.axes:
1625            if isinstance(a, ChannelAxis):
1626                size = a.size
1627                assert isinstance(size, int)
1628                break
1629        else:
1630            return self
1631
1632        if len(self.data) != size:
1633            raise ValueError(
1634                f"Got tensor data descriptions for {len(self.data)} channels, but"
1635                + f" '{a.id}' axis has size {size}."
1636            )
1637
1638        return self
1639
1640    def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]:
1641        if len(array.shape) != len(self.axes):
1642            raise ValueError(
1643                f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})"
1644                + f" incompatible with {len(self.axes)} axes."
1645            )
1646        return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
id: TensorId

Tensor id. No duplicates are allowed.

description: Annotated[str, MaxLen(max_length=128)]

free text description

axes: Annotated[Sequence[~IO_AxisT], MinLen(min_length=1)]

tensor axes

shape
1497    @property
1498    def shape(self):
1499        return tuple(a.size for a in self.axes)
test_tensor: Annotated[bioimageio.spec._internal.io.FileDescr, AfterValidator(func=<function wo_special_file_name at 0x7f5fe201bd80>), WrapSerializer(func=<function package_file_descr_serializer at 0x7f5fe20b6de0>, return_type=PydanticUndefined, when_used='unless-none')]

An example tensor to use for testing. Using the model with the test input tensors is expected to yield the test output tensors. Each test tensor has be a an ndarray in the numpy.lib file format. The file extension must be '.npy'.

sample_tensor: Optional[Annotated[bioimageio.spec._internal.io.FileDescr, AfterValidator(func=<function wo_special_file_name at 0x7f5fe201bd80>), WrapSerializer(func=<function package_file_descr_serializer at 0x7f5fe20b6de0>, return_type=PydanticUndefined, when_used='unless-none')]]

A sample tensor to illustrate a possible input/output for the model, The sample image primarily serves to inform a human user about an example use case and is typically stored as .hdf5, .png or .tiff. It has to be readable by the imageio library (numpy's .npy format is not supported). The image dimensionality has to match the number of axes specified in this tensor description.

data: Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr, Annotated[Sequence[Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr]], MinLen(min_length=1)]]

Description of the tensor's data values, optionally per channel. If specified per channel, the data type needs to match across channels.

dtype: Literal['float32', 'float64', 'uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64', 'bool']
1580    @property
1581    def dtype(
1582        self,
1583    ) -> Literal[
1584        "float32",
1585        "float64",
1586        "uint8",
1587        "int8",
1588        "uint16",
1589        "int16",
1590        "uint32",
1591        "int32",
1592        "uint64",
1593        "int64",
1594        "bool",
1595    ]:
1596        """dtype as specified under `data.type` or `data[i].type`"""
1597        if isinstance(self.data, collections.abc.Sequence):
1598            return self.data[0].type
1599        else:
1600            return self.data.type

dtype as specified under data.type or data[i].type

def get_axis_sizes_for_array( self, array: numpy.ndarray[tuple[typing.Any, ...], numpy.dtype[typing.Any]]) -> Dict[AxisId, int]:
1640    def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]:
1641        if len(array.shape) != len(self.axes):
1642            raise ValueError(
1643                f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})"
1644                + f" incompatible with {len(self.axes)} axes."
1645            )
1646        return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
class InputTensorDescr(bioimageio.spec._internal.node.Node, typing.Generic[~IO_AxisT]):
1649class InputTensorDescr(TensorDescrBase[InputAxis]):
1650    id: TensorId = TensorId("input")
1651    """Input tensor id.
1652    No duplicates are allowed across all inputs and outputs."""
1653
1654    optional: bool = False
1655    """indicates that this tensor may be `None`"""
1656
1657    preprocessing: List[PreprocessingDescr] = Field(
1658        default_factory=cast(Callable[[], List[PreprocessingDescr]], list)
1659    )
1660
1661    """Description of how this input should be preprocessed.
1662
1663    notes:
1664    - If preprocessing does not start with an 'ensure_dtype' entry, it is added
1665      to ensure an input tensor's data type matches the input tensor's data description.
1666    - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an
1667      'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally
1668      changing the data type.
1669    """
1670
1671    @model_validator(mode="after")
1672    def _validate_preprocessing_kwargs(self) -> Self:
1673        axes_ids = [a.id for a in self.axes]
1674        for p in self.preprocessing:
1675            kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
1676            if kwargs_axes is None:
1677                continue
1678
1679            if not isinstance(kwargs_axes, collections.abc.Sequence):
1680                raise ValueError(
1681                    f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}"
1682                )
1683
1684            if any(a not in axes_ids for a in kwargs_axes):
1685                raise ValueError(
1686                    "`preprocessing.i.kwargs.axes` needs to be subset of axes ids"
1687                )
1688
1689        if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
1690            dtype = self.data.type
1691        else:
1692            dtype = self.data[0].type
1693
1694        # ensure `preprocessing` begins with `EnsureDtypeDescr`
1695        if not self.preprocessing or not isinstance(
1696            self.preprocessing[0], EnsureDtypeDescr
1697        ):
1698            self.preprocessing.insert(
1699                0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1700            )
1701
1702        # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
1703        if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)):
1704            self.preprocessing.append(
1705                EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1706            )
1707
1708        return self
id: TensorId

Input tensor id. No duplicates are allowed across all inputs and outputs.

optional: bool

indicates that this tensor may be None

preprocessing: List[Annotated[Union[BinarizeDescr, ClipDescr, EnsureDtypeDescr, ScaleLinearDescr, SigmoidDescr, FixedZeroMeanUnitVarianceDescr, ZeroMeanUnitVarianceDescr, ScaleRangeDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]

Description of how this input should be preprocessed.

notes:

  • If preprocessing does not start with an 'ensure_dtype' entry, it is added to ensure an input tensor's data type matches the input tensor's data description.
  • If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally changing the data type.
def convert_axes( axes: str, *, shape: Union[Sequence[int], bioimageio.spec.model.v0_4.ParameterizedInputShape, bioimageio.spec.model.v0_4.ImplicitOutputShape], tensor_type: Literal['input', 'output'], halo: Optional[Sequence[int]], size_refs: Mapping[bioimageio.spec.model.v0_4.TensorName, Mapping[str, int]]):
1711def convert_axes(
1712    axes: str,
1713    *,
1714    shape: Union[
1715        Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4
1716    ],
1717    tensor_type: Literal["input", "output"],
1718    halo: Optional[Sequence[int]],
1719    size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
1720):
1721    ret: List[AnyAxis] = []
1722    for i, a in enumerate(axes):
1723        axis_type = _AXIS_TYPE_MAP.get(a, a)
1724        if axis_type == "batch":
1725            ret.append(BatchAxis())
1726            continue
1727
1728        scale = 1.0
1729        if isinstance(shape, _ParameterizedInputShape_v0_4):
1730            if shape.step[i] == 0:
1731                size = shape.min[i]
1732            else:
1733                size = ParameterizedSize(min=shape.min[i], step=shape.step[i])
1734        elif isinstance(shape, _ImplicitOutputShape_v0_4):
1735            ref_t = str(shape.reference_tensor)
1736            if ref_t.count(".") == 1:
1737                t_id, orig_a_id = ref_t.split(".")
1738            else:
1739                t_id = ref_t
1740                orig_a_id = a
1741
1742            a_id = _AXIS_ID_MAP.get(orig_a_id, a)
1743            if not (orig_scale := shape.scale[i]):
1744                # old way to insert a new axis dimension
1745                size = int(2 * shape.offset[i])
1746            else:
1747                scale = 1 / orig_scale
1748                if axis_type in ("channel", "index"):
1749                    # these axes no longer have a scale
1750                    offset_from_scale = orig_scale * size_refs.get(
1751                        _TensorName_v0_4(t_id), {}
1752                    ).get(orig_a_id, 0)
1753                else:
1754                    offset_from_scale = 0
1755                size = SizeReference(
1756                    tensor_id=TensorId(t_id),
1757                    axis_id=AxisId(a_id),
1758                    offset=int(offset_from_scale + 2 * shape.offset[i]),
1759                )
1760        else:
1761            size = shape[i]
1762
1763        if axis_type == "time":
1764            if tensor_type == "input":
1765                ret.append(TimeInputAxis(size=size, scale=scale))
1766            else:
1767                assert not isinstance(size, ParameterizedSize)
1768                if halo is None:
1769                    ret.append(TimeOutputAxis(size=size, scale=scale))
1770                else:
1771                    assert not isinstance(size, int)
1772                    ret.append(
1773                        TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i])
1774                    )
1775
1776        elif axis_type == "index":
1777            if tensor_type == "input":
1778                ret.append(IndexInputAxis(size=size))
1779            else:
1780                if isinstance(size, ParameterizedSize):
1781                    size = DataDependentSize(min=size.min)
1782
1783                ret.append(IndexOutputAxis(size=size))
1784        elif axis_type == "channel":
1785            assert not isinstance(size, ParameterizedSize)
1786            if isinstance(size, SizeReference):
1787                warnings.warn(
1788                    "Conversion of channel size from an implicit output shape may be"
1789                    + " wrong"
1790                )
1791                ret.append(
1792                    ChannelAxis(
1793                        channel_names=[
1794                            Identifier(f"channel{i}") for i in range(size.offset)
1795                        ]
1796                    )
1797                )
1798            else:
1799                ret.append(
1800                    ChannelAxis(
1801                        channel_names=[Identifier(f"channel{i}") for i in range(size)]
1802                    )
1803                )
1804        elif axis_type == "space":
1805            if tensor_type == "input":
1806                ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale))
1807            else:
1808                assert not isinstance(size, ParameterizedSize)
1809                if halo is None or halo[i] == 0:
1810                    ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale))
1811                elif isinstance(size, int):
1812                    raise NotImplementedError(
1813                        f"output axis with halo and fixed size (here {size}) not allowed"
1814                    )
1815                else:
1816                    ret.append(
1817                        SpaceOutputAxisWithHalo(
1818                            id=AxisId(a), size=size, scale=scale, halo=halo[i]
1819                        )
1820                    )
1821
1822    return ret
class OutputTensorDescr(bioimageio.spec._internal.node.Node, typing.Generic[~IO_AxisT]):
1982class OutputTensorDescr(TensorDescrBase[OutputAxis]):
1983    id: TensorId = TensorId("output")
1984    """Output tensor id.
1985    No duplicates are allowed across all inputs and outputs."""
1986
1987    postprocessing: List[PostprocessingDescr] = Field(
1988        default_factory=cast(Callable[[], List[PostprocessingDescr]], list)
1989    )
1990    """Description of how this output should be postprocessed.
1991
1992    note: `postprocessing` always ends with an 'ensure_dtype' operation.
1993          If not given this is added to cast to this tensor's `data.type`.
1994    """
1995
1996    @model_validator(mode="after")
1997    def _validate_postprocessing_kwargs(self) -> Self:
1998        axes_ids = [a.id for a in self.axes]
1999        for p in self.postprocessing:
2000            kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
2001            if kwargs_axes is None:
2002                continue
2003
2004            if not isinstance(kwargs_axes, collections.abc.Sequence):
2005                raise ValueError(
2006                    f"expected `axes` sequence, but got {type(kwargs_axes)}"
2007                )
2008
2009            if any(a not in axes_ids for a in kwargs_axes):
2010                raise ValueError("`kwargs.axes` needs to be subset of axes ids")
2011
2012        if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
2013            dtype = self.data.type
2014        else:
2015            dtype = self.data[0].type
2016
2017        # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
2018        if not self.postprocessing or not isinstance(
2019            self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)
2020        ):
2021            self.postprocessing.append(
2022                EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
2023            )
2024        return self
id: TensorId

Output tensor id. No duplicates are allowed across all inputs and outputs.

postprocessing: List[Annotated[Union[BinarizeDescr, ClipDescr, EnsureDtypeDescr, ScaleLinearDescr, SigmoidDescr, FixedZeroMeanUnitVarianceDescr, ZeroMeanUnitVarianceDescr, ScaleRangeDescr, ScaleMeanVarianceDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]

Description of how this output should be postprocessed.

note: postprocessing always ends with an 'ensure_dtype' operation. If not given this is added to cast to this tensor's data.type.

TensorDescr = typing.Union[InputTensorDescr, OutputTensorDescr]
def validate_tensors( tensors: Mapping[TensorId, Tuple[Union[InputTensorDescr, OutputTensorDescr], numpy.ndarray[tuple[Any, ...], numpy.dtype[Any]]]], tensor_origin: Literal['test_tensor']):
2074def validate_tensors(
2075    tensors: Mapping[TensorId, Tuple[TensorDescr, NDArray[Any]]],
2076    tensor_origin: Literal[
2077        "test_tensor"
2078    ],  # for more precise error messages, e.g. 'test_tensor'
2079):
2080    all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, int]]] = {}
2081
2082    def e_msg(d: TensorDescr):
2083        return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]"
2084
2085    for descr, array in tensors.values():
2086        try:
2087            axis_sizes = descr.get_axis_sizes_for_array(array)
2088        except ValueError as e:
2089            raise ValueError(f"{e_msg(descr)} {e}")
2090        else:
2091            all_tensor_axes[descr.id] = {
2092                a.id: (a, axis_sizes[a.id]) for a in descr.axes
2093            }
2094
2095    for descr, array in tensors.values():
2096        if descr.dtype in ("float32", "float64"):
2097            invalid_test_tensor_dtype = array.dtype.name not in (
2098                "float32",
2099                "float64",
2100                "uint8",
2101                "int8",
2102                "uint16",
2103                "int16",
2104                "uint32",
2105                "int32",
2106                "uint64",
2107                "int64",
2108            )
2109        else:
2110            invalid_test_tensor_dtype = array.dtype.name != descr.dtype
2111
2112        if invalid_test_tensor_dtype:
2113            raise ValueError(
2114                f"{e_msg(descr)}.{tensor_origin}.dtype '{array.dtype.name}' does not"
2115                + f" match described dtype '{descr.dtype}'"
2116            )
2117
2118        if array.min() > -1e-4 and array.max() < 1e-4:
2119            raise ValueError(
2120                "Output values are too small for reliable testing."
2121                + f" Values <-1e5 or >=1e5 must be present in {tensor_origin}"
2122            )
2123
2124        for a in descr.axes:
2125            actual_size = all_tensor_axes[descr.id][a.id][1]
2126            if a.size is None:
2127                continue
2128
2129            if isinstance(a.size, int):
2130                if actual_size != a.size:
2131                    raise ValueError(
2132                        f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' "
2133                        + f"has incompatible size {actual_size}, expected {a.size}"
2134                    )
2135            elif isinstance(a.size, ParameterizedSize):
2136                _ = a.size.validate_size(actual_size)
2137            elif isinstance(a.size, DataDependentSize):
2138                _ = a.size.validate_size(actual_size)
2139            elif isinstance(a.size, SizeReference):
2140                ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id)
2141                if ref_tensor_axes is None:
2142                    raise ValueError(
2143                        f"{e_msg(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor"
2144                        + f" reference '{a.size.tensor_id}'"
2145                    )
2146
2147                ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None))
2148                if ref_axis is None or ref_size is None:
2149                    raise ValueError(
2150                        f"{e_msg(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis"
2151                        + f" reference '{a.size.tensor_id}.{a.size.axis_id}"
2152                    )
2153
2154                if a.unit != ref_axis.unit:
2155                    raise ValueError(
2156                        f"{e_msg(descr)}.axes[{a.id}].size: `SizeReference` requires"
2157                        + " axis and reference axis to have the same `unit`, but"
2158                        + f" {a.unit}!={ref_axis.unit}"
2159                    )
2160
2161                if actual_size != (
2162                    expected_size := (
2163                        ref_size * ref_axis.scale / a.scale + a.size.offset
2164                    )
2165                ):
2166                    raise ValueError(
2167                        f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' of size"
2168                        + f" {actual_size} invalid for referenced size {ref_size};"
2169                        + f" expected {expected_size}"
2170                    )
2171            else:
2172                assert_never(a.size)
FileDescr_dependencies = typing.Annotated[bioimageio.spec._internal.io.FileDescr, AfterValidator(func=<function wo_special_file_name>), WrapSerializer(func=<function package_file_descr_serializer>, return_type=PydanticUndefined, when_used='unless-none'), WithSuffix(suffix=('.yaml', '.yml'), case_sensitive=True), FieldInfo(annotation=NoneType, required=True, examples=[{'source': 'environment.yaml'}])]
class ArchitectureFromFileDescr(_ArchitectureCallableDescr, bioimageio.spec._internal.io.FileDescr):
2192class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr):
2193    source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2194    """Architecture source file"""
2195
2196    @model_serializer(mode="wrap", when_used="unless-none")
2197    def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo):
2198        return package_file_descr_serializer(self, nxt, info)

A file description

source: Annotated[Union[bioimageio.spec._internal.url.HttpUrl, bioimageio.spec._internal.io.RelativeFilePath, Annotated[pathlib.Path, PathType(path_type='file'), FieldInfo(annotation=NoneType, required=True, title='FilePath')]], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')]), AfterValidator(func=<function wo_special_file_name at 0x7f5fe201bd80>)]

Architecture source file

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr):
2201class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr):
2202    import_from: str
2203    """Where to import the callable from, i.e. `from <import_from> import <callable>`"""
import_from: str

Where to import the callable from, i.e. from <import_from> import <callable>

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class WeightsEntryDescrBase(bioimageio.spec._internal.io.FileDescr):
2263class WeightsEntryDescrBase(FileDescr):
2264    type: ClassVar[WeightsFormat]
2265    weights_format_name: ClassVar[str]  # human readable
2266
2267    source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2268    """Source of the weights file."""
2269
2270    authors: Optional[List[Author]] = None
2271    """Authors
2272    Either the person(s) that have trained this model resulting in the original weights file.
2273        (If this is the initial weights entry, i.e. it does not have a `parent`)
2274    Or the person(s) who have converted the weights to this weights format.
2275        (If this is a child weight, i.e. it has a `parent` field)
2276    """
2277
2278    parent: Annotated[
2279        Optional[WeightsFormat], Field(examples=["pytorch_state_dict"])
2280    ] = None
2281    """The source weights these weights were converted from.
2282    For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`,
2283    The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights.
2284    All weight entries except one (the initial set of weights resulting from training the model),
2285    need to have this field."""
2286
2287    comment: str = ""
2288    """A comment about this weights entry, for example how these weights were created."""
2289
2290    @model_validator(mode="after")
2291    def _validate(self) -> Self:
2292        if self.type == self.parent:
2293            raise ValueError("Weights entry can't be it's own parent.")
2294
2295        return self
2296
2297    @model_serializer(mode="wrap", when_used="unless-none")
2298    def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo):
2299        return package_file_descr_serializer(self, nxt, info)

A file description

type: ClassVar[Literal['keras_hdf5', 'onnx', 'pytorch_state_dict', 'tensorflow_js', 'tensorflow_saved_model_bundle', 'torchscript']]
weights_format_name: ClassVar[str]
source: Annotated[Union[bioimageio.spec._internal.url.HttpUrl, bioimageio.spec._internal.io.RelativeFilePath, Annotated[pathlib.Path, PathType(path_type='file'), FieldInfo(annotation=NoneType, required=True, title='FilePath')]], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')]), AfterValidator(func=<function wo_special_file_name at 0x7f5fe201bd80>)]

Source of the weights file.

authors: Optional[List[bioimageio.spec.generic.v0_3.Author]]

Authors Either the person(s) that have trained this model resulting in the original weights file. (If this is the initial weights entry, i.e. it does not have a parent) Or the person(s) who have converted the weights to this weights format. (If this is a child weight, i.e. it has a parent field)

parent: Annotated[Optional[Literal['keras_hdf5', 'onnx', 'pytorch_state_dict', 'tensorflow_js', 'tensorflow_saved_model_bundle', 'torchscript']], FieldInfo(annotation=NoneType, required=True, examples=['pytorch_state_dict'])]

The source weights these weights were converted from. For example, if a model's weights were converted from the pytorch_state_dict format to torchscript, The pytorch_state_dict weights entry has no parent and is the parent of the torchscript weights. All weight entries except one (the initial set of weights resulting from training the model), need to have this field.

comment: str

A comment about this weights entry, for example how these weights were created.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class KerasHdf5WeightsDescr(WeightsEntryDescrBase):
2302class KerasHdf5WeightsDescr(WeightsEntryDescrBase):
2303    type = "keras_hdf5"
2304    weights_format_name: ClassVar[str] = "Keras HDF5"
2305    tensorflow_version: Version
2306    """TensorFlow version used to create these weights."""

A file description

type = 'keras_hdf5'
weights_format_name: ClassVar[str] = 'Keras HDF5'

TensorFlow version used to create these weights.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class OnnxWeightsDescr(WeightsEntryDescrBase):
2309class OnnxWeightsDescr(WeightsEntryDescrBase):
2310    type = "onnx"
2311    weights_format_name: ClassVar[str] = "ONNX"
2312    opset_version: Annotated[int, Ge(7)]
2313    """ONNX opset version"""

A file description

type = 'onnx'
weights_format_name: ClassVar[str] = 'ONNX'
opset_version: Annotated[int, Ge(ge=7)]

ONNX opset version

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class PytorchStateDictWeightsDescr(WeightsEntryDescrBase):
2316class PytorchStateDictWeightsDescr(WeightsEntryDescrBase):
2317    type = "pytorch_state_dict"
2318    weights_format_name: ClassVar[str] = "Pytorch State Dict"
2319    architecture: Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr]
2320    pytorch_version: Version
2321    """Version of the PyTorch library used.
2322    If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible.
2323    """
2324    dependencies: Optional[FileDescr_dependencies] = None
2325    """Custom depencies beyond pytorch described in a Conda environment file.
2326    Allows to specify custom dependencies, see conda docs:
2327    - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms)
2328    - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually)
2329
2330    The conda environment file should include pytorch and any version pinning has to be compatible with
2331    **pytorch_version**.
2332    """

A file description

type = 'pytorch_state_dict'
weights_format_name: ClassVar[str] = 'Pytorch State Dict'

Version of the PyTorch library used. If architecture.depencencies is specified it has to include pytorch and any version pinning has to be compatible.

dependencies: Optional[Annotated[bioimageio.spec._internal.io.FileDescr, AfterValidator(func=<function wo_special_file_name at 0x7f5fe201bd80>), WrapSerializer(func=<function package_file_descr_serializer at 0x7f5fe20b6de0>, return_type=PydanticUndefined, when_used='unless-none'), WithSuffix(suffix=('.yaml', '.yml'), case_sensitive=True), FieldInfo(annotation=NoneType, required=True, examples=[{'source': 'environment.yaml'}])]]

Custom depencies beyond pytorch described in a Conda environment file. Allows to specify custom dependencies, see conda docs:

The conda environment file should include pytorch and any version pinning has to be compatible with pytorch_version.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class TensorflowJsWeightsDescr(WeightsEntryDescrBase):
2335class TensorflowJsWeightsDescr(WeightsEntryDescrBase):
2336    type = "tensorflow_js"
2337    weights_format_name: ClassVar[str] = "Tensorflow.js"
2338    tensorflow_version: Version
2339    """Version of the TensorFlow library used."""
2340
2341    source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2342    """The multi-file weights.
2343    All required files/folders should be a zip archive."""

A file description

type = 'tensorflow_js'
weights_format_name: ClassVar[str] = 'Tensorflow.js'

Version of the TensorFlow library used.

source: Annotated[Union[bioimageio.spec._internal.url.HttpUrl, bioimageio.spec._internal.io.RelativeFilePath, Annotated[pathlib.Path, PathType(path_type='file'), FieldInfo(annotation=NoneType, required=True, title='FilePath')]], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')]), AfterValidator(func=<function wo_special_file_name at 0x7f5fe201bd80>)]

The multi-file weights. All required files/folders should be a zip archive.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase):
2346class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase):
2347    type = "tensorflow_saved_model_bundle"
2348    weights_format_name: ClassVar[str] = "Tensorflow Saved Model"
2349    tensorflow_version: Version
2350    """Version of the TensorFlow library used."""
2351
2352    dependencies: Optional[FileDescr_dependencies] = None
2353    """Custom dependencies beyond tensorflow.
2354    Should include tensorflow and any version pinning has to be compatible with **tensorflow_version**."""
2355
2356    source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2357    """The multi-file weights.
2358    All required files/folders should be a zip archive."""

A file description

type = 'tensorflow_saved_model_bundle'
weights_format_name: ClassVar[str] = 'Tensorflow Saved Model'

Version of the TensorFlow library used.

dependencies: Optional[Annotated[bioimageio.spec._internal.io.FileDescr, AfterValidator(func=<function wo_special_file_name at 0x7f5fe201bd80>), WrapSerializer(func=<function package_file_descr_serializer at 0x7f5fe20b6de0>, return_type=PydanticUndefined, when_used='unless-none'), WithSuffix(suffix=('.yaml', '.yml'), case_sensitive=True), FieldInfo(annotation=NoneType, required=True, examples=[{'source': 'environment.yaml'}])]]

Custom dependencies beyond tensorflow. Should include tensorflow and any version pinning has to be compatible with tensorflow_version.

source: Annotated[Union[bioimageio.spec._internal.url.HttpUrl, bioimageio.spec._internal.io.RelativeFilePath, Annotated[pathlib.Path, PathType(path_type='file'), FieldInfo(annotation=NoneType, required=True, title='FilePath')]], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')]), AfterValidator(func=<function wo_special_file_name at 0x7f5fe201bd80>)]

The multi-file weights. All required files/folders should be a zip archive.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class TorchscriptWeightsDescr(WeightsEntryDescrBase):
2361class TorchscriptWeightsDescr(WeightsEntryDescrBase):
2362    type = "torchscript"
2363    weights_format_name: ClassVar[str] = "TorchScript"
2364    pytorch_version: Version
2365    """Version of the PyTorch library used."""

A file description

type = 'torchscript'
weights_format_name: ClassVar[str] = 'TorchScript'

Version of the PyTorch library used.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class WeightsDescr(bioimageio.spec._internal.node.Node):
2368class WeightsDescr(Node):
2369    keras_hdf5: Optional[KerasHdf5WeightsDescr] = None
2370    onnx: Optional[OnnxWeightsDescr] = None
2371    pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None
2372    tensorflow_js: Optional[TensorflowJsWeightsDescr] = None
2373    tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = (
2374        None
2375    )
2376    torchscript: Optional[TorchscriptWeightsDescr] = None
2377
2378    @model_validator(mode="after")
2379    def check_entries(self) -> Self:
2380        entries = {wtype for wtype, entry in self if entry is not None}
2381
2382        if not entries:
2383            raise ValueError("Missing weights entry")
2384
2385        entries_wo_parent = {
2386            wtype
2387            for wtype, entry in self
2388            if entry is not None and hasattr(entry, "parent") and entry.parent is None
2389        }
2390        if len(entries_wo_parent) != 1:
2391            issue_warning(
2392                "Exactly one weights entry may not specify the `parent` field (got"
2393                + " {value}). That entry is considered the original set of model weights."
2394                + " Other weight formats are created through conversion of the orignal or"
2395                + " already converted weights. They have to reference the weights format"
2396                + " they were converted from as their `parent`.",
2397                value=len(entries_wo_parent),
2398                field="weights",
2399            )
2400
2401        for wtype, entry in self:
2402            if entry is None:
2403                continue
2404
2405            assert hasattr(entry, "type")
2406            assert hasattr(entry, "parent")
2407            assert wtype == entry.type
2408            if (
2409                entry.parent is not None and entry.parent not in entries
2410            ):  # self reference checked for `parent` field
2411                raise ValueError(
2412                    f"`weights.{wtype}.parent={entry.parent} not in specified weight"
2413                    + f" formats: {entries}"
2414                )
2415
2416        return self
2417
2418    def __getitem__(
2419        self,
2420        key: Literal[
2421            "keras_hdf5",
2422            "onnx",
2423            "pytorch_state_dict",
2424            "tensorflow_js",
2425            "tensorflow_saved_model_bundle",
2426            "torchscript",
2427        ],
2428    ):
2429        if key == "keras_hdf5":
2430            ret = self.keras_hdf5
2431        elif key == "onnx":
2432            ret = self.onnx
2433        elif key == "pytorch_state_dict":
2434            ret = self.pytorch_state_dict
2435        elif key == "tensorflow_js":
2436            ret = self.tensorflow_js
2437        elif key == "tensorflow_saved_model_bundle":
2438            ret = self.tensorflow_saved_model_bundle
2439        elif key == "torchscript":
2440            ret = self.torchscript
2441        else:
2442            raise KeyError(key)
2443
2444        if ret is None:
2445            raise KeyError(key)
2446
2447        return ret
2448
2449    @property
2450    def available_formats(self):
2451        return {
2452            **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}),
2453            **({} if self.onnx is None else {"onnx": self.onnx}),
2454            **(
2455                {}
2456                if self.pytorch_state_dict is None
2457                else {"pytorch_state_dict": self.pytorch_state_dict}
2458            ),
2459            **(
2460                {}
2461                if self.tensorflow_js is None
2462                else {"tensorflow_js": self.tensorflow_js}
2463            ),
2464            **(
2465                {}
2466                if self.tensorflow_saved_model_bundle is None
2467                else {
2468                    "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle
2469                }
2470            ),
2471            **({} if self.torchscript is None else {"torchscript": self.torchscript}),
2472        }
2473
2474    @property
2475    def missing_formats(self):
2476        return {
2477            wf for wf in get_args(WeightsFormat) if wf not in self.available_formats
2478        }
keras_hdf5: Optional[KerasHdf5WeightsDescr]
onnx: Optional[OnnxWeightsDescr]
pytorch_state_dict: Optional[PytorchStateDictWeightsDescr]
tensorflow_js: Optional[TensorflowJsWeightsDescr]
tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr]
torchscript: Optional[TorchscriptWeightsDescr]
@model_validator(mode='after')
def check_entries(self) -> Self:
2378    @model_validator(mode="after")
2379    def check_entries(self) -> Self:
2380        entries = {wtype for wtype, entry in self if entry is not None}
2381
2382        if not entries:
2383            raise ValueError("Missing weights entry")
2384
2385        entries_wo_parent = {
2386            wtype
2387            for wtype, entry in self
2388            if entry is not None and hasattr(entry, "parent") and entry.parent is None
2389        }
2390        if len(entries_wo_parent) != 1:
2391            issue_warning(
2392                "Exactly one weights entry may not specify the `parent` field (got"
2393                + " {value}). That entry is considered the original set of model weights."
2394                + " Other weight formats are created through conversion of the orignal or"
2395                + " already converted weights. They have to reference the weights format"
2396                + " they were converted from as their `parent`.",
2397                value=len(entries_wo_parent),
2398                field="weights",
2399            )
2400
2401        for wtype, entry in self:
2402            if entry is None:
2403                continue
2404
2405            assert hasattr(entry, "type")
2406            assert hasattr(entry, "parent")
2407            assert wtype == entry.type
2408            if (
2409                entry.parent is not None and entry.parent not in entries
2410            ):  # self reference checked for `parent` field
2411                raise ValueError(
2412                    f"`weights.{wtype}.parent={entry.parent} not in specified weight"
2413                    + f" formats: {entries}"
2414                )
2415
2416        return self
available_formats
2449    @property
2450    def available_formats(self):
2451        return {
2452            **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}),
2453            **({} if self.onnx is None else {"onnx": self.onnx}),
2454            **(
2455                {}
2456                if self.pytorch_state_dict is None
2457                else {"pytorch_state_dict": self.pytorch_state_dict}
2458            ),
2459            **(
2460                {}
2461                if self.tensorflow_js is None
2462                else {"tensorflow_js": self.tensorflow_js}
2463            ),
2464            **(
2465                {}
2466                if self.tensorflow_saved_model_bundle is None
2467                else {
2468                    "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle
2469                }
2470            ),
2471            **({} if self.torchscript is None else {"torchscript": self.torchscript}),
2472        }
missing_formats
2474    @property
2475    def missing_formats(self):
2476        return {
2477            wf for wf in get_args(WeightsFormat) if wf not in self.available_formats
2478        }
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class ModelId(bioimageio.spec.generic.v0_3.ResourceId):
2481class ModelId(ResourceId):
2482    pass

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.

class LinkedModel(bioimageio.spec.generic.v0_3.LinkedResourceBase):
2485class LinkedModel(LinkedResourceBase):
2486    """Reference to a bioimage.io model."""
2487
2488    id: ModelId
2489    """A valid model `id` from the bioimage.io collection."""

Reference to a bioimage.io model.

id: ModelId

A valid model id from the bioimage.io collection.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class ReproducibilityTolerance(bioimageio.spec._internal.node.Node):
2511class ReproducibilityTolerance(Node, extra="allow"):
2512    """Describes what small numerical differences -- if any -- may be tolerated
2513    in the generated output when executing in different environments.
2514
2515    A tensor element *output* is considered mismatched to the **test_tensor** if
2516    abs(*output* - **test_tensor**) > **absolute_tolerance** + **relative_tolerance** * abs(**test_tensor**).
2517    (Internally we call [numpy.testing.assert_allclose](https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html).)
2518
2519    Motivation:
2520        For testing we can request the respective deep learning frameworks to be as
2521        reproducible as possible by setting seeds and chosing deterministic algorithms,
2522        but differences in operating systems, available hardware and installed drivers
2523        may still lead to numerical differences.
2524    """
2525
2526    relative_tolerance: RelativeTolerance = 1e-3
2527    """Maximum relative tolerance of reproduced test tensor."""
2528
2529    absolute_tolerance: AbsoluteTolerance = 1e-4
2530    """Maximum absolute tolerance of reproduced test tensor."""
2531
2532    mismatched_elements_per_million: MismatchedElementsPerMillion = 100
2533    """Maximum number of mismatched elements/pixels per million to tolerate."""
2534
2535    output_ids: Sequence[TensorId] = ()
2536    """Limits the output tensor IDs these reproducibility details apply to."""
2537
2538    weights_formats: Sequence[WeightsFormat] = ()
2539    """Limits the weights formats these details apply to."""

Describes what small numerical differences -- if any -- may be tolerated in the generated output when executing in different environments.

A tensor element output is considered mismatched to the test_tensor if abs(output - test_tensor) > absolute_tolerance + relative_tolerance * abs(test_tensor). (Internally we call numpy.testing.assert_allclose.)

Motivation:

For testing we can request the respective deep learning frameworks to be as reproducible as possible by setting seeds and chosing deterministic algorithms, but differences in operating systems, available hardware and installed drivers may still lead to numerical differences.

relative_tolerance: Annotated[float, Interval(gt=None, ge=0, lt=None, le=0.01)]

Maximum relative tolerance of reproduced test tensor.

absolute_tolerance: Annotated[float, Interval(gt=None, ge=0, lt=None, le=None)]

Maximum absolute tolerance of reproduced test tensor.

mismatched_elements_per_million: Annotated[int, Interval(gt=None, ge=0, lt=None, le=1000)]

Maximum number of mismatched elements/pixels per million to tolerate.

output_ids: Sequence[TensorId]

Limits the output tensor IDs these reproducibility details apply to.

weights_formats: Sequence[Literal['keras_hdf5', 'onnx', 'pytorch_state_dict', 'tensorflow_js', 'tensorflow_saved_model_bundle', 'torchscript']]

Limits the weights formats these details apply to.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'allow', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class BioimageioConfig(bioimageio.spec._internal.node.Node):
2542class BioimageioConfig(Node, extra="allow"):
2543    reproducibility_tolerance: Sequence[ReproducibilityTolerance] = ()
2544    """Tolerances to allow when reproducing the model's test outputs
2545    from the model's test inputs.
2546    Only the first entry matching tensor id and weights format is considered.
2547    """
reproducibility_tolerance: Sequence[ReproducibilityTolerance]

Tolerances to allow when reproducing the model's test outputs from the model's test inputs. Only the first entry matching tensor id and weights format is considered.

model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'allow', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class Config(bioimageio.spec._internal.node.Node):
2550class Config(Node, extra="allow"):
2551    bioimageio: BioimageioConfig = Field(default_factory=BioimageioConfig)
bioimageio: BioimageioConfig
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'allow', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

2554class ModelDescr(GenericModelDescrBase):
2555    """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights.
2556    These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
2557    """
2558
2559    implemented_format_version: ClassVar[Literal["0.5.4"]] = "0.5.4"
2560    if TYPE_CHECKING:
2561        format_version: Literal["0.5.4"] = "0.5.4"
2562    else:
2563        format_version: Literal["0.5.4"]
2564        """Version of the bioimage.io model description specification used.
2565        When creating a new model always use the latest micro/patch version described here.
2566        The `format_version` is important for any consumer software to understand how to parse the fields.
2567        """
2568
2569    implemented_type: ClassVar[Literal["model"]] = "model"
2570    if TYPE_CHECKING:
2571        type: Literal["model"] = "model"
2572    else:
2573        type: Literal["model"]
2574        """Specialized resource type 'model'"""
2575
2576    id: Optional[ModelId] = None
2577    """bioimage.io-wide unique resource identifier
2578    assigned by bioimage.io; version **un**specific."""
2579
2580    authors: NotEmpty[List[Author]]
2581    """The authors are the creators of the model RDF and the primary points of contact."""
2582
2583    documentation: FileSource_documentation
2584    """URL or relative path to a markdown file with additional documentation.
2585    The recommended documentation file name is `README.md`. An `.md` suffix is mandatory.
2586    The documentation should include a '#[#] Validation' (sub)section
2587    with details on how to quantitatively validate the model on unseen data."""
2588
2589    @field_validator("documentation", mode="after")
2590    @classmethod
2591    def _validate_documentation(
2592        cls, value: FileSource_documentation
2593    ) -> FileSource_documentation:
2594        if not get_validation_context().perform_io_checks:
2595            return value
2596
2597        doc_reader = get_reader(value)
2598        doc_content = doc_reader.read().decode(encoding="utf-8")
2599        if not re.search("#.*[vV]alidation", doc_content):
2600            issue_warning(
2601                "No '# Validation' (sub)section found in {value}.",
2602                value=value,
2603                field="documentation",
2604            )
2605
2606        return value
2607
2608    inputs: NotEmpty[Sequence[InputTensorDescr]]
2609    """Describes the input tensors expected by this model."""
2610
2611    @field_validator("inputs", mode="after")
2612    @classmethod
2613    def _validate_input_axes(
2614        cls, inputs: Sequence[InputTensorDescr]
2615    ) -> Sequence[InputTensorDescr]:
2616        input_size_refs = cls._get_axes_with_independent_size(inputs)
2617
2618        for i, ipt in enumerate(inputs):
2619            valid_independent_refs: Dict[
2620                Tuple[TensorId, AxisId],
2621                Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2622            ] = {
2623                **{
2624                    (ipt.id, a.id): (ipt, a, a.size)
2625                    for a in ipt.axes
2626                    if not isinstance(a, BatchAxis)
2627                    and isinstance(a.size, (int, ParameterizedSize))
2628                },
2629                **input_size_refs,
2630            }
2631            for a, ax in enumerate(ipt.axes):
2632                cls._validate_axis(
2633                    "inputs",
2634                    i=i,
2635                    tensor_id=ipt.id,
2636                    a=a,
2637                    axis=ax,
2638                    valid_independent_refs=valid_independent_refs,
2639                )
2640        return inputs
2641
2642    @staticmethod
2643    def _validate_axis(
2644        field_name: str,
2645        i: int,
2646        tensor_id: TensorId,
2647        a: int,
2648        axis: AnyAxis,
2649        valid_independent_refs: Dict[
2650            Tuple[TensorId, AxisId],
2651            Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2652        ],
2653    ):
2654        if isinstance(axis, BatchAxis) or isinstance(
2655            axis.size, (int, ParameterizedSize, DataDependentSize)
2656        ):
2657            return
2658        elif not isinstance(axis.size, SizeReference):
2659            assert_never(axis.size)
2660
2661        # validate axis.size SizeReference
2662        ref = (axis.size.tensor_id, axis.size.axis_id)
2663        if ref not in valid_independent_refs:
2664            raise ValueError(
2665                "Invalid tensor axis reference at"
2666                + f" {field_name}[{i}].axes[{a}].size: {axis.size}."
2667            )
2668        if ref == (tensor_id, axis.id):
2669            raise ValueError(
2670                "Self-referencing not allowed for"
2671                + f" {field_name}[{i}].axes[{a}].size: {axis.size}"
2672            )
2673        if axis.type == "channel":
2674            if valid_independent_refs[ref][1].type != "channel":
2675                raise ValueError(
2676                    "A channel axis' size may only reference another fixed size"
2677                    + " channel axis."
2678                )
2679            if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names:
2680                ref_size = valid_independent_refs[ref][2]
2681                assert isinstance(ref_size, int), (
2682                    "channel axis ref (another channel axis) has to specify fixed"
2683                    + " size"
2684                )
2685                generated_channel_names = [
2686                    Identifier(axis.channel_names.format(i=i))
2687                    for i in range(1, ref_size + 1)
2688                ]
2689                axis.channel_names = generated_channel_names
2690
2691        if (ax_unit := getattr(axis, "unit", None)) != (
2692            ref_unit := getattr(valid_independent_refs[ref][1], "unit", None)
2693        ):
2694            raise ValueError(
2695                "The units of an axis and its reference axis need to match, but"
2696                + f" '{ax_unit}' != '{ref_unit}'."
2697            )
2698        ref_axis = valid_independent_refs[ref][1]
2699        if isinstance(ref_axis, BatchAxis):
2700            raise ValueError(
2701                f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}"
2702                + " (a batch axis is not allowed as reference)."
2703            )
2704
2705        if isinstance(axis, WithHalo):
2706            min_size = axis.size.get_size(axis, ref_axis, n=0)
2707            if (min_size - 2 * axis.halo) < 1:
2708                raise ValueError(
2709                    f"axis {axis.id} with minimum size {min_size} is too small for halo"
2710                    + f" {axis.halo}."
2711                )
2712
2713            input_halo = axis.halo * axis.scale / ref_axis.scale
2714            if input_halo != int(input_halo) or input_halo % 2 == 1:
2715                raise ValueError(
2716                    f"input_halo {input_halo} (output_halo {axis.halo} *"
2717                    + f" output_scale {axis.scale} / input_scale {ref_axis.scale})"
2718                    + f"     {tensor_id}.{axis.id}."
2719                )
2720
2721    @model_validator(mode="after")
2722    def _validate_test_tensors(self) -> Self:
2723        if not get_validation_context().perform_io_checks:
2724            return self
2725
2726        test_output_arrays = [load_array(descr.test_tensor) for descr in self.outputs]
2727        test_input_arrays = [load_array(descr.test_tensor) for descr in self.inputs]
2728
2729        tensors = {
2730            descr.id: (descr, array)
2731            for descr, array in zip(
2732                chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays
2733            )
2734        }
2735        validate_tensors(tensors, tensor_origin="test_tensor")
2736
2737        output_arrays = {
2738            descr.id: array for descr, array in zip(self.outputs, test_output_arrays)
2739        }
2740        for rep_tol in self.config.bioimageio.reproducibility_tolerance:
2741            if not rep_tol.absolute_tolerance:
2742                continue
2743
2744            if rep_tol.output_ids:
2745                out_arrays = {
2746                    oid: a
2747                    for oid, a in output_arrays.items()
2748                    if oid in rep_tol.output_ids
2749                }
2750            else:
2751                out_arrays = output_arrays
2752
2753            for out_id, array in out_arrays.items():
2754                if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01:
2755                    raise ValueError(
2756                        "config.bioimageio.reproducibility_tolerance.absolute_tolerance="
2757                        + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}"
2758                        + f" (1% of the maximum value of the test tensor '{out_id}')"
2759                    )
2760
2761        return self
2762
2763    @model_validator(mode="after")
2764    def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self:
2765        ipt_refs = {t.id for t in self.inputs}
2766        out_refs = {t.id for t in self.outputs}
2767        for ipt in self.inputs:
2768            for p in ipt.preprocessing:
2769                ref = p.kwargs.get("reference_tensor")
2770                if ref is None:
2771                    continue
2772                if ref not in ipt_refs:
2773                    raise ValueError(
2774                        f"`reference_tensor` '{ref}' not found. Valid input tensor"
2775                        + f" references are: {ipt_refs}."
2776                    )
2777
2778        for out in self.outputs:
2779            for p in out.postprocessing:
2780                ref = p.kwargs.get("reference_tensor")
2781                if ref is None:
2782                    continue
2783
2784                if ref not in ipt_refs and ref not in out_refs:
2785                    raise ValueError(
2786                        f"`reference_tensor` '{ref}' not found. Valid tensor references"
2787                        + f" are: {ipt_refs | out_refs}."
2788                    )
2789
2790        return self
2791
2792    # TODO: use validate funcs in validate_test_tensors
2793    # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]:
2794
2795    name: Annotated[
2796        Annotated[
2797            str, RestrictCharacters(string.ascii_letters + string.digits + "_+- ()")
2798        ],
2799        MinLen(5),
2800        MaxLen(128),
2801        warn(MaxLen(64), "Name longer than 64 characters.", INFO),
2802    ]
2803    """A human-readable name of this model.
2804    It should be no longer than 64 characters
2805    and may only contain letter, number, underscore, minus, parentheses and spaces.
2806    We recommend to chose a name that refers to the model's task and image modality.
2807    """
2808
2809    outputs: NotEmpty[Sequence[OutputTensorDescr]]
2810    """Describes the output tensors."""
2811
2812    @field_validator("outputs", mode="after")
2813    @classmethod
2814    def _validate_tensor_ids(
2815        cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo
2816    ) -> Sequence[OutputTensorDescr]:
2817        tensor_ids = [
2818            t.id for t in info.data.get("inputs", []) + info.data.get("outputs", [])
2819        ]
2820        duplicate_tensor_ids: List[str] = []
2821        seen: Set[str] = set()
2822        for t in tensor_ids:
2823            if t in seen:
2824                duplicate_tensor_ids.append(t)
2825
2826            seen.add(t)
2827
2828        if duplicate_tensor_ids:
2829            raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}")
2830
2831        return outputs
2832
2833    @staticmethod
2834    def _get_axes_with_parameterized_size(
2835        io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2836    ):
2837        return {
2838            f"{t.id}.{a.id}": (t, a, a.size)
2839            for t in io
2840            for a in t.axes
2841            if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize)
2842        }
2843
2844    @staticmethod
2845    def _get_axes_with_independent_size(
2846        io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2847    ):
2848        return {
2849            (t.id, a.id): (t, a, a.size)
2850            for t in io
2851            for a in t.axes
2852            if not isinstance(a, BatchAxis)
2853            and isinstance(a.size, (int, ParameterizedSize))
2854        }
2855
2856    @field_validator("outputs", mode="after")
2857    @classmethod
2858    def _validate_output_axes(
2859        cls, outputs: List[OutputTensorDescr], info: ValidationInfo
2860    ) -> List[OutputTensorDescr]:
2861        input_size_refs = cls._get_axes_with_independent_size(
2862            info.data.get("inputs", [])
2863        )
2864        output_size_refs = cls._get_axes_with_independent_size(outputs)
2865
2866        for i, out in enumerate(outputs):
2867            valid_independent_refs: Dict[
2868                Tuple[TensorId, AxisId],
2869                Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2870            ] = {
2871                **{
2872                    (out.id, a.id): (out, a, a.size)
2873                    for a in out.axes
2874                    if not isinstance(a, BatchAxis)
2875                    and isinstance(a.size, (int, ParameterizedSize))
2876                },
2877                **input_size_refs,
2878                **output_size_refs,
2879            }
2880            for a, ax in enumerate(out.axes):
2881                cls._validate_axis(
2882                    "outputs",
2883                    i,
2884                    out.id,
2885                    a,
2886                    ax,
2887                    valid_independent_refs=valid_independent_refs,
2888                )
2889
2890        return outputs
2891
2892    packaged_by: List[Author] = Field(
2893        default_factory=cast(Callable[[], List[Author]], list)
2894    )
2895    """The persons that have packaged and uploaded this model.
2896    Only required if those persons differ from the `authors`."""
2897
2898    parent: Optional[LinkedModel] = None
2899    """The model from which this model is derived, e.g. by fine-tuning the weights."""
2900
2901    @model_validator(mode="after")
2902    def _validate_parent_is_not_self(self) -> Self:
2903        if self.parent is not None and self.parent.id == self.id:
2904            raise ValueError("A model description may not reference itself as parent.")
2905
2906        return self
2907
2908    run_mode: Annotated[
2909        Optional[RunMode],
2910        warn(None, "Run mode '{value}' has limited support across consumer softwares."),
2911    ] = None
2912    """Custom run mode for this model: for more complex prediction procedures like test time
2913    data augmentation that currently cannot be expressed in the specification.
2914    No standard run modes are defined yet."""
2915
2916    timestamp: Datetime = Field(default_factory=Datetime.now)
2917    """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format
2918    with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat).
2919    (In Python a datetime object is valid, too)."""
2920
2921    training_data: Annotated[
2922        Union[None, LinkedDataset, DatasetDescr, DatasetDescr02],
2923        Field(union_mode="left_to_right"),
2924    ] = None
2925    """The dataset used to train this model"""
2926
2927    weights: Annotated[WeightsDescr, WrapSerializer(package_weights)]
2928    """The weights for this model.
2929    Weights can be given for different formats, but should otherwise be equivalent.
2930    The available weight formats determine which consumers can use this model."""
2931
2932    config: Config = Field(default_factory=Config)
2933
2934    @model_validator(mode="after")
2935    def _add_default_cover(self) -> Self:
2936        if not get_validation_context().perform_io_checks or self.covers:
2937            return self
2938
2939        try:
2940            generated_covers = generate_covers(
2941                [(t, load_array(t.test_tensor)) for t in self.inputs],
2942                [(t, load_array(t.test_tensor)) for t in self.outputs],
2943            )
2944        except Exception as e:
2945            issue_warning(
2946                "Failed to generate cover image(s): {e}",
2947                value=self.covers,
2948                msg_context=dict(e=e),
2949                field="covers",
2950            )
2951        else:
2952            self.covers.extend(generated_covers)
2953
2954        return self
2955
2956    def get_input_test_arrays(self) -> List[NDArray[Any]]:
2957        data = [load_array(ipt.test_tensor) for ipt in self.inputs]
2958        assert all(isinstance(d, np.ndarray) for d in data)
2959        return data
2960
2961    def get_output_test_arrays(self) -> List[NDArray[Any]]:
2962        data = [load_array(out.test_tensor) for out in self.outputs]
2963        assert all(isinstance(d, np.ndarray) for d in data)
2964        return data
2965
2966    @staticmethod
2967    def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
2968        batch_size = 1
2969        tensor_with_batchsize: Optional[TensorId] = None
2970        for tid in tensor_sizes:
2971            for aid, s in tensor_sizes[tid].items():
2972                if aid != BATCH_AXIS_ID or s == 1 or s == batch_size:
2973                    continue
2974
2975                if batch_size != 1:
2976                    assert tensor_with_batchsize is not None
2977                    raise ValueError(
2978                        f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})"
2979                    )
2980
2981                batch_size = s
2982                tensor_with_batchsize = tid
2983
2984        return batch_size
2985
2986    def get_output_tensor_sizes(
2987        self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]
2988    ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]:
2989        """Returns the tensor output sizes for given **input_sizes**.
2990        Only if **input_sizes** has a valid input shape, the tensor output size is exact.
2991        Otherwise it might be larger than the actual (valid) output"""
2992        batch_size = self.get_batch_size(input_sizes)
2993        ns = self.get_ns(input_sizes)
2994
2995        tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size)
2996        return tensor_sizes.outputs
2997
2998    def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
2999        """get parameter `n` for each parameterized axis
3000        such that the valid input size is >= the given input size"""
3001        ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {}
3002        axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs}
3003        for tid in input_sizes:
3004            for aid, s in input_sizes[tid].items():
3005                size_descr = axes[tid][aid].size
3006                if isinstance(size_descr, ParameterizedSize):
3007                    ret[(tid, aid)] = size_descr.get_n(s)
3008                elif size_descr is None or isinstance(size_descr, (int, SizeReference)):
3009                    pass
3010                else:
3011                    assert_never(size_descr)
3012
3013        return ret
3014
3015    def get_tensor_sizes(
3016        self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int
3017    ) -> _TensorSizes:
3018        axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size)
3019        return _TensorSizes(
3020            {
3021                t: {
3022                    aa: axis_sizes.inputs[(tt, aa)]
3023                    for tt, aa in axis_sizes.inputs
3024                    if tt == t
3025                }
3026                for t in {tt for tt, _ in axis_sizes.inputs}
3027            },
3028            {
3029                t: {
3030                    aa: axis_sizes.outputs[(tt, aa)]
3031                    for tt, aa in axis_sizes.outputs
3032                    if tt == t
3033                }
3034                for t in {tt for tt, _ in axis_sizes.outputs}
3035            },
3036        )
3037
3038    def get_axis_sizes(
3039        self,
3040        ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
3041        batch_size: Optional[int] = None,
3042        *,
3043        max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
3044    ) -> _AxisSizes:
3045        """Determine input and output block shape for scale factors **ns**
3046        of parameterized input sizes.
3047
3048        Args:
3049            ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id))
3050                that is parameterized as `size = min + n * step`.
3051            batch_size: The desired size of the batch dimension.
3052                If given **batch_size** overwrites any batch size present in
3053                **max_input_shape**. Default 1.
3054            max_input_shape: Limits the derived block shapes.
3055                Each axis for which the input size, parameterized by `n`, is larger
3056                than **max_input_shape** is set to the minimal value `n_min` for which
3057                this is still true.
3058                Use this for small input samples or large values of **ns**.
3059                Or simply whenever you know the full input shape.
3060
3061        Returns:
3062            Resolved axis sizes for model inputs and outputs.
3063        """
3064        max_input_shape = max_input_shape or {}
3065        if batch_size is None:
3066            for (_t_id, a_id), s in max_input_shape.items():
3067                if a_id == BATCH_AXIS_ID:
3068                    batch_size = s
3069                    break
3070            else:
3071                batch_size = 1
3072
3073        all_axes = {
3074            t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
3075        }
3076
3077        inputs: Dict[Tuple[TensorId, AxisId], int] = {}
3078        outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {}
3079
3080        def get_axis_size(a: Union[InputAxis, OutputAxis]):
3081            if isinstance(a, BatchAxis):
3082                if (t_descr.id, a.id) in ns:
3083                    logger.warning(
3084                        "Ignoring unexpected size increment factor (n) for batch axis"
3085                        + " of tensor '{}'.",
3086                        t_descr.id,
3087                    )
3088                return batch_size
3089            elif isinstance(a.size, int):
3090                if (t_descr.id, a.id) in ns:
3091                    logger.warning(
3092                        "Ignoring unexpected size increment factor (n) for fixed size"
3093                        + " axis '{}' of tensor '{}'.",
3094                        a.id,
3095                        t_descr.id,
3096                    )
3097                return a.size
3098            elif isinstance(a.size, ParameterizedSize):
3099                if (t_descr.id, a.id) not in ns:
3100                    raise ValueError(
3101                        "Size increment factor (n) missing for parametrized axis"
3102                        + f" '{a.id}' of tensor '{t_descr.id}'."
3103                    )
3104                n = ns[(t_descr.id, a.id)]
3105                s_max = max_input_shape.get((t_descr.id, a.id))
3106                if s_max is not None:
3107                    n = min(n, a.size.get_n(s_max))
3108
3109                return a.size.get_size(n)
3110
3111            elif isinstance(a.size, SizeReference):
3112                if (t_descr.id, a.id) in ns:
3113                    logger.warning(
3114                        "Ignoring unexpected size increment factor (n) for axis '{}'"
3115                        + " of tensor '{}' with size reference.",
3116                        a.id,
3117                        t_descr.id,
3118                    )
3119                assert not isinstance(a, BatchAxis)
3120                ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
3121                assert not isinstance(ref_axis, BatchAxis)
3122                ref_key = (a.size.tensor_id, a.size.axis_id)
3123                ref_size = inputs.get(ref_key, outputs.get(ref_key))
3124                assert ref_size is not None, ref_key
3125                assert not isinstance(ref_size, _DataDepSize), ref_key
3126                return a.size.get_size(
3127                    axis=a,
3128                    ref_axis=ref_axis,
3129                    ref_size=ref_size,
3130                )
3131            elif isinstance(a.size, DataDependentSize):
3132                if (t_descr.id, a.id) in ns:
3133                    logger.warning(
3134                        "Ignoring unexpected increment factor (n) for data dependent"
3135                        + " size axis '{}' of tensor '{}'.",
3136                        a.id,
3137                        t_descr.id,
3138                    )
3139                return _DataDepSize(a.size.min, a.size.max)
3140            else:
3141                assert_never(a.size)
3142
3143        # first resolve all , but the `SizeReference` input sizes
3144        for t_descr in self.inputs:
3145            for a in t_descr.axes:
3146                if not isinstance(a.size, SizeReference):
3147                    s = get_axis_size(a)
3148                    assert not isinstance(s, _DataDepSize)
3149                    inputs[t_descr.id, a.id] = s
3150
3151        # resolve all other input axis sizes
3152        for t_descr in self.inputs:
3153            for a in t_descr.axes:
3154                if isinstance(a.size, SizeReference):
3155                    s = get_axis_size(a)
3156                    assert not isinstance(s, _DataDepSize)
3157                    inputs[t_descr.id, a.id] = s
3158
3159        # resolve all output axis sizes
3160        for t_descr in self.outputs:
3161            for a in t_descr.axes:
3162                assert not isinstance(a.size, ParameterizedSize)
3163                s = get_axis_size(a)
3164                outputs[t_descr.id, a.id] = s
3165
3166        return _AxisSizes(inputs=inputs, outputs=outputs)
3167
3168    @model_validator(mode="before")
3169    @classmethod
3170    def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]:
3171        cls.convert_from_old_format_wo_validation(data)
3172        return data
3173
3174    @classmethod
3175    def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3176        """Convert metadata following an older format version to this classes' format
3177        without validating the result.
3178        """
3179        if (
3180            data.get("type") == "model"
3181            and isinstance(fv := data.get("format_version"), str)
3182            and fv.count(".") == 2
3183        ):
3184            fv_parts = fv.split(".")
3185            if any(not p.isdigit() for p in fv_parts):
3186                return
3187
3188            fv_tuple = tuple(map(int, fv_parts))
3189
3190            assert cls.implemented_format_version_tuple[0:2] == (0, 5)
3191            if fv_tuple[:2] in ((0, 3), (0, 4)):
3192                m04 = _ModelDescr_v0_4.load(data)
3193                if isinstance(m04, InvalidDescr):
3194                    try:
3195                        updated = _model_conv.convert_as_dict(
3196                            m04  # pyright: ignore[reportArgumentType]
3197                        )
3198                    except Exception as e:
3199                        logger.error(
3200                            "Failed to convert from invalid model 0.4 description."
3201                            + f"\nerror: {e}"
3202                            + "\nProceeding with model 0.5 validation without conversion."
3203                        )
3204                        updated = None
3205                else:
3206                    updated = _model_conv.convert_as_dict(m04)
3207
3208                if updated is not None:
3209                    data.clear()
3210                    data.update(updated)
3211
3212            elif fv_tuple[:2] == (0, 5):
3213                # bump patch version
3214                data["format_version"] = cls.implemented_format_version

Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. These fields are typically stored in a YAML file which we call a model resource description file (model RDF).

implemented_format_version: ClassVar[Literal['0.5.4']] = '0.5.4'
implemented_type: ClassVar[Literal['model']] = 'model'
id: Optional[ModelId]

bioimage.io-wide unique resource identifier assigned by bioimage.io; version unspecific.

authors: Annotated[List[bioimageio.spec.generic.v0_3.Author], MinLen(min_length=1)]

The authors are the creators of the model RDF and the primary points of contact.

documentation: Annotated[Union[bioimageio.spec._internal.url.HttpUrl, bioimageio.spec._internal.io.RelativeFilePath, Annotated[pathlib.Path, PathType(path_type='file'), FieldInfo(annotation=NoneType, required=True, title='FilePath')]], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')]), AfterValidator(func=<function wo_special_file_name at 0x7f5fe201bd80>), PlainSerializer(func=<function _package_serializer at 0x7f5fe20b6d40>, return_type=PydanticUndefined, when_used='unless-none'), WithSuffix(suffix='.md', case_sensitive=True), FieldInfo(annotation=NoneType, required=True, examples=['https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/README.md', 'README.md'])]

URL or relative path to a markdown file with additional documentation. The recommended documentation file name is README.md. An .md suffix is mandatory. The documentation should include a '#[#] Validation' (sub)section with details on how to quantitatively validate the model on unseen data.

inputs: Annotated[Sequence[InputTensorDescr], MinLen(min_length=1)]

Describes the input tensors expected by this model.

name: Annotated[str, RestrictCharacters(alphabet='abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_+- ()'), MinLen(min_length=5), MaxLen(max_length=128), AfterWarner(func=<function as_warning.<locals>.wrapper at 0x7f5fded90b80>, severity=20, msg='Name longer than 64 characters.', context={'typ': Annotated[Any, MaxLen(max_length=64)]})]

A human-readable name of this model. It should be no longer than 64 characters and may only contain letter, number, underscore, minus, parentheses and spaces. We recommend to chose a name that refers to the model's task and image modality.

outputs: Annotated[Sequence[OutputTensorDescr], MinLen(min_length=1)]

Describes the output tensors.

The persons that have packaged and uploaded this model. Only required if those persons differ from the authors.

parent: Optional[LinkedModel]

The model from which this model is derived, e.g. by fine-tuning the weights.

run_mode: Annotated[Optional[bioimageio.spec.model.v0_4.RunMode], AfterWarner(func=<function as_warning.<locals>.wrapper at 0x7f5fded91f80>, severity=30, msg="Run mode '{value}' has limited support across consumer softwares.", context={'typ': None})]

Custom run mode for this model: for more complex prediction procedures like test time data augmentation that currently cannot be expressed in the specification. No standard run modes are defined yet.

Timestamp in ISO 8601 format with a few restrictions listed here. (In Python a datetime object is valid, too).

training_data: Annotated[Union[NoneType, bioimageio.spec.dataset.v0_3.LinkedDataset, bioimageio.spec.DatasetDescr, bioimageio.spec.dataset.v0_2.DatasetDescr], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')])]

The dataset used to train this model

weights: Annotated[WeightsDescr, WrapSerializer(func=<function package_weights at 0x7f5fe1ef1b20>, return_type=PydanticUndefined, when_used='always')]

The weights for this model. Weights can be given for different formats, but should otherwise be equivalent. The available weight formats determine which consumers can use this model.

config: Config
def get_input_test_arrays(self) -> List[numpy.ndarray[tuple[Any, ...], numpy.dtype[Any]]]:
2956    def get_input_test_arrays(self) -> List[NDArray[Any]]:
2957        data = [load_array(ipt.test_tensor) for ipt in self.inputs]
2958        assert all(isinstance(d, np.ndarray) for d in data)
2959        return data
def get_output_test_arrays(self) -> List[numpy.ndarray[tuple[Any, ...], numpy.dtype[Any]]]:
2961    def get_output_test_arrays(self) -> List[NDArray[Any]]:
2962        data = [load_array(out.test_tensor) for out in self.outputs]
2963        assert all(isinstance(d, np.ndarray) for d in data)
2964        return data
@staticmethod
def get_batch_size( tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
2966    @staticmethod
2967    def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
2968        batch_size = 1
2969        tensor_with_batchsize: Optional[TensorId] = None
2970        for tid in tensor_sizes:
2971            for aid, s in tensor_sizes[tid].items():
2972                if aid != BATCH_AXIS_ID or s == 1 or s == batch_size:
2973                    continue
2974
2975                if batch_size != 1:
2976                    assert tensor_with_batchsize is not None
2977                    raise ValueError(
2978                        f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})"
2979                    )
2980
2981                batch_size = s
2982                tensor_with_batchsize = tid
2983
2984        return batch_size
def get_output_tensor_sizes( self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> Dict[TensorId, Dict[AxisId, Union[int, bioimageio.spec.model.v0_5._DataDepSize]]]:
2986    def get_output_tensor_sizes(
2987        self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]
2988    ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]:
2989        """Returns the tensor output sizes for given **input_sizes**.
2990        Only if **input_sizes** has a valid input shape, the tensor output size is exact.
2991        Otherwise it might be larger than the actual (valid) output"""
2992        batch_size = self.get_batch_size(input_sizes)
2993        ns = self.get_ns(input_sizes)
2994
2995        tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size)
2996        return tensor_sizes.outputs

Returns the tensor output sizes for given input_sizes. Only if input_sizes has a valid input shape, the tensor output size is exact. Otherwise it might be larger than the actual (valid) output

def get_ns( self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
2998    def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
2999        """get parameter `n` for each parameterized axis
3000        such that the valid input size is >= the given input size"""
3001        ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {}
3002        axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs}
3003        for tid in input_sizes:
3004            for aid, s in input_sizes[tid].items():
3005                size_descr = axes[tid][aid].size
3006                if isinstance(size_descr, ParameterizedSize):
3007                    ret[(tid, aid)] = size_descr.get_n(s)
3008                elif size_descr is None or isinstance(size_descr, (int, SizeReference)):
3009                    pass
3010                else:
3011                    assert_never(size_descr)
3012
3013        return ret

get parameter n for each parameterized axis such that the valid input size is >= the given input size

def get_tensor_sizes( self, ns: Mapping[Tuple[TensorId, AxisId], int], batch_size: int) -> bioimageio.spec.model.v0_5._TensorSizes:
3015    def get_tensor_sizes(
3016        self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int
3017    ) -> _TensorSizes:
3018        axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size)
3019        return _TensorSizes(
3020            {
3021                t: {
3022                    aa: axis_sizes.inputs[(tt, aa)]
3023                    for tt, aa in axis_sizes.inputs
3024                    if tt == t
3025                }
3026                for t in {tt for tt, _ in axis_sizes.inputs}
3027            },
3028            {
3029                t: {
3030                    aa: axis_sizes.outputs[(tt, aa)]
3031                    for tt, aa in axis_sizes.outputs
3032                    if tt == t
3033                }
3034                for t in {tt for tt, _ in axis_sizes.outputs}
3035            },
3036        )
def get_axis_sizes( self, ns: Mapping[Tuple[TensorId, AxisId], int], batch_size: Optional[int] = None, *, max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None) -> bioimageio.spec.model.v0_5._AxisSizes:
3038    def get_axis_sizes(
3039        self,
3040        ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
3041        batch_size: Optional[int] = None,
3042        *,
3043        max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
3044    ) -> _AxisSizes:
3045        """Determine input and output block shape for scale factors **ns**
3046        of parameterized input sizes.
3047
3048        Args:
3049            ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id))
3050                that is parameterized as `size = min + n * step`.
3051            batch_size: The desired size of the batch dimension.
3052                If given **batch_size** overwrites any batch size present in
3053                **max_input_shape**. Default 1.
3054            max_input_shape: Limits the derived block shapes.
3055                Each axis for which the input size, parameterized by `n`, is larger
3056                than **max_input_shape** is set to the minimal value `n_min` for which
3057                this is still true.
3058                Use this for small input samples or large values of **ns**.
3059                Or simply whenever you know the full input shape.
3060
3061        Returns:
3062            Resolved axis sizes for model inputs and outputs.
3063        """
3064        max_input_shape = max_input_shape or {}
3065        if batch_size is None:
3066            for (_t_id, a_id), s in max_input_shape.items():
3067                if a_id == BATCH_AXIS_ID:
3068                    batch_size = s
3069                    break
3070            else:
3071                batch_size = 1
3072
3073        all_axes = {
3074            t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
3075        }
3076
3077        inputs: Dict[Tuple[TensorId, AxisId], int] = {}
3078        outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {}
3079
3080        def get_axis_size(a: Union[InputAxis, OutputAxis]):
3081            if isinstance(a, BatchAxis):
3082                if (t_descr.id, a.id) in ns:
3083                    logger.warning(
3084                        "Ignoring unexpected size increment factor (n) for batch axis"
3085                        + " of tensor '{}'.",
3086                        t_descr.id,
3087                    )
3088                return batch_size
3089            elif isinstance(a.size, int):
3090                if (t_descr.id, a.id) in ns:
3091                    logger.warning(
3092                        "Ignoring unexpected size increment factor (n) for fixed size"
3093                        + " axis '{}' of tensor '{}'.",
3094                        a.id,
3095                        t_descr.id,
3096                    )
3097                return a.size
3098            elif isinstance(a.size, ParameterizedSize):
3099                if (t_descr.id, a.id) not in ns:
3100                    raise ValueError(
3101                        "Size increment factor (n) missing for parametrized axis"
3102                        + f" '{a.id}' of tensor '{t_descr.id}'."
3103                    )
3104                n = ns[(t_descr.id, a.id)]
3105                s_max = max_input_shape.get((t_descr.id, a.id))
3106                if s_max is not None:
3107                    n = min(n, a.size.get_n(s_max))
3108
3109                return a.size.get_size(n)
3110
3111            elif isinstance(a.size, SizeReference):
3112                if (t_descr.id, a.id) in ns:
3113                    logger.warning(
3114                        "Ignoring unexpected size increment factor (n) for axis '{}'"
3115                        + " of tensor '{}' with size reference.",
3116                        a.id,
3117                        t_descr.id,
3118                    )
3119                assert not isinstance(a, BatchAxis)
3120                ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
3121                assert not isinstance(ref_axis, BatchAxis)
3122                ref_key = (a.size.tensor_id, a.size.axis_id)
3123                ref_size = inputs.get(ref_key, outputs.get(ref_key))
3124                assert ref_size is not None, ref_key
3125                assert not isinstance(ref_size, _DataDepSize), ref_key
3126                return a.size.get_size(
3127                    axis=a,
3128                    ref_axis=ref_axis,
3129                    ref_size=ref_size,
3130                )
3131            elif isinstance(a.size, DataDependentSize):
3132                if (t_descr.id, a.id) in ns:
3133                    logger.warning(
3134                        "Ignoring unexpected increment factor (n) for data dependent"
3135                        + " size axis '{}' of tensor '{}'.",
3136                        a.id,
3137                        t_descr.id,
3138                    )
3139                return _DataDepSize(a.size.min, a.size.max)
3140            else:
3141                assert_never(a.size)
3142
3143        # first resolve all , but the `SizeReference` input sizes
3144        for t_descr in self.inputs:
3145            for a in t_descr.axes:
3146                if not isinstance(a.size, SizeReference):
3147                    s = get_axis_size(a)
3148                    assert not isinstance(s, _DataDepSize)
3149                    inputs[t_descr.id, a.id] = s
3150
3151        # resolve all other input axis sizes
3152        for t_descr in self.inputs:
3153            for a in t_descr.axes:
3154                if isinstance(a.size, SizeReference):
3155                    s = get_axis_size(a)
3156                    assert not isinstance(s, _DataDepSize)
3157                    inputs[t_descr.id, a.id] = s
3158
3159        # resolve all output axis sizes
3160        for t_descr in self.outputs:
3161            for a in t_descr.axes:
3162                assert not isinstance(a.size, ParameterizedSize)
3163                s = get_axis_size(a)
3164                outputs[t_descr.id, a.id] = s
3165
3166        return _AxisSizes(inputs=inputs, outputs=outputs)

Determine input and output block shape for scale factors ns of parameterized input sizes.

Arguments:
  • ns: Scale factor n for each axis (keyed by (tensor_id, axis_id)) that is parameterized as size = min + n * step.
  • batch_size: The desired size of the batch dimension. If given batch_size overwrites any batch size present in max_input_shape. Default 1.
  • max_input_shape: Limits the derived block shapes. Each axis for which the input size, parameterized by n, is larger than max_input_shape is set to the minimal value n_min for which this is still true. Use this for small input samples or large values of ns. Or simply whenever you know the full input shape.
Returns:

Resolved axis sizes for model inputs and outputs.

@classmethod
def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3174    @classmethod
3175    def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3176        """Convert metadata following an older format version to this classes' format
3177        without validating the result.
3178        """
3179        if (
3180            data.get("type") == "model"
3181            and isinstance(fv := data.get("format_version"), str)
3182            and fv.count(".") == 2
3183        ):
3184            fv_parts = fv.split(".")
3185            if any(not p.isdigit() for p in fv_parts):
3186                return
3187
3188            fv_tuple = tuple(map(int, fv_parts))
3189
3190            assert cls.implemented_format_version_tuple[0:2] == (0, 5)
3191            if fv_tuple[:2] in ((0, 3), (0, 4)):
3192                m04 = _ModelDescr_v0_4.load(data)
3193                if isinstance(m04, InvalidDescr):
3194                    try:
3195                        updated = _model_conv.convert_as_dict(
3196                            m04  # pyright: ignore[reportArgumentType]
3197                        )
3198                    except Exception as e:
3199                        logger.error(
3200                            "Failed to convert from invalid model 0.4 description."
3201                            + f"\nerror: {e}"
3202                            + "\nProceeding with model 0.5 validation without conversion."
3203                        )
3204                        updated = None
3205                else:
3206                    updated = _model_conv.convert_as_dict(m04)
3207
3208                if updated is not None:
3209                    data.clear()
3210                    data.update(updated)
3211
3212            elif fv_tuple[:2] == (0, 5):
3213                # bump patch version
3214                data["format_version"] = cls.implemented_format_version

Convert metadata following an older format version to this classes' format without validating the result.

implemented_format_version_tuple: ClassVar[Tuple[int, int, int]] = (0, 5, 4)
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

def model_post_init(self: pydantic.main.BaseModel, context: Any, /) -> None:
337def init_private_attributes(self: BaseModel, context: Any, /) -> None:
338    """This function is meant to behave like a BaseModel method to initialise private attributes.
339
340    It takes context as an argument since that's what pydantic-core passes when calling it.
341
342    Args:
343        self: The BaseModel instance.
344        context: The context.
345    """
346    if getattr(self, '__pydantic_private__', None) is None:
347        pydantic_private = {}
348        for name, private_attr in self.__private_attributes__.items():
349            default = private_attr.get_default()
350            if default is not PydanticUndefined:
351                pydantic_private[name] = default
352        object_setattr(self, '__pydantic_private__', pydantic_private)

This function is meant to behave like a BaseModel method to initialise private attributes.

It takes context as an argument since that's what pydantic-core passes when calling it.

Arguments:
  • self: The BaseModel instance.
  • context: The context.
def generate_covers( inputs: Sequence[Tuple[InputTensorDescr, numpy.ndarray[tuple[Any, ...], numpy.dtype[Any]]]], outputs: Sequence[Tuple[OutputTensorDescr, numpy.ndarray[tuple[Any, ...], numpy.dtype[Any]]]]) -> List[pathlib.Path]:
3439def generate_covers(
3440    inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]],
3441    outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]],
3442) -> List[Path]:
3443    def squeeze(
3444        data: NDArray[Any], axes: Sequence[AnyAxis]
3445    ) -> Tuple[NDArray[Any], List[AnyAxis]]:
3446        """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining"""
3447        if data.ndim != len(axes):
3448            raise ValueError(
3449                f"tensor shape {data.shape} does not match described axes"
3450                + f" {[a.id for a in axes]}"
3451            )
3452
3453        axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1]
3454        return data.squeeze(), axes
3455
3456    def normalize(
3457        data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7
3458    ) -> NDArray[np.float32]:
3459        data = data.astype("float32")
3460        data -= data.min(axis=axis, keepdims=True)
3461        data /= data.max(axis=axis, keepdims=True) + eps
3462        return data
3463
3464    def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]):
3465        original_shape = data.shape
3466        data, axes = squeeze(data, axes)
3467
3468        # take slice fom any batch or index axis if needed
3469        # and convert the first channel axis and take a slice from any additional channel axes
3470        slices: Tuple[slice, ...] = ()
3471        ndim = data.ndim
3472        ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2
3473        has_c_axis = False
3474        for i, a in enumerate(axes):
3475            s = data.shape[i]
3476            assert s > 1
3477            if (
3478                isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis))
3479                and ndim > ndim_need
3480            ):
3481                data = data[slices + (slice(s // 2 - 1, s // 2),)]
3482                ndim -= 1
3483            elif isinstance(a, ChannelAxis):
3484                if has_c_axis:
3485                    # second channel axis
3486                    data = data[slices + (slice(0, 1),)]
3487                    ndim -= 1
3488                else:
3489                    has_c_axis = True
3490                    if s == 2:
3491                        # visualize two channels with cyan and magenta
3492                        data = np.concatenate(
3493                            [
3494                                data[slices + (slice(1, 2),)],
3495                                data[slices + (slice(0, 1),)],
3496                                (
3497                                    data[slices + (slice(0, 1),)]
3498                                    + data[slices + (slice(1, 2),)]
3499                                )
3500                                / 2,  # TODO: take maximum instead?
3501                            ],
3502                            axis=i,
3503                        )
3504                    elif data.shape[i] == 3:
3505                        pass  # visualize 3 channels as RGB
3506                    else:
3507                        # visualize first 3 channels as RGB
3508                        data = data[slices + (slice(3),)]
3509
3510                    assert data.shape[i] == 3
3511
3512            slices += (slice(None),)
3513
3514        data, axes = squeeze(data, axes)
3515        assert len(axes) == ndim
3516        # take slice from z axis if needed
3517        slices = ()
3518        if ndim > ndim_need:
3519            for i, a in enumerate(axes):
3520                s = data.shape[i]
3521                if a.id == AxisId("z"):
3522                    data = data[slices + (slice(s // 2 - 1, s // 2),)]
3523                    data, axes = squeeze(data, axes)
3524                    ndim -= 1
3525                    break
3526
3527            slices += (slice(None),)
3528
3529        # take slice from any space or time axis
3530        slices = ()
3531
3532        for i, a in enumerate(axes):
3533            if ndim <= ndim_need:
3534                break
3535
3536            s = data.shape[i]
3537            assert s > 1
3538            if isinstance(
3539                a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis)
3540            ):
3541                data = data[slices + (slice(s // 2 - 1, s // 2),)]
3542                ndim -= 1
3543
3544            slices += (slice(None),)
3545
3546        del slices
3547        data, axes = squeeze(data, axes)
3548        assert len(axes) == ndim
3549
3550        if (has_c_axis and ndim != 3) or ndim != 2:
3551            raise ValueError(
3552                f"Failed to construct cover image from shape {original_shape}"
3553            )
3554
3555        if not has_c_axis:
3556            assert ndim == 2
3557            data = np.repeat(data[:, :, None], 3, axis=2)
3558            axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB"))))
3559            ndim += 1
3560
3561        assert ndim == 3
3562
3563        # transpose axis order such that longest axis comes first...
3564        axis_order: List[int] = list(np.argsort(list(data.shape)))
3565        axis_order.reverse()
3566        # ... and channel axis is last
3567        c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0]
3568        axis_order.append(axis_order.pop(c))
3569        axes = [axes[ao] for ao in axis_order]
3570        data = data.transpose(axis_order)
3571
3572        # h, w = data.shape[:2]
3573        # if h / w  in (1.0 or 2.0):
3574        #     pass
3575        # elif h / w < 2:
3576        # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images
3577
3578        norm_along = (
3579            tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None
3580        )
3581        # normalize the data and map to 8 bit
3582        data = normalize(data, norm_along)
3583        data = (data * 255).astype("uint8")
3584
3585        return data
3586
3587    def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]):
3588        assert im0.dtype == im1.dtype == np.uint8
3589        assert im0.shape == im1.shape
3590        assert im0.ndim == 3
3591        N, M, C = im0.shape
3592        assert C == 3
3593        out = np.ones((N, M, C), dtype="uint8")
3594        for c in range(C):
3595            outc = np.tril(im0[..., c])
3596            mask = outc == 0
3597            outc[mask] = np.triu(im1[..., c])[mask]
3598            out[..., c] = outc
3599
3600        return out
3601
3602    ipt_descr, ipt = inputs[0]
3603    out_descr, out = outputs[0]
3604
3605    ipt_img = to_2d_image(ipt, ipt_descr.axes)
3606    out_img = to_2d_image(out, out_descr.axes)
3607
3608    cover_folder = Path(mkdtemp())
3609    if ipt_img.shape == out_img.shape:
3610        covers = [cover_folder / "cover.png"]
3611        imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img))
3612    else:
3613        covers = [cover_folder / "input.png", cover_folder / "output.png"]
3614        imwrite(covers[0], ipt_img)
3615        imwrite(covers[1], out_img)
3616
3617    return covers
class TensorDescrBase[Annotated[Union[BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis], Discriminator]](bioimageio.spec._internal.node.Node, typing.Generic[~IO_AxisT]):
class TensorDescrBase[Annotated[Union[BatchAxis, ChannelAxis, IndexOutputAxis, Annotated[Union[Annotated[TimeOutputAxis, Tag], Annotated[TimeOutputAxisWithHalo, Tag]], Discriminator], Annotated[Union[Annotated[SpaceOutputAxis, Tag], Annotated[SpaceOutputAxisWithHalo, Tag]], Discriminator]], Discriminator]](bioimageio.spec._internal.node.Node, typing.Generic[~IO_AxisT]):