bioimageio.spec.model.v0_5

   1from __future__ import annotations
   2
   3import collections.abc
   4import re
   5import string
   6import warnings
   7from abc import ABC
   8from copy import deepcopy
   9from datetime import datetime
  10from itertools import chain
  11from math import ceil
  12from pathlib import Path, PurePosixPath
  13from tempfile import mkdtemp
  14from typing import (
  15    TYPE_CHECKING,
  16    Any,
  17    ClassVar,
  18    Dict,
  19    FrozenSet,
  20    Generic,
  21    List,
  22    Literal,
  23    Mapping,
  24    NamedTuple,
  25    Optional,
  26    Sequence,
  27    Set,
  28    Tuple,
  29    Type,
  30    TypeVar,
  31    Union,
  32    cast,
  33)
  34
  35import numpy as np
  36from annotated_types import Ge, Gt, Interval, MaxLen, MinLen, Predicate
  37from imageio.v3 import imread, imwrite  # pyright: ignore[reportUnknownVariableType]
  38from loguru import logger
  39from numpy.typing import NDArray
  40from pydantic import (
  41    Discriminator,
  42    Field,
  43    RootModel,
  44    Tag,
  45    ValidationInfo,
  46    WrapSerializer,
  47    field_validator,
  48    model_validator,
  49)
  50from typing_extensions import Annotated, LiteralString, Self, assert_never, get_args
  51
  52from .._internal.common_nodes import (
  53    InvalidDescr,
  54    Node,
  55    NodeWithExplicitlySetFields,
  56)
  57from .._internal.constants import DTYPE_LIMITS
  58from .._internal.field_warning import issue_warning, warn
  59from .._internal.io import BioimageioYamlContent as BioimageioYamlContent
  60from .._internal.io import FileDescr as FileDescr
  61from .._internal.io import WithSuffix, YamlValue, download
  62from .._internal.io_basics import AbsoluteFilePath as AbsoluteFilePath
  63from .._internal.io_basics import Sha256 as Sha256
  64from .._internal.io_utils import load_array
  65from .._internal.node_converter import Converter
  66from .._internal.types import Datetime as Datetime
  67from .._internal.types import Identifier as Identifier
  68from .._internal.types import (
  69    ImportantFileSource,
  70    LowerCaseIdentifier,
  71    LowerCaseIdentifierAnno,
  72    SiUnit,
  73)
  74from .._internal.types import NotEmpty as NotEmpty
  75from .._internal.url import HttpUrl as HttpUrl
  76from .._internal.validation_context import validation_context_var
  77from .._internal.validator_annotations import RestrictCharacters
  78from .._internal.version_type import Version as Version
  79from .._internal.warning_levels import INFO
  80from ..dataset.v0_2 import DatasetDescr as DatasetDescr02
  81from ..dataset.v0_2 import LinkedDataset as LinkedDataset02
  82from ..dataset.v0_3 import DatasetDescr as DatasetDescr
  83from ..dataset.v0_3 import DatasetId as DatasetId
  84from ..dataset.v0_3 import LinkedDataset as LinkedDataset
  85from ..dataset.v0_3 import Uploader as Uploader
  86from ..generic.v0_3 import (
  87    VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS,
  88)
  89from ..generic.v0_3 import Author as Author
  90from ..generic.v0_3 import BadgeDescr as BadgeDescr
  91from ..generic.v0_3 import CiteEntry as CiteEntry
  92from ..generic.v0_3 import DeprecatedLicenseId as DeprecatedLicenseId
  93from ..generic.v0_3 import (
  94    DocumentationSource,
  95    GenericModelDescrBase,
  96    LinkedResourceBase,
  97    _author_conv,  # pyright: ignore[reportPrivateUsage]
  98    _maintainer_conv,  # pyright: ignore[reportPrivateUsage]
  99)
 100from ..generic.v0_3 import Doi as Doi
 101from ..generic.v0_3 import LicenseId as LicenseId
 102from ..generic.v0_3 import LinkedResource as LinkedResource
 103from ..generic.v0_3 import Maintainer as Maintainer
 104from ..generic.v0_3 import OrcidId as OrcidId
 105from ..generic.v0_3 import RelativeFilePath as RelativeFilePath
 106from ..generic.v0_3 import ResourceId as ResourceId
 107from .v0_4 import Author as _Author_v0_4
 108from .v0_4 import BinarizeDescr as _BinarizeDescr_v0_4
 109from .v0_4 import CallableFromDepencency as CallableFromDepencency
 110from .v0_4 import CallableFromDepencency as _CallableFromDepencency_v0_4
 111from .v0_4 import CallableFromFile as _CallableFromFile_v0_4
 112from .v0_4 import ClipDescr as _ClipDescr_v0_4
 113from .v0_4 import ClipKwargs as ClipKwargs
 114from .v0_4 import ImplicitOutputShape as _ImplicitOutputShape_v0_4
 115from .v0_4 import InputTensorDescr as _InputTensorDescr_v0_4
 116from .v0_4 import KnownRunMode as KnownRunMode
 117from .v0_4 import ModelDescr as _ModelDescr_v0_4
 118from .v0_4 import OutputTensorDescr as _OutputTensorDescr_v0_4
 119from .v0_4 import ParameterizedInputShape as _ParameterizedInputShape_v0_4
 120from .v0_4 import PostprocessingDescr as _PostprocessingDescr_v0_4
 121from .v0_4 import PreprocessingDescr as _PreprocessingDescr_v0_4
 122from .v0_4 import ProcessingKwargs as ProcessingKwargs
 123from .v0_4 import RunMode as RunMode
 124from .v0_4 import ScaleLinearDescr as _ScaleLinearDescr_v0_4
 125from .v0_4 import ScaleMeanVarianceDescr as _ScaleMeanVarianceDescr_v0_4
 126from .v0_4 import ScaleRangeDescr as _ScaleRangeDescr_v0_4
 127from .v0_4 import SigmoidDescr as _SigmoidDescr_v0_4
 128from .v0_4 import TensorName as _TensorName_v0_4
 129from .v0_4 import WeightsFormat as WeightsFormat
 130from .v0_4 import ZeroMeanUnitVarianceDescr as _ZeroMeanUnitVarianceDescr_v0_4
 131from .v0_4 import package_weights
 132
 133SpaceUnit = Literal[
 134    "attometer",
 135    "angstrom",
 136    "centimeter",
 137    "decimeter",
 138    "exameter",
 139    "femtometer",
 140    "foot",
 141    "gigameter",
 142    "hectometer",
 143    "inch",
 144    "kilometer",
 145    "megameter",
 146    "meter",
 147    "micrometer",
 148    "mile",
 149    "millimeter",
 150    "nanometer",
 151    "parsec",
 152    "petameter",
 153    "picometer",
 154    "terameter",
 155    "yard",
 156    "yoctometer",
 157    "yottameter",
 158    "zeptometer",
 159    "zettameter",
 160]
 161"""Space unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)"""
 162
 163TimeUnit = Literal[
 164    "attosecond",
 165    "centisecond",
 166    "day",
 167    "decisecond",
 168    "exasecond",
 169    "femtosecond",
 170    "gigasecond",
 171    "hectosecond",
 172    "hour",
 173    "kilosecond",
 174    "megasecond",
 175    "microsecond",
 176    "millisecond",
 177    "minute",
 178    "nanosecond",
 179    "petasecond",
 180    "picosecond",
 181    "second",
 182    "terasecond",
 183    "yoctosecond",
 184    "yottasecond",
 185    "zeptosecond",
 186    "zettasecond",
 187]
 188"""Time unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)"""
 189
 190AxisType = Literal["batch", "channel", "index", "time", "space"]
 191
 192
 193class TensorId(LowerCaseIdentifier):
 194    root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
 195        Annotated[LowerCaseIdentifierAnno, MaxLen(32)]
 196    ]
 197
 198
 199class AxisId(LowerCaseIdentifier):
 200    root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
 201        Annotated[LowerCaseIdentifierAnno, MaxLen(16)]
 202    ]
 203
 204
 205def _is_batch(a: str) -> bool:
 206    return a == BATCH_AXIS_ID
 207
 208
 209def _is_not_batch(a: str) -> bool:
 210    return not _is_batch(a)
 211
 212
 213NonBatchAxisId = Annotated[AxisId, Predicate(_is_not_batch)]
 214
 215PostprocessingId = Literal[
 216    "binarize",
 217    "clip",
 218    "ensure_dtype",
 219    "fixed_zero_mean_unit_variance",
 220    "scale_linear",
 221    "scale_mean_variance",
 222    "scale_range",
 223    "sigmoid",
 224    "zero_mean_unit_variance",
 225]
 226PreprocessingId = Literal[
 227    "binarize",
 228    "clip",
 229    "ensure_dtype",
 230    "scale_linear",
 231    "sigmoid",
 232    "zero_mean_unit_variance",
 233    "scale_range",
 234]
 235
 236
 237SAME_AS_TYPE = "<same as type>"
 238
 239
 240ParameterizedSize_N = int
 241
 242
 243class ParameterizedSize(Node):
 244    """Describes a range of valid tensor axis sizes as `size = min + n*step`."""
 245
 246    N: ClassVar[Type[int]] = ParameterizedSize_N
 247    """integer to parameterize this axis"""
 248
 249    min: Annotated[int, Gt(0)]
 250    step: Annotated[int, Gt(0)]
 251
 252    def validate_size(self, size: int) -> int:
 253        if size < self.min:
 254            raise ValueError(f"size {size} < {self.min}")
 255        if (size - self.min) % self.step != 0:
 256            raise ValueError(
 257                f"axis of size {size} is not parameterized by `min + n*step` ="
 258                + f" `{self.min} + n*{self.step}`"
 259            )
 260
 261        return size
 262
 263    def get_size(self, n: ParameterizedSize_N) -> int:
 264        return self.min + self.step * n
 265
 266    def get_n(self, s: int) -> ParameterizedSize_N:
 267        """return smallest n parameterizing a size greater or equal than `s`"""
 268        return ceil((s - self.min) / self.step)
 269
 270
 271class DataDependentSize(Node):
 272    min: Annotated[int, Gt(0)] = 1
 273    max: Annotated[Optional[int], Gt(1)] = None
 274
 275    @model_validator(mode="after")
 276    def _validate_max_gt_min(self):
 277        if self.max is not None and self.min >= self.max:
 278            raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}")
 279
 280        return self
 281
 282    def validate_size(self, size: int) -> int:
 283        if size < self.min:
 284            raise ValueError(f"size {size} < {self.min}")
 285
 286        if self.max is not None and size > self.max:
 287            raise ValueError(f"size {size} > {self.max}")
 288
 289        return size
 290
 291
 292class SizeReference(Node):
 293    """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis.
 294
 295    `axis.size = reference.size * reference.scale / axis.scale + offset`
 296
 297    Note:
 298    1. The axis and the referenced axis need to have the same unit (or no unit).
 299    2. Batch axes may not be referenced.
 300    3. Fractions are rounded down.
 301    4. If the reference axis is `concatenable` the referencing axis is assumed to be
 302        `concatenable` as well with the same block order.
 303
 304    Example:
 305    An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm².
 306    Let's assume that we want to express the image height h in relation to its width w
 307    instead of only accepting input images of exactly 100*49 pixels
 308    (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`).
 309
 310    >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2)
 311    >>> h = SpaceInputAxis(
 312    ...     id=AxisId("h"),
 313    ...     size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1),
 314    ...     unit="millimeter",
 315    ...     scale=4,
 316    ... )
 317    >>> print(h.size.get_size(h, w))
 318    49
 319
 320    ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49
 321    """
 322
 323    tensor_id: TensorId
 324    """tensor id of the reference axis"""
 325
 326    axis_id: AxisId
 327    """axis id of the reference axis"""
 328
 329    offset: int = 0
 330
 331    def get_size(
 332        self,
 333        axis: Union[
 334            ChannelAxis,
 335            IndexInputAxis,
 336            IndexOutputAxis,
 337            TimeInputAxis,
 338            SpaceInputAxis,
 339            TimeOutputAxis,
 340            TimeOutputAxisWithHalo,
 341            SpaceOutputAxis,
 342            SpaceOutputAxisWithHalo,
 343        ],
 344        ref_axis: Union[
 345            ChannelAxis,
 346            IndexInputAxis,
 347            IndexOutputAxis,
 348            TimeInputAxis,
 349            SpaceInputAxis,
 350            TimeOutputAxis,
 351            TimeOutputAxisWithHalo,
 352            SpaceOutputAxis,
 353            SpaceOutputAxisWithHalo,
 354        ],
 355        n: ParameterizedSize_N = 0,
 356        ref_size: Optional[int] = None,
 357    ):
 358        """Compute the concrete size for a given axis and its reference axis.
 359
 360        Args:
 361            axis: The axis this `SizeReference` is the size of.
 362            ref_axis: The reference axis to compute the size from.
 363            n: If the **ref_axis** is parameterized (of type `ParameterizedSize`)
 364                and no fixed **ref_size** is given,
 365                **n** is used to compute the size of the parameterized **ref_axis**.
 366            ref_size: Overwrite the reference size instead of deriving it from
 367                **ref_axis**
 368                (**ref_axis.scale** is still used; any given **n** is ignored).
 369        """
 370        assert (
 371            axis.size == self
 372        ), "Given `axis.size` is not defined by this `SizeReference`"
 373
 374        assert (
 375            ref_axis.id == self.axis_id
 376        ), f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}."
 377
 378        assert axis.unit == ref_axis.unit, (
 379            "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`,"
 380            f" but {axis.unit}!={ref_axis.unit}"
 381        )
 382        if ref_size is None:
 383            if isinstance(ref_axis.size, (int, float)):
 384                ref_size = ref_axis.size
 385            elif isinstance(ref_axis.size, ParameterizedSize):
 386                ref_size = ref_axis.size.get_size(n)
 387            elif isinstance(ref_axis.size, DataDependentSize):
 388                raise ValueError(
 389                    "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`."
 390                )
 391            elif isinstance(ref_axis.size, SizeReference):
 392                raise ValueError(
 393                    "Reference axis referenced in `SizeReference` may not be sized by a"
 394                    + " `SizeReference` itself."
 395                )
 396            else:
 397                assert_never(ref_axis.size)
 398
 399        return int(ref_size * ref_axis.scale / axis.scale + self.offset)
 400
 401    @staticmethod
 402    def _get_unit(
 403        axis: Union[
 404            ChannelAxis,
 405            IndexInputAxis,
 406            IndexOutputAxis,
 407            TimeInputAxis,
 408            SpaceInputAxis,
 409            TimeOutputAxis,
 410            TimeOutputAxisWithHalo,
 411            SpaceOutputAxis,
 412            SpaceOutputAxisWithHalo,
 413        ],
 414    ):
 415        return axis.unit
 416
 417
 418class AxisBase(NodeWithExplicitlySetFields):
 419    fields_to_set_explicitly: ClassVar[FrozenSet[LiteralString]] = frozenset({"type"})
 420
 421    id: AxisId
 422    """An axis id unique across all axes of one tensor."""
 423
 424    description: Annotated[str, MaxLen(128)] = ""
 425
 426
 427class WithHalo(Node):
 428    halo: Annotated[int, Ge(1)]
 429    """The halo should be cropped from the output tensor to avoid boundary effects.
 430    It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`.
 431    To document a halo that is already cropped by the model use `size.offset` instead."""
 432
 433    size: Annotated[
 434        SizeReference,
 435        Field(
 436            examples=[
 437                10,
 438                SizeReference(
 439                    tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
 440                ).model_dump(mode="json"),
 441            ]
 442        ),
 443    ]
 444    """reference to another axis with an optional offset (see `SizeReference`)"""
 445
 446
 447BATCH_AXIS_ID = AxisId("batch")
 448
 449
 450class BatchAxis(AxisBase):
 451    type: Literal["batch"] = "batch"
 452    id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID
 453    size: Optional[Literal[1]] = None
 454    """The batch size may be fixed to 1,
 455    otherwise (the default) it may be chosen arbitrarily depending on available memory"""
 456
 457    @property
 458    def scale(self):
 459        return 1.0
 460
 461    @property
 462    def concatenable(self):
 463        return True
 464
 465    @property
 466    def unit(self):
 467        return None
 468
 469
 470class ChannelAxis(AxisBase):
 471    type: Literal["channel"] = "channel"
 472    id: NonBatchAxisId = AxisId("channel")
 473    channel_names: NotEmpty[List[Identifier]]
 474
 475    @property
 476    def size(self) -> int:
 477        return len(self.channel_names)
 478
 479    @property
 480    def concatenable(self):
 481        return False
 482
 483    @property
 484    def scale(self) -> float:
 485        return 1.0
 486
 487    @property
 488    def unit(self):
 489        return None
 490
 491
 492class IndexAxisBase(AxisBase):
 493    type: Literal["index"] = "index"
 494    id: NonBatchAxisId = AxisId("index")
 495
 496    @property
 497    def scale(self) -> float:
 498        return 1.0
 499
 500    @property
 501    def unit(self):
 502        return None
 503
 504
 505class _WithInputAxisSize(Node):
 506    size: Annotated[
 507        Union[Annotated[int, Gt(0)], ParameterizedSize, SizeReference],
 508        Field(
 509            examples=[
 510                10,
 511                ParameterizedSize(min=32, step=16).model_dump(mode="json"),
 512                SizeReference(
 513                    tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
 514                ).model_dump(mode="json"),
 515            ]
 516        ),
 517    ]
 518    """The size/length of this axis can be specified as
 519    - fixed integer
 520    - parameterized series of valid sizes (`ParameterizedSize`)
 521    - reference to another axis with an optional offset (`SizeReference`)
 522    """
 523
 524
 525class IndexInputAxis(IndexAxisBase, _WithInputAxisSize):
 526    concatenable: bool = False
 527    """If a model has a `concatenable` input axis, it can be processed blockwise,
 528    splitting a longer sample axis into blocks matching its input tensor description.
 529    Output axes are concatenable if they have a `SizeReference` to a concatenable
 530    input axis.
 531    """
 532
 533
 534class IndexOutputAxis(IndexAxisBase):
 535    size: Annotated[
 536        Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize],
 537        Field(
 538            examples=[
 539                10,
 540                SizeReference(
 541                    tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
 542                ).model_dump(mode="json"),
 543            ]
 544        ),
 545    ]
 546    """The size/length of this axis can be specified as
 547    - fixed integer
 548    - reference to another axis with an optional offset (`SizeReference`)
 549    - data dependent size using `DataDependentSize` (size is only known after model inference)
 550    """
 551
 552
 553class TimeAxisBase(AxisBase):
 554    type: Literal["time"] = "time"
 555    id: NonBatchAxisId = AxisId("time")
 556    unit: Optional[TimeUnit] = None
 557    scale: Annotated[float, Gt(0)] = 1.0
 558
 559
 560class TimeInputAxis(TimeAxisBase, _WithInputAxisSize):
 561    concatenable: bool = False
 562    """If a model has a `concatenable` input axis, it can be processed blockwise,
 563    splitting a longer sample axis into blocks matching its input tensor description.
 564    Output axes are concatenable if they have a `SizeReference` to a concatenable
 565    input axis.
 566    """
 567
 568
 569class SpaceAxisBase(AxisBase):
 570    type: Literal["space"] = "space"
 571    id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x")
 572    unit: Optional[SpaceUnit] = None
 573    scale: Annotated[float, Gt(0)] = 1.0
 574
 575
 576class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize):
 577    concatenable: bool = False
 578    """If a model has a `concatenable` input axis, it can be processed blockwise,
 579    splitting a longer sample axis into blocks matching its input tensor description.
 580    Output axes are concatenable if they have a `SizeReference` to a concatenable
 581    input axis.
 582    """
 583
 584
 585_InputAxisUnion = Union[
 586    BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis
 587]
 588InputAxis = Annotated[_InputAxisUnion, Discriminator("type")]
 589
 590
 591class _WithOutputAxisSize(Node):
 592    size: Annotated[
 593        Union[Annotated[int, Gt(0)], SizeReference],
 594        Field(
 595            examples=[
 596                10,
 597                SizeReference(
 598                    tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
 599                ).model_dump(mode="json"),
 600            ]
 601        ),
 602    ]
 603    """The size/length of this axis can be specified as
 604    - fixed integer
 605    - reference to another axis with an optional offset (see `SizeReference`)
 606    """
 607
 608
 609class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize):
 610    pass
 611
 612
 613class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo):
 614    pass
 615
 616
 617def _get_halo_axis_discriminator_value(v: Any) -> Literal["with_halo", "wo_halo"]:
 618    if isinstance(v, dict):
 619        return "with_halo" if "halo" in v else "wo_halo"
 620    else:
 621        return "with_halo" if hasattr(v, "halo") else "wo_halo"
 622
 623
 624_TimeOutputAxisUnion = Annotated[
 625    Union[
 626        Annotated[TimeOutputAxis, Tag("wo_halo")],
 627        Annotated[TimeOutputAxisWithHalo, Tag("with_halo")],
 628    ],
 629    Discriminator(_get_halo_axis_discriminator_value),
 630]
 631
 632
 633class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize):
 634    pass
 635
 636
 637class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo):
 638    pass
 639
 640
 641_SpaceOutputAxisUnion = Annotated[
 642    Union[
 643        Annotated[SpaceOutputAxis, Tag("wo_halo")],
 644        Annotated[SpaceOutputAxisWithHalo, Tag("with_halo")],
 645    ],
 646    Discriminator(_get_halo_axis_discriminator_value),
 647]
 648
 649
 650_OutputAxisUnion = Union[
 651    BatchAxis, ChannelAxis, IndexOutputAxis, _TimeOutputAxisUnion, _SpaceOutputAxisUnion
 652]
 653OutputAxis = Annotated[_OutputAxisUnion, Discriminator("type")]
 654
 655AnyAxis = Union[InputAxis, OutputAxis]
 656
 657TVs = Union[
 658    NotEmpty[List[int]],
 659    NotEmpty[List[float]],
 660    NotEmpty[List[bool]],
 661    NotEmpty[List[str]],
 662]
 663
 664
 665NominalOrOrdinalDType = Literal[
 666    "float32",
 667    "float64",
 668    "uint8",
 669    "int8",
 670    "uint16",
 671    "int16",
 672    "uint32",
 673    "int32",
 674    "uint64",
 675    "int64",
 676    "bool",
 677]
 678
 679
 680class NominalOrOrdinalDataDescr(Node):
 681    values: TVs
 682    """A fixed set of nominal or an ascending sequence of ordinal values.
 683    In this case `data_type` is required to be an unsigend integer type, e.g. 'uint8'.
 684    String `values` are interpreted as labels for tensor values 0, ..., N.
 685    Note: as YAML 1.2 does not natively support a "set" datatype,
 686    nominal values should be given as a sequence (aka list/array) as well.
 687    """
 688
 689    type: Annotated[
 690        NominalOrOrdinalDType,
 691        Field(
 692            examples=[
 693                "float32",
 694                "uint8",
 695                "uint16",
 696                "int64",
 697                "bool",
 698            ],
 699        ),
 700    ] = "uint8"
 701
 702    @model_validator(mode="after")
 703    def _validate_values_match_type(
 704        self,
 705    ) -> Self:
 706        incompatible: List[Any] = []
 707        for v in self.values:
 708            if self.type == "bool":
 709                if not isinstance(v, bool):
 710                    incompatible.append(v)
 711            elif self.type in DTYPE_LIMITS:
 712                if (
 713                    isinstance(v, (int, float))
 714                    and (
 715                        v < DTYPE_LIMITS[self.type].min
 716                        or v > DTYPE_LIMITS[self.type].max
 717                    )
 718                    or (isinstance(v, str) and "uint" not in self.type)
 719                    or (isinstance(v, float) and "int" in self.type)
 720                ):
 721                    incompatible.append(v)
 722            else:
 723                incompatible.append(v)
 724
 725            if len(incompatible) == 5:
 726                incompatible.append("...")
 727                break
 728
 729        if incompatible:
 730            raise ValueError(
 731                f"data type '{self.type}' incompatible with values {incompatible}"
 732            )
 733
 734        return self
 735
 736    unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None
 737
 738    @property
 739    def range(self):
 740        if isinstance(self.values[0], str):
 741            return 0, len(self.values) - 1
 742        else:
 743            return min(self.values), max(self.values)
 744
 745
 746IntervalOrRatioDType = Literal[
 747    "float32",
 748    "float64",
 749    "uint8",
 750    "int8",
 751    "uint16",
 752    "int16",
 753    "uint32",
 754    "int32",
 755    "uint64",
 756    "int64",
 757]
 758
 759
 760class IntervalOrRatioDataDescr(Node):
 761    type: Annotated[  # todo: rename to dtype
 762        IntervalOrRatioDType,
 763        Field(
 764            examples=["float32", "float64", "uint8", "uint16"],
 765        ),
 766    ] = "float32"
 767    range: Tuple[Optional[float], Optional[float]] = (
 768        None,
 769        None,
 770    )
 771    """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor.
 772    `None` corresponds to min/max of what can be expressed by `data_type`."""
 773    unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit"
 774    scale: float = 1.0
 775    """Scale for data on an interval (or ratio) scale."""
 776    offset: Optional[float] = None
 777    """Offset for data on a ratio scale."""
 778
 779
 780TensorDataDescr = Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr]
 781
 782
 783class ProcessingDescrBase(NodeWithExplicitlySetFields, ABC):
 784    """processing base class"""
 785
 786    # id: Literal[PreprocessingId, PostprocessingId]  # make abstract field
 787    fields_to_set_explicitly: ClassVar[FrozenSet[LiteralString]] = frozenset({"id"})
 788
 789
 790class BinarizeKwargs(ProcessingKwargs):
 791    """key word arguments for `BinarizeDescr`"""
 792
 793    threshold: float
 794    """The fixed threshold"""
 795
 796
 797class BinarizeAlongAxisKwargs(ProcessingKwargs):
 798    """key word arguments for `BinarizeDescr`"""
 799
 800    threshold: NotEmpty[List[float]]
 801    """The fixed threshold values along `axis`"""
 802
 803    axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
 804    """The `threshold` axis"""
 805
 806
 807class BinarizeDescr(ProcessingDescrBase):
 808    """Binarize the tensor with a fixed threshold.
 809
 810    Values above `BinarizeKwargs.threshold`/`BinarizeAlongAxisKwargs.threshold`
 811    will be set to one, values below the threshold to zero.
 812
 813    Examples:
 814    - in YAML
 815        ```yaml
 816        postprocessing:
 817          - id: binarize
 818            kwargs:
 819              axis: 'channel'
 820              threshold: [0.25, 0.5, 0.75]
 821        ```
 822    - in Python:
 823        >>> postprocessing = [BinarizeDescr(
 824        ...   kwargs=BinarizeAlongAxisKwargs(
 825        ...       axis=AxisId('channel'),
 826        ...       threshold=[0.25, 0.5, 0.75],
 827        ...   )
 828        ... )]
 829    """
 830
 831    id: Literal["binarize"] = "binarize"
 832    kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs]
 833
 834
 835class ClipDescr(ProcessingDescrBase):
 836    """Set tensor values below min to min and above max to max.
 837
 838    See `ScaleRangeDescr` for examples.
 839    """
 840
 841    id: Literal["clip"] = "clip"
 842    kwargs: ClipKwargs
 843
 844
 845class EnsureDtypeKwargs(ProcessingKwargs):
 846    """key word arguments for `EnsureDtypeDescr`"""
 847
 848    dtype: Literal[
 849        "float32",
 850        "float64",
 851        "uint8",
 852        "int8",
 853        "uint16",
 854        "int16",
 855        "uint32",
 856        "int32",
 857        "uint64",
 858        "int64",
 859        "bool",
 860    ]
 861
 862
 863class EnsureDtypeDescr(ProcessingDescrBase):
 864    """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching).
 865
 866    This can for example be used to ensure the inner neural network model gets a
 867    different input tensor data type than the fully described bioimage.io model does.
 868
 869    Examples:
 870        The described bioimage.io model (incl. preprocessing) accepts any
 871        float32-compatible tensor, normalizes it with percentiles and clipping and then
 872        casts it to uint8, which is what the neural network in this example expects.
 873        - in YAML
 874            ```yaml
 875            inputs:
 876            - data:
 877                type: float32  # described bioimage.io model is compatible with any float32 input tensor
 878            preprocessing:
 879            - id: scale_range
 880                kwargs:
 881                axes: ['y', 'x']
 882                max_percentile: 99.8
 883                min_percentile: 5.0
 884            - id: clip
 885                kwargs:
 886                min: 0.0
 887                max: 1.0
 888            - id: ensure_dtype
 889                kwargs:
 890                dtype: uint8
 891            ```
 892        - in Python:
 893            >>> preprocessing = [
 894            ...     ScaleRangeDescr(
 895            ...         kwargs=ScaleRangeKwargs(
 896            ...           axes= (AxisId('y'), AxisId('x')),
 897            ...           max_percentile= 99.8,
 898            ...           min_percentile= 5.0,
 899            ...         )
 900            ...     ),
 901            ...     ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)),
 902            ...     EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")),
 903            ... ]
 904    """
 905
 906    id: Literal["ensure_dtype"] = "ensure_dtype"
 907    kwargs: EnsureDtypeKwargs
 908
 909
 910class ScaleLinearKwargs(ProcessingKwargs):
 911    """Key word arguments for `ScaleLinearDescr`"""
 912
 913    gain: float = 1.0
 914    """multiplicative factor"""
 915
 916    offset: float = 0.0
 917    """additive term"""
 918
 919    @model_validator(mode="after")
 920    def _validate(self) -> Self:
 921        if self.gain == 1.0 and self.offset == 0.0:
 922            raise ValueError(
 923                "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
 924                + " != 0.0."
 925            )
 926
 927        return self
 928
 929
 930class ScaleLinearAlongAxisKwargs(ProcessingKwargs):
 931    """Key word arguments for `ScaleLinearDescr`"""
 932
 933    axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
 934    """The axis of of gains/offsets values."""
 935
 936    gain: Union[float, NotEmpty[List[float]]] = 1.0
 937    """multiplicative factor"""
 938
 939    offset: Union[float, NotEmpty[List[float]]] = 0.0
 940    """additive term"""
 941
 942    @model_validator(mode="after")
 943    def _validate(self) -> Self:
 944
 945        if isinstance(self.gain, list):
 946            if isinstance(self.offset, list):
 947                if len(self.gain) != len(self.offset):
 948                    raise ValueError(
 949                        f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match."
 950                    )
 951            else:
 952                self.offset = [float(self.offset)] * len(self.gain)
 953        elif isinstance(self.offset, list):
 954            self.gain = [float(self.gain)] * len(self.offset)
 955        else:
 956            raise ValueError(
 957                "Do not specify an `axis` for scalar gain and offset values."
 958            )
 959
 960        if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset):
 961            raise ValueError(
 962                "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
 963                + " != 0.0."
 964            )
 965
 966        return self
 967
 968
 969class ScaleLinearDescr(ProcessingDescrBase):
 970    """Fixed linear scaling.
 971
 972    Examples:
 973      1. Scale with scalar gain and offset
 974        - in YAML
 975        ```yaml
 976        preprocessing:
 977          - id: scale_linear
 978            kwargs:
 979              gain: 2.0
 980              offset: 3.0
 981        ```
 982        - in Python:
 983        >>> preprocessing = [
 984        ...     ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0))
 985        ... ]
 986
 987      2. Independent scaling along an axis
 988        - in YAML
 989        ```yaml
 990        preprocessing:
 991          - id: scale_linear
 992            kwargs:
 993              axis: 'channel'
 994              gain: [1.0, 2.0, 3.0]
 995        ```
 996        - in Python:
 997        >>> preprocessing = [
 998        ...     ScaleLinearDescr(
 999        ...         kwargs=ScaleLinearAlongAxisKwargs(
1000        ...             axis=AxisId("channel"),
1001        ...             gain=[1.0, 2.0, 3.0],
1002        ...         )
1003        ...     )
1004        ... ]
1005
1006    """
1007
1008    id: Literal["scale_linear"] = "scale_linear"
1009    kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs]
1010
1011
1012class SigmoidDescr(ProcessingDescrBase):
1013    """The logistic sigmoid funciton, a.k.a. expit function.
1014
1015    Examples:
1016    - in YAML
1017        ```yaml
1018        postprocessing:
1019          - id: sigmoid
1020        ```
1021    - in Python:
1022        >>> postprocessing = [SigmoidDescr()]
1023    """
1024
1025    id: Literal["sigmoid"] = "sigmoid"
1026
1027    @property
1028    def kwargs(self) -> ProcessingKwargs:
1029        """empty kwargs"""
1030        return ProcessingKwargs()
1031
1032
1033class FixedZeroMeanUnitVarianceKwargs(ProcessingKwargs):
1034    """key word arguments for `FixedZeroMeanUnitVarianceDescr`"""
1035
1036    mean: float
1037    """The mean value to normalize with."""
1038
1039    std: Annotated[float, Ge(1e-6)]
1040    """The standard deviation value to normalize with."""
1041
1042
1043class FixedZeroMeanUnitVarianceAlongAxisKwargs(ProcessingKwargs):
1044    """key word arguments for `FixedZeroMeanUnitVarianceDescr`"""
1045
1046    mean: NotEmpty[List[float]]
1047    """The mean value(s) to normalize with."""
1048
1049    std: NotEmpty[List[Annotated[float, Ge(1e-6)]]]
1050    """The standard deviation value(s) to normalize with.
1051    Size must match `mean` values."""
1052
1053    axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])]
1054    """The axis of the mean/std values to normalize each entry along that dimension
1055    separately."""
1056
1057    @model_validator(mode="after")
1058    def _mean_and_std_match(self) -> Self:
1059        if len(self.mean) != len(self.std):
1060            raise ValueError(
1061                f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})"
1062                + " must match."
1063            )
1064
1065        return self
1066
1067
1068class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1069    """Subtract a given mean and divide by the standard deviation.
1070
1071    Normalize with fixed, precomputed values for
1072    `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std`
1073    Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given
1074    axes.
1075
1076    Examples:
1077    1. scalar value for whole tensor
1078        - in YAML
1079        ```yaml
1080        preprocessing:
1081          - id: fixed_zero_mean_unit_variance
1082            kwargs:
1083              mean: 103.5
1084              std: 13.7
1085        ```
1086        - in Python
1087        >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
1088        ...   kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7)
1089        ... )]
1090
1091    2. independently along an axis
1092        - in YAML
1093        ```yaml
1094        preprocessing:
1095          - id: fixed_zero_mean_unit_variance
1096            kwargs:
1097              axis: channel
1098              mean: [101.5, 102.5, 103.5]
1099              std: [11.7, 12.7, 13.7]
1100        ```
1101        - in Python
1102        >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
1103        ...   kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
1104        ...     axis=AxisId("channel"),
1105        ...     mean=[101.5, 102.5, 103.5],
1106        ...     std=[11.7, 12.7, 13.7],
1107        ...   )
1108        ... )]
1109    """
1110
1111    id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance"
1112    kwargs: Union[
1113        FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs
1114    ]
1115
1116
1117class ZeroMeanUnitVarianceKwargs(ProcessingKwargs):
1118    """key word arguments for `ZeroMeanUnitVarianceDescr`"""
1119
1120    axes: Annotated[
1121        Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1122    ] = None
1123    """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1124    For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1125    resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1126    To normalize each sample independently leave out the 'batch' axis.
1127    Default: Scale all axes jointly."""
1128
1129    eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1130    """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`."""
1131
1132
1133class ZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1134    """Subtract mean and divide by variance.
1135
1136    Examples:
1137        Subtract tensor mean and variance
1138        - in YAML
1139        ```yaml
1140        preprocessing:
1141          - id: zero_mean_unit_variance
1142        ```
1143        - in Python
1144        >>> preprocessing = [ZeroMeanUnitVarianceDescr()]
1145    """
1146
1147    id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance"
1148    kwargs: ZeroMeanUnitVarianceKwargs = Field(
1149        default_factory=ZeroMeanUnitVarianceKwargs
1150    )
1151
1152
1153class ScaleRangeKwargs(ProcessingKwargs):
1154    """key word arguments for `ScaleRangeDescr`
1155
1156    For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default)
1157    this processing step normalizes data to the [0, 1] intervall.
1158    For other percentiles the normalized values will partially be outside the [0, 1]
1159    intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the
1160    normalized values to a range.
1161    """
1162
1163    axes: Annotated[
1164        Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1165    ] = None
1166    """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value.
1167    For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1168    resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1169    To normalize samples independently, leave out the "batch" axis.
1170    Default: Scale all axes jointly."""
1171
1172    min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0
1173    """The lower percentile used to determine the value to align with zero."""
1174
1175    max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0
1176    """The upper percentile used to determine the value to align with one.
1177    Has to be bigger than `min_percentile`.
1178    The range is 1 to 100 instead of 0 to 100 to avoid mistakenly
1179    accepting percentiles specified in the range 0.0 to 1.0."""
1180
1181    eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1182    """Epsilon for numeric stability.
1183    `out = (tensor - v_lower) / (v_upper - v_lower + eps)`;
1184    with `v_lower,v_upper` values at the respective percentiles."""
1185
1186    reference_tensor: Optional[TensorId] = None
1187    """Tensor ID to compute the percentiles from. Default: The tensor itself.
1188    For any tensor in `inputs` only input tensor references are allowed."""
1189
1190    @field_validator("max_percentile", mode="after")
1191    @classmethod
1192    def min_smaller_max(cls, value: float, info: ValidationInfo) -> float:
1193        if (min_p := info.data["min_percentile"]) >= value:
1194            raise ValueError(f"min_percentile {min_p} >= max_percentile {value}")
1195
1196        return value
1197
1198
1199class ScaleRangeDescr(ProcessingDescrBase):
1200    """Scale with percentiles.
1201
1202    Examples:
1203    1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0
1204        - in YAML
1205        ```yaml
1206        preprocessing:
1207          - id: scale_range
1208            kwargs:
1209              axes: ['y', 'x']
1210              max_percentile: 99.8
1211              min_percentile: 5.0
1212        ```
1213        - in Python
1214        >>> preprocessing = [
1215        ...     ScaleRangeDescr(
1216        ...         kwargs=ScaleRangeKwargs(
1217        ...           axes= (AxisId('y'), AxisId('x')),
1218        ...           max_percentile= 99.8,
1219        ...           min_percentile= 5.0,
1220        ...         )
1221        ...     ),
1222        ...     ClipDescr(
1223        ...         kwargs=ClipKwargs(
1224        ...             min=0.0,
1225        ...             max=1.0,
1226        ...         )
1227        ...     ),
1228        ... ]
1229
1230      2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles.
1231        - in YAML
1232        ```yaml
1233        preprocessing:
1234          - id: scale_range
1235            kwargs:
1236              axes: ['y', 'x']
1237              max_percentile: 99.8
1238              min_percentile: 5.0
1239                  - id: scale_range
1240           - id: clip
1241             kwargs:
1242              min: 0.0
1243              max: 1.0
1244        ```
1245        - in Python
1246        >>> preprocessing = [ScaleRangeDescr(
1247        ...   kwargs=ScaleRangeKwargs(
1248        ...       axes= (AxisId('y'), AxisId('x')),
1249        ...       max_percentile= 99.8,
1250        ...       min_percentile= 5.0,
1251        ...   )
1252        ... )]
1253
1254    """
1255
1256    id: Literal["scale_range"] = "scale_range"
1257    kwargs: ScaleRangeKwargs
1258
1259
1260class ScaleMeanVarianceKwargs(ProcessingKwargs):
1261    """key word arguments for `ScaleMeanVarianceKwargs`"""
1262
1263    reference_tensor: TensorId
1264    """Name of tensor to match."""
1265
1266    axes: Annotated[
1267        Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1268    ] = None
1269    """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1270    For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1271    resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1272    To normalize samples independently, leave out the 'batch' axis.
1273    Default: Scale all axes jointly."""
1274
1275    eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1276    """Epsilon for numeric stability:
1277    `out  = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`"""
1278
1279
1280class ScaleMeanVarianceDescr(ProcessingDescrBase):
1281    """Scale a tensor's data distribution to match another tensor's mean/std.
1282    `out  = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`
1283    """
1284
1285    id: Literal["scale_mean_variance"] = "scale_mean_variance"
1286    kwargs: ScaleMeanVarianceKwargs
1287
1288
1289PreprocessingDescr = Annotated[
1290    Union[
1291        BinarizeDescr,
1292        ClipDescr,
1293        EnsureDtypeDescr,
1294        ScaleLinearDescr,
1295        SigmoidDescr,
1296        FixedZeroMeanUnitVarianceDescr,
1297        ZeroMeanUnitVarianceDescr,
1298        ScaleRangeDescr,
1299    ],
1300    Discriminator("id"),
1301]
1302PostprocessingDescr = Annotated[
1303    Union[
1304        BinarizeDescr,
1305        ClipDescr,
1306        EnsureDtypeDescr,
1307        ScaleLinearDescr,
1308        SigmoidDescr,
1309        FixedZeroMeanUnitVarianceDescr,
1310        ZeroMeanUnitVarianceDescr,
1311        ScaleRangeDescr,
1312        ScaleMeanVarianceDescr,
1313    ],
1314    Discriminator("id"),
1315]
1316
1317IO_AxisT = TypeVar("IO_AxisT", InputAxis, OutputAxis)
1318
1319
1320class TensorDescrBase(Node, Generic[IO_AxisT]):
1321    id: TensorId
1322    """Tensor id. No duplicates are allowed."""
1323
1324    description: Annotated[str, MaxLen(128)] = ""
1325    """free text description"""
1326
1327    axes: NotEmpty[Sequence[IO_AxisT]]
1328    """tensor axes"""
1329
1330    @property
1331    def shape(self):
1332        return tuple(a.size for a in self.axes)
1333
1334    @field_validator("axes", mode="after", check_fields=False)
1335    @classmethod
1336    def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]:
1337        batch_axes = [a for a in axes if a.type == "batch"]
1338        if len(batch_axes) > 1:
1339            raise ValueError(
1340                f"Only one batch axis (per tensor) allowed, but got {batch_axes}"
1341            )
1342
1343        seen_ids: Set[AxisId] = set()
1344        duplicate_axes_ids: Set[AxisId] = set()
1345        for a in axes:
1346            (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id)
1347
1348        if duplicate_axes_ids:
1349            raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}")
1350
1351        return axes
1352
1353    test_tensor: FileDescr
1354    """An example tensor to use for testing.
1355    Using the model with the test input tensors is expected to yield the test output tensors.
1356    Each test tensor has be a an ndarray in the
1357    [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format).
1358    The file extension must be '.npy'."""
1359
1360    sample_tensor: Optional[FileDescr] = None
1361    """A sample tensor to illustrate a possible input/output for the model,
1362    The sample image primarily serves to inform a human user about an example use case
1363    and is typically stored as .hdf5, .png or .tiff.
1364    It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats)
1365    (numpy's `.npy` format is not supported).
1366    The image dimensionality has to match the number of axes specified in this tensor description.
1367    """
1368
1369    @model_validator(mode="after")
1370    def _validate_sample_tensor(self) -> Self:
1371        if (
1372            self.sample_tensor is None
1373            or not validation_context_var.get().perform_io_checks
1374        ):
1375            return self
1376
1377        local = download(self.sample_tensor.source, sha256=self.sample_tensor.sha256)
1378        tensor: NDArray[Any] = imread(
1379            local.path.read_bytes(),
1380            extension=PurePosixPath(local.original_file_name).suffix,
1381        )
1382        n_dims = len(tensor.squeeze().shape)
1383        n_dims_min = n_dims_max = len(self.axes)
1384
1385        for a in self.axes:
1386            if isinstance(a, BatchAxis):
1387                n_dims_min -= 1
1388            elif isinstance(a.size, int):
1389                if a.size == 1:
1390                    n_dims_min -= 1
1391            elif isinstance(a.size, (ParameterizedSize, DataDependentSize)):
1392                if a.size.min == 1:
1393                    n_dims_min -= 1
1394            elif isinstance(a.size, SizeReference):
1395                if a.size.offset < 2:
1396                    # size reference may result in singleton axis
1397                    n_dims_min -= 1
1398            else:
1399                assert_never(a.size)
1400
1401        n_dims_min = max(0, n_dims_min)
1402        if n_dims < n_dims_min or n_dims > n_dims_max:
1403            raise ValueError(
1404                f"Expected sample tensor to have {n_dims_min} to"
1405                + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})."
1406            )
1407
1408        return self
1409
1410    data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = (
1411        IntervalOrRatioDataDescr()
1412    )
1413    """Description of the tensor's data values, optionally per channel.
1414    If specified per channel, the data `type` needs to match across channels."""
1415
1416    @property
1417    def dtype(
1418        self,
1419    ) -> Literal[
1420        "float32",
1421        "float64",
1422        "uint8",
1423        "int8",
1424        "uint16",
1425        "int16",
1426        "uint32",
1427        "int32",
1428        "uint64",
1429        "int64",
1430        "bool",
1431    ]:
1432        """dtype as specified under `data.type` or `data[i].type`"""
1433        if isinstance(self.data, collections.abc.Sequence):
1434            return self.data[0].type
1435        else:
1436            return self.data.type
1437
1438    @field_validator("data", mode="after")
1439    @classmethod
1440    def _check_data_type_across_channels(
1441        cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]
1442    ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]:
1443        if not isinstance(value, list):
1444            return value
1445
1446        dtypes = {t.type for t in value}
1447        if len(dtypes) > 1:
1448            raise ValueError(
1449                "Tensor data descriptions per channel need to agree in their data"
1450                + f" `type`, but found {dtypes}."
1451            )
1452
1453        return value
1454
1455    @model_validator(mode="after")
1456    def _check_data_matches_channelaxis(self) -> Self:
1457        if not isinstance(self.data, (list, tuple)):
1458            return self
1459
1460        for a in self.axes:
1461            if isinstance(a, ChannelAxis):
1462                size = a.size
1463                assert isinstance(size, int)
1464                break
1465        else:
1466            return self
1467
1468        if len(self.data) != size:
1469            raise ValueError(
1470                f"Got tensor data descriptions for {len(self.data)} channels, but"
1471                + f" '{a.id}' axis has size {size}."
1472            )
1473
1474        return self
1475
1476    def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]:
1477        if len(array.shape) != len(self.axes):
1478            raise ValueError(
1479                f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})"
1480                + f" incompatible with {len(self.axes)} axes."
1481            )
1482        return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
1483
1484
1485class InputTensorDescr(TensorDescrBase[InputAxis]):
1486    id: TensorId = TensorId("input")
1487    """Input tensor id.
1488    No duplicates are allowed across all inputs and outputs."""
1489
1490    optional: bool = False
1491    """indicates that this tensor may be `None`"""
1492
1493    preprocessing: List[PreprocessingDescr] = Field(default_factory=list)
1494    """Description of how this input should be preprocessed.
1495
1496    notes:
1497    - If preprocessing does not start with an 'ensure_dtype' entry, it is added
1498      to ensure an input tensor's data type matches the input tensor's data description.
1499    - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an
1500      'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally
1501      changing the data type.
1502    """
1503
1504    @model_validator(mode="after")
1505    def _validate_preprocessing_kwargs(self) -> Self:
1506        axes_ids = [a.id for a in self.axes]
1507        for p in self.preprocessing:
1508            kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
1509            if kwargs_axes is None:
1510                continue
1511
1512            if not isinstance(kwargs_axes, collections.abc.Sequence):
1513                raise ValueError(
1514                    f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}"
1515                )
1516
1517            if any(a not in axes_ids for a in kwargs_axes):
1518                raise ValueError(
1519                    "`preprocessing.i.kwargs.axes` needs to be subset of axes ids"
1520                )
1521
1522        if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
1523            dtype = self.data.type
1524        else:
1525            dtype = self.data[0].type
1526
1527        # ensure `preprocessing` begins with `EnsureDtypeDescr`
1528        if not self.preprocessing or not isinstance(
1529            self.preprocessing[0], EnsureDtypeDescr
1530        ):
1531            self.preprocessing.insert(
1532                0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1533            )
1534
1535        # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
1536        if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)):
1537            self.preprocessing.append(
1538                EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1539            )
1540
1541        return self
1542
1543
1544def convert_axes(
1545    axes: str,
1546    *,
1547    shape: Union[
1548        Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4
1549    ],
1550    tensor_type: Literal["input", "output"],
1551    halo: Optional[Sequence[int]],
1552    size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
1553):
1554    ret: List[AnyAxis] = []
1555    for i, a in enumerate(axes):
1556        axis_type = _AXIS_TYPE_MAP.get(a, a)
1557        if axis_type == "batch":
1558            ret.append(BatchAxis())
1559            continue
1560
1561        scale = 1.0
1562        if isinstance(shape, _ParameterizedInputShape_v0_4):
1563            if shape.step[i] == 0:
1564                size = shape.min[i]
1565            else:
1566                size = ParameterizedSize(min=shape.min[i], step=shape.step[i])
1567        elif isinstance(shape, _ImplicitOutputShape_v0_4):
1568            ref_t = str(shape.reference_tensor)
1569            if ref_t.count(".") == 1:
1570                t_id, orig_a_id = ref_t.split(".")
1571            else:
1572                t_id = ref_t
1573                orig_a_id = a
1574
1575            a_id = _AXIS_ID_MAP.get(orig_a_id, a)
1576            if not (orig_scale := shape.scale[i]):
1577                # old way to insert a new axis dimension
1578                size = int(2 * shape.offset[i])
1579            else:
1580                scale = 1 / orig_scale
1581                if axis_type in ("channel", "index"):
1582                    # these axes no longer have a scale
1583                    offset_from_scale = orig_scale * size_refs.get(
1584                        _TensorName_v0_4(t_id), {}
1585                    ).get(orig_a_id, 0)
1586                else:
1587                    offset_from_scale = 0
1588                size = SizeReference(
1589                    tensor_id=TensorId(t_id),
1590                    axis_id=AxisId(a_id),
1591                    offset=int(offset_from_scale + 2 * shape.offset[i]),
1592                )
1593        else:
1594            size = shape[i]
1595
1596        if axis_type == "time":
1597            if tensor_type == "input":
1598                ret.append(TimeInputAxis(size=size, scale=scale))
1599            else:
1600                assert not isinstance(size, ParameterizedSize)
1601                if halo is None:
1602                    ret.append(TimeOutputAxis(size=size, scale=scale))
1603                else:
1604                    assert not isinstance(size, int)
1605                    ret.append(
1606                        TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i])
1607                    )
1608
1609        elif axis_type == "index":
1610            if tensor_type == "input":
1611                ret.append(IndexInputAxis(size=size))
1612            else:
1613                if isinstance(size, ParameterizedSize):
1614                    size = DataDependentSize(min=size.min)
1615
1616                ret.append(IndexOutputAxis(size=size))
1617        elif axis_type == "channel":
1618            assert not isinstance(size, ParameterizedSize)
1619            if isinstance(size, SizeReference):
1620                warnings.warn(
1621                    "Conversion of channel size from an implicit output shape may be"
1622                    + " wrong"
1623                )
1624                ret.append(
1625                    ChannelAxis(
1626                        channel_names=[
1627                            Identifier(f"channel{i}") for i in range(size.offset)
1628                        ]
1629                    )
1630                )
1631            else:
1632                ret.append(
1633                    ChannelAxis(
1634                        channel_names=[Identifier(f"channel{i}") for i in range(size)]
1635                    )
1636                )
1637        elif axis_type == "space":
1638            if tensor_type == "input":
1639                ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale))
1640            else:
1641                assert not isinstance(size, ParameterizedSize)
1642                if halo is None or halo[i] == 0:
1643                    ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale))
1644                elif isinstance(size, int):
1645                    raise NotImplementedError(
1646                        f"output axis with halo and fixed size (here {size}) not allowed"
1647                    )
1648                else:
1649                    ret.append(
1650                        SpaceOutputAxisWithHalo(
1651                            id=AxisId(a), size=size, scale=scale, halo=halo[i]
1652                        )
1653                    )
1654
1655    return ret
1656
1657
1658_AXIS_TYPE_MAP = {
1659    "b": "batch",
1660    "t": "time",
1661    "i": "index",
1662    "c": "channel",
1663    "x": "space",
1664    "y": "space",
1665    "z": "space",
1666}
1667
1668_AXIS_ID_MAP = {
1669    "b": "batch",
1670    "t": "time",
1671    "i": "index",
1672    "c": "channel",
1673}
1674
1675
1676def _axes_letters_to_ids(
1677    axes: Optional[str],
1678) -> Optional[List[AxisId]]:
1679    if axes is None:
1680        return None
1681    return [AxisId(_AXIS_ID_MAP.get(a, a)) for a in map(str, axes)]
1682
1683
1684def _get_complement_v04_axis(
1685    tensor_axes: Sequence[str], axes: Optional[Sequence[str]]
1686) -> Optional[AxisId]:
1687    if axes is None:
1688        return None
1689
1690    non_complement_axes = set(axes) | {"b"}
1691    complement_axes = [a for a in tensor_axes if a not in non_complement_axes]
1692    if len(complement_axes) > 1:
1693        raise ValueError(
1694            f"Expected none or a single complement axis, but axes '{axes}' "
1695            + f"for tensor dims '{tensor_axes}' leave '{complement_axes}'."
1696        )
1697
1698    return None if not complement_axes else AxisId(complement_axes[0])
1699
1700
1701def _convert_proc(
1702    p: Union[_PreprocessingDescr_v0_4, _PostprocessingDescr_v0_4],
1703    tensor_axes: Sequence[str],
1704) -> Union[PreprocessingDescr, PostprocessingDescr]:
1705    if isinstance(p, _BinarizeDescr_v0_4):
1706        return BinarizeDescr(kwargs=BinarizeKwargs(threshold=p.kwargs.threshold))
1707    elif isinstance(p, _ClipDescr_v0_4):
1708        return ClipDescr(kwargs=ClipKwargs(min=p.kwargs.min, max=p.kwargs.max))
1709    elif isinstance(p, _SigmoidDescr_v0_4):
1710        return SigmoidDescr()
1711    elif isinstance(p, _ScaleLinearDescr_v0_4):
1712        axes = _axes_letters_to_ids(p.kwargs.axes)
1713        if p.kwargs.axes is None:
1714            axis = None
1715        else:
1716            axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes)
1717
1718        if axis is None:
1719            assert not isinstance(p.kwargs.gain, list)
1720            assert not isinstance(p.kwargs.offset, list)
1721            kwargs = ScaleLinearKwargs(gain=p.kwargs.gain, offset=p.kwargs.offset)
1722        else:
1723            kwargs = ScaleLinearAlongAxisKwargs(
1724                axis=axis, gain=p.kwargs.gain, offset=p.kwargs.offset
1725            )
1726        return ScaleLinearDescr(kwargs=kwargs)
1727    elif isinstance(p, _ScaleMeanVarianceDescr_v0_4):
1728        return ScaleMeanVarianceDescr(
1729            kwargs=ScaleMeanVarianceKwargs(
1730                axes=_axes_letters_to_ids(p.kwargs.axes),
1731                reference_tensor=TensorId(str(p.kwargs.reference_tensor)),
1732                eps=p.kwargs.eps,
1733            )
1734        )
1735    elif isinstance(p, _ZeroMeanUnitVarianceDescr_v0_4):
1736        if p.kwargs.mode == "fixed":
1737            mean = p.kwargs.mean
1738            std = p.kwargs.std
1739            assert mean is not None
1740            assert std is not None
1741
1742            axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes)
1743
1744            if axis is None:
1745                return FixedZeroMeanUnitVarianceDescr(
1746                    kwargs=FixedZeroMeanUnitVarianceKwargs(
1747                        mean=mean, std=std  # pyright: ignore[reportArgumentType]
1748                    )
1749                )
1750            else:
1751                if not isinstance(mean, list):
1752                    mean = [float(mean)]
1753                if not isinstance(std, list):
1754                    std = [float(std)]
1755
1756                return FixedZeroMeanUnitVarianceDescr(
1757                    kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
1758                        axis=axis, mean=mean, std=std
1759                    )
1760                )
1761
1762        else:
1763            axes = _axes_letters_to_ids(p.kwargs.axes) or []
1764            if p.kwargs.mode == "per_dataset":
1765                axes = [AxisId("batch")] + axes
1766            if not axes:
1767                axes = None
1768            return ZeroMeanUnitVarianceDescr(
1769                kwargs=ZeroMeanUnitVarianceKwargs(axes=axes, eps=p.kwargs.eps)
1770            )
1771
1772    elif isinstance(p, _ScaleRangeDescr_v0_4):
1773        return ScaleRangeDescr(
1774            kwargs=ScaleRangeKwargs(
1775                axes=_axes_letters_to_ids(p.kwargs.axes),
1776                min_percentile=p.kwargs.min_percentile,
1777                max_percentile=p.kwargs.max_percentile,
1778                eps=p.kwargs.eps,
1779            )
1780        )
1781    else:
1782        assert_never(p)
1783
1784
1785class _InputTensorConv(
1786    Converter[
1787        _InputTensorDescr_v0_4,
1788        InputTensorDescr,
1789        ImportantFileSource,
1790        Optional[ImportantFileSource],
1791        Mapping[_TensorName_v0_4, Mapping[str, int]],
1792    ]
1793):
1794    def _convert(
1795        self,
1796        src: _InputTensorDescr_v0_4,
1797        tgt: "type[InputTensorDescr] | type[dict[str, Any]]",
1798        test_tensor: ImportantFileSource,
1799        sample_tensor: Optional[ImportantFileSource],
1800        size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
1801    ) -> "InputTensorDescr | dict[str, Any]":
1802        axes: List[InputAxis] = convert_axes(  # pyright: ignore[reportAssignmentType]
1803            src.axes,
1804            shape=src.shape,
1805            tensor_type="input",
1806            halo=None,
1807            size_refs=size_refs,
1808        )
1809        prep: List[PreprocessingDescr] = []
1810        for p in src.preprocessing:
1811            cp = _convert_proc(p, src.axes)
1812            assert not isinstance(cp, ScaleMeanVarianceDescr)
1813            prep.append(cp)
1814
1815        return tgt(
1816            axes=axes,
1817            id=TensorId(str(src.name)),
1818            test_tensor=FileDescr(source=test_tensor),
1819            sample_tensor=(
1820                None if sample_tensor is None else FileDescr(source=sample_tensor)
1821            ),
1822            data=dict(type=src.data_type),  # pyright: ignore[reportArgumentType]
1823            preprocessing=prep,
1824        )
1825
1826
1827_input_tensor_conv = _InputTensorConv(_InputTensorDescr_v0_4, InputTensorDescr)
1828
1829
1830class OutputTensorDescr(TensorDescrBase[OutputAxis]):
1831    id: TensorId = TensorId("output")
1832    """Output tensor id.
1833    No duplicates are allowed across all inputs and outputs."""
1834
1835    postprocessing: List[PostprocessingDescr] = Field(default_factory=list)
1836    """Description of how this output should be postprocessed.
1837
1838    note: `postprocessing` always ends with an 'ensure_dtype' operation.
1839          If not given this is added to cast to this tensor's `data.type`.
1840    """
1841
1842    @model_validator(mode="after")
1843    def _validate_postprocessing_kwargs(self) -> Self:
1844        axes_ids = [a.id for a in self.axes]
1845        for p in self.postprocessing:
1846            kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
1847            if kwargs_axes is None:
1848                continue
1849
1850            if not isinstance(kwargs_axes, collections.abc.Sequence):
1851                raise ValueError(
1852                    f"expected `axes` sequence, but got {type(kwargs_axes)}"
1853                )
1854
1855            if any(a not in axes_ids for a in kwargs_axes):
1856                raise ValueError("`kwargs.axes` needs to be subset of axes ids")
1857
1858        if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
1859            dtype = self.data.type
1860        else:
1861            dtype = self.data[0].type
1862
1863        # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
1864        if not self.postprocessing or not isinstance(
1865            self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)
1866        ):
1867            self.postprocessing.append(
1868                EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1869            )
1870        return self
1871
1872
1873class _OutputTensorConv(
1874    Converter[
1875        _OutputTensorDescr_v0_4,
1876        OutputTensorDescr,
1877        ImportantFileSource,
1878        Optional[ImportantFileSource],
1879        Mapping[_TensorName_v0_4, Mapping[str, int]],
1880    ]
1881):
1882    def _convert(
1883        self,
1884        src: _OutputTensorDescr_v0_4,
1885        tgt: "type[OutputTensorDescr] | type[dict[str, Any]]",
1886        test_tensor: ImportantFileSource,
1887        sample_tensor: Optional[ImportantFileSource],
1888        size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
1889    ) -> "OutputTensorDescr | dict[str, Any]":
1890        # TODO: split convert_axes into convert_output_axes and convert_input_axes
1891        axes: List[OutputAxis] = convert_axes(  # pyright: ignore[reportAssignmentType]
1892            src.axes,
1893            shape=src.shape,
1894            tensor_type="output",
1895            halo=src.halo,
1896            size_refs=size_refs,
1897        )
1898        data_descr: Dict[str, Any] = dict(type=src.data_type)
1899        if data_descr["type"] == "bool":
1900            data_descr["values"] = [False, True]
1901
1902        return tgt(
1903            axes=axes,
1904            id=TensorId(str(src.name)),
1905            test_tensor=FileDescr(source=test_tensor),
1906            sample_tensor=(
1907                None if sample_tensor is None else FileDescr(source=sample_tensor)
1908            ),
1909            data=data_descr,  # pyright: ignore[reportArgumentType]
1910            postprocessing=[_convert_proc(p, src.axes) for p in src.postprocessing],
1911        )
1912
1913
1914_output_tensor_conv = _OutputTensorConv(_OutputTensorDescr_v0_4, OutputTensorDescr)
1915
1916
1917TensorDescr = Union[InputTensorDescr, OutputTensorDescr]
1918
1919
1920def validate_tensors(
1921    tensors: Mapping[TensorId, Tuple[TensorDescr, NDArray[Any]]],
1922    tensor_origin: str,  # for more precise error messages, e.g. 'test_tensor'
1923):
1924    all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, int]]] = {}
1925
1926    def e_msg(d: TensorDescr):
1927        return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]"
1928
1929    for descr, array in tensors.values():
1930        try:
1931            axis_sizes = descr.get_axis_sizes_for_array(array)
1932        except ValueError as e:
1933            raise ValueError(f"{e_msg(descr)} {e}")
1934        else:
1935            all_tensor_axes[descr.id] = {
1936                a.id: (a, axis_sizes[a.id]) for a in descr.axes
1937            }
1938
1939    for descr, array in tensors.values():
1940        if array.dtype.name != descr.dtype:
1941            raise ValueError(
1942                f"{e_msg(descr)}.{tensor_origin}.dtype '{array.dtype.name}' does not"
1943                + f" match described dtype '{descr.dtype}'"
1944            )
1945
1946        for a in descr.axes:
1947            actual_size = all_tensor_axes[descr.id][a.id][1]
1948            if a.size is None:
1949                continue
1950
1951            if isinstance(a.size, int):
1952                if actual_size != a.size:
1953                    raise ValueError(
1954                        f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' "
1955                        + f"has incompatible size {actual_size}, expected {a.size}"
1956                    )
1957            elif isinstance(a.size, ParameterizedSize):
1958                _ = a.size.validate_size(actual_size)
1959            elif isinstance(a.size, DataDependentSize):
1960                _ = a.size.validate_size(actual_size)
1961            elif isinstance(a.size, SizeReference):
1962                ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id)
1963                if ref_tensor_axes is None:
1964                    raise ValueError(
1965                        f"{e_msg(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor"
1966                        + f" reference '{a.size.tensor_id}'"
1967                    )
1968
1969                ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None))
1970                if ref_axis is None or ref_size is None:
1971                    raise ValueError(
1972                        f"{e_msg(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis"
1973                        + f" reference '{a.size.tensor_id}.{a.size.axis_id}"
1974                    )
1975
1976                if a.unit != ref_axis.unit:
1977                    raise ValueError(
1978                        f"{e_msg(descr)}.axes[{a.id}].size: `SizeReference` requires"
1979                        + " axis and reference axis to have the same `unit`, but"
1980                        + f" {a.unit}!={ref_axis.unit}"
1981                    )
1982
1983                if actual_size != (
1984                    expected_size := (
1985                        ref_size * ref_axis.scale / a.scale + a.size.offset
1986                    )
1987                ):
1988                    raise ValueError(
1989                        f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' of size"
1990                        + f" {actual_size} invalid for referenced size {ref_size};"
1991                        + f" expected {expected_size}"
1992                    )
1993            else:
1994                assert_never(a.size)
1995
1996
1997class EnvironmentFileDescr(FileDescr):
1998    source: Annotated[
1999        ImportantFileSource,
2000        WithSuffix((".yaml", ".yml"), case_sensitive=True),
2001        Field(
2002            examples=["environment.yaml"],
2003        ),
2004    ]
2005    """∈📦 Conda environment file.
2006    Allows to specify custom dependencies, see conda docs:
2007    - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms)
2008    - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually)
2009    """
2010
2011
2012class _ArchitectureCallableDescr(Node):
2013    callable: Annotated[Identifier, Field(examples=["MyNetworkClass", "get_my_model"])]
2014    """Identifier of the callable that returns a torch.nn.Module instance."""
2015
2016    kwargs: Dict[str, YamlValue] = Field(default_factory=dict)
2017    """key word arguments for the `callable`"""
2018
2019
2020class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr):
2021    pass
2022
2023
2024class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr):
2025    import_from: str
2026    """Where to import the callable from, i.e. `from <import_from> import <callable>`"""
2027
2028
2029ArchitectureDescr = Annotated[
2030    Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr],
2031    Field(union_mode="left_to_right"),
2032]
2033
2034
2035class _ArchFileConv(
2036    Converter[
2037        _CallableFromFile_v0_4,
2038        ArchitectureFromFileDescr,
2039        Optional[Sha256],
2040        Dict[str, Any],
2041    ]
2042):
2043    def _convert(
2044        self,
2045        src: _CallableFromFile_v0_4,
2046        tgt: "type[ArchitectureFromFileDescr | dict[str, Any]]",
2047        sha256: Optional[Sha256],
2048        kwargs: Dict[str, Any],
2049    ) -> "ArchitectureFromFileDescr | dict[str, Any]":
2050        if src.startswith("http") and src.count(":") == 2:
2051            http, source, callable_ = src.split(":")
2052            source = ":".join((http, source))
2053        elif not src.startswith("http") and src.count(":") == 1:
2054            source, callable_ = src.split(":")
2055        else:
2056            source = str(src)
2057            callable_ = str(src)
2058        return tgt(
2059            callable=Identifier(callable_),
2060            source=cast(ImportantFileSource, source),
2061            sha256=sha256,
2062            kwargs=kwargs,
2063        )
2064
2065
2066_arch_file_conv = _ArchFileConv(_CallableFromFile_v0_4, ArchitectureFromFileDescr)
2067
2068
2069class _ArchLibConv(
2070    Converter[
2071        _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr, Dict[str, Any]
2072    ]
2073):
2074    def _convert(
2075        self,
2076        src: _CallableFromDepencency_v0_4,
2077        tgt: "type[ArchitectureFromLibraryDescr | dict[str, Any]]",
2078        kwargs: Dict[str, Any],
2079    ) -> "ArchitectureFromLibraryDescr | dict[str, Any]":
2080        *mods, callable_ = src.split(".")
2081        import_from = ".".join(mods)
2082        return tgt(
2083            import_from=import_from, callable=Identifier(callable_), kwargs=kwargs
2084        )
2085
2086
2087_arch_lib_conv = _ArchLibConv(
2088    _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr
2089)
2090
2091
2092class WeightsEntryDescrBase(FileDescr):
2093    type: ClassVar[WeightsFormat]
2094    weights_format_name: ClassVar[str]  # human readable
2095
2096    source: ImportantFileSource
2097    """∈📦 The weights file."""
2098
2099    authors: Optional[List[Author]] = None
2100    """Authors
2101    Either the person(s) that have trained this model resulting in the original weights file.
2102        (If this is the initial weights entry, i.e. it does not have a `parent`)
2103    Or the person(s) who have converted the weights to this weights format.
2104        (If this is a child weight, i.e. it has a `parent` field)
2105    """
2106
2107    parent: Annotated[
2108        Optional[WeightsFormat], Field(examples=["pytorch_state_dict"])
2109    ] = None
2110    """The source weights these weights were converted from.
2111    For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`,
2112    The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights.
2113    All weight entries except one (the initial set of weights resulting from training the model),
2114    need to have this field."""
2115
2116    @model_validator(mode="after")
2117    def check_parent_is_not_self(self) -> Self:
2118        if self.type == self.parent:
2119            raise ValueError("Weights entry can't be it's own parent.")
2120
2121        return self
2122
2123
2124class KerasHdf5WeightsDescr(WeightsEntryDescrBase):
2125    type = "keras_hdf5"
2126    weights_format_name: ClassVar[str] = "Keras HDF5"
2127    tensorflow_version: Version
2128    """TensorFlow version used to create these weights."""
2129
2130
2131class OnnxWeightsDescr(WeightsEntryDescrBase):
2132    type = "onnx"
2133    weights_format_name: ClassVar[str] = "ONNX"
2134    opset_version: Annotated[int, Ge(7)]
2135    """ONNX opset version"""
2136
2137
2138class PytorchStateDictWeightsDescr(WeightsEntryDescrBase):
2139    type = "pytorch_state_dict"
2140    weights_format_name: ClassVar[str] = "Pytorch State Dict"
2141    architecture: ArchitectureDescr
2142    pytorch_version: Version
2143    """Version of the PyTorch library used.
2144    If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible.
2145    """
2146    dependencies: Optional[EnvironmentFileDescr] = None
2147    """Custom depencies beyond pytorch.
2148    The conda environment file should include pytorch and any version pinning has to be compatible with
2149    `pytorch_version`.
2150    """
2151
2152
2153class TensorflowJsWeightsDescr(WeightsEntryDescrBase):
2154    type = "tensorflow_js"
2155    weights_format_name: ClassVar[str] = "Tensorflow.js"
2156    tensorflow_version: Version
2157    """Version of the TensorFlow library used."""
2158
2159    source: ImportantFileSource
2160    """∈📦 The multi-file weights.
2161    All required files/folders should be a zip archive."""
2162
2163
2164class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase):
2165    type = "tensorflow_saved_model_bundle"
2166    weights_format_name: ClassVar[str] = "Tensorflow Saved Model"
2167    tensorflow_version: Version
2168    """Version of the TensorFlow library used."""
2169
2170    dependencies: Optional[EnvironmentFileDescr] = None
2171    """Custom dependencies beyond tensorflow.
2172    Should include tensorflow and any version pinning has to be compatible with `tensorflow_version`."""
2173
2174    source: ImportantFileSource
2175    """∈📦 The multi-file weights.
2176    All required files/folders should be a zip archive."""
2177
2178
2179class TorchscriptWeightsDescr(WeightsEntryDescrBase):
2180    type = "torchscript"
2181    weights_format_name: ClassVar[str] = "TorchScript"
2182    pytorch_version: Version
2183    """Version of the PyTorch library used."""
2184
2185
2186class WeightsDescr(Node):
2187    keras_hdf5: Optional[KerasHdf5WeightsDescr] = None
2188    onnx: Optional[OnnxWeightsDescr] = None
2189    pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None
2190    tensorflow_js: Optional[TensorflowJsWeightsDescr] = None
2191    tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = (
2192        None
2193    )
2194    torchscript: Optional[TorchscriptWeightsDescr] = None
2195
2196    @model_validator(mode="after")
2197    def check_entries(self) -> Self:
2198        entries = {wtype for wtype, entry in self if entry is not None}
2199
2200        if not entries:
2201            raise ValueError("Missing weights entry")
2202
2203        entries_wo_parent = {
2204            wtype
2205            for wtype, entry in self
2206            if entry is not None and hasattr(entry, "parent") and entry.parent is None
2207        }
2208        if len(entries_wo_parent) != 1:
2209            issue_warning(
2210                "Exactly one weights entry may not specify the `parent` field (got"
2211                + " {value}). That entry is considered the original set of model weights."
2212                + " Other weight formats are created through conversion of the orignal or"
2213                + " already converted weights. They have to reference the weights format"
2214                + " they were converted from as their `parent`.",
2215                value=len(entries_wo_parent),
2216                field="weights",
2217            )
2218
2219        for wtype, entry in self:
2220            if entry is None:
2221                continue
2222
2223            assert hasattr(entry, "type")
2224            assert hasattr(entry, "parent")
2225            assert wtype == entry.type
2226            if (
2227                entry.parent is not None and entry.parent not in entries
2228            ):  # self reference checked for `parent` field
2229                raise ValueError(
2230                    f"`weights.{wtype}.parent={entry.parent} not in specified weight"
2231                    + f" formats: {entries}"
2232                )
2233
2234        return self
2235
2236    def __getitem__(
2237        self,
2238        key: Literal[
2239            "keras_hdf5",
2240            "onnx",
2241            "pytorch_state_dict",
2242            "tensorflow_js",
2243            "tensorflow_saved_model_bundle",
2244            "torchscript",
2245        ],
2246    ):
2247        if key == "keras_hdf5":
2248            ret = self.keras_hdf5
2249        elif key == "onnx":
2250            ret = self.onnx
2251        elif key == "pytorch_state_dict":
2252            ret = self.pytorch_state_dict
2253        elif key == "tensorflow_js":
2254            ret = self.tensorflow_js
2255        elif key == "tensorflow_saved_model_bundle":
2256            ret = self.tensorflow_saved_model_bundle
2257        elif key == "torchscript":
2258            ret = self.torchscript
2259        else:
2260            raise KeyError(key)
2261
2262        if ret is None:
2263            raise KeyError(key)
2264
2265        return ret
2266
2267    @property
2268    def available_formats(self):
2269        return {
2270            **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}),
2271            **({} if self.onnx is None else {"onnx": self.onnx}),
2272            **(
2273                {}
2274                if self.pytorch_state_dict is None
2275                else {"pytorch_state_dict": self.pytorch_state_dict}
2276            ),
2277            **(
2278                {}
2279                if self.tensorflow_js is None
2280                else {"tensorflow_js": self.tensorflow_js}
2281            ),
2282            **(
2283                {}
2284                if self.tensorflow_saved_model_bundle is None
2285                else {
2286                    "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle
2287                }
2288            ),
2289            **({} if self.torchscript is None else {"torchscript": self.torchscript}),
2290        }
2291
2292    @property
2293    def missing_formats(self):
2294        return {
2295            wf for wf in get_args(WeightsFormat) if wf not in self.available_formats
2296        }
2297
2298
2299class ModelId(ResourceId):
2300    pass
2301
2302
2303class LinkedModel(LinkedResourceBase):
2304    """Reference to a bioimage.io model."""
2305
2306    id: ModelId
2307    """A valid model `id` from the bioimage.io collection."""
2308
2309
2310class _DataDepSize(NamedTuple):
2311    min: int
2312    max: Optional[int]
2313
2314
2315class _AxisSizes(NamedTuple):
2316    """the lenghts of all axes of model inputs and outputs"""
2317
2318    inputs: Dict[Tuple[TensorId, AxisId], int]
2319    outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]]
2320
2321
2322class _TensorSizes(NamedTuple):
2323    """_AxisSizes as nested dicts"""
2324
2325    inputs: Dict[TensorId, Dict[AxisId, int]]
2326    outputs: Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]
2327
2328
2329class ModelDescr(GenericModelDescrBase):
2330    """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights.
2331    These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
2332    """
2333
2334    format_version: Literal["0.5.3"] = "0.5.3"
2335    """Version of the bioimage.io model description specification used.
2336    When creating a new model always use the latest micro/patch version described here.
2337    The `format_version` is important for any consumer software to understand how to parse the fields.
2338    """
2339
2340    type: Literal["model"] = "model"
2341    """Specialized resource type 'model'"""
2342
2343    id: Optional[ModelId] = None
2344    """bioimage.io-wide unique resource identifier
2345    assigned by bioimage.io; version **un**specific."""
2346
2347    authors: NotEmpty[List[Author]]
2348    """The authors are the creators of the model RDF and the primary points of contact."""
2349
2350    documentation: Annotated[
2351        DocumentationSource,
2352        Field(
2353            examples=[
2354                "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/README.md",
2355                "README.md",
2356            ],
2357        ),
2358    ]
2359    """∈📦 URL or relative path to a markdown file with additional documentation.
2360    The recommended documentation file name is `README.md`. An `.md` suffix is mandatory.
2361    The documentation should include a '#[#] Validation' (sub)section
2362    with details on how to quantitatively validate the model on unseen data."""
2363
2364    @field_validator("documentation", mode="after")
2365    @classmethod
2366    def _validate_documentation(cls, value: DocumentationSource) -> DocumentationSource:
2367        if not validation_context_var.get().perform_io_checks:
2368            return value
2369
2370        doc_path = download(value).path
2371        doc_content = doc_path.read_text(encoding="utf-8")
2372        assert isinstance(doc_content, str)
2373        if not re.match("#.*[vV]alidation", doc_content):
2374            issue_warning(
2375                "No '# Validation' (sub)section found in {value}.",
2376                value=value,
2377                field="documentation",
2378            )
2379
2380        return value
2381
2382    inputs: NotEmpty[Sequence[InputTensorDescr]]
2383    """Describes the input tensors expected by this model."""
2384
2385    @field_validator("inputs", mode="after")
2386    @classmethod
2387    def _validate_input_axes(
2388        cls, inputs: Sequence[InputTensorDescr]
2389    ) -> Sequence[InputTensorDescr]:
2390        input_size_refs = cls._get_axes_with_independent_size(inputs)
2391
2392        for i, ipt in enumerate(inputs):
2393            valid_independent_refs: Dict[
2394                Tuple[TensorId, AxisId],
2395                Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2396            ] = {
2397                **{
2398                    (ipt.id, a.id): (ipt, a, a.size)
2399                    for a in ipt.axes
2400                    if not isinstance(a, BatchAxis)
2401                    and isinstance(a.size, (int, ParameterizedSize))
2402                },
2403                **input_size_refs,
2404            }
2405            for a, ax in enumerate(ipt.axes):
2406                cls._validate_axis(
2407                    "inputs",
2408                    i=i,
2409                    tensor_id=ipt.id,
2410                    a=a,
2411                    axis=ax,
2412                    valid_independent_refs=valid_independent_refs,
2413                )
2414        return inputs
2415
2416    @staticmethod
2417    def _validate_axis(
2418        field_name: str,
2419        i: int,
2420        tensor_id: TensorId,
2421        a: int,
2422        axis: AnyAxis,
2423        valid_independent_refs: Dict[
2424            Tuple[TensorId, AxisId],
2425            Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2426        ],
2427    ):
2428        if isinstance(axis, BatchAxis) or isinstance(
2429            axis.size, (int, ParameterizedSize, DataDependentSize)
2430        ):
2431            return
2432        elif not isinstance(axis.size, SizeReference):
2433            assert_never(axis.size)
2434
2435        # validate axis.size SizeReference
2436        ref = (axis.size.tensor_id, axis.size.axis_id)
2437        if ref not in valid_independent_refs:
2438            raise ValueError(
2439                "Invalid tensor axis reference at"
2440                + f" {field_name}[{i}].axes[{a}].size: {axis.size}."
2441            )
2442        if ref == (tensor_id, axis.id):
2443            raise ValueError(
2444                "Self-referencing not allowed for"
2445                + f" {field_name}[{i}].axes[{a}].size: {axis.size}"
2446            )
2447        if axis.type == "channel":
2448            if valid_independent_refs[ref][1].type != "channel":
2449                raise ValueError(
2450                    "A channel axis' size may only reference another fixed size"
2451                    + " channel axis."
2452                )
2453            if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names:
2454                ref_size = valid_independent_refs[ref][2]
2455                assert isinstance(ref_size, int), (
2456                    "channel axis ref (another channel axis) has to specify fixed"
2457                    + " size"
2458                )
2459                generated_channel_names = [
2460                    Identifier(axis.channel_names.format(i=i))
2461                    for i in range(1, ref_size + 1)
2462                ]
2463                axis.channel_names = generated_channel_names
2464
2465        if (ax_unit := getattr(axis, "unit", None)) != (
2466            ref_unit := getattr(valid_independent_refs[ref][1], "unit", None)
2467        ):
2468            raise ValueError(
2469                "The units of an axis and its reference axis need to match, but"
2470                + f" '{ax_unit}' != '{ref_unit}'."
2471            )
2472        ref_axis = valid_independent_refs[ref][1]
2473        if isinstance(ref_axis, BatchAxis):
2474            raise ValueError(
2475                f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}"
2476                + " (a batch axis is not allowed as reference)."
2477            )
2478
2479        if isinstance(axis, WithHalo):
2480            min_size = axis.size.get_size(axis, ref_axis, n=0)
2481            if (min_size - 2 * axis.halo) < 1:
2482                raise ValueError(
2483                    f"axis {axis.id} with minimum size {min_size} is too small for halo"
2484                    + f" {axis.halo}."
2485                )
2486
2487            input_halo = axis.halo * axis.scale / ref_axis.scale
2488            if input_halo != int(input_halo) or input_halo % 2 == 1:
2489                raise ValueError(
2490                    f"input_halo {input_halo} (output_halo {axis.halo} *"
2491                    + f" output_scale {axis.scale} / input_scale {ref_axis.scale})"
2492                    + f" is not an even integer for {tensor_id}.{axis.id}."
2493                )
2494
2495    @model_validator(mode="after")
2496    def _validate_test_tensors(self) -> Self:
2497        if not validation_context_var.get().perform_io_checks:
2498            return self
2499
2500        test_arrays = [
2501            load_array(descr.test_tensor.download().path)
2502            for descr in chain(self.inputs, self.outputs)
2503        ]
2504        tensors = {
2505            descr.id: (descr, array)
2506            for descr, array in zip(chain(self.inputs, self.outputs), test_arrays)
2507        }
2508        validate_tensors(tensors, tensor_origin="test_tensor")
2509        return self
2510
2511    @model_validator(mode="after")
2512    def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self:
2513        ipt_refs = {t.id for t in self.inputs}
2514        out_refs = {t.id for t in self.outputs}
2515        for ipt in self.inputs:
2516            for p in ipt.preprocessing:
2517                ref = p.kwargs.get("reference_tensor")
2518                if ref is None:
2519                    continue
2520                if ref not in ipt_refs:
2521                    raise ValueError(
2522                        f"`reference_tensor` '{ref}' not found. Valid input tensor"
2523                        + f" references are: {ipt_refs}."
2524                    )
2525
2526        for out in self.outputs:
2527            for p in out.postprocessing:
2528                ref = p.kwargs.get("reference_tensor")
2529                if ref is None:
2530                    continue
2531
2532                if ref not in ipt_refs and ref not in out_refs:
2533                    raise ValueError(
2534                        f"`reference_tensor` '{ref}' not found. Valid tensor references"
2535                        + f" are: {ipt_refs | out_refs}."
2536                    )
2537
2538        return self
2539
2540    # TODO: use validate funcs in validate_test_tensors
2541    # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]:
2542
2543    name: Annotated[
2544        Annotated[
2545            str, RestrictCharacters(string.ascii_letters + string.digits + "_- ()")
2546        ],
2547        MinLen(5),
2548        MaxLen(128),
2549        warn(MaxLen(64), "Name longer than 64 characters.", INFO),
2550    ]
2551    """A human-readable name of this model.
2552    It should be no longer than 64 characters
2553    and may only contain letter, number, underscore, minus, parentheses and spaces.
2554    We recommend to chose a name that refers to the model's task and image modality.
2555    """
2556
2557    outputs: NotEmpty[Sequence[OutputTensorDescr]]
2558    """Describes the output tensors."""
2559
2560    @field_validator("outputs", mode="after")
2561    @classmethod
2562    def _validate_tensor_ids(
2563        cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo
2564    ) -> Sequence[OutputTensorDescr]:
2565        tensor_ids = [
2566            t.id for t in info.data.get("inputs", []) + info.data.get("outputs", [])
2567        ]
2568        duplicate_tensor_ids: List[str] = []
2569        seen: Set[str] = set()
2570        for t in tensor_ids:
2571            if t in seen:
2572                duplicate_tensor_ids.append(t)
2573
2574            seen.add(t)
2575
2576        if duplicate_tensor_ids:
2577            raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}")
2578
2579        return outputs
2580
2581    @staticmethod
2582    def _get_axes_with_parameterized_size(
2583        io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2584    ):
2585        return {
2586            f"{t.id}.{a.id}": (t, a, a.size)
2587            for t in io
2588            for a in t.axes
2589            if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize)
2590        }
2591
2592    @staticmethod
2593    def _get_axes_with_independent_size(
2594        io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2595    ):
2596        return {
2597            (t.id, a.id): (t, a, a.size)
2598            for t in io
2599            for a in t.axes
2600            if not isinstance(a, BatchAxis)
2601            and isinstance(a.size, (int, ParameterizedSize))
2602        }
2603
2604    @field_validator("outputs", mode="after")
2605    @classmethod
2606    def _validate_output_axes(
2607        cls, outputs: List[OutputTensorDescr], info: ValidationInfo
2608    ) -> List[OutputTensorDescr]:
2609        input_size_refs = cls._get_axes_with_independent_size(
2610            info.data.get("inputs", [])
2611        )
2612        output_size_refs = cls._get_axes_with_independent_size(outputs)
2613
2614        for i, out in enumerate(outputs):
2615            valid_independent_refs: Dict[
2616                Tuple[TensorId, AxisId],
2617                Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2618            ] = {
2619                **{
2620                    (out.id, a.id): (out, a, a.size)
2621                    for a in out.axes
2622                    if not isinstance(a, BatchAxis)
2623                    and isinstance(a.size, (int, ParameterizedSize))
2624                },
2625                **input_size_refs,
2626                **output_size_refs,
2627            }
2628            for a, ax in enumerate(out.axes):
2629                cls._validate_axis(
2630                    "outputs",
2631                    i,
2632                    out.id,
2633                    a,
2634                    ax,
2635                    valid_independent_refs=valid_independent_refs,
2636                )
2637
2638        return outputs
2639
2640    packaged_by: List[Author] = Field(default_factory=list)
2641    """The persons that have packaged and uploaded this model.
2642    Only required if those persons differ from the `authors`."""
2643
2644    parent: Optional[LinkedModel] = None
2645    """The model from which this model is derived, e.g. by fine-tuning the weights."""
2646
2647    # todo: add parent self check once we have `id`
2648    # @model_validator(mode="after")
2649    # def validate_parent_is_not_self(self) -> Self:
2650    #     if self.parent is not None and self.parent == self.id:
2651    #         raise ValueError("The model may not reference itself as parent model")
2652
2653    #     return self
2654
2655    run_mode: Annotated[
2656        Optional[RunMode],
2657        warn(None, "Run mode '{value}' has limited support across consumer softwares."),
2658    ] = None
2659    """Custom run mode for this model: for more complex prediction procedures like test time
2660    data augmentation that currently cannot be expressed in the specification.
2661    No standard run modes are defined yet."""
2662
2663    timestamp: Datetime = Datetime(datetime.now())
2664    """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format
2665    with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat).
2666    (In Python a datetime object is valid, too)."""
2667
2668    training_data: Annotated[
2669        Union[None, LinkedDataset, DatasetDescr, DatasetDescr02],
2670        Field(union_mode="left_to_right"),
2671    ] = None
2672    """The dataset used to train this model"""
2673
2674    weights: Annotated[WeightsDescr, WrapSerializer(package_weights)]
2675    """The weights for this model.
2676    Weights can be given for different formats, but should otherwise be equivalent.
2677    The available weight formats determine which consumers can use this model."""
2678
2679    @model_validator(mode="after")
2680    def _add_default_cover(self) -> Self:
2681        if not validation_context_var.get().perform_io_checks or self.covers:
2682            return self
2683
2684        try:
2685            generated_covers = generate_covers(
2686                [(t, load_array(t.test_tensor.download().path)) for t in self.inputs],
2687                [(t, load_array(t.test_tensor.download().path)) for t in self.outputs],
2688            )
2689        except Exception as e:
2690            issue_warning(
2691                "Failed to generate cover image(s): {e}",
2692                value=self.covers,
2693                msg_context=dict(e=e),
2694                field="covers",
2695            )
2696        else:
2697            self.covers.extend(generated_covers)
2698
2699        return self
2700
2701    def get_input_test_arrays(self) -> List[NDArray[Any]]:
2702        data = [load_array(ipt.test_tensor.download().path) for ipt in self.inputs]
2703        assert all(isinstance(d, np.ndarray) for d in data)
2704        return data
2705
2706    def get_output_test_arrays(self) -> List[NDArray[Any]]:
2707        data = [load_array(out.test_tensor.download().path) for out in self.outputs]
2708        assert all(isinstance(d, np.ndarray) for d in data)
2709        return data
2710
2711    @staticmethod
2712    def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
2713        batch_size = 1
2714        tensor_with_batchsize: Optional[TensorId] = None
2715        for tid in tensor_sizes:
2716            for aid, s in tensor_sizes[tid].items():
2717                if aid != BATCH_AXIS_ID or s == 1 or s == batch_size:
2718                    continue
2719
2720                if batch_size != 1:
2721                    assert tensor_with_batchsize is not None
2722                    raise ValueError(
2723                        f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})"
2724                    )
2725
2726                batch_size = s
2727                tensor_with_batchsize = tid
2728
2729        return batch_size
2730
2731    def get_output_tensor_sizes(
2732        self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]
2733    ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]:
2734        """Returns the tensor output sizes for given **input_sizes**.
2735        Only if **input_sizes** has a valid input shape, the tensor output size is exact.
2736        Otherwise it might be larger than the actual (valid) output"""
2737        batch_size = self.get_batch_size(input_sizes)
2738        ns = self.get_ns(input_sizes)
2739
2740        tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size)
2741        return tensor_sizes.outputs
2742
2743    def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
2744        """get parameter `n` for each parameterized axis
2745        such that the valid input size is >= the given input size"""
2746        ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {}
2747        axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs}
2748        for tid in input_sizes:
2749            for aid, s in input_sizes[tid].items():
2750                size_descr = axes[tid][aid].size
2751                if isinstance(size_descr, ParameterizedSize):
2752                    ret[(tid, aid)] = size_descr.get_n(s)
2753                elif size_descr is None or isinstance(size_descr, (int, SizeReference)):
2754                    pass
2755                else:
2756                    assert_never(size_descr)
2757
2758        return ret
2759
2760    def get_tensor_sizes(
2761        self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int
2762    ) -> _TensorSizes:
2763        axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size)
2764        return _TensorSizes(
2765            {
2766                t: {
2767                    aa: axis_sizes.inputs[(tt, aa)]
2768                    for tt, aa in axis_sizes.inputs
2769                    if tt == t
2770                }
2771                for t in {tt for tt, _ in axis_sizes.inputs}
2772            },
2773            {
2774                t: {
2775                    aa: axis_sizes.outputs[(tt, aa)]
2776                    for tt, aa in axis_sizes.outputs
2777                    if tt == t
2778                }
2779                for t in {tt for tt, _ in axis_sizes.outputs}
2780            },
2781        )
2782
2783    def get_axis_sizes(
2784        self,
2785        ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
2786        batch_size: Optional[int] = None,
2787        *,
2788        max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
2789    ) -> _AxisSizes:
2790        """Determine input and output block shape for scale factors **ns**
2791        of parameterized input sizes.
2792
2793        Args:
2794            ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id))
2795                that is parameterized as `size = min + n * step`.
2796            batch_size: The desired size of the batch dimension.
2797                If given **batch_size** overwrites any batch size present in
2798                **max_input_shape**. Default 1.
2799            max_input_shape: Limits the derived block shapes.
2800                Each axis for which the input size, parameterized by `n`, is larger
2801                than **max_input_shape** is set to the minimal value `n_min` for which
2802                this is still true.
2803                Use this for small input samples or large values of **ns**.
2804                Or simply whenever you know the full input shape.
2805
2806        Returns:
2807            Resolved axis sizes for model inputs and outputs.
2808        """
2809        max_input_shape = max_input_shape or {}
2810        if batch_size is None:
2811            for (_t_id, a_id), s in max_input_shape.items():
2812                if a_id == BATCH_AXIS_ID:
2813                    batch_size = s
2814                    break
2815            else:
2816                batch_size = 1
2817
2818        all_axes = {
2819            t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
2820        }
2821
2822        inputs: Dict[Tuple[TensorId, AxisId], int] = {}
2823        outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {}
2824
2825        def get_axis_size(a: Union[InputAxis, OutputAxis]):
2826            if isinstance(a, BatchAxis):
2827                if (t_descr.id, a.id) in ns:
2828                    logger.warning(
2829                        "Ignoring unexpected size increment factor (n) for batch axis"
2830                        + " of tensor '{}'.",
2831                        t_descr.id,
2832                    )
2833                return batch_size
2834            elif isinstance(a.size, int):
2835                if (t_descr.id, a.id) in ns:
2836                    logger.warning(
2837                        "Ignoring unexpected size increment factor (n) for fixed size"
2838                        + " axis '{}' of tensor '{}'.",
2839                        a.id,
2840                        t_descr.id,
2841                    )
2842                return a.size
2843            elif isinstance(a.size, ParameterizedSize):
2844                if (t_descr.id, a.id) not in ns:
2845                    raise ValueError(
2846                        "Size increment factor (n) missing for parametrized axis"
2847                        + f" '{a.id}' of tensor '{t_descr.id}'."
2848                    )
2849                n = ns[(t_descr.id, a.id)]
2850                s_max = max_input_shape.get((t_descr.id, a.id))
2851                if s_max is not None:
2852                    n = min(n, a.size.get_n(s_max))
2853
2854                return a.size.get_size(n)
2855
2856            elif isinstance(a.size, SizeReference):
2857                if (t_descr.id, a.id) in ns:
2858                    logger.warning(
2859                        "Ignoring unexpected size increment factor (n) for axis '{}'"
2860                        + " of tensor '{}' with size reference.",
2861                        a.id,
2862                        t_descr.id,
2863                    )
2864                assert not isinstance(a, BatchAxis)
2865                ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
2866                assert not isinstance(ref_axis, BatchAxis)
2867                ref_key = (a.size.tensor_id, a.size.axis_id)
2868                ref_size = inputs.get(ref_key, outputs.get(ref_key))
2869                assert ref_size is not None, ref_key
2870                assert not isinstance(ref_size, _DataDepSize), ref_key
2871                return a.size.get_size(
2872                    axis=a,
2873                    ref_axis=ref_axis,
2874                    ref_size=ref_size,
2875                )
2876            elif isinstance(a.size, DataDependentSize):
2877                if (t_descr.id, a.id) in ns:
2878                    logger.warning(
2879                        "Ignoring unexpected increment factor (n) for data dependent"
2880                        + " size axis '{}' of tensor '{}'.",
2881                        a.id,
2882                        t_descr.id,
2883                    )
2884                return _DataDepSize(a.size.min, a.size.max)
2885            else:
2886                assert_never(a.size)
2887
2888        # first resolve all , but the `SizeReference` input sizes
2889        for t_descr in self.inputs:
2890            for a in t_descr.axes:
2891                if not isinstance(a.size, SizeReference):
2892                    s = get_axis_size(a)
2893                    assert not isinstance(s, _DataDepSize)
2894                    inputs[t_descr.id, a.id] = s
2895
2896        # resolve all other input axis sizes
2897        for t_descr in self.inputs:
2898            for a in t_descr.axes:
2899                if isinstance(a.size, SizeReference):
2900                    s = get_axis_size(a)
2901                    assert not isinstance(s, _DataDepSize)
2902                    inputs[t_descr.id, a.id] = s
2903
2904        # resolve all output axis sizes
2905        for t_descr in self.outputs:
2906            for a in t_descr.axes:
2907                assert not isinstance(a.size, ParameterizedSize)
2908                s = get_axis_size(a)
2909                outputs[t_descr.id, a.id] = s
2910
2911        return _AxisSizes(inputs=inputs, outputs=outputs)
2912
2913    @model_validator(mode="before")
2914    @classmethod
2915    def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]:
2916        if (
2917            data.get("type") == "model"
2918            and isinstance(fv := data.get("format_version"), str)
2919            and fv.count(".") == 2
2920        ):
2921            fv_parts = fv.split(".")
2922            if any(not p.isdigit() for p in fv_parts):
2923                return data
2924
2925            fv_tuple = tuple(map(int, fv_parts))
2926
2927            assert cls.implemented_format_version_tuple[0:2] == (0, 5)
2928            if fv_tuple[:2] in ((0, 3), (0, 4)):
2929                m04 = _ModelDescr_v0_4.load(data)
2930                if not isinstance(m04, InvalidDescr):
2931                    return _model_conv.convert_as_dict(m04)
2932            elif fv_tuple[:2] == (0, 5):
2933                # bump patch version
2934                data["format_version"] = cls.implemented_format_version
2935
2936        return data
2937
2938
2939class _ModelConv(Converter[_ModelDescr_v0_4, ModelDescr]):
2940    def _convert(
2941        self, src: _ModelDescr_v0_4, tgt: "type[ModelDescr] | type[dict[str, Any]]"
2942    ) -> "ModelDescr | dict[str, Any]":
2943        name = "".join(
2944            c if c in string.ascii_letters + string.digits + "_- ()" else " "
2945            for c in src.name
2946        )
2947
2948        def conv_authors(auths: Optional[Sequence[_Author_v0_4]]):
2949            conv = (
2950                _author_conv.convert if TYPE_CHECKING else _author_conv.convert_as_dict
2951            )
2952            return None if auths is None else [conv(a) for a in auths]
2953
2954        if TYPE_CHECKING:
2955            arch_file_conv = _arch_file_conv.convert
2956            arch_lib_conv = _arch_lib_conv.convert
2957        else:
2958            arch_file_conv = _arch_file_conv.convert_as_dict
2959            arch_lib_conv = _arch_lib_conv.convert_as_dict
2960
2961        input_size_refs = {
2962            ipt.name: {
2963                a: s
2964                for a, s in zip(
2965                    ipt.axes,
2966                    (
2967                        ipt.shape.min
2968                        if isinstance(ipt.shape, _ParameterizedInputShape_v0_4)
2969                        else ipt.shape
2970                    ),
2971                )
2972            }
2973            for ipt in src.inputs
2974            if ipt.shape
2975        }
2976        output_size_refs = {
2977            **{
2978                out.name: {a: s for a, s in zip(out.axes, out.shape)}
2979                for out in src.outputs
2980                if not isinstance(out.shape, _ImplicitOutputShape_v0_4)
2981            },
2982            **input_size_refs,
2983        }
2984
2985        return tgt(
2986            attachments=(
2987                []
2988                if src.attachments is None
2989                else [FileDescr(source=f) for f in src.attachments.files]
2990            ),
2991            authors=[
2992                _author_conv.convert_as_dict(a) for a in src.authors
2993            ],  # pyright: ignore[reportArgumentType]
2994            cite=[
2995                {"text": c.text, "doi": c.doi, "url": c.url} for c in src.cite
2996            ],  # pyright: ignore[reportArgumentType]
2997            config=src.config,
2998            covers=src.covers,
2999            description=src.description,
3000            documentation=src.documentation,
3001            format_version="0.5.3",
3002            git_repo=src.git_repo,  # pyright: ignore[reportArgumentType]
3003            icon=src.icon,
3004            id=None if src.id is None else ModelId(src.id),
3005            id_emoji=src.id_emoji,
3006            license=src.license,  # type: ignore
3007            links=src.links,
3008            maintainers=[
3009                _maintainer_conv.convert_as_dict(m) for m in src.maintainers
3010            ],  # pyright: ignore[reportArgumentType]
3011            name=name,
3012            tags=src.tags,
3013            type=src.type,
3014            uploader=src.uploader,
3015            version=src.version,
3016            inputs=[  # pyright: ignore[reportArgumentType]
3017                _input_tensor_conv.convert_as_dict(ipt, tt, st, input_size_refs)
3018                for ipt, tt, st, in zip(
3019                    src.inputs,
3020                    src.test_inputs,
3021                    src.sample_inputs or [None] * len(src.test_inputs),
3022                )
3023            ],
3024            outputs=[  # pyright: ignore[reportArgumentType]
3025                _output_tensor_conv.convert_as_dict(out, tt, st, output_size_refs)
3026                for out, tt, st, in zip(
3027                    src.outputs,
3028                    src.test_outputs,
3029                    src.sample_outputs or [None] * len(src.test_outputs),
3030                )
3031            ],
3032            parent=(
3033                None
3034                if src.parent is None
3035                else LinkedModel(
3036                    id=ModelId(
3037                        str(src.parent.id)
3038                        + (
3039                            ""
3040                            if src.parent.version_number is None
3041                            else f"/{src.parent.version_number}"
3042                        )
3043                    )
3044                )
3045            ),
3046            training_data=(
3047                None
3048                if src.training_data is None
3049                else (
3050                    LinkedDataset(
3051                        id=DatasetId(
3052                            str(src.training_data.id)
3053                            + (
3054                                ""
3055                                if src.training_data.version_number is None
3056                                else f"/{src.training_data.version_number}"
3057                            )
3058                        )
3059                    )
3060                    if isinstance(src.training_data, LinkedDataset02)
3061                    else src.training_data
3062                )
3063            ),
3064            packaged_by=[
3065                _author_conv.convert_as_dict(a) for a in src.packaged_by
3066            ],  # pyright: ignore[reportArgumentType]
3067            run_mode=src.run_mode,
3068            timestamp=src.timestamp,
3069            weights=(WeightsDescr if TYPE_CHECKING else dict)(
3070                keras_hdf5=(w := src.weights.keras_hdf5)
3071                and (KerasHdf5WeightsDescr if TYPE_CHECKING else dict)(
3072                    authors=conv_authors(w.authors),
3073                    source=w.source,
3074                    tensorflow_version=w.tensorflow_version or Version("1.15"),
3075                    parent=w.parent,
3076                ),
3077                onnx=(w := src.weights.onnx)
3078                and (OnnxWeightsDescr if TYPE_CHECKING else dict)(
3079                    source=w.source,
3080                    authors=conv_authors(w.authors),
3081                    parent=w.parent,
3082                    opset_version=w.opset_version or 15,
3083                ),
3084                pytorch_state_dict=(w := src.weights.pytorch_state_dict)
3085                and (PytorchStateDictWeightsDescr if TYPE_CHECKING else dict)(
3086                    source=w.source,
3087                    authors=conv_authors(w.authors),
3088                    parent=w.parent,
3089                    architecture=(
3090                        arch_file_conv(
3091                            w.architecture,
3092                            w.architecture_sha256,
3093                            w.kwargs,
3094                        )
3095                        if isinstance(w.architecture, _CallableFromFile_v0_4)
3096                        else arch_lib_conv(w.architecture, w.kwargs)
3097                    ),
3098                    pytorch_version=w.pytorch_version or Version("1.10"),
3099                    dependencies=(
3100                        None
3101                        if w.dependencies is None
3102                        else (EnvironmentFileDescr if TYPE_CHECKING else dict)(
3103                            source=cast(
3104                                ImportantFileSource,
3105                                str(deps := w.dependencies)[
3106                                    (
3107                                        len("conda:")
3108                                        if str(deps).startswith("conda:")
3109                                        else 0
3110                                    ) :
3111                                ],
3112                            )
3113                        )
3114                    ),
3115                ),
3116                tensorflow_js=(w := src.weights.tensorflow_js)
3117                and (TensorflowJsWeightsDescr if TYPE_CHECKING else dict)(
3118                    source=w.source,
3119                    authors=conv_authors(w.authors),
3120                    parent=w.parent,
3121                    tensorflow_version=w.tensorflow_version or Version("1.15"),
3122                ),
3123                tensorflow_saved_model_bundle=(
3124                    w := src.weights.tensorflow_saved_model_bundle
3125                )
3126                and (TensorflowSavedModelBundleWeightsDescr if TYPE_CHECKING else dict)(
3127                    authors=conv_authors(w.authors),
3128                    parent=w.parent,
3129                    source=w.source,
3130                    tensorflow_version=w.tensorflow_version or Version("1.15"),
3131                    dependencies=(
3132                        None
3133                        if w.dependencies is None
3134                        else (EnvironmentFileDescr if TYPE_CHECKING else dict)(
3135                            source=cast(
3136                                ImportantFileSource,
3137                                (
3138                                    str(w.dependencies)[len("conda:") :]
3139                                    if str(w.dependencies).startswith("conda:")
3140                                    else str(w.dependencies)
3141                                ),
3142                            )
3143                        )
3144                    ),
3145                ),
3146                torchscript=(w := src.weights.torchscript)
3147                and (TorchscriptWeightsDescr if TYPE_CHECKING else dict)(
3148                    source=w.source,
3149                    authors=conv_authors(w.authors),
3150                    parent=w.parent,
3151                    pytorch_version=w.pytorch_version or Version("1.10"),
3152                ),
3153            ),
3154        )
3155
3156
3157_model_conv = _ModelConv(_ModelDescr_v0_4, ModelDescr)
3158
3159
3160# create better cover images for 3d data and non-image outputs
3161def generate_covers(
3162    inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]],
3163    outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]],
3164) -> List[Path]:
3165    def squeeze(
3166        data: NDArray[Any], axes: Sequence[AnyAxis]
3167    ) -> Tuple[NDArray[Any], List[AnyAxis]]:
3168        """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining"""
3169        if data.ndim != len(axes):
3170            raise ValueError(
3171                f"tensor shape {data.shape} does not match described axes"
3172                + f" {[a.id for a in axes]}"
3173            )
3174
3175        axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1]
3176        return data.squeeze(), axes
3177
3178    def normalize(
3179        data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7
3180    ) -> NDArray[np.float32]:
3181        data = data.astype("float32")
3182        data -= data.min(axis=axis, keepdims=True)
3183        data /= data.max(axis=axis, keepdims=True) + eps
3184        return data
3185
3186    def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]):
3187        original_shape = data.shape
3188        data, axes = squeeze(data, axes)
3189
3190        # take slice fom any batch or index axis if needed
3191        # and convert the first channel axis and take a slice from any additional channel axes
3192        slices: Tuple[slice, ...] = ()
3193        ndim = data.ndim
3194        ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2
3195        has_c_axis = False
3196        for i, a in enumerate(axes):
3197            s = data.shape[i]
3198            assert s > 1
3199            if (
3200                isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis))
3201                and ndim > ndim_need
3202            ):
3203                data = data[slices + (slice(s // 2 - 1, s // 2),)]
3204                ndim -= 1
3205            elif isinstance(a, ChannelAxis):
3206                if has_c_axis:
3207                    # second channel axis
3208                    data = data[slices + (slice(0, 1),)]
3209                    ndim -= 1
3210                else:
3211                    has_c_axis = True
3212                    if s == 2:
3213                        # visualize two channels with cyan and magenta
3214                        data = np.concatenate(
3215                            [
3216                                data[slices + (slice(1, 2),)],
3217                                data[slices + (slice(0, 1),)],
3218                                (
3219                                    data[slices + (slice(0, 1),)]
3220                                    + data[slices + (slice(1, 2),)]
3221                                )
3222                                / 2,  # TODO: take maximum instead?
3223                            ],
3224                            axis=i,
3225                        )
3226                    elif data.shape[i] == 3:
3227                        pass  # visualize 3 channels as RGB
3228                    else:
3229                        # visualize first 3 channels as RGB
3230                        data = data[slices + (slice(3),)]
3231
3232                    assert data.shape[i] == 3
3233
3234            slices += (slice(None),)
3235
3236        data, axes = squeeze(data, axes)
3237        assert len(axes) == ndim
3238        # take slice from z axis if needed
3239        slices = ()
3240        if ndim > ndim_need:
3241            for i, a in enumerate(axes):
3242                s = data.shape[i]
3243                if a.id == AxisId("z"):
3244                    data = data[slices + (slice(s // 2 - 1, s // 2),)]
3245                    data, axes = squeeze(data, axes)
3246                    ndim -= 1
3247                    break
3248
3249            slices += (slice(None),)
3250
3251        # take slice from any space or time axis
3252        slices = ()
3253
3254        for i, a in enumerate(axes):
3255            if ndim <= ndim_need:
3256                break
3257
3258            s = data.shape[i]
3259            assert s > 1
3260            if isinstance(
3261                a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis)
3262            ):
3263                data = data[slices + (slice(s // 2 - 1, s // 2),)]
3264                ndim -= 1
3265
3266            slices += (slice(None),)
3267
3268        del slices
3269        data, axes = squeeze(data, axes)
3270        assert len(axes) == ndim
3271
3272        if (has_c_axis and ndim != 3) or ndim != 2:
3273            raise ValueError(
3274                f"Failed to construct cover image from shape {original_shape}"
3275            )
3276
3277        if not has_c_axis:
3278            assert ndim == 2
3279            data = np.repeat(data[:, :, None], 3, axis=2)
3280            axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB"))))
3281            ndim += 1
3282
3283        assert ndim == 3
3284
3285        # transpose axis order such that longest axis comes first...
3286        axis_order = list(np.argsort(list(data.shape)))
3287        axis_order.reverse()
3288        # ... and channel axis is last
3289        c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0]
3290        axis_order.append(axis_order.pop(c))
3291        axes = [axes[ao] for ao in axis_order]
3292        data = data.transpose(axis_order)
3293
3294        # h, w = data.shape[:2]
3295        # if h / w  in (1.0 or 2.0):
3296        #     pass
3297        # elif h / w < 2:
3298        # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images
3299
3300        norm_along = (
3301            tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None
3302        )
3303        # normalize the data and map to 8 bit
3304        data = normalize(data, norm_along)
3305        data = (data * 255).astype("uint8")
3306
3307        return data
3308
3309    def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]):
3310        assert im0.dtype == im1.dtype == np.uint8
3311        assert im0.shape == im1.shape
3312        assert im0.ndim == 3
3313        N, M, C = im0.shape
3314        assert C == 3
3315        out = np.ones((N, M, C), dtype="uint8")
3316        for c in range(C):
3317            outc = np.tril(im0[..., c])
3318            mask = outc == 0
3319            outc[mask] = np.triu(im1[..., c])[mask]
3320            out[..., c] = outc
3321
3322        return out
3323
3324    ipt_descr, ipt = inputs[0]
3325    out_descr, out = outputs[0]
3326
3327    ipt_img = to_2d_image(ipt, ipt_descr.axes)
3328    out_img = to_2d_image(out, out_descr.axes)
3329
3330    cover_folder = Path(mkdtemp())
3331    if ipt_img.shape == out_img.shape:
3332        covers = [cover_folder / "cover.png"]
3333        imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img))
3334    else:
3335        covers = [cover_folder / "input.png", cover_folder / "output.png"]
3336        imwrite(covers[0], ipt_img)
3337        imwrite(covers[1], out_img)
3338
3339    return covers
SpaceUnit = typing.Literal['attometer', 'angstrom', 'centimeter', 'decimeter', 'exameter', 'femtometer', 'foot', 'gigameter', 'hectometer', 'inch', 'kilometer', 'megameter', 'meter', 'micrometer', 'mile', 'millimeter', 'nanometer', 'parsec', 'petameter', 'picometer', 'terameter', 'yard', 'yoctometer', 'yottameter', 'zeptometer', 'zettameter']

Space unit compatible to the OME-Zarr axes specification 0.5

TimeUnit = typing.Literal['attosecond', 'centisecond', 'day', 'decisecond', 'exasecond', 'femtosecond', 'gigasecond', 'hectosecond', 'hour', 'kilosecond', 'megasecond', 'microsecond', 'millisecond', 'minute', 'nanosecond', 'petasecond', 'picosecond', 'second', 'terasecond', 'yoctosecond', 'yottasecond', 'zeptosecond', 'zettasecond']

Time unit compatible to the OME-Zarr axes specification 0.5

AxisType = typing.Literal['batch', 'channel', 'index', 'time', 'space']
194class TensorId(LowerCaseIdentifier):
195    root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
196        Annotated[LowerCaseIdentifierAnno, MaxLen(32)]
197    ]

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.

root_model: ClassVar[Type[pydantic.root_model.RootModel[Any]]] = <class 'pydantic.root_model.RootModel[Annotated[str, MinLen, AfterValidator, AfterValidator, Annotated[TypeVar, Predicate], MaxLen]]'>

the pydantic root model to validate the string

200class AxisId(LowerCaseIdentifier):
201    root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
202        Annotated[LowerCaseIdentifierAnno, MaxLen(16)]
203    ]

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.

root_model: ClassVar[Type[pydantic.root_model.RootModel[Any]]] = <class 'pydantic.root_model.RootModel[Annotated[str, MinLen, AfterValidator, AfterValidator, Annotated[TypeVar, Predicate], MaxLen]]'>

the pydantic root model to validate the string

NonBatchAxisId = typing.Annotated[AxisId, Predicate(_is_not_batch)]
PostprocessingId = typing.Literal['binarize', 'clip', 'ensure_dtype', 'fixed_zero_mean_unit_variance', 'scale_linear', 'scale_mean_variance', 'scale_range', 'sigmoid', 'zero_mean_unit_variance']
PreprocessingId = typing.Literal['binarize', 'clip', 'ensure_dtype', 'scale_linear', 'sigmoid', 'zero_mean_unit_variance', 'scale_range']
SAME_AS_TYPE = '<same as type>'
ParameterizedSize_N = <class 'int'>
class ParameterizedSize(bioimageio.spec._internal.node.Node):
244class ParameterizedSize(Node):
245    """Describes a range of valid tensor axis sizes as `size = min + n*step`."""
246
247    N: ClassVar[Type[int]] = ParameterizedSize_N
248    """integer to parameterize this axis"""
249
250    min: Annotated[int, Gt(0)]
251    step: Annotated[int, Gt(0)]
252
253    def validate_size(self, size: int) -> int:
254        if size < self.min:
255            raise ValueError(f"size {size} < {self.min}")
256        if (size - self.min) % self.step != 0:
257            raise ValueError(
258                f"axis of size {size} is not parameterized by `min + n*step` ="
259                + f" `{self.min} + n*{self.step}`"
260            )
261
262        return size
263
264    def get_size(self, n: ParameterizedSize_N) -> int:
265        return self.min + self.step * n
266
267    def get_n(self, s: int) -> ParameterizedSize_N:
268        """return smallest n parameterizing a size greater or equal than `s`"""
269        return ceil((s - self.min) / self.step)

Describes a range of valid tensor axis sizes as size = min + n*step.

N: ClassVar[Type[int]] = <class 'int'>

integer to parameterize this axis

min: Annotated[int, Gt(gt=0)]
step: Annotated[int, Gt(gt=0)]
def validate_size(self, size: int) -> int:
253    def validate_size(self, size: int) -> int:
254        if size < self.min:
255            raise ValueError(f"size {size} < {self.min}")
256        if (size - self.min) % self.step != 0:
257            raise ValueError(
258                f"axis of size {size} is not parameterized by `min + n*step` ="
259                + f" `{self.min} + n*{self.step}`"
260            )
261
262        return size
def get_size(self, n: int) -> int:
264    def get_size(self, n: ParameterizedSize_N) -> int:
265        return self.min + self.step * n
def get_n(self, s: int) -> int:
267    def get_n(self, s: int) -> ParameterizedSize_N:
268        """return smallest n parameterizing a size greater or equal than `s`"""
269        return ceil((s - self.min) / self.step)

return smallest n parameterizing a size greater or equal than s

class DataDependentSize(bioimageio.spec._internal.node.Node):
272class DataDependentSize(Node):
273    min: Annotated[int, Gt(0)] = 1
274    max: Annotated[Optional[int], Gt(1)] = None
275
276    @model_validator(mode="after")
277    def _validate_max_gt_min(self):
278        if self.max is not None and self.min >= self.max:
279            raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}")
280
281        return self
282
283    def validate_size(self, size: int) -> int:
284        if size < self.min:
285            raise ValueError(f"size {size} < {self.min}")
286
287        if self.max is not None and size > self.max:
288            raise ValueError(f"size {size} > {self.max}")
289
290        return size
min: Annotated[int, Gt(gt=0)]
max: Annotated[Optional[int], Gt(gt=1)]
def validate_size(self, size: int) -> int:
283    def validate_size(self, size: int) -> int:
284        if size < self.min:
285            raise ValueError(f"size {size} < {self.min}")
286
287        if self.max is not None and size > self.max:
288            raise ValueError(f"size {size} > {self.max}")
289
290        return size
class SizeReference(bioimageio.spec._internal.node.Node):
293class SizeReference(Node):
294    """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis.
295
296    `axis.size = reference.size * reference.scale / axis.scale + offset`
297
298    Note:
299    1. The axis and the referenced axis need to have the same unit (or no unit).
300    2. Batch axes may not be referenced.
301    3. Fractions are rounded down.
302    4. If the reference axis is `concatenable` the referencing axis is assumed to be
303        `concatenable` as well with the same block order.
304
305    Example:
306    An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm².
307    Let's assume that we want to express the image height h in relation to its width w
308    instead of only accepting input images of exactly 100*49 pixels
309    (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`).
310
311    >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2)
312    >>> h = SpaceInputAxis(
313    ...     id=AxisId("h"),
314    ...     size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1),
315    ...     unit="millimeter",
316    ...     scale=4,
317    ... )
318    >>> print(h.size.get_size(h, w))
319    49
320
321    ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49
322    """
323
324    tensor_id: TensorId
325    """tensor id of the reference axis"""
326
327    axis_id: AxisId
328    """axis id of the reference axis"""
329
330    offset: int = 0
331
332    def get_size(
333        self,
334        axis: Union[
335            ChannelAxis,
336            IndexInputAxis,
337            IndexOutputAxis,
338            TimeInputAxis,
339            SpaceInputAxis,
340            TimeOutputAxis,
341            TimeOutputAxisWithHalo,
342            SpaceOutputAxis,
343            SpaceOutputAxisWithHalo,
344        ],
345        ref_axis: Union[
346            ChannelAxis,
347            IndexInputAxis,
348            IndexOutputAxis,
349            TimeInputAxis,
350            SpaceInputAxis,
351            TimeOutputAxis,
352            TimeOutputAxisWithHalo,
353            SpaceOutputAxis,
354            SpaceOutputAxisWithHalo,
355        ],
356        n: ParameterizedSize_N = 0,
357        ref_size: Optional[int] = None,
358    ):
359        """Compute the concrete size for a given axis and its reference axis.
360
361        Args:
362            axis: The axis this `SizeReference` is the size of.
363            ref_axis: The reference axis to compute the size from.
364            n: If the **ref_axis** is parameterized (of type `ParameterizedSize`)
365                and no fixed **ref_size** is given,
366                **n** is used to compute the size of the parameterized **ref_axis**.
367            ref_size: Overwrite the reference size instead of deriving it from
368                **ref_axis**
369                (**ref_axis.scale** is still used; any given **n** is ignored).
370        """
371        assert (
372            axis.size == self
373        ), "Given `axis.size` is not defined by this `SizeReference`"
374
375        assert (
376            ref_axis.id == self.axis_id
377        ), f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}."
378
379        assert axis.unit == ref_axis.unit, (
380            "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`,"
381            f" but {axis.unit}!={ref_axis.unit}"
382        )
383        if ref_size is None:
384            if isinstance(ref_axis.size, (int, float)):
385                ref_size = ref_axis.size
386            elif isinstance(ref_axis.size, ParameterizedSize):
387                ref_size = ref_axis.size.get_size(n)
388            elif isinstance(ref_axis.size, DataDependentSize):
389                raise ValueError(
390                    "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`."
391                )
392            elif isinstance(ref_axis.size, SizeReference):
393                raise ValueError(
394                    "Reference axis referenced in `SizeReference` may not be sized by a"
395                    + " `SizeReference` itself."
396                )
397            else:
398                assert_never(ref_axis.size)
399
400        return int(ref_size * ref_axis.scale / axis.scale + self.offset)
401
402    @staticmethod
403    def _get_unit(
404        axis: Union[
405            ChannelAxis,
406            IndexInputAxis,
407            IndexOutputAxis,
408            TimeInputAxis,
409            SpaceInputAxis,
410            TimeOutputAxis,
411            TimeOutputAxisWithHalo,
412            SpaceOutputAxis,
413            SpaceOutputAxisWithHalo,
414        ],
415    ):
416        return axis.unit

A tensor axis size (extent in pixels/frames) defined in relation to a reference axis.

axis.size = reference.size * reference.scale / axis.scale + offset

Note:

  1. The axis and the referenced axis need to have the same unit (or no unit).
  2. Batch axes may not be referenced.
  3. Fractions are rounded down.
  4. If the reference axis is concatenable the referencing axis is assumed to be concatenable as well with the same block order.

Example: An unisotropic input image of wh=10049 pixels depicts a phsical space of 200196mm². Let's assume that we want to express the image height h in relation to its width w instead of only accepting input images of exactly 10049 pixels (for example to express a range of valid image shapes by parametrizing w, see ParameterizedSize).

>>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2)
>>> h = SpaceInputAxis(
...     id=AxisId("h"),
...     size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1),
...     unit="millimeter",
...     scale=4,
... )
>>> print(h.size.get_size(h, w))
49

⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49

tensor_id: TensorId

tensor id of the reference axis

axis_id: AxisId

axis id of the reference axis

offset: int
332    def get_size(
333        self,
334        axis: Union[
335            ChannelAxis,
336            IndexInputAxis,
337            IndexOutputAxis,
338            TimeInputAxis,
339            SpaceInputAxis,
340            TimeOutputAxis,
341            TimeOutputAxisWithHalo,
342            SpaceOutputAxis,
343            SpaceOutputAxisWithHalo,
344        ],
345        ref_axis: Union[
346            ChannelAxis,
347            IndexInputAxis,
348            IndexOutputAxis,
349            TimeInputAxis,
350            SpaceInputAxis,
351            TimeOutputAxis,
352            TimeOutputAxisWithHalo,
353            SpaceOutputAxis,
354            SpaceOutputAxisWithHalo,
355        ],
356        n: ParameterizedSize_N = 0,
357        ref_size: Optional[int] = None,
358    ):
359        """Compute the concrete size for a given axis and its reference axis.
360
361        Args:
362            axis: The axis this `SizeReference` is the size of.
363            ref_axis: The reference axis to compute the size from.
364            n: If the **ref_axis** is parameterized (of type `ParameterizedSize`)
365                and no fixed **ref_size** is given,
366                **n** is used to compute the size of the parameterized **ref_axis**.
367            ref_size: Overwrite the reference size instead of deriving it from
368                **ref_axis**
369                (**ref_axis.scale** is still used; any given **n** is ignored).
370        """
371        assert (
372            axis.size == self
373        ), "Given `axis.size` is not defined by this `SizeReference`"
374
375        assert (
376            ref_axis.id == self.axis_id
377        ), f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}."
378
379        assert axis.unit == ref_axis.unit, (
380            "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`,"
381            f" but {axis.unit}!={ref_axis.unit}"
382        )
383        if ref_size is None:
384            if isinstance(ref_axis.size, (int, float)):
385                ref_size = ref_axis.size
386            elif isinstance(ref_axis.size, ParameterizedSize):
387                ref_size = ref_axis.size.get_size(n)
388            elif isinstance(ref_axis.size, DataDependentSize):
389                raise ValueError(
390                    "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`."
391                )
392            elif isinstance(ref_axis.size, SizeReference):
393                raise ValueError(
394                    "Reference axis referenced in `SizeReference` may not be sized by a"
395                    + " `SizeReference` itself."
396                )
397            else:
398                assert_never(ref_axis.size)
399
400        return int(ref_size * ref_axis.scale / axis.scale + self.offset)

Compute the concrete size for a given axis and its reference axis.

Arguments:
  • axis: The axis this SizeReference is the size of.
  • ref_axis: The reference axis to compute the size from.
  • n: If the ref_axis is parameterized (of type ParameterizedSize) and no fixed ref_size is given, n is used to compute the size of the parameterized ref_axis.
  • ref_size: Overwrite the reference size instead of deriving it from ref_axis (ref_axis.scale is still used; any given n is ignored).
419class AxisBase(NodeWithExplicitlySetFields):
420    fields_to_set_explicitly: ClassVar[FrozenSet[LiteralString]] = frozenset({"type"})
421
422    id: AxisId
423    """An axis id unique across all axes of one tensor."""
424
425    description: Annotated[str, MaxLen(128)] = ""
fields_to_set_explicitly: ClassVar[FrozenSet[LiteralString]] = frozenset({'type'})

set set these fields explicitly with their default value if they are not set, such that they are always included even when dumping with 'exlude_unset'

id: AxisId

An axis id unique across all axes of one tensor.

description: Annotated[str, MaxLen(max_length=128)]
class WithHalo(bioimageio.spec._internal.node.Node):
428class WithHalo(Node):
429    halo: Annotated[int, Ge(1)]
430    """The halo should be cropped from the output tensor to avoid boundary effects.
431    It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`.
432    To document a halo that is already cropped by the model use `size.offset` instead."""
433
434    size: Annotated[
435        SizeReference,
436        Field(
437            examples=[
438                10,
439                SizeReference(
440                    tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
441                ).model_dump(mode="json"),
442            ]
443        ),
444    ]
445    """reference to another axis with an optional offset (see `SizeReference`)"""
halo: Annotated[int, Ge(ge=1)]

The halo should be cropped from the output tensor to avoid boundary effects. It is to be cropped from both sides, i.e. size_after_crop = size - 2 * halo. To document a halo that is already cropped by the model use size.offset instead.

size: Annotated[SizeReference, FieldInfo(annotation=NoneType, required=True, examples=[10, {'tensor_id': 't', 'axis_id': 'a', 'offset': 5}])]

reference to another axis with an optional offset (see SizeReference)

BATCH_AXIS_ID = 'batch'
class BatchAxis(AxisBase):
451class BatchAxis(AxisBase):
452    type: Literal["batch"] = "batch"
453    id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID
454    size: Optional[Literal[1]] = None
455    """The batch size may be fixed to 1,
456    otherwise (the default) it may be chosen arbitrarily depending on available memory"""
457
458    @property
459    def scale(self):
460        return 1.0
461
462    @property
463    def concatenable(self):
464        return True
465
466    @property
467    def unit(self):
468        return None
type: Literal['batch']
id: Annotated[AxisId, Predicate(_is_batch)]

An axis id unique across all axes of one tensor.

size: Optional[Literal[1]]

The batch size may be fixed to 1, otherwise (the default) it may be chosen arbitrarily depending on available memory

scale
458    @property
459    def scale(self):
460        return 1.0
concatenable
462    @property
463    def concatenable(self):
464        return True
unit
466    @property
467    def unit(self):
468        return None
class ChannelAxis(AxisBase):
471class ChannelAxis(AxisBase):
472    type: Literal["channel"] = "channel"
473    id: NonBatchAxisId = AxisId("channel")
474    channel_names: NotEmpty[List[Identifier]]
475
476    @property
477    def size(self) -> int:
478        return len(self.channel_names)
479
480    @property
481    def concatenable(self):
482        return False
483
484    @property
485    def scale(self) -> float:
486        return 1.0
487
488    @property
489    def unit(self):
490        return None
type: Literal['channel']
id: Annotated[AxisId, Predicate(_is_not_batch)]

An axis id unique across all axes of one tensor.

channel_names: Annotated[List[bioimageio.spec._internal.types.Identifier], MinLen(min_length=1)]
size: int
476    @property
477    def size(self) -> int:
478        return len(self.channel_names)
concatenable
480    @property
481    def concatenable(self):
482        return False
scale: float
484    @property
485    def scale(self) -> float:
486        return 1.0
unit
488    @property
489    def unit(self):
490        return None
class IndexAxisBase(AxisBase):
493class IndexAxisBase(AxisBase):
494    type: Literal["index"] = "index"
495    id: NonBatchAxisId = AxisId("index")
496
497    @property
498    def scale(self) -> float:
499        return 1.0
500
501    @property
502    def unit(self):
503        return None
type: Literal['index']
id: Annotated[AxisId, Predicate(_is_not_batch)]

An axis id unique across all axes of one tensor.

scale: float
497    @property
498    def scale(self) -> float:
499        return 1.0
unit
501    @property
502    def unit(self):
503        return None
class IndexInputAxis(IndexAxisBase, _WithInputAxisSize):
526class IndexInputAxis(IndexAxisBase, _WithInputAxisSize):
527    concatenable: bool = False
528    """If a model has a `concatenable` input axis, it can be processed blockwise,
529    splitting a longer sample axis into blocks matching its input tensor description.
530    Output axes are concatenable if they have a `SizeReference` to a concatenable
531    input axis.
532    """
concatenable: bool

If a model has a concatenable input axis, it can be processed blockwise, splitting a longer sample axis into blocks matching its input tensor description. Output axes are concatenable if they have a SizeReference to a concatenable input axis.

class IndexOutputAxis(IndexAxisBase):
535class IndexOutputAxis(IndexAxisBase):
536    size: Annotated[
537        Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize],
538        Field(
539            examples=[
540                10,
541                SizeReference(
542                    tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
543                ).model_dump(mode="json"),
544            ]
545        ),
546    ]
547    """The size/length of this axis can be specified as
548    - fixed integer
549    - reference to another axis with an optional offset (`SizeReference`)
550    - data dependent size using `DataDependentSize` (size is only known after model inference)
551    """
size: Annotated[Union[Annotated[int, Gt(gt=0)], SizeReference, DataDependentSize], FieldInfo(annotation=NoneType, required=True, examples=[10, {'tensor_id': 't', 'axis_id': 'a', 'offset': 5}])]

The size/length of this axis can be specified as

  • fixed integer
  • reference to another axis with an optional offset (SizeReference)
  • data dependent size using DataDependentSize (size is only known after model inference)
class TimeAxisBase(AxisBase):
554class TimeAxisBase(AxisBase):
555    type: Literal["time"] = "time"
556    id: NonBatchAxisId = AxisId("time")
557    unit: Optional[TimeUnit] = None
558    scale: Annotated[float, Gt(0)] = 1.0
type: Literal['time']
id: Annotated[AxisId, Predicate(_is_not_batch)]

An axis id unique across all axes of one tensor.

unit: Optional[Literal['attosecond', 'centisecond', 'day', 'decisecond', 'exasecond', 'femtosecond', 'gigasecond', 'hectosecond', 'hour', 'kilosecond', 'megasecond', 'microsecond', 'millisecond', 'minute', 'nanosecond', 'petasecond', 'picosecond', 'second', 'terasecond', 'yoctosecond', 'yottasecond', 'zeptosecond', 'zettasecond']]
scale: Annotated[float, Gt(gt=0)]
class TimeInputAxis(TimeAxisBase, _WithInputAxisSize):
561class TimeInputAxis(TimeAxisBase, _WithInputAxisSize):
562    concatenable: bool = False
563    """If a model has a `concatenable` input axis, it can be processed blockwise,
564    splitting a longer sample axis into blocks matching its input tensor description.
565    Output axes are concatenable if they have a `SizeReference` to a concatenable
566    input axis.
567    """
concatenable: bool

If a model has a concatenable input axis, it can be processed blockwise, splitting a longer sample axis into blocks matching its input tensor description. Output axes are concatenable if they have a SizeReference to a concatenable input axis.

class SpaceAxisBase(AxisBase):
570class SpaceAxisBase(AxisBase):
571    type: Literal["space"] = "space"
572    id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x")
573    unit: Optional[SpaceUnit] = None
574    scale: Annotated[float, Gt(0)] = 1.0
type: Literal['space']
id: Annotated[AxisId, Predicate(_is_not_batch), FieldInfo(annotation=NoneType, required=True, examples=['x', 'y', 'z'])]

An axis id unique across all axes of one tensor.

unit: Optional[Literal['attometer', 'angstrom', 'centimeter', 'decimeter', 'exameter', 'femtometer', 'foot', 'gigameter', 'hectometer', 'inch', 'kilometer', 'megameter', 'meter', 'micrometer', 'mile', 'millimeter', 'nanometer', 'parsec', 'petameter', 'picometer', 'terameter', 'yard', 'yoctometer', 'yottameter', 'zeptometer', 'zettameter']]
scale: Annotated[float, Gt(gt=0)]
class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize):
577class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize):
578    concatenable: bool = False
579    """If a model has a `concatenable` input axis, it can be processed blockwise,
580    splitting a longer sample axis into blocks matching its input tensor description.
581    Output axes are concatenable if they have a `SizeReference` to a concatenable
582    input axis.
583    """
concatenable: bool

If a model has a concatenable input axis, it can be processed blockwise, splitting a longer sample axis into blocks matching its input tensor description. Output axes are concatenable if they have a SizeReference to a concatenable input axis.

InputAxis = typing.Annotated[typing.Union[BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)]
class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize):
610class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize):
611    pass
class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo):
614class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo):
615    pass
class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize):
634class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize):
635    pass
class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo):
638class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo):
639    pass
OutputAxis = typing.Annotated[typing.Union[BatchAxis, ChannelAxis, IndexOutputAxis, typing.Annotated[typing.Union[typing.Annotated[TimeOutputAxis, Tag(tag='wo_halo')], typing.Annotated[TimeOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)], typing.Annotated[typing.Union[typing.Annotated[SpaceOutputAxis, Tag(tag='wo_halo')], typing.Annotated[SpaceOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)]
AnyAxis = typing.Union[typing.Annotated[typing.Union[BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], typing.Annotated[typing.Union[BatchAxis, ChannelAxis, IndexOutputAxis, typing.Annotated[typing.Union[typing.Annotated[TimeOutputAxis, Tag(tag='wo_halo')], typing.Annotated[TimeOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)], typing.Annotated[typing.Union[typing.Annotated[SpaceOutputAxis, Tag(tag='wo_halo')], typing.Annotated[SpaceOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]
TVs = typing.Union[typing.Annotated[typing.List[int], MinLen(min_length=1)], typing.Annotated[typing.List[float], MinLen(min_length=1)], typing.Annotated[typing.List[bool], MinLen(min_length=1)], typing.Annotated[typing.List[str], MinLen(min_length=1)]]
NominalOrOrdinalDType = typing.Literal['float32', 'float64', 'uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64', 'bool']
class NominalOrOrdinalDataDescr(bioimageio.spec._internal.node.Node):
681class NominalOrOrdinalDataDescr(Node):
682    values: TVs
683    """A fixed set of nominal or an ascending sequence of ordinal values.
684    In this case `data_type` is required to be an unsigend integer type, e.g. 'uint8'.
685    String `values` are interpreted as labels for tensor values 0, ..., N.
686    Note: as YAML 1.2 does not natively support a "set" datatype,
687    nominal values should be given as a sequence (aka list/array) as well.
688    """
689
690    type: Annotated[
691        NominalOrOrdinalDType,
692        Field(
693            examples=[
694                "float32",
695                "uint8",
696                "uint16",
697                "int64",
698                "bool",
699            ],
700        ),
701    ] = "uint8"
702
703    @model_validator(mode="after")
704    def _validate_values_match_type(
705        self,
706    ) -> Self:
707        incompatible: List[Any] = []
708        for v in self.values:
709            if self.type == "bool":
710                if not isinstance(v, bool):
711                    incompatible.append(v)
712            elif self.type in DTYPE_LIMITS:
713                if (
714                    isinstance(v, (int, float))
715                    and (
716                        v < DTYPE_LIMITS[self.type].min
717                        or v > DTYPE_LIMITS[self.type].max
718                    )
719                    or (isinstance(v, str) and "uint" not in self.type)
720                    or (isinstance(v, float) and "int" in self.type)
721                ):
722                    incompatible.append(v)
723            else:
724                incompatible.append(v)
725
726            if len(incompatible) == 5:
727                incompatible.append("...")
728                break
729
730        if incompatible:
731            raise ValueError(
732                f"data type '{self.type}' incompatible with values {incompatible}"
733            )
734
735        return self
736
737    unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None
738
739    @property
740    def range(self):
741        if isinstance(self.values[0], str):
742            return 0, len(self.values) - 1
743        else:
744            return min(self.values), max(self.values)
values: Union[Annotated[List[int], MinLen(min_length=1)], Annotated[List[float], MinLen(min_length=1)], Annotated[List[bool], MinLen(min_length=1)], Annotated[List[str], MinLen(min_length=1)]]

A fixed set of nominal or an ascending sequence of ordinal values. In this case data_type is required to be an unsigend integer type, e.g. 'uint8'. String values are interpreted as labels for tensor values 0, ..., N. Note: as YAML 1.2 does not natively support a "set" datatype, nominal values should be given as a sequence (aka list/array) as well.

type: Annotated[Literal['float32', 'float64', 'uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64', 'bool'], FieldInfo(annotation=NoneType, required=True, examples=['float32', 'uint8', 'uint16', 'int64', 'bool'])]
unit: Union[Literal['arbitrary unit'], bioimageio.spec._internal.types.SiUnit, NoneType]
range
739    @property
740    def range(self):
741        if isinstance(self.values[0], str):
742            return 0, len(self.values) - 1
743        else:
744            return min(self.values), max(self.values)
IntervalOrRatioDType = typing.Literal['float32', 'float64', 'uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64']
class IntervalOrRatioDataDescr(bioimageio.spec._internal.node.Node):
761class IntervalOrRatioDataDescr(Node):
762    type: Annotated[  # todo: rename to dtype
763        IntervalOrRatioDType,
764        Field(
765            examples=["float32", "float64", "uint8", "uint16"],
766        ),
767    ] = "float32"
768    range: Tuple[Optional[float], Optional[float]] = (
769        None,
770        None,
771    )
772    """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor.
773    `None` corresponds to min/max of what can be expressed by `data_type`."""
774    unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit"
775    scale: float = 1.0
776    """Scale for data on an interval (or ratio) scale."""
777    offset: Optional[float] = None
778    """Offset for data on a ratio scale."""
type: Annotated[Literal['float32', 'float64', 'uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64'], FieldInfo(annotation=NoneType, required=True, examples=['float32', 'float64', 'uint8', 'uint16'])]
range: Tuple[Optional[float], Optional[float]]

Tuple (minimum, maximum) specifying the allowed range of the data in this tensor. None corresponds to min/max of what can be expressed by data_type.

unit: Union[Literal['arbitrary unit'], bioimageio.spec._internal.types.SiUnit]
scale: float

Scale for data on an interval (or ratio) scale.

offset: Optional[float]

Offset for data on a ratio scale.

TensorDataDescr = typing.Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr]
class ProcessingDescrBase(bioimageio.spec._internal.common_nodes.NodeWithExplicitlySetFields, abc.ABC):
784class ProcessingDescrBase(NodeWithExplicitlySetFields, ABC):
785    """processing base class"""
786
787    # id: Literal[PreprocessingId, PostprocessingId]  # make abstract field
788    fields_to_set_explicitly: ClassVar[FrozenSet[LiteralString]] = frozenset({"id"})

processing base class

fields_to_set_explicitly: ClassVar[FrozenSet[LiteralString]] = frozenset({'id'})

set set these fields explicitly with their default value if they are not set, such that they are always included even when dumping with 'exlude_unset'

class BinarizeKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
791class BinarizeKwargs(ProcessingKwargs):
792    """key word arguments for `BinarizeDescr`"""
793
794    threshold: float
795    """The fixed threshold"""

key word arguments for BinarizeDescr

threshold: float

The fixed threshold

class BinarizeAlongAxisKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
798class BinarizeAlongAxisKwargs(ProcessingKwargs):
799    """key word arguments for `BinarizeDescr`"""
800
801    threshold: NotEmpty[List[float]]
802    """The fixed threshold values along `axis`"""
803
804    axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
805    """The `threshold` axis"""

key word arguments for BinarizeDescr

threshold: Annotated[List[float], MinLen(min_length=1)]

The fixed threshold values along axis

axis: Annotated[AxisId, Predicate(_is_not_batch), FieldInfo(annotation=NoneType, required=True, examples=['channel'])]

The threshold axis

class BinarizeDescr(ProcessingDescrBase):
808class BinarizeDescr(ProcessingDescrBase):
809    """Binarize the tensor with a fixed threshold.
810
811    Values above `BinarizeKwargs.threshold`/`BinarizeAlongAxisKwargs.threshold`
812    will be set to one, values below the threshold to zero.
813
814    Examples:
815    - in YAML
816        ```yaml
817        postprocessing:
818          - id: binarize
819            kwargs:
820              axis: 'channel'
821              threshold: [0.25, 0.5, 0.75]
822        ```
823    - in Python:
824        >>> postprocessing = [BinarizeDescr(
825        ...   kwargs=BinarizeAlongAxisKwargs(
826        ...       axis=AxisId('channel'),
827        ...       threshold=[0.25, 0.5, 0.75],
828        ...   )
829        ... )]
830    """
831
832    id: Literal["binarize"] = "binarize"
833    kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs]

Binarize the tensor with a fixed threshold.

Values above BinarizeKwargs.threshold/BinarizeAlongAxisKwargs.threshold will be set to one, values below the threshold to zero.

Examples:

  • in YAML
postprocessing:
  - id: binarize
    kwargs:
      axis: 'channel'
      threshold: [0.25, 0.5, 0.75]
  • in Python:
    >>> postprocessing = [BinarizeDescr(
    ...   kwargs=BinarizeAlongAxisKwargs(
    ...       axis=AxisId('channel'),
    ...       threshold=[0.25, 0.5, 0.75],
    ...   )
    ... )]
    
id: Literal['binarize']
class ClipDescr(ProcessingDescrBase):
836class ClipDescr(ProcessingDescrBase):
837    """Set tensor values below min to min and above max to max.
838
839    See `ScaleRangeDescr` for examples.
840    """
841
842    id: Literal["clip"] = "clip"
843    kwargs: ClipKwargs

Set tensor values below min to min and above max to max.

See ScaleRangeDescr for examples.

id: Literal['clip']
class EnsureDtypeKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
846class EnsureDtypeKwargs(ProcessingKwargs):
847    """key word arguments for `EnsureDtypeDescr`"""
848
849    dtype: Literal[
850        "float32",
851        "float64",
852        "uint8",
853        "int8",
854        "uint16",
855        "int16",
856        "uint32",
857        "int32",
858        "uint64",
859        "int64",
860        "bool",
861    ]

key word arguments for EnsureDtypeDescr

dtype: Literal['float32', 'float64', 'uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64', 'bool']
class EnsureDtypeDescr(ProcessingDescrBase):
864class EnsureDtypeDescr(ProcessingDescrBase):
865    """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching).
866
867    This can for example be used to ensure the inner neural network model gets a
868    different input tensor data type than the fully described bioimage.io model does.
869
870    Examples:
871        The described bioimage.io model (incl. preprocessing) accepts any
872        float32-compatible tensor, normalizes it with percentiles and clipping and then
873        casts it to uint8, which is what the neural network in this example expects.
874        - in YAML
875            ```yaml
876            inputs:
877            - data:
878                type: float32  # described bioimage.io model is compatible with any float32 input tensor
879            preprocessing:
880            - id: scale_range
881                kwargs:
882                axes: ['y', 'x']
883                max_percentile: 99.8
884                min_percentile: 5.0
885            - id: clip
886                kwargs:
887                min: 0.0
888                max: 1.0
889            - id: ensure_dtype
890                kwargs:
891                dtype: uint8
892            ```
893        - in Python:
894            >>> preprocessing = [
895            ...     ScaleRangeDescr(
896            ...         kwargs=ScaleRangeKwargs(
897            ...           axes= (AxisId('y'), AxisId('x')),
898            ...           max_percentile= 99.8,
899            ...           min_percentile= 5.0,
900            ...         )
901            ...     ),
902            ...     ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)),
903            ...     EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")),
904            ... ]
905    """
906
907    id: Literal["ensure_dtype"] = "ensure_dtype"
908    kwargs: EnsureDtypeKwargs

Cast the tensor data type to EnsureDtypeKwargs.dtype (if not matching).

This can for example be used to ensure the inner neural network model gets a different input tensor data type than the fully described bioimage.io model does.

Examples:

The described bioimage.io model (incl. preprocessing) accepts any float32-compatible tensor, normalizes it with percentiles and clipping and then casts it to uint8, which is what the neural network in this example expects.

  • in YAML

inputs:
- data:
    type: float32  # described bioimage.io model is compatible with any float32 input tensor
preprocessing:
- id: scale_range
    kwargs:
    axes: ['y', 'x']
    max_percentile: 99.8
    min_percentile: 5.0
- id: clip
    kwargs:
    min: 0.0
    max: 1.0
- id: ensure_dtype
    kwargs:
    dtype: uint8

  • in Python:
    >>> preprocessing = [
    ...     ScaleRangeDescr(
    ...         kwargs=ScaleRangeKwargs(
    ...           axes= (AxisId('y'), AxisId('x')),
    ...           max_percentile= 99.8,
    ...           min_percentile= 5.0,
    ...         )
    ...     ),
    ...     ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)),
    ...     EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")),
    ... ]
    
id: Literal['ensure_dtype']
class ScaleLinearKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
911class ScaleLinearKwargs(ProcessingKwargs):
912    """Key word arguments for `ScaleLinearDescr`"""
913
914    gain: float = 1.0
915    """multiplicative factor"""
916
917    offset: float = 0.0
918    """additive term"""
919
920    @model_validator(mode="after")
921    def _validate(self) -> Self:
922        if self.gain == 1.0 and self.offset == 0.0:
923            raise ValueError(
924                "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
925                + " != 0.0."
926            )
927
928        return self

Key word arguments for ScaleLinearDescr

gain: float

multiplicative factor

offset: float

additive term

class ScaleLinearAlongAxisKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
931class ScaleLinearAlongAxisKwargs(ProcessingKwargs):
932    """Key word arguments for `ScaleLinearDescr`"""
933
934    axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
935    """The axis of of gains/offsets values."""
936
937    gain: Union[float, NotEmpty[List[float]]] = 1.0
938    """multiplicative factor"""
939
940    offset: Union[float, NotEmpty[List[float]]] = 0.0
941    """additive term"""
942
943    @model_validator(mode="after")
944    def _validate(self) -> Self:
945
946        if isinstance(self.gain, list):
947            if isinstance(self.offset, list):
948                if len(self.gain) != len(self.offset):
949                    raise ValueError(
950                        f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match."
951                    )
952            else:
953                self.offset = [float(self.offset)] * len(self.gain)
954        elif isinstance(self.offset, list):
955            self.gain = [float(self.gain)] * len(self.offset)
956        else:
957            raise ValueError(
958                "Do not specify an `axis` for scalar gain and offset values."
959            )
960
961        if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset):
962            raise ValueError(
963                "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
964                + " != 0.0."
965            )
966
967        return self

Key word arguments for ScaleLinearDescr

axis: Annotated[AxisId, Predicate(_is_not_batch), FieldInfo(annotation=NoneType, required=True, examples=['channel'])]

The axis of of gains/offsets values.

gain: Union[float, Annotated[List[float], MinLen(min_length=1)]]

multiplicative factor

offset: Union[float, Annotated[List[float], MinLen(min_length=1)]]

additive term

class ScaleLinearDescr(ProcessingDescrBase):
 970class ScaleLinearDescr(ProcessingDescrBase):
 971    """Fixed linear scaling.
 972
 973    Examples:
 974      1. Scale with scalar gain and offset
 975        - in YAML
 976        ```yaml
 977        preprocessing:
 978          - id: scale_linear
 979            kwargs:
 980              gain: 2.0
 981              offset: 3.0
 982        ```
 983        - in Python:
 984        >>> preprocessing = [
 985        ...     ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0))
 986        ... ]
 987
 988      2. Independent scaling along an axis
 989        - in YAML
 990        ```yaml
 991        preprocessing:
 992          - id: scale_linear
 993            kwargs:
 994              axis: 'channel'
 995              gain: [1.0, 2.0, 3.0]
 996        ```
 997        - in Python:
 998        >>> preprocessing = [
 999        ...     ScaleLinearDescr(
1000        ...         kwargs=ScaleLinearAlongAxisKwargs(
1001        ...             axis=AxisId("channel"),
1002        ...             gain=[1.0, 2.0, 3.0],
1003        ...         )
1004        ...     )
1005        ... ]
1006
1007    """
1008
1009    id: Literal["scale_linear"] = "scale_linear"
1010    kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs]

Fixed linear scaling.

Examples:
  1. Scale with scalar gain and offset

    • in YAML

      preprocessing:
        - id: scale_linear
          kwargs:
            gain: 2.0
            offset: 3.0
      
    • in Python:

      >>> preprocessing = [
      ...     ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0))
      ... ]
      
  2. Independent scaling along an axis

    • in YAML

      preprocessing:
        - id: scale_linear
          kwargs:
            axis: 'channel'
            gain: [1.0, 2.0, 3.0]
      
    • in Python:

      >>> preprocessing = [
      ...     ScaleLinearDescr(
      ...         kwargs=ScaleLinearAlongAxisKwargs(
      ...             axis=AxisId("channel"),
      ...             gain=[1.0, 2.0, 3.0],
      ...         )
      ...     )
      ... ]
      
id: Literal['scale_linear']
class SigmoidDescr(ProcessingDescrBase):
1013class SigmoidDescr(ProcessingDescrBase):
1014    """The logistic sigmoid funciton, a.k.a. expit function.
1015
1016    Examples:
1017    - in YAML
1018        ```yaml
1019        postprocessing:
1020          - id: sigmoid
1021        ```
1022    - in Python:
1023        >>> postprocessing = [SigmoidDescr()]
1024    """
1025
1026    id: Literal["sigmoid"] = "sigmoid"
1027
1028    @property
1029    def kwargs(self) -> ProcessingKwargs:
1030        """empty kwargs"""
1031        return ProcessingKwargs()

The logistic sigmoid funciton, a.k.a. expit function.

Examples:

  • in YAML
postprocessing:
  - id: sigmoid
  • in Python:
    >>> postprocessing = [SigmoidDescr()]
    
id: Literal['sigmoid']
1028    @property
1029    def kwargs(self) -> ProcessingKwargs:
1030        """empty kwargs"""
1031        return ProcessingKwargs()

empty kwargs

class FixedZeroMeanUnitVarianceKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1034class FixedZeroMeanUnitVarianceKwargs(ProcessingKwargs):
1035    """key word arguments for `FixedZeroMeanUnitVarianceDescr`"""
1036
1037    mean: float
1038    """The mean value to normalize with."""
1039
1040    std: Annotated[float, Ge(1e-6)]
1041    """The standard deviation value to normalize with."""

key word arguments for FixedZeroMeanUnitVarianceDescr

mean: float

The mean value to normalize with.

std: Annotated[float, Ge(ge=1e-06)]

The standard deviation value to normalize with.

class FixedZeroMeanUnitVarianceAlongAxisKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1044class FixedZeroMeanUnitVarianceAlongAxisKwargs(ProcessingKwargs):
1045    """key word arguments for `FixedZeroMeanUnitVarianceDescr`"""
1046
1047    mean: NotEmpty[List[float]]
1048    """The mean value(s) to normalize with."""
1049
1050    std: NotEmpty[List[Annotated[float, Ge(1e-6)]]]
1051    """The standard deviation value(s) to normalize with.
1052    Size must match `mean` values."""
1053
1054    axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])]
1055    """The axis of the mean/std values to normalize each entry along that dimension
1056    separately."""
1057
1058    @model_validator(mode="after")
1059    def _mean_and_std_match(self) -> Self:
1060        if len(self.mean) != len(self.std):
1061            raise ValueError(
1062                f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})"
1063                + " must match."
1064            )
1065
1066        return self

key word arguments for FixedZeroMeanUnitVarianceDescr

mean: Annotated[List[float], MinLen(min_length=1)]

The mean value(s) to normalize with.

std: Annotated[List[Annotated[float, Ge(ge=1e-06)]], MinLen(min_length=1)]

The standard deviation value(s) to normalize with. Size must match mean values.

axis: Annotated[AxisId, Predicate(_is_not_batch), FieldInfo(annotation=NoneType, required=True, examples=['channel', 'index'])]

The axis of the mean/std values to normalize each entry along that dimension separately.

class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1069class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1070    """Subtract a given mean and divide by the standard deviation.
1071
1072    Normalize with fixed, precomputed values for
1073    `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std`
1074    Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given
1075    axes.
1076
1077    Examples:
1078    1. scalar value for whole tensor
1079        - in YAML
1080        ```yaml
1081        preprocessing:
1082          - id: fixed_zero_mean_unit_variance
1083            kwargs:
1084              mean: 103.5
1085              std: 13.7
1086        ```
1087        - in Python
1088        >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
1089        ...   kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7)
1090        ... )]
1091
1092    2. independently along an axis
1093        - in YAML
1094        ```yaml
1095        preprocessing:
1096          - id: fixed_zero_mean_unit_variance
1097            kwargs:
1098              axis: channel
1099              mean: [101.5, 102.5, 103.5]
1100              std: [11.7, 12.7, 13.7]
1101        ```
1102        - in Python
1103        >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
1104        ...   kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
1105        ...     axis=AxisId("channel"),
1106        ...     mean=[101.5, 102.5, 103.5],
1107        ...     std=[11.7, 12.7, 13.7],
1108        ...   )
1109        ... )]
1110    """
1111
1112    id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance"
1113    kwargs: Union[
1114        FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs
1115    ]

Subtract a given mean and divide by the standard deviation.

Normalize with fixed, precomputed values for FixedZeroMeanUnitVarianceKwargs.mean and FixedZeroMeanUnitVarianceKwargs.std Use FixedZeroMeanUnitVarianceAlongAxisKwargs for independent scaling along given axes.

Examples:

  1. scalar value for whole tensor
    • in YAML
preprocessing:
  - id: fixed_zero_mean_unit_variance
    kwargs:
      mean: 103.5
      std: 13.7
- in Python >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( ... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7) ... )]
  1. independently along an axis

    • in YAML
    preprocessing:
      - id: fixed_zero_mean_unit_variance
        kwargs:
          axis: channel
          mean: [101.5, 102.5, 103.5]
          std: [11.7, 12.7, 13.7]
    
    • in Python
      >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
      ...   kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
      ...     axis=AxisId("channel"),
      ...     mean=[101.5, 102.5, 103.5],
      ...     std=[11.7, 12.7, 13.7],
      ...   )
      ... )]
      
id: Literal['fixed_zero_mean_unit_variance']
class ZeroMeanUnitVarianceKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1118class ZeroMeanUnitVarianceKwargs(ProcessingKwargs):
1119    """key word arguments for `ZeroMeanUnitVarianceDescr`"""
1120
1121    axes: Annotated[
1122        Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1123    ] = None
1124    """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1125    For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1126    resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1127    To normalize each sample independently leave out the 'batch' axis.
1128    Default: Scale all axes jointly."""
1129
1130    eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1131    """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`."""

key word arguments for ZeroMeanUnitVarianceDescr

axes: Annotated[Optional[Sequence[AxisId]], FieldInfo(annotation=NoneType, required=True, examples=[('batch', 'x', 'y')])]

The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') resulting in a tensor of equal shape normalized per channel, specify axes=('batch', 'x', 'y'). To normalize each sample independently leave out the 'batch' axis. Default: Scale all axes jointly.

eps: Annotated[float, Interval(gt=0, ge=None, lt=None, le=0.1)]

epsilon for numeric stability: out = (tensor - mean) / (std + eps).

class ZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1134class ZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1135    """Subtract mean and divide by variance.
1136
1137    Examples:
1138        Subtract tensor mean and variance
1139        - in YAML
1140        ```yaml
1141        preprocessing:
1142          - id: zero_mean_unit_variance
1143        ```
1144        - in Python
1145        >>> preprocessing = [ZeroMeanUnitVarianceDescr()]
1146    """
1147
1148    id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance"
1149    kwargs: ZeroMeanUnitVarianceKwargs = Field(
1150        default_factory=ZeroMeanUnitVarianceKwargs
1151    )

Subtract mean and divide by variance.

Examples:

Subtract tensor mean and variance

  • in YAML
preprocessing:
  - id: zero_mean_unit_variance
  • in Python
    >>> preprocessing = [ZeroMeanUnitVarianceDescr()]
    
id: Literal['zero_mean_unit_variance']
class ScaleRangeKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1154class ScaleRangeKwargs(ProcessingKwargs):
1155    """key word arguments for `ScaleRangeDescr`
1156
1157    For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default)
1158    this processing step normalizes data to the [0, 1] intervall.
1159    For other percentiles the normalized values will partially be outside the [0, 1]
1160    intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the
1161    normalized values to a range.
1162    """
1163
1164    axes: Annotated[
1165        Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1166    ] = None
1167    """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value.
1168    For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1169    resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1170    To normalize samples independently, leave out the "batch" axis.
1171    Default: Scale all axes jointly."""
1172
1173    min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0
1174    """The lower percentile used to determine the value to align with zero."""
1175
1176    max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0
1177    """The upper percentile used to determine the value to align with one.
1178    Has to be bigger than `min_percentile`.
1179    The range is 1 to 100 instead of 0 to 100 to avoid mistakenly
1180    accepting percentiles specified in the range 0.0 to 1.0."""
1181
1182    eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1183    """Epsilon for numeric stability.
1184    `out = (tensor - v_lower) / (v_upper - v_lower + eps)`;
1185    with `v_lower,v_upper` values at the respective percentiles."""
1186
1187    reference_tensor: Optional[TensorId] = None
1188    """Tensor ID to compute the percentiles from. Default: The tensor itself.
1189    For any tensor in `inputs` only input tensor references are allowed."""
1190
1191    @field_validator("max_percentile", mode="after")
1192    @classmethod
1193    def min_smaller_max(cls, value: float, info: ValidationInfo) -> float:
1194        if (min_p := info.data["min_percentile"]) >= value:
1195            raise ValueError(f"min_percentile {min_p} >= max_percentile {value}")
1196
1197        return value

key word arguments for ScaleRangeDescr

For min_percentile=0.0 (the default) and max_percentile=100 (the default) this processing step normalizes data to the [0, 1] intervall. For other percentiles the normalized values will partially be outside the [0, 1] intervall. Use ScaleRange followed by ClipDescr if you want to limit the normalized values to a range.

axes: Annotated[Optional[Sequence[AxisId]], FieldInfo(annotation=NoneType, required=True, examples=[('batch', 'x', 'y')])]

The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value. For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') resulting in a tensor of equal shape normalized per channel, specify axes=('batch', 'x', 'y'). To normalize samples independently, leave out the "batch" axis. Default: Scale all axes jointly.

min_percentile: Annotated[float, Interval(gt=None, ge=0, lt=100, le=None)]

The lower percentile used to determine the value to align with zero.

max_percentile: Annotated[float, Interval(gt=1, ge=None, lt=None, le=100)]

The upper percentile used to determine the value to align with one. Has to be bigger than min_percentile. The range is 1 to 100 instead of 0 to 100 to avoid mistakenly accepting percentiles specified in the range 0.0 to 1.0.

eps: Annotated[float, Interval(gt=0, ge=None, lt=None, le=0.1)]

Epsilon for numeric stability. out = (tensor - v_lower) / (v_upper - v_lower + eps); with v_lower,v_upper values at the respective percentiles.

reference_tensor: Optional[TensorId]

Tensor ID to compute the percentiles from. Default: The tensor itself. For any tensor in inputs only input tensor references are allowed.

@field_validator('max_percentile', mode='after')
@classmethod
def min_smaller_max( cls, value: float, info: pydantic_core.core_schema.ValidationInfo) -> float:
1191    @field_validator("max_percentile", mode="after")
1192    @classmethod
1193    def min_smaller_max(cls, value: float, info: ValidationInfo) -> float:
1194        if (min_p := info.data["min_percentile"]) >= value:
1195            raise ValueError(f"min_percentile {min_p} >= max_percentile {value}")
1196
1197        return value
class ScaleRangeDescr(ProcessingDescrBase):
1200class ScaleRangeDescr(ProcessingDescrBase):
1201    """Scale with percentiles.
1202
1203    Examples:
1204    1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0
1205        - in YAML
1206        ```yaml
1207        preprocessing:
1208          - id: scale_range
1209            kwargs:
1210              axes: ['y', 'x']
1211              max_percentile: 99.8
1212              min_percentile: 5.0
1213        ```
1214        - in Python
1215        >>> preprocessing = [
1216        ...     ScaleRangeDescr(
1217        ...         kwargs=ScaleRangeKwargs(
1218        ...           axes= (AxisId('y'), AxisId('x')),
1219        ...           max_percentile= 99.8,
1220        ...           min_percentile= 5.0,
1221        ...         )
1222        ...     ),
1223        ...     ClipDescr(
1224        ...         kwargs=ClipKwargs(
1225        ...             min=0.0,
1226        ...             max=1.0,
1227        ...         )
1228        ...     ),
1229        ... ]
1230
1231      2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles.
1232        - in YAML
1233        ```yaml
1234        preprocessing:
1235          - id: scale_range
1236            kwargs:
1237              axes: ['y', 'x']
1238              max_percentile: 99.8
1239              min_percentile: 5.0
1240                  - id: scale_range
1241           - id: clip
1242             kwargs:
1243              min: 0.0
1244              max: 1.0
1245        ```
1246        - in Python
1247        >>> preprocessing = [ScaleRangeDescr(
1248        ...   kwargs=ScaleRangeKwargs(
1249        ...       axes= (AxisId('y'), AxisId('x')),
1250        ...       max_percentile= 99.8,
1251        ...       min_percentile= 5.0,
1252        ...   )
1253        ... )]
1254
1255    """
1256
1257    id: Literal["scale_range"] = "scale_range"
1258    kwargs: ScaleRangeKwargs

Scale with percentiles.

Examples:

  1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0
    • in YAML
preprocessing:
  - id: scale_range
    kwargs:
      axes: ['y', 'x']
      max_percentile: 99.8
      min_percentile: 5.0
- in Python >>> preprocessing = [ ... ScaleRangeDescr( ... kwargs=ScaleRangeKwargs( ... axes= (AxisId('y'), AxisId('x')), ... max_percentile= 99.8, ... min_percentile= 5.0, ... ) ... ), ... ClipDescr( ... kwargs=ClipKwargs( ... min=0.0, ... max=1.0, ... ) ... ), ... ]
  1. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles.

    • in YAML
    preprocessing:
      - id: scale_range
        kwargs:
          axes: ['y', 'x']
          max_percentile: 99.8
          min_percentile: 5.0
              - id: scale_range
       - id: clip
         kwargs:
          min: 0.0
          max: 1.0
    
    • in Python
      >>> preprocessing = [ScaleRangeDescr(
      ...   kwargs=ScaleRangeKwargs(
      ...       axes= (AxisId('y'), AxisId('x')),
      ...       max_percentile= 99.8,
      ...       min_percentile= 5.0,
      ...   )
      ... )]
      
id: Literal['scale_range']
class ScaleMeanVarianceKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1261class ScaleMeanVarianceKwargs(ProcessingKwargs):
1262    """key word arguments for `ScaleMeanVarianceKwargs`"""
1263
1264    reference_tensor: TensorId
1265    """Name of tensor to match."""
1266
1267    axes: Annotated[
1268        Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1269    ] = None
1270    """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1271    For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1272    resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1273    To normalize samples independently, leave out the 'batch' axis.
1274    Default: Scale all axes jointly."""
1275
1276    eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1277    """Epsilon for numeric stability:
1278    `out  = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`"""

key word arguments for ScaleMeanVarianceKwargs

reference_tensor: TensorId

Name of tensor to match.

axes: Annotated[Optional[Sequence[AxisId]], FieldInfo(annotation=NoneType, required=True, examples=[('batch', 'x', 'y')])]

The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') resulting in a tensor of equal shape normalized per channel, specify axes=('batch', 'x', 'y'). To normalize samples independently, leave out the 'batch' axis. Default: Scale all axes jointly.

eps: Annotated[float, Interval(gt=0, ge=None, lt=None, le=0.1)]

Epsilon for numeric stability: out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.

class ScaleMeanVarianceDescr(ProcessingDescrBase):
1281class ScaleMeanVarianceDescr(ProcessingDescrBase):
1282    """Scale a tensor's data distribution to match another tensor's mean/std.
1283    `out  = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`
1284    """
1285
1286    id: Literal["scale_mean_variance"] = "scale_mean_variance"
1287    kwargs: ScaleMeanVarianceKwargs

Scale a tensor's data distribution to match another tensor's mean/std. out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.

id: Literal['scale_mean_variance']
PreprocessingDescr = typing.Annotated[typing.Union[BinarizeDescr, ClipDescr, EnsureDtypeDescr, ScaleLinearDescr, SigmoidDescr, FixedZeroMeanUnitVarianceDescr, ZeroMeanUnitVarianceDescr, ScaleRangeDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)]
PostprocessingDescr = typing.Annotated[typing.Union[BinarizeDescr, ClipDescr, EnsureDtypeDescr, ScaleLinearDescr, SigmoidDescr, FixedZeroMeanUnitVarianceDescr, ZeroMeanUnitVarianceDescr, ScaleRangeDescr, ScaleMeanVarianceDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)]
class TensorDescrBase(bioimageio.spec._internal.node.Node, typing.Generic[~IO_AxisT]):
1321class TensorDescrBase(Node, Generic[IO_AxisT]):
1322    id: TensorId
1323    """Tensor id. No duplicates are allowed."""
1324
1325    description: Annotated[str, MaxLen(128)] = ""
1326    """free text description"""
1327
1328    axes: NotEmpty[Sequence[IO_AxisT]]
1329    """tensor axes"""
1330
1331    @property
1332    def shape(self):
1333        return tuple(a.size for a in self.axes)
1334
1335    @field_validator("axes", mode="after", check_fields=False)
1336    @classmethod
1337    def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]:
1338        batch_axes = [a for a in axes if a.type == "batch"]
1339        if len(batch_axes) > 1:
1340            raise ValueError(
1341                f"Only one batch axis (per tensor) allowed, but got {batch_axes}"
1342            )
1343
1344        seen_ids: Set[AxisId] = set()
1345        duplicate_axes_ids: Set[AxisId] = set()
1346        for a in axes:
1347            (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id)
1348
1349        if duplicate_axes_ids:
1350            raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}")
1351
1352        return axes
1353
1354    test_tensor: FileDescr
1355    """An example tensor to use for testing.
1356    Using the model with the test input tensors is expected to yield the test output tensors.
1357    Each test tensor has be a an ndarray in the
1358    [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format).
1359    The file extension must be '.npy'."""
1360
1361    sample_tensor: Optional[FileDescr] = None
1362    """A sample tensor to illustrate a possible input/output for the model,
1363    The sample image primarily serves to inform a human user about an example use case
1364    and is typically stored as .hdf5, .png or .tiff.
1365    It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats)
1366    (numpy's `.npy` format is not supported).
1367    The image dimensionality has to match the number of axes specified in this tensor description.
1368    """
1369
1370    @model_validator(mode="after")
1371    def _validate_sample_tensor(self) -> Self:
1372        if (
1373            self.sample_tensor is None
1374            or not validation_context_var.get().perform_io_checks
1375        ):
1376            return self
1377
1378        local = download(self.sample_tensor.source, sha256=self.sample_tensor.sha256)
1379        tensor: NDArray[Any] = imread(
1380            local.path.read_bytes(),
1381            extension=PurePosixPath(local.original_file_name).suffix,
1382        )
1383        n_dims = len(tensor.squeeze().shape)
1384        n_dims_min = n_dims_max = len(self.axes)
1385
1386        for a in self.axes:
1387            if isinstance(a, BatchAxis):
1388                n_dims_min -= 1
1389            elif isinstance(a.size, int):
1390                if a.size == 1:
1391                    n_dims_min -= 1
1392            elif isinstance(a.size, (ParameterizedSize, DataDependentSize)):
1393                if a.size.min == 1:
1394                    n_dims_min -= 1
1395            elif isinstance(a.size, SizeReference):
1396                if a.size.offset < 2:
1397                    # size reference may result in singleton axis
1398                    n_dims_min -= 1
1399            else:
1400                assert_never(a.size)
1401
1402        n_dims_min = max(0, n_dims_min)
1403        if n_dims < n_dims_min or n_dims > n_dims_max:
1404            raise ValueError(
1405                f"Expected sample tensor to have {n_dims_min} to"
1406                + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})."
1407            )
1408
1409        return self
1410
1411    data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = (
1412        IntervalOrRatioDataDescr()
1413    )
1414    """Description of the tensor's data values, optionally per channel.
1415    If specified per channel, the data `type` needs to match across channels."""
1416
1417    @property
1418    def dtype(
1419        self,
1420    ) -> Literal[
1421        "float32",
1422        "float64",
1423        "uint8",
1424        "int8",
1425        "uint16",
1426        "int16",
1427        "uint32",
1428        "int32",
1429        "uint64",
1430        "int64",
1431        "bool",
1432    ]:
1433        """dtype as specified under `data.type` or `data[i].type`"""
1434        if isinstance(self.data, collections.abc.Sequence):
1435            return self.data[0].type
1436        else:
1437            return self.data.type
1438
1439    @field_validator("data", mode="after")
1440    @classmethod
1441    def _check_data_type_across_channels(
1442        cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]
1443    ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]:
1444        if not isinstance(value, list):
1445            return value
1446
1447        dtypes = {t.type for t in value}
1448        if len(dtypes) > 1:
1449            raise ValueError(
1450                "Tensor data descriptions per channel need to agree in their data"
1451                + f" `type`, but found {dtypes}."
1452            )
1453
1454        return value
1455
1456    @model_validator(mode="after")
1457    def _check_data_matches_channelaxis(self) -> Self:
1458        if not isinstance(self.data, (list, tuple)):
1459            return self
1460
1461        for a in self.axes:
1462            if isinstance(a, ChannelAxis):
1463                size = a.size
1464                assert isinstance(size, int)
1465                break
1466        else:
1467            return self
1468
1469        if len(self.data) != size:
1470            raise ValueError(
1471                f"Got tensor data descriptions for {len(self.data)} channels, but"
1472                + f" '{a.id}' axis has size {size}."
1473            )
1474
1475        return self
1476
1477    def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]:
1478        if len(array.shape) != len(self.axes):
1479            raise ValueError(
1480                f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})"
1481                + f" incompatible with {len(self.axes)} axes."
1482            )
1483        return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
id: TensorId

Tensor id. No duplicates are allowed.

description: Annotated[str, MaxLen(max_length=128)]

free text description

axes: Annotated[Sequence[~IO_AxisT], MinLen(min_length=1)]

tensor axes

shape
1331    @property
1332    def shape(self):
1333        return tuple(a.size for a in self.axes)

An example tensor to use for testing. Using the model with the test input tensors is expected to yield the test output tensors. Each test tensor has be a an ndarray in the numpy.lib file format. The file extension must be '.npy'.

sample_tensor: Optional[bioimageio.spec._internal.io.FileDescr]

A sample tensor to illustrate a possible input/output for the model, The sample image primarily serves to inform a human user about an example use case and is typically stored as .hdf5, .png or .tiff. It has to be readable by the imageio library (numpy's .npy format is not supported). The image dimensionality has to match the number of axes specified in this tensor description.

data: Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr, Annotated[Sequence[Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr]], MinLen(min_length=1)]]

Description of the tensor's data values, optionally per channel. If specified per channel, the data type needs to match across channels.

dtype: Literal['float32', 'float64', 'uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64', 'bool']
1417    @property
1418    def dtype(
1419        self,
1420    ) -> Literal[
1421        "float32",
1422        "float64",
1423        "uint8",
1424        "int8",
1425        "uint16",
1426        "int16",
1427        "uint32",
1428        "int32",
1429        "uint64",
1430        "int64",
1431        "bool",
1432    ]:
1433        """dtype as specified under `data.type` or `data[i].type`"""
1434        if isinstance(self.data, collections.abc.Sequence):
1435            return self.data[0].type
1436        else:
1437            return self.data.type

dtype as specified under data.type or data[i].type

def get_axis_sizes_for_array( self, array: numpy.ndarray[tuple[int, ...], numpy.dtype[typing.Any]]) -> Dict[AxisId, int]:
1477    def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]:
1478        if len(array.shape) != len(self.axes):
1479            raise ValueError(
1480                f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})"
1481                + f" incompatible with {len(self.axes)} axes."
1482            )
1483        return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
class InputTensorDescr(bioimageio.spec._internal.node.Node, typing.Generic[~IO_AxisT]):
1486class InputTensorDescr(TensorDescrBase[InputAxis]):
1487    id: TensorId = TensorId("input")
1488    """Input tensor id.
1489    No duplicates are allowed across all inputs and outputs."""
1490
1491    optional: bool = False
1492    """indicates that this tensor may be `None`"""
1493
1494    preprocessing: List[PreprocessingDescr] = Field(default_factory=list)
1495    """Description of how this input should be preprocessed.
1496
1497    notes:
1498    - If preprocessing does not start with an 'ensure_dtype' entry, it is added
1499      to ensure an input tensor's data type matches the input tensor's data description.
1500    - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an
1501      'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally
1502      changing the data type.
1503    """
1504
1505    @model_validator(mode="after")
1506    def _validate_preprocessing_kwargs(self) -> Self:
1507        axes_ids = [a.id for a in self.axes]
1508        for p in self.preprocessing:
1509            kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
1510            if kwargs_axes is None:
1511                continue
1512
1513            if not isinstance(kwargs_axes, collections.abc.Sequence):
1514                raise ValueError(
1515                    f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}"
1516                )
1517
1518            if any(a not in axes_ids for a in kwargs_axes):
1519                raise ValueError(
1520                    "`preprocessing.i.kwargs.axes` needs to be subset of axes ids"
1521                )
1522
1523        if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
1524            dtype = self.data.type
1525        else:
1526            dtype = self.data[0].type
1527
1528        # ensure `preprocessing` begins with `EnsureDtypeDescr`
1529        if not self.preprocessing or not isinstance(
1530            self.preprocessing[0], EnsureDtypeDescr
1531        ):
1532            self.preprocessing.insert(
1533                0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1534            )
1535
1536        # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
1537        if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)):
1538            self.preprocessing.append(
1539                EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1540            )
1541
1542        return self
id: TensorId

Input tensor id. No duplicates are allowed across all inputs and outputs.

optional: bool

indicates that this tensor may be None

preprocessing: List[Annotated[Union[BinarizeDescr, ClipDescr, EnsureDtypeDescr, ScaleLinearDescr, SigmoidDescr, FixedZeroMeanUnitVarianceDescr, ZeroMeanUnitVarianceDescr, ScaleRangeDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]

Description of how this input should be preprocessed.

notes:

  • If preprocessing does not start with an 'ensure_dtype' entry, it is added to ensure an input tensor's data type matches the input tensor's data description.
  • If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally changing the data type.
def convert_axes( axes: str, *, shape: Union[Sequence[int], bioimageio.spec.model.v0_4.ParameterizedInputShape, bioimageio.spec.model.v0_4.ImplicitOutputShape], tensor_type: Literal['input', 'output'], halo: Optional[Sequence[int]], size_refs: Mapping[bioimageio.spec.model.v0_4.TensorName, Mapping[str, int]]):
1545def convert_axes(
1546    axes: str,
1547    *,
1548    shape: Union[
1549        Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4
1550    ],
1551    tensor_type: Literal["input", "output"],
1552    halo: Optional[Sequence[int]],
1553    size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
1554):
1555    ret: List[AnyAxis] = []
1556    for i, a in enumerate(axes):
1557        axis_type = _AXIS_TYPE_MAP.get(a, a)
1558        if axis_type == "batch":
1559            ret.append(BatchAxis())
1560            continue
1561
1562        scale = 1.0
1563        if isinstance(shape, _ParameterizedInputShape_v0_4):
1564            if shape.step[i] == 0:
1565                size = shape.min[i]
1566            else:
1567                size = ParameterizedSize(min=shape.min[i], step=shape.step[i])
1568        elif isinstance(shape, _ImplicitOutputShape_v0_4):
1569            ref_t = str(shape.reference_tensor)
1570            if ref_t.count(".") == 1:
1571                t_id, orig_a_id = ref_t.split(".")
1572            else:
1573                t_id = ref_t
1574                orig_a_id = a
1575
1576            a_id = _AXIS_ID_MAP.get(orig_a_id, a)
1577            if not (orig_scale := shape.scale[i]):
1578                # old way to insert a new axis dimension
1579                size = int(2 * shape.offset[i])
1580            else:
1581                scale = 1 / orig_scale
1582                if axis_type in ("channel", "index"):
1583                    # these axes no longer have a scale
1584                    offset_from_scale = orig_scale * size_refs.get(
1585                        _TensorName_v0_4(t_id), {}
1586                    ).get(orig_a_id, 0)
1587                else:
1588                    offset_from_scale = 0
1589                size = SizeReference(
1590                    tensor_id=TensorId(t_id),
1591                    axis_id=AxisId(a_id),
1592                    offset=int(offset_from_scale + 2 * shape.offset[i]),
1593                )
1594        else:
1595            size = shape[i]
1596
1597        if axis_type == "time":
1598            if tensor_type == "input":
1599                ret.append(TimeInputAxis(size=size, scale=scale))
1600            else:
1601                assert not isinstance(size, ParameterizedSize)
1602                if halo is None:
1603                    ret.append(TimeOutputAxis(size=size, scale=scale))
1604                else:
1605                    assert not isinstance(size, int)
1606                    ret.append(
1607                        TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i])
1608                    )
1609
1610        elif axis_type == "index":
1611            if tensor_type == "input":
1612                ret.append(IndexInputAxis(size=size))
1613            else:
1614                if isinstance(size, ParameterizedSize):
1615                    size = DataDependentSize(min=size.min)
1616
1617                ret.append(IndexOutputAxis(size=size))
1618        elif axis_type == "channel":
1619            assert not isinstance(size, ParameterizedSize)
1620            if isinstance(size, SizeReference):
1621                warnings.warn(
1622                    "Conversion of channel size from an implicit output shape may be"
1623                    + " wrong"
1624                )
1625                ret.append(
1626                    ChannelAxis(
1627                        channel_names=[
1628                            Identifier(f"channel{i}") for i in range(size.offset)
1629                        ]
1630                    )
1631                )
1632            else:
1633                ret.append(
1634                    ChannelAxis(
1635                        channel_names=[Identifier(f"channel{i}") for i in range(size)]
1636                    )
1637                )
1638        elif axis_type == "space":
1639            if tensor_type == "input":
1640                ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale))
1641            else:
1642                assert not isinstance(size, ParameterizedSize)
1643                if halo is None or halo[i] == 0:
1644                    ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale))
1645                elif isinstance(size, int):
1646                    raise NotImplementedError(
1647                        f"output axis with halo and fixed size (here {size}) not allowed"
1648                    )
1649                else:
1650                    ret.append(
1651                        SpaceOutputAxisWithHalo(
1652                            id=AxisId(a), size=size, scale=scale, halo=halo[i]
1653                        )
1654                    )
1655
1656    return ret
class OutputTensorDescr(bioimageio.spec._internal.node.Node, typing.Generic[~IO_AxisT]):
1831class OutputTensorDescr(TensorDescrBase[OutputAxis]):
1832    id: TensorId = TensorId("output")
1833    """Output tensor id.
1834    No duplicates are allowed across all inputs and outputs."""
1835
1836    postprocessing: List[PostprocessingDescr] = Field(default_factory=list)
1837    """Description of how this output should be postprocessed.
1838
1839    note: `postprocessing` always ends with an 'ensure_dtype' operation.
1840          If not given this is added to cast to this tensor's `data.type`.
1841    """
1842
1843    @model_validator(mode="after")
1844    def _validate_postprocessing_kwargs(self) -> Self:
1845        axes_ids = [a.id for a in self.axes]
1846        for p in self.postprocessing:
1847            kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
1848            if kwargs_axes is None:
1849                continue
1850
1851            if not isinstance(kwargs_axes, collections.abc.Sequence):
1852                raise ValueError(
1853                    f"expected `axes` sequence, but got {type(kwargs_axes)}"
1854                )
1855
1856            if any(a not in axes_ids for a in kwargs_axes):
1857                raise ValueError("`kwargs.axes` needs to be subset of axes ids")
1858
1859        if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
1860            dtype = self.data.type
1861        else:
1862            dtype = self.data[0].type
1863
1864        # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
1865        if not self.postprocessing or not isinstance(
1866            self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)
1867        ):
1868            self.postprocessing.append(
1869                EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1870            )
1871        return self
id: TensorId

Output tensor id. No duplicates are allowed across all inputs and outputs.

postprocessing: List[Annotated[Union[BinarizeDescr, ClipDescr, EnsureDtypeDescr, ScaleLinearDescr, SigmoidDescr, FixedZeroMeanUnitVarianceDescr, ZeroMeanUnitVarianceDescr, ScaleRangeDescr, ScaleMeanVarianceDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]

Description of how this output should be postprocessed.

note: postprocessing always ends with an 'ensure_dtype' operation. If not given this is added to cast to this tensor's data.type.

TensorDescr = typing.Union[InputTensorDescr, OutputTensorDescr]
def validate_tensors( tensors: Mapping[TensorId, Tuple[Union[InputTensorDescr, OutputTensorDescr], numpy.ndarray[tuple[int, ...], numpy.dtype[Any]]]], tensor_origin: str):
1921def validate_tensors(
1922    tensors: Mapping[TensorId, Tuple[TensorDescr, NDArray[Any]]],
1923    tensor_origin: str,  # for more precise error messages, e.g. 'test_tensor'
1924):
1925    all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, int]]] = {}
1926
1927    def e_msg(d: TensorDescr):
1928        return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]"
1929
1930    for descr, array in tensors.values():
1931        try:
1932            axis_sizes = descr.get_axis_sizes_for_array(array)
1933        except ValueError as e:
1934            raise ValueError(f"{e_msg(descr)} {e}")
1935        else:
1936            all_tensor_axes[descr.id] = {
1937                a.id: (a, axis_sizes[a.id]) for a in descr.axes
1938            }
1939
1940    for descr, array in tensors.values():
1941        if array.dtype.name != descr.dtype:
1942            raise ValueError(
1943                f"{e_msg(descr)}.{tensor_origin}.dtype '{array.dtype.name}' does not"
1944                + f" match described dtype '{descr.dtype}'"
1945            )
1946
1947        for a in descr.axes:
1948            actual_size = all_tensor_axes[descr.id][a.id][1]
1949            if a.size is None:
1950                continue
1951
1952            if isinstance(a.size, int):
1953                if actual_size != a.size:
1954                    raise ValueError(
1955                        f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' "
1956                        + f"has incompatible size {actual_size}, expected {a.size}"
1957                    )
1958            elif isinstance(a.size, ParameterizedSize):
1959                _ = a.size.validate_size(actual_size)
1960            elif isinstance(a.size, DataDependentSize):
1961                _ = a.size.validate_size(actual_size)
1962            elif isinstance(a.size, SizeReference):
1963                ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id)
1964                if ref_tensor_axes is None:
1965                    raise ValueError(
1966                        f"{e_msg(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor"
1967                        + f" reference '{a.size.tensor_id}'"
1968                    )
1969
1970                ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None))
1971                if ref_axis is None or ref_size is None:
1972                    raise ValueError(
1973                        f"{e_msg(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis"
1974                        + f" reference '{a.size.tensor_id}.{a.size.axis_id}"
1975                    )
1976
1977                if a.unit != ref_axis.unit:
1978                    raise ValueError(
1979                        f"{e_msg(descr)}.axes[{a.id}].size: `SizeReference` requires"
1980                        + " axis and reference axis to have the same `unit`, but"
1981                        + f" {a.unit}!={ref_axis.unit}"
1982                    )
1983
1984                if actual_size != (
1985                    expected_size := (
1986                        ref_size * ref_axis.scale / a.scale + a.size.offset
1987                    )
1988                ):
1989                    raise ValueError(
1990                        f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' of size"
1991                        + f" {actual_size} invalid for referenced size {ref_size};"
1992                        + f" expected {expected_size}"
1993                    )
1994            else:
1995                assert_never(a.size)
class EnvironmentFileDescr(bioimageio.spec._internal.io.FileDescr):
1998class EnvironmentFileDescr(FileDescr):
1999    source: Annotated[
2000        ImportantFileSource,
2001        WithSuffix((".yaml", ".yml"), case_sensitive=True),
2002        Field(
2003            examples=["environment.yaml"],
2004        ),
2005    ]
2006    """∈📦 Conda environment file.
2007    Allows to specify custom dependencies, see conda docs:
2008    - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms)
2009    - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually)
2010    """
source: Annotated[Union[bioimageio.spec._internal.url.HttpUrl, bioimageio.spec._internal.io.RelativeFilePath, Annotated[pathlib.Path, PathType(path_type='file'), FieldInfo(annotation=NoneType, required=True, title='FilePath')]], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')]), AfterValidator(func=<function wo_special_file_name at 0x7fec152e3740>), PlainSerializer(func=<function _package at 0x7fec152e2a20>, return_type=PydanticUndefined, when_used='unless-none'), WithSuffix(suffix=('.yaml', '.yml'), case_sensitive=True), FieldInfo(annotation=NoneType, required=True, examples=['environment.yaml'])]

∈📦 Conda environment file. Allows to specify custom dependencies, see conda docs:

class ArchitectureFromFileDescr(_ArchitectureCallableDescr, bioimageio.spec._internal.io.FileDescr):
2021class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr):
2022    pass
class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr):
2025class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr):
2026    import_from: str
2027    """Where to import the callable from, i.e. `from <import_from> import <callable>`"""
import_from: str

Where to import the callable from, i.e. from <import_from> import <callable>

ArchitectureDescr = typing.Annotated[typing.Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')])]
class WeightsEntryDescrBase(bioimageio.spec._internal.io.FileDescr):
2093class WeightsEntryDescrBase(FileDescr):
2094    type: ClassVar[WeightsFormat]
2095    weights_format_name: ClassVar[str]  # human readable
2096
2097    source: ImportantFileSource
2098    """∈📦 The weights file."""
2099
2100    authors: Optional[List[Author]] = None
2101    """Authors
2102    Either the person(s) that have trained this model resulting in the original weights file.
2103        (If this is the initial weights entry, i.e. it does not have a `parent`)
2104    Or the person(s) who have converted the weights to this weights format.
2105        (If this is a child weight, i.e. it has a `parent` field)
2106    """
2107
2108    parent: Annotated[
2109        Optional[WeightsFormat], Field(examples=["pytorch_state_dict"])
2110    ] = None
2111    """The source weights these weights were converted from.
2112    For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`,
2113    The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights.
2114    All weight entries except one (the initial set of weights resulting from training the model),
2115    need to have this field."""
2116
2117    @model_validator(mode="after")
2118    def check_parent_is_not_self(self) -> Self:
2119        if self.type == self.parent:
2120            raise ValueError("Weights entry can't be it's own parent.")
2121
2122        return self
type: ClassVar[Literal['keras_hdf5', 'onnx', 'pytorch_state_dict', 'tensorflow_js', 'tensorflow_saved_model_bundle', 'torchscript']]
weights_format_name: ClassVar[str]
source: Annotated[Union[bioimageio.spec._internal.url.HttpUrl, bioimageio.spec._internal.io.RelativeFilePath, Annotated[pathlib.Path, PathType(path_type='file'), FieldInfo(annotation=NoneType, required=True, title='FilePath')]], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')]), AfterValidator(func=<function wo_special_file_name at 0x7fec152e3740>), PlainSerializer(func=<function _package at 0x7fec152e2a20>, return_type=PydanticUndefined, when_used='unless-none')]

∈📦 The weights file.

authors: Optional[List[bioimageio.spec.generic.v0_3.Author]]

Authors Either the person(s) that have trained this model resulting in the original weights file. (If this is the initial weights entry, i.e. it does not have a parent) Or the person(s) who have converted the weights to this weights format. (If this is a child weight, i.e. it has a parent field)

parent: Annotated[Optional[Literal['keras_hdf5', 'onnx', 'pytorch_state_dict', 'tensorflow_js', 'tensorflow_saved_model_bundle', 'torchscript']], FieldInfo(annotation=NoneType, required=True, examples=['pytorch_state_dict'])]

The source weights these weights were converted from. For example, if a model's weights were converted from the pytorch_state_dict format to torchscript, The pytorch_state_dict weights entry has no parent and is the parent of the torchscript weights. All weight entries except one (the initial set of weights resulting from training the model), need to have this field.

@model_validator(mode='after')
def check_parent_is_not_self(self) -> Self:
2117    @model_validator(mode="after")
2118    def check_parent_is_not_self(self) -> Self:
2119        if self.type == self.parent:
2120            raise ValueError("Weights entry can't be it's own parent.")
2121
2122        return self
class KerasHdf5WeightsDescr(WeightsEntryDescrBase):
2125class KerasHdf5WeightsDescr(WeightsEntryDescrBase):
2126    type = "keras_hdf5"
2127    weights_format_name: ClassVar[str] = "Keras HDF5"
2128    tensorflow_version: Version
2129    """TensorFlow version used to create these weights."""
type = 'keras_hdf5'
weights_format_name: ClassVar[str] = 'Keras HDF5'

TensorFlow version used to create these weights.

class OnnxWeightsDescr(WeightsEntryDescrBase):
2132class OnnxWeightsDescr(WeightsEntryDescrBase):
2133    type = "onnx"
2134    weights_format_name: ClassVar[str] = "ONNX"
2135    opset_version: Annotated[int, Ge(7)]
2136    """ONNX opset version"""
type = 'onnx'
weights_format_name: ClassVar[str] = 'ONNX'
opset_version: Annotated[int, Ge(ge=7)]

ONNX opset version

class PytorchStateDictWeightsDescr(WeightsEntryDescrBase):
2139class PytorchStateDictWeightsDescr(WeightsEntryDescrBase):
2140    type = "pytorch_state_dict"
2141    weights_format_name: ClassVar[str] = "Pytorch State Dict"
2142    architecture: ArchitectureDescr
2143    pytorch_version: Version
2144    """Version of the PyTorch library used.
2145    If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible.
2146    """
2147    dependencies: Optional[EnvironmentFileDescr] = None
2148    """Custom depencies beyond pytorch.
2149    The conda environment file should include pytorch and any version pinning has to be compatible with
2150    `pytorch_version`.
2151    """
type = 'pytorch_state_dict'
weights_format_name: ClassVar[str] = 'Pytorch State Dict'
architecture: Annotated[Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')])]

Version of the PyTorch library used. If architecture.depencencies is specified it has to include pytorch and any version pinning has to be compatible.

dependencies: Optional[EnvironmentFileDescr]

Custom depencies beyond pytorch. The conda environment file should include pytorch and any version pinning has to be compatible with pytorch_version.

class TensorflowJsWeightsDescr(WeightsEntryDescrBase):
2154class TensorflowJsWeightsDescr(WeightsEntryDescrBase):
2155    type = "tensorflow_js"
2156    weights_format_name: ClassVar[str] = "Tensorflow.js"
2157    tensorflow_version: Version
2158    """Version of the TensorFlow library used."""
2159
2160    source: ImportantFileSource
2161    """∈📦 The multi-file weights.
2162    All required files/folders should be a zip archive."""
type = 'tensorflow_js'
weights_format_name: ClassVar[str] = 'Tensorflow.js'

Version of the TensorFlow library used.

source: Annotated[Union[bioimageio.spec._internal.url.HttpUrl, bioimageio.spec._internal.io.RelativeFilePath, Annotated[pathlib.Path, PathType(path_type='file'), FieldInfo(annotation=NoneType, required=True, title='FilePath')]], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')]), AfterValidator(func=<function wo_special_file_name at 0x7fec152e3740>), PlainSerializer(func=<function _package at 0x7fec152e2a20>, return_type=PydanticUndefined, when_used='unless-none')]

∈📦 The multi-file weights. All required files/folders should be a zip archive.

class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase):
2165class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase):
2166    type = "tensorflow_saved_model_bundle"
2167    weights_format_name: ClassVar[str] = "Tensorflow Saved Model"
2168    tensorflow_version: Version
2169    """Version of the TensorFlow library used."""
2170
2171    dependencies: Optional[EnvironmentFileDescr] = None
2172    """Custom dependencies beyond tensorflow.
2173    Should include tensorflow and any version pinning has to be compatible with `tensorflow_version`."""
2174
2175    source: ImportantFileSource
2176    """∈📦 The multi-file weights.
2177    All required files/folders should be a zip archive."""
type = 'tensorflow_saved_model_bundle'
weights_format_name: ClassVar[str] = 'Tensorflow Saved Model'

Version of the TensorFlow library used.

dependencies: Optional[EnvironmentFileDescr]

Custom dependencies beyond tensorflow. Should include tensorflow and any version pinning has to be compatible with tensorflow_version.

source: Annotated[Union[bioimageio.spec._internal.url.HttpUrl, bioimageio.spec._internal.io.RelativeFilePath, Annotated[pathlib.Path, PathType(path_type='file'), FieldInfo(annotation=NoneType, required=True, title='FilePath')]], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')]), AfterValidator(func=<function wo_special_file_name at 0x7fec152e3740>), PlainSerializer(func=<function _package at 0x7fec152e2a20>, return_type=PydanticUndefined, when_used='unless-none')]

∈📦 The multi-file weights. All required files/folders should be a zip archive.

class TorchscriptWeightsDescr(WeightsEntryDescrBase):
2180class TorchscriptWeightsDescr(WeightsEntryDescrBase):
2181    type = "torchscript"
2182    weights_format_name: ClassVar[str] = "TorchScript"
2183    pytorch_version: Version
2184    """Version of the PyTorch library used."""
type = 'torchscript'
weights_format_name: ClassVar[str] = 'TorchScript'

Version of the PyTorch library used.

class WeightsDescr(bioimageio.spec._internal.node.Node):
2187class WeightsDescr(Node):
2188    keras_hdf5: Optional[KerasHdf5WeightsDescr] = None
2189    onnx: Optional[OnnxWeightsDescr] = None
2190    pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None
2191    tensorflow_js: Optional[TensorflowJsWeightsDescr] = None
2192    tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = (
2193        None
2194    )
2195    torchscript: Optional[TorchscriptWeightsDescr] = None
2196
2197    @model_validator(mode="after")
2198    def check_entries(self) -> Self:
2199        entries = {wtype for wtype, entry in self if entry is not None}
2200
2201        if not entries:
2202            raise ValueError("Missing weights entry")
2203
2204        entries_wo_parent = {
2205            wtype
2206            for wtype, entry in self
2207            if entry is not None and hasattr(entry, "parent") and entry.parent is None
2208        }
2209        if len(entries_wo_parent) != 1:
2210            issue_warning(
2211                "Exactly one weights entry may not specify the `parent` field (got"
2212                + " {value}). That entry is considered the original set of model weights."
2213                + " Other weight formats are created through conversion of the orignal or"
2214                + " already converted weights. They have to reference the weights format"
2215                + " they were converted from as their `parent`.",
2216                value=len(entries_wo_parent),
2217                field="weights",
2218            )
2219
2220        for wtype, entry in self:
2221            if entry is None:
2222                continue
2223
2224            assert hasattr(entry, "type")
2225            assert hasattr(entry, "parent")
2226            assert wtype == entry.type
2227            if (
2228                entry.parent is not None and entry.parent not in entries
2229            ):  # self reference checked for `parent` field
2230                raise ValueError(
2231                    f"`weights.{wtype}.parent={entry.parent} not in specified weight"
2232                    + f" formats: {entries}"
2233                )
2234
2235        return self
2236
2237    def __getitem__(
2238        self,
2239        key: Literal[
2240            "keras_hdf5",
2241            "onnx",
2242            "pytorch_state_dict",
2243            "tensorflow_js",
2244            "tensorflow_saved_model_bundle",
2245            "torchscript",
2246        ],
2247    ):
2248        if key == "keras_hdf5":
2249            ret = self.keras_hdf5
2250        elif key == "onnx":
2251            ret = self.onnx
2252        elif key == "pytorch_state_dict":
2253            ret = self.pytorch_state_dict
2254        elif key == "tensorflow_js":
2255            ret = self.tensorflow_js
2256        elif key == "tensorflow_saved_model_bundle":
2257            ret = self.tensorflow_saved_model_bundle
2258        elif key == "torchscript":
2259            ret = self.torchscript
2260        else:
2261            raise KeyError(key)
2262
2263        if ret is None:
2264            raise KeyError(key)
2265
2266        return ret
2267
2268    @property
2269    def available_formats(self):
2270        return {
2271            **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}),
2272            **({} if self.onnx is None else {"onnx": self.onnx}),
2273            **(
2274                {}
2275                if self.pytorch_state_dict is None
2276                else {"pytorch_state_dict": self.pytorch_state_dict}
2277            ),
2278            **(
2279                {}
2280                if self.tensorflow_js is None
2281                else {"tensorflow_js": self.tensorflow_js}
2282            ),
2283            **(
2284                {}
2285                if self.tensorflow_saved_model_bundle is None
2286                else {
2287                    "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle
2288                }
2289            ),
2290            **({} if self.torchscript is None else {"torchscript": self.torchscript}),
2291        }
2292
2293    @property
2294    def missing_formats(self):
2295        return {
2296            wf for wf in get_args(WeightsFormat) if wf not in self.available_formats
2297        }
keras_hdf5: Optional[KerasHdf5WeightsDescr]
onnx: Optional[OnnxWeightsDescr]
pytorch_state_dict: Optional[PytorchStateDictWeightsDescr]
tensorflow_js: Optional[TensorflowJsWeightsDescr]
tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr]
torchscript: Optional[TorchscriptWeightsDescr]
@model_validator(mode='after')
def check_entries(self) -> Self:
2197    @model_validator(mode="after")
2198    def check_entries(self) -> Self:
2199        entries = {wtype for wtype, entry in self if entry is not None}
2200
2201        if not entries:
2202            raise ValueError("Missing weights entry")
2203
2204        entries_wo_parent = {
2205            wtype
2206            for wtype, entry in self
2207            if entry is not None and hasattr(entry, "parent") and entry.parent is None
2208        }
2209        if len(entries_wo_parent) != 1:
2210            issue_warning(
2211                "Exactly one weights entry may not specify the `parent` field (got"
2212                + " {value}). That entry is considered the original set of model weights."
2213                + " Other weight formats are created through conversion of the orignal or"
2214                + " already converted weights. They have to reference the weights format"
2215                + " they were converted from as their `parent`.",
2216                value=len(entries_wo_parent),
2217                field="weights",
2218            )
2219
2220        for wtype, entry in self:
2221            if entry is None:
2222                continue
2223
2224            assert hasattr(entry, "type")
2225            assert hasattr(entry, "parent")
2226            assert wtype == entry.type
2227            if (
2228                entry.parent is not None and entry.parent not in entries
2229            ):  # self reference checked for `parent` field
2230                raise ValueError(
2231                    f"`weights.{wtype}.parent={entry.parent} not in specified weight"
2232                    + f" formats: {entries}"
2233                )
2234
2235        return self
available_formats
2268    @property
2269    def available_formats(self):
2270        return {
2271            **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}),
2272            **({} if self.onnx is None else {"onnx": self.onnx}),
2273            **(
2274                {}
2275                if self.pytorch_state_dict is None
2276                else {"pytorch_state_dict": self.pytorch_state_dict}
2277            ),
2278            **(
2279                {}
2280                if self.tensorflow_js is None
2281                else {"tensorflow_js": self.tensorflow_js}
2282            ),
2283            **(
2284                {}
2285                if self.tensorflow_saved_model_bundle is None
2286                else {
2287                    "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle
2288                }
2289            ),
2290            **({} if self.torchscript is None else {"torchscript": self.torchscript}),
2291        }
missing_formats
2293    @property
2294    def missing_formats(self):
2295        return {
2296            wf for wf in get_args(WeightsFormat) if wf not in self.available_formats
2297        }
class ModelId(bioimageio.spec.generic.v0_3.ResourceId):
2300class ModelId(ResourceId):
2301    pass

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.

class LinkedModel(bioimageio.spec.generic.v0_3.LinkedResourceBase):
2304class LinkedModel(LinkedResourceBase):
2305    """Reference to a bioimage.io model."""
2306
2307    id: ModelId
2308    """A valid model `id` from the bioimage.io collection."""

Reference to a bioimage.io model.

id: ModelId

A valid model id from the bioimage.io collection.

2330class ModelDescr(GenericModelDescrBase):
2331    """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights.
2332    These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
2333    """
2334
2335    format_version: Literal["0.5.3"] = "0.5.3"
2336    """Version of the bioimage.io model description specification used.
2337    When creating a new model always use the latest micro/patch version described here.
2338    The `format_version` is important for any consumer software to understand how to parse the fields.
2339    """
2340
2341    type: Literal["model"] = "model"
2342    """Specialized resource type 'model'"""
2343
2344    id: Optional[ModelId] = None
2345    """bioimage.io-wide unique resource identifier
2346    assigned by bioimage.io; version **un**specific."""
2347
2348    authors: NotEmpty[List[Author]]
2349    """The authors are the creators of the model RDF and the primary points of contact."""
2350
2351    documentation: Annotated[
2352        DocumentationSource,
2353        Field(
2354            examples=[
2355                "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/README.md",
2356                "README.md",
2357            ],
2358        ),
2359    ]
2360    """∈📦 URL or relative path to a markdown file with additional documentation.
2361    The recommended documentation file name is `README.md`. An `.md` suffix is mandatory.
2362    The documentation should include a '#[#] Validation' (sub)section
2363    with details on how to quantitatively validate the model on unseen data."""
2364
2365    @field_validator("documentation", mode="after")
2366    @classmethod
2367    def _validate_documentation(cls, value: DocumentationSource) -> DocumentationSource:
2368        if not validation_context_var.get().perform_io_checks:
2369            return value
2370
2371        doc_path = download(value).path
2372        doc_content = doc_path.read_text(encoding="utf-8")
2373        assert isinstance(doc_content, str)
2374        if not re.match("#.*[vV]alidation", doc_content):
2375            issue_warning(
2376                "No '# Validation' (sub)section found in {value}.",
2377                value=value,
2378                field="documentation",
2379            )
2380
2381        return value
2382
2383    inputs: NotEmpty[Sequence[InputTensorDescr]]
2384    """Describes the input tensors expected by this model."""
2385
2386    @field_validator("inputs", mode="after")
2387    @classmethod
2388    def _validate_input_axes(
2389        cls, inputs: Sequence[InputTensorDescr]
2390    ) -> Sequence[InputTensorDescr]:
2391        input_size_refs = cls._get_axes_with_independent_size(inputs)
2392
2393        for i, ipt in enumerate(inputs):
2394            valid_independent_refs: Dict[
2395                Tuple[TensorId, AxisId],
2396                Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2397            ] = {
2398                **{
2399                    (ipt.id, a.id): (ipt, a, a.size)
2400                    for a in ipt.axes
2401                    if not isinstance(a, BatchAxis)
2402                    and isinstance(a.size, (int, ParameterizedSize))
2403                },
2404                **input_size_refs,
2405            }
2406            for a, ax in enumerate(ipt.axes):
2407                cls._validate_axis(
2408                    "inputs",
2409                    i=i,
2410                    tensor_id=ipt.id,
2411                    a=a,
2412                    axis=ax,
2413                    valid_independent_refs=valid_independent_refs,
2414                )
2415        return inputs
2416
2417    @staticmethod
2418    def _validate_axis(
2419        field_name: str,
2420        i: int,
2421        tensor_id: TensorId,
2422        a: int,
2423        axis: AnyAxis,
2424        valid_independent_refs: Dict[
2425            Tuple[TensorId, AxisId],
2426            Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2427        ],
2428    ):
2429        if isinstance(axis, BatchAxis) or isinstance(
2430            axis.size, (int, ParameterizedSize, DataDependentSize)
2431        ):
2432            return
2433        elif not isinstance(axis.size, SizeReference):
2434            assert_never(axis.size)
2435
2436        # validate axis.size SizeReference
2437        ref = (axis.size.tensor_id, axis.size.axis_id)
2438        if ref not in valid_independent_refs:
2439            raise ValueError(
2440                "Invalid tensor axis reference at"
2441                + f" {field_name}[{i}].axes[{a}].size: {axis.size}."
2442            )
2443        if ref == (tensor_id, axis.id):
2444            raise ValueError(
2445                "Self-referencing not allowed for"
2446                + f" {field_name}[{i}].axes[{a}].size: {axis.size}"
2447            )
2448        if axis.type == "channel":
2449            if valid_independent_refs[ref][1].type != "channel":
2450                raise ValueError(
2451                    "A channel axis' size may only reference another fixed size"
2452                    + " channel axis."
2453                )
2454            if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names:
2455                ref_size = valid_independent_refs[ref][2]
2456                assert isinstance(ref_size, int), (
2457                    "channel axis ref (another channel axis) has to specify fixed"
2458                    + " size"
2459                )
2460                generated_channel_names = [
2461                    Identifier(axis.channel_names.format(i=i))
2462                    for i in range(1, ref_size + 1)
2463                ]
2464                axis.channel_names = generated_channel_names
2465
2466        if (ax_unit := getattr(axis, "unit", None)) != (
2467            ref_unit := getattr(valid_independent_refs[ref][1], "unit", None)
2468        ):
2469            raise ValueError(
2470                "The units of an axis and its reference axis need to match, but"
2471                + f" '{ax_unit}' != '{ref_unit}'."
2472            )
2473        ref_axis = valid_independent_refs[ref][1]
2474        if isinstance(ref_axis, BatchAxis):
2475            raise ValueError(
2476                f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}"
2477                + " (a batch axis is not allowed as reference)."
2478            )
2479
2480        if isinstance(axis, WithHalo):
2481            min_size = axis.size.get_size(axis, ref_axis, n=0)
2482            if (min_size - 2 * axis.halo) < 1:
2483                raise ValueError(
2484                    f"axis {axis.id} with minimum size {min_size} is too small for halo"
2485                    + f" {axis.halo}."
2486                )
2487
2488            input_halo = axis.halo * axis.scale / ref_axis.scale
2489            if input_halo != int(input_halo) or input_halo % 2 == 1:
2490                raise ValueError(
2491                    f"input_halo {input_halo} (output_halo {axis.halo} *"
2492                    + f" output_scale {axis.scale} / input_scale {ref_axis.scale})"
2493                    + f" is not an even integer for {tensor_id}.{axis.id}."
2494                )
2495
2496    @model_validator(mode="after")
2497    def _validate_test_tensors(self) -> Self:
2498        if not validation_context_var.get().perform_io_checks:
2499            return self
2500
2501        test_arrays = [
2502            load_array(descr.test_tensor.download().path)
2503            for descr in chain(self.inputs, self.outputs)
2504        ]
2505        tensors = {
2506            descr.id: (descr, array)
2507            for descr, array in zip(chain(self.inputs, self.outputs), test_arrays)
2508        }
2509        validate_tensors(tensors, tensor_origin="test_tensor")
2510        return self
2511
2512    @model_validator(mode="after")
2513    def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self:
2514        ipt_refs = {t.id for t in self.inputs}
2515        out_refs = {t.id for t in self.outputs}
2516        for ipt in self.inputs:
2517            for p in ipt.preprocessing:
2518                ref = p.kwargs.get("reference_tensor")
2519                if ref is None:
2520                    continue
2521                if ref not in ipt_refs:
2522                    raise ValueError(
2523                        f"`reference_tensor` '{ref}' not found. Valid input tensor"
2524                        + f" references are: {ipt_refs}."
2525                    )
2526
2527        for out in self.outputs:
2528            for p in out.postprocessing:
2529                ref = p.kwargs.get("reference_tensor")
2530                if ref is None:
2531                    continue
2532
2533                if ref not in ipt_refs and ref not in out_refs:
2534                    raise ValueError(
2535                        f"`reference_tensor` '{ref}' not found. Valid tensor references"
2536                        + f" are: {ipt_refs | out_refs}."
2537                    )
2538
2539        return self
2540
2541    # TODO: use validate funcs in validate_test_tensors
2542    # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]:
2543
2544    name: Annotated[
2545        Annotated[
2546            str, RestrictCharacters(string.ascii_letters + string.digits + "_- ()")
2547        ],
2548        MinLen(5),
2549        MaxLen(128),
2550        warn(MaxLen(64), "Name longer than 64 characters.", INFO),
2551    ]
2552    """A human-readable name of this model.
2553    It should be no longer than 64 characters
2554    and may only contain letter, number, underscore, minus, parentheses and spaces.
2555    We recommend to chose a name that refers to the model's task and image modality.
2556    """
2557
2558    outputs: NotEmpty[Sequence[OutputTensorDescr]]
2559    """Describes the output tensors."""
2560
2561    @field_validator("outputs", mode="after")
2562    @classmethod
2563    def _validate_tensor_ids(
2564        cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo
2565    ) -> Sequence[OutputTensorDescr]:
2566        tensor_ids = [
2567            t.id for t in info.data.get("inputs", []) + info.data.get("outputs", [])
2568        ]
2569        duplicate_tensor_ids: List[str] = []
2570        seen: Set[str] = set()
2571        for t in tensor_ids:
2572            if t in seen:
2573                duplicate_tensor_ids.append(t)
2574
2575            seen.add(t)
2576
2577        if duplicate_tensor_ids:
2578            raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}")
2579
2580        return outputs
2581
2582    @staticmethod
2583    def _get_axes_with_parameterized_size(
2584        io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2585    ):
2586        return {
2587            f"{t.id}.{a.id}": (t, a, a.size)
2588            for t in io
2589            for a in t.axes
2590            if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize)
2591        }
2592
2593    @staticmethod
2594    def _get_axes_with_independent_size(
2595        io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2596    ):
2597        return {
2598            (t.id, a.id): (t, a, a.size)
2599            for t in io
2600            for a in t.axes
2601            if not isinstance(a, BatchAxis)
2602            and isinstance(a.size, (int, ParameterizedSize))
2603        }
2604
2605    @field_validator("outputs", mode="after")
2606    @classmethod
2607    def _validate_output_axes(
2608        cls, outputs: List[OutputTensorDescr], info: ValidationInfo
2609    ) -> List[OutputTensorDescr]:
2610        input_size_refs = cls._get_axes_with_independent_size(
2611            info.data.get("inputs", [])
2612        )
2613        output_size_refs = cls._get_axes_with_independent_size(outputs)
2614
2615        for i, out in enumerate(outputs):
2616            valid_independent_refs: Dict[
2617                Tuple[TensorId, AxisId],
2618                Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2619            ] = {
2620                **{
2621                    (out.id, a.id): (out, a, a.size)
2622                    for a in out.axes
2623                    if not isinstance(a, BatchAxis)
2624                    and isinstance(a.size, (int, ParameterizedSize))
2625                },
2626                **input_size_refs,
2627                **output_size_refs,
2628            }
2629            for a, ax in enumerate(out.axes):
2630                cls._validate_axis(
2631                    "outputs",
2632                    i,
2633                    out.id,
2634                    a,
2635                    ax,
2636                    valid_independent_refs=valid_independent_refs,
2637                )
2638
2639        return outputs
2640
2641    packaged_by: List[Author] = Field(default_factory=list)
2642    """The persons that have packaged and uploaded this model.
2643    Only required if those persons differ from the `authors`."""
2644
2645    parent: Optional[LinkedModel] = None
2646    """The model from which this model is derived, e.g. by fine-tuning the weights."""
2647
2648    # todo: add parent self check once we have `id`
2649    # @model_validator(mode="after")
2650    # def validate_parent_is_not_self(self) -> Self:
2651    #     if self.parent is not None and self.parent == self.id:
2652    #         raise ValueError("The model may not reference itself as parent model")
2653
2654    #     return self
2655
2656    run_mode: Annotated[
2657        Optional[RunMode],
2658        warn(None, "Run mode '{value}' has limited support across consumer softwares."),
2659    ] = None
2660    """Custom run mode for this model: for more complex prediction procedures like test time
2661    data augmentation that currently cannot be expressed in the specification.
2662    No standard run modes are defined yet."""
2663
2664    timestamp: Datetime = Datetime(datetime.now())
2665    """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format
2666    with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat).
2667    (In Python a datetime object is valid, too)."""
2668
2669    training_data: Annotated[
2670        Union[None, LinkedDataset, DatasetDescr, DatasetDescr02],
2671        Field(union_mode="left_to_right"),
2672    ] = None
2673    """The dataset used to train this model"""
2674
2675    weights: Annotated[WeightsDescr, WrapSerializer(package_weights)]
2676    """The weights for this model.
2677    Weights can be given for different formats, but should otherwise be equivalent.
2678    The available weight formats determine which consumers can use this model."""
2679
2680    @model_validator(mode="after")
2681    def _add_default_cover(self) -> Self:
2682        if not validation_context_var.get().perform_io_checks or self.covers:
2683            return self
2684
2685        try:
2686            generated_covers = generate_covers(
2687                [(t, load_array(t.test_tensor.download().path)) for t in self.inputs],
2688                [(t, load_array(t.test_tensor.download().path)) for t in self.outputs],
2689            )
2690        except Exception as e:
2691            issue_warning(
2692                "Failed to generate cover image(s): {e}",
2693                value=self.covers,
2694                msg_context=dict(e=e),
2695                field="covers",
2696            )
2697        else:
2698            self.covers.extend(generated_covers)
2699
2700        return self
2701
2702    def get_input_test_arrays(self) -> List[NDArray[Any]]:
2703        data = [load_array(ipt.test_tensor.download().path) for ipt in self.inputs]
2704        assert all(isinstance(d, np.ndarray) for d in data)
2705        return data
2706
2707    def get_output_test_arrays(self) -> List[NDArray[Any]]:
2708        data = [load_array(out.test_tensor.download().path) for out in self.outputs]
2709        assert all(isinstance(d, np.ndarray) for d in data)
2710        return data
2711
2712    @staticmethod
2713    def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
2714        batch_size = 1
2715        tensor_with_batchsize: Optional[TensorId] = None
2716        for tid in tensor_sizes:
2717            for aid, s in tensor_sizes[tid].items():
2718                if aid != BATCH_AXIS_ID or s == 1 or s == batch_size:
2719                    continue
2720
2721                if batch_size != 1:
2722                    assert tensor_with_batchsize is not None
2723                    raise ValueError(
2724                        f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})"
2725                    )
2726
2727                batch_size = s
2728                tensor_with_batchsize = tid
2729
2730        return batch_size
2731
2732    def get_output_tensor_sizes(
2733        self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]
2734    ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]:
2735        """Returns the tensor output sizes for given **input_sizes**.
2736        Only if **input_sizes** has a valid input shape, the tensor output size is exact.
2737        Otherwise it might be larger than the actual (valid) output"""
2738        batch_size = self.get_batch_size(input_sizes)
2739        ns = self.get_ns(input_sizes)
2740
2741        tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size)
2742        return tensor_sizes.outputs
2743
2744    def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
2745        """get parameter `n` for each parameterized axis
2746        such that the valid input size is >= the given input size"""
2747        ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {}
2748        axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs}
2749        for tid in input_sizes:
2750            for aid, s in input_sizes[tid].items():
2751                size_descr = axes[tid][aid].size
2752                if isinstance(size_descr, ParameterizedSize):
2753                    ret[(tid, aid)] = size_descr.get_n(s)
2754                elif size_descr is None or isinstance(size_descr, (int, SizeReference)):
2755                    pass
2756                else:
2757                    assert_never(size_descr)
2758
2759        return ret
2760
2761    def get_tensor_sizes(
2762        self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int
2763    ) -> _TensorSizes:
2764        axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size)
2765        return _TensorSizes(
2766            {
2767                t: {
2768                    aa: axis_sizes.inputs[(tt, aa)]
2769                    for tt, aa in axis_sizes.inputs
2770                    if tt == t
2771                }
2772                for t in {tt for tt, _ in axis_sizes.inputs}
2773            },
2774            {
2775                t: {
2776                    aa: axis_sizes.outputs[(tt, aa)]
2777                    for tt, aa in axis_sizes.outputs
2778                    if tt == t
2779                }
2780                for t in {tt for tt, _ in axis_sizes.outputs}
2781            },
2782        )
2783
2784    def get_axis_sizes(
2785        self,
2786        ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
2787        batch_size: Optional[int] = None,
2788        *,
2789        max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
2790    ) -> _AxisSizes:
2791        """Determine input and output block shape for scale factors **ns**
2792        of parameterized input sizes.
2793
2794        Args:
2795            ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id))
2796                that is parameterized as `size = min + n * step`.
2797            batch_size: The desired size of the batch dimension.
2798                If given **batch_size** overwrites any batch size present in
2799                **max_input_shape**. Default 1.
2800            max_input_shape: Limits the derived block shapes.
2801                Each axis for which the input size, parameterized by `n`, is larger
2802                than **max_input_shape** is set to the minimal value `n_min` for which
2803                this is still true.
2804                Use this for small input samples or large values of **ns**.
2805                Or simply whenever you know the full input shape.
2806
2807        Returns:
2808            Resolved axis sizes for model inputs and outputs.
2809        """
2810        max_input_shape = max_input_shape or {}
2811        if batch_size is None:
2812            for (_t_id, a_id), s in max_input_shape.items():
2813                if a_id == BATCH_AXIS_ID:
2814                    batch_size = s
2815                    break
2816            else:
2817                batch_size = 1
2818
2819        all_axes = {
2820            t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
2821        }
2822
2823        inputs: Dict[Tuple[TensorId, AxisId], int] = {}
2824        outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {}
2825
2826        def get_axis_size(a: Union[InputAxis, OutputAxis]):
2827            if isinstance(a, BatchAxis):
2828                if (t_descr.id, a.id) in ns:
2829                    logger.warning(
2830                        "Ignoring unexpected size increment factor (n) for batch axis"
2831                        + " of tensor '{}'.",
2832                        t_descr.id,
2833                    )
2834                return batch_size
2835            elif isinstance(a.size, int):
2836                if (t_descr.id, a.id) in ns:
2837                    logger.warning(
2838                        "Ignoring unexpected size increment factor (n) for fixed size"
2839                        + " axis '{}' of tensor '{}'.",
2840                        a.id,
2841                        t_descr.id,
2842                    )
2843                return a.size
2844            elif isinstance(a.size, ParameterizedSize):
2845                if (t_descr.id, a.id) not in ns:
2846                    raise ValueError(
2847                        "Size increment factor (n) missing for parametrized axis"
2848                        + f" '{a.id}' of tensor '{t_descr.id}'."
2849                    )
2850                n = ns[(t_descr.id, a.id)]
2851                s_max = max_input_shape.get((t_descr.id, a.id))
2852                if s_max is not None:
2853                    n = min(n, a.size.get_n(s_max))
2854
2855                return a.size.get_size(n)
2856
2857            elif isinstance(a.size, SizeReference):
2858                if (t_descr.id, a.id) in ns:
2859                    logger.warning(
2860                        "Ignoring unexpected size increment factor (n) for axis '{}'"
2861                        + " of tensor '{}' with size reference.",
2862                        a.id,
2863                        t_descr.id,
2864                    )
2865                assert not isinstance(a, BatchAxis)
2866                ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
2867                assert not isinstance(ref_axis, BatchAxis)
2868                ref_key = (a.size.tensor_id, a.size.axis_id)
2869                ref_size = inputs.get(ref_key, outputs.get(ref_key))
2870                assert ref_size is not None, ref_key
2871                assert not isinstance(ref_size, _DataDepSize), ref_key
2872                return a.size.get_size(
2873                    axis=a,
2874                    ref_axis=ref_axis,
2875                    ref_size=ref_size,
2876                )
2877            elif isinstance(a.size, DataDependentSize):
2878                if (t_descr.id, a.id) in ns:
2879                    logger.warning(
2880                        "Ignoring unexpected increment factor (n) for data dependent"
2881                        + " size axis '{}' of tensor '{}'.",
2882                        a.id,
2883                        t_descr.id,
2884                    )
2885                return _DataDepSize(a.size.min, a.size.max)
2886            else:
2887                assert_never(a.size)
2888
2889        # first resolve all , but the `SizeReference` input sizes
2890        for t_descr in self.inputs:
2891            for a in t_descr.axes:
2892                if not isinstance(a.size, SizeReference):
2893                    s = get_axis_size(a)
2894                    assert not isinstance(s, _DataDepSize)
2895                    inputs[t_descr.id, a.id] = s
2896
2897        # resolve all other input axis sizes
2898        for t_descr in self.inputs:
2899            for a in t_descr.axes:
2900                if isinstance(a.size, SizeReference):
2901                    s = get_axis_size(a)
2902                    assert not isinstance(s, _DataDepSize)
2903                    inputs[t_descr.id, a.id] = s
2904
2905        # resolve all output axis sizes
2906        for t_descr in self.outputs:
2907            for a in t_descr.axes:
2908                assert not isinstance(a.size, ParameterizedSize)
2909                s = get_axis_size(a)
2910                outputs[t_descr.id, a.id] = s
2911
2912        return _AxisSizes(inputs=inputs, outputs=outputs)
2913
2914    @model_validator(mode="before")
2915    @classmethod
2916    def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]:
2917        if (
2918            data.get("type") == "model"
2919            and isinstance(fv := data.get("format_version"), str)
2920            and fv.count(".") == 2
2921        ):
2922            fv_parts = fv.split(".")
2923            if any(not p.isdigit() for p in fv_parts):
2924                return data
2925
2926            fv_tuple = tuple(map(int, fv_parts))
2927
2928            assert cls.implemented_format_version_tuple[0:2] == (0, 5)
2929            if fv_tuple[:2] in ((0, 3), (0, 4)):
2930                m04 = _ModelDescr_v0_4.load(data)
2931                if not isinstance(m04, InvalidDescr):
2932                    return _model_conv.convert_as_dict(m04)
2933            elif fv_tuple[:2] == (0, 5):
2934                # bump patch version
2935                data["format_version"] = cls.implemented_format_version
2936
2937        return data

Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. These fields are typically stored in a YAML file which we call a model resource description file (model RDF).

format_version: Literal['0.5.3']

Version of the bioimage.io model description specification used. When creating a new model always use the latest micro/patch version described here. The format_version is important for any consumer software to understand how to parse the fields.

type: Literal['model']

Specialized resource type 'model'

id: Optional[ModelId]

bioimage.io-wide unique resource identifier assigned by bioimage.io; version unspecific.

authors: Annotated[List[bioimageio.spec.generic.v0_3.Author], MinLen(min_length=1)]

The authors are the creators of the model RDF and the primary points of contact.

documentation: Annotated[Union[Annotated[pathlib.Path, PathType(path_type='file'), Predicate(is_absolute), FieldInfo(annotation=NoneType, required=True, title='AbsoluteFilePath')], bioimageio.spec._internal.io.RelativeFilePath, bioimageio.spec._internal.url.HttpUrl], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')]), AfterValidator(func=<function _validate_md_suffix at 0x7fec14a24400>), PlainSerializer(func=<function _package at 0x7fec152e2a20>, return_type=PydanticUndefined, when_used='unless-none'), FieldInfo(annotation=NoneType, required=True, examples=['https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/README.md', 'README.md'])]

∈📦 URL or relative path to a markdown file with additional documentation. The recommended documentation file name is README.md. An .md suffix is mandatory. The documentation should include a '#[#] Validation' (sub)section with details on how to quantitatively validate the model on unseen data.

inputs: Annotated[Sequence[InputTensorDescr], MinLen(min_length=1)]

Describes the input tensors expected by this model.

name: Annotated[str, RestrictCharacters(alphabet='abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_- ()'), MinLen(min_length=5), MaxLen(max_length=128), AfterWarner(func=<function as_warning.<locals>.wrapper at 0x7fec07ebdf80>, severity=20, msg='Name longer than 64 characters.', context={'typ': Annotated[Any, MaxLen(max_length=64)]})]

A human-readable name of this model. It should be no longer than 64 characters and may only contain letter, number, underscore, minus, parentheses and spaces. We recommend to chose a name that refers to the model's task and image modality.

outputs: Annotated[Sequence[OutputTensorDescr], MinLen(min_length=1)]

Describes the output tensors.

The persons that have packaged and uploaded this model. Only required if those persons differ from the authors.

parent: Optional[LinkedModel]

The model from which this model is derived, e.g. by fine-tuning the weights.

run_mode: Annotated[Optional[bioimageio.spec.model.v0_4.RunMode], AfterWarner(func=<function as_warning.<locals>.wrapper at 0x7fec07ebd440>, severity=30, msg="Run mode '{value}' has limited support across consumer softwares.", context={'typ': None})]

Custom run mode for this model: for more complex prediction procedures like test time data augmentation that currently cannot be expressed in the specification. No standard run modes are defined yet.

Timestamp in ISO 8601 format with a few restrictions listed here. (In Python a datetime object is valid, too).

training_data: Annotated[Union[NoneType, bioimageio.spec.dataset.v0_3.LinkedDataset, bioimageio.spec.DatasetDescr, bioimageio.spec.dataset.v0_2.DatasetDescr], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')])]

The dataset used to train this model

weights: Annotated[WeightsDescr, WrapSerializer(func=<function package_weights at 0x7fec14ab2d40>, return_type=PydanticUndefined, when_used='always')]

The weights for this model. Weights can be given for different formats, but should otherwise be equivalent. The available weight formats determine which consumers can use this model.

def get_input_test_arrays(self) -> List[numpy.ndarray[tuple[int, ...], numpy.dtype[Any]]]:
2702    def get_input_test_arrays(self) -> List[NDArray[Any]]:
2703        data = [load_array(ipt.test_tensor.download().path) for ipt in self.inputs]
2704        assert all(isinstance(d, np.ndarray) for d in data)
2705        return data
def get_output_test_arrays(self) -> List[numpy.ndarray[tuple[int, ...], numpy.dtype[Any]]]:
2707    def get_output_test_arrays(self) -> List[NDArray[Any]]:
2708        data = [load_array(out.test_tensor.download().path) for out in self.outputs]
2709        assert all(isinstance(d, np.ndarray) for d in data)
2710        return data
@staticmethod
def get_batch_size( tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
2712    @staticmethod
2713    def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
2714        batch_size = 1
2715        tensor_with_batchsize: Optional[TensorId] = None
2716        for tid in tensor_sizes:
2717            for aid, s in tensor_sizes[tid].items():
2718                if aid != BATCH_AXIS_ID or s == 1 or s == batch_size:
2719                    continue
2720
2721                if batch_size != 1:
2722                    assert tensor_with_batchsize is not None
2723                    raise ValueError(
2724                        f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})"
2725                    )
2726
2727                batch_size = s
2728                tensor_with_batchsize = tid
2729
2730        return batch_size
def get_output_tensor_sizes( self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> Dict[TensorId, Dict[AxisId, Union[int, bioimageio.spec.model.v0_5._DataDepSize]]]:
2732    def get_output_tensor_sizes(
2733        self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]
2734    ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]:
2735        """Returns the tensor output sizes for given **input_sizes**.
2736        Only if **input_sizes** has a valid input shape, the tensor output size is exact.
2737        Otherwise it might be larger than the actual (valid) output"""
2738        batch_size = self.get_batch_size(input_sizes)
2739        ns = self.get_ns(input_sizes)
2740
2741        tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size)
2742        return tensor_sizes.outputs

Returns the tensor output sizes for given input_sizes. Only if input_sizes has a valid input shape, the tensor output size is exact. Otherwise it might be larger than the actual (valid) output

def get_ns( self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
2744    def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
2745        """get parameter `n` for each parameterized axis
2746        such that the valid input size is >= the given input size"""
2747        ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {}
2748        axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs}
2749        for tid in input_sizes:
2750            for aid, s in input_sizes[tid].items():
2751                size_descr = axes[tid][aid].size
2752                if isinstance(size_descr, ParameterizedSize):
2753                    ret[(tid, aid)] = size_descr.get_n(s)
2754                elif size_descr is None or isinstance(size_descr, (int, SizeReference)):
2755                    pass
2756                else:
2757                    assert_never(size_descr)
2758
2759        return ret

get parameter n for each parameterized axis such that the valid input size is >= the given input size

def get_tensor_sizes( self, ns: Mapping[Tuple[TensorId, AxisId], int], batch_size: int) -> bioimageio.spec.model.v0_5._TensorSizes:
2761    def get_tensor_sizes(
2762        self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int
2763    ) -> _TensorSizes:
2764        axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size)
2765        return _TensorSizes(
2766            {
2767                t: {
2768                    aa: axis_sizes.inputs[(tt, aa)]
2769                    for tt, aa in axis_sizes.inputs
2770                    if tt == t
2771                }
2772                for t in {tt for tt, _ in axis_sizes.inputs}
2773            },
2774            {
2775                t: {
2776                    aa: axis_sizes.outputs[(tt, aa)]
2777                    for tt, aa in axis_sizes.outputs
2778                    if tt == t
2779                }
2780                for t in {tt for tt, _ in axis_sizes.outputs}
2781            },
2782        )
def get_axis_sizes( self, ns: Mapping[Tuple[TensorId, AxisId], int], batch_size: Optional[int] = None, *, max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None) -> bioimageio.spec.model.v0_5._AxisSizes:
2784    def get_axis_sizes(
2785        self,
2786        ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
2787        batch_size: Optional[int] = None,
2788        *,
2789        max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
2790    ) -> _AxisSizes:
2791        """Determine input and output block shape for scale factors **ns**
2792        of parameterized input sizes.
2793
2794        Args:
2795            ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id))
2796                that is parameterized as `size = min + n * step`.
2797            batch_size: The desired size of the batch dimension.
2798                If given **batch_size** overwrites any batch size present in
2799                **max_input_shape**. Default 1.
2800            max_input_shape: Limits the derived block shapes.
2801                Each axis for which the input size, parameterized by `n`, is larger
2802                than **max_input_shape** is set to the minimal value `n_min` for which
2803                this is still true.
2804                Use this for small input samples or large values of **ns**.
2805                Or simply whenever you know the full input shape.
2806
2807        Returns:
2808            Resolved axis sizes for model inputs and outputs.
2809        """
2810        max_input_shape = max_input_shape or {}
2811        if batch_size is None:
2812            for (_t_id, a_id), s in max_input_shape.items():
2813                if a_id == BATCH_AXIS_ID:
2814                    batch_size = s
2815                    break
2816            else:
2817                batch_size = 1
2818
2819        all_axes = {
2820            t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
2821        }
2822
2823        inputs: Dict[Tuple[TensorId, AxisId], int] = {}
2824        outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {}
2825
2826        def get_axis_size(a: Union[InputAxis, OutputAxis]):
2827            if isinstance(a, BatchAxis):
2828                if (t_descr.id, a.id) in ns:
2829                    logger.warning(
2830                        "Ignoring unexpected size increment factor (n) for batch axis"
2831                        + " of tensor '{}'.",
2832                        t_descr.id,
2833                    )
2834                return batch_size
2835            elif isinstance(a.size, int):
2836                if (t_descr.id, a.id) in ns:
2837                    logger.warning(
2838                        "Ignoring unexpected size increment factor (n) for fixed size"
2839                        + " axis '{}' of tensor '{}'.",
2840                        a.id,
2841                        t_descr.id,
2842                    )
2843                return a.size
2844            elif isinstance(a.size, ParameterizedSize):
2845                if (t_descr.id, a.id) not in ns:
2846                    raise ValueError(
2847                        "Size increment factor (n) missing for parametrized axis"
2848                        + f" '{a.id}' of tensor '{t_descr.id}'."
2849                    )
2850                n = ns[(t_descr.id, a.id)]
2851                s_max = max_input_shape.get((t_descr.id, a.id))
2852                if s_max is not None:
2853                    n = min(n, a.size.get_n(s_max))
2854
2855                return a.size.get_size(n)
2856
2857            elif isinstance(a.size, SizeReference):
2858                if (t_descr.id, a.id) in ns:
2859                    logger.warning(
2860                        "Ignoring unexpected size increment factor (n) for axis '{}'"
2861                        + " of tensor '{}' with size reference.",
2862                        a.id,
2863                        t_descr.id,
2864                    )
2865                assert not isinstance(a, BatchAxis)
2866                ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
2867                assert not isinstance(ref_axis, BatchAxis)
2868                ref_key = (a.size.tensor_id, a.size.axis_id)
2869                ref_size = inputs.get(ref_key, outputs.get(ref_key))
2870                assert ref_size is not None, ref_key
2871                assert not isinstance(ref_size, _DataDepSize), ref_key
2872                return a.size.get_size(
2873                    axis=a,
2874                    ref_axis=ref_axis,
2875                    ref_size=ref_size,
2876                )
2877            elif isinstance(a.size, DataDependentSize):
2878                if (t_descr.id, a.id) in ns:
2879                    logger.warning(
2880                        "Ignoring unexpected increment factor (n) for data dependent"
2881                        + " size axis '{}' of tensor '{}'.",
2882                        a.id,
2883                        t_descr.id,
2884                    )
2885                return _DataDepSize(a.size.min, a.size.max)
2886            else:
2887                assert_never(a.size)
2888
2889        # first resolve all , but the `SizeReference` input sizes
2890        for t_descr in self.inputs:
2891            for a in t_descr.axes:
2892                if not isinstance(a.size, SizeReference):
2893                    s = get_axis_size(a)
2894                    assert not isinstance(s, _DataDepSize)
2895                    inputs[t_descr.id, a.id] = s
2896
2897        # resolve all other input axis sizes
2898        for t_descr in self.inputs:
2899            for a in t_descr.axes:
2900                if isinstance(a.size, SizeReference):
2901                    s = get_axis_size(a)
2902                    assert not isinstance(s, _DataDepSize)
2903                    inputs[t_descr.id, a.id] = s
2904
2905        # resolve all output axis sizes
2906        for t_descr in self.outputs:
2907            for a in t_descr.axes:
2908                assert not isinstance(a.size, ParameterizedSize)
2909                s = get_axis_size(a)
2910                outputs[t_descr.id, a.id] = s
2911
2912        return _AxisSizes(inputs=inputs, outputs=outputs)

Determine input and output block shape for scale factors ns of parameterized input sizes.

Arguments:
  • ns: Scale factor n for each axis (keyed by (tensor_id, axis_id)) that is parameterized as size = min + n * step.
  • batch_size: The desired size of the batch dimension. If given batch_size overwrites any batch size present in max_input_shape. Default 1.
  • max_input_shape: Limits the derived block shapes. Each axis for which the input size, parameterized by n, is larger than max_input_shape is set to the minimal value n_min for which this is still true. Use this for small input samples or large values of ns. Or simply whenever you know the full input shape.
Returns:

Resolved axis sizes for model inputs and outputs.

implemented_type: ClassVar[str] = 'model'
implemented_format_version: ClassVar[str] = '0.5.3'
implemented_format_version_tuple: ClassVar[Tuple[int, int, int]] = (0, 5, 3)
def model_post_init(self: pydantic.main.BaseModel, context: Any, /) -> None:
124                    def wrapped_model_post_init(self: BaseModel, context: Any, /) -> None:
125                        """We need to both initialize private attributes and call the user-defined model_post_init
126                        method.
127                        """
128                        init_private_attributes(self, context)
129                        original_model_post_init(self, context)

We need to both initialize private attributes and call the user-defined model_post_init method.

def generate_covers( inputs: Sequence[Tuple[InputTensorDescr, numpy.ndarray[tuple[int, ...], numpy.dtype[Any]]]], outputs: Sequence[Tuple[OutputTensorDescr, numpy.ndarray[tuple[int, ...], numpy.dtype[Any]]]]) -> List[pathlib.Path]:
3162def generate_covers(
3163    inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]],
3164    outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]],
3165) -> List[Path]:
3166    def squeeze(
3167        data: NDArray[Any], axes: Sequence[AnyAxis]
3168    ) -> Tuple[NDArray[Any], List[AnyAxis]]:
3169        """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining"""
3170        if data.ndim != len(axes):
3171            raise ValueError(
3172                f"tensor shape {data.shape} does not match described axes"
3173                + f" {[a.id for a in axes]}"
3174            )
3175
3176        axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1]
3177        return data.squeeze(), axes
3178
3179    def normalize(
3180        data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7
3181    ) -> NDArray[np.float32]:
3182        data = data.astype("float32")
3183        data -= data.min(axis=axis, keepdims=True)
3184        data /= data.max(axis=axis, keepdims=True) + eps
3185        return data
3186
3187    def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]):
3188        original_shape = data.shape
3189        data, axes = squeeze(data, axes)
3190
3191        # take slice fom any batch or index axis if needed
3192        # and convert the first channel axis and take a slice from any additional channel axes
3193        slices: Tuple[slice, ...] = ()
3194        ndim = data.ndim
3195        ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2
3196        has_c_axis = False
3197        for i, a in enumerate(axes):
3198            s = data.shape[i]
3199            assert s > 1
3200            if (
3201                isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis))
3202                and ndim > ndim_need
3203            ):
3204                data = data[slices + (slice(s // 2 - 1, s // 2),)]
3205                ndim -= 1
3206            elif isinstance(a, ChannelAxis):
3207                if has_c_axis:
3208                    # second channel axis
3209                    data = data[slices + (slice(0, 1),)]
3210                    ndim -= 1
3211                else:
3212                    has_c_axis = True
3213                    if s == 2:
3214                        # visualize two channels with cyan and magenta
3215                        data = np.concatenate(
3216                            [
3217                                data[slices + (slice(1, 2),)],
3218                                data[slices + (slice(0, 1),)],
3219                                (
3220                                    data[slices + (slice(0, 1),)]
3221                                    + data[slices + (slice(1, 2),)]
3222                                )
3223                                / 2,  # TODO: take maximum instead?
3224                            ],
3225                            axis=i,
3226                        )
3227                    elif data.shape[i] == 3:
3228                        pass  # visualize 3 channels as RGB
3229                    else:
3230                        # visualize first 3 channels as RGB
3231                        data = data[slices + (slice(3),)]
3232
3233                    assert data.shape[i] == 3
3234
3235            slices += (slice(None),)
3236
3237        data, axes = squeeze(data, axes)
3238        assert len(axes) == ndim
3239        # take slice from z axis if needed
3240        slices = ()
3241        if ndim > ndim_need:
3242            for i, a in enumerate(axes):
3243                s = data.shape[i]
3244                if a.id == AxisId("z"):
3245                    data = data[slices + (slice(s // 2 - 1, s // 2),)]
3246                    data, axes = squeeze(data, axes)
3247                    ndim -= 1
3248                    break
3249
3250            slices += (slice(None),)
3251
3252        # take slice from any space or time axis
3253        slices = ()
3254
3255        for i, a in enumerate(axes):
3256            if ndim <= ndim_need:
3257                break
3258
3259            s = data.shape[i]
3260            assert s > 1
3261            if isinstance(
3262                a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis)
3263            ):
3264                data = data[slices + (slice(s // 2 - 1, s // 2),)]
3265                ndim -= 1
3266
3267            slices += (slice(None),)
3268
3269        del slices
3270        data, axes = squeeze(data, axes)
3271        assert len(axes) == ndim
3272
3273        if (has_c_axis and ndim != 3) or ndim != 2:
3274            raise ValueError(
3275                f"Failed to construct cover image from shape {original_shape}"
3276            )
3277
3278        if not has_c_axis:
3279            assert ndim == 2
3280            data = np.repeat(data[:, :, None], 3, axis=2)
3281            axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB"))))
3282            ndim += 1
3283
3284        assert ndim == 3
3285
3286        # transpose axis order such that longest axis comes first...
3287        axis_order = list(np.argsort(list(data.shape)))
3288        axis_order.reverse()
3289        # ... and channel axis is last
3290        c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0]
3291        axis_order.append(axis_order.pop(c))
3292        axes = [axes[ao] for ao in axis_order]
3293        data = data.transpose(axis_order)
3294
3295        # h, w = data.shape[:2]
3296        # if h / w  in (1.0 or 2.0):
3297        #     pass
3298        # elif h / w < 2:
3299        # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images
3300
3301        norm_along = (
3302            tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None
3303        )
3304        # normalize the data and map to 8 bit
3305        data = normalize(data, norm_along)
3306        data = (data * 255).astype("uint8")
3307
3308        return data
3309
3310    def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]):
3311        assert im0.dtype == im1.dtype == np.uint8
3312        assert im0.shape == im1.shape
3313        assert im0.ndim == 3
3314        N, M, C = im0.shape
3315        assert C == 3
3316        out = np.ones((N, M, C), dtype="uint8")
3317        for c in range(C):
3318            outc = np.tril(im0[..., c])
3319            mask = outc == 0
3320            outc[mask] = np.triu(im1[..., c])[mask]
3321            out[..., c] = outc
3322
3323        return out
3324
3325    ipt_descr, ipt = inputs[0]
3326    out_descr, out = outputs[0]
3327
3328    ipt_img = to_2d_image(ipt, ipt_descr.axes)
3329    out_img = to_2d_image(out, out_descr.axes)
3330
3331    cover_folder = Path(mkdtemp())
3332    if ipt_img.shape == out_img.shape:
3333        covers = [cover_folder / "cover.png"]
3334        imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img))
3335    else:
3336        covers = [cover_folder / "input.png", cover_folder / "output.png"]
3337        imwrite(covers[0], ipt_img)
3338        imwrite(covers[1], out_img)
3339
3340    return covers
class TensorDescrBase[Annotated[Union[BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis], Discriminator]](bioimageio.spec._internal.node.Node, typing.Generic[~IO_AxisT]):
class TensorDescrBase[Annotated[Union[BatchAxis, ChannelAxis, IndexOutputAxis, Annotated[Union[Annotated[TimeOutputAxis, Tag], Annotated[TimeOutputAxisWithHalo, Tag]], Discriminator], Annotated[Union[Annotated[SpaceOutputAxis, Tag], Annotated[SpaceOutputAxisWithHalo, Tag]], Discriminator]], Discriminator]](bioimageio.spec._internal.node.Node, typing.Generic[~IO_AxisT]):