bioimageio.spec.model.v0_5

   1from __future__ import annotations
   2
   3import collections.abc
   4import re
   5import string
   6import warnings
   7from abc import ABC
   8from copy import deepcopy
   9from datetime import datetime
  10from itertools import chain
  11from math import ceil
  12from pathlib import Path, PurePosixPath
  13from tempfile import mkdtemp
  14from typing import (
  15    TYPE_CHECKING,
  16    Any,
  17    ClassVar,
  18    Dict,
  19    FrozenSet,
  20    Generic,
  21    List,
  22    Literal,
  23    Mapping,
  24    NamedTuple,
  25    Optional,
  26    Sequence,
  27    Set,
  28    Tuple,
  29    Type,
  30    TypeVar,
  31    Union,
  32    cast,
  33)
  34
  35import numpy as np
  36from annotated_types import Ge, Gt, Interval, MaxLen, MinLen, Predicate
  37from imageio.v3 import imread, imwrite  # pyright: ignore[reportUnknownVariableType]
  38from loguru import logger
  39from numpy.typing import NDArray
  40from pydantic import (
  41    Discriminator,
  42    Field,
  43    RootModel,
  44    Tag,
  45    ValidationInfo,
  46    WrapSerializer,
  47    field_validator,
  48    model_validator,
  49)
  50from typing_extensions import Annotated, LiteralString, Self, assert_never
  51
  52from .._internal.common_nodes import (
  53    Converter,
  54    InvalidDescr,
  55    Node,
  56    NodeWithExplicitlySetFields,
  57)
  58from .._internal.constants import DTYPE_LIMITS
  59from .._internal.field_warning import issue_warning, warn
  60from .._internal.io import BioimageioYamlContent as BioimageioYamlContent
  61from .._internal.io import FileDescr as FileDescr
  62from .._internal.io import WithSuffix, YamlValue, download
  63from .._internal.io_basics import AbsoluteFilePath as AbsoluteFilePath
  64from .._internal.io_basics import Sha256 as Sha256
  65from .._internal.io_utils import load_array
  66from .._internal.types import Datetime as Datetime
  67from .._internal.types import Identifier as Identifier
  68from .._internal.types import (
  69    ImportantFileSource,
  70    LowerCaseIdentifier,
  71    LowerCaseIdentifierAnno,
  72    SiUnit,
  73)
  74from .._internal.types import NotEmpty as NotEmpty
  75from .._internal.url import HttpUrl as HttpUrl
  76from .._internal.validation_context import validation_context_var
  77from .._internal.validator_annotations import RestrictCharacters
  78from .._internal.version_type import Version as Version
  79from .._internal.warning_levels import INFO
  80from ..dataset.v0_2 import DatasetDescr as DatasetDescr02
  81from ..dataset.v0_2 import LinkedDataset as LinkedDataset02
  82from ..dataset.v0_3 import DatasetDescr as DatasetDescr
  83from ..dataset.v0_3 import DatasetId as DatasetId
  84from ..dataset.v0_3 import LinkedDataset as LinkedDataset
  85from ..dataset.v0_3 import Uploader as Uploader
  86from ..generic.v0_3 import (
  87    VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS,
  88)
  89from ..generic.v0_3 import Author as Author
  90from ..generic.v0_3 import BadgeDescr as BadgeDescr
  91from ..generic.v0_3 import CiteEntry as CiteEntry
  92from ..generic.v0_3 import DeprecatedLicenseId as DeprecatedLicenseId
  93from ..generic.v0_3 import (
  94    DocumentationSource,
  95    GenericModelDescrBase,
  96    LinkedResourceNode,
  97    _author_conv,  # pyright: ignore[reportPrivateUsage]
  98    _maintainer_conv,  # pyright: ignore[reportPrivateUsage]
  99)
 100from ..generic.v0_3 import Doi as Doi
 101from ..generic.v0_3 import LicenseId as LicenseId
 102from ..generic.v0_3 import LinkedResource as LinkedResource
 103from ..generic.v0_3 import Maintainer as Maintainer
 104from ..generic.v0_3 import OrcidId as OrcidId
 105from ..generic.v0_3 import RelativeFilePath as RelativeFilePath
 106from ..generic.v0_3 import ResourceId as ResourceId
 107from .v0_4 import Author as _Author_v0_4
 108from .v0_4 import BinarizeDescr as _BinarizeDescr_v0_4
 109from .v0_4 import CallableFromDepencency as CallableFromDepencency
 110from .v0_4 import CallableFromDepencency as _CallableFromDepencency_v0_4
 111from .v0_4 import CallableFromFile as _CallableFromFile_v0_4
 112from .v0_4 import ClipDescr as _ClipDescr_v0_4
 113from .v0_4 import ClipKwargs as ClipKwargs
 114from .v0_4 import ImplicitOutputShape as _ImplicitOutputShape_v0_4
 115from .v0_4 import InputTensorDescr as _InputTensorDescr_v0_4
 116from .v0_4 import KnownRunMode as KnownRunMode
 117from .v0_4 import ModelDescr as _ModelDescr_v0_4
 118from .v0_4 import OutputTensorDescr as _OutputTensorDescr_v0_4
 119from .v0_4 import ParameterizedInputShape as _ParameterizedInputShape_v0_4
 120from .v0_4 import PostprocessingDescr as _PostprocessingDescr_v0_4
 121from .v0_4 import PreprocessingDescr as _PreprocessingDescr_v0_4
 122from .v0_4 import ProcessingKwargs as ProcessingKwargs
 123from .v0_4 import RunMode as RunMode
 124from .v0_4 import ScaleLinearDescr as _ScaleLinearDescr_v0_4
 125from .v0_4 import ScaleMeanVarianceDescr as _ScaleMeanVarianceDescr_v0_4
 126from .v0_4 import ScaleRangeDescr as _ScaleRangeDescr_v0_4
 127from .v0_4 import SigmoidDescr as _SigmoidDescr_v0_4
 128from .v0_4 import TensorName as _TensorName_v0_4
 129from .v0_4 import WeightsFormat as WeightsFormat
 130from .v0_4 import ZeroMeanUnitVarianceDescr as _ZeroMeanUnitVarianceDescr_v0_4
 131from .v0_4 import package_weights
 132
 133# unit names from https://ngff.openmicroscopy.org/latest/#axes-md
 134SpaceUnit = Literal[
 135    "attometer",
 136    "angstrom",
 137    "centimeter",
 138    "decimeter",
 139    "exameter",
 140    "femtometer",
 141    "foot",
 142    "gigameter",
 143    "hectometer",
 144    "inch",
 145    "kilometer",
 146    "megameter",
 147    "meter",
 148    "micrometer",
 149    "mile",
 150    "millimeter",
 151    "nanometer",
 152    "parsec",
 153    "petameter",
 154    "picometer",
 155    "terameter",
 156    "yard",
 157    "yoctometer",
 158    "yottameter",
 159    "zeptometer",
 160    "zettameter",
 161]
 162
 163TimeUnit = Literal[
 164    "attosecond",
 165    "centisecond",
 166    "day",
 167    "decisecond",
 168    "exasecond",
 169    "femtosecond",
 170    "gigasecond",
 171    "hectosecond",
 172    "hour",
 173    "kilosecond",
 174    "megasecond",
 175    "microsecond",
 176    "millisecond",
 177    "minute",
 178    "nanosecond",
 179    "petasecond",
 180    "picosecond",
 181    "second",
 182    "terasecond",
 183    "yoctosecond",
 184    "yottasecond",
 185    "zeptosecond",
 186    "zettasecond",
 187]
 188
 189AxisType = Literal["batch", "channel", "index", "time", "space"]
 190
 191
 192class TensorId(LowerCaseIdentifier):
 193    root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
 194        Annotated[LowerCaseIdentifierAnno, MaxLen(32)]
 195    ]
 196
 197
 198class AxisId(LowerCaseIdentifier):
 199    root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
 200        Annotated[LowerCaseIdentifierAnno, MaxLen(16)]
 201    ]
 202
 203
 204def _is_batch(a: str) -> bool:
 205    return a == BATCH_AXIS_ID
 206
 207
 208def _is_not_batch(a: str) -> bool:
 209    return not _is_batch(a)
 210
 211
 212NonBatchAxisId = Annotated[AxisId, Predicate(_is_not_batch)]
 213
 214PostprocessingId = Literal[
 215    "binarize",
 216    "clip",
 217    "ensure_dtype",
 218    "fixed_zero_mean_unit_variance",
 219    "scale_linear",
 220    "scale_mean_variance",
 221    "scale_range",
 222    "sigmoid",
 223    "zero_mean_unit_variance",
 224]
 225PreprocessingId = Literal[
 226    "binarize",
 227    "clip",
 228    "ensure_dtype",
 229    "scale_linear",
 230    "sigmoid",
 231    "zero_mean_unit_variance",
 232    "scale_range",
 233]
 234
 235
 236SAME_AS_TYPE = "<same as type>"
 237
 238
 239ParameterizedSize_N = int
 240
 241
 242class ParameterizedSize(Node):
 243    """Describes a range of valid tensor axis sizes as `size = min + n*step`."""
 244
 245    N: ClassVar[Type[int]] = ParameterizedSize_N
 246    """integer to parameterize this axis"""
 247
 248    min: Annotated[int, Gt(0)]
 249    step: Annotated[int, Gt(0)]
 250
 251    def validate_size(self, size: int) -> int:
 252        if size < self.min:
 253            raise ValueError(f"size {size} < {self.min}")
 254        if (size - self.min) % self.step != 0:
 255            raise ValueError(
 256                f"axis of size {size} is not parameterized by `min + n*step` ="
 257                + f" `{self.min} + n*{self.step}`"
 258            )
 259
 260        return size
 261
 262    def get_size(self, n: ParameterizedSize_N) -> int:
 263        return self.min + self.step * n
 264
 265    def get_n(self, s: int) -> ParameterizedSize_N:
 266        """return smallest n parameterizing a size greater or equal than `s`"""
 267        return ceil((s - self.min) / self.step)
 268
 269
 270class DataDependentSize(Node):
 271    min: Annotated[int, Gt(0)] = 1
 272    max: Annotated[Optional[int], Gt(1)] = None
 273
 274    @model_validator(mode="after")
 275    def _validate_max_gt_min(self):
 276        if self.max is not None and self.min >= self.max:
 277            raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}")
 278
 279        return self
 280
 281    def validate_size(self, size: int) -> int:
 282        if size < self.min:
 283            raise ValueError(f"size {size} < {self.min}")
 284
 285        if self.max is not None and size > self.max:
 286            raise ValueError(f"size {size} > {self.max}")
 287
 288        return size
 289
 290
 291class SizeReference(Node):
 292    """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis.
 293
 294    `axis.size = reference.size * reference.scale / axis.scale + offset`
 295
 296    note:
 297    1. The axis and the referenced axis need to have the same unit (or no unit).
 298    2. Batch axes may not be referenced.
 299    3. Fractions are rounded down.
 300    4. If the reference axis is `concatenable` the referencing axis is assumed to be
 301        `concatenable` as well with the same block order.
 302
 303    example:
 304    An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm².
 305    Let's assume that we want to express the image height h in relation to its width w
 306    instead of only accepting input images of exactly 100*49 pixels
 307    (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`).
 308
 309    >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2)
 310    >>> h = SpaceInputAxis(
 311    ...     id=AxisId("h"),
 312    ...     size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1),
 313    ...     unit="millimeter",
 314    ...     scale=4,
 315    ... )
 316    >>> print(h.size.compute(h, w))
 317    49
 318
 319    -> h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49
 320    """
 321
 322    tensor_id: TensorId
 323    """tensor id of the reference axis"""
 324
 325    axis_id: AxisId
 326    """axis id of the reference axis"""
 327
 328    offset: int = 0
 329
 330    def get_size(
 331        self,
 332        axis: Union[
 333            ChannelAxis,
 334            IndexInputAxis,
 335            IndexOutputAxis,
 336            TimeInputAxis,
 337            SpaceInputAxis,
 338            TimeOutputAxis,
 339            TimeOutputAxisWithHalo,
 340            SpaceOutputAxis,
 341            SpaceOutputAxisWithHalo,
 342        ],
 343        ref_axis: Union[
 344            ChannelAxis,
 345            IndexInputAxis,
 346            IndexOutputAxis,
 347            TimeInputAxis,
 348            SpaceInputAxis,
 349            TimeOutputAxis,
 350            TimeOutputAxisWithHalo,
 351            SpaceOutputAxis,
 352            SpaceOutputAxisWithHalo,
 353        ],
 354        n: ParameterizedSize_N = 0,
 355        ref_size: Optional[int] = None,
 356    ):
 357        """Compute the concrete size for a given axis and its reference axis.
 358
 359        Args:
 360            axis: The axis this `SizeReference` is the size of.
 361            ref_axis: The reference axis to compute the size from.
 362            n: If the **ref_axis** is parameterized (of type `ParameterizedSize`)
 363                and no fixed **ref_size** is given,
 364                **n** is used to compute the size of the parameterized **ref_axis**.
 365            ref_size: Overwrite the reference size instead of deriving it from
 366                **ref_axis**
 367                (**ref_axis.scale** is still used; any given **n** is ignored).
 368        """
 369        assert (
 370            axis.size == self
 371        ), "Given `axis.size` is not defined by this `SizeReference`"
 372
 373        assert (
 374            ref_axis.id == self.axis_id
 375        ), f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}."
 376
 377        assert axis.unit == ref_axis.unit, (
 378            "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`,"
 379            f" but {axis.unit}!={ref_axis.unit}"
 380        )
 381        if ref_size is None:
 382            if isinstance(ref_axis.size, (int, float)):
 383                ref_size = ref_axis.size
 384            elif isinstance(ref_axis.size, ParameterizedSize):
 385                ref_size = ref_axis.size.get_size(n)
 386            elif isinstance(ref_axis.size, DataDependentSize):
 387                raise ValueError(
 388                    "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`."
 389                )
 390            elif isinstance(ref_axis.size, SizeReference):
 391                raise ValueError(
 392                    "Reference axis referenced in `SizeReference` may not be sized by a"
 393                    + " `SizeReference` itself."
 394                )
 395            else:
 396                assert_never(ref_axis.size)
 397
 398        return int(ref_size * ref_axis.scale / axis.scale + self.offset)
 399
 400    @staticmethod
 401    def _get_unit(
 402        axis: Union[
 403            ChannelAxis,
 404            IndexInputAxis,
 405            IndexOutputAxis,
 406            TimeInputAxis,
 407            SpaceInputAxis,
 408            TimeOutputAxis,
 409            TimeOutputAxisWithHalo,
 410            SpaceOutputAxis,
 411            SpaceOutputAxisWithHalo,
 412        ],
 413    ):
 414        return axis.unit
 415
 416
 417# this Axis definition is compatible with the NGFF draft from July 10, 2023
 418# https://ngff.openmicroscopy.org/latest/#axes-md
 419class AxisBase(NodeWithExplicitlySetFields):
 420    fields_to_set_explicitly: ClassVar[FrozenSet[LiteralString]] = frozenset({"type"})
 421
 422    id: AxisId
 423    """An axis id unique across all axes of one tensor."""
 424
 425    description: Annotated[str, MaxLen(128)] = ""
 426
 427
 428class WithHalo(Node):
 429    halo: Annotated[int, Ge(1)]
 430    """The halo should be cropped from the output tensor to avoid boundary effects.
 431    It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`.
 432    To document a halo that is already cropped by the model use `size.offset` instead."""
 433
 434    size: Annotated[
 435        SizeReference,
 436        Field(
 437            examples=[
 438                10,
 439                SizeReference(
 440                    tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
 441                ).model_dump(mode="json"),
 442            ]
 443        ),
 444    ]
 445    """reference to another axis with an optional offset (see `SizeReference`)"""
 446
 447
 448BATCH_AXIS_ID = AxisId("batch")
 449
 450
 451class BatchAxis(AxisBase):
 452    type: Literal["batch"] = "batch"
 453    id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID
 454    size: Optional[Literal[1]] = None
 455    """The batch size may be fixed to 1,
 456    otherwise (the default) it may be chosen arbitrarily depending on available memory"""
 457
 458    @property
 459    def scale(self):
 460        return 1.0
 461
 462    @property
 463    def concatenable(self):
 464        return True
 465
 466    @property
 467    def unit(self):
 468        return None
 469
 470
 471class ChannelAxis(AxisBase):
 472    type: Literal["channel"] = "channel"
 473    id: NonBatchAxisId = AxisId("channel")
 474    channel_names: NotEmpty[List[Identifier]]
 475
 476    @property
 477    def size(self) -> int:
 478        return len(self.channel_names)
 479
 480    @property
 481    def concatenable(self):
 482        return False
 483
 484    @property
 485    def scale(self) -> float:
 486        return 1.0
 487
 488    @property
 489    def unit(self):
 490        return None
 491
 492
 493class IndexAxisBase(AxisBase):
 494    type: Literal["index"] = "index"
 495    id: NonBatchAxisId = AxisId("index")
 496
 497    @property
 498    def scale(self) -> float:
 499        return 1.0
 500
 501    @property
 502    def unit(self):
 503        return None
 504
 505
 506class _WithInputAxisSize(Node):
 507    size: Annotated[
 508        Union[Annotated[int, Gt(0)], ParameterizedSize, SizeReference],
 509        Field(
 510            examples=[
 511                10,
 512                ParameterizedSize(min=32, step=16).model_dump(mode="json"),
 513                SizeReference(
 514                    tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
 515                ).model_dump(mode="json"),
 516            ]
 517        ),
 518    ]
 519    """The size/length of this axis can be specified as
 520    - fixed integer
 521    - parameterized series of valid sizes (`ParameterizedSize`)
 522    - reference to another axis with an optional offset (`SizeReference`)
 523    """
 524
 525
 526class IndexInputAxis(IndexAxisBase, _WithInputAxisSize):
 527    concatenable: bool = False
 528    """If a model has a `concatenable` input axis, it can be processed blockwise,
 529    splitting a longer sample axis into blocks matching its input tensor description.
 530    Output axes are concatenable if they have a `SizeReference` to a concatenable
 531    input axis.
 532    """
 533
 534
 535class IndexOutputAxis(IndexAxisBase):
 536    size: Annotated[
 537        Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize],
 538        Field(
 539            examples=[
 540                10,
 541                SizeReference(
 542                    tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
 543                ).model_dump(mode="json"),
 544            ]
 545        ),
 546    ]
 547    """The size/length of this axis can be specified as
 548    - fixed integer
 549    - reference to another axis with an optional offset (`SizeReference`)
 550    - data dependent size using `DataDependentSize` (size is only known after model inference)
 551    """
 552
 553
 554class TimeAxisBase(AxisBase):
 555    type: Literal["time"] = "time"
 556    id: NonBatchAxisId = AxisId("time")
 557    unit: Optional[TimeUnit] = None
 558    scale: Annotated[float, Gt(0)] = 1.0
 559
 560
 561class TimeInputAxis(TimeAxisBase, _WithInputAxisSize):
 562    concatenable: bool = False
 563    """If a model has a `concatenable` input axis, it can be processed blockwise,
 564    splitting a longer sample axis into blocks matching its input tensor description.
 565    Output axes are concatenable if they have a `SizeReference` to a concatenable
 566    input axis.
 567    """
 568
 569
 570class SpaceAxisBase(AxisBase):
 571    type: Literal["space"] = "space"
 572    id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x")
 573    unit: Optional[SpaceUnit] = None
 574    scale: Annotated[float, Gt(0)] = 1.0
 575
 576
 577class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize):
 578    concatenable: bool = False
 579    """If a model has a `concatenable` input axis, it can be processed blockwise,
 580    splitting a longer sample axis into blocks matching its input tensor description.
 581    Output axes are concatenable if they have a `SizeReference` to a concatenable
 582    input axis.
 583    """
 584
 585
 586_InputAxisUnion = Union[
 587    BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis
 588]
 589InputAxis = Annotated[_InputAxisUnion, Discriminator("type")]
 590
 591
 592class _WithOutputAxisSize(Node):
 593    size: Annotated[
 594        Union[Annotated[int, Gt(0)], SizeReference],
 595        Field(
 596            examples=[
 597                10,
 598                SizeReference(
 599                    tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
 600                ).model_dump(mode="json"),
 601            ]
 602        ),
 603    ]
 604    """The size/length of this axis can be specified as
 605    - fixed integer
 606    - reference to another axis with an optional offset (see `SizeReference`)
 607    """
 608
 609
 610class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize):
 611    pass
 612
 613
 614class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo):
 615    pass
 616
 617
 618def _get_halo_axis_discriminator_value(v: Any) -> Literal["with_halo", "wo_halo"]:
 619    if isinstance(v, dict):
 620        return "with_halo" if "halo" in v else "wo_halo"
 621    else:
 622        return "with_halo" if hasattr(v, "halo") else "wo_halo"
 623
 624
 625_TimeOutputAxisUnion = Annotated[
 626    Union[
 627        Annotated[TimeOutputAxis, Tag("wo_halo")],
 628        Annotated[TimeOutputAxisWithHalo, Tag("with_halo")],
 629    ],
 630    Discriminator(_get_halo_axis_discriminator_value),
 631]
 632
 633
 634class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize):
 635    pass
 636
 637
 638class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo):
 639    pass
 640
 641
 642_SpaceOutputAxisUnion = Annotated[
 643    Union[
 644        Annotated[SpaceOutputAxis, Tag("wo_halo")],
 645        Annotated[SpaceOutputAxisWithHalo, Tag("with_halo")],
 646    ],
 647    Discriminator(_get_halo_axis_discriminator_value),
 648]
 649
 650
 651_OutputAxisUnion = Union[
 652    BatchAxis, ChannelAxis, IndexOutputAxis, _TimeOutputAxisUnion, _SpaceOutputAxisUnion
 653]
 654OutputAxis = Annotated[_OutputAxisUnion, Discriminator("type")]
 655
 656AnyAxis = Union[InputAxis, OutputAxis]
 657
 658TVs = Union[
 659    NotEmpty[List[int]],
 660    NotEmpty[List[float]],
 661    NotEmpty[List[bool]],
 662    NotEmpty[List[str]],
 663]
 664
 665
 666NominalOrOrdinalDType = Literal[
 667    "float32",
 668    "float64",
 669    "uint8",
 670    "int8",
 671    "uint16",
 672    "int16",
 673    "uint32",
 674    "int32",
 675    "uint64",
 676    "int64",
 677    "bool",
 678]
 679
 680
 681class NominalOrOrdinalDataDescr(Node):
 682    values: TVs
 683    """A fixed set of nominal or an ascending sequence of ordinal values.
 684    In this case `data_type` is required to be an unsigend integer type, e.g. 'uint8'.
 685    String `values` are interpreted as labels for tensor values 0, ..., N.
 686    Note: as YAML 1.2 does not natively support a "set" datatype,
 687    nominal values should be given as a sequence (aka list/array) as well.
 688    """
 689
 690    type: Annotated[
 691        NominalOrOrdinalDType,
 692        Field(
 693            examples=[
 694                "float32",
 695                "uint8",
 696                "uint16",
 697                "int64",
 698                "bool",
 699            ],
 700        ),
 701    ] = "uint8"
 702
 703    @model_validator(mode="after")
 704    def _validate_values_match_type(
 705        self,
 706    ) -> Self:
 707        incompatible: List[Any] = []
 708        for v in self.values:
 709            if self.type == "bool":
 710                if not isinstance(v, bool):
 711                    incompatible.append(v)
 712            elif self.type in DTYPE_LIMITS:
 713                if (
 714                    isinstance(v, (int, float))
 715                    and (
 716                        v < DTYPE_LIMITS[self.type].min
 717                        or v > DTYPE_LIMITS[self.type].max
 718                    )
 719                    or (isinstance(v, str) and "uint" not in self.type)
 720                    or (isinstance(v, float) and "int" in self.type)
 721                ):
 722                    incompatible.append(v)
 723            else:
 724                incompatible.append(v)
 725
 726            if len(incompatible) == 5:
 727                incompatible.append("...")
 728                break
 729
 730        if incompatible:
 731            raise ValueError(
 732                f"data type '{self.type}' incompatible with values {incompatible}"
 733            )
 734
 735        return self
 736
 737    unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None
 738
 739    @property
 740    def range(self):
 741        if isinstance(self.values[0], str):
 742            return 0, len(self.values) - 1
 743        else:
 744            return min(self.values), max(self.values)
 745
 746
 747IntervalOrRatioDType = Literal[
 748    "float32",
 749    "float64",
 750    "uint8",
 751    "int8",
 752    "uint16",
 753    "int16",
 754    "uint32",
 755    "int32",
 756    "uint64",
 757    "int64",
 758]
 759
 760
 761class IntervalOrRatioDataDescr(Node):
 762    type: Annotated[  # todo: rename to dtype
 763        IntervalOrRatioDType,
 764        Field(
 765            examples=["float32", "float64", "uint8", "uint16"],
 766        ),
 767    ] = "float32"
 768    range: Tuple[Optional[float], Optional[float]] = (
 769        None,
 770        None,
 771    )
 772    """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor.
 773    `None` corresponds to min/max of what can be expressed by `data_type`."""
 774    unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit"
 775    scale: float = 1.0
 776    """Scale for data on an interval (or ratio) scale."""
 777    offset: Optional[float] = None
 778    """Offset for data on a ratio scale."""
 779
 780
 781TensorDataDescr = Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr]
 782
 783
 784class ProcessingDescrBase(NodeWithExplicitlySetFields, ABC):
 785    """processing base class"""
 786
 787    # id: Literal[PreprocessingId, PostprocessingId]  # make abstract field
 788    fields_to_set_explicitly: ClassVar[FrozenSet[LiteralString]] = frozenset({"id"})
 789
 790
 791class BinarizeKwargs(ProcessingKwargs):
 792    """key word arguments for `BinarizeDescr`"""
 793
 794    threshold: float
 795    """The fixed threshold"""
 796
 797
 798class BinarizeAlongAxisKwargs(ProcessingKwargs):
 799    """key word arguments for `BinarizeDescr`"""
 800
 801    threshold: NotEmpty[List[float]]
 802    """The fixed threshold values along `axis`"""
 803
 804    axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
 805    """The `threshold` axis"""
 806
 807
 808class BinarizeDescr(ProcessingDescrBase):
 809    """Binarize the tensor with a fixed threshold.
 810
 811    Values above `BinarizeKwargs.threshold`/`BinarizeAlongAxisKwargs.threshold`
 812    will be set to one, values below the threshold to zero.
 813    """
 814
 815    id: Literal["binarize"] = "binarize"
 816    kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs]
 817
 818
 819class ClipDescr(ProcessingDescrBase):
 820    """Set tensor values below min to min and above max to max."""
 821
 822    id: Literal["clip"] = "clip"
 823    kwargs: ClipKwargs
 824
 825
 826class EnsureDtypeKwargs(ProcessingKwargs):
 827    """key word arguments for `EnsureDtypeDescr`"""
 828
 829    dtype: Literal[
 830        "float32",
 831        "float64",
 832        "uint8",
 833        "int8",
 834        "uint16",
 835        "int16",
 836        "uint32",
 837        "int32",
 838        "uint64",
 839        "int64",
 840        "bool",
 841    ]
 842
 843
 844class EnsureDtypeDescr(ProcessingDescrBase):
 845    """cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching)"""
 846
 847    id: Literal["ensure_dtype"] = "ensure_dtype"
 848    kwargs: EnsureDtypeKwargs
 849
 850
 851class ScaleLinearKwargs(ProcessingKwargs):
 852    """key word arguments for `ScaleLinearDescr`"""
 853
 854    gain: float = 1.0
 855    """multiplicative factor"""
 856
 857    offset: float = 0.0
 858    """additive term"""
 859
 860    @model_validator(mode="after")
 861    def _validate(self) -> Self:
 862        if self.gain == 1.0 and self.offset == 0.0:
 863            raise ValueError(
 864                "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
 865                + " != 0.0."
 866            )
 867
 868        return self
 869
 870
 871class ScaleLinearAlongAxisKwargs(ProcessingKwargs):
 872    """key word arguments for `ScaleLinearDescr`"""
 873
 874    axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
 875    """The axis of of gains/offsets values."""
 876
 877    gain: Union[float, NotEmpty[List[float]]] = 1.0
 878    """multiplicative factor"""
 879
 880    offset: Union[float, NotEmpty[List[float]]] = 0.0
 881    """additive term"""
 882
 883    @model_validator(mode="after")
 884    def _validate(self) -> Self:
 885
 886        if isinstance(self.gain, list):
 887            if isinstance(self.offset, list):
 888                if len(self.gain) != len(self.offset):
 889                    raise ValueError(
 890                        f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match."
 891                    )
 892            else:
 893                self.offset = [float(self.offset)] * len(self.gain)
 894        elif isinstance(self.offset, list):
 895            self.gain = [float(self.gain)] * len(self.offset)
 896        else:
 897            raise ValueError(
 898                "Do not specify an `axis` for scalar gain and offset values."
 899            )
 900
 901        if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset):
 902            raise ValueError(
 903                "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
 904                + " != 0.0."
 905            )
 906
 907        return self
 908
 909
 910class ScaleLinearDescr(ProcessingDescrBase):
 911    """Fixed linear scaling."""
 912
 913    id: Literal["scale_linear"] = "scale_linear"
 914    kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs]
 915
 916
 917class SigmoidDescr(ProcessingDescrBase):
 918    """The logistic sigmoid funciton, a.k.a. expit function."""
 919
 920    id: Literal["sigmoid"] = "sigmoid"
 921
 922    @property
 923    def kwargs(self) -> ProcessingKwargs:
 924        """empty kwargs"""
 925        return ProcessingKwargs()
 926
 927
 928class FixedZeroMeanUnitVarianceKwargs(ProcessingKwargs):
 929    """key word arguments for `FixedZeroMeanUnitVarianceDescr`"""
 930
 931    mean: float
 932    """The mean value to normalize with."""
 933
 934    std: Annotated[float, Ge(1e-6)]
 935    """The standard deviation value to normalize with."""
 936
 937
 938class FixedZeroMeanUnitVarianceAlongAxisKwargs(ProcessingKwargs):
 939    """key word arguments for `FixedZeroMeanUnitVarianceDescr`"""
 940
 941    mean: NotEmpty[List[float]]
 942    """The mean value(s) to normalize with."""
 943
 944    std: NotEmpty[List[Annotated[float, Ge(1e-6)]]]
 945    """The standard deviation value(s) to normalize with.
 946    Size must match `mean` values."""
 947
 948    axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])]
 949    """The axis of the mean/std values to normalize each entry along that dimension
 950    separately."""
 951
 952    @model_validator(mode="after")
 953    def _mean_and_std_match(self) -> Self:
 954        if len(self.mean) != len(self.std):
 955            raise ValueError(
 956                f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})"
 957                + " must match."
 958            )
 959
 960        return self
 961
 962
 963class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase):
 964    """Subtract a given mean and divide by the standard deviation.
 965
 966    Normalize with fixed, precomputed values for
 967    `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std`
 968    Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given
 969    axes.
 970    """
 971
 972    id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance"
 973    kwargs: Union[
 974        FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs
 975    ]
 976
 977
 978class ZeroMeanUnitVarianceKwargs(ProcessingKwargs):
 979    """key word arguments for `ZeroMeanUnitVarianceDescr`"""
 980
 981    axes: Annotated[
 982        Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
 983    ] = None
 984    """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
 985    For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
 986    resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
 987    To normalize each sample independently leave out the 'batch' axis.
 988    Default: Scale all axes jointly."""
 989
 990    eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
 991    """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`."""
 992
 993
 994class ZeroMeanUnitVarianceDescr(ProcessingDescrBase):
 995    """Subtract mean and divide by variance."""
 996
 997    id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance"
 998    kwargs: ZeroMeanUnitVarianceKwargs
 999
1000
1001class ScaleRangeKwargs(ProcessingKwargs):
1002    """key word arguments for `ScaleRangeDescr`
1003
1004    For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default)
1005    this processing step normalizes data to the [0, 1] intervall.
1006    For other percentiles the normalized values will partially be outside the [0, 1]
1007    intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the
1008    normalized values to a range.
1009    """
1010
1011    axes: Annotated[
1012        Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1013    ] = None
1014    """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value.
1015    For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1016    resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1017    To normalize samples indepdencently, leave out the "batch" axis.
1018    Default: Scale all axes jointly."""
1019
1020    min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0
1021    """The lower percentile used to determine the value to align with zero."""
1022
1023    max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0
1024    """The upper percentile used to determine the value to align with one.
1025    Has to be bigger than `min_percentile`.
1026    The range is 1 to 100 instead of 0 to 100 to avoid mistakenly
1027    accepting percentiles specified in the range 0.0 to 1.0."""
1028
1029    eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1030    """Epsilon for numeric stability.
1031    `out = (tensor - v_lower) / (v_upper - v_lower + eps)`;
1032    with `v_lower,v_upper` values at the respective percentiles."""
1033
1034    reference_tensor: Optional[TensorId] = None
1035    """Tensor ID to compute the percentiles from. Default: The tensor itself.
1036    For any tensor in `inputs` only input tensor references are allowed."""
1037
1038    @field_validator("max_percentile", mode="after")
1039    @classmethod
1040    def min_smaller_max(cls, value: float, info: ValidationInfo) -> float:
1041        if (min_p := info.data["min_percentile"]) >= value:
1042            raise ValueError(f"min_percentile {min_p} >= max_percentile {value}")
1043
1044        return value
1045
1046
1047class ScaleRangeDescr(ProcessingDescrBase):
1048    """Scale with percentiles."""
1049
1050    id: Literal["scale_range"] = "scale_range"
1051    kwargs: ScaleRangeKwargs
1052
1053
1054class ScaleMeanVarianceKwargs(ProcessingKwargs):
1055    """key word arguments for `ScaleMeanVarianceKwargs`"""
1056
1057    reference_tensor: TensorId
1058    """Name of tensor to match."""
1059
1060    axes: Annotated[
1061        Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1062    ] = None
1063    """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1064    For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1065    resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1066    To normalize samples independently, leave out the 'batch' axis.
1067    Default: Scale all axes jointly."""
1068
1069    eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1070    """Epsilon for numeric stability:
1071    `out  = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`"""
1072
1073
1074class ScaleMeanVarianceDescr(ProcessingDescrBase):
1075    """Scale a tensor's data distribution to match another tensor's mean/std.
1076    `out  = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`
1077    """
1078
1079    id: Literal["scale_mean_variance"] = "scale_mean_variance"
1080    kwargs: ScaleMeanVarianceKwargs
1081
1082
1083PreprocessingDescr = Annotated[
1084    Union[
1085        BinarizeDescr,
1086        ClipDescr,
1087        EnsureDtypeDescr,
1088        ScaleLinearDescr,
1089        SigmoidDescr,
1090        FixedZeroMeanUnitVarianceDescr,
1091        ZeroMeanUnitVarianceDescr,
1092        ScaleRangeDescr,
1093    ],
1094    Discriminator("id"),
1095]
1096PostprocessingDescr = Annotated[
1097    Union[
1098        BinarizeDescr,
1099        ClipDescr,
1100        EnsureDtypeDescr,
1101        ScaleLinearDescr,
1102        SigmoidDescr,
1103        FixedZeroMeanUnitVarianceDescr,
1104        ZeroMeanUnitVarianceDescr,
1105        ScaleRangeDescr,
1106        ScaleMeanVarianceDescr,
1107    ],
1108    Discriminator("id"),
1109]
1110
1111IO_AxisT = TypeVar("IO_AxisT", InputAxis, OutputAxis)
1112
1113
1114class TensorDescrBase(Node, Generic[IO_AxisT]):
1115    id: TensorId
1116    """Tensor id. No duplicates are allowed."""
1117
1118    description: Annotated[str, MaxLen(128)] = ""
1119    """free text description"""
1120
1121    axes: NotEmpty[Sequence[IO_AxisT]]
1122    """tensor axes"""
1123
1124    @property
1125    def shape(self):
1126        return tuple(a.size for a in self.axes)
1127
1128    @field_validator("axes", mode="after", check_fields=False)
1129    @classmethod
1130    def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]:
1131        batch_axes = [a for a in axes if a.type == "batch"]
1132        if len(batch_axes) > 1:
1133            raise ValueError(
1134                f"Only one batch axis (per tensor) allowed, but got {batch_axes}"
1135            )
1136
1137        seen_ids: Set[AxisId] = set()
1138        duplicate_axes_ids: Set[AxisId] = set()
1139        for a in axes:
1140            (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id)
1141
1142        if duplicate_axes_ids:
1143            raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}")
1144
1145        return axes
1146
1147    test_tensor: FileDescr
1148    """An example tensor to use for testing.
1149    Using the model with the test input tensors is expected to yield the test output tensors.
1150    Each test tensor has be a an ndarray in the
1151    [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format).
1152    The file extension must be '.npy'."""
1153
1154    sample_tensor: Optional[FileDescr] = None
1155    """A sample tensor to illustrate a possible input/output for the model,
1156    The sample image primarily serves to inform a human user about an example use case
1157    and is typically stored as .hdf5, .png or .tiff.
1158    It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats)
1159    (numpy's `.npy` format is not supported).
1160    The image dimensionality has to match the number of axes specified in this tensor description.
1161    """
1162
1163    @model_validator(mode="after")
1164    def _validate_sample_tensor(self) -> Self:
1165        if (
1166            self.sample_tensor is None
1167            or not validation_context_var.get().perform_io_checks
1168        ):
1169            return self
1170
1171        local = download(self.sample_tensor.source, sha256=self.sample_tensor.sha256)
1172        tensor: NDArray[Any] = imread(
1173            local.path.read_bytes(),
1174            extension=PurePosixPath(local.original_file_name).suffix,
1175        )
1176        n_dims = len(tensor.squeeze().shape)
1177        n_dims_min = n_dims_max = len(self.axes)
1178
1179        for a in self.axes:
1180            if isinstance(a, BatchAxis):
1181                n_dims_min -= 1
1182            elif isinstance(a.size, int):
1183                if a.size == 1:
1184                    n_dims_min -= 1
1185            elif isinstance(a.size, (ParameterizedSize, DataDependentSize)):
1186                if a.size.min == 1:
1187                    n_dims_min -= 1
1188            elif isinstance(a.size, SizeReference):
1189                if a.size.offset < 2:
1190                    # size reference may result in singleton axis
1191                    n_dims_min -= 1
1192            else:
1193                assert_never(a.size)
1194
1195        n_dims_min = max(0, n_dims_min)
1196        if n_dims < n_dims_min or n_dims > n_dims_max:
1197            raise ValueError(
1198                f"Expected sample tensor to have {n_dims_min} to"
1199                + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})."
1200            )
1201
1202        return self
1203
1204    data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = (
1205        IntervalOrRatioDataDescr()
1206    )
1207    """Description of the tensor's data values, optionally per channel.
1208    If specified per channel, the data `type` needs to match across channels."""
1209
1210    @property
1211    def dtype(
1212        self,
1213    ) -> Literal[
1214        "float32",
1215        "float64",
1216        "uint8",
1217        "int8",
1218        "uint16",
1219        "int16",
1220        "uint32",
1221        "int32",
1222        "uint64",
1223        "int64",
1224        "bool",
1225    ]:
1226        """dtype as specified under `data.type` or `data[i].type`"""
1227        if isinstance(self.data, collections.abc.Sequence):
1228            return self.data[0].type
1229        else:
1230            return self.data.type
1231
1232    @field_validator("data", mode="after")
1233    @classmethod
1234    def _check_data_type_across_channels(
1235        cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]
1236    ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]:
1237        if not isinstance(value, list):
1238            return value
1239
1240        dtypes = {t.type for t in value}
1241        if len(dtypes) > 1:
1242            raise ValueError(
1243                "Tensor data descriptions per channel need to agree in their data"
1244                + f" `type`, but found {dtypes}."
1245            )
1246
1247        return value
1248
1249    @model_validator(mode="after")
1250    def _check_data_matches_channelaxis(self) -> Self:
1251        if not isinstance(self.data, (list, tuple)):
1252            return self
1253
1254        for a in self.axes:
1255            if isinstance(a, ChannelAxis):
1256                size = a.size
1257                assert isinstance(size, int)
1258                break
1259        else:
1260            return self
1261
1262        if len(self.data) != size:
1263            raise ValueError(
1264                f"Got tensor data descriptions for {len(self.data)} channels, but"
1265                + f" '{a.id}' axis has size {size}."
1266            )
1267
1268        return self
1269
1270    def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]:
1271        if len(array.shape) != len(self.axes):
1272            raise ValueError(
1273                f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})"
1274                + f" incompatible with {len(self.axes)} axes."
1275            )
1276        return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
1277
1278
1279class InputTensorDescr(TensorDescrBase[InputAxis]):
1280    id: TensorId = TensorId("input")
1281    """Input tensor id.
1282    No duplicates are allowed across all inputs and outputs."""
1283
1284    optional: bool = False
1285    """indicates that this tensor may be `None`"""
1286
1287    preprocessing: List[PreprocessingDescr] = Field(default_factory=list)
1288    """Description of how this input should be preprocessed.
1289
1290    notes:
1291    - If preprocessing does not start with an 'ensure_dtype' entry, it is added
1292      to ensure an input tensor's data type matches the input tensor's data description.
1293    - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an
1294      'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally
1295      changing the data type.
1296    """
1297
1298    @model_validator(mode="after")
1299    def _validate_preprocessing_kwargs(self) -> Self:
1300        axes_ids = [a.id for a in self.axes]
1301        for p in self.preprocessing:
1302            kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
1303            if kwargs_axes is None:
1304                continue
1305
1306            if not isinstance(kwargs_axes, collections.abc.Sequence):
1307                raise ValueError(
1308                    f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}"
1309                )
1310
1311            if any(a not in axes_ids for a in kwargs_axes):
1312                raise ValueError(
1313                    "`preprocessing.i.kwargs.axes` needs to be subset of axes ids"
1314                )
1315
1316        if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
1317            dtype = self.data.type
1318        else:
1319            dtype = self.data[0].type
1320
1321        # ensure `preprocessing` begins with `EnsureDtypeDescr`
1322        if not self.preprocessing or not isinstance(
1323            self.preprocessing[0], EnsureDtypeDescr
1324        ):
1325            self.preprocessing.insert(
1326                0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1327            )
1328
1329        # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
1330        if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)):
1331            self.preprocessing.append(
1332                EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1333            )
1334
1335        return self
1336
1337
1338def convert_axes(
1339    axes: str,
1340    *,
1341    shape: Union[
1342        Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4
1343    ],
1344    tensor_type: Literal["input", "output"],
1345    halo: Optional[Sequence[int]],
1346    size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
1347):
1348    ret: List[AnyAxis] = []
1349    for i, a in enumerate(axes):
1350        axis_type = _AXIS_TYPE_MAP.get(a, a)
1351        if axis_type == "batch":
1352            ret.append(BatchAxis())
1353            continue
1354
1355        scale = 1.0
1356        if isinstance(shape, _ParameterizedInputShape_v0_4):
1357            if shape.step[i] == 0:
1358                size = shape.min[i]
1359            else:
1360                size = ParameterizedSize(min=shape.min[i], step=shape.step[i])
1361        elif isinstance(shape, _ImplicitOutputShape_v0_4):
1362            ref_t = str(shape.reference_tensor)
1363            if ref_t.count(".") == 1:
1364                t_id, orig_a_id = ref_t.split(".")
1365            else:
1366                t_id = ref_t
1367                orig_a_id = a
1368
1369            a_id = _AXIS_ID_MAP.get(orig_a_id, a)
1370            if not (orig_scale := shape.scale[i]):
1371                # old way to insert a new axis dimension
1372                size = int(2 * shape.offset[i])
1373            else:
1374                scale = 1 / orig_scale
1375                if axis_type in ("channel", "index"):
1376                    # these axes no longer have a scale
1377                    offset_from_scale = orig_scale * size_refs.get(
1378                        _TensorName_v0_4(t_id), {}
1379                    ).get(orig_a_id, 0)
1380                else:
1381                    offset_from_scale = 0
1382                size = SizeReference(
1383                    tensor_id=TensorId(t_id),
1384                    axis_id=AxisId(a_id),
1385                    offset=int(offset_from_scale + 2 * shape.offset[i]),
1386                )
1387        else:
1388            size = shape[i]
1389
1390        if axis_type == "time":
1391            if tensor_type == "input":
1392                ret.append(TimeInputAxis(size=size, scale=scale))
1393            else:
1394                assert not isinstance(size, ParameterizedSize)
1395                if halo is None:
1396                    ret.append(TimeOutputAxis(size=size, scale=scale))
1397                else:
1398                    assert not isinstance(size, int)
1399                    ret.append(
1400                        TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i])
1401                    )
1402
1403        elif axis_type == "index":
1404            if tensor_type == "input":
1405                ret.append(IndexInputAxis(size=size))
1406            else:
1407                if isinstance(size, ParameterizedSize):
1408                    size = DataDependentSize(min=size.min)
1409
1410                ret.append(IndexOutputAxis(size=size))
1411        elif axis_type == "channel":
1412            assert not isinstance(size, ParameterizedSize)
1413            if isinstance(size, SizeReference):
1414                warnings.warn(
1415                    "Conversion of channel size from an implicit output shape may be"
1416                    + " wrong"
1417                )
1418                ret.append(
1419                    ChannelAxis(
1420                        channel_names=[
1421                            Identifier(f"channel{i}") for i in range(size.offset)
1422                        ]
1423                    )
1424                )
1425            else:
1426                ret.append(
1427                    ChannelAxis(
1428                        channel_names=[Identifier(f"channel{i}") for i in range(size)]
1429                    )
1430                )
1431        elif axis_type == "space":
1432            if tensor_type == "input":
1433                ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale))
1434            else:
1435                assert not isinstance(size, ParameterizedSize)
1436                if halo is None or halo[i] == 0:
1437                    ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale))
1438                elif isinstance(size, int):
1439                    raise NotImplementedError(
1440                        f"output axis with halo and fixed size (here {size}) not allowed"
1441                    )
1442                else:
1443                    ret.append(
1444                        SpaceOutputAxisWithHalo(
1445                            id=AxisId(a), size=size, scale=scale, halo=halo[i]
1446                        )
1447                    )
1448
1449    return ret
1450
1451
1452_AXIS_TYPE_MAP = {
1453    "b": "batch",
1454    "t": "time",
1455    "i": "index",
1456    "c": "channel",
1457    "x": "space",
1458    "y": "space",
1459    "z": "space",
1460}
1461
1462_AXIS_ID_MAP = {
1463    "b": "batch",
1464    "t": "time",
1465    "i": "index",
1466    "c": "channel",
1467}
1468
1469
1470def _axes_letters_to_ids(
1471    axes: Optional[str],
1472) -> Optional[List[AxisId]]:
1473    if axes is None:
1474        return None
1475    return [AxisId(_AXIS_ID_MAP.get(a, a)) for a in map(str, axes)]
1476
1477
1478def _get_complement_v04_axis(
1479    tensor_axes: Sequence[str], axes: Optional[Sequence[str]]
1480) -> Optional[AxisId]:
1481    if axes is None:
1482        return None
1483
1484    non_complement_axes = set(axes) | {"b"}
1485    complement_axes = [a for a in tensor_axes if a not in non_complement_axes]
1486    if len(complement_axes) > 1:
1487        raise ValueError(
1488            f"Expected none or a single complement axis, but axes '{axes}' "
1489            + f"for tensor dims '{tensor_axes}' leave '{complement_axes}'."
1490        )
1491
1492    return None if not complement_axes else AxisId(complement_axes[0])
1493
1494
1495def _convert_proc(
1496    p: Union[_PreprocessingDescr_v0_4, _PostprocessingDescr_v0_4],
1497    tensor_axes: Sequence[str],
1498) -> Union[PreprocessingDescr, PostprocessingDescr]:
1499    if isinstance(p, _BinarizeDescr_v0_4):
1500        return BinarizeDescr(kwargs=BinarizeKwargs(threshold=p.kwargs.threshold))
1501    elif isinstance(p, _ClipDescr_v0_4):
1502        return ClipDescr(kwargs=ClipKwargs(min=p.kwargs.min, max=p.kwargs.max))
1503    elif isinstance(p, _SigmoidDescr_v0_4):
1504        return SigmoidDescr()
1505    elif isinstance(p, _ScaleLinearDescr_v0_4):
1506        axes = _axes_letters_to_ids(p.kwargs.axes)
1507        if p.kwargs.axes is None:
1508            axis = None
1509        else:
1510            axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes)
1511
1512        if axis is None:
1513            assert not isinstance(p.kwargs.gain, list)
1514            assert not isinstance(p.kwargs.offset, list)
1515            kwargs = ScaleLinearKwargs(gain=p.kwargs.gain, offset=p.kwargs.offset)
1516        else:
1517            kwargs = ScaleLinearAlongAxisKwargs(
1518                axis=axis, gain=p.kwargs.gain, offset=p.kwargs.offset
1519            )
1520        return ScaleLinearDescr(kwargs=kwargs)
1521    elif isinstance(p, _ScaleMeanVarianceDescr_v0_4):
1522        return ScaleMeanVarianceDescr(
1523            kwargs=ScaleMeanVarianceKwargs(
1524                axes=_axes_letters_to_ids(p.kwargs.axes),
1525                reference_tensor=TensorId(str(p.kwargs.reference_tensor)),
1526                eps=p.kwargs.eps,
1527            )
1528        )
1529    elif isinstance(p, _ZeroMeanUnitVarianceDescr_v0_4):
1530        if p.kwargs.mode == "fixed":
1531            mean = p.kwargs.mean
1532            std = p.kwargs.std
1533            assert mean is not None
1534            assert std is not None
1535
1536            axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes)
1537
1538            if axis is None:
1539                return FixedZeroMeanUnitVarianceDescr(
1540                    kwargs=FixedZeroMeanUnitVarianceKwargs(
1541                        mean=mean, std=std  # pyright: ignore[reportArgumentType]
1542                    )
1543                )
1544            else:
1545                if not isinstance(mean, list):
1546                    mean = [float(mean)]
1547                if not isinstance(std, list):
1548                    std = [float(std)]
1549
1550                return FixedZeroMeanUnitVarianceDescr(
1551                    kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
1552                        axis=axis, mean=mean, std=std
1553                    )
1554                )
1555
1556        else:
1557            axes = _axes_letters_to_ids(p.kwargs.axes) or []
1558            if p.kwargs.mode == "per_dataset":
1559                axes = [AxisId("batch")] + axes
1560            if not axes:
1561                axes = None
1562            return ZeroMeanUnitVarianceDescr(
1563                kwargs=ZeroMeanUnitVarianceKwargs(axes=axes, eps=p.kwargs.eps)
1564            )
1565
1566    elif isinstance(p, _ScaleRangeDescr_v0_4):
1567        return ScaleRangeDescr(
1568            kwargs=ScaleRangeKwargs(
1569                axes=_axes_letters_to_ids(p.kwargs.axes),
1570                min_percentile=p.kwargs.min_percentile,
1571                max_percentile=p.kwargs.max_percentile,
1572                eps=p.kwargs.eps,
1573            )
1574        )
1575    else:
1576        assert_never(p)
1577
1578
1579class _InputTensorConv(
1580    Converter[
1581        _InputTensorDescr_v0_4,
1582        InputTensorDescr,
1583        ImportantFileSource,
1584        Optional[ImportantFileSource],
1585        Mapping[_TensorName_v0_4, Mapping[str, int]],
1586    ]
1587):
1588    def _convert(
1589        self,
1590        src: _InputTensorDescr_v0_4,
1591        tgt: "type[InputTensorDescr] | type[dict[str, Any]]",
1592        test_tensor: ImportantFileSource,
1593        sample_tensor: Optional[ImportantFileSource],
1594        size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
1595    ) -> "InputTensorDescr | dict[str, Any]":
1596        axes: List[InputAxis] = convert_axes(  # pyright: ignore[reportAssignmentType]
1597            src.axes,
1598            shape=src.shape,
1599            tensor_type="input",
1600            halo=None,
1601            size_refs=size_refs,
1602        )
1603        prep: List[PreprocessingDescr] = []
1604        for p in src.preprocessing:
1605            cp = _convert_proc(p, src.axes)
1606            assert not isinstance(cp, ScaleMeanVarianceDescr)
1607            prep.append(cp)
1608
1609        return tgt(
1610            axes=axes,
1611            id=TensorId(str(src.name)),
1612            test_tensor=FileDescr(source=test_tensor),
1613            sample_tensor=(
1614                None if sample_tensor is None else FileDescr(source=sample_tensor)
1615            ),
1616            data=dict(type=src.data_type),  # pyright: ignore[reportArgumentType]
1617            preprocessing=prep,
1618        )
1619
1620
1621_input_tensor_conv = _InputTensorConv(_InputTensorDescr_v0_4, InputTensorDescr)
1622
1623
1624class OutputTensorDescr(TensorDescrBase[OutputAxis]):
1625    id: TensorId = TensorId("output")
1626    """Output tensor id.
1627    No duplicates are allowed across all inputs and outputs."""
1628
1629    postprocessing: List[PostprocessingDescr] = Field(default_factory=list)
1630    """Description of how this output should be postprocessed.
1631
1632    note: `postprocessing` always ends with an 'ensure_dtype' operation.
1633          If not given this is added to cast to this tensor's `data.type`.
1634    """
1635
1636    @model_validator(mode="after")
1637    def _validate_postprocessing_kwargs(self) -> Self:
1638        axes_ids = [a.id for a in self.axes]
1639        for p in self.postprocessing:
1640            kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
1641            if kwargs_axes is None:
1642                continue
1643
1644            if not isinstance(kwargs_axes, collections.abc.Sequence):
1645                raise ValueError(
1646                    f"expected `axes` sequence, but got {type(kwargs_axes)}"
1647                )
1648
1649            if any(a not in axes_ids for a in kwargs_axes):
1650                raise ValueError("`kwargs.axes` needs to be subset of axes ids")
1651
1652        if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
1653            dtype = self.data.type
1654        else:
1655            dtype = self.data[0].type
1656
1657        # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
1658        if not self.postprocessing or not isinstance(
1659            self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)
1660        ):
1661            self.postprocessing.append(
1662                EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1663            )
1664        return self
1665
1666
1667class _OutputTensorConv(
1668    Converter[
1669        _OutputTensorDescr_v0_4,
1670        OutputTensorDescr,
1671        ImportantFileSource,
1672        Optional[ImportantFileSource],
1673        Mapping[_TensorName_v0_4, Mapping[str, int]],
1674    ]
1675):
1676    def _convert(
1677        self,
1678        src: _OutputTensorDescr_v0_4,
1679        tgt: "type[OutputTensorDescr] | type[dict[str, Any]]",
1680        test_tensor: ImportantFileSource,
1681        sample_tensor: Optional[ImportantFileSource],
1682        size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
1683    ) -> "OutputTensorDescr | dict[str, Any]":
1684        # TODO: split convert_axes into convert_output_axes and convert_input_axes
1685        axes: List[OutputAxis] = convert_axes(  # pyright: ignore[reportAssignmentType]
1686            src.axes,
1687            shape=src.shape,
1688            tensor_type="output",
1689            halo=src.halo,
1690            size_refs=size_refs,
1691        )
1692        data_descr: Dict[str, Any] = dict(type=src.data_type)
1693        if data_descr["type"] == "bool":
1694            data_descr["values"] = [False, True]
1695
1696        return tgt(
1697            axes=axes,
1698            id=TensorId(str(src.name)),
1699            test_tensor=FileDescr(source=test_tensor),
1700            sample_tensor=(
1701                None if sample_tensor is None else FileDescr(source=sample_tensor)
1702            ),
1703            data=data_descr,  # pyright: ignore[reportArgumentType]
1704            postprocessing=[_convert_proc(p, src.axes) for p in src.postprocessing],
1705        )
1706
1707
1708_output_tensor_conv = _OutputTensorConv(_OutputTensorDescr_v0_4, OutputTensorDescr)
1709
1710
1711TensorDescr = Union[InputTensorDescr, OutputTensorDescr]
1712
1713
1714def validate_tensors(
1715    tensors: Mapping[TensorId, Tuple[TensorDescr, NDArray[Any]]],
1716    tensor_origin: str,  # for more precise error messages, e.g. 'test_tensor'
1717):
1718    all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, int]]] = {}
1719
1720    def e_msg(d: TensorDescr):
1721        return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]"
1722
1723    for descr, array in tensors.values():
1724        try:
1725            axis_sizes = descr.get_axis_sizes_for_array(array)
1726        except ValueError as e:
1727            raise ValueError(f"{e_msg(descr)} {e}")
1728        else:
1729            all_tensor_axes[descr.id] = {
1730                a.id: (a, axis_sizes[a.id]) for a in descr.axes
1731            }
1732
1733    for descr, array in tensors.values():
1734        if array.dtype.name != descr.dtype:
1735            raise ValueError(
1736                f"{e_msg(descr)}.{tensor_origin}.dtype '{array.dtype.name}' does not"
1737                + f" match described dtype '{descr.dtype}'"
1738            )
1739
1740        for a in descr.axes:
1741            actual_size = all_tensor_axes[descr.id][a.id][1]
1742            if a.size is None:
1743                continue
1744
1745            if isinstance(a.size, int):
1746                if actual_size != a.size:
1747                    raise ValueError(
1748                        f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' "
1749                        + f"has incompatible size {actual_size}, expected {a.size}"
1750                    )
1751            elif isinstance(a.size, ParameterizedSize):
1752                _ = a.size.validate_size(actual_size)
1753            elif isinstance(a.size, DataDependentSize):
1754                _ = a.size.validate_size(actual_size)
1755            elif isinstance(a.size, SizeReference):
1756                ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id)
1757                if ref_tensor_axes is None:
1758                    raise ValueError(
1759                        f"{e_msg(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor"
1760                        + f" reference '{a.size.tensor_id}'"
1761                    )
1762
1763                ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None))
1764                if ref_axis is None or ref_size is None:
1765                    raise ValueError(
1766                        f"{e_msg(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis"
1767                        + f" reference '{a.size.tensor_id}.{a.size.axis_id}"
1768                    )
1769
1770                if a.unit != ref_axis.unit:
1771                    raise ValueError(
1772                        f"{e_msg(descr)}.axes[{a.id}].size: `SizeReference` requires"
1773                        + " axis and reference axis to have the same `unit`, but"
1774                        + f" {a.unit}!={ref_axis.unit}"
1775                    )
1776
1777                if actual_size != (
1778                    expected_size := (
1779                        ref_size * ref_axis.scale / a.scale + a.size.offset
1780                    )
1781                ):
1782                    raise ValueError(
1783                        f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' of size"
1784                        + f" {actual_size} invalid for referenced size {ref_size};"
1785                        + f" expected {expected_size}"
1786                    )
1787            else:
1788                assert_never(a.size)
1789
1790
1791class EnvironmentFileDescr(FileDescr):
1792    source: Annotated[
1793        ImportantFileSource,
1794        WithSuffix((".yaml", ".yml"), case_sensitive=True),
1795        Field(
1796            examples=["environment.yaml"],
1797        ),
1798    ]
1799    """∈📦 Conda environment file.
1800    Allows to specify custom dependencies, see conda docs:
1801    - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms)
1802    - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually)
1803    """
1804
1805
1806class _ArchitectureCallableDescr(Node):
1807    callable: Annotated[Identifier, Field(examples=["MyNetworkClass", "get_my_model"])]
1808    """Identifier of the callable that returns a torch.nn.Module instance."""
1809
1810    kwargs: Dict[str, YamlValue] = Field(default_factory=dict)
1811    """key word arguments for the `callable`"""
1812
1813
1814class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr):
1815    pass
1816
1817
1818class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr):
1819    import_from: str
1820    """Where to import the callable from, i.e. `from <import_from> import <callable>`"""
1821
1822
1823ArchitectureDescr = Annotated[
1824    Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr],
1825    Field(union_mode="left_to_right"),
1826]
1827
1828
1829class _ArchFileConv(
1830    Converter[
1831        _CallableFromFile_v0_4,
1832        ArchitectureFromFileDescr,
1833        Optional[Sha256],
1834        Dict[str, Any],
1835    ]
1836):
1837    def _convert(
1838        self,
1839        src: _CallableFromFile_v0_4,
1840        tgt: "type[ArchitectureFromFileDescr | dict[str, Any]]",
1841        sha256: Optional[Sha256],
1842        kwargs: Dict[str, Any],
1843    ) -> "ArchitectureFromFileDescr | dict[str, Any]":
1844        if src.startswith("http") and src.count(":") == 2:
1845            http, source, callable_ = src.split(":")
1846            source = ":".join((http, source))
1847        elif not src.startswith("http") and src.count(":") == 1:
1848            source, callable_ = src.split(":")
1849        else:
1850            source = str(src)
1851            callable_ = str(src)
1852        return tgt(
1853            callable=Identifier(callable_),
1854            source=cast(ImportantFileSource, source),
1855            sha256=sha256,
1856            kwargs=kwargs,
1857        )
1858
1859
1860_arch_file_conv = _ArchFileConv(_CallableFromFile_v0_4, ArchitectureFromFileDescr)
1861
1862
1863class _ArchLibConv(
1864    Converter[
1865        _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr, Dict[str, Any]
1866    ]
1867):
1868    def _convert(
1869        self,
1870        src: _CallableFromDepencency_v0_4,
1871        tgt: "type[ArchitectureFromLibraryDescr | dict[str, Any]]",
1872        kwargs: Dict[str, Any],
1873    ) -> "ArchitectureFromLibraryDescr | dict[str, Any]":
1874        *mods, callable_ = src.split(".")
1875        import_from = ".".join(mods)
1876        return tgt(
1877            import_from=import_from, callable=Identifier(callable_), kwargs=kwargs
1878        )
1879
1880
1881_arch_lib_conv = _ArchLibConv(
1882    _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr
1883)
1884
1885
1886class WeightsEntryDescrBase(FileDescr):
1887    type: ClassVar[WeightsFormat]
1888    weights_format_name: ClassVar[str]  # human readable
1889
1890    source: ImportantFileSource
1891    """∈📦 The weights file."""
1892
1893    authors: Optional[List[Author]] = None
1894    """Authors
1895    Either the person(s) that have trained this model resulting in the original weights file.
1896        (If this is the initial weights entry, i.e. it does not have a `parent`)
1897    Or the person(s) who have converted the weights to this weights format.
1898        (If this is a child weight, i.e. it has a `parent` field)
1899    """
1900
1901    parent: Annotated[
1902        Optional[WeightsFormat], Field(examples=["pytorch_state_dict"])
1903    ] = None
1904    """The source weights these weights were converted from.
1905    For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`,
1906    The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights.
1907    All weight entries except one (the initial set of weights resulting from training the model),
1908    need to have this field."""
1909
1910    @model_validator(mode="after")
1911    def check_parent_is_not_self(self) -> Self:
1912        if self.type == self.parent:
1913            raise ValueError("Weights entry can't be it's own parent.")
1914
1915        return self
1916
1917
1918class KerasHdf5WeightsDescr(WeightsEntryDescrBase):
1919    type = "keras_hdf5"
1920    weights_format_name: ClassVar[str] = "Keras HDF5"
1921    tensorflow_version: Version
1922    """TensorFlow version used to create these weights."""
1923
1924
1925class OnnxWeightsDescr(WeightsEntryDescrBase):
1926    type = "onnx"
1927    weights_format_name: ClassVar[str] = "ONNX"
1928    opset_version: Annotated[int, Ge(7)]
1929    """ONNX opset version"""
1930
1931
1932class PytorchStateDictWeightsDescr(WeightsEntryDescrBase):
1933    type = "pytorch_state_dict"
1934    weights_format_name: ClassVar[str] = "Pytorch State Dict"
1935    architecture: ArchitectureDescr
1936    pytorch_version: Version
1937    """Version of the PyTorch library used.
1938    If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible.
1939    """
1940    dependencies: Optional[EnvironmentFileDescr] = None
1941    """Custom depencies beyond pytorch.
1942    The conda environment file should include pytorch and any version pinning has to be compatible with
1943    `pytorch_version`.
1944    """
1945
1946
1947class TensorflowJsWeightsDescr(WeightsEntryDescrBase):
1948    type = "tensorflow_js"
1949    weights_format_name: ClassVar[str] = "Tensorflow.js"
1950    tensorflow_version: Version
1951    """Version of the TensorFlow library used."""
1952
1953    source: ImportantFileSource
1954    """∈📦 The multi-file weights.
1955    All required files/folders should be a zip archive."""
1956
1957
1958class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase):
1959    type = "tensorflow_saved_model_bundle"
1960    weights_format_name: ClassVar[str] = "Tensorflow Saved Model"
1961    tensorflow_version: Version
1962    """Version of the TensorFlow library used."""
1963
1964    dependencies: Optional[EnvironmentFileDescr] = None
1965    """Custom dependencies beyond tensorflow.
1966    Should include tensorflow and any version pinning has to be compatible with `tensorflow_version`."""
1967
1968    source: ImportantFileSource
1969    """∈📦 The multi-file weights.
1970    All required files/folders should be a zip archive."""
1971
1972
1973class TorchscriptWeightsDescr(WeightsEntryDescrBase):
1974    type = "torchscript"
1975    weights_format_name: ClassVar[str] = "TorchScript"
1976    pytorch_version: Version
1977    """Version of the PyTorch library used."""
1978
1979
1980class WeightsDescr(Node):
1981    keras_hdf5: Optional[KerasHdf5WeightsDescr] = None
1982    onnx: Optional[OnnxWeightsDescr] = None
1983    pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None
1984    tensorflow_js: Optional[TensorflowJsWeightsDescr] = None
1985    tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = (
1986        None
1987    )
1988    torchscript: Optional[TorchscriptWeightsDescr] = None
1989
1990    @model_validator(mode="after")
1991    def check_entries(self) -> Self:
1992        entries = {wtype for wtype, entry in self if entry is not None}
1993
1994        if not entries:
1995            raise ValueError("Missing weights entry")
1996
1997        entries_wo_parent = {
1998            wtype
1999            for wtype, entry in self
2000            if entry is not None and hasattr(entry, "parent") and entry.parent is None
2001        }
2002        if len(entries_wo_parent) != 1:
2003            issue_warning(
2004                "Exactly one weights entry may not specify the `parent` field (got"
2005                + " {value}). That entry is considered the original set of model weights."
2006                + " Other weight formats are created through conversion of the orignal or"
2007                + " already converted weights. They have to reference the weights format"
2008                + " they were converted from as their `parent`.",
2009                value=len(entries_wo_parent),
2010                field="weights",
2011            )
2012
2013        for wtype, entry in self:
2014            if entry is None:
2015                continue
2016
2017            assert hasattr(entry, "type")
2018            assert hasattr(entry, "parent")
2019            assert wtype == entry.type
2020            if (
2021                entry.parent is not None and entry.parent not in entries
2022            ):  # self reference checked for `parent` field
2023                raise ValueError(
2024                    f"`weights.{wtype}.parent={entry.parent} not in specified weight"
2025                    + f" formats: {entries}"
2026                )
2027
2028        return self
2029
2030
2031class ModelId(ResourceId):
2032    pass
2033
2034
2035class LinkedModel(LinkedResourceNode):
2036    """Reference to a bioimage.io model."""
2037
2038    id: ModelId
2039    """A valid model `id` from the bioimage.io collection."""
2040
2041
2042class _DataDepSize(NamedTuple):
2043    min: int
2044    max: Optional[int]
2045
2046
2047class _AxisSizes(NamedTuple):
2048    """the lenghts of all axes of model inputs and outputs"""
2049
2050    inputs: Dict[Tuple[TensorId, AxisId], int]
2051    outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]]
2052
2053
2054class _TensorSizes(NamedTuple):
2055    """_AxisSizes as nested dicts"""
2056
2057    inputs: Dict[TensorId, Dict[AxisId, int]]
2058    outputs: Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]
2059
2060
2061class ModelDescr(GenericModelDescrBase, title="bioimage.io model specification"):
2062    """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights.
2063    These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
2064    """
2065
2066    format_version: Literal["0.5.3"] = "0.5.3"
2067    """Version of the bioimage.io model description specification used.
2068    When creating a new model always use the latest micro/patch version described here.
2069    The `format_version` is important for any consumer software to understand how to parse the fields.
2070    """
2071
2072    type: Literal["model"] = "model"
2073    """Specialized resource type 'model'"""
2074
2075    id: Optional[ModelId] = None
2076    """bioimage.io-wide unique resource identifier
2077    assigned by bioimage.io; version **un**specific."""
2078
2079    authors: NotEmpty[List[Author]]
2080    """The authors are the creators of the model RDF and the primary points of contact."""
2081
2082    documentation: Annotated[
2083        DocumentationSource,
2084        Field(
2085            examples=[
2086                "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/README.md",
2087                "README.md",
2088            ],
2089        ),
2090    ]
2091    """∈📦 URL or relative path to a markdown file with additional documentation.
2092    The recommended documentation file name is `README.md`. An `.md` suffix is mandatory.
2093    The documentation should include a '#[#] Validation' (sub)section
2094    with details on how to quantitatively validate the model on unseen data."""
2095
2096    @field_validator("documentation", mode="after")
2097    @classmethod
2098    def _validate_documentation(cls, value: DocumentationSource) -> DocumentationSource:
2099        if not validation_context_var.get().perform_io_checks:
2100            return value
2101
2102        doc_path = download(value).path
2103        doc_content = doc_path.read_text(encoding="utf-8")
2104        assert isinstance(doc_content, str)
2105        if not re.match("#.*[vV]alidation", doc_content):
2106            issue_warning(
2107                "No '# Validation' (sub)section found in {value}.",
2108                value=value,
2109                field="documentation",
2110            )
2111
2112        return value
2113
2114    inputs: NotEmpty[Sequence[InputTensorDescr]]
2115    """Describes the input tensors expected by this model."""
2116
2117    @field_validator("inputs", mode="after")
2118    @classmethod
2119    def _validate_input_axes(
2120        cls, inputs: Sequence[InputTensorDescr]
2121    ) -> Sequence[InputTensorDescr]:
2122        input_size_refs = cls._get_axes_with_independent_size(inputs)
2123
2124        for i, ipt in enumerate(inputs):
2125            valid_independent_refs: Dict[
2126                Tuple[TensorId, AxisId],
2127                Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2128            ] = {
2129                **{
2130                    (ipt.id, a.id): (ipt, a, a.size)
2131                    for a in ipt.axes
2132                    if not isinstance(a, BatchAxis)
2133                    and isinstance(a.size, (int, ParameterizedSize))
2134                },
2135                **input_size_refs,
2136            }
2137            for a, ax in enumerate(ipt.axes):
2138                cls._validate_axis(
2139                    "inputs",
2140                    i=i,
2141                    tensor_id=ipt.id,
2142                    a=a,
2143                    axis=ax,
2144                    valid_independent_refs=valid_independent_refs,
2145                )
2146        return inputs
2147
2148    @staticmethod
2149    def _validate_axis(
2150        field_name: str,
2151        i: int,
2152        tensor_id: TensorId,
2153        a: int,
2154        axis: AnyAxis,
2155        valid_independent_refs: Dict[
2156            Tuple[TensorId, AxisId],
2157            Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2158        ],
2159    ):
2160        if isinstance(axis, BatchAxis) or isinstance(
2161            axis.size, (int, ParameterizedSize, DataDependentSize)
2162        ):
2163            return
2164        elif not isinstance(axis.size, SizeReference):
2165            assert_never(axis.size)
2166
2167        # validate axis.size SizeReference
2168        ref = (axis.size.tensor_id, axis.size.axis_id)
2169        if ref not in valid_independent_refs:
2170            raise ValueError(
2171                "Invalid tensor axis reference at"
2172                + f" {field_name}[{i}].axes[{a}].size: {axis.size}."
2173            )
2174        if ref == (tensor_id, axis.id):
2175            raise ValueError(
2176                "Self-referencing not allowed for"
2177                + f" {field_name}[{i}].axes[{a}].size: {axis.size}"
2178            )
2179        if axis.type == "channel":
2180            if valid_independent_refs[ref][1].type != "channel":
2181                raise ValueError(
2182                    "A channel axis' size may only reference another fixed size"
2183                    + " channel axis."
2184                )
2185            if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names:
2186                ref_size = valid_independent_refs[ref][2]
2187                assert isinstance(ref_size, int), (
2188                    "channel axis ref (another channel axis) has to specify fixed"
2189                    + " size"
2190                )
2191                generated_channel_names = [
2192                    Identifier(axis.channel_names.format(i=i))
2193                    for i in range(1, ref_size + 1)
2194                ]
2195                axis.channel_names = generated_channel_names
2196
2197        if (ax_unit := getattr(axis, "unit", None)) != (
2198            ref_unit := getattr(valid_independent_refs[ref][1], "unit", None)
2199        ):
2200            raise ValueError(
2201                "The units of an axis and its reference axis need to match, but"
2202                + f" '{ax_unit}' != '{ref_unit}'."
2203            )
2204        ref_axis = valid_independent_refs[ref][1]
2205        if isinstance(ref_axis, BatchAxis):
2206            raise ValueError(
2207                f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}"
2208                + " (a batch axis is not allowed as reference)."
2209            )
2210
2211        if isinstance(axis, WithHalo):
2212            min_size = axis.size.get_size(axis, ref_axis, n=0)
2213            if (min_size - 2 * axis.halo) < 1:
2214                raise ValueError(
2215                    f"axis {axis.id} with minimum size {min_size} is too small for halo"
2216                    + f" {axis.halo}."
2217                )
2218
2219            input_halo = axis.halo * axis.scale / ref_axis.scale
2220            if input_halo != int(input_halo) or input_halo % 2 == 1:
2221                raise ValueError(
2222                    f"input_halo {input_halo} (output_halo {axis.halo} *"
2223                    + f" output_scale {axis.scale} / input_scale {ref_axis.scale})"
2224                    + f" is not an even integer for {tensor_id}.{axis.id}."
2225                )
2226
2227    @model_validator(mode="after")
2228    def _validate_test_tensors(self) -> Self:
2229        if not validation_context_var.get().perform_io_checks:
2230            return self
2231
2232        test_arrays = [
2233            load_array(descr.test_tensor.download().path)
2234            for descr in chain(self.inputs, self.outputs)
2235        ]
2236        tensors = {
2237            descr.id: (descr, array)
2238            for descr, array in zip(chain(self.inputs, self.outputs), test_arrays)
2239        }
2240        validate_tensors(tensors, tensor_origin="test_tensor")
2241        return self
2242
2243    @model_validator(mode="after")
2244    def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self:
2245        ipt_refs = {t.id for t in self.inputs}
2246        out_refs = {t.id for t in self.outputs}
2247        for ipt in self.inputs:
2248            for p in ipt.preprocessing:
2249                ref = p.kwargs.get("reference_tensor")
2250                if ref is None:
2251                    continue
2252                if ref not in ipt_refs:
2253                    raise ValueError(
2254                        f"`reference_tensor` '{ref}' not found. Valid input tensor"
2255                        + f" references are: {ipt_refs}."
2256                    )
2257
2258        for out in self.outputs:
2259            for p in out.postprocessing:
2260                ref = p.kwargs.get("reference_tensor")
2261                if ref is None:
2262                    continue
2263
2264                if ref not in ipt_refs and ref not in out_refs:
2265                    raise ValueError(
2266                        f"`reference_tensor` '{ref}' not found. Valid tensor references"
2267                        + f" are: {ipt_refs | out_refs}."
2268                    )
2269
2270        return self
2271
2272    # TODO: use validate funcs in validate_test_tensors
2273    # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]:
2274
2275    name: Annotated[
2276        Annotated[
2277            str, RestrictCharacters(string.ascii_letters + string.digits + "_- ()")
2278        ],
2279        MinLen(5),
2280        MaxLen(128),
2281        warn(MaxLen(64), "Name longer than 64 characters.", INFO),
2282    ]
2283    """A human-readable name of this model.
2284    It should be no longer than 64 characters
2285    and may only contain letter, number, underscore, minus, parentheses and spaces.
2286    We recommend to chose a name that refers to the model's task and image modality.
2287    """
2288
2289    outputs: NotEmpty[Sequence[OutputTensorDescr]]
2290    """Describes the output tensors."""
2291
2292    @field_validator("outputs", mode="after")
2293    @classmethod
2294    def _validate_tensor_ids(
2295        cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo
2296    ) -> Sequence[OutputTensorDescr]:
2297        tensor_ids = [
2298            t.id for t in info.data.get("inputs", []) + info.data.get("outputs", [])
2299        ]
2300        duplicate_tensor_ids: List[str] = []
2301        seen: Set[str] = set()
2302        for t in tensor_ids:
2303            if t in seen:
2304                duplicate_tensor_ids.append(t)
2305
2306            seen.add(t)
2307
2308        if duplicate_tensor_ids:
2309            raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}")
2310
2311        return outputs
2312
2313    @staticmethod
2314    def _get_axes_with_parameterized_size(
2315        io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2316    ):
2317        return {
2318            f"{t.id}.{a.id}": (t, a, a.size)
2319            for t in io
2320            for a in t.axes
2321            if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize)
2322        }
2323
2324    @staticmethod
2325    def _get_axes_with_independent_size(
2326        io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2327    ):
2328        return {
2329            (t.id, a.id): (t, a, a.size)
2330            for t in io
2331            for a in t.axes
2332            if not isinstance(a, BatchAxis)
2333            and isinstance(a.size, (int, ParameterizedSize))
2334        }
2335
2336    @field_validator("outputs", mode="after")
2337    @classmethod
2338    def _validate_output_axes(
2339        cls, outputs: List[OutputTensorDescr], info: ValidationInfo
2340    ) -> List[OutputTensorDescr]:
2341        input_size_refs = cls._get_axes_with_independent_size(
2342            info.data.get("inputs", [])
2343        )
2344        output_size_refs = cls._get_axes_with_independent_size(outputs)
2345
2346        for i, out in enumerate(outputs):
2347            valid_independent_refs: Dict[
2348                Tuple[TensorId, AxisId],
2349                Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2350            ] = {
2351                **{
2352                    (out.id, a.id): (out, a, a.size)
2353                    for a in out.axes
2354                    if not isinstance(a, BatchAxis)
2355                    and isinstance(a.size, (int, ParameterizedSize))
2356                },
2357                **input_size_refs,
2358                **output_size_refs,
2359            }
2360            for a, ax in enumerate(out.axes):
2361                cls._validate_axis(
2362                    "outputs",
2363                    i,
2364                    out.id,
2365                    a,
2366                    ax,
2367                    valid_independent_refs=valid_independent_refs,
2368                )
2369
2370        return outputs
2371
2372    packaged_by: List[Author] = Field(default_factory=list)
2373    """The persons that have packaged and uploaded this model.
2374    Only required if those persons differ from the `authors`."""
2375
2376    parent: Optional[LinkedModel] = None
2377    """The model from which this model is derived, e.g. by fine-tuning the weights."""
2378
2379    # todo: add parent self check once we have `id`
2380    # @model_validator(mode="after")
2381    # def validate_parent_is_not_self(self) -> Self:
2382    #     if self.parent is not None and self.parent == self.id:
2383    #         raise ValueError("The model may not reference itself as parent model")
2384
2385    #     return self
2386
2387    run_mode: Annotated[
2388        Optional[RunMode],
2389        warn(None, "Run mode '{value}' has limited support across consumer softwares."),
2390    ] = None
2391    """Custom run mode for this model: for more complex prediction procedures like test time
2392    data augmentation that currently cannot be expressed in the specification.
2393    No standard run modes are defined yet."""
2394
2395    timestamp: Datetime = Datetime(datetime.now())
2396    """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format
2397    with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat).
2398    (In Python a datetime object is valid, too)."""
2399
2400    training_data: Annotated[
2401        Union[None, LinkedDataset, DatasetDescr, DatasetDescr02],
2402        Field(union_mode="left_to_right"),
2403    ] = None
2404    """The dataset used to train this model"""
2405
2406    weights: Annotated[WeightsDescr, WrapSerializer(package_weights)]
2407    """The weights for this model.
2408    Weights can be given for different formats, but should otherwise be equivalent.
2409    The available weight formats determine which consumers can use this model."""
2410
2411    @model_validator(mode="after")
2412    def _add_default_cover(self) -> Self:
2413        if not validation_context_var.get().perform_io_checks or self.covers:
2414            return self
2415
2416        try:
2417            generated_covers = generate_covers(
2418                [(t, load_array(t.test_tensor.download().path)) for t in self.inputs],
2419                [(t, load_array(t.test_tensor.download().path)) for t in self.outputs],
2420            )
2421        except Exception as e:
2422            issue_warning(
2423                "Failed to generate cover image(s): {e}",
2424                value=self.covers,
2425                msg_context=dict(e=e),
2426                field="covers",
2427            )
2428        else:
2429            self.covers.extend(generated_covers)
2430
2431        return self
2432
2433    def get_input_test_arrays(self) -> List[NDArray[Any]]:
2434        data = [load_array(ipt.test_tensor.download().path) for ipt in self.inputs]
2435        assert all(isinstance(d, np.ndarray) for d in data)
2436        return data
2437
2438    def get_output_test_arrays(self) -> List[NDArray[Any]]:
2439        data = [load_array(out.test_tensor.download().path) for out in self.outputs]
2440        assert all(isinstance(d, np.ndarray) for d in data)
2441        return data
2442
2443    @staticmethod
2444    def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
2445        batch_size = 1
2446        tensor_with_batchsize: Optional[TensorId] = None
2447        for tid in tensor_sizes:
2448            for aid, s in tensor_sizes[tid].items():
2449                if aid != BATCH_AXIS_ID or s == 1 or s == batch_size:
2450                    continue
2451
2452                if batch_size != 1:
2453                    assert tensor_with_batchsize is not None
2454                    raise ValueError(
2455                        f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})"
2456                    )
2457
2458                batch_size = s
2459                tensor_with_batchsize = tid
2460
2461        return batch_size
2462
2463    def get_output_tensor_sizes(
2464        self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]
2465    ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]:
2466        """Returns the tensor output sizes for given **input_sizes**.
2467        Only if **input_sizes** has a valid input shape, the tensor output size is exact.
2468        Otherwise it might be larger than the actual (valid) output"""
2469        batch_size = self.get_batch_size(input_sizes)
2470        ns = self.get_ns(input_sizes)
2471
2472        tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size)
2473        return tensor_sizes.outputs
2474
2475    def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
2476        """get parameter `n` for each parameterized axis
2477        such that the valid input size is >= the given input size"""
2478        ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {}
2479        axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs}
2480        for tid in input_sizes:
2481            for aid, s in input_sizes[tid].items():
2482                size_descr = axes[tid][aid].size
2483                if isinstance(size_descr, ParameterizedSize):
2484                    ret[(tid, aid)] = size_descr.get_n(s)
2485                elif size_descr is None or isinstance(size_descr, (int, SizeReference)):
2486                    pass
2487                else:
2488                    assert_never(size_descr)
2489
2490        return ret
2491
2492    def get_tensor_sizes(
2493        self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int
2494    ) -> _TensorSizes:
2495        axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size)
2496        return _TensorSizes(
2497            {
2498                t: {
2499                    aa: axis_sizes.inputs[(tt, aa)]
2500                    for tt, aa in axis_sizes.inputs
2501                    if tt == t
2502                }
2503                for t in {tt for tt, _ in axis_sizes.inputs}
2504            },
2505            {
2506                t: {
2507                    aa: axis_sizes.outputs[(tt, aa)]
2508                    for tt, aa in axis_sizes.outputs
2509                    if tt == t
2510                }
2511                for t in {tt for tt, _ in axis_sizes.outputs}
2512            },
2513        )
2514
2515    def get_axis_sizes(
2516        self,
2517        ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
2518        batch_size: Optional[int] = None,
2519        *,
2520        max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
2521    ) -> _AxisSizes:
2522        """Determine input and output block shape for scale factors **ns**
2523        of parameterized input sizes.
2524
2525        Args:
2526            ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id))
2527                that is parameterized as `size = min + n * step`.
2528            batch_size: The desired size of the batch dimension.
2529                If given **batch_size** overwrites any batch size present in
2530                **max_input_shape**. Default 1.
2531            max_input_shape: Limits the derived block shapes.
2532                Each axis for which the input size, parameterized by `n`, is larger
2533                than **max_input_shape** is set to the minimal value `n_min` for which
2534                this is still true.
2535                Use this for small input samples or large values of **ns**.
2536                Or simply whenever you know the full input shape.
2537
2538        Returns:
2539            Resolved axis sizes for model inputs and outputs.
2540        """
2541        max_input_shape = max_input_shape or {}
2542        if batch_size is None:
2543            for (_t_id, a_id), s in max_input_shape.items():
2544                if a_id == BATCH_AXIS_ID:
2545                    batch_size = s
2546                    break
2547            else:
2548                batch_size = 1
2549
2550        all_axes = {
2551            t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
2552        }
2553
2554        inputs: Dict[Tuple[TensorId, AxisId], int] = {}
2555        outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {}
2556
2557        def get_axis_size(a: Union[InputAxis, OutputAxis]):
2558            if isinstance(a, BatchAxis):
2559                if (t_descr.id, a.id) in ns:
2560                    logger.warning(
2561                        "Ignoring unexpected size increment factor (n) for batch axis"
2562                        + " of tensor '{}'.",
2563                        t_descr.id,
2564                    )
2565                return batch_size
2566            elif isinstance(a.size, int):
2567                if (t_descr.id, a.id) in ns:
2568                    logger.warning(
2569                        "Ignoring unexpected size increment factor (n) for fixed size"
2570                        + " axis '{}' of tensor '{}'.",
2571                        a.id,
2572                        t_descr.id,
2573                    )
2574                return a.size
2575            elif isinstance(a.size, ParameterizedSize):
2576                if (t_descr.id, a.id) not in ns:
2577                    raise ValueError(
2578                        "Size increment factor (n) missing for parametrized axis"
2579                        + f" '{a.id}' of tensor '{t_descr.id}'."
2580                    )
2581                n = ns[(t_descr.id, a.id)]
2582                s_max = max_input_shape.get((t_descr.id, a.id))
2583                if s_max is not None:
2584                    n = min(n, a.size.get_n(s_max))
2585
2586                return a.size.get_size(n)
2587
2588            elif isinstance(a.size, SizeReference):
2589                if (t_descr.id, a.id) in ns:
2590                    logger.warning(
2591                        "Ignoring unexpected size increment factor (n) for axis '{}'"
2592                        + " of tensor '{}' with size reference.",
2593                        a.id,
2594                        t_descr.id,
2595                    )
2596                assert not isinstance(a, BatchAxis)
2597                ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
2598                assert not isinstance(ref_axis, BatchAxis)
2599                ref_key = (a.size.tensor_id, a.size.axis_id)
2600                ref_size = inputs.get(ref_key, outputs.get(ref_key))
2601                assert ref_size is not None, ref_key
2602                assert not isinstance(ref_size, _DataDepSize), ref_key
2603                return a.size.get_size(
2604                    axis=a,
2605                    ref_axis=ref_axis,
2606                    ref_size=ref_size,
2607                )
2608            elif isinstance(a.size, DataDependentSize):
2609                if (t_descr.id, a.id) in ns:
2610                    logger.warning(
2611                        "Ignoring unexpected increment factor (n) for data dependent"
2612                        + " size axis '{}' of tensor '{}'.",
2613                        a.id,
2614                        t_descr.id,
2615                    )
2616                return _DataDepSize(a.size.min, a.size.max)
2617            else:
2618                assert_never(a.size)
2619
2620        # first resolve all , but the `SizeReference` input sizes
2621        for t_descr in self.inputs:
2622            for a in t_descr.axes:
2623                if not isinstance(a.size, SizeReference):
2624                    s = get_axis_size(a)
2625                    assert not isinstance(s, _DataDepSize)
2626                    inputs[t_descr.id, a.id] = s
2627
2628        # resolve all other input axis sizes
2629        for t_descr in self.inputs:
2630            for a in t_descr.axes:
2631                if isinstance(a.size, SizeReference):
2632                    s = get_axis_size(a)
2633                    assert not isinstance(s, _DataDepSize)
2634                    inputs[t_descr.id, a.id] = s
2635
2636        # resolve all output axis sizes
2637        for t_descr in self.outputs:
2638            for a in t_descr.axes:
2639                assert not isinstance(a.size, ParameterizedSize)
2640                s = get_axis_size(a)
2641                outputs[t_descr.id, a.id] = s
2642
2643        return _AxisSizes(inputs=inputs, outputs=outputs)
2644
2645    @model_validator(mode="before")
2646    @classmethod
2647    def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]:
2648        if (
2649            data.get("type") == "model"
2650            and isinstance(fv := data.get("format_version"), str)
2651            and fv.count(".") == 2
2652        ):
2653            fv_parts = fv.split(".")
2654            if any(not p.isdigit() for p in fv_parts):
2655                return data
2656
2657            fv_tuple = tuple(map(int, fv_parts))
2658
2659            assert cls.implemented_format_version_tuple[0:2] == (0, 5)
2660            if fv_tuple[:2] in ((0, 3), (0, 4)):
2661                m04 = _ModelDescr_v0_4.load(data)
2662                if not isinstance(m04, InvalidDescr):
2663                    return _model_conv.convert_as_dict(m04)
2664            elif fv_tuple[:2] == (0, 5):
2665                # bump patch version
2666                data["format_version"] = cls.implemented_format_version
2667
2668        return data
2669
2670
2671class _ModelConv(Converter[_ModelDescr_v0_4, ModelDescr]):
2672    def _convert(
2673        self, src: _ModelDescr_v0_4, tgt: "type[ModelDescr] | type[dict[str, Any]]"
2674    ) -> "ModelDescr | dict[str, Any]":
2675        name = "".join(
2676            c if c in string.ascii_letters + string.digits + "_- ()" else " "
2677            for c in src.name
2678        )
2679
2680        def conv_authors(auths: Optional[Sequence[_Author_v0_4]]):
2681            conv = (
2682                _author_conv.convert if TYPE_CHECKING else _author_conv.convert_as_dict
2683            )
2684            return None if auths is None else [conv(a) for a in auths]
2685
2686        if TYPE_CHECKING:
2687            arch_file_conv = _arch_file_conv.convert
2688            arch_lib_conv = _arch_lib_conv.convert
2689        else:
2690            arch_file_conv = _arch_file_conv.convert_as_dict
2691            arch_lib_conv = _arch_lib_conv.convert_as_dict
2692
2693        input_size_refs = {
2694            ipt.name: {
2695                a: s
2696                for a, s in zip(
2697                    ipt.axes,
2698                    (
2699                        ipt.shape.min
2700                        if isinstance(ipt.shape, _ParameterizedInputShape_v0_4)
2701                        else ipt.shape
2702                    ),
2703                )
2704            }
2705            for ipt in src.inputs
2706            if ipt.shape
2707        }
2708        output_size_refs = {
2709            **{
2710                out.name: {a: s for a, s in zip(out.axes, out.shape)}
2711                for out in src.outputs
2712                if not isinstance(out.shape, _ImplicitOutputShape_v0_4)
2713            },
2714            **input_size_refs,
2715        }
2716
2717        return tgt(
2718            attachments=(
2719                []
2720                if src.attachments is None
2721                else [FileDescr(source=f) for f in src.attachments.files]
2722            ),
2723            authors=[
2724                _author_conv.convert_as_dict(a) for a in src.authors
2725            ],  # pyright: ignore[reportArgumentType]
2726            cite=[
2727                {"text": c.text, "doi": c.doi, "url": c.url} for c in src.cite
2728            ],  # pyright: ignore[reportArgumentType]
2729            config=src.config,
2730            covers=src.covers,
2731            description=src.description,
2732            documentation=src.documentation,
2733            format_version="0.5.3",
2734            git_repo=src.git_repo,  # pyright: ignore[reportArgumentType]
2735            icon=src.icon,
2736            id=None if src.id is None else ModelId(src.id),
2737            id_emoji=src.id_emoji,
2738            license=src.license,  # type: ignore
2739            links=src.links,
2740            maintainers=[
2741                _maintainer_conv.convert_as_dict(m) for m in src.maintainers
2742            ],  # pyright: ignore[reportArgumentType]
2743            name=name,
2744            tags=src.tags,
2745            type=src.type,
2746            uploader=src.uploader,
2747            version=src.version,
2748            inputs=[  # pyright: ignore[reportArgumentType]
2749                _input_tensor_conv.convert_as_dict(ipt, tt, st, input_size_refs)
2750                for ipt, tt, st, in zip(
2751                    src.inputs,
2752                    src.test_inputs,
2753                    src.sample_inputs or [None] * len(src.test_inputs),
2754                )
2755            ],
2756            outputs=[  # pyright: ignore[reportArgumentType]
2757                _output_tensor_conv.convert_as_dict(out, tt, st, output_size_refs)
2758                for out, tt, st, in zip(
2759                    src.outputs,
2760                    src.test_outputs,
2761                    src.sample_outputs or [None] * len(src.test_outputs),
2762                )
2763            ],
2764            parent=(
2765                None
2766                if src.parent is None
2767                else LinkedModel(
2768                    id=ModelId(
2769                        str(src.parent.id)
2770                        + (
2771                            ""
2772                            if src.parent.version_number is None
2773                            else f"/{src.parent.version_number}"
2774                        )
2775                    )
2776                )
2777            ),
2778            training_data=(
2779                None
2780                if src.training_data is None
2781                else (
2782                    LinkedDataset(
2783                        id=DatasetId(
2784                            str(src.training_data.id)
2785                            + (
2786                                ""
2787                                if src.training_data.version_number is None
2788                                else f"/{src.training_data.version_number}"
2789                            )
2790                        )
2791                    )
2792                    if isinstance(src.training_data, LinkedDataset02)
2793                    else src.training_data
2794                )
2795            ),
2796            packaged_by=[
2797                _author_conv.convert_as_dict(a) for a in src.packaged_by
2798            ],  # pyright: ignore[reportArgumentType]
2799            run_mode=src.run_mode,
2800            timestamp=src.timestamp,
2801            weights=(WeightsDescr if TYPE_CHECKING else dict)(
2802                keras_hdf5=(w := src.weights.keras_hdf5)
2803                and (KerasHdf5WeightsDescr if TYPE_CHECKING else dict)(
2804                    authors=conv_authors(w.authors),
2805                    source=w.source,
2806                    tensorflow_version=w.tensorflow_version or Version("1.15"),
2807                    parent=w.parent,
2808                ),
2809                onnx=(w := src.weights.onnx)
2810                and (OnnxWeightsDescr if TYPE_CHECKING else dict)(
2811                    source=w.source,
2812                    authors=conv_authors(w.authors),
2813                    parent=w.parent,
2814                    opset_version=w.opset_version or 15,
2815                ),
2816                pytorch_state_dict=(w := src.weights.pytorch_state_dict)
2817                and (PytorchStateDictWeightsDescr if TYPE_CHECKING else dict)(
2818                    source=w.source,
2819                    authors=conv_authors(w.authors),
2820                    parent=w.parent,
2821                    architecture=(
2822                        arch_file_conv(
2823                            w.architecture,
2824                            w.architecture_sha256,
2825                            w.kwargs,
2826                        )
2827                        if isinstance(w.architecture, _CallableFromFile_v0_4)
2828                        else arch_lib_conv(w.architecture, w.kwargs)
2829                    ),
2830                    pytorch_version=w.pytorch_version or Version("1.10"),
2831                    dependencies=(
2832                        None
2833                        if w.dependencies is None
2834                        else (EnvironmentFileDescr if TYPE_CHECKING else dict)(
2835                            source=cast(
2836                                ImportantFileSource,
2837                                str(deps := w.dependencies)[
2838                                    (
2839                                        len("conda:")
2840                                        if str(deps).startswith("conda:")
2841                                        else 0
2842                                    ) :
2843                                ],
2844                            )
2845                        )
2846                    ),
2847                ),
2848                tensorflow_js=(w := src.weights.tensorflow_js)
2849                and (TensorflowJsWeightsDescr if TYPE_CHECKING else dict)(
2850                    source=w.source,
2851                    authors=conv_authors(w.authors),
2852                    parent=w.parent,
2853                    tensorflow_version=w.tensorflow_version or Version("1.15"),
2854                ),
2855                tensorflow_saved_model_bundle=(
2856                    w := src.weights.tensorflow_saved_model_bundle
2857                )
2858                and (TensorflowSavedModelBundleWeightsDescr if TYPE_CHECKING else dict)(
2859                    authors=conv_authors(w.authors),
2860                    parent=w.parent,
2861                    source=w.source,
2862                    tensorflow_version=w.tensorflow_version or Version("1.15"),
2863                    dependencies=(
2864                        None
2865                        if w.dependencies is None
2866                        else (EnvironmentFileDescr if TYPE_CHECKING else dict)(
2867                            source=cast(
2868                                ImportantFileSource,
2869                                (
2870                                    str(w.dependencies)[len("conda:") :]
2871                                    if str(w.dependencies).startswith("conda:")
2872                                    else str(w.dependencies)
2873                                ),
2874                            )
2875                        )
2876                    ),
2877                ),
2878                torchscript=(w := src.weights.torchscript)
2879                and (TorchscriptWeightsDescr if TYPE_CHECKING else dict)(
2880                    source=w.source,
2881                    authors=conv_authors(w.authors),
2882                    parent=w.parent,
2883                    pytorch_version=w.pytorch_version or Version("1.10"),
2884                ),
2885            ),
2886        )
2887
2888
2889_model_conv = _ModelConv(_ModelDescr_v0_4, ModelDescr)
2890
2891
2892# create better cover images for 3d data and non-image outputs
2893def generate_covers(
2894    inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]],
2895    outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]],
2896) -> List[Path]:
2897    def squeeze(
2898        data: NDArray[Any], axes: Sequence[AnyAxis]
2899    ) -> Tuple[NDArray[Any], List[AnyAxis]]:
2900        """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining"""
2901        if data.ndim != len(axes):
2902            raise ValueError(
2903                f"tensor shape {data.shape} does not match described axes"
2904                + f" {[a.id for a in axes]}"
2905            )
2906
2907        axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1]
2908        return data.squeeze(), axes
2909
2910    def normalize(
2911        data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7
2912    ) -> NDArray[np.float32]:
2913        data = data.astype("float32")
2914        data -= data.min(axis=axis, keepdims=True)
2915        data /= data.max(axis=axis, keepdims=True) + eps
2916        return data
2917
2918    def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]):
2919        original_shape = data.shape
2920        data, axes = squeeze(data, axes)
2921
2922        # take slice fom any batch or index axis if needed
2923        # and convert the first channel axis and take a slice from any additional channel axes
2924        slices: Tuple[slice, ...] = ()
2925        ndim = data.ndim
2926        ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2
2927        has_c_axis = False
2928        for i, a in enumerate(axes):
2929            s = data.shape[i]
2930            assert s > 1
2931            if (
2932                isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis))
2933                and ndim > ndim_need
2934            ):
2935                data = data[slices + (slice(s // 2 - 1, s // 2),)]
2936                ndim -= 1
2937            elif isinstance(a, ChannelAxis):
2938                if has_c_axis:
2939                    # second channel axis
2940                    data = data[slices + (slice(0, 1),)]
2941                    ndim -= 1
2942                else:
2943                    has_c_axis = True
2944                    if s == 2:
2945                        # visualize two channels with cyan and magenta
2946                        data = np.concatenate(
2947                            [
2948                                data[slices + (slice(1, 2),)],
2949                                data[slices + (slice(0, 1),)],
2950                                (
2951                                    data[slices + (slice(0, 1),)]
2952                                    + data[slices + (slice(1, 2),)]
2953                                )
2954                                / 2,  # TODO: take maximum instead?
2955                            ],
2956                            axis=i,
2957                        )
2958                    elif data.shape[i] == 3:
2959                        pass  # visualize 3 channels as RGB
2960                    else:
2961                        # visualize first 3 channels as RGB
2962                        data = data[slices + (slice(3),)]
2963
2964                    assert data.shape[i] == 3
2965
2966            slices += (slice(None),)
2967
2968        data, axes = squeeze(data, axes)
2969        assert len(axes) == ndim
2970        # take slice from z axis if needed
2971        slices = ()
2972        if ndim > ndim_need:
2973            for i, a in enumerate(axes):
2974                s = data.shape[i]
2975                if a.id == AxisId("z"):
2976                    data = data[slices + (slice(s // 2 - 1, s // 2),)]
2977                    data, axes = squeeze(data, axes)
2978                    ndim -= 1
2979                    break
2980
2981            slices += (slice(None),)
2982
2983        # take slice from any space or time axis
2984        slices = ()
2985
2986        for i, a in enumerate(axes):
2987            if ndim <= ndim_need:
2988                break
2989
2990            s = data.shape[i]
2991            assert s > 1
2992            if isinstance(
2993                a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis)
2994            ):
2995                data = data[slices + (slice(s // 2 - 1, s // 2),)]
2996                ndim -= 1
2997
2998            slices += (slice(None),)
2999
3000        del slices
3001        data, axes = squeeze(data, axes)
3002        assert len(axes) == ndim
3003
3004        if (has_c_axis and ndim != 3) or ndim != 2:
3005            raise ValueError(
3006                f"Failed to construct cover image from shape {original_shape}"
3007            )
3008
3009        if not has_c_axis:
3010            assert ndim == 2
3011            data = np.repeat(data[:, :, None], 3, axis=2)
3012            axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB"))))
3013            ndim += 1
3014
3015        assert ndim == 3
3016
3017        # transpose axis order such that longest axis comes first...
3018        axis_order = list(np.argsort(list(data.shape)))
3019        axis_order.reverse()
3020        # ... and channel axis is last
3021        c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0]
3022        axis_order.append(axis_order.pop(c))
3023        axes = [axes[ao] for ao in axis_order]
3024        data = data.transpose(axis_order)
3025
3026        # h, w = data.shape[:2]
3027        # if h / w  in (1.0 or 2.0):
3028        #     pass
3029        # elif h / w < 2:
3030        # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images
3031
3032        norm_along = (
3033            tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None
3034        )
3035        # normalize the data and map to 8 bit
3036        data = normalize(data, norm_along)
3037        data = (data * 255).astype("uint8")
3038
3039        return data
3040
3041    def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]):
3042        assert im0.dtype == im1.dtype == np.uint8
3043        assert im0.shape == im1.shape
3044        assert im0.ndim == 3
3045        N, M, C = im0.shape
3046        assert C == 3
3047        out = np.ones((N, M, C), dtype="uint8")
3048        for c in range(C):
3049            outc = np.tril(im0[..., c])
3050            mask = outc == 0
3051            outc[mask] = np.triu(im1[..., c])[mask]
3052            out[..., c] = outc
3053
3054        return out
3055
3056    ipt_descr, ipt = inputs[0]
3057    out_descr, out = outputs[0]
3058
3059    ipt_img = to_2d_image(ipt, ipt_descr.axes)
3060    out_img = to_2d_image(out, out_descr.axes)
3061
3062    cover_folder = Path(mkdtemp())
3063    if ipt_img.shape == out_img.shape:
3064        covers = [cover_folder / "cover.png"]
3065        imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img))
3066    else:
3067        covers = [cover_folder / "input.png", cover_folder / "output.png"]
3068        imwrite(covers[0], ipt_img)
3069        imwrite(covers[1], out_img)
3070
3071    return covers
SpaceUnit = typing.Literal['attometer', 'angstrom', 'centimeter', 'decimeter', 'exameter', 'femtometer', 'foot', 'gigameter', 'hectometer', 'inch', 'kilometer', 'megameter', 'meter', 'micrometer', 'mile', 'millimeter', 'nanometer', 'parsec', 'petameter', 'picometer', 'terameter', 'yard', 'yoctometer', 'yottameter', 'zeptometer', 'zettameter']
TimeUnit = typing.Literal['attosecond', 'centisecond', 'day', 'decisecond', 'exasecond', 'femtosecond', 'gigasecond', 'hectosecond', 'hour', 'kilosecond', 'megasecond', 'microsecond', 'millisecond', 'minute', 'nanosecond', 'petasecond', 'picosecond', 'second', 'terasecond', 'yoctosecond', 'yottasecond', 'zeptosecond', 'zettasecond']
AxisType = typing.Literal['batch', 'channel', 'index', 'time', 'space']
class TensorId(bioimageio.spec._internal.types.LowerCaseIdentifier):
193class TensorId(LowerCaseIdentifier):
194    root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
195        Annotated[LowerCaseIdentifierAnno, MaxLen(32)]
196    ]

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.

root_model: ClassVar[Type[pydantic.root_model.RootModel[Any]]] = <class 'pydantic.root_model.RootModel[Annotated[str, MinLen, AfterValidator, AfterValidator, Annotated[TypeVar, Predicate], MaxLen]]'>

the pydantic root model to validate the string

class AxisId(bioimageio.spec._internal.types.LowerCaseIdentifier):
199class AxisId(LowerCaseIdentifier):
200    root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
201        Annotated[LowerCaseIdentifierAnno, MaxLen(16)]
202    ]

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.

root_model: ClassVar[Type[pydantic.root_model.RootModel[Any]]] = <class 'pydantic.root_model.RootModel[Annotated[str, MinLen, AfterValidator, AfterValidator, Annotated[TypeVar, Predicate], MaxLen]]'>

the pydantic root model to validate the string

NonBatchAxisId = typing.Annotated[AxisId, Predicate(_is_not_batch)]
PostprocessingId = typing.Literal['binarize', 'clip', 'ensure_dtype', 'fixed_zero_mean_unit_variance', 'scale_linear', 'scale_mean_variance', 'scale_range', 'sigmoid', 'zero_mean_unit_variance']
PreprocessingId = typing.Literal['binarize', 'clip', 'ensure_dtype', 'scale_linear', 'sigmoid', 'zero_mean_unit_variance', 'scale_range']
SAME_AS_TYPE = '<same as type>'
ParameterizedSize_N = <class 'int'>
class ParameterizedSize(bioimageio.spec._internal.node.Node):
243class ParameterizedSize(Node):
244    """Describes a range of valid tensor axis sizes as `size = min + n*step`."""
245
246    N: ClassVar[Type[int]] = ParameterizedSize_N
247    """integer to parameterize this axis"""
248
249    min: Annotated[int, Gt(0)]
250    step: Annotated[int, Gt(0)]
251
252    def validate_size(self, size: int) -> int:
253        if size < self.min:
254            raise ValueError(f"size {size} < {self.min}")
255        if (size - self.min) % self.step != 0:
256            raise ValueError(
257                f"axis of size {size} is not parameterized by `min + n*step` ="
258                + f" `{self.min} + n*{self.step}`"
259            )
260
261        return size
262
263    def get_size(self, n: ParameterizedSize_N) -> int:
264        return self.min + self.step * n
265
266    def get_n(self, s: int) -> ParameterizedSize_N:
267        """return smallest n parameterizing a size greater or equal than `s`"""
268        return ceil((s - self.min) / self.step)

Describes a range of valid tensor axis sizes as size = min + n*step.

N: ClassVar[Type[int]] = <class 'int'>

integer to parameterize this axis

min: Annotated[int, Gt(gt=0)]
step: Annotated[int, Gt(gt=0)]
def validate_size(self, size: int) -> int:
252    def validate_size(self, size: int) -> int:
253        if size < self.min:
254            raise ValueError(f"size {size} < {self.min}")
255        if (size - self.min) % self.step != 0:
256            raise ValueError(
257                f"axis of size {size} is not parameterized by `min + n*step` ="
258                + f" `{self.min} + n*{self.step}`"
259            )
260
261        return size
def get_size(self, n: int) -> int:
263    def get_size(self, n: ParameterizedSize_N) -> int:
264        return self.min + self.step * n
def get_n(self, s: int) -> int:
266    def get_n(self, s: int) -> ParameterizedSize_N:
267        """return smallest n parameterizing a size greater or equal than `s`"""
268        return ceil((s - self.min) / self.step)

return smallest n parameterizing a size greater or equal than s

class DataDependentSize(bioimageio.spec._internal.node.Node):
271class DataDependentSize(Node):
272    min: Annotated[int, Gt(0)] = 1
273    max: Annotated[Optional[int], Gt(1)] = None
274
275    @model_validator(mode="after")
276    def _validate_max_gt_min(self):
277        if self.max is not None and self.min >= self.max:
278            raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}")
279
280        return self
281
282    def validate_size(self, size: int) -> int:
283        if size < self.min:
284            raise ValueError(f"size {size} < {self.min}")
285
286        if self.max is not None and size > self.max:
287            raise ValueError(f"size {size} > {self.max}")
288
289        return size

Subpart of a resource description

min: Annotated[int, Gt(gt=0)]
max: Annotated[Optional[int], Gt(gt=1)]
def validate_size(self, size: int) -> int:
282    def validate_size(self, size: int) -> int:
283        if size < self.min:
284            raise ValueError(f"size {size} < {self.min}")
285
286        if self.max is not None and size > self.max:
287            raise ValueError(f"size {size} > {self.max}")
288
289        return size
class SizeReference(bioimageio.spec._internal.node.Node):
292class SizeReference(Node):
293    """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis.
294
295    `axis.size = reference.size * reference.scale / axis.scale + offset`
296
297    note:
298    1. The axis and the referenced axis need to have the same unit (or no unit).
299    2. Batch axes may not be referenced.
300    3. Fractions are rounded down.
301    4. If the reference axis is `concatenable` the referencing axis is assumed to be
302        `concatenable` as well with the same block order.
303
304    example:
305    An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm².
306    Let's assume that we want to express the image height h in relation to its width w
307    instead of only accepting input images of exactly 100*49 pixels
308    (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`).
309
310    >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2)
311    >>> h = SpaceInputAxis(
312    ...     id=AxisId("h"),
313    ...     size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1),
314    ...     unit="millimeter",
315    ...     scale=4,
316    ... )
317    >>> print(h.size.compute(h, w))
318    49
319
320    -> h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49
321    """
322
323    tensor_id: TensorId
324    """tensor id of the reference axis"""
325
326    axis_id: AxisId
327    """axis id of the reference axis"""
328
329    offset: int = 0
330
331    def get_size(
332        self,
333        axis: Union[
334            ChannelAxis,
335            IndexInputAxis,
336            IndexOutputAxis,
337            TimeInputAxis,
338            SpaceInputAxis,
339            TimeOutputAxis,
340            TimeOutputAxisWithHalo,
341            SpaceOutputAxis,
342            SpaceOutputAxisWithHalo,
343        ],
344        ref_axis: Union[
345            ChannelAxis,
346            IndexInputAxis,
347            IndexOutputAxis,
348            TimeInputAxis,
349            SpaceInputAxis,
350            TimeOutputAxis,
351            TimeOutputAxisWithHalo,
352            SpaceOutputAxis,
353            SpaceOutputAxisWithHalo,
354        ],
355        n: ParameterizedSize_N = 0,
356        ref_size: Optional[int] = None,
357    ):
358        """Compute the concrete size for a given axis and its reference axis.
359
360        Args:
361            axis: The axis this `SizeReference` is the size of.
362            ref_axis: The reference axis to compute the size from.
363            n: If the **ref_axis** is parameterized (of type `ParameterizedSize`)
364                and no fixed **ref_size** is given,
365                **n** is used to compute the size of the parameterized **ref_axis**.
366            ref_size: Overwrite the reference size instead of deriving it from
367                **ref_axis**
368                (**ref_axis.scale** is still used; any given **n** is ignored).
369        """
370        assert (
371            axis.size == self
372        ), "Given `axis.size` is not defined by this `SizeReference`"
373
374        assert (
375            ref_axis.id == self.axis_id
376        ), f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}."
377
378        assert axis.unit == ref_axis.unit, (
379            "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`,"
380            f" but {axis.unit}!={ref_axis.unit}"
381        )
382        if ref_size is None:
383            if isinstance(ref_axis.size, (int, float)):
384                ref_size = ref_axis.size
385            elif isinstance(ref_axis.size, ParameterizedSize):
386                ref_size = ref_axis.size.get_size(n)
387            elif isinstance(ref_axis.size, DataDependentSize):
388                raise ValueError(
389                    "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`."
390                )
391            elif isinstance(ref_axis.size, SizeReference):
392                raise ValueError(
393                    "Reference axis referenced in `SizeReference` may not be sized by a"
394                    + " `SizeReference` itself."
395                )
396            else:
397                assert_never(ref_axis.size)
398
399        return int(ref_size * ref_axis.scale / axis.scale + self.offset)
400
401    @staticmethod
402    def _get_unit(
403        axis: Union[
404            ChannelAxis,
405            IndexInputAxis,
406            IndexOutputAxis,
407            TimeInputAxis,
408            SpaceInputAxis,
409            TimeOutputAxis,
410            TimeOutputAxisWithHalo,
411            SpaceOutputAxis,
412            SpaceOutputAxisWithHalo,
413        ],
414    ):
415        return axis.unit

A tensor axis size (extent in pixels/frames) defined in relation to a reference axis.

axis.size = reference.size * reference.scale / axis.scale + offset

note:

  1. The axis and the referenced axis need to have the same unit (or no unit).
  2. Batch axes may not be referenced.
  3. Fractions are rounded down.
  4. If the reference axis is concatenable the referencing axis is assumed to be concatenable as well with the same block order.

example: An unisotropic input image of wh=10049 pixels depicts a phsical space of 200196mm². Let's assume that we want to express the image height h in relation to its width w instead of only accepting input images of exactly 10049 pixels (for example to express a range of valid image shapes by parametrizing w, see ParameterizedSize).

>>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2)
>>> h = SpaceInputAxis(
...     id=AxisId("h"),
...     size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1),
...     unit="millimeter",
...     scale=4,
... )
>>> print(h.size.compute(h, w))
49

-> h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49

tensor_id: TensorId

tensor id of the reference axis

axis_id: AxisId

axis id of the reference axis

offset: int
331    def get_size(
332        self,
333        axis: Union[
334            ChannelAxis,
335            IndexInputAxis,
336            IndexOutputAxis,
337            TimeInputAxis,
338            SpaceInputAxis,
339            TimeOutputAxis,
340            TimeOutputAxisWithHalo,
341            SpaceOutputAxis,
342            SpaceOutputAxisWithHalo,
343        ],
344        ref_axis: Union[
345            ChannelAxis,
346            IndexInputAxis,
347            IndexOutputAxis,
348            TimeInputAxis,
349            SpaceInputAxis,
350            TimeOutputAxis,
351            TimeOutputAxisWithHalo,
352            SpaceOutputAxis,
353            SpaceOutputAxisWithHalo,
354        ],
355        n: ParameterizedSize_N = 0,
356        ref_size: Optional[int] = None,
357    ):
358        """Compute the concrete size for a given axis and its reference axis.
359
360        Args:
361            axis: The axis this `SizeReference` is the size of.
362            ref_axis: The reference axis to compute the size from.
363            n: If the **ref_axis** is parameterized (of type `ParameterizedSize`)
364                and no fixed **ref_size** is given,
365                **n** is used to compute the size of the parameterized **ref_axis**.
366            ref_size: Overwrite the reference size instead of deriving it from
367                **ref_axis**
368                (**ref_axis.scale** is still used; any given **n** is ignored).
369        """
370        assert (
371            axis.size == self
372        ), "Given `axis.size` is not defined by this `SizeReference`"
373
374        assert (
375            ref_axis.id == self.axis_id
376        ), f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}."
377
378        assert axis.unit == ref_axis.unit, (
379            "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`,"
380            f" but {axis.unit}!={ref_axis.unit}"
381        )
382        if ref_size is None:
383            if isinstance(ref_axis.size, (int, float)):
384                ref_size = ref_axis.size
385            elif isinstance(ref_axis.size, ParameterizedSize):
386                ref_size = ref_axis.size.get_size(n)
387            elif isinstance(ref_axis.size, DataDependentSize):
388                raise ValueError(
389                    "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`."
390                )
391            elif isinstance(ref_axis.size, SizeReference):
392                raise ValueError(
393                    "Reference axis referenced in `SizeReference` may not be sized by a"
394                    + " `SizeReference` itself."
395                )
396            else:
397                assert_never(ref_axis.size)
398
399        return int(ref_size * ref_axis.scale / axis.scale + self.offset)

Compute the concrete size for a given axis and its reference axis.

Arguments:
  • axis: The axis this SizeReference is the size of.
  • ref_axis: The reference axis to compute the size from.
  • n: If the ref_axis is parameterized (of type ParameterizedSize) and no fixed ref_size is given, n is used to compute the size of the parameterized ref_axis.
  • ref_size: Overwrite the reference size instead of deriving it from ref_axis (ref_axis.scale is still used; any given n is ignored).
class AxisBase(bioimageio.spec._internal.common_nodes.NodeWithExplicitlySetFields):
420class AxisBase(NodeWithExplicitlySetFields):
421    fields_to_set_explicitly: ClassVar[FrozenSet[LiteralString]] = frozenset({"type"})
422
423    id: AxisId
424    """An axis id unique across all axes of one tensor."""
425
426    description: Annotated[str, MaxLen(128)] = ""

Subpart of a resource description

fields_to_set_explicitly: ClassVar[FrozenSet[LiteralString]] = frozenset({'type'})

set set these fields explicitly with their default value if they are not set, such that they are always included even when dumping with 'exlude_unset'

id: AxisId

An axis id unique across all axes of one tensor.

description: Annotated[str, MaxLen(max_length=128)]
class WithHalo(bioimageio.spec._internal.node.Node):
429class WithHalo(Node):
430    halo: Annotated[int, Ge(1)]
431    """The halo should be cropped from the output tensor to avoid boundary effects.
432    It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`.
433    To document a halo that is already cropped by the model use `size.offset` instead."""
434
435    size: Annotated[
436        SizeReference,
437        Field(
438            examples=[
439                10,
440                SizeReference(
441                    tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
442                ).model_dump(mode="json"),
443            ]
444        ),
445    ]
446    """reference to another axis with an optional offset (see `SizeReference`)"""

Subpart of a resource description

halo: Annotated[int, Ge(ge=1)]

The halo should be cropped from the output tensor to avoid boundary effects. It is to be cropped from both sides, i.e. size_after_crop = size - 2 * halo. To document a halo that is already cropped by the model use size.offset instead.

size: Annotated[SizeReference, FieldInfo(annotation=NoneType, required=True, examples=[10, {'tensor_id': 't', 'axis_id': 'a', 'offset': 5}])]

reference to another axis with an optional offset (see SizeReference)

BATCH_AXIS_ID = 'batch'
class BatchAxis(AxisBase):
452class BatchAxis(AxisBase):
453    type: Literal["batch"] = "batch"
454    id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID
455    size: Optional[Literal[1]] = None
456    """The batch size may be fixed to 1,
457    otherwise (the default) it may be chosen arbitrarily depending on available memory"""
458
459    @property
460    def scale(self):
461        return 1.0
462
463    @property
464    def concatenable(self):
465        return True
466
467    @property
468    def unit(self):
469        return None

Subpart of a resource description

type: Literal['batch']
id: Annotated[AxisId, Predicate(_is_batch)]

An axis id unique across all axes of one tensor.

size: Optional[Literal[1]]

The batch size may be fixed to 1, otherwise (the default) it may be chosen arbitrarily depending on available memory

scale
459    @property
460    def scale(self):
461        return 1.0
concatenable
463    @property
464    def concatenable(self):
465        return True
unit
467    @property
468    def unit(self):
469        return None
class ChannelAxis(AxisBase):
472class ChannelAxis(AxisBase):
473    type: Literal["channel"] = "channel"
474    id: NonBatchAxisId = AxisId("channel")
475    channel_names: NotEmpty[List[Identifier]]
476
477    @property
478    def size(self) -> int:
479        return len(self.channel_names)
480
481    @property
482    def concatenable(self):
483        return False
484
485    @property
486    def scale(self) -> float:
487        return 1.0
488
489    @property
490    def unit(self):
491        return None

Subpart of a resource description

type: Literal['channel']
id: Annotated[AxisId, Predicate(_is_not_batch)]

An axis id unique across all axes of one tensor.

channel_names: Annotated[List[bioimageio.spec._internal.types.Identifier], MinLen(min_length=1)]
size: int
477    @property
478    def size(self) -> int:
479        return len(self.channel_names)
concatenable
481    @property
482    def concatenable(self):
483        return False
scale: float
485    @property
486    def scale(self) -> float:
487        return 1.0
unit
489    @property
490    def unit(self):
491        return None
class IndexAxisBase(AxisBase):
494class IndexAxisBase(AxisBase):
495    type: Literal["index"] = "index"
496    id: NonBatchAxisId = AxisId("index")
497
498    @property
499    def scale(self) -> float:
500        return 1.0
501
502    @property
503    def unit(self):
504        return None

Subpart of a resource description

type: Literal['index']
id: Annotated[AxisId, Predicate(_is_not_batch)]

An axis id unique across all axes of one tensor.

scale: float
498    @property
499    def scale(self) -> float:
500        return 1.0
unit
502    @property
503    def unit(self):
504        return None
class IndexInputAxis(IndexAxisBase, _WithInputAxisSize):
527class IndexInputAxis(IndexAxisBase, _WithInputAxisSize):
528    concatenable: bool = False
529    """If a model has a `concatenable` input axis, it can be processed blockwise,
530    splitting a longer sample axis into blocks matching its input tensor description.
531    Output axes are concatenable if they have a `SizeReference` to a concatenable
532    input axis.
533    """

Subpart of a resource description

concatenable: bool

If a model has a concatenable input axis, it can be processed blockwise, splitting a longer sample axis into blocks matching its input tensor description. Output axes are concatenable if they have a SizeReference to a concatenable input axis.

class IndexOutputAxis(IndexAxisBase):
536class IndexOutputAxis(IndexAxisBase):
537    size: Annotated[
538        Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize],
539        Field(
540            examples=[
541                10,
542                SizeReference(
543                    tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
544                ).model_dump(mode="json"),
545            ]
546        ),
547    ]
548    """The size/length of this axis can be specified as
549    - fixed integer
550    - reference to another axis with an optional offset (`SizeReference`)
551    - data dependent size using `DataDependentSize` (size is only known after model inference)
552    """

Subpart of a resource description

size: Annotated[Union[Annotated[int, Gt(gt=0)], SizeReference, DataDependentSize], FieldInfo(annotation=NoneType, required=True, examples=[10, {'tensor_id': 't', 'axis_id': 'a', 'offset': 5}])]

The size/length of this axis can be specified as

  • fixed integer
  • reference to another axis with an optional offset (SizeReference)
  • data dependent size using DataDependentSize (size is only known after model inference)
class TimeAxisBase(AxisBase):
555class TimeAxisBase(AxisBase):
556    type: Literal["time"] = "time"
557    id: NonBatchAxisId = AxisId("time")
558    unit: Optional[TimeUnit] = None
559    scale: Annotated[float, Gt(0)] = 1.0

Subpart of a resource description

type: Literal['time']
id: Annotated[AxisId, Predicate(_is_not_batch)]

An axis id unique across all axes of one tensor.

unit: Optional[Literal['attosecond', 'centisecond', 'day', 'decisecond', 'exasecond', 'femtosecond', 'gigasecond', 'hectosecond', 'hour', 'kilosecond', 'megasecond', 'microsecond', 'millisecond', 'minute', 'nanosecond', 'petasecond', 'picosecond', 'second', 'terasecond', 'yoctosecond', 'yottasecond', 'zeptosecond', 'zettasecond']]
scale: Annotated[float, Gt(gt=0)]
class TimeInputAxis(TimeAxisBase, _WithInputAxisSize):
562class TimeInputAxis(TimeAxisBase, _WithInputAxisSize):
563    concatenable: bool = False
564    """If a model has a `concatenable` input axis, it can be processed blockwise,
565    splitting a longer sample axis into blocks matching its input tensor description.
566    Output axes are concatenable if they have a `SizeReference` to a concatenable
567    input axis.
568    """

Subpart of a resource description

concatenable: bool

If a model has a concatenable input axis, it can be processed blockwise, splitting a longer sample axis into blocks matching its input tensor description. Output axes are concatenable if they have a SizeReference to a concatenable input axis.

class SpaceAxisBase(AxisBase):
571class SpaceAxisBase(AxisBase):
572    type: Literal["space"] = "space"
573    id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x")
574    unit: Optional[SpaceUnit] = None
575    scale: Annotated[float, Gt(0)] = 1.0

Subpart of a resource description

type: Literal['space']
id: Annotated[AxisId, Predicate(_is_not_batch), FieldInfo(annotation=NoneType, required=True, examples=['x', 'y', 'z'])]

An axis id unique across all axes of one tensor.

unit: Optional[Literal['attometer', 'angstrom', 'centimeter', 'decimeter', 'exameter', 'femtometer', 'foot', 'gigameter', 'hectometer', 'inch', 'kilometer', 'megameter', 'meter', 'micrometer', 'mile', 'millimeter', 'nanometer', 'parsec', 'petameter', 'picometer', 'terameter', 'yard', 'yoctometer', 'yottameter', 'zeptometer', 'zettameter']]
scale: Annotated[float, Gt(gt=0)]
class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize):
578class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize):
579    concatenable: bool = False
580    """If a model has a `concatenable` input axis, it can be processed blockwise,
581    splitting a longer sample axis into blocks matching its input tensor description.
582    Output axes are concatenable if they have a `SizeReference` to a concatenable
583    input axis.
584    """

Subpart of a resource description

concatenable: bool

If a model has a concatenable input axis, it can be processed blockwise, splitting a longer sample axis into blocks matching its input tensor description. Output axes are concatenable if they have a SizeReference to a concatenable input axis.

InputAxis = typing.Annotated[typing.Union[BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)]
class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize):
611class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize):
612    pass

Subpart of a resource description

class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo):
615class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo):
616    pass

Subpart of a resource description

class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize):
635class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize):
636    pass

Subpart of a resource description

class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo):
639class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo):
640    pass

Subpart of a resource description

OutputAxis = typing.Annotated[typing.Union[BatchAxis, ChannelAxis, IndexOutputAxis, typing.Annotated[typing.Union[typing.Annotated[TimeOutputAxis, Tag(tag='wo_halo')], typing.Annotated[TimeOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)], typing.Annotated[typing.Union[typing.Annotated[SpaceOutputAxis, Tag(tag='wo_halo')], typing.Annotated[SpaceOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)]
AnyAxis = typing.Union[typing.Annotated[typing.Union[BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], typing.Annotated[typing.Union[BatchAxis, ChannelAxis, IndexOutputAxis, typing.Annotated[typing.Union[typing.Annotated[TimeOutputAxis, Tag(tag='wo_halo')], typing.Annotated[TimeOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)], typing.Annotated[typing.Union[typing.Annotated[SpaceOutputAxis, Tag(tag='wo_halo')], typing.Annotated[SpaceOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]
TVs = typing.Union[typing.Annotated[typing.List[int], MinLen(min_length=1)], typing.Annotated[typing.List[float], MinLen(min_length=1)], typing.Annotated[typing.List[bool], MinLen(min_length=1)], typing.Annotated[typing.List[str], MinLen(min_length=1)]]
NominalOrOrdinalDType = typing.Literal['float32', 'float64', 'uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64', 'bool']
class NominalOrOrdinalDataDescr(bioimageio.spec._internal.node.Node):
682class NominalOrOrdinalDataDescr(Node):
683    values: TVs
684    """A fixed set of nominal or an ascending sequence of ordinal values.
685    In this case `data_type` is required to be an unsigend integer type, e.g. 'uint8'.
686    String `values` are interpreted as labels for tensor values 0, ..., N.
687    Note: as YAML 1.2 does not natively support a "set" datatype,
688    nominal values should be given as a sequence (aka list/array) as well.
689    """
690
691    type: Annotated[
692        NominalOrOrdinalDType,
693        Field(
694            examples=[
695                "float32",
696                "uint8",
697                "uint16",
698                "int64",
699                "bool",
700            ],
701        ),
702    ] = "uint8"
703
704    @model_validator(mode="after")
705    def _validate_values_match_type(
706        self,
707    ) -> Self:
708        incompatible: List[Any] = []
709        for v in self.values:
710            if self.type == "bool":
711                if not isinstance(v, bool):
712                    incompatible.append(v)
713            elif self.type in DTYPE_LIMITS:
714                if (
715                    isinstance(v, (int, float))
716                    and (
717                        v < DTYPE_LIMITS[self.type].min
718                        or v > DTYPE_LIMITS[self.type].max
719                    )
720                    or (isinstance(v, str) and "uint" not in self.type)
721                    or (isinstance(v, float) and "int" in self.type)
722                ):
723                    incompatible.append(v)
724            else:
725                incompatible.append(v)
726
727            if len(incompatible) == 5:
728                incompatible.append("...")
729                break
730
731        if incompatible:
732            raise ValueError(
733                f"data type '{self.type}' incompatible with values {incompatible}"
734            )
735
736        return self
737
738    unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None
739
740    @property
741    def range(self):
742        if isinstance(self.values[0], str):
743            return 0, len(self.values) - 1
744        else:
745            return min(self.values), max(self.values)

Subpart of a resource description

values: Union[Annotated[List[int], MinLen(min_length=1)], Annotated[List[float], MinLen(min_length=1)], Annotated[List[bool], MinLen(min_length=1)], Annotated[List[str], MinLen(min_length=1)]]

A fixed set of nominal or an ascending sequence of ordinal values. In this case data_type is required to be an unsigend integer type, e.g. 'uint8'. String values are interpreted as labels for tensor values 0, ..., N. Note: as YAML 1.2 does not natively support a "set" datatype, nominal values should be given as a sequence (aka list/array) as well.

type: Annotated[Literal['float32', 'float64', 'uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64', 'bool'], FieldInfo(annotation=NoneType, required=True, examples=['float32', 'uint8', 'uint16', 'int64', 'bool'])]
unit: Union[Literal['arbitrary unit'], bioimageio.spec._internal.types.SiUnit, NoneType]
range
740    @property
741    def range(self):
742        if isinstance(self.values[0], str):
743            return 0, len(self.values) - 1
744        else:
745            return min(self.values), max(self.values)
IntervalOrRatioDType = typing.Literal['float32', 'float64', 'uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64']
class IntervalOrRatioDataDescr(bioimageio.spec._internal.node.Node):
762class IntervalOrRatioDataDescr(Node):
763    type: Annotated[  # todo: rename to dtype
764        IntervalOrRatioDType,
765        Field(
766            examples=["float32", "float64", "uint8", "uint16"],
767        ),
768    ] = "float32"
769    range: Tuple[Optional[float], Optional[float]] = (
770        None,
771        None,
772    )
773    """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor.
774    `None` corresponds to min/max of what can be expressed by `data_type`."""
775    unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit"
776    scale: float = 1.0
777    """Scale for data on an interval (or ratio) scale."""
778    offset: Optional[float] = None
779    """Offset for data on a ratio scale."""

Subpart of a resource description

type: Annotated[Literal['float32', 'float64', 'uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64'], FieldInfo(annotation=NoneType, required=True, examples=['float32', 'float64', 'uint8', 'uint16'])]
range: Tuple[Optional[float], Optional[float]]

Tuple (minimum, maximum) specifying the allowed range of the data in this tensor. None corresponds to min/max of what can be expressed by data_type.

unit: Union[Literal['arbitrary unit'], bioimageio.spec._internal.types.SiUnit]
scale: float

Scale for data on an interval (or ratio) scale.

offset: Optional[float]

Offset for data on a ratio scale.

TensorDataDescr = typing.Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr]
class ProcessingDescrBase(bioimageio.spec._internal.common_nodes.NodeWithExplicitlySetFields, abc.ABC):
785class ProcessingDescrBase(NodeWithExplicitlySetFields, ABC):
786    """processing base class"""
787
788    # id: Literal[PreprocessingId, PostprocessingId]  # make abstract field
789    fields_to_set_explicitly: ClassVar[FrozenSet[LiteralString]] = frozenset({"id"})

processing base class

fields_to_set_explicitly: ClassVar[FrozenSet[LiteralString]] = frozenset({'id'})

set set these fields explicitly with their default value if they are not set, such that they are always included even when dumping with 'exlude_unset'

class BinarizeKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
792class BinarizeKwargs(ProcessingKwargs):
793    """key word arguments for `BinarizeDescr`"""
794
795    threshold: float
796    """The fixed threshold"""

key word arguments for BinarizeDescr

threshold: float

The fixed threshold

class BinarizeAlongAxisKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
799class BinarizeAlongAxisKwargs(ProcessingKwargs):
800    """key word arguments for `BinarizeDescr`"""
801
802    threshold: NotEmpty[List[float]]
803    """The fixed threshold values along `axis`"""
804
805    axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
806    """The `threshold` axis"""

key word arguments for BinarizeDescr

threshold: Annotated[List[float], MinLen(min_length=1)]

The fixed threshold values along axis

axis: Annotated[AxisId, Predicate(_is_not_batch), FieldInfo(annotation=NoneType, required=True, examples=['channel'])]

The threshold axis

class BinarizeDescr(ProcessingDescrBase):
809class BinarizeDescr(ProcessingDescrBase):
810    """Binarize the tensor with a fixed threshold.
811
812    Values above `BinarizeKwargs.threshold`/`BinarizeAlongAxisKwargs.threshold`
813    will be set to one, values below the threshold to zero.
814    """
815
816    id: Literal["binarize"] = "binarize"
817    kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs]

Binarize the tensor with a fixed threshold.

Values above BinarizeKwargs.threshold/BinarizeAlongAxisKwargs.threshold will be set to one, values below the threshold to zero.

id: Literal['binarize']
class ClipDescr(ProcessingDescrBase):
820class ClipDescr(ProcessingDescrBase):
821    """Set tensor values below min to min and above max to max."""
822
823    id: Literal["clip"] = "clip"
824    kwargs: ClipKwargs

Set tensor values below min to min and above max to max.

id: Literal['clip']
class EnsureDtypeKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
827class EnsureDtypeKwargs(ProcessingKwargs):
828    """key word arguments for `EnsureDtypeDescr`"""
829
830    dtype: Literal[
831        "float32",
832        "float64",
833        "uint8",
834        "int8",
835        "uint16",
836        "int16",
837        "uint32",
838        "int32",
839        "uint64",
840        "int64",
841        "bool",
842    ]

key word arguments for EnsureDtypeDescr

dtype: Literal['float32', 'float64', 'uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64', 'bool']
class EnsureDtypeDescr(ProcessingDescrBase):
845class EnsureDtypeDescr(ProcessingDescrBase):
846    """cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching)"""
847
848    id: Literal["ensure_dtype"] = "ensure_dtype"
849    kwargs: EnsureDtypeKwargs

cast the tensor data type to EnsureDtypeKwargs.dtype (if not matching)

id: Literal['ensure_dtype']
class ScaleLinearKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
852class ScaleLinearKwargs(ProcessingKwargs):
853    """key word arguments for `ScaleLinearDescr`"""
854
855    gain: float = 1.0
856    """multiplicative factor"""
857
858    offset: float = 0.0
859    """additive term"""
860
861    @model_validator(mode="after")
862    def _validate(self) -> Self:
863        if self.gain == 1.0 and self.offset == 0.0:
864            raise ValueError(
865                "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
866                + " != 0.0."
867            )
868
869        return self

key word arguments for ScaleLinearDescr

gain: float

multiplicative factor

offset: float

additive term

class ScaleLinearAlongAxisKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
872class ScaleLinearAlongAxisKwargs(ProcessingKwargs):
873    """key word arguments for `ScaleLinearDescr`"""
874
875    axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
876    """The axis of of gains/offsets values."""
877
878    gain: Union[float, NotEmpty[List[float]]] = 1.0
879    """multiplicative factor"""
880
881    offset: Union[float, NotEmpty[List[float]]] = 0.0
882    """additive term"""
883
884    @model_validator(mode="after")
885    def _validate(self) -> Self:
886
887        if isinstance(self.gain, list):
888            if isinstance(self.offset, list):
889                if len(self.gain) != len(self.offset):
890                    raise ValueError(
891                        f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match."
892                    )
893            else:
894                self.offset = [float(self.offset)] * len(self.gain)
895        elif isinstance(self.offset, list):
896            self.gain = [float(self.gain)] * len(self.offset)
897        else:
898            raise ValueError(
899                "Do not specify an `axis` for scalar gain and offset values."
900            )
901
902        if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset):
903            raise ValueError(
904                "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
905                + " != 0.0."
906            )
907
908        return self

key word arguments for ScaleLinearDescr

axis: Annotated[AxisId, Predicate(_is_not_batch), FieldInfo(annotation=NoneType, required=True, examples=['channel'])]

The axis of of gains/offsets values.

gain: Union[float, Annotated[List[float], MinLen(min_length=1)]]

multiplicative factor

offset: Union[float, Annotated[List[float], MinLen(min_length=1)]]

additive term

class ScaleLinearDescr(ProcessingDescrBase):
911class ScaleLinearDescr(ProcessingDescrBase):
912    """Fixed linear scaling."""
913
914    id: Literal["scale_linear"] = "scale_linear"
915    kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs]

Fixed linear scaling.

id: Literal['scale_linear']
class SigmoidDescr(ProcessingDescrBase):
918class SigmoidDescr(ProcessingDescrBase):
919    """The logistic sigmoid funciton, a.k.a. expit function."""
920
921    id: Literal["sigmoid"] = "sigmoid"
922
923    @property
924    def kwargs(self) -> ProcessingKwargs:
925        """empty kwargs"""
926        return ProcessingKwargs()

The logistic sigmoid funciton, a.k.a. expit function.

id: Literal['sigmoid']
923    @property
924    def kwargs(self) -> ProcessingKwargs:
925        """empty kwargs"""
926        return ProcessingKwargs()

empty kwargs

class FixedZeroMeanUnitVarianceKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
929class FixedZeroMeanUnitVarianceKwargs(ProcessingKwargs):
930    """key word arguments for `FixedZeroMeanUnitVarianceDescr`"""
931
932    mean: float
933    """The mean value to normalize with."""
934
935    std: Annotated[float, Ge(1e-6)]
936    """The standard deviation value to normalize with."""

key word arguments for FixedZeroMeanUnitVarianceDescr

mean: float

The mean value to normalize with.

std: Annotated[float, Ge(ge=1e-06)]

The standard deviation value to normalize with.

class FixedZeroMeanUnitVarianceAlongAxisKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
939class FixedZeroMeanUnitVarianceAlongAxisKwargs(ProcessingKwargs):
940    """key word arguments for `FixedZeroMeanUnitVarianceDescr`"""
941
942    mean: NotEmpty[List[float]]
943    """The mean value(s) to normalize with."""
944
945    std: NotEmpty[List[Annotated[float, Ge(1e-6)]]]
946    """The standard deviation value(s) to normalize with.
947    Size must match `mean` values."""
948
949    axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])]
950    """The axis of the mean/std values to normalize each entry along that dimension
951    separately."""
952
953    @model_validator(mode="after")
954    def _mean_and_std_match(self) -> Self:
955        if len(self.mean) != len(self.std):
956            raise ValueError(
957                f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})"
958                + " must match."
959            )
960
961        return self

key word arguments for FixedZeroMeanUnitVarianceDescr

mean: Annotated[List[float], MinLen(min_length=1)]

The mean value(s) to normalize with.

std: Annotated[List[Annotated[float, Ge(ge=1e-06)]], MinLen(min_length=1)]

The standard deviation value(s) to normalize with. Size must match mean values.

axis: Annotated[AxisId, Predicate(_is_not_batch), FieldInfo(annotation=NoneType, required=True, examples=['channel', 'index'])]

The axis of the mean/std values to normalize each entry along that dimension separately.

class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase):
964class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase):
965    """Subtract a given mean and divide by the standard deviation.
966
967    Normalize with fixed, precomputed values for
968    `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std`
969    Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given
970    axes.
971    """
972
973    id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance"
974    kwargs: Union[
975        FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs
976    ]

Subtract a given mean and divide by the standard deviation.

Normalize with fixed, precomputed values for FixedZeroMeanUnitVarianceKwargs.mean and FixedZeroMeanUnitVarianceKwargs.std Use FixedZeroMeanUnitVarianceAlongAxisKwargs for independent scaling along given axes.

id: Literal['fixed_zero_mean_unit_variance']
class ZeroMeanUnitVarianceKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
979class ZeroMeanUnitVarianceKwargs(ProcessingKwargs):
980    """key word arguments for `ZeroMeanUnitVarianceDescr`"""
981
982    axes: Annotated[
983        Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
984    ] = None
985    """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
986    For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
987    resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
988    To normalize each sample independently leave out the 'batch' axis.
989    Default: Scale all axes jointly."""
990
991    eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
992    """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`."""

key word arguments for ZeroMeanUnitVarianceDescr

axes: Annotated[Optional[Sequence[AxisId]], FieldInfo(annotation=NoneType, required=True, examples=[('batch', 'x', 'y')])]

The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') resulting in a tensor of equal shape normalized per channel, specify axes=('batch', 'x', 'y'). To normalize each sample independently leave out the 'batch' axis. Default: Scale all axes jointly.

eps: Annotated[float, Interval(gt=0, ge=None, lt=None, le=0.1)]

epsilon for numeric stability: out = (tensor - mean) / (std + eps).

class ZeroMeanUnitVarianceDescr(ProcessingDescrBase):
995class ZeroMeanUnitVarianceDescr(ProcessingDescrBase):
996    """Subtract mean and divide by variance."""
997
998    id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance"
999    kwargs: ZeroMeanUnitVarianceKwargs

Subtract mean and divide by variance.

id: Literal['zero_mean_unit_variance']
class ScaleRangeKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1002class ScaleRangeKwargs(ProcessingKwargs):
1003    """key word arguments for `ScaleRangeDescr`
1004
1005    For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default)
1006    this processing step normalizes data to the [0, 1] intervall.
1007    For other percentiles the normalized values will partially be outside the [0, 1]
1008    intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the
1009    normalized values to a range.
1010    """
1011
1012    axes: Annotated[
1013        Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1014    ] = None
1015    """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value.
1016    For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1017    resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1018    To normalize samples indepdencently, leave out the "batch" axis.
1019    Default: Scale all axes jointly."""
1020
1021    min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0
1022    """The lower percentile used to determine the value to align with zero."""
1023
1024    max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0
1025    """The upper percentile used to determine the value to align with one.
1026    Has to be bigger than `min_percentile`.
1027    The range is 1 to 100 instead of 0 to 100 to avoid mistakenly
1028    accepting percentiles specified in the range 0.0 to 1.0."""
1029
1030    eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1031    """Epsilon for numeric stability.
1032    `out = (tensor - v_lower) / (v_upper - v_lower + eps)`;
1033    with `v_lower,v_upper` values at the respective percentiles."""
1034
1035    reference_tensor: Optional[TensorId] = None
1036    """Tensor ID to compute the percentiles from. Default: The tensor itself.
1037    For any tensor in `inputs` only input tensor references are allowed."""
1038
1039    @field_validator("max_percentile", mode="after")
1040    @classmethod
1041    def min_smaller_max(cls, value: float, info: ValidationInfo) -> float:
1042        if (min_p := info.data["min_percentile"]) >= value:
1043            raise ValueError(f"min_percentile {min_p} >= max_percentile {value}")
1044
1045        return value

key word arguments for ScaleRangeDescr

For min_percentile=0.0 (the default) and max_percentile=100 (the default) this processing step normalizes data to the [0, 1] intervall. For other percentiles the normalized values will partially be outside the [0, 1] intervall. Use ScaleRange followed by ClipDescr if you want to limit the normalized values to a range.

axes: Annotated[Optional[Sequence[AxisId]], FieldInfo(annotation=NoneType, required=True, examples=[('batch', 'x', 'y')])]

The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value. For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') resulting in a tensor of equal shape normalized per channel, specify axes=('batch', 'x', 'y'). To normalize samples indepdencently, leave out the "batch" axis. Default: Scale all axes jointly.

min_percentile: Annotated[float, Interval(gt=None, ge=0, lt=100, le=None)]

The lower percentile used to determine the value to align with zero.

max_percentile: Annotated[float, Interval(gt=1, ge=None, lt=None, le=100)]

The upper percentile used to determine the value to align with one. Has to be bigger than min_percentile. The range is 1 to 100 instead of 0 to 100 to avoid mistakenly accepting percentiles specified in the range 0.0 to 1.0.

eps: Annotated[float, Interval(gt=0, ge=None, lt=None, le=0.1)]

Epsilon for numeric stability. out = (tensor - v_lower) / (v_upper - v_lower + eps); with v_lower,v_upper values at the respective percentiles.

reference_tensor: Optional[TensorId]

Tensor ID to compute the percentiles from. Default: The tensor itself. For any tensor in inputs only input tensor references are allowed.

@field_validator('max_percentile', mode='after')
@classmethod
def min_smaller_max( cls, value: float, info: pydantic_core.core_schema.ValidationInfo) -> float:
1039    @field_validator("max_percentile", mode="after")
1040    @classmethod
1041    def min_smaller_max(cls, value: float, info: ValidationInfo) -> float:
1042        if (min_p := info.data["min_percentile"]) >= value:
1043            raise ValueError(f"min_percentile {min_p} >= max_percentile {value}")
1044
1045        return value
class ScaleRangeDescr(ProcessingDescrBase):
1048class ScaleRangeDescr(ProcessingDescrBase):
1049    """Scale with percentiles."""
1050
1051    id: Literal["scale_range"] = "scale_range"
1052    kwargs: ScaleRangeKwargs

Scale with percentiles.

id: Literal['scale_range']
class ScaleMeanVarianceKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1055class ScaleMeanVarianceKwargs(ProcessingKwargs):
1056    """key word arguments for `ScaleMeanVarianceKwargs`"""
1057
1058    reference_tensor: TensorId
1059    """Name of tensor to match."""
1060
1061    axes: Annotated[
1062        Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1063    ] = None
1064    """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1065    For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1066    resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1067    To normalize samples independently, leave out the 'batch' axis.
1068    Default: Scale all axes jointly."""
1069
1070    eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1071    """Epsilon for numeric stability:
1072    `out  = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`"""

key word arguments for ScaleMeanVarianceKwargs

reference_tensor: TensorId

Name of tensor to match.

axes: Annotated[Optional[Sequence[AxisId]], FieldInfo(annotation=NoneType, required=True, examples=[('batch', 'x', 'y')])]

The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') resulting in a tensor of equal shape normalized per channel, specify axes=('batch', 'x', 'y'). To normalize samples independently, leave out the 'batch' axis. Default: Scale all axes jointly.

eps: Annotated[float, Interval(gt=0, ge=None, lt=None, le=0.1)]

Epsilon for numeric stability: out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.

class ScaleMeanVarianceDescr(ProcessingDescrBase):
1075class ScaleMeanVarianceDescr(ProcessingDescrBase):
1076    """Scale a tensor's data distribution to match another tensor's mean/std.
1077    `out  = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`
1078    """
1079
1080    id: Literal["scale_mean_variance"] = "scale_mean_variance"
1081    kwargs: ScaleMeanVarianceKwargs

Scale a tensor's data distribution to match another tensor's mean/std. out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.

id: Literal['scale_mean_variance']
PreprocessingDescr = typing.Annotated[typing.Union[BinarizeDescr, ClipDescr, EnsureDtypeDescr, ScaleLinearDescr, SigmoidDescr, FixedZeroMeanUnitVarianceDescr, ZeroMeanUnitVarianceDescr, ScaleRangeDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)]
PostprocessingDescr = typing.Annotated[typing.Union[BinarizeDescr, ClipDescr, EnsureDtypeDescr, ScaleLinearDescr, SigmoidDescr, FixedZeroMeanUnitVarianceDescr, ZeroMeanUnitVarianceDescr, ScaleRangeDescr, ScaleMeanVarianceDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)]
class TensorDescrBase(bioimageio.spec._internal.node.Node, typing.Generic[~IO_AxisT]):
1115class TensorDescrBase(Node, Generic[IO_AxisT]):
1116    id: TensorId
1117    """Tensor id. No duplicates are allowed."""
1118
1119    description: Annotated[str, MaxLen(128)] = ""
1120    """free text description"""
1121
1122    axes: NotEmpty[Sequence[IO_AxisT]]
1123    """tensor axes"""
1124
1125    @property
1126    def shape(self):
1127        return tuple(a.size for a in self.axes)
1128
1129    @field_validator("axes", mode="after", check_fields=False)
1130    @classmethod
1131    def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]:
1132        batch_axes = [a for a in axes if a.type == "batch"]
1133        if len(batch_axes) > 1:
1134            raise ValueError(
1135                f"Only one batch axis (per tensor) allowed, but got {batch_axes}"
1136            )
1137
1138        seen_ids: Set[AxisId] = set()
1139        duplicate_axes_ids: Set[AxisId] = set()
1140        for a in axes:
1141            (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id)
1142
1143        if duplicate_axes_ids:
1144            raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}")
1145
1146        return axes
1147
1148    test_tensor: FileDescr
1149    """An example tensor to use for testing.
1150    Using the model with the test input tensors is expected to yield the test output tensors.
1151    Each test tensor has be a an ndarray in the
1152    [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format).
1153    The file extension must be '.npy'."""
1154
1155    sample_tensor: Optional[FileDescr] = None
1156    """A sample tensor to illustrate a possible input/output for the model,
1157    The sample image primarily serves to inform a human user about an example use case
1158    and is typically stored as .hdf5, .png or .tiff.
1159    It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats)
1160    (numpy's `.npy` format is not supported).
1161    The image dimensionality has to match the number of axes specified in this tensor description.
1162    """
1163
1164    @model_validator(mode="after")
1165    def _validate_sample_tensor(self) -> Self:
1166        if (
1167            self.sample_tensor is None
1168            or not validation_context_var.get().perform_io_checks
1169        ):
1170            return self
1171
1172        local = download(self.sample_tensor.source, sha256=self.sample_tensor.sha256)
1173        tensor: NDArray[Any] = imread(
1174            local.path.read_bytes(),
1175            extension=PurePosixPath(local.original_file_name).suffix,
1176        )
1177        n_dims = len(tensor.squeeze().shape)
1178        n_dims_min = n_dims_max = len(self.axes)
1179
1180        for a in self.axes:
1181            if isinstance(a, BatchAxis):
1182                n_dims_min -= 1
1183            elif isinstance(a.size, int):
1184                if a.size == 1:
1185                    n_dims_min -= 1
1186            elif isinstance(a.size, (ParameterizedSize, DataDependentSize)):
1187                if a.size.min == 1:
1188                    n_dims_min -= 1
1189            elif isinstance(a.size, SizeReference):
1190                if a.size.offset < 2:
1191                    # size reference may result in singleton axis
1192                    n_dims_min -= 1
1193            else:
1194                assert_never(a.size)
1195
1196        n_dims_min = max(0, n_dims_min)
1197        if n_dims < n_dims_min or n_dims > n_dims_max:
1198            raise ValueError(
1199                f"Expected sample tensor to have {n_dims_min} to"
1200                + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})."
1201            )
1202
1203        return self
1204
1205    data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = (
1206        IntervalOrRatioDataDescr()
1207    )
1208    """Description of the tensor's data values, optionally per channel.
1209    If specified per channel, the data `type` needs to match across channels."""
1210
1211    @property
1212    def dtype(
1213        self,
1214    ) -> Literal[
1215        "float32",
1216        "float64",
1217        "uint8",
1218        "int8",
1219        "uint16",
1220        "int16",
1221        "uint32",
1222        "int32",
1223        "uint64",
1224        "int64",
1225        "bool",
1226    ]:
1227        """dtype as specified under `data.type` or `data[i].type`"""
1228        if isinstance(self.data, collections.abc.Sequence):
1229            return self.data[0].type
1230        else:
1231            return self.data.type
1232
1233    @field_validator("data", mode="after")
1234    @classmethod
1235    def _check_data_type_across_channels(
1236        cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]
1237    ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]:
1238        if not isinstance(value, list):
1239            return value
1240
1241        dtypes = {t.type for t in value}
1242        if len(dtypes) > 1:
1243            raise ValueError(
1244                "Tensor data descriptions per channel need to agree in their data"
1245                + f" `type`, but found {dtypes}."
1246            )
1247
1248        return value
1249
1250    @model_validator(mode="after")
1251    def _check_data_matches_channelaxis(self) -> Self:
1252        if not isinstance(self.data, (list, tuple)):
1253            return self
1254
1255        for a in self.axes:
1256            if isinstance(a, ChannelAxis):
1257                size = a.size
1258                assert isinstance(size, int)
1259                break
1260        else:
1261            return self
1262
1263        if len(self.data) != size:
1264            raise ValueError(
1265                f"Got tensor data descriptions for {len(self.data)} channels, but"
1266                + f" '{a.id}' axis has size {size}."
1267            )
1268
1269        return self
1270
1271    def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]:
1272        if len(array.shape) != len(self.axes):
1273            raise ValueError(
1274                f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})"
1275                + f" incompatible with {len(self.axes)} axes."
1276            )
1277        return {a.id: array.shape[i] for i, a in enumerate(self.axes)}

Subpart of a resource description

id: TensorId

Tensor id. No duplicates are allowed.

description: Annotated[str, MaxLen(max_length=128)]

free text description

axes: Annotated[Sequence[~IO_AxisT], MinLen(min_length=1)]

tensor axes

shape
1125    @property
1126    def shape(self):
1127        return tuple(a.size for a in self.axes)
test_tensor: bioimageio.spec._internal.io.FileDescr

An example tensor to use for testing. Using the model with the test input tensors is expected to yield the test output tensors. Each test tensor has be a an ndarray in the numpy.lib file format. The file extension must be '.npy'.

sample_tensor: Optional[bioimageio.spec._internal.io.FileDescr]

A sample tensor to illustrate a possible input/output for the model, The sample image primarily serves to inform a human user about an example use case and is typically stored as .hdf5, .png or .tiff. It has to be readable by the imageio library (numpy's .npy format is not supported). The image dimensionality has to match the number of axes specified in this tensor description.

data: Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr, Annotated[Sequence[Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr]], MinLen(min_length=1)]]

Description of the tensor's data values, optionally per channel. If specified per channel, the data type needs to match across channels.

dtype: Literal['float32', 'float64', 'uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64', 'bool']
1211    @property
1212    def dtype(
1213        self,
1214    ) -> Literal[
1215        "float32",
1216        "float64",
1217        "uint8",
1218        "int8",
1219        "uint16",
1220        "int16",
1221        "uint32",
1222        "int32",
1223        "uint64",
1224        "int64",
1225        "bool",
1226    ]:
1227        """dtype as specified under `data.type` or `data[i].type`"""
1228        if isinstance(self.data, collections.abc.Sequence):
1229            return self.data[0].type
1230        else:
1231            return self.data.type

dtype as specified under data.type or data[i].type

def get_axis_sizes_for_array( self, array: numpy.ndarray[typing.Any, numpy.dtype[typing.Any]]) -> Dict[AxisId, int]:
1271    def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]:
1272        if len(array.shape) != len(self.axes):
1273            raise ValueError(
1274                f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})"
1275                + f" incompatible with {len(self.axes)} axes."
1276            )
1277        return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
class InputTensorDescr(bioimageio.spec._internal.node.Node, typing.Generic[~IO_AxisT]):
1280class InputTensorDescr(TensorDescrBase[InputAxis]):
1281    id: TensorId = TensorId("input")
1282    """Input tensor id.
1283    No duplicates are allowed across all inputs and outputs."""
1284
1285    optional: bool = False
1286    """indicates that this tensor may be `None`"""
1287
1288    preprocessing: List[PreprocessingDescr] = Field(default_factory=list)
1289    """Description of how this input should be preprocessed.
1290
1291    notes:
1292    - If preprocessing does not start with an 'ensure_dtype' entry, it is added
1293      to ensure an input tensor's data type matches the input tensor's data description.
1294    - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an
1295      'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally
1296      changing the data type.
1297    """
1298
1299    @model_validator(mode="after")
1300    def _validate_preprocessing_kwargs(self) -> Self:
1301        axes_ids = [a.id for a in self.axes]
1302        for p in self.preprocessing:
1303            kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
1304            if kwargs_axes is None:
1305                continue
1306
1307            if not isinstance(kwargs_axes, collections.abc.Sequence):
1308                raise ValueError(
1309                    f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}"
1310                )
1311
1312            if any(a not in axes_ids for a in kwargs_axes):
1313                raise ValueError(
1314                    "`preprocessing.i.kwargs.axes` needs to be subset of axes ids"
1315                )
1316
1317        if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
1318            dtype = self.data.type
1319        else:
1320            dtype = self.data[0].type
1321
1322        # ensure `preprocessing` begins with `EnsureDtypeDescr`
1323        if not self.preprocessing or not isinstance(
1324            self.preprocessing[0], EnsureDtypeDescr
1325        ):
1326            self.preprocessing.insert(
1327                0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1328            )
1329
1330        # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
1331        if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)):
1332            self.preprocessing.append(
1333                EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1334            )
1335
1336        return self

Subpart of a resource description

id: TensorId

Input tensor id. No duplicates are allowed across all inputs and outputs.

optional: bool

indicates that this tensor may be None

preprocessing: List[Annotated[Union[BinarizeDescr, ClipDescr, EnsureDtypeDescr, ScaleLinearDescr, SigmoidDescr, FixedZeroMeanUnitVarianceDescr, ZeroMeanUnitVarianceDescr, ScaleRangeDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]

Description of how this input should be preprocessed.

notes:

  • If preprocessing does not start with an 'ensure_dtype' entry, it is added to ensure an input tensor's data type matches the input tensor's data description.
  • If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally changing the data type.
def convert_axes( axes: str, *, shape: Union[Sequence[int], bioimageio.spec.model.v0_4.ParameterizedInputShape, bioimageio.spec.model.v0_4.ImplicitOutputShape], tensor_type: Literal['input', 'output'], halo: Optional[Sequence[int]], size_refs: Mapping[bioimageio.spec.model.v0_4.TensorName, Mapping[str, int]]):
1339def convert_axes(
1340    axes: str,
1341    *,
1342    shape: Union[
1343        Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4
1344    ],
1345    tensor_type: Literal["input", "output"],
1346    halo: Optional[Sequence[int]],
1347    size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
1348):
1349    ret: List[AnyAxis] = []
1350    for i, a in enumerate(axes):
1351        axis_type = _AXIS_TYPE_MAP.get(a, a)
1352        if axis_type == "batch":
1353            ret.append(BatchAxis())
1354            continue
1355
1356        scale = 1.0
1357        if isinstance(shape, _ParameterizedInputShape_v0_4):
1358            if shape.step[i] == 0:
1359                size = shape.min[i]
1360            else:
1361                size = ParameterizedSize(min=shape.min[i], step=shape.step[i])
1362        elif isinstance(shape, _ImplicitOutputShape_v0_4):
1363            ref_t = str(shape.reference_tensor)
1364            if ref_t.count(".") == 1:
1365                t_id, orig_a_id = ref_t.split(".")
1366            else:
1367                t_id = ref_t
1368                orig_a_id = a
1369
1370            a_id = _AXIS_ID_MAP.get(orig_a_id, a)
1371            if not (orig_scale := shape.scale[i]):
1372                # old way to insert a new axis dimension
1373                size = int(2 * shape.offset[i])
1374            else:
1375                scale = 1 / orig_scale
1376                if axis_type in ("channel", "index"):
1377                    # these axes no longer have a scale
1378                    offset_from_scale = orig_scale * size_refs.get(
1379                        _TensorName_v0_4(t_id), {}
1380                    ).get(orig_a_id, 0)
1381                else:
1382                    offset_from_scale = 0
1383                size = SizeReference(
1384                    tensor_id=TensorId(t_id),
1385                    axis_id=AxisId(a_id),
1386                    offset=int(offset_from_scale + 2 * shape.offset[i]),
1387                )
1388        else:
1389            size = shape[i]
1390
1391        if axis_type == "time":
1392            if tensor_type == "input":
1393                ret.append(TimeInputAxis(size=size, scale=scale))
1394            else:
1395                assert not isinstance(size, ParameterizedSize)
1396                if halo is None:
1397                    ret.append(TimeOutputAxis(size=size, scale=scale))
1398                else:
1399                    assert not isinstance(size, int)
1400                    ret.append(
1401                        TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i])
1402                    )
1403
1404        elif axis_type == "index":
1405            if tensor_type == "input":
1406                ret.append(IndexInputAxis(size=size))
1407            else:
1408                if isinstance(size, ParameterizedSize):
1409                    size = DataDependentSize(min=size.min)
1410
1411                ret.append(IndexOutputAxis(size=size))
1412        elif axis_type == "channel":
1413            assert not isinstance(size, ParameterizedSize)
1414            if isinstance(size, SizeReference):
1415                warnings.warn(
1416                    "Conversion of channel size from an implicit output shape may be"
1417                    + " wrong"
1418                )
1419                ret.append(
1420                    ChannelAxis(
1421                        channel_names=[
1422                            Identifier(f"channel{i}") for i in range(size.offset)
1423                        ]
1424                    )
1425                )
1426            else:
1427                ret.append(
1428                    ChannelAxis(
1429                        channel_names=[Identifier(f"channel{i}") for i in range(size)]
1430                    )
1431                )
1432        elif axis_type == "space":
1433            if tensor_type == "input":
1434                ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale))
1435            else:
1436                assert not isinstance(size, ParameterizedSize)
1437                if halo is None or halo[i] == 0:
1438                    ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale))
1439                elif isinstance(size, int):
1440                    raise NotImplementedError(
1441                        f"output axis with halo and fixed size (here {size}) not allowed"
1442                    )
1443                else:
1444                    ret.append(
1445                        SpaceOutputAxisWithHalo(
1446                            id=AxisId(a), size=size, scale=scale, halo=halo[i]
1447                        )
1448                    )
1449
1450    return ret
class OutputTensorDescr(bioimageio.spec._internal.node.Node, typing.Generic[~IO_AxisT]):
1625class OutputTensorDescr(TensorDescrBase[OutputAxis]):
1626    id: TensorId = TensorId("output")
1627    """Output tensor id.
1628    No duplicates are allowed across all inputs and outputs."""
1629
1630    postprocessing: List[PostprocessingDescr] = Field(default_factory=list)
1631    """Description of how this output should be postprocessed.
1632
1633    note: `postprocessing` always ends with an 'ensure_dtype' operation.
1634          If not given this is added to cast to this tensor's `data.type`.
1635    """
1636
1637    @model_validator(mode="after")
1638    def _validate_postprocessing_kwargs(self) -> Self:
1639        axes_ids = [a.id for a in self.axes]
1640        for p in self.postprocessing:
1641            kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
1642            if kwargs_axes is None:
1643                continue
1644
1645            if not isinstance(kwargs_axes, collections.abc.Sequence):
1646                raise ValueError(
1647                    f"expected `axes` sequence, but got {type(kwargs_axes)}"
1648                )
1649
1650            if any(a not in axes_ids for a in kwargs_axes):
1651                raise ValueError("`kwargs.axes` needs to be subset of axes ids")
1652
1653        if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
1654            dtype = self.data.type
1655        else:
1656            dtype = self.data[0].type
1657
1658        # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
1659        if not self.postprocessing or not isinstance(
1660            self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)
1661        ):
1662            self.postprocessing.append(
1663                EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1664            )
1665        return self

Subpart of a resource description

id: TensorId

Output tensor id. No duplicates are allowed across all inputs and outputs.

postprocessing: List[Annotated[Union[BinarizeDescr, ClipDescr, EnsureDtypeDescr, ScaleLinearDescr, SigmoidDescr, FixedZeroMeanUnitVarianceDescr, ZeroMeanUnitVarianceDescr, ScaleRangeDescr, ScaleMeanVarianceDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]

Description of how this output should be postprocessed.

note: postprocessing always ends with an 'ensure_dtype' operation. If not given this is added to cast to this tensor's data.type.

TensorDescr = typing.Union[InputTensorDescr, OutputTensorDescr]
def validate_tensors( tensors: Mapping[TensorId, Tuple[Union[InputTensorDescr, OutputTensorDescr], numpy.ndarray[Any, numpy.dtype[Any]]]], tensor_origin: str):
1715def validate_tensors(
1716    tensors: Mapping[TensorId, Tuple[TensorDescr, NDArray[Any]]],
1717    tensor_origin: str,  # for more precise error messages, e.g. 'test_tensor'
1718):
1719    all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, int]]] = {}
1720
1721    def e_msg(d: TensorDescr):
1722        return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]"
1723
1724    for descr, array in tensors.values():
1725        try:
1726            axis_sizes = descr.get_axis_sizes_for_array(array)
1727        except ValueError as e:
1728            raise ValueError(f"{e_msg(descr)} {e}")
1729        else:
1730            all_tensor_axes[descr.id] = {
1731                a.id: (a, axis_sizes[a.id]) for a in descr.axes
1732            }
1733
1734    for descr, array in tensors.values():
1735        if array.dtype.name != descr.dtype:
1736            raise ValueError(
1737                f"{e_msg(descr)}.{tensor_origin}.dtype '{array.dtype.name}' does not"
1738                + f" match described dtype '{descr.dtype}'"
1739            )
1740
1741        for a in descr.axes:
1742            actual_size = all_tensor_axes[descr.id][a.id][1]
1743            if a.size is None:
1744                continue
1745
1746            if isinstance(a.size, int):
1747                if actual_size != a.size:
1748                    raise ValueError(
1749                        f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' "
1750                        + f"has incompatible size {actual_size}, expected {a.size}"
1751                    )
1752            elif isinstance(a.size, ParameterizedSize):
1753                _ = a.size.validate_size(actual_size)
1754            elif isinstance(a.size, DataDependentSize):
1755                _ = a.size.validate_size(actual_size)
1756            elif isinstance(a.size, SizeReference):
1757                ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id)
1758                if ref_tensor_axes is None:
1759                    raise ValueError(
1760                        f"{e_msg(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor"
1761                        + f" reference '{a.size.tensor_id}'"
1762                    )
1763
1764                ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None))
1765                if ref_axis is None or ref_size is None:
1766                    raise ValueError(
1767                        f"{e_msg(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis"
1768                        + f" reference '{a.size.tensor_id}.{a.size.axis_id}"
1769                    )
1770
1771                if a.unit != ref_axis.unit:
1772                    raise ValueError(
1773                        f"{e_msg(descr)}.axes[{a.id}].size: `SizeReference` requires"
1774                        + " axis and reference axis to have the same `unit`, but"
1775                        + f" {a.unit}!={ref_axis.unit}"
1776                    )
1777
1778                if actual_size != (
1779                    expected_size := (
1780                        ref_size * ref_axis.scale / a.scale + a.size.offset
1781                    )
1782                ):
1783                    raise ValueError(
1784                        f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' of size"
1785                        + f" {actual_size} invalid for referenced size {ref_size};"
1786                        + f" expected {expected_size}"
1787                    )
1788            else:
1789                assert_never(a.size)
class EnvironmentFileDescr(bioimageio.spec._internal.io.FileDescr):
1792class EnvironmentFileDescr(FileDescr):
1793    source: Annotated[
1794        ImportantFileSource,
1795        WithSuffix((".yaml", ".yml"), case_sensitive=True),
1796        Field(
1797            examples=["environment.yaml"],
1798        ),
1799    ]
1800    """∈📦 Conda environment file.
1801    Allows to specify custom dependencies, see conda docs:
1802    - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms)
1803    - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually)
1804    """

Subpart of a resource description

source: Annotated[Union[bioimageio.spec._internal.url.HttpUrl, bioimageio.spec._internal.io.RelativeFilePath, Annotated[pathlib.Path, PathType(path_type='file')]], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')]), AfterValidator(func=<function wo_special_file_name at 0x7f9a7f3b8ea0>), PlainSerializer(func=<function _package at 0x7f9a7f3b9620>, return_type=PydanticUndefined, when_used='unless-none'), WithSuffix(suffix=('.yaml', '.yml'), case_sensitive=True), FieldInfo(annotation=NoneType, required=True, examples=['environment.yaml'])]

∈📦 Conda environment file. Allows to specify custom dependencies, see conda docs:

class ArchitectureFromFileDescr(_ArchitectureCallableDescr, bioimageio.spec._internal.io.FileDescr):
1815class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr):
1816    pass

Subpart of a resource description

class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr):
1819class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr):
1820    import_from: str
1821    """Where to import the callable from, i.e. `from <import_from> import <callable>`"""

Subpart of a resource description

import_from: str

Where to import the callable from, i.e. from <import_from> import <callable>

ArchitectureDescr = typing.Annotated[typing.Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')])]
class WeightsEntryDescrBase(bioimageio.spec._internal.io.FileDescr):
1887class WeightsEntryDescrBase(FileDescr):
1888    type: ClassVar[WeightsFormat]
1889    weights_format_name: ClassVar[str]  # human readable
1890
1891    source: ImportantFileSource
1892    """∈📦 The weights file."""
1893
1894    authors: Optional[List[Author]] = None
1895    """Authors
1896    Either the person(s) that have trained this model resulting in the original weights file.
1897        (If this is the initial weights entry, i.e. it does not have a `parent`)
1898    Or the person(s) who have converted the weights to this weights format.
1899        (If this is a child weight, i.e. it has a `parent` field)
1900    """
1901
1902    parent: Annotated[
1903        Optional[WeightsFormat], Field(examples=["pytorch_state_dict"])
1904    ] = None
1905    """The source weights these weights were converted from.
1906    For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`,
1907    The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights.
1908    All weight entries except one (the initial set of weights resulting from training the model),
1909    need to have this field."""
1910
1911    @model_validator(mode="after")
1912    def check_parent_is_not_self(self) -> Self:
1913        if self.type == self.parent:
1914            raise ValueError("Weights entry can't be it's own parent.")
1915
1916        return self

Subpart of a resource description

type: ClassVar[Literal['keras_hdf5', 'onnx', 'pytorch_state_dict', 'tensorflow_js', 'tensorflow_saved_model_bundle', 'torchscript']]
weights_format_name: ClassVar[str]
source: Annotated[Union[bioimageio.spec._internal.url.HttpUrl, bioimageio.spec._internal.io.RelativeFilePath, Annotated[pathlib.Path, PathType(path_type='file')]], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')]), AfterValidator(func=<function wo_special_file_name at 0x7f9a7f3b8ea0>), PlainSerializer(func=<function _package at 0x7f9a7f3b9620>, return_type=PydanticUndefined, when_used='unless-none')]

∈📦 The weights file.

authors: Optional[List[bioimageio.spec.generic.v0_3.Author]]

Authors Either the person(s) that have trained this model resulting in the original weights file. (If this is the initial weights entry, i.e. it does not have a parent) Or the person(s) who have converted the weights to this weights format. (If this is a child weight, i.e. it has a parent field)

parent: Annotated[Optional[Literal['keras_hdf5', 'onnx', 'pytorch_state_dict', 'tensorflow_js', 'tensorflow_saved_model_bundle', 'torchscript']], FieldInfo(annotation=NoneType, required=True, examples=['pytorch_state_dict'])]

The source weights these weights were converted from. For example, if a model's weights were converted from the pytorch_state_dict format to torchscript, The pytorch_state_dict weights entry has no parent and is the parent of the torchscript weights. All weight entries except one (the initial set of weights resulting from training the model), need to have this field.

@model_validator(mode='after')
def check_parent_is_not_self(self) -> Self:
1911    @model_validator(mode="after")
1912    def check_parent_is_not_self(self) -> Self:
1913        if self.type == self.parent:
1914            raise ValueError("Weights entry can't be it's own parent.")
1915
1916        return self
class KerasHdf5WeightsDescr(WeightsEntryDescrBase):
1919class KerasHdf5WeightsDescr(WeightsEntryDescrBase):
1920    type = "keras_hdf5"
1921    weights_format_name: ClassVar[str] = "Keras HDF5"
1922    tensorflow_version: Version
1923    """TensorFlow version used to create these weights."""

Subpart of a resource description

type = 'keras_hdf5'
weights_format_name: ClassVar[str] = 'Keras HDF5'
tensorflow_version: bioimageio.spec._internal.version_type.Version

TensorFlow version used to create these weights.

class OnnxWeightsDescr(WeightsEntryDescrBase):
1926class OnnxWeightsDescr(WeightsEntryDescrBase):
1927    type = "onnx"
1928    weights_format_name: ClassVar[str] = "ONNX"
1929    opset_version: Annotated[int, Ge(7)]
1930    """ONNX opset version"""

Subpart of a resource description

type = 'onnx'
weights_format_name: ClassVar[str] = 'ONNX'
opset_version: Annotated[int, Ge(ge=7)]

ONNX opset version

class PytorchStateDictWeightsDescr(WeightsEntryDescrBase):
1933class PytorchStateDictWeightsDescr(WeightsEntryDescrBase):
1934    type = "pytorch_state_dict"
1935    weights_format_name: ClassVar[str] = "Pytorch State Dict"
1936    architecture: ArchitectureDescr
1937    pytorch_version: Version
1938    """Version of the PyTorch library used.
1939    If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible.
1940    """
1941    dependencies: Optional[EnvironmentFileDescr] = None
1942    """Custom depencies beyond pytorch.
1943    The conda environment file should include pytorch and any version pinning has to be compatible with
1944    `pytorch_version`.
1945    """

Subpart of a resource description

type = 'pytorch_state_dict'
weights_format_name: ClassVar[str] = 'Pytorch State Dict'
architecture: Annotated[Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')])]
pytorch_version: bioimageio.spec._internal.version_type.Version

Version of the PyTorch library used. If architecture.depencencies is specified it has to include pytorch and any version pinning has to be compatible.

dependencies: Optional[EnvironmentFileDescr]

Custom depencies beyond pytorch. The conda environment file should include pytorch and any version pinning has to be compatible with pytorch_version.

class TensorflowJsWeightsDescr(WeightsEntryDescrBase):
1948class TensorflowJsWeightsDescr(WeightsEntryDescrBase):
1949    type = "tensorflow_js"
1950    weights_format_name: ClassVar[str] = "Tensorflow.js"
1951    tensorflow_version: Version
1952    """Version of the TensorFlow library used."""
1953
1954    source: ImportantFileSource
1955    """∈📦 The multi-file weights.
1956    All required files/folders should be a zip archive."""

Subpart of a resource description

type = 'tensorflow_js'
weights_format_name: ClassVar[str] = 'Tensorflow.js'
tensorflow_version: bioimageio.spec._internal.version_type.Version

Version of the TensorFlow library used.

source: Annotated[Union[bioimageio.spec._internal.url.HttpUrl, bioimageio.spec._internal.io.RelativeFilePath, Annotated[pathlib.Path, PathType(path_type='file')]], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')]), AfterValidator(func=<function wo_special_file_name at 0x7f9a7f3b8ea0>), PlainSerializer(func=<function _package at 0x7f9a7f3b9620>, return_type=PydanticUndefined, when_used='unless-none')]

∈📦 The multi-file weights. All required files/folders should be a zip archive.

class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase):
1959class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase):
1960    type = "tensorflow_saved_model_bundle"
1961    weights_format_name: ClassVar[str] = "Tensorflow Saved Model"
1962    tensorflow_version: Version
1963    """Version of the TensorFlow library used."""
1964
1965    dependencies: Optional[EnvironmentFileDescr] = None
1966    """Custom dependencies beyond tensorflow.
1967    Should include tensorflow and any version pinning has to be compatible with `tensorflow_version`."""
1968
1969    source: ImportantFileSource
1970    """∈📦 The multi-file weights.
1971    All required files/folders should be a zip archive."""

Subpart of a resource description

type = 'tensorflow_saved_model_bundle'
weights_format_name: ClassVar[str] = 'Tensorflow Saved Model'
tensorflow_version: bioimageio.spec._internal.version_type.Version

Version of the TensorFlow library used.

dependencies: Optional[EnvironmentFileDescr]

Custom dependencies beyond tensorflow. Should include tensorflow and any version pinning has to be compatible with tensorflow_version.

source: Annotated[Union[bioimageio.spec._internal.url.HttpUrl, bioimageio.spec._internal.io.RelativeFilePath, Annotated[pathlib.Path, PathType(path_type='file')]], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')]), AfterValidator(func=<function wo_special_file_name at 0x7f9a7f3b8ea0>), PlainSerializer(func=<function _package at 0x7f9a7f3b9620>, return_type=PydanticUndefined, when_used='unless-none')]

∈📦 The multi-file weights. All required files/folders should be a zip archive.

class TorchscriptWeightsDescr(WeightsEntryDescrBase):
1974class TorchscriptWeightsDescr(WeightsEntryDescrBase):
1975    type = "torchscript"
1976    weights_format_name: ClassVar[str] = "TorchScript"
1977    pytorch_version: Version
1978    """Version of the PyTorch library used."""

Subpart of a resource description

type = 'torchscript'
weights_format_name: ClassVar[str] = 'TorchScript'
pytorch_version: bioimageio.spec._internal.version_type.Version

Version of the PyTorch library used.

class WeightsDescr(bioimageio.spec._internal.node.Node):
1981class WeightsDescr(Node):
1982    keras_hdf5: Optional[KerasHdf5WeightsDescr] = None
1983    onnx: Optional[OnnxWeightsDescr] = None
1984    pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None
1985    tensorflow_js: Optional[TensorflowJsWeightsDescr] = None
1986    tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = (
1987        None
1988    )
1989    torchscript: Optional[TorchscriptWeightsDescr] = None
1990
1991    @model_validator(mode="after")
1992    def check_entries(self) -> Self:
1993        entries = {wtype for wtype, entry in self if entry is not None}
1994
1995        if not entries:
1996            raise ValueError("Missing weights entry")
1997
1998        entries_wo_parent = {
1999            wtype
2000            for wtype, entry in self
2001            if entry is not None and hasattr(entry, "parent") and entry.parent is None
2002        }
2003        if len(entries_wo_parent) != 1:
2004            issue_warning(
2005                "Exactly one weights entry may not specify the `parent` field (got"
2006                + " {value}). That entry is considered the original set of model weights."
2007                + " Other weight formats are created through conversion of the orignal or"
2008                + " already converted weights. They have to reference the weights format"
2009                + " they were converted from as their `parent`.",
2010                value=len(entries_wo_parent),
2011                field="weights",
2012            )
2013
2014        for wtype, entry in self:
2015            if entry is None:
2016                continue
2017
2018            assert hasattr(entry, "type")
2019            assert hasattr(entry, "parent")
2020            assert wtype == entry.type
2021            if (
2022                entry.parent is not None and entry.parent not in entries
2023            ):  # self reference checked for `parent` field
2024                raise ValueError(
2025                    f"`weights.{wtype}.parent={entry.parent} not in specified weight"
2026                    + f" formats: {entries}"
2027                )
2028
2029        return self

Subpart of a resource description

keras_hdf5: Optional[KerasHdf5WeightsDescr]
onnx: Optional[OnnxWeightsDescr]
pytorch_state_dict: Optional[PytorchStateDictWeightsDescr]
tensorflow_js: Optional[TensorflowJsWeightsDescr]
tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr]
torchscript: Optional[TorchscriptWeightsDescr]
@model_validator(mode='after')
def check_entries(self) -> Self:
1991    @model_validator(mode="after")
1992    def check_entries(self) -> Self:
1993        entries = {wtype for wtype, entry in self if entry is not None}
1994
1995        if not entries:
1996            raise ValueError("Missing weights entry")
1997
1998        entries_wo_parent = {
1999            wtype
2000            for wtype, entry in self
2001            if entry is not None and hasattr(entry, "parent") and entry.parent is None
2002        }
2003        if len(entries_wo_parent) != 1:
2004            issue_warning(
2005                "Exactly one weights entry may not specify the `parent` field (got"
2006                + " {value}). That entry is considered the original set of model weights."
2007                + " Other weight formats are created through conversion of the orignal or"
2008                + " already converted weights. They have to reference the weights format"
2009                + " they were converted from as their `parent`.",
2010                value=len(entries_wo_parent),
2011                field="weights",
2012            )
2013
2014        for wtype, entry in self:
2015            if entry is None:
2016                continue
2017
2018            assert hasattr(entry, "type")
2019            assert hasattr(entry, "parent")
2020            assert wtype == entry.type
2021            if (
2022                entry.parent is not None and entry.parent not in entries
2023            ):  # self reference checked for `parent` field
2024                raise ValueError(
2025                    f"`weights.{wtype}.parent={entry.parent} not in specified weight"
2026                    + f" formats: {entries}"
2027                )
2028
2029        return self
class ModelId(bioimageio.spec.generic.v0_3.ResourceId):
2032class ModelId(ResourceId):
2033    pass

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.

class LinkedModel(bioimageio.spec.generic.v0_3.LinkedResourceNode):
2036class LinkedModel(LinkedResourceNode):
2037    """Reference to a bioimage.io model."""
2038
2039    id: ModelId
2040    """A valid model `id` from the bioimage.io collection."""

Reference to a bioimage.io model.

id: ModelId

A valid model id from the bioimage.io collection.

2062class ModelDescr(GenericModelDescrBase, title="bioimage.io model specification"):
2063    """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights.
2064    These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
2065    """
2066
2067    format_version: Literal["0.5.3"] = "0.5.3"
2068    """Version of the bioimage.io model description specification used.
2069    When creating a new model always use the latest micro/patch version described here.
2070    The `format_version` is important for any consumer software to understand how to parse the fields.
2071    """
2072
2073    type: Literal["model"] = "model"
2074    """Specialized resource type 'model'"""
2075
2076    id: Optional[ModelId] = None
2077    """bioimage.io-wide unique resource identifier
2078    assigned by bioimage.io; version **un**specific."""
2079
2080    authors: NotEmpty[List[Author]]
2081    """The authors are the creators of the model RDF and the primary points of contact."""
2082
2083    documentation: Annotated[
2084        DocumentationSource,
2085        Field(
2086            examples=[
2087                "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/README.md",
2088                "README.md",
2089            ],
2090        ),
2091    ]
2092    """∈📦 URL or relative path to a markdown file with additional documentation.
2093    The recommended documentation file name is `README.md`. An `.md` suffix is mandatory.
2094    The documentation should include a '#[#] Validation' (sub)section
2095    with details on how to quantitatively validate the model on unseen data."""
2096
2097    @field_validator("documentation", mode="after")
2098    @classmethod
2099    def _validate_documentation(cls, value: DocumentationSource) -> DocumentationSource:
2100        if not validation_context_var.get().perform_io_checks:
2101            return value
2102
2103        doc_path = download(value).path
2104        doc_content = doc_path.read_text(encoding="utf-8")
2105        assert isinstance(doc_content, str)
2106        if not re.match("#.*[vV]alidation", doc_content):
2107            issue_warning(
2108                "No '# Validation' (sub)section found in {value}.",
2109                value=value,
2110                field="documentation",
2111            )
2112
2113        return value
2114
2115    inputs: NotEmpty[Sequence[InputTensorDescr]]
2116    """Describes the input tensors expected by this model."""
2117
2118    @field_validator("inputs", mode="after")
2119    @classmethod
2120    def _validate_input_axes(
2121        cls, inputs: Sequence[InputTensorDescr]
2122    ) -> Sequence[InputTensorDescr]:
2123        input_size_refs = cls._get_axes_with_independent_size(inputs)
2124
2125        for i, ipt in enumerate(inputs):
2126            valid_independent_refs: Dict[
2127                Tuple[TensorId, AxisId],
2128                Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2129            ] = {
2130                **{
2131                    (ipt.id, a.id): (ipt, a, a.size)
2132                    for a in ipt.axes
2133                    if not isinstance(a, BatchAxis)
2134                    and isinstance(a.size, (int, ParameterizedSize))
2135                },
2136                **input_size_refs,
2137            }
2138            for a, ax in enumerate(ipt.axes):
2139                cls._validate_axis(
2140                    "inputs",
2141                    i=i,
2142                    tensor_id=ipt.id,
2143                    a=a,
2144                    axis=ax,
2145                    valid_independent_refs=valid_independent_refs,
2146                )
2147        return inputs
2148
2149    @staticmethod
2150    def _validate_axis(
2151        field_name: str,
2152        i: int,
2153        tensor_id: TensorId,
2154        a: int,
2155        axis: AnyAxis,
2156        valid_independent_refs: Dict[
2157            Tuple[TensorId, AxisId],
2158            Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2159        ],
2160    ):
2161        if isinstance(axis, BatchAxis) or isinstance(
2162            axis.size, (int, ParameterizedSize, DataDependentSize)
2163        ):
2164            return
2165        elif not isinstance(axis.size, SizeReference):
2166            assert_never(axis.size)
2167
2168        # validate axis.size SizeReference
2169        ref = (axis.size.tensor_id, axis.size.axis_id)
2170        if ref not in valid_independent_refs:
2171            raise ValueError(
2172                "Invalid tensor axis reference at"
2173                + f" {field_name}[{i}].axes[{a}].size: {axis.size}."
2174            )
2175        if ref == (tensor_id, axis.id):
2176            raise ValueError(
2177                "Self-referencing not allowed for"
2178                + f" {field_name}[{i}].axes[{a}].size: {axis.size}"
2179            )
2180        if axis.type == "channel":
2181            if valid_independent_refs[ref][1].type != "channel":
2182                raise ValueError(
2183                    "A channel axis' size may only reference another fixed size"
2184                    + " channel axis."
2185                )
2186            if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names:
2187                ref_size = valid_independent_refs[ref][2]
2188                assert isinstance(ref_size, int), (
2189                    "channel axis ref (another channel axis) has to specify fixed"
2190                    + " size"
2191                )
2192                generated_channel_names = [
2193                    Identifier(axis.channel_names.format(i=i))
2194                    for i in range(1, ref_size + 1)
2195                ]
2196                axis.channel_names = generated_channel_names
2197
2198        if (ax_unit := getattr(axis, "unit", None)) != (
2199            ref_unit := getattr(valid_independent_refs[ref][1], "unit", None)
2200        ):
2201            raise ValueError(
2202                "The units of an axis and its reference axis need to match, but"
2203                + f" '{ax_unit}' != '{ref_unit}'."
2204            )
2205        ref_axis = valid_independent_refs[ref][1]
2206        if isinstance(ref_axis, BatchAxis):
2207            raise ValueError(
2208                f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}"
2209                + " (a batch axis is not allowed as reference)."
2210            )
2211
2212        if isinstance(axis, WithHalo):
2213            min_size = axis.size.get_size(axis, ref_axis, n=0)
2214            if (min_size - 2 * axis.halo) < 1:
2215                raise ValueError(
2216                    f"axis {axis.id} with minimum size {min_size} is too small for halo"
2217                    + f" {axis.halo}."
2218                )
2219
2220            input_halo = axis.halo * axis.scale / ref_axis.scale
2221            if input_halo != int(input_halo) or input_halo % 2 == 1:
2222                raise ValueError(
2223                    f"input_halo {input_halo} (output_halo {axis.halo} *"
2224                    + f" output_scale {axis.scale} / input_scale {ref_axis.scale})"
2225                    + f" is not an even integer for {tensor_id}.{axis.id}."
2226                )
2227
2228    @model_validator(mode="after")
2229    def _validate_test_tensors(self) -> Self:
2230        if not validation_context_var.get().perform_io_checks:
2231            return self
2232
2233        test_arrays = [
2234            load_array(descr.test_tensor.download().path)
2235            for descr in chain(self.inputs, self.outputs)
2236        ]
2237        tensors = {
2238            descr.id: (descr, array)
2239            for descr, array in zip(chain(self.inputs, self.outputs), test_arrays)
2240        }
2241        validate_tensors(tensors, tensor_origin="test_tensor")
2242        return self
2243
2244    @model_validator(mode="after")
2245    def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self:
2246        ipt_refs = {t.id for t in self.inputs}
2247        out_refs = {t.id for t in self.outputs}
2248        for ipt in self.inputs:
2249            for p in ipt.preprocessing:
2250                ref = p.kwargs.get("reference_tensor")
2251                if ref is None:
2252                    continue
2253                if ref not in ipt_refs:
2254                    raise ValueError(
2255                        f"`reference_tensor` '{ref}' not found. Valid input tensor"
2256                        + f" references are: {ipt_refs}."
2257                    )
2258
2259        for out in self.outputs:
2260            for p in out.postprocessing:
2261                ref = p.kwargs.get("reference_tensor")
2262                if ref is None:
2263                    continue
2264
2265                if ref not in ipt_refs and ref not in out_refs:
2266                    raise ValueError(
2267                        f"`reference_tensor` '{ref}' not found. Valid tensor references"
2268                        + f" are: {ipt_refs | out_refs}."
2269                    )
2270
2271        return self
2272
2273    # TODO: use validate funcs in validate_test_tensors
2274    # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]:
2275
2276    name: Annotated[
2277        Annotated[
2278            str, RestrictCharacters(string.ascii_letters + string.digits + "_- ()")
2279        ],
2280        MinLen(5),
2281        MaxLen(128),
2282        warn(MaxLen(64), "Name longer than 64 characters.", INFO),
2283    ]
2284    """A human-readable name of this model.
2285    It should be no longer than 64 characters
2286    and may only contain letter, number, underscore, minus, parentheses and spaces.
2287    We recommend to chose a name that refers to the model's task and image modality.
2288    """
2289
2290    outputs: NotEmpty[Sequence[OutputTensorDescr]]
2291    """Describes the output tensors."""
2292
2293    @field_validator("outputs", mode="after")
2294    @classmethod
2295    def _validate_tensor_ids(
2296        cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo
2297    ) -> Sequence[OutputTensorDescr]:
2298        tensor_ids = [
2299            t.id for t in info.data.get("inputs", []) + info.data.get("outputs", [])
2300        ]
2301        duplicate_tensor_ids: List[str] = []
2302        seen: Set[str] = set()
2303        for t in tensor_ids:
2304            if t in seen:
2305                duplicate_tensor_ids.append(t)
2306
2307            seen.add(t)
2308
2309        if duplicate_tensor_ids:
2310            raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}")
2311
2312        return outputs
2313
2314    @staticmethod
2315    def _get_axes_with_parameterized_size(
2316        io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2317    ):
2318        return {
2319            f"{t.id}.{a.id}": (t, a, a.size)
2320            for t in io
2321            for a in t.axes
2322            if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize)
2323        }
2324
2325    @staticmethod
2326    def _get_axes_with_independent_size(
2327        io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2328    ):
2329        return {
2330            (t.id, a.id): (t, a, a.size)
2331            for t in io
2332            for a in t.axes
2333            if not isinstance(a, BatchAxis)
2334            and isinstance(a.size, (int, ParameterizedSize))
2335        }
2336
2337    @field_validator("outputs", mode="after")
2338    @classmethod
2339    def _validate_output_axes(
2340        cls, outputs: List[OutputTensorDescr], info: ValidationInfo
2341    ) -> List[OutputTensorDescr]:
2342        input_size_refs = cls._get_axes_with_independent_size(
2343            info.data.get("inputs", [])
2344        )
2345        output_size_refs = cls._get_axes_with_independent_size(outputs)
2346
2347        for i, out in enumerate(outputs):
2348            valid_independent_refs: Dict[
2349                Tuple[TensorId, AxisId],
2350                Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2351            ] = {
2352                **{
2353                    (out.id, a.id): (out, a, a.size)
2354                    for a in out.axes
2355                    if not isinstance(a, BatchAxis)
2356                    and isinstance(a.size, (int, ParameterizedSize))
2357                },
2358                **input_size_refs,
2359                **output_size_refs,
2360            }
2361            for a, ax in enumerate(out.axes):
2362                cls._validate_axis(
2363                    "outputs",
2364                    i,
2365                    out.id,
2366                    a,
2367                    ax,
2368                    valid_independent_refs=valid_independent_refs,
2369                )
2370
2371        return outputs
2372
2373    packaged_by: List[Author] = Field(default_factory=list)
2374    """The persons that have packaged and uploaded this model.
2375    Only required if those persons differ from the `authors`."""
2376
2377    parent: Optional[LinkedModel] = None
2378    """The model from which this model is derived, e.g. by fine-tuning the weights."""
2379
2380    # todo: add parent self check once we have `id`
2381    # @model_validator(mode="after")
2382    # def validate_parent_is_not_self(self) -> Self:
2383    #     if self.parent is not None and self.parent == self.id:
2384    #         raise ValueError("The model may not reference itself as parent model")
2385
2386    #     return self
2387
2388    run_mode: Annotated[
2389        Optional[RunMode],
2390        warn(None, "Run mode '{value}' has limited support across consumer softwares."),
2391    ] = None
2392    """Custom run mode for this model: for more complex prediction procedures like test time
2393    data augmentation that currently cannot be expressed in the specification.
2394    No standard run modes are defined yet."""
2395
2396    timestamp: Datetime = Datetime(datetime.now())
2397    """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format
2398    with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat).
2399    (In Python a datetime object is valid, too)."""
2400
2401    training_data: Annotated[
2402        Union[None, LinkedDataset, DatasetDescr, DatasetDescr02],
2403        Field(union_mode="left_to_right"),
2404    ] = None
2405    """The dataset used to train this model"""
2406
2407    weights: Annotated[WeightsDescr, WrapSerializer(package_weights)]
2408    """The weights for this model.
2409    Weights can be given for different formats, but should otherwise be equivalent.
2410    The available weight formats determine which consumers can use this model."""
2411
2412    @model_validator(mode="after")
2413    def _add_default_cover(self) -> Self:
2414        if not validation_context_var.get().perform_io_checks or self.covers:
2415            return self
2416
2417        try:
2418            generated_covers = generate_covers(
2419                [(t, load_array(t.test_tensor.download().path)) for t in self.inputs],
2420                [(t, load_array(t.test_tensor.download().path)) for t in self.outputs],
2421            )
2422        except Exception as e:
2423            issue_warning(
2424                "Failed to generate cover image(s): {e}",
2425                value=self.covers,
2426                msg_context=dict(e=e),
2427                field="covers",
2428            )
2429        else:
2430            self.covers.extend(generated_covers)
2431
2432        return self
2433
2434    def get_input_test_arrays(self) -> List[NDArray[Any]]:
2435        data = [load_array(ipt.test_tensor.download().path) for ipt in self.inputs]
2436        assert all(isinstance(d, np.ndarray) for d in data)
2437        return data
2438
2439    def get_output_test_arrays(self) -> List[NDArray[Any]]:
2440        data = [load_array(out.test_tensor.download().path) for out in self.outputs]
2441        assert all(isinstance(d, np.ndarray) for d in data)
2442        return data
2443
2444    @staticmethod
2445    def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
2446        batch_size = 1
2447        tensor_with_batchsize: Optional[TensorId] = None
2448        for tid in tensor_sizes:
2449            for aid, s in tensor_sizes[tid].items():
2450                if aid != BATCH_AXIS_ID or s == 1 or s == batch_size:
2451                    continue
2452
2453                if batch_size != 1:
2454                    assert tensor_with_batchsize is not None
2455                    raise ValueError(
2456                        f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})"
2457                    )
2458
2459                batch_size = s
2460                tensor_with_batchsize = tid
2461
2462        return batch_size
2463
2464    def get_output_tensor_sizes(
2465        self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]
2466    ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]:
2467        """Returns the tensor output sizes for given **input_sizes**.
2468        Only if **input_sizes** has a valid input shape, the tensor output size is exact.
2469        Otherwise it might be larger than the actual (valid) output"""
2470        batch_size = self.get_batch_size(input_sizes)
2471        ns = self.get_ns(input_sizes)
2472
2473        tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size)
2474        return tensor_sizes.outputs
2475
2476    def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
2477        """get parameter `n` for each parameterized axis
2478        such that the valid input size is >= the given input size"""
2479        ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {}
2480        axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs}
2481        for tid in input_sizes:
2482            for aid, s in input_sizes[tid].items():
2483                size_descr = axes[tid][aid].size
2484                if isinstance(size_descr, ParameterizedSize):
2485                    ret[(tid, aid)] = size_descr.get_n(s)
2486                elif size_descr is None or isinstance(size_descr, (int, SizeReference)):
2487                    pass
2488                else:
2489                    assert_never(size_descr)
2490
2491        return ret
2492
2493    def get_tensor_sizes(
2494        self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int
2495    ) -> _TensorSizes:
2496        axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size)
2497        return _TensorSizes(
2498            {
2499                t: {
2500                    aa: axis_sizes.inputs[(tt, aa)]
2501                    for tt, aa in axis_sizes.inputs
2502                    if tt == t
2503                }
2504                for t in {tt for tt, _ in axis_sizes.inputs}
2505            },
2506            {
2507                t: {
2508                    aa: axis_sizes.outputs[(tt, aa)]
2509                    for tt, aa in axis_sizes.outputs
2510                    if tt == t
2511                }
2512                for t in {tt for tt, _ in axis_sizes.outputs}
2513            },
2514        )
2515
2516    def get_axis_sizes(
2517        self,
2518        ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
2519        batch_size: Optional[int] = None,
2520        *,
2521        max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
2522    ) -> _AxisSizes:
2523        """Determine input and output block shape for scale factors **ns**
2524        of parameterized input sizes.
2525
2526        Args:
2527            ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id))
2528                that is parameterized as `size = min + n * step`.
2529            batch_size: The desired size of the batch dimension.
2530                If given **batch_size** overwrites any batch size present in
2531                **max_input_shape**. Default 1.
2532            max_input_shape: Limits the derived block shapes.
2533                Each axis for which the input size, parameterized by `n`, is larger
2534                than **max_input_shape** is set to the minimal value `n_min` for which
2535                this is still true.
2536                Use this for small input samples or large values of **ns**.
2537                Or simply whenever you know the full input shape.
2538
2539        Returns:
2540            Resolved axis sizes for model inputs and outputs.
2541        """
2542        max_input_shape = max_input_shape or {}
2543        if batch_size is None:
2544            for (_t_id, a_id), s in max_input_shape.items():
2545                if a_id == BATCH_AXIS_ID:
2546                    batch_size = s
2547                    break
2548            else:
2549                batch_size = 1
2550
2551        all_axes = {
2552            t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
2553        }
2554
2555        inputs: Dict[Tuple[TensorId, AxisId], int] = {}
2556        outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {}
2557
2558        def get_axis_size(a: Union[InputAxis, OutputAxis]):
2559            if isinstance(a, BatchAxis):
2560                if (t_descr.id, a.id) in ns:
2561                    logger.warning(
2562                        "Ignoring unexpected size increment factor (n) for batch axis"
2563                        + " of tensor '{}'.",
2564                        t_descr.id,
2565                    )
2566                return batch_size
2567            elif isinstance(a.size, int):
2568                if (t_descr.id, a.id) in ns:
2569                    logger.warning(
2570                        "Ignoring unexpected size increment factor (n) for fixed size"
2571                        + " axis '{}' of tensor '{}'.",
2572                        a.id,
2573                        t_descr.id,
2574                    )
2575                return a.size
2576            elif isinstance(a.size, ParameterizedSize):
2577                if (t_descr.id, a.id) not in ns:
2578                    raise ValueError(
2579                        "Size increment factor (n) missing for parametrized axis"
2580                        + f" '{a.id}' of tensor '{t_descr.id}'."
2581                    )
2582                n = ns[(t_descr.id, a.id)]
2583                s_max = max_input_shape.get((t_descr.id, a.id))
2584                if s_max is not None:
2585                    n = min(n, a.size.get_n(s_max))
2586
2587                return a.size.get_size(n)
2588
2589            elif isinstance(a.size, SizeReference):
2590                if (t_descr.id, a.id) in ns:
2591                    logger.warning(
2592                        "Ignoring unexpected size increment factor (n) for axis '{}'"
2593                        + " of tensor '{}' with size reference.",
2594                        a.id,
2595                        t_descr.id,
2596                    )
2597                assert not isinstance(a, BatchAxis)
2598                ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
2599                assert not isinstance(ref_axis, BatchAxis)
2600                ref_key = (a.size.tensor_id, a.size.axis_id)
2601                ref_size = inputs.get(ref_key, outputs.get(ref_key))
2602                assert ref_size is not None, ref_key
2603                assert not isinstance(ref_size, _DataDepSize), ref_key
2604                return a.size.get_size(
2605                    axis=a,
2606                    ref_axis=ref_axis,
2607                    ref_size=ref_size,
2608                )
2609            elif isinstance(a.size, DataDependentSize):
2610                if (t_descr.id, a.id) in ns:
2611                    logger.warning(
2612                        "Ignoring unexpected increment factor (n) for data dependent"
2613                        + " size axis '{}' of tensor '{}'.",
2614                        a.id,
2615                        t_descr.id,
2616                    )
2617                return _DataDepSize(a.size.min, a.size.max)
2618            else:
2619                assert_never(a.size)
2620
2621        # first resolve all , but the `SizeReference` input sizes
2622        for t_descr in self.inputs:
2623            for a in t_descr.axes:
2624                if not isinstance(a.size, SizeReference):
2625                    s = get_axis_size(a)
2626                    assert not isinstance(s, _DataDepSize)
2627                    inputs[t_descr.id, a.id] = s
2628
2629        # resolve all other input axis sizes
2630        for t_descr in self.inputs:
2631            for a in t_descr.axes:
2632                if isinstance(a.size, SizeReference):
2633                    s = get_axis_size(a)
2634                    assert not isinstance(s, _DataDepSize)
2635                    inputs[t_descr.id, a.id] = s
2636
2637        # resolve all output axis sizes
2638        for t_descr in self.outputs:
2639            for a in t_descr.axes:
2640                assert not isinstance(a.size, ParameterizedSize)
2641                s = get_axis_size(a)
2642                outputs[t_descr.id, a.id] = s
2643
2644        return _AxisSizes(inputs=inputs, outputs=outputs)
2645
2646    @model_validator(mode="before")
2647    @classmethod
2648    def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]:
2649        if (
2650            data.get("type") == "model"
2651            and isinstance(fv := data.get("format_version"), str)
2652            and fv.count(".") == 2
2653        ):
2654            fv_parts = fv.split(".")
2655            if any(not p.isdigit() for p in fv_parts):
2656                return data
2657
2658            fv_tuple = tuple(map(int, fv_parts))
2659
2660            assert cls.implemented_format_version_tuple[0:2] == (0, 5)
2661            if fv_tuple[:2] in ((0, 3), (0, 4)):
2662                m04 = _ModelDescr_v0_4.load(data)
2663                if not isinstance(m04, InvalidDescr):
2664                    return _model_conv.convert_as_dict(m04)
2665            elif fv_tuple[:2] == (0, 5):
2666                # bump patch version
2667                data["format_version"] = cls.implemented_format_version
2668
2669        return data

Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. These fields are typically stored in a YAML file which we call a model resource description file (model RDF).

format_version: Literal['0.5.3']

Version of the bioimage.io model description specification used. When creating a new model always use the latest micro/patch version described here. The format_version is important for any consumer software to understand how to parse the fields.

type: Literal['model']

Specialized resource type 'model'

id: Optional[ModelId]

bioimage.io-wide unique resource identifier assigned by bioimage.io; version unspecific.

authors: Annotated[List[bioimageio.spec.generic.v0_3.Author], MinLen(min_length=1)]

The authors are the creators of the model RDF and the primary points of contact.

documentation: Annotated[Union[Annotated[pathlib.Path, PathType(path_type='file'), Predicate(is_absolute)], bioimageio.spec._internal.io.RelativeFilePath, bioimageio.spec._internal.url.HttpUrl], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')]), AfterValidator(func=<function _validate_md_suffix at 0x7f9a7e7b3e20>), PlainSerializer(func=<function _package at 0x7f9a7f3b9620>, return_type=PydanticUndefined, when_used='unless-none'), FieldInfo(annotation=NoneType, required=True, examples=['https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/README.md', 'README.md'])]

∈📦 URL or relative path to a markdown file with additional documentation. The recommended documentation file name is README.md. An .md suffix is mandatory. The documentation should include a '#[#] Validation' (sub)section with details on how to quantitatively validate the model on unseen data.

inputs: Annotated[Sequence[InputTensorDescr], MinLen(min_length=1)]

Describes the input tensors expected by this model.

name: Annotated[str, RestrictCharacters(alphabet='abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_- ()'), MinLen(min_length=5), MaxLen(max_length=128), AfterWarner(func=<function as_warning.<locals>.wrapper at 0x7f9a6e849c60>, severity=20, msg='Name longer than 64 characters.', context={'typ': Annotated[Any, MaxLen(max_length=64)]})]

A human-readable name of this model. It should be no longer than 64 characters and may only contain letter, number, underscore, minus, parentheses and spaces. We recommend to chose a name that refers to the model's task and image modality.

outputs: Annotated[Sequence[OutputTensorDescr], MinLen(min_length=1)]

Describes the output tensors.

The persons that have packaged and uploaded this model. Only required if those persons differ from the authors.

parent: Optional[LinkedModel]

The model from which this model is derived, e.g. by fine-tuning the weights.

run_mode: Annotated[Optional[bioimageio.spec.model.v0_4.RunMode], AfterWarner(func=<function as_warning.<locals>.wrapper at 0x7f9a6e8298a0>, severity=30, msg="Run mode '{value}' has limited support across consumer softwares.", context={'typ': None})]

Custom run mode for this model: for more complex prediction procedures like test time data augmentation that currently cannot be expressed in the specification. No standard run modes are defined yet.

timestamp: bioimageio.spec._internal.types.Datetime

Timestamp in ISO 8601 format with a few restrictions listed here. (In Python a datetime object is valid, too).

training_data: Annotated[Union[NoneType, bioimageio.spec.dataset.v0_3.LinkedDataset, bioimageio.spec.DatasetDescr, bioimageio.spec.dataset.v0_2.DatasetDescr], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')])]

The dataset used to train this model

weights: Annotated[WeightsDescr, WrapSerializer(func=<function package_weights at 0x7f9a7cb40360>, return_type=PydanticUndefined, when_used='always')]

The weights for this model. Weights can be given for different formats, but should otherwise be equivalent. The available weight formats determine which consumers can use this model.

def get_input_test_arrays(self) -> List[numpy.ndarray[Any, numpy.dtype[Any]]]:
2434    def get_input_test_arrays(self) -> List[NDArray[Any]]:
2435        data = [load_array(ipt.test_tensor.download().path) for ipt in self.inputs]
2436        assert all(isinstance(d, np.ndarray) for d in data)
2437        return data
def get_output_test_arrays(self) -> List[numpy.ndarray[Any, numpy.dtype[Any]]]:
2439    def get_output_test_arrays(self) -> List[NDArray[Any]]:
2440        data = [load_array(out.test_tensor.download().path) for out in self.outputs]
2441        assert all(isinstance(d, np.ndarray) for d in data)
2442        return data
@staticmethod
def get_batch_size( tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
2444    @staticmethod
2445    def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
2446        batch_size = 1
2447        tensor_with_batchsize: Optional[TensorId] = None
2448        for tid in tensor_sizes:
2449            for aid, s in tensor_sizes[tid].items():
2450                if aid != BATCH_AXIS_ID or s == 1 or s == batch_size:
2451                    continue
2452
2453                if batch_size != 1:
2454                    assert tensor_with_batchsize is not None
2455                    raise ValueError(
2456                        f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})"
2457                    )
2458
2459                batch_size = s
2460                tensor_with_batchsize = tid
2461
2462        return batch_size
def get_output_tensor_sizes( self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> Dict[TensorId, Dict[AxisId, Union[int, bioimageio.spec.model.v0_5._DataDepSize]]]:
2464    def get_output_tensor_sizes(
2465        self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]
2466    ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]:
2467        """Returns the tensor output sizes for given **input_sizes**.
2468        Only if **input_sizes** has a valid input shape, the tensor output size is exact.
2469        Otherwise it might be larger than the actual (valid) output"""
2470        batch_size = self.get_batch_size(input_sizes)
2471        ns = self.get_ns(input_sizes)
2472
2473        tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size)
2474        return tensor_sizes.outputs

Returns the tensor output sizes for given input_sizes. Only if input_sizes has a valid input shape, the tensor output size is exact. Otherwise it might be larger than the actual (valid) output

def get_ns( self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
2476    def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
2477        """get parameter `n` for each parameterized axis
2478        such that the valid input size is >= the given input size"""
2479        ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {}
2480        axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs}
2481        for tid in input_sizes:
2482            for aid, s in input_sizes[tid].items():
2483                size_descr = axes[tid][aid].size
2484                if isinstance(size_descr, ParameterizedSize):
2485                    ret[(tid, aid)] = size_descr.get_n(s)
2486                elif size_descr is None or isinstance(size_descr, (int, SizeReference)):
2487                    pass
2488                else:
2489                    assert_never(size_descr)
2490
2491        return ret

get parameter n for each parameterized axis such that the valid input size is >= the given input size

def get_tensor_sizes( self, ns: Mapping[Tuple[TensorId, AxisId], int], batch_size: int) -> bioimageio.spec.model.v0_5._TensorSizes:
2493    def get_tensor_sizes(
2494        self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int
2495    ) -> _TensorSizes:
2496        axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size)
2497        return _TensorSizes(
2498            {
2499                t: {
2500                    aa: axis_sizes.inputs[(tt, aa)]
2501                    for tt, aa in axis_sizes.inputs
2502                    if tt == t
2503                }
2504                for t in {tt for tt, _ in axis_sizes.inputs}
2505            },
2506            {
2507                t: {
2508                    aa: axis_sizes.outputs[(tt, aa)]
2509                    for tt, aa in axis_sizes.outputs
2510                    if tt == t
2511                }
2512                for t in {tt for tt, _ in axis_sizes.outputs}
2513            },
2514        )
def get_axis_sizes( self, ns: Mapping[Tuple[TensorId, AxisId], int], batch_size: Optional[int] = None, *, max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None) -> bioimageio.spec.model.v0_5._AxisSizes:
2516    def get_axis_sizes(
2517        self,
2518        ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
2519        batch_size: Optional[int] = None,
2520        *,
2521        max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
2522    ) -> _AxisSizes:
2523        """Determine input and output block shape for scale factors **ns**
2524        of parameterized input sizes.
2525
2526        Args:
2527            ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id))
2528                that is parameterized as `size = min + n * step`.
2529            batch_size: The desired size of the batch dimension.
2530                If given **batch_size** overwrites any batch size present in
2531                **max_input_shape**. Default 1.
2532            max_input_shape: Limits the derived block shapes.
2533                Each axis for which the input size, parameterized by `n`, is larger
2534                than **max_input_shape** is set to the minimal value `n_min` for which
2535                this is still true.
2536                Use this for small input samples or large values of **ns**.
2537                Or simply whenever you know the full input shape.
2538
2539        Returns:
2540            Resolved axis sizes for model inputs and outputs.
2541        """
2542        max_input_shape = max_input_shape or {}
2543        if batch_size is None:
2544            for (_t_id, a_id), s in max_input_shape.items():
2545                if a_id == BATCH_AXIS_ID:
2546                    batch_size = s
2547                    break
2548            else:
2549                batch_size = 1
2550
2551        all_axes = {
2552            t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
2553        }
2554
2555        inputs: Dict[Tuple[TensorId, AxisId], int] = {}
2556        outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {}
2557
2558        def get_axis_size(a: Union[InputAxis, OutputAxis]):
2559            if isinstance(a, BatchAxis):
2560                if (t_descr.id, a.id) in ns:
2561                    logger.warning(
2562                        "Ignoring unexpected size increment factor (n) for batch axis"
2563                        + " of tensor '{}'.",
2564                        t_descr.id,
2565                    )
2566                return batch_size
2567            elif isinstance(a.size, int):
2568                if (t_descr.id, a.id) in ns:
2569                    logger.warning(
2570                        "Ignoring unexpected size increment factor (n) for fixed size"
2571                        + " axis '{}' of tensor '{}'.",
2572                        a.id,
2573                        t_descr.id,
2574                    )
2575                return a.size
2576            elif isinstance(a.size, ParameterizedSize):
2577                if (t_descr.id, a.id) not in ns:
2578                    raise ValueError(
2579                        "Size increment factor (n) missing for parametrized axis"
2580                        + f" '{a.id}' of tensor '{t_descr.id}'."
2581                    )
2582                n = ns[(t_descr.id, a.id)]
2583                s_max = max_input_shape.get((t_descr.id, a.id))
2584                if s_max is not None:
2585                    n = min(n, a.size.get_n(s_max))
2586
2587                return a.size.get_size(n)
2588
2589            elif isinstance(a.size, SizeReference):
2590                if (t_descr.id, a.id) in ns:
2591                    logger.warning(
2592                        "Ignoring unexpected size increment factor (n) for axis '{}'"
2593                        + " of tensor '{}' with size reference.",
2594                        a.id,
2595                        t_descr.id,
2596                    )
2597                assert not isinstance(a, BatchAxis)
2598                ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
2599                assert not isinstance(ref_axis, BatchAxis)
2600                ref_key = (a.size.tensor_id, a.size.axis_id)
2601                ref_size = inputs.get(ref_key, outputs.get(ref_key))
2602                assert ref_size is not None, ref_key
2603                assert not isinstance(ref_size, _DataDepSize), ref_key
2604                return a.size.get_size(
2605                    axis=a,
2606                    ref_axis=ref_axis,
2607                    ref_size=ref_size,
2608                )
2609            elif isinstance(a.size, DataDependentSize):
2610                if (t_descr.id, a.id) in ns:
2611                    logger.warning(
2612                        "Ignoring unexpected increment factor (n) for data dependent"
2613                        + " size axis '{}' of tensor '{}'.",
2614                        a.id,
2615                        t_descr.id,
2616                    )
2617                return _DataDepSize(a.size.min, a.size.max)
2618            else:
2619                assert_never(a.size)
2620
2621        # first resolve all , but the `SizeReference` input sizes
2622        for t_descr in self.inputs:
2623            for a in t_descr.axes:
2624                if not isinstance(a.size, SizeReference):
2625                    s = get_axis_size(a)
2626                    assert not isinstance(s, _DataDepSize)
2627                    inputs[t_descr.id, a.id] = s
2628
2629        # resolve all other input axis sizes
2630        for t_descr in self.inputs:
2631            for a in t_descr.axes:
2632                if isinstance(a.size, SizeReference):
2633                    s = get_axis_size(a)
2634                    assert not isinstance(s, _DataDepSize)
2635                    inputs[t_descr.id, a.id] = s
2636
2637        # resolve all output axis sizes
2638        for t_descr in self.outputs:
2639            for a in t_descr.axes:
2640                assert not isinstance(a.size, ParameterizedSize)
2641                s = get_axis_size(a)
2642                outputs[t_descr.id, a.id] = s
2643
2644        return _AxisSizes(inputs=inputs, outputs=outputs)

Determine input and output block shape for scale factors ns of parameterized input sizes.

Arguments:
  • ns: Scale factor n for each axis (keyed by (tensor_id, axis_id)) that is parameterized as size = min + n * step.
  • batch_size: The desired size of the batch dimension. If given batch_size overwrites any batch size present in max_input_shape. Default 1.
  • max_input_shape: Limits the derived block shapes. Each axis for which the input size, parameterized by n, is larger than max_input_shape is set to the minimal value n_min for which this is still true. Use this for small input samples or large values of ns. Or simply whenever you know the full input shape.
Returns:

Resolved axis sizes for model inputs and outputs.

implemented_format_version: ClassVar[str] = '0.5.3'
implemented_format_version_tuple: ClassVar[Tuple[int, int, int]] = (0, 5, 3)
def model_post_init(self: pydantic.main.BaseModel, context: Any, /) -> None:
124                    def wrapped_model_post_init(self: BaseModel, context: Any, /) -> None:
125                        """We need to both initialize private attributes and call the user-defined model_post_init
126                        method.
127                        """
128                        init_private_attributes(self, context)
129                        original_model_post_init(self, context)

We need to both initialize private attributes and call the user-defined model_post_init method.

def generate_covers( inputs: Sequence[Tuple[InputTensorDescr, numpy.ndarray[Any, numpy.dtype[Any]]]], outputs: Sequence[Tuple[OutputTensorDescr, numpy.ndarray[Any, numpy.dtype[Any]]]]) -> List[pathlib.Path]:
2894def generate_covers(
2895    inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]],
2896    outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]],
2897) -> List[Path]:
2898    def squeeze(
2899        data: NDArray[Any], axes: Sequence[AnyAxis]
2900    ) -> Tuple[NDArray[Any], List[AnyAxis]]:
2901        """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining"""
2902        if data.ndim != len(axes):
2903            raise ValueError(
2904                f"tensor shape {data.shape} does not match described axes"
2905                + f" {[a.id for a in axes]}"
2906            )
2907
2908        axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1]
2909        return data.squeeze(), axes
2910
2911    def normalize(
2912        data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7
2913    ) -> NDArray[np.float32]:
2914        data = data.astype("float32")
2915        data -= data.min(axis=axis, keepdims=True)
2916        data /= data.max(axis=axis, keepdims=True) + eps
2917        return data
2918
2919    def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]):
2920        original_shape = data.shape
2921        data, axes = squeeze(data, axes)
2922
2923        # take slice fom any batch or index axis if needed
2924        # and convert the first channel axis and take a slice from any additional channel axes
2925        slices: Tuple[slice, ...] = ()
2926        ndim = data.ndim
2927        ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2
2928        has_c_axis = False
2929        for i, a in enumerate(axes):
2930            s = data.shape[i]
2931            assert s > 1
2932            if (
2933                isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis))
2934                and ndim > ndim_need
2935            ):
2936                data = data[slices + (slice(s // 2 - 1, s // 2),)]
2937                ndim -= 1
2938            elif isinstance(a, ChannelAxis):
2939                if has_c_axis:
2940                    # second channel axis
2941                    data = data[slices + (slice(0, 1),)]
2942                    ndim -= 1
2943                else:
2944                    has_c_axis = True
2945                    if s == 2:
2946                        # visualize two channels with cyan and magenta
2947                        data = np.concatenate(
2948                            [
2949                                data[slices + (slice(1, 2),)],
2950                                data[slices + (slice(0, 1),)],
2951                                (
2952                                    data[slices + (slice(0, 1),)]
2953                                    + data[slices + (slice(1, 2),)]
2954                                )
2955                                / 2,  # TODO: take maximum instead?
2956                            ],
2957                            axis=i,
2958                        )
2959                    elif data.shape[i] == 3:
2960                        pass  # visualize 3 channels as RGB
2961                    else:
2962                        # visualize first 3 channels as RGB
2963                        data = data[slices + (slice(3),)]
2964
2965                    assert data.shape[i] == 3
2966
2967            slices += (slice(None),)
2968
2969        data, axes = squeeze(data, axes)
2970        assert len(axes) == ndim
2971        # take slice from z axis if needed
2972        slices = ()
2973        if ndim > ndim_need:
2974            for i, a in enumerate(axes):
2975                s = data.shape[i]
2976                if a.id == AxisId("z"):
2977                    data = data[slices + (slice(s // 2 - 1, s // 2),)]
2978                    data, axes = squeeze(data, axes)
2979                    ndim -= 1
2980                    break
2981
2982            slices += (slice(None),)
2983
2984        # take slice from any space or time axis
2985        slices = ()
2986
2987        for i, a in enumerate(axes):
2988            if ndim <= ndim_need:
2989                break
2990
2991            s = data.shape[i]
2992            assert s > 1
2993            if isinstance(
2994                a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis)
2995            ):
2996                data = data[slices + (slice(s // 2 - 1, s // 2),)]
2997                ndim -= 1
2998
2999            slices += (slice(None),)
3000
3001        del slices
3002        data, axes = squeeze(data, axes)
3003        assert len(axes) == ndim
3004
3005        if (has_c_axis and ndim != 3) or ndim != 2:
3006            raise ValueError(
3007                f"Failed to construct cover image from shape {original_shape}"
3008            )
3009
3010        if not has_c_axis:
3011            assert ndim == 2
3012            data = np.repeat(data[:, :, None], 3, axis=2)
3013            axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB"))))
3014            ndim += 1
3015
3016        assert ndim == 3
3017
3018        # transpose axis order such that longest axis comes first...
3019        axis_order = list(np.argsort(list(data.shape)))
3020        axis_order.reverse()
3021        # ... and channel axis is last
3022        c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0]
3023        axis_order.append(axis_order.pop(c))
3024        axes = [axes[ao] for ao in axis_order]
3025        data = data.transpose(axis_order)
3026
3027        # h, w = data.shape[:2]
3028        # if h / w  in (1.0 or 2.0):
3029        #     pass
3030        # elif h / w < 2:
3031        # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images
3032
3033        norm_along = (
3034            tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None
3035        )
3036        # normalize the data and map to 8 bit
3037        data = normalize(data, norm_along)
3038        data = (data * 255).astype("uint8")
3039
3040        return data
3041
3042    def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]):
3043        assert im0.dtype == im1.dtype == np.uint8
3044        assert im0.shape == im1.shape
3045        assert im0.ndim == 3
3046        N, M, C = im0.shape
3047        assert C == 3
3048        out = np.ones((N, M, C), dtype="uint8")
3049        for c in range(C):
3050            outc = np.tril(im0[..., c])
3051            mask = outc == 0
3052            outc[mask] = np.triu(im1[..., c])[mask]
3053            out[..., c] = outc
3054
3055        return out
3056
3057    ipt_descr, ipt = inputs[0]
3058    out_descr, out = outputs[0]
3059
3060    ipt_img = to_2d_image(ipt, ipt_descr.axes)
3061    out_img = to_2d_image(out, out_descr.axes)
3062
3063    cover_folder = Path(mkdtemp())
3064    if ipt_img.shape == out_img.shape:
3065        covers = [cover_folder / "cover.png"]
3066        imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img))
3067    else:
3068        covers = [cover_folder / "input.png", cover_folder / "output.png"]
3069        imwrite(covers[0], ipt_img)
3070        imwrite(covers[1], out_img)
3071
3072    return covers
class TensorDescrBase[Annotated[Union[BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis], Discriminator]](bioimageio.spec._internal.node.Node, typing.Generic[~IO_AxisT]):

Subpart of a resource description

class TensorDescrBase[Annotated[Union[BatchAxis, ChannelAxis, IndexOutputAxis, Annotated[Union[Annotated[TimeOutputAxis, Tag], Annotated[TimeOutputAxisWithHalo, Tag]], Discriminator], Annotated[Union[Annotated[SpaceOutputAxis, Tag], Annotated[SpaceOutputAxisWithHalo, Tag]], Discriminator]], Discriminator]](bioimageio.spec._internal.node.Node, typing.Generic[~IO_AxisT]):

Subpart of a resource description