bioimageio.spec.model.v0_5

   1from __future__ import annotations
   2
   3import collections.abc
   4import re
   5import string
   6import warnings
   7from abc import ABC
   8from copy import deepcopy
   9from itertools import chain
  10from math import ceil
  11from pathlib import Path, PurePosixPath
  12from tempfile import mkdtemp
  13from typing import (
  14    TYPE_CHECKING,
  15    Any,
  16    ClassVar,
  17    Dict,
  18    Generic,
  19    List,
  20    Literal,
  21    Mapping,
  22    NamedTuple,
  23    Optional,
  24    Sequence,
  25    Set,
  26    Tuple,
  27    Type,
  28    TypeVar,
  29    Union,
  30    cast,
  31)
  32
  33import numpy as np
  34from annotated_types import Ge, Gt, Interval, MaxLen, MinLen, Predicate
  35from imageio.v3 import imread, imwrite  # pyright: ignore[reportUnknownVariableType]
  36from loguru import logger
  37from numpy.typing import NDArray
  38from pydantic import (
  39    AfterValidator,
  40    Discriminator,
  41    Field,
  42    RootModel,
  43    Tag,
  44    ValidationInfo,
  45    WrapSerializer,
  46    field_validator,
  47    model_validator,
  48)
  49from typing_extensions import Annotated, Self, assert_never, get_args
  50
  51from .._internal.common_nodes import (
  52    InvalidDescr,
  53    Node,
  54    NodeWithExplicitlySetFields,
  55)
  56from .._internal.constants import DTYPE_LIMITS
  57from .._internal.field_warning import issue_warning, warn
  58from .._internal.io import BioimageioYamlContent as BioimageioYamlContent
  59from .._internal.io import FileDescr as FileDescr
  60from .._internal.io import WithSuffix, YamlValue, download
  61from .._internal.io_basics import AbsoluteFilePath as AbsoluteFilePath
  62from .._internal.io_basics import Sha256 as Sha256
  63from .._internal.io_utils import load_array
  64from .._internal.node_converter import Converter
  65from .._internal.types import (
  66    AbsoluteTolerance,
  67    ImportantFileSource,
  68    LowerCaseIdentifier,
  69    LowerCaseIdentifierAnno,
  70    MismatchedElementsPerMillion,
  71    RelativeTolerance,
  72    SiUnit,
  73)
  74from .._internal.types import Datetime as Datetime
  75from .._internal.types import Identifier as Identifier
  76from .._internal.types import NotEmpty as NotEmpty
  77from .._internal.url import HttpUrl as HttpUrl
  78from .._internal.validation_context import get_validation_context
  79from .._internal.validator_annotations import RestrictCharacters
  80from .._internal.version_type import Version as Version
  81from .._internal.warning_levels import INFO
  82from ..dataset.v0_2 import DatasetDescr as DatasetDescr02
  83from ..dataset.v0_2 import LinkedDataset as LinkedDataset02
  84from ..dataset.v0_3 import DatasetDescr as DatasetDescr
  85from ..dataset.v0_3 import DatasetId as DatasetId
  86from ..dataset.v0_3 import LinkedDataset as LinkedDataset
  87from ..dataset.v0_3 import Uploader as Uploader
  88from ..generic.v0_3 import (
  89    VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS,
  90)
  91from ..generic.v0_3 import Author as Author
  92from ..generic.v0_3 import BadgeDescr as BadgeDescr
  93from ..generic.v0_3 import CiteEntry as CiteEntry
  94from ..generic.v0_3 import DeprecatedLicenseId as DeprecatedLicenseId
  95from ..generic.v0_3 import (
  96    DocumentationSource,
  97    GenericModelDescrBase,
  98    LinkedResourceBase,
  99    _author_conv,  # pyright: ignore[reportPrivateUsage]
 100    _maintainer_conv,  # pyright: ignore[reportPrivateUsage]
 101)
 102from ..generic.v0_3 import Doi as Doi
 103from ..generic.v0_3 import LicenseId as LicenseId
 104from ..generic.v0_3 import LinkedResource as LinkedResource
 105from ..generic.v0_3 import Maintainer as Maintainer
 106from ..generic.v0_3 import OrcidId as OrcidId
 107from ..generic.v0_3 import RelativeFilePath as RelativeFilePath
 108from ..generic.v0_3 import ResourceId as ResourceId
 109from .v0_4 import Author as _Author_v0_4
 110from .v0_4 import BinarizeDescr as _BinarizeDescr_v0_4
 111from .v0_4 import CallableFromDepencency as CallableFromDepencency
 112from .v0_4 import CallableFromDepencency as _CallableFromDepencency_v0_4
 113from .v0_4 import CallableFromFile as _CallableFromFile_v0_4
 114from .v0_4 import ClipDescr as _ClipDescr_v0_4
 115from .v0_4 import ClipKwargs as ClipKwargs
 116from .v0_4 import ImplicitOutputShape as _ImplicitOutputShape_v0_4
 117from .v0_4 import InputTensorDescr as _InputTensorDescr_v0_4
 118from .v0_4 import KnownRunMode as KnownRunMode
 119from .v0_4 import ModelDescr as _ModelDescr_v0_4
 120from .v0_4 import OutputTensorDescr as _OutputTensorDescr_v0_4
 121from .v0_4 import ParameterizedInputShape as _ParameterizedInputShape_v0_4
 122from .v0_4 import PostprocessingDescr as _PostprocessingDescr_v0_4
 123from .v0_4 import PreprocessingDescr as _PreprocessingDescr_v0_4
 124from .v0_4 import ProcessingKwargs as ProcessingKwargs
 125from .v0_4 import RunMode as RunMode
 126from .v0_4 import ScaleLinearDescr as _ScaleLinearDescr_v0_4
 127from .v0_4 import ScaleMeanVarianceDescr as _ScaleMeanVarianceDescr_v0_4
 128from .v0_4 import ScaleRangeDescr as _ScaleRangeDescr_v0_4
 129from .v0_4 import SigmoidDescr as _SigmoidDescr_v0_4
 130from .v0_4 import TensorName as _TensorName_v0_4
 131from .v0_4 import WeightsFormat as WeightsFormat
 132from .v0_4 import ZeroMeanUnitVarianceDescr as _ZeroMeanUnitVarianceDescr_v0_4
 133from .v0_4 import package_weights
 134
 135SpaceUnit = Literal[
 136    "attometer",
 137    "angstrom",
 138    "centimeter",
 139    "decimeter",
 140    "exameter",
 141    "femtometer",
 142    "foot",
 143    "gigameter",
 144    "hectometer",
 145    "inch",
 146    "kilometer",
 147    "megameter",
 148    "meter",
 149    "micrometer",
 150    "mile",
 151    "millimeter",
 152    "nanometer",
 153    "parsec",
 154    "petameter",
 155    "picometer",
 156    "terameter",
 157    "yard",
 158    "yoctometer",
 159    "yottameter",
 160    "zeptometer",
 161    "zettameter",
 162]
 163"""Space unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)"""
 164
 165TimeUnit = Literal[
 166    "attosecond",
 167    "centisecond",
 168    "day",
 169    "decisecond",
 170    "exasecond",
 171    "femtosecond",
 172    "gigasecond",
 173    "hectosecond",
 174    "hour",
 175    "kilosecond",
 176    "megasecond",
 177    "microsecond",
 178    "millisecond",
 179    "minute",
 180    "nanosecond",
 181    "petasecond",
 182    "picosecond",
 183    "second",
 184    "terasecond",
 185    "yoctosecond",
 186    "yottasecond",
 187    "zeptosecond",
 188    "zettasecond",
 189]
 190"""Time unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)"""
 191
 192AxisType = Literal["batch", "channel", "index", "time", "space"]
 193
 194_AXIS_TYPE_MAP: Mapping[str, AxisType] = {
 195    "b": "batch",
 196    "t": "time",
 197    "i": "index",
 198    "c": "channel",
 199    "x": "space",
 200    "y": "space",
 201    "z": "space",
 202}
 203
 204_AXIS_ID_MAP = {
 205    "b": "batch",
 206    "t": "time",
 207    "i": "index",
 208    "c": "channel",
 209}
 210
 211
 212class TensorId(LowerCaseIdentifier):
 213    root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
 214        Annotated[LowerCaseIdentifierAnno, MaxLen(32)]
 215    ]
 216
 217
 218def _normalize_axis_id(a: str):
 219    a = str(a)
 220    normalized = _AXIS_ID_MAP.get(a, a)
 221    if a != normalized:
 222        logger.opt(depth=3).warning(
 223            "Normalized axis id from '{}' to '{}'.", a, normalized
 224        )
 225    return normalized
 226
 227
 228class AxisId(LowerCaseIdentifier):
 229    root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
 230        Annotated[
 231            LowerCaseIdentifierAnno,
 232            MaxLen(16),
 233            AfterValidator(_normalize_axis_id),
 234        ]
 235    ]
 236
 237
 238def _is_batch(a: str) -> bool:
 239    return str(a) == "batch"
 240
 241
 242def _is_not_batch(a: str) -> bool:
 243    return not _is_batch(a)
 244
 245
 246NonBatchAxisId = Annotated[AxisId, Predicate(_is_not_batch)]
 247
 248PostprocessingId = Literal[
 249    "binarize",
 250    "clip",
 251    "ensure_dtype",
 252    "fixed_zero_mean_unit_variance",
 253    "scale_linear",
 254    "scale_mean_variance",
 255    "scale_range",
 256    "sigmoid",
 257    "zero_mean_unit_variance",
 258]
 259PreprocessingId = Literal[
 260    "binarize",
 261    "clip",
 262    "ensure_dtype",
 263    "scale_linear",
 264    "sigmoid",
 265    "zero_mean_unit_variance",
 266    "scale_range",
 267]
 268
 269
 270SAME_AS_TYPE = "<same as type>"
 271
 272
 273ParameterizedSize_N = int
 274"""
 275Annotates an integer to calculate a concrete axis size from a `ParameterizedSize`.
 276"""
 277
 278
 279class ParameterizedSize(Node):
 280    """Describes a range of valid tensor axis sizes as `size = min + n*step`.
 281
 282    - **min** and **step** are given by the model description.
 283    - All blocksize paramters n = 0,1,2,... yield a valid `size`.
 284    - A greater blocksize paramter n = 0,1,2,... results in a greater **size**.
 285      This allows to adjust the axis size more generically.
 286    """
 287
 288    N: ClassVar[Type[int]] = ParameterizedSize_N
 289    """Positive integer to parameterize this axis"""
 290
 291    min: Annotated[int, Gt(0)]
 292    step: Annotated[int, Gt(0)]
 293
 294    def validate_size(self, size: int) -> int:
 295        if size < self.min:
 296            raise ValueError(f"size {size} < {self.min}")
 297        if (size - self.min) % self.step != 0:
 298            raise ValueError(
 299                f"axis of size {size} is not parameterized by `min + n*step` ="
 300                + f" `{self.min} + n*{self.step}`"
 301            )
 302
 303        return size
 304
 305    def get_size(self, n: ParameterizedSize_N) -> int:
 306        return self.min + self.step * n
 307
 308    def get_n(self, s: int) -> ParameterizedSize_N:
 309        """return smallest n parameterizing a size greater or equal than `s`"""
 310        return ceil((s - self.min) / self.step)
 311
 312
 313class DataDependentSize(Node):
 314    min: Annotated[int, Gt(0)] = 1
 315    max: Annotated[Optional[int], Gt(1)] = None
 316
 317    @model_validator(mode="after")
 318    def _validate_max_gt_min(self):
 319        if self.max is not None and self.min >= self.max:
 320            raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}")
 321
 322        return self
 323
 324    def validate_size(self, size: int) -> int:
 325        if size < self.min:
 326            raise ValueError(f"size {size} < {self.min}")
 327
 328        if self.max is not None and size > self.max:
 329            raise ValueError(f"size {size} > {self.max}")
 330
 331        return size
 332
 333
 334class SizeReference(Node):
 335    """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis.
 336
 337    `axis.size = reference.size * reference.scale / axis.scale + offset`
 338
 339    Note:
 340    1. The axis and the referenced axis need to have the same unit (or no unit).
 341    2. Batch axes may not be referenced.
 342    3. Fractions are rounded down.
 343    4. If the reference axis is `concatenable` the referencing axis is assumed to be
 344        `concatenable` as well with the same block order.
 345
 346    Example:
 347    An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm².
 348    Let's assume that we want to express the image height h in relation to its width w
 349    instead of only accepting input images of exactly 100*49 pixels
 350    (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`).
 351
 352    >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2)
 353    >>> h = SpaceInputAxis(
 354    ...     id=AxisId("h"),
 355    ...     size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1),
 356    ...     unit="millimeter",
 357    ...     scale=4,
 358    ... )
 359    >>> print(h.size.get_size(h, w))
 360    49
 361
 362    ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49
 363    """
 364
 365    tensor_id: TensorId
 366    """tensor id of the reference axis"""
 367
 368    axis_id: AxisId
 369    """axis id of the reference axis"""
 370
 371    offset: int = 0
 372
 373    def get_size(
 374        self,
 375        axis: Union[
 376            ChannelAxis,
 377            IndexInputAxis,
 378            IndexOutputAxis,
 379            TimeInputAxis,
 380            SpaceInputAxis,
 381            TimeOutputAxis,
 382            TimeOutputAxisWithHalo,
 383            SpaceOutputAxis,
 384            SpaceOutputAxisWithHalo,
 385        ],
 386        ref_axis: Union[
 387            ChannelAxis,
 388            IndexInputAxis,
 389            IndexOutputAxis,
 390            TimeInputAxis,
 391            SpaceInputAxis,
 392            TimeOutputAxis,
 393            TimeOutputAxisWithHalo,
 394            SpaceOutputAxis,
 395            SpaceOutputAxisWithHalo,
 396        ],
 397        n: ParameterizedSize_N = 0,
 398        ref_size: Optional[int] = None,
 399    ):
 400        """Compute the concrete size for a given axis and its reference axis.
 401
 402        Args:
 403            axis: The axis this `SizeReference` is the size of.
 404            ref_axis: The reference axis to compute the size from.
 405            n: If the **ref_axis** is parameterized (of type `ParameterizedSize`)
 406                and no fixed **ref_size** is given,
 407                **n** is used to compute the size of the parameterized **ref_axis**.
 408            ref_size: Overwrite the reference size instead of deriving it from
 409                **ref_axis**
 410                (**ref_axis.scale** is still used; any given **n** is ignored).
 411        """
 412        assert (
 413            axis.size == self
 414        ), "Given `axis.size` is not defined by this `SizeReference`"
 415
 416        assert (
 417            ref_axis.id == self.axis_id
 418        ), f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}."
 419
 420        assert axis.unit == ref_axis.unit, (
 421            "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`,"
 422            f" but {axis.unit}!={ref_axis.unit}"
 423        )
 424        if ref_size is None:
 425            if isinstance(ref_axis.size, (int, float)):
 426                ref_size = ref_axis.size
 427            elif isinstance(ref_axis.size, ParameterizedSize):
 428                ref_size = ref_axis.size.get_size(n)
 429            elif isinstance(ref_axis.size, DataDependentSize):
 430                raise ValueError(
 431                    "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`."
 432                )
 433            elif isinstance(ref_axis.size, SizeReference):
 434                raise ValueError(
 435                    "Reference axis referenced in `SizeReference` may not be sized by a"
 436                    + " `SizeReference` itself."
 437                )
 438            else:
 439                assert_never(ref_axis.size)
 440
 441        return int(ref_size * ref_axis.scale / axis.scale + self.offset)
 442
 443    @staticmethod
 444    def _get_unit(
 445        axis: Union[
 446            ChannelAxis,
 447            IndexInputAxis,
 448            IndexOutputAxis,
 449            TimeInputAxis,
 450            SpaceInputAxis,
 451            TimeOutputAxis,
 452            TimeOutputAxisWithHalo,
 453            SpaceOutputAxis,
 454            SpaceOutputAxisWithHalo,
 455        ],
 456    ):
 457        return axis.unit
 458
 459
 460class AxisBase(NodeWithExplicitlySetFields):
 461    id: AxisId
 462    """An axis id unique across all axes of one tensor."""
 463
 464    description: Annotated[str, MaxLen(128)] = ""
 465
 466
 467class WithHalo(Node):
 468    halo: Annotated[int, Ge(1)]
 469    """The halo should be cropped from the output tensor to avoid boundary effects.
 470    It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`.
 471    To document a halo that is already cropped by the model use `size.offset` instead."""
 472
 473    size: Annotated[
 474        SizeReference,
 475        Field(
 476            examples=[
 477                10,
 478                SizeReference(
 479                    tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
 480                ).model_dump(mode="json"),
 481            ]
 482        ),
 483    ]
 484    """reference to another axis with an optional offset (see `SizeReference`)"""
 485
 486
 487BATCH_AXIS_ID = AxisId("batch")
 488
 489
 490class BatchAxis(AxisBase):
 491    implemented_type: ClassVar[Literal["batch"]] = "batch"
 492    if TYPE_CHECKING:
 493        type: Literal["batch"] = "batch"
 494    else:
 495        type: Literal["batch"]
 496
 497    id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID
 498    size: Optional[Literal[1]] = None
 499    """The batch size may be fixed to 1,
 500    otherwise (the default) it may be chosen arbitrarily depending on available memory"""
 501
 502    @property
 503    def scale(self):
 504        return 1.0
 505
 506    @property
 507    def concatenable(self):
 508        return True
 509
 510    @property
 511    def unit(self):
 512        return None
 513
 514
 515class ChannelAxis(AxisBase):
 516    implemented_type: ClassVar[Literal["channel"]] = "channel"
 517    if TYPE_CHECKING:
 518        type: Literal["channel"] = "channel"
 519    else:
 520        type: Literal["channel"]
 521
 522    id: NonBatchAxisId = AxisId("channel")
 523    channel_names: NotEmpty[List[Identifier]]
 524
 525    @property
 526    def size(self) -> int:
 527        return len(self.channel_names)
 528
 529    @property
 530    def concatenable(self):
 531        return False
 532
 533    @property
 534    def scale(self) -> float:
 535        return 1.0
 536
 537    @property
 538    def unit(self):
 539        return None
 540
 541
 542class IndexAxisBase(AxisBase):
 543    implemented_type: ClassVar[Literal["index"]] = "index"
 544    if TYPE_CHECKING:
 545        type: Literal["index"] = "index"
 546    else:
 547        type: Literal["index"]
 548
 549    id: NonBatchAxisId = AxisId("index")
 550
 551    @property
 552    def scale(self) -> float:
 553        return 1.0
 554
 555    @property
 556    def unit(self):
 557        return None
 558
 559
 560class _WithInputAxisSize(Node):
 561    size: Annotated[
 562        Union[Annotated[int, Gt(0)], ParameterizedSize, SizeReference],
 563        Field(
 564            examples=[
 565                10,
 566                ParameterizedSize(min=32, step=16).model_dump(mode="json"),
 567                SizeReference(
 568                    tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
 569                ).model_dump(mode="json"),
 570            ]
 571        ),
 572    ]
 573    """The size/length of this axis can be specified as
 574    - fixed integer
 575    - parameterized series of valid sizes (`ParameterizedSize`)
 576    - reference to another axis with an optional offset (`SizeReference`)
 577    """
 578
 579
 580class IndexInputAxis(IndexAxisBase, _WithInputAxisSize):
 581    concatenable: bool = False
 582    """If a model has a `concatenable` input axis, it can be processed blockwise,
 583    splitting a longer sample axis into blocks matching its input tensor description.
 584    Output axes are concatenable if they have a `SizeReference` to a concatenable
 585    input axis.
 586    """
 587
 588
 589class IndexOutputAxis(IndexAxisBase):
 590    size: Annotated[
 591        Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize],
 592        Field(
 593            examples=[
 594                10,
 595                SizeReference(
 596                    tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
 597                ).model_dump(mode="json"),
 598            ]
 599        ),
 600    ]
 601    """The size/length of this axis can be specified as
 602    - fixed integer
 603    - reference to another axis with an optional offset (`SizeReference`)
 604    - data dependent size using `DataDependentSize` (size is only known after model inference)
 605    """
 606
 607
 608class TimeAxisBase(AxisBase):
 609    implemented_type: ClassVar[Literal["time"]] = "time"
 610    if TYPE_CHECKING:
 611        type: Literal["time"] = "time"
 612    else:
 613        type: Literal["time"]
 614
 615    id: NonBatchAxisId = AxisId("time")
 616    unit: Optional[TimeUnit] = None
 617    scale: Annotated[float, Gt(0)] = 1.0
 618
 619
 620class TimeInputAxis(TimeAxisBase, _WithInputAxisSize):
 621    concatenable: bool = False
 622    """If a model has a `concatenable` input axis, it can be processed blockwise,
 623    splitting a longer sample axis into blocks matching its input tensor description.
 624    Output axes are concatenable if they have a `SizeReference` to a concatenable
 625    input axis.
 626    """
 627
 628
 629class SpaceAxisBase(AxisBase):
 630    implemented_type: ClassVar[Literal["space"]] = "space"
 631    if TYPE_CHECKING:
 632        type: Literal["space"] = "space"
 633    else:
 634        type: Literal["space"]
 635
 636    id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x")
 637    unit: Optional[SpaceUnit] = None
 638    scale: Annotated[float, Gt(0)] = 1.0
 639
 640
 641class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize):
 642    concatenable: bool = False
 643    """If a model has a `concatenable` input axis, it can be processed blockwise,
 644    splitting a longer sample axis into blocks matching its input tensor description.
 645    Output axes are concatenable if they have a `SizeReference` to a concatenable
 646    input axis.
 647    """
 648
 649
 650INPUT_AXIS_TYPES = (
 651    BatchAxis,
 652    ChannelAxis,
 653    IndexInputAxis,
 654    TimeInputAxis,
 655    SpaceInputAxis,
 656)
 657"""intended for isinstance comparisons in py<3.10"""
 658
 659_InputAxisUnion = Union[
 660    BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis
 661]
 662InputAxis = Annotated[_InputAxisUnion, Discriminator("type")]
 663
 664
 665class _WithOutputAxisSize(Node):
 666    size: Annotated[
 667        Union[Annotated[int, Gt(0)], SizeReference],
 668        Field(
 669            examples=[
 670                10,
 671                SizeReference(
 672                    tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
 673                ).model_dump(mode="json"),
 674            ]
 675        ),
 676    ]
 677    """The size/length of this axis can be specified as
 678    - fixed integer
 679    - reference to another axis with an optional offset (see `SizeReference`)
 680    """
 681
 682
 683class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize):
 684    pass
 685
 686
 687class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo):
 688    pass
 689
 690
 691def _get_halo_axis_discriminator_value(v: Any) -> Literal["with_halo", "wo_halo"]:
 692    if isinstance(v, dict):
 693        return "with_halo" if "halo" in v else "wo_halo"
 694    else:
 695        return "with_halo" if hasattr(v, "halo") else "wo_halo"
 696
 697
 698_TimeOutputAxisUnion = Annotated[
 699    Union[
 700        Annotated[TimeOutputAxis, Tag("wo_halo")],
 701        Annotated[TimeOutputAxisWithHalo, Tag("with_halo")],
 702    ],
 703    Discriminator(_get_halo_axis_discriminator_value),
 704]
 705
 706
 707class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize):
 708    pass
 709
 710
 711class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo):
 712    pass
 713
 714
 715_SpaceOutputAxisUnion = Annotated[
 716    Union[
 717        Annotated[SpaceOutputAxis, Tag("wo_halo")],
 718        Annotated[SpaceOutputAxisWithHalo, Tag("with_halo")],
 719    ],
 720    Discriminator(_get_halo_axis_discriminator_value),
 721]
 722
 723
 724_OutputAxisUnion = Union[
 725    BatchAxis, ChannelAxis, IndexOutputAxis, _TimeOutputAxisUnion, _SpaceOutputAxisUnion
 726]
 727OutputAxis = Annotated[_OutputAxisUnion, Discriminator("type")]
 728
 729OUTPUT_AXIS_TYPES = (
 730    BatchAxis,
 731    ChannelAxis,
 732    IndexOutputAxis,
 733    TimeOutputAxis,
 734    TimeOutputAxisWithHalo,
 735    SpaceOutputAxis,
 736    SpaceOutputAxisWithHalo,
 737)
 738"""intended for isinstance comparisons in py<3.10"""
 739
 740
 741AnyAxis = Union[InputAxis, OutputAxis]
 742
 743ANY_AXIS_TYPES = INPUT_AXIS_TYPES + OUTPUT_AXIS_TYPES
 744"""intended for isinstance comparisons in py<3.10"""
 745
 746TVs = Union[
 747    NotEmpty[List[int]],
 748    NotEmpty[List[float]],
 749    NotEmpty[List[bool]],
 750    NotEmpty[List[str]],
 751]
 752
 753
 754NominalOrOrdinalDType = Literal[
 755    "float32",
 756    "float64",
 757    "uint8",
 758    "int8",
 759    "uint16",
 760    "int16",
 761    "uint32",
 762    "int32",
 763    "uint64",
 764    "int64",
 765    "bool",
 766]
 767
 768
 769class NominalOrOrdinalDataDescr(Node):
 770    values: TVs
 771    """A fixed set of nominal or an ascending sequence of ordinal values.
 772    In this case `data.type` is required to be an unsigend integer type, e.g. 'uint8'.
 773    String `values` are interpreted as labels for tensor values 0, ..., N.
 774    Note: as YAML 1.2 does not natively support a "set" datatype,
 775    nominal values should be given as a sequence (aka list/array) as well.
 776    """
 777
 778    type: Annotated[
 779        NominalOrOrdinalDType,
 780        Field(
 781            examples=[
 782                "float32",
 783                "uint8",
 784                "uint16",
 785                "int64",
 786                "bool",
 787            ],
 788        ),
 789    ] = "uint8"
 790
 791    @model_validator(mode="after")
 792    def _validate_values_match_type(
 793        self,
 794    ) -> Self:
 795        incompatible: List[Any] = []
 796        for v in self.values:
 797            if self.type == "bool":
 798                if not isinstance(v, bool):
 799                    incompatible.append(v)
 800            elif self.type in DTYPE_LIMITS:
 801                if (
 802                    isinstance(v, (int, float))
 803                    and (
 804                        v < DTYPE_LIMITS[self.type].min
 805                        or v > DTYPE_LIMITS[self.type].max
 806                    )
 807                    or (isinstance(v, str) and "uint" not in self.type)
 808                    or (isinstance(v, float) and "int" in self.type)
 809                ):
 810                    incompatible.append(v)
 811            else:
 812                incompatible.append(v)
 813
 814            if len(incompatible) == 5:
 815                incompatible.append("...")
 816                break
 817
 818        if incompatible:
 819            raise ValueError(
 820                f"data type '{self.type}' incompatible with values {incompatible}"
 821            )
 822
 823        return self
 824
 825    unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None
 826
 827    @property
 828    def range(self):
 829        if isinstance(self.values[0], str):
 830            return 0, len(self.values) - 1
 831        else:
 832            return min(self.values), max(self.values)
 833
 834
 835IntervalOrRatioDType = Literal[
 836    "float32",
 837    "float64",
 838    "uint8",
 839    "int8",
 840    "uint16",
 841    "int16",
 842    "uint32",
 843    "int32",
 844    "uint64",
 845    "int64",
 846]
 847
 848
 849class IntervalOrRatioDataDescr(Node):
 850    type: Annotated[  # todo: rename to dtype
 851        IntervalOrRatioDType,
 852        Field(
 853            examples=["float32", "float64", "uint8", "uint16"],
 854        ),
 855    ] = "float32"
 856    range: Tuple[Optional[float], Optional[float]] = (
 857        None,
 858        None,
 859    )
 860    """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor.
 861    `None` corresponds to min/max of what can be expressed by **type**."""
 862    unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit"
 863    scale: float = 1.0
 864    """Scale for data on an interval (or ratio) scale."""
 865    offset: Optional[float] = None
 866    """Offset for data on a ratio scale."""
 867
 868
 869TensorDataDescr = Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr]
 870
 871
 872class ProcessingDescrBase(NodeWithExplicitlySetFields, ABC):
 873    """processing base class"""
 874
 875
 876class BinarizeKwargs(ProcessingKwargs):
 877    """key word arguments for `BinarizeDescr`"""
 878
 879    threshold: float
 880    """The fixed threshold"""
 881
 882
 883class BinarizeAlongAxisKwargs(ProcessingKwargs):
 884    """key word arguments for `BinarizeDescr`"""
 885
 886    threshold: NotEmpty[List[float]]
 887    """The fixed threshold values along `axis`"""
 888
 889    axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
 890    """The `threshold` axis"""
 891
 892
 893class BinarizeDescr(ProcessingDescrBase):
 894    """Binarize the tensor with a fixed threshold.
 895
 896    Values above `BinarizeKwargs.threshold`/`BinarizeAlongAxisKwargs.threshold`
 897    will be set to one, values below the threshold to zero.
 898
 899    Examples:
 900    - in YAML
 901        ```yaml
 902        postprocessing:
 903          - id: binarize
 904            kwargs:
 905              axis: 'channel'
 906              threshold: [0.25, 0.5, 0.75]
 907        ```
 908    - in Python:
 909        >>> postprocessing = [BinarizeDescr(
 910        ...   kwargs=BinarizeAlongAxisKwargs(
 911        ...       axis=AxisId('channel'),
 912        ...       threshold=[0.25, 0.5, 0.75],
 913        ...   )
 914        ... )]
 915    """
 916
 917    implemented_id: ClassVar[Literal["binarize"]] = "binarize"
 918    if TYPE_CHECKING:
 919        id: Literal["binarize"] = "binarize"
 920    else:
 921        id: Literal["binarize"]
 922    kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs]
 923
 924
 925class ClipDescr(ProcessingDescrBase):
 926    """Set tensor values below min to min and above max to max.
 927
 928    See `ScaleRangeDescr` for examples.
 929    """
 930
 931    implemented_id: ClassVar[Literal["clip"]] = "clip"
 932    if TYPE_CHECKING:
 933        id: Literal["clip"] = "clip"
 934    else:
 935        id: Literal["clip"]
 936
 937    kwargs: ClipKwargs
 938
 939
 940class EnsureDtypeKwargs(ProcessingKwargs):
 941    """key word arguments for `EnsureDtypeDescr`"""
 942
 943    dtype: Literal[
 944        "float32",
 945        "float64",
 946        "uint8",
 947        "int8",
 948        "uint16",
 949        "int16",
 950        "uint32",
 951        "int32",
 952        "uint64",
 953        "int64",
 954        "bool",
 955    ]
 956
 957
 958class EnsureDtypeDescr(ProcessingDescrBase):
 959    """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching).
 960
 961    This can for example be used to ensure the inner neural network model gets a
 962    different input tensor data type than the fully described bioimage.io model does.
 963
 964    Examples:
 965        The described bioimage.io model (incl. preprocessing) accepts any
 966        float32-compatible tensor, normalizes it with percentiles and clipping and then
 967        casts it to uint8, which is what the neural network in this example expects.
 968        - in YAML
 969            ```yaml
 970            inputs:
 971            - data:
 972                type: float32  # described bioimage.io model is compatible with any float32 input tensor
 973            preprocessing:
 974            - id: scale_range
 975                kwargs:
 976                axes: ['y', 'x']
 977                max_percentile: 99.8
 978                min_percentile: 5.0
 979            - id: clip
 980                kwargs:
 981                min: 0.0
 982                max: 1.0
 983            - id: ensure_dtype
 984                kwargs:
 985                dtype: uint8
 986            ```
 987        - in Python:
 988            >>> preprocessing = [
 989            ...     ScaleRangeDescr(
 990            ...         kwargs=ScaleRangeKwargs(
 991            ...           axes= (AxisId('y'), AxisId('x')),
 992            ...           max_percentile= 99.8,
 993            ...           min_percentile= 5.0,
 994            ...         )
 995            ...     ),
 996            ...     ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)),
 997            ...     EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")),
 998            ... ]
 999    """
1000
1001    implemented_id: ClassVar[Literal["ensure_dtype"]] = "ensure_dtype"
1002    if TYPE_CHECKING:
1003        id: Literal["ensure_dtype"] = "ensure_dtype"
1004    else:
1005        id: Literal["ensure_dtype"]
1006
1007    kwargs: EnsureDtypeKwargs
1008
1009
1010class ScaleLinearKwargs(ProcessingKwargs):
1011    """Key word arguments for `ScaleLinearDescr`"""
1012
1013    gain: float = 1.0
1014    """multiplicative factor"""
1015
1016    offset: float = 0.0
1017    """additive term"""
1018
1019    @model_validator(mode="after")
1020    def _validate(self) -> Self:
1021        if self.gain == 1.0 and self.offset == 0.0:
1022            raise ValueError(
1023                "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
1024                + " != 0.0."
1025            )
1026
1027        return self
1028
1029
1030class ScaleLinearAlongAxisKwargs(ProcessingKwargs):
1031    """Key word arguments for `ScaleLinearDescr`"""
1032
1033    axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
1034    """The axis of gain and offset values."""
1035
1036    gain: Union[float, NotEmpty[List[float]]] = 1.0
1037    """multiplicative factor"""
1038
1039    offset: Union[float, NotEmpty[List[float]]] = 0.0
1040    """additive term"""
1041
1042    @model_validator(mode="after")
1043    def _validate(self) -> Self:
1044
1045        if isinstance(self.gain, list):
1046            if isinstance(self.offset, list):
1047                if len(self.gain) != len(self.offset):
1048                    raise ValueError(
1049                        f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match."
1050                    )
1051            else:
1052                self.offset = [float(self.offset)] * len(self.gain)
1053        elif isinstance(self.offset, list):
1054            self.gain = [float(self.gain)] * len(self.offset)
1055        else:
1056            raise ValueError(
1057                "Do not specify an `axis` for scalar gain and offset values."
1058            )
1059
1060        if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset):
1061            raise ValueError(
1062                "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
1063                + " != 0.0."
1064            )
1065
1066        return self
1067
1068
1069class ScaleLinearDescr(ProcessingDescrBase):
1070    """Fixed linear scaling.
1071
1072    Examples:
1073      1. Scale with scalar gain and offset
1074        - in YAML
1075        ```yaml
1076        preprocessing:
1077          - id: scale_linear
1078            kwargs:
1079              gain: 2.0
1080              offset: 3.0
1081        ```
1082        - in Python:
1083        >>> preprocessing = [
1084        ...     ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0))
1085        ... ]
1086
1087      2. Independent scaling along an axis
1088        - in YAML
1089        ```yaml
1090        preprocessing:
1091          - id: scale_linear
1092            kwargs:
1093              axis: 'channel'
1094              gain: [1.0, 2.0, 3.0]
1095        ```
1096        - in Python:
1097        >>> preprocessing = [
1098        ...     ScaleLinearDescr(
1099        ...         kwargs=ScaleLinearAlongAxisKwargs(
1100        ...             axis=AxisId("channel"),
1101        ...             gain=[1.0, 2.0, 3.0],
1102        ...         )
1103        ...     )
1104        ... ]
1105
1106    """
1107
1108    implemented_id: ClassVar[Literal["scale_linear"]] = "scale_linear"
1109    if TYPE_CHECKING:
1110        id: Literal["scale_linear"] = "scale_linear"
1111    else:
1112        id: Literal["scale_linear"]
1113    kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs]
1114
1115
1116class SigmoidDescr(ProcessingDescrBase):
1117    """The logistic sigmoid funciton, a.k.a. expit function.
1118
1119    Examples:
1120    - in YAML
1121        ```yaml
1122        postprocessing:
1123          - id: sigmoid
1124        ```
1125    - in Python:
1126        >>> postprocessing = [SigmoidDescr()]
1127    """
1128
1129    implemented_id: ClassVar[Literal["sigmoid"]] = "sigmoid"
1130    if TYPE_CHECKING:
1131        id: Literal["sigmoid"] = "sigmoid"
1132    else:
1133        id: Literal["sigmoid"]
1134
1135    @property
1136    def kwargs(self) -> ProcessingKwargs:
1137        """empty kwargs"""
1138        return ProcessingKwargs()
1139
1140
1141class FixedZeroMeanUnitVarianceKwargs(ProcessingKwargs):
1142    """key word arguments for `FixedZeroMeanUnitVarianceDescr`"""
1143
1144    mean: float
1145    """The mean value to normalize with."""
1146
1147    std: Annotated[float, Ge(1e-6)]
1148    """The standard deviation value to normalize with."""
1149
1150
1151class FixedZeroMeanUnitVarianceAlongAxisKwargs(ProcessingKwargs):
1152    """key word arguments for `FixedZeroMeanUnitVarianceDescr`"""
1153
1154    mean: NotEmpty[List[float]]
1155    """The mean value(s) to normalize with."""
1156
1157    std: NotEmpty[List[Annotated[float, Ge(1e-6)]]]
1158    """The standard deviation value(s) to normalize with.
1159    Size must match `mean` values."""
1160
1161    axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])]
1162    """The axis of the mean/std values to normalize each entry along that dimension
1163    separately."""
1164
1165    @model_validator(mode="after")
1166    def _mean_and_std_match(self) -> Self:
1167        if len(self.mean) != len(self.std):
1168            raise ValueError(
1169                f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})"
1170                + " must match."
1171            )
1172
1173        return self
1174
1175
1176class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1177    """Subtract a given mean and divide by the standard deviation.
1178
1179    Normalize with fixed, precomputed values for
1180    `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std`
1181    Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given
1182    axes.
1183
1184    Examples:
1185    1. scalar value for whole tensor
1186        - in YAML
1187        ```yaml
1188        preprocessing:
1189          - id: fixed_zero_mean_unit_variance
1190            kwargs:
1191              mean: 103.5
1192              std: 13.7
1193        ```
1194        - in Python
1195        >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
1196        ...   kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7)
1197        ... )]
1198
1199    2. independently along an axis
1200        - in YAML
1201        ```yaml
1202        preprocessing:
1203          - id: fixed_zero_mean_unit_variance
1204            kwargs:
1205              axis: channel
1206              mean: [101.5, 102.5, 103.5]
1207              std: [11.7, 12.7, 13.7]
1208        ```
1209        - in Python
1210        >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
1211        ...   kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
1212        ...     axis=AxisId("channel"),
1213        ...     mean=[101.5, 102.5, 103.5],
1214        ...     std=[11.7, 12.7, 13.7],
1215        ...   )
1216        ... )]
1217    """
1218
1219    implemented_id: ClassVar[Literal["fixed_zero_mean_unit_variance"]] = (
1220        "fixed_zero_mean_unit_variance"
1221    )
1222    if TYPE_CHECKING:
1223        id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance"
1224    else:
1225        id: Literal["fixed_zero_mean_unit_variance"]
1226
1227    kwargs: Union[
1228        FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs
1229    ]
1230
1231
1232class ZeroMeanUnitVarianceKwargs(ProcessingKwargs):
1233    """key word arguments for `ZeroMeanUnitVarianceDescr`"""
1234
1235    axes: Annotated[
1236        Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1237    ] = None
1238    """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1239    For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1240    resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1241    To normalize each sample independently leave out the 'batch' axis.
1242    Default: Scale all axes jointly."""
1243
1244    eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1245    """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`."""
1246
1247
1248class ZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1249    """Subtract mean and divide by variance.
1250
1251    Examples:
1252        Subtract tensor mean and variance
1253        - in YAML
1254        ```yaml
1255        preprocessing:
1256          - id: zero_mean_unit_variance
1257        ```
1258        - in Python
1259        >>> preprocessing = [ZeroMeanUnitVarianceDescr()]
1260    """
1261
1262    implemented_id: ClassVar[Literal["zero_mean_unit_variance"]] = (
1263        "zero_mean_unit_variance"
1264    )
1265    if TYPE_CHECKING:
1266        id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance"
1267    else:
1268        id: Literal["zero_mean_unit_variance"]
1269
1270    kwargs: ZeroMeanUnitVarianceKwargs = Field(
1271        default_factory=ZeroMeanUnitVarianceKwargs
1272    )
1273
1274
1275class ScaleRangeKwargs(ProcessingKwargs):
1276    """key word arguments for `ScaleRangeDescr`
1277
1278    For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default)
1279    this processing step normalizes data to the [0, 1] intervall.
1280    For other percentiles the normalized values will partially be outside the [0, 1]
1281    intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the
1282    normalized values to a range.
1283    """
1284
1285    axes: Annotated[
1286        Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1287    ] = None
1288    """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value.
1289    For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1290    resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1291    To normalize samples independently, leave out the "batch" axis.
1292    Default: Scale all axes jointly."""
1293
1294    min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0
1295    """The lower percentile used to determine the value to align with zero."""
1296
1297    max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0
1298    """The upper percentile used to determine the value to align with one.
1299    Has to be bigger than `min_percentile`.
1300    The range is 1 to 100 instead of 0 to 100 to avoid mistakenly
1301    accepting percentiles specified in the range 0.0 to 1.0."""
1302
1303    eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1304    """Epsilon for numeric stability.
1305    `out = (tensor - v_lower) / (v_upper - v_lower + eps)`;
1306    with `v_lower,v_upper` values at the respective percentiles."""
1307
1308    reference_tensor: Optional[TensorId] = None
1309    """Tensor ID to compute the percentiles from. Default: The tensor itself.
1310    For any tensor in `inputs` only input tensor references are allowed."""
1311
1312    @field_validator("max_percentile", mode="after")
1313    @classmethod
1314    def min_smaller_max(cls, value: float, info: ValidationInfo) -> float:
1315        if (min_p := info.data["min_percentile"]) >= value:
1316            raise ValueError(f"min_percentile {min_p} >= max_percentile {value}")
1317
1318        return value
1319
1320
1321class ScaleRangeDescr(ProcessingDescrBase):
1322    """Scale with percentiles.
1323
1324    Examples:
1325    1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0
1326        - in YAML
1327        ```yaml
1328        preprocessing:
1329          - id: scale_range
1330            kwargs:
1331              axes: ['y', 'x']
1332              max_percentile: 99.8
1333              min_percentile: 5.0
1334        ```
1335        - in Python
1336        >>> preprocessing = [
1337        ...     ScaleRangeDescr(
1338        ...         kwargs=ScaleRangeKwargs(
1339        ...           axes= (AxisId('y'), AxisId('x')),
1340        ...           max_percentile= 99.8,
1341        ...           min_percentile= 5.0,
1342        ...         )
1343        ...     ),
1344        ...     ClipDescr(
1345        ...         kwargs=ClipKwargs(
1346        ...             min=0.0,
1347        ...             max=1.0,
1348        ...         )
1349        ...     ),
1350        ... ]
1351
1352      2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles.
1353        - in YAML
1354        ```yaml
1355        preprocessing:
1356          - id: scale_range
1357            kwargs:
1358              axes: ['y', 'x']
1359              max_percentile: 99.8
1360              min_percentile: 5.0
1361                  - id: scale_range
1362           - id: clip
1363             kwargs:
1364              min: 0.0
1365              max: 1.0
1366        ```
1367        - in Python
1368        >>> preprocessing = [ScaleRangeDescr(
1369        ...   kwargs=ScaleRangeKwargs(
1370        ...       axes= (AxisId('y'), AxisId('x')),
1371        ...       max_percentile= 99.8,
1372        ...       min_percentile= 5.0,
1373        ...   )
1374        ... )]
1375
1376    """
1377
1378    implemented_id: ClassVar[Literal["scale_range"]] = "scale_range"
1379    if TYPE_CHECKING:
1380        id: Literal["scale_range"] = "scale_range"
1381    else:
1382        id: Literal["scale_range"]
1383    kwargs: ScaleRangeKwargs
1384
1385
1386class ScaleMeanVarianceKwargs(ProcessingKwargs):
1387    """key word arguments for `ScaleMeanVarianceKwargs`"""
1388
1389    reference_tensor: TensorId
1390    """Name of tensor to match."""
1391
1392    axes: Annotated[
1393        Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1394    ] = None
1395    """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1396    For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1397    resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1398    To normalize samples independently, leave out the 'batch' axis.
1399    Default: Scale all axes jointly."""
1400
1401    eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1402    """Epsilon for numeric stability:
1403    `out  = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`"""
1404
1405
1406class ScaleMeanVarianceDescr(ProcessingDescrBase):
1407    """Scale a tensor's data distribution to match another tensor's mean/std.
1408    `out  = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`
1409    """
1410
1411    implemented_id: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance"
1412    if TYPE_CHECKING:
1413        id: Literal["scale_mean_variance"] = "scale_mean_variance"
1414    else:
1415        id: Literal["scale_mean_variance"]
1416    kwargs: ScaleMeanVarianceKwargs
1417
1418
1419PreprocessingDescr = Annotated[
1420    Union[
1421        BinarizeDescr,
1422        ClipDescr,
1423        EnsureDtypeDescr,
1424        ScaleLinearDescr,
1425        SigmoidDescr,
1426        FixedZeroMeanUnitVarianceDescr,
1427        ZeroMeanUnitVarianceDescr,
1428        ScaleRangeDescr,
1429    ],
1430    Discriminator("id"),
1431]
1432PostprocessingDescr = Annotated[
1433    Union[
1434        BinarizeDescr,
1435        ClipDescr,
1436        EnsureDtypeDescr,
1437        ScaleLinearDescr,
1438        SigmoidDescr,
1439        FixedZeroMeanUnitVarianceDescr,
1440        ZeroMeanUnitVarianceDescr,
1441        ScaleRangeDescr,
1442        ScaleMeanVarianceDescr,
1443    ],
1444    Discriminator("id"),
1445]
1446
1447IO_AxisT = TypeVar("IO_AxisT", InputAxis, OutputAxis)
1448
1449
1450class TensorDescrBase(Node, Generic[IO_AxisT]):
1451    id: TensorId
1452    """Tensor id. No duplicates are allowed."""
1453
1454    description: Annotated[str, MaxLen(128)] = ""
1455    """free text description"""
1456
1457    axes: NotEmpty[Sequence[IO_AxisT]]
1458    """tensor axes"""
1459
1460    @property
1461    def shape(self):
1462        return tuple(a.size for a in self.axes)
1463
1464    @field_validator("axes", mode="after", check_fields=False)
1465    @classmethod
1466    def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]:
1467        batch_axes = [a for a in axes if a.type == "batch"]
1468        if len(batch_axes) > 1:
1469            raise ValueError(
1470                f"Only one batch axis (per tensor) allowed, but got {batch_axes}"
1471            )
1472
1473        seen_ids: Set[AxisId] = set()
1474        duplicate_axes_ids: Set[AxisId] = set()
1475        for a in axes:
1476            (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id)
1477
1478        if duplicate_axes_ids:
1479            raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}")
1480
1481        return axes
1482
1483    test_tensor: FileDescr
1484    """An example tensor to use for testing.
1485    Using the model with the test input tensors is expected to yield the test output tensors.
1486    Each test tensor has be a an ndarray in the
1487    [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format).
1488    The file extension must be '.npy'."""
1489
1490    sample_tensor: Optional[FileDescr] = None
1491    """A sample tensor to illustrate a possible input/output for the model,
1492    The sample image primarily serves to inform a human user about an example use case
1493    and is typically stored as .hdf5, .png or .tiff.
1494    It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats)
1495    (numpy's `.npy` format is not supported).
1496    The image dimensionality has to match the number of axes specified in this tensor description.
1497    """
1498
1499    @model_validator(mode="after")
1500    def _validate_sample_tensor(self) -> Self:
1501        if self.sample_tensor is None or not get_validation_context().perform_io_checks:
1502            return self
1503
1504        local = download(self.sample_tensor.source, sha256=self.sample_tensor.sha256)
1505        tensor: NDArray[Any] = imread(
1506            local.path.read_bytes(),
1507            extension=PurePosixPath(local.original_file_name).suffix,
1508        )
1509        n_dims = len(tensor.squeeze().shape)
1510        n_dims_min = n_dims_max = len(self.axes)
1511
1512        for a in self.axes:
1513            if isinstance(a, BatchAxis):
1514                n_dims_min -= 1
1515            elif isinstance(a.size, int):
1516                if a.size == 1:
1517                    n_dims_min -= 1
1518            elif isinstance(a.size, (ParameterizedSize, DataDependentSize)):
1519                if a.size.min == 1:
1520                    n_dims_min -= 1
1521            elif isinstance(a.size, SizeReference):
1522                if a.size.offset < 2:
1523                    # size reference may result in singleton axis
1524                    n_dims_min -= 1
1525            else:
1526                assert_never(a.size)
1527
1528        n_dims_min = max(0, n_dims_min)
1529        if n_dims < n_dims_min or n_dims > n_dims_max:
1530            raise ValueError(
1531                f"Expected sample tensor to have {n_dims_min} to"
1532                + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})."
1533            )
1534
1535        return self
1536
1537    data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = (
1538        IntervalOrRatioDataDescr()
1539    )
1540    """Description of the tensor's data values, optionally per channel.
1541    If specified per channel, the data `type` needs to match across channels."""
1542
1543    @property
1544    def dtype(
1545        self,
1546    ) -> Literal[
1547        "float32",
1548        "float64",
1549        "uint8",
1550        "int8",
1551        "uint16",
1552        "int16",
1553        "uint32",
1554        "int32",
1555        "uint64",
1556        "int64",
1557        "bool",
1558    ]:
1559        """dtype as specified under `data.type` or `data[i].type`"""
1560        if isinstance(self.data, collections.abc.Sequence):
1561            return self.data[0].type
1562        else:
1563            return self.data.type
1564
1565    @field_validator("data", mode="after")
1566    @classmethod
1567    def _check_data_type_across_channels(
1568        cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]
1569    ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]:
1570        if not isinstance(value, list):
1571            return value
1572
1573        dtypes = {t.type for t in value}
1574        if len(dtypes) > 1:
1575            raise ValueError(
1576                "Tensor data descriptions per channel need to agree in their data"
1577                + f" `type`, but found {dtypes}."
1578            )
1579
1580        return value
1581
1582    @model_validator(mode="after")
1583    def _check_data_matches_channelaxis(self) -> Self:
1584        if not isinstance(self.data, (list, tuple)):
1585            return self
1586
1587        for a in self.axes:
1588            if isinstance(a, ChannelAxis):
1589                size = a.size
1590                assert isinstance(size, int)
1591                break
1592        else:
1593            return self
1594
1595        if len(self.data) != size:
1596            raise ValueError(
1597                f"Got tensor data descriptions for {len(self.data)} channels, but"
1598                + f" '{a.id}' axis has size {size}."
1599            )
1600
1601        return self
1602
1603    def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]:
1604        if len(array.shape) != len(self.axes):
1605            raise ValueError(
1606                f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})"
1607                + f" incompatible with {len(self.axes)} axes."
1608            )
1609        return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
1610
1611
1612class InputTensorDescr(TensorDescrBase[InputAxis]):
1613    id: TensorId = TensorId("input")
1614    """Input tensor id.
1615    No duplicates are allowed across all inputs and outputs."""
1616
1617    optional: bool = False
1618    """indicates that this tensor may be `None`"""
1619
1620    preprocessing: List[PreprocessingDescr] = Field(default_factory=list)
1621    """Description of how this input should be preprocessed.
1622
1623    notes:
1624    - If preprocessing does not start with an 'ensure_dtype' entry, it is added
1625      to ensure an input tensor's data type matches the input tensor's data description.
1626    - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an
1627      'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally
1628      changing the data type.
1629    """
1630
1631    @model_validator(mode="after")
1632    def _validate_preprocessing_kwargs(self) -> Self:
1633        axes_ids = [a.id for a in self.axes]
1634        for p in self.preprocessing:
1635            kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
1636            if kwargs_axes is None:
1637                continue
1638
1639            if not isinstance(kwargs_axes, collections.abc.Sequence):
1640                raise ValueError(
1641                    f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}"
1642                )
1643
1644            if any(a not in axes_ids for a in kwargs_axes):
1645                raise ValueError(
1646                    "`preprocessing.i.kwargs.axes` needs to be subset of axes ids"
1647                )
1648
1649        if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
1650            dtype = self.data.type
1651        else:
1652            dtype = self.data[0].type
1653
1654        # ensure `preprocessing` begins with `EnsureDtypeDescr`
1655        if not self.preprocessing or not isinstance(
1656            self.preprocessing[0], EnsureDtypeDescr
1657        ):
1658            self.preprocessing.insert(
1659                0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1660            )
1661
1662        # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
1663        if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)):
1664            self.preprocessing.append(
1665                EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1666            )
1667
1668        return self
1669
1670
1671def convert_axes(
1672    axes: str,
1673    *,
1674    shape: Union[
1675        Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4
1676    ],
1677    tensor_type: Literal["input", "output"],
1678    halo: Optional[Sequence[int]],
1679    size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
1680):
1681    ret: List[AnyAxis] = []
1682    for i, a in enumerate(axes):
1683        axis_type = _AXIS_TYPE_MAP.get(a, a)
1684        if axis_type == "batch":
1685            ret.append(BatchAxis())
1686            continue
1687
1688        scale = 1.0
1689        if isinstance(shape, _ParameterizedInputShape_v0_4):
1690            if shape.step[i] == 0:
1691                size = shape.min[i]
1692            else:
1693                size = ParameterizedSize(min=shape.min[i], step=shape.step[i])
1694        elif isinstance(shape, _ImplicitOutputShape_v0_4):
1695            ref_t = str(shape.reference_tensor)
1696            if ref_t.count(".") == 1:
1697                t_id, orig_a_id = ref_t.split(".")
1698            else:
1699                t_id = ref_t
1700                orig_a_id = a
1701
1702            a_id = _AXIS_ID_MAP.get(orig_a_id, a)
1703            if not (orig_scale := shape.scale[i]):
1704                # old way to insert a new axis dimension
1705                size = int(2 * shape.offset[i])
1706            else:
1707                scale = 1 / orig_scale
1708                if axis_type in ("channel", "index"):
1709                    # these axes no longer have a scale
1710                    offset_from_scale = orig_scale * size_refs.get(
1711                        _TensorName_v0_4(t_id), {}
1712                    ).get(orig_a_id, 0)
1713                else:
1714                    offset_from_scale = 0
1715                size = SizeReference(
1716                    tensor_id=TensorId(t_id),
1717                    axis_id=AxisId(a_id),
1718                    offset=int(offset_from_scale + 2 * shape.offset[i]),
1719                )
1720        else:
1721            size = shape[i]
1722
1723        if axis_type == "time":
1724            if tensor_type == "input":
1725                ret.append(TimeInputAxis(size=size, scale=scale))
1726            else:
1727                assert not isinstance(size, ParameterizedSize)
1728                if halo is None:
1729                    ret.append(TimeOutputAxis(size=size, scale=scale))
1730                else:
1731                    assert not isinstance(size, int)
1732                    ret.append(
1733                        TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i])
1734                    )
1735
1736        elif axis_type == "index":
1737            if tensor_type == "input":
1738                ret.append(IndexInputAxis(size=size))
1739            else:
1740                if isinstance(size, ParameterizedSize):
1741                    size = DataDependentSize(min=size.min)
1742
1743                ret.append(IndexOutputAxis(size=size))
1744        elif axis_type == "channel":
1745            assert not isinstance(size, ParameterizedSize)
1746            if isinstance(size, SizeReference):
1747                warnings.warn(
1748                    "Conversion of channel size from an implicit output shape may be"
1749                    + " wrong"
1750                )
1751                ret.append(
1752                    ChannelAxis(
1753                        channel_names=[
1754                            Identifier(f"channel{i}") for i in range(size.offset)
1755                        ]
1756                    )
1757                )
1758            else:
1759                ret.append(
1760                    ChannelAxis(
1761                        channel_names=[Identifier(f"channel{i}") for i in range(size)]
1762                    )
1763                )
1764        elif axis_type == "space":
1765            if tensor_type == "input":
1766                ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale))
1767            else:
1768                assert not isinstance(size, ParameterizedSize)
1769                if halo is None or halo[i] == 0:
1770                    ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale))
1771                elif isinstance(size, int):
1772                    raise NotImplementedError(
1773                        f"output axis with halo and fixed size (here {size}) not allowed"
1774                    )
1775                else:
1776                    ret.append(
1777                        SpaceOutputAxisWithHalo(
1778                            id=AxisId(a), size=size, scale=scale, halo=halo[i]
1779                        )
1780                    )
1781
1782    return ret
1783
1784
1785def _axes_letters_to_ids(
1786    axes: Optional[str],
1787) -> Optional[List[AxisId]]:
1788    if axes is None:
1789        return None
1790
1791    return [AxisId(a) for a in axes]
1792
1793
1794def _get_complement_v04_axis(
1795    tensor_axes: Sequence[str], axes: Optional[Sequence[str]]
1796) -> Optional[AxisId]:
1797    if axes is None:
1798        return None
1799
1800    non_complement_axes = set(axes) | {"b"}
1801    complement_axes = [a for a in tensor_axes if a not in non_complement_axes]
1802    if len(complement_axes) > 1:
1803        raise ValueError(
1804            f"Expected none or a single complement axis, but axes '{axes}' "
1805            + f"for tensor dims '{tensor_axes}' leave '{complement_axes}'."
1806        )
1807
1808    return None if not complement_axes else AxisId(complement_axes[0])
1809
1810
1811def _convert_proc(
1812    p: Union[_PreprocessingDescr_v0_4, _PostprocessingDescr_v0_4],
1813    tensor_axes: Sequence[str],
1814) -> Union[PreprocessingDescr, PostprocessingDescr]:
1815    if isinstance(p, _BinarizeDescr_v0_4):
1816        return BinarizeDescr(kwargs=BinarizeKwargs(threshold=p.kwargs.threshold))
1817    elif isinstance(p, _ClipDescr_v0_4):
1818        return ClipDescr(kwargs=ClipKwargs(min=p.kwargs.min, max=p.kwargs.max))
1819    elif isinstance(p, _SigmoidDescr_v0_4):
1820        return SigmoidDescr()
1821    elif isinstance(p, _ScaleLinearDescr_v0_4):
1822        axes = _axes_letters_to_ids(p.kwargs.axes)
1823        if p.kwargs.axes is None:
1824            axis = None
1825        else:
1826            axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes)
1827
1828        if axis is None:
1829            assert not isinstance(p.kwargs.gain, list)
1830            assert not isinstance(p.kwargs.offset, list)
1831            kwargs = ScaleLinearKwargs(gain=p.kwargs.gain, offset=p.kwargs.offset)
1832        else:
1833            kwargs = ScaleLinearAlongAxisKwargs(
1834                axis=axis, gain=p.kwargs.gain, offset=p.kwargs.offset
1835            )
1836        return ScaleLinearDescr(kwargs=kwargs)
1837    elif isinstance(p, _ScaleMeanVarianceDescr_v0_4):
1838        return ScaleMeanVarianceDescr(
1839            kwargs=ScaleMeanVarianceKwargs(
1840                axes=_axes_letters_to_ids(p.kwargs.axes),
1841                reference_tensor=TensorId(str(p.kwargs.reference_tensor)),
1842                eps=p.kwargs.eps,
1843            )
1844        )
1845    elif isinstance(p, _ZeroMeanUnitVarianceDescr_v0_4):
1846        if p.kwargs.mode == "fixed":
1847            mean = p.kwargs.mean
1848            std = p.kwargs.std
1849            assert mean is not None
1850            assert std is not None
1851
1852            axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes)
1853
1854            if axis is None:
1855                return FixedZeroMeanUnitVarianceDescr(
1856                    kwargs=FixedZeroMeanUnitVarianceKwargs(
1857                        mean=mean, std=std  # pyright: ignore[reportArgumentType]
1858                    )
1859                )
1860            else:
1861                if not isinstance(mean, list):
1862                    mean = [float(mean)]
1863                if not isinstance(std, list):
1864                    std = [float(std)]
1865
1866                return FixedZeroMeanUnitVarianceDescr(
1867                    kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
1868                        axis=axis, mean=mean, std=std
1869                    )
1870                )
1871
1872        else:
1873            axes = _axes_letters_to_ids(p.kwargs.axes) or []
1874            if p.kwargs.mode == "per_dataset":
1875                axes = [AxisId("batch")] + axes
1876            if not axes:
1877                axes = None
1878            return ZeroMeanUnitVarianceDescr(
1879                kwargs=ZeroMeanUnitVarianceKwargs(axes=axes, eps=p.kwargs.eps)
1880            )
1881
1882    elif isinstance(p, _ScaleRangeDescr_v0_4):
1883        return ScaleRangeDescr(
1884            kwargs=ScaleRangeKwargs(
1885                axes=_axes_letters_to_ids(p.kwargs.axes),
1886                min_percentile=p.kwargs.min_percentile,
1887                max_percentile=p.kwargs.max_percentile,
1888                eps=p.kwargs.eps,
1889            )
1890        )
1891    else:
1892        assert_never(p)
1893
1894
1895class _InputTensorConv(
1896    Converter[
1897        _InputTensorDescr_v0_4,
1898        InputTensorDescr,
1899        ImportantFileSource,
1900        Optional[ImportantFileSource],
1901        Mapping[_TensorName_v0_4, Mapping[str, int]],
1902    ]
1903):
1904    def _convert(
1905        self,
1906        src: _InputTensorDescr_v0_4,
1907        tgt: "type[InputTensorDescr] | type[dict[str, Any]]",
1908        test_tensor: ImportantFileSource,
1909        sample_tensor: Optional[ImportantFileSource],
1910        size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
1911    ) -> "InputTensorDescr | dict[str, Any]":
1912        axes: List[InputAxis] = convert_axes(  # pyright: ignore[reportAssignmentType]
1913            src.axes,
1914            shape=src.shape,
1915            tensor_type="input",
1916            halo=None,
1917            size_refs=size_refs,
1918        )
1919        prep: List[PreprocessingDescr] = []
1920        for p in src.preprocessing:
1921            cp = _convert_proc(p, src.axes)
1922            assert not isinstance(cp, ScaleMeanVarianceDescr)
1923            prep.append(cp)
1924
1925        prep.append(EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="float32")))
1926
1927        return tgt(
1928            axes=axes,
1929            id=TensorId(str(src.name)),
1930            test_tensor=FileDescr(source=test_tensor),
1931            sample_tensor=(
1932                None if sample_tensor is None else FileDescr(source=sample_tensor)
1933            ),
1934            data=dict(type=src.data_type),  # pyright: ignore[reportArgumentType]
1935            preprocessing=prep,
1936        )
1937
1938
1939_input_tensor_conv = _InputTensorConv(_InputTensorDescr_v0_4, InputTensorDescr)
1940
1941
1942class OutputTensorDescr(TensorDescrBase[OutputAxis]):
1943    id: TensorId = TensorId("output")
1944    """Output tensor id.
1945    No duplicates are allowed across all inputs and outputs."""
1946
1947    postprocessing: List[PostprocessingDescr] = Field(default_factory=list)
1948    """Description of how this output should be postprocessed.
1949
1950    note: `postprocessing` always ends with an 'ensure_dtype' operation.
1951          If not given this is added to cast to this tensor's `data.type`.
1952    """
1953
1954    @model_validator(mode="after")
1955    def _validate_postprocessing_kwargs(self) -> Self:
1956        axes_ids = [a.id for a in self.axes]
1957        for p in self.postprocessing:
1958            kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
1959            if kwargs_axes is None:
1960                continue
1961
1962            if not isinstance(kwargs_axes, collections.abc.Sequence):
1963                raise ValueError(
1964                    f"expected `axes` sequence, but got {type(kwargs_axes)}"
1965                )
1966
1967            if any(a not in axes_ids for a in kwargs_axes):
1968                raise ValueError("`kwargs.axes` needs to be subset of axes ids")
1969
1970        if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
1971            dtype = self.data.type
1972        else:
1973            dtype = self.data[0].type
1974
1975        # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
1976        if not self.postprocessing or not isinstance(
1977            self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)
1978        ):
1979            self.postprocessing.append(
1980                EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1981            )
1982        return self
1983
1984
1985class _OutputTensorConv(
1986    Converter[
1987        _OutputTensorDescr_v0_4,
1988        OutputTensorDescr,
1989        ImportantFileSource,
1990        Optional[ImportantFileSource],
1991        Mapping[_TensorName_v0_4, Mapping[str, int]],
1992    ]
1993):
1994    def _convert(
1995        self,
1996        src: _OutputTensorDescr_v0_4,
1997        tgt: "type[OutputTensorDescr] | type[dict[str, Any]]",
1998        test_tensor: ImportantFileSource,
1999        sample_tensor: Optional[ImportantFileSource],
2000        size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
2001    ) -> "OutputTensorDescr | dict[str, Any]":
2002        # TODO: split convert_axes into convert_output_axes and convert_input_axes
2003        axes: List[OutputAxis] = convert_axes(  # pyright: ignore[reportAssignmentType]
2004            src.axes,
2005            shape=src.shape,
2006            tensor_type="output",
2007            halo=src.halo,
2008            size_refs=size_refs,
2009        )
2010        data_descr: Dict[str, Any] = dict(type=src.data_type)
2011        if data_descr["type"] == "bool":
2012            data_descr["values"] = [False, True]
2013
2014        return tgt(
2015            axes=axes,
2016            id=TensorId(str(src.name)),
2017            test_tensor=FileDescr(source=test_tensor),
2018            sample_tensor=(
2019                None if sample_tensor is None else FileDescr(source=sample_tensor)
2020            ),
2021            data=data_descr,  # pyright: ignore[reportArgumentType]
2022            postprocessing=[_convert_proc(p, src.axes) for p in src.postprocessing],
2023        )
2024
2025
2026_output_tensor_conv = _OutputTensorConv(_OutputTensorDescr_v0_4, OutputTensorDescr)
2027
2028
2029TensorDescr = Union[InputTensorDescr, OutputTensorDescr]
2030
2031
2032def validate_tensors(
2033    tensors: Mapping[TensorId, Tuple[TensorDescr, NDArray[Any]]],
2034    tensor_origin: Literal[
2035        "test_tensor"
2036    ],  # for more precise error messages, e.g. 'test_tensor'
2037):
2038    all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, int]]] = {}
2039
2040    def e_msg(d: TensorDescr):
2041        return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]"
2042
2043    for descr, array in tensors.values():
2044        try:
2045            axis_sizes = descr.get_axis_sizes_for_array(array)
2046        except ValueError as e:
2047            raise ValueError(f"{e_msg(descr)} {e}")
2048        else:
2049            all_tensor_axes[descr.id] = {
2050                a.id: (a, axis_sizes[a.id]) for a in descr.axes
2051            }
2052
2053    for descr, array in tensors.values():
2054        if descr.dtype in ("float32", "float64"):
2055            invalid_test_tensor_dtype = array.dtype.name not in (
2056                "float32",
2057                "float64",
2058                "uint8",
2059                "int8",
2060                "uint16",
2061                "int16",
2062                "uint32",
2063                "int32",
2064                "uint64",
2065                "int64",
2066            )
2067        else:
2068            invalid_test_tensor_dtype = array.dtype.name != descr.dtype
2069
2070        if invalid_test_tensor_dtype:
2071            raise ValueError(
2072                f"{e_msg(descr)}.{tensor_origin}.dtype '{array.dtype.name}' does not"
2073                + f" match described dtype '{descr.dtype}'"
2074            )
2075
2076        if array.min() > -1e-4 and array.max() < 1e-4:
2077            raise ValueError(
2078                "Output values are too small for reliable testing."
2079                + f" Values <-1e5 or >=1e5 must be present in {tensor_origin}"
2080            )
2081
2082        for a in descr.axes:
2083            actual_size = all_tensor_axes[descr.id][a.id][1]
2084            if a.size is None:
2085                continue
2086
2087            if isinstance(a.size, int):
2088                if actual_size != a.size:
2089                    raise ValueError(
2090                        f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' "
2091                        + f"has incompatible size {actual_size}, expected {a.size}"
2092                    )
2093            elif isinstance(a.size, ParameterizedSize):
2094                _ = a.size.validate_size(actual_size)
2095            elif isinstance(a.size, DataDependentSize):
2096                _ = a.size.validate_size(actual_size)
2097            elif isinstance(a.size, SizeReference):
2098                ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id)
2099                if ref_tensor_axes is None:
2100                    raise ValueError(
2101                        f"{e_msg(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor"
2102                        + f" reference '{a.size.tensor_id}'"
2103                    )
2104
2105                ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None))
2106                if ref_axis is None or ref_size is None:
2107                    raise ValueError(
2108                        f"{e_msg(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis"
2109                        + f" reference '{a.size.tensor_id}.{a.size.axis_id}"
2110                    )
2111
2112                if a.unit != ref_axis.unit:
2113                    raise ValueError(
2114                        f"{e_msg(descr)}.axes[{a.id}].size: `SizeReference` requires"
2115                        + " axis and reference axis to have the same `unit`, but"
2116                        + f" {a.unit}!={ref_axis.unit}"
2117                    )
2118
2119                if actual_size != (
2120                    expected_size := (
2121                        ref_size * ref_axis.scale / a.scale + a.size.offset
2122                    )
2123                ):
2124                    raise ValueError(
2125                        f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' of size"
2126                        + f" {actual_size} invalid for referenced size {ref_size};"
2127                        + f" expected {expected_size}"
2128                    )
2129            else:
2130                assert_never(a.size)
2131
2132
2133class EnvironmentFileDescr(FileDescr):
2134    source: Annotated[
2135        ImportantFileSource,
2136        WithSuffix((".yaml", ".yml"), case_sensitive=True),
2137        Field(
2138            examples=["environment.yaml"],
2139        ),
2140    ]
2141    """∈📦 Conda environment file.
2142    Allows to specify custom dependencies, see conda docs:
2143    - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms)
2144    - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually)
2145    """
2146
2147
2148class _ArchitectureCallableDescr(Node):
2149    callable: Annotated[Identifier, Field(examples=["MyNetworkClass", "get_my_model"])]
2150    """Identifier of the callable that returns a torch.nn.Module instance."""
2151
2152    kwargs: Dict[str, YamlValue] = Field(default_factory=dict)
2153    """key word arguments for the `callable`"""
2154
2155
2156class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr):
2157    pass
2158
2159
2160class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr):
2161    import_from: str
2162    """Where to import the callable from, i.e. `from <import_from> import <callable>`"""
2163
2164
2165ArchitectureDescr = Annotated[
2166    Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr],
2167    Field(union_mode="left_to_right"),
2168]
2169
2170
2171class _ArchFileConv(
2172    Converter[
2173        _CallableFromFile_v0_4,
2174        ArchitectureFromFileDescr,
2175        Optional[Sha256],
2176        Dict[str, Any],
2177    ]
2178):
2179    def _convert(
2180        self,
2181        src: _CallableFromFile_v0_4,
2182        tgt: "type[ArchitectureFromFileDescr | dict[str, Any]]",
2183        sha256: Optional[Sha256],
2184        kwargs: Dict[str, Any],
2185    ) -> "ArchitectureFromFileDescr | dict[str, Any]":
2186        if src.startswith("http") and src.count(":") == 2:
2187            http, source, callable_ = src.split(":")
2188            source = ":".join((http, source))
2189        elif not src.startswith("http") and src.count(":") == 1:
2190            source, callable_ = src.split(":")
2191        else:
2192            source = str(src)
2193            callable_ = str(src)
2194        return tgt(
2195            callable=Identifier(callable_),
2196            source=cast(ImportantFileSource, source),
2197            sha256=sha256,
2198            kwargs=kwargs,
2199        )
2200
2201
2202_arch_file_conv = _ArchFileConv(_CallableFromFile_v0_4, ArchitectureFromFileDescr)
2203
2204
2205class _ArchLibConv(
2206    Converter[
2207        _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr, Dict[str, Any]
2208    ]
2209):
2210    def _convert(
2211        self,
2212        src: _CallableFromDepencency_v0_4,
2213        tgt: "type[ArchitectureFromLibraryDescr | dict[str, Any]]",
2214        kwargs: Dict[str, Any],
2215    ) -> "ArchitectureFromLibraryDescr | dict[str, Any]":
2216        *mods, callable_ = src.split(".")
2217        import_from = ".".join(mods)
2218        return tgt(
2219            import_from=import_from, callable=Identifier(callable_), kwargs=kwargs
2220        )
2221
2222
2223_arch_lib_conv = _ArchLibConv(
2224    _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr
2225)
2226
2227
2228class WeightsEntryDescrBase(FileDescr):
2229    type: ClassVar[WeightsFormat]
2230    weights_format_name: ClassVar[str]  # human readable
2231
2232    source: ImportantFileSource
2233    """∈📦 The weights file."""
2234
2235    authors: Optional[List[Author]] = None
2236    """Authors
2237    Either the person(s) that have trained this model resulting in the original weights file.
2238        (If this is the initial weights entry, i.e. it does not have a `parent`)
2239    Or the person(s) who have converted the weights to this weights format.
2240        (If this is a child weight, i.e. it has a `parent` field)
2241    """
2242
2243    parent: Annotated[
2244        Optional[WeightsFormat], Field(examples=["pytorch_state_dict"])
2245    ] = None
2246    """The source weights these weights were converted from.
2247    For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`,
2248    The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights.
2249    All weight entries except one (the initial set of weights resulting from training the model),
2250    need to have this field."""
2251
2252    comment: str = ""
2253    """A comment about this weights entry, for example how these weights were created."""
2254
2255    @model_validator(mode="after")
2256    def check_parent_is_not_self(self) -> Self:
2257        if self.type == self.parent:
2258            raise ValueError("Weights entry can't be it's own parent.")
2259
2260        return self
2261
2262
2263class KerasHdf5WeightsDescr(WeightsEntryDescrBase):
2264    type = "keras_hdf5"
2265    weights_format_name: ClassVar[str] = "Keras HDF5"
2266    tensorflow_version: Version
2267    """TensorFlow version used to create these weights."""
2268
2269
2270class OnnxWeightsDescr(WeightsEntryDescrBase):
2271    type = "onnx"
2272    weights_format_name: ClassVar[str] = "ONNX"
2273    opset_version: Annotated[int, Ge(7)]
2274    """ONNX opset version"""
2275
2276
2277class PytorchStateDictWeightsDescr(WeightsEntryDescrBase):
2278    type = "pytorch_state_dict"
2279    weights_format_name: ClassVar[str] = "Pytorch State Dict"
2280    architecture: ArchitectureDescr
2281    pytorch_version: Version
2282    """Version of the PyTorch library used.
2283    If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible.
2284    """
2285    dependencies: Optional[EnvironmentFileDescr] = None
2286    """Custom depencies beyond pytorch.
2287    The conda environment file should include pytorch and any version pinning has to be compatible with
2288    `pytorch_version`.
2289    """
2290
2291
2292class TensorflowJsWeightsDescr(WeightsEntryDescrBase):
2293    type = "tensorflow_js"
2294    weights_format_name: ClassVar[str] = "Tensorflow.js"
2295    tensorflow_version: Version
2296    """Version of the TensorFlow library used."""
2297
2298    source: ImportantFileSource
2299    """∈📦 The multi-file weights.
2300    All required files/folders should be a zip archive."""
2301
2302
2303class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase):
2304    type = "tensorflow_saved_model_bundle"
2305    weights_format_name: ClassVar[str] = "Tensorflow Saved Model"
2306    tensorflow_version: Version
2307    """Version of the TensorFlow library used."""
2308
2309    dependencies: Optional[EnvironmentFileDescr] = None
2310    """Custom dependencies beyond tensorflow.
2311    Should include tensorflow and any version pinning has to be compatible with `tensorflow_version`."""
2312
2313    source: ImportantFileSource
2314    """∈📦 The multi-file weights.
2315    All required files/folders should be a zip archive."""
2316
2317
2318class TorchscriptWeightsDescr(WeightsEntryDescrBase):
2319    type = "torchscript"
2320    weights_format_name: ClassVar[str] = "TorchScript"
2321    pytorch_version: Version
2322    """Version of the PyTorch library used."""
2323
2324
2325class WeightsDescr(Node):
2326    keras_hdf5: Optional[KerasHdf5WeightsDescr] = None
2327    onnx: Optional[OnnxWeightsDescr] = None
2328    pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None
2329    tensorflow_js: Optional[TensorflowJsWeightsDescr] = None
2330    tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = (
2331        None
2332    )
2333    torchscript: Optional[TorchscriptWeightsDescr] = None
2334
2335    @model_validator(mode="after")
2336    def check_entries(self) -> Self:
2337        entries = {wtype for wtype, entry in self if entry is not None}
2338
2339        if not entries:
2340            raise ValueError("Missing weights entry")
2341
2342        entries_wo_parent = {
2343            wtype
2344            for wtype, entry in self
2345            if entry is not None and hasattr(entry, "parent") and entry.parent is None
2346        }
2347        if len(entries_wo_parent) != 1:
2348            issue_warning(
2349                "Exactly one weights entry may not specify the `parent` field (got"
2350                + " {value}). That entry is considered the original set of model weights."
2351                + " Other weight formats are created through conversion of the orignal or"
2352                + " already converted weights. They have to reference the weights format"
2353                + " they were converted from as their `parent`.",
2354                value=len(entries_wo_parent),
2355                field="weights",
2356            )
2357
2358        for wtype, entry in self:
2359            if entry is None:
2360                continue
2361
2362            assert hasattr(entry, "type")
2363            assert hasattr(entry, "parent")
2364            assert wtype == entry.type
2365            if (
2366                entry.parent is not None and entry.parent not in entries
2367            ):  # self reference checked for `parent` field
2368                raise ValueError(
2369                    f"`weights.{wtype}.parent={entry.parent} not in specified weight"
2370                    + f" formats: {entries}"
2371                )
2372
2373        return self
2374
2375    def __getitem__(
2376        self,
2377        key: Literal[
2378            "keras_hdf5",
2379            "onnx",
2380            "pytorch_state_dict",
2381            "tensorflow_js",
2382            "tensorflow_saved_model_bundle",
2383            "torchscript",
2384        ],
2385    ):
2386        if key == "keras_hdf5":
2387            ret = self.keras_hdf5
2388        elif key == "onnx":
2389            ret = self.onnx
2390        elif key == "pytorch_state_dict":
2391            ret = self.pytorch_state_dict
2392        elif key == "tensorflow_js":
2393            ret = self.tensorflow_js
2394        elif key == "tensorflow_saved_model_bundle":
2395            ret = self.tensorflow_saved_model_bundle
2396        elif key == "torchscript":
2397            ret = self.torchscript
2398        else:
2399            raise KeyError(key)
2400
2401        if ret is None:
2402            raise KeyError(key)
2403
2404        return ret
2405
2406    @property
2407    def available_formats(self):
2408        return {
2409            **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}),
2410            **({} if self.onnx is None else {"onnx": self.onnx}),
2411            **(
2412                {}
2413                if self.pytorch_state_dict is None
2414                else {"pytorch_state_dict": self.pytorch_state_dict}
2415            ),
2416            **(
2417                {}
2418                if self.tensorflow_js is None
2419                else {"tensorflow_js": self.tensorflow_js}
2420            ),
2421            **(
2422                {}
2423                if self.tensorflow_saved_model_bundle is None
2424                else {
2425                    "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle
2426                }
2427            ),
2428            **({} if self.torchscript is None else {"torchscript": self.torchscript}),
2429        }
2430
2431    @property
2432    def missing_formats(self):
2433        return {
2434            wf for wf in get_args(WeightsFormat) if wf not in self.available_formats
2435        }
2436
2437
2438class ModelId(ResourceId):
2439    pass
2440
2441
2442class LinkedModel(LinkedResourceBase):
2443    """Reference to a bioimage.io model."""
2444
2445    id: ModelId
2446    """A valid model `id` from the bioimage.io collection."""
2447
2448
2449class _DataDepSize(NamedTuple):
2450    min: int
2451    max: Optional[int]
2452
2453
2454class _AxisSizes(NamedTuple):
2455    """the lenghts of all axes of model inputs and outputs"""
2456
2457    inputs: Dict[Tuple[TensorId, AxisId], int]
2458    outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]]
2459
2460
2461class _TensorSizes(NamedTuple):
2462    """_AxisSizes as nested dicts"""
2463
2464    inputs: Dict[TensorId, Dict[AxisId, int]]
2465    outputs: Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]
2466
2467
2468class ReproducibilityTolerance(Node, extra="allow"):
2469    """Describes what small numerical differences -- if any -- may be tolerated
2470    in the generated output when executing in different environments.
2471
2472    A tensor element *output* is considered mismatched to the **test_tensor** if
2473    abs(*output* - **test_tensor**) > **absolute_tolerance** + **relative_tolerance** * abs(**test_tensor**).
2474    (Internally we call [numpy.testing.assert_allclose](https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html).)
2475
2476    Motivation:
2477        For testing we can request the respective deep learning frameworks to be as
2478        reproducible as possible by setting seeds and chosing deterministic algorithms,
2479        but differences in operating systems, available hardware and installed drivers
2480        may still lead to numerical differences.
2481    """
2482
2483    relative_tolerance: RelativeTolerance = 1e-3
2484    """Maximum relative tolerance of reproduced test tensor."""
2485
2486    absolute_tolerance: AbsoluteTolerance = 1e-4
2487    """Maximum absolute tolerance of reproduced test tensor."""
2488
2489    mismatched_elements_per_million: MismatchedElementsPerMillion = 0
2490    """Maximum number of mismatched elements/pixels per million to tolerate."""
2491
2492    output_ids: Sequence[TensorId] = ()
2493    """Limits the output tensor IDs these reproducibility details apply to."""
2494
2495    weights_formats: Sequence[WeightsFormat] = ()
2496    """Limits the weights formats these details apply to."""
2497
2498
2499class BioimageioConfig(Node, extra="allow"):
2500    reproducibility_tolerance: Sequence[ReproducibilityTolerance] = ()
2501    """Tolerances to allow when reproducing the model's test outputs
2502    from the model's test inputs.
2503    Only the first entry matching tensor id and weights format is considered.
2504    """
2505
2506
2507class Config(Node, extra="allow"):
2508    bioimageio: BioimageioConfig = Field(default_factory=BioimageioConfig)
2509
2510
2511class ModelDescr(GenericModelDescrBase):
2512    """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights.
2513    These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
2514    """
2515
2516    implemented_format_version: ClassVar[Literal["0.5.4"]] = "0.5.4"
2517    if TYPE_CHECKING:
2518        format_version: Literal["0.5.4"] = "0.5.4"
2519    else:
2520        format_version: Literal["0.5.4"]
2521        """Version of the bioimage.io model description specification used.
2522        When creating a new model always use the latest micro/patch version described here.
2523        The `format_version` is important for any consumer software to understand how to parse the fields.
2524        """
2525
2526    implemented_type: ClassVar[Literal["model"]] = "model"
2527    if TYPE_CHECKING:
2528        type: Literal["model"] = "model"
2529    else:
2530        type: Literal["model"]
2531        """Specialized resource type 'model'"""
2532
2533    id: Optional[ModelId] = None
2534    """bioimage.io-wide unique resource identifier
2535    assigned by bioimage.io; version **un**specific."""
2536
2537    authors: NotEmpty[List[Author]]
2538    """The authors are the creators of the model RDF and the primary points of contact."""
2539
2540    documentation: Annotated[
2541        DocumentationSource,
2542        Field(
2543            examples=[
2544                "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/README.md",
2545                "README.md",
2546            ],
2547        ),
2548    ]
2549    """∈📦 URL or relative path to a markdown file with additional documentation.
2550    The recommended documentation file name is `README.md`. An `.md` suffix is mandatory.
2551    The documentation should include a '#[#] Validation' (sub)section
2552    with details on how to quantitatively validate the model on unseen data."""
2553
2554    @field_validator("documentation", mode="after")
2555    @classmethod
2556    def _validate_documentation(cls, value: DocumentationSource) -> DocumentationSource:
2557        if not get_validation_context().perform_io_checks:
2558            return value
2559
2560        doc_path = download(value).path
2561        doc_content = doc_path.read_text(encoding="utf-8")
2562        assert isinstance(doc_content, str)
2563        if not re.search("#.*[vV]alidation", doc_content):
2564            issue_warning(
2565                "No '# Validation' (sub)section found in {value}.",
2566                value=value,
2567                field="documentation",
2568            )
2569
2570        return value
2571
2572    inputs: NotEmpty[Sequence[InputTensorDescr]]
2573    """Describes the input tensors expected by this model."""
2574
2575    @field_validator("inputs", mode="after")
2576    @classmethod
2577    def _validate_input_axes(
2578        cls, inputs: Sequence[InputTensorDescr]
2579    ) -> Sequence[InputTensorDescr]:
2580        input_size_refs = cls._get_axes_with_independent_size(inputs)
2581
2582        for i, ipt in enumerate(inputs):
2583            valid_independent_refs: Dict[
2584                Tuple[TensorId, AxisId],
2585                Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2586            ] = {
2587                **{
2588                    (ipt.id, a.id): (ipt, a, a.size)
2589                    for a in ipt.axes
2590                    if not isinstance(a, BatchAxis)
2591                    and isinstance(a.size, (int, ParameterizedSize))
2592                },
2593                **input_size_refs,
2594            }
2595            for a, ax in enumerate(ipt.axes):
2596                cls._validate_axis(
2597                    "inputs",
2598                    i=i,
2599                    tensor_id=ipt.id,
2600                    a=a,
2601                    axis=ax,
2602                    valid_independent_refs=valid_independent_refs,
2603                )
2604        return inputs
2605
2606    @staticmethod
2607    def _validate_axis(
2608        field_name: str,
2609        i: int,
2610        tensor_id: TensorId,
2611        a: int,
2612        axis: AnyAxis,
2613        valid_independent_refs: Dict[
2614            Tuple[TensorId, AxisId],
2615            Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2616        ],
2617    ):
2618        if isinstance(axis, BatchAxis) or isinstance(
2619            axis.size, (int, ParameterizedSize, DataDependentSize)
2620        ):
2621            return
2622        elif not isinstance(axis.size, SizeReference):
2623            assert_never(axis.size)
2624
2625        # validate axis.size SizeReference
2626        ref = (axis.size.tensor_id, axis.size.axis_id)
2627        if ref not in valid_independent_refs:
2628            raise ValueError(
2629                "Invalid tensor axis reference at"
2630                + f" {field_name}[{i}].axes[{a}].size: {axis.size}."
2631            )
2632        if ref == (tensor_id, axis.id):
2633            raise ValueError(
2634                "Self-referencing not allowed for"
2635                + f" {field_name}[{i}].axes[{a}].size: {axis.size}"
2636            )
2637        if axis.type == "channel":
2638            if valid_independent_refs[ref][1].type != "channel":
2639                raise ValueError(
2640                    "A channel axis' size may only reference another fixed size"
2641                    + " channel axis."
2642                )
2643            if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names:
2644                ref_size = valid_independent_refs[ref][2]
2645                assert isinstance(ref_size, int), (
2646                    "channel axis ref (another channel axis) has to specify fixed"
2647                    + " size"
2648                )
2649                generated_channel_names = [
2650                    Identifier(axis.channel_names.format(i=i))
2651                    for i in range(1, ref_size + 1)
2652                ]
2653                axis.channel_names = generated_channel_names
2654
2655        if (ax_unit := getattr(axis, "unit", None)) != (
2656            ref_unit := getattr(valid_independent_refs[ref][1], "unit", None)
2657        ):
2658            raise ValueError(
2659                "The units of an axis and its reference axis need to match, but"
2660                + f" '{ax_unit}' != '{ref_unit}'."
2661            )
2662        ref_axis = valid_independent_refs[ref][1]
2663        if isinstance(ref_axis, BatchAxis):
2664            raise ValueError(
2665                f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}"
2666                + " (a batch axis is not allowed as reference)."
2667            )
2668
2669        if isinstance(axis, WithHalo):
2670            min_size = axis.size.get_size(axis, ref_axis, n=0)
2671            if (min_size - 2 * axis.halo) < 1:
2672                raise ValueError(
2673                    f"axis {axis.id} with minimum size {min_size} is too small for halo"
2674                    + f" {axis.halo}."
2675                )
2676
2677            input_halo = axis.halo * axis.scale / ref_axis.scale
2678            if input_halo != int(input_halo) or input_halo % 2 == 1:
2679                raise ValueError(
2680                    f"input_halo {input_halo} (output_halo {axis.halo} *"
2681                    + f" output_scale {axis.scale} / input_scale {ref_axis.scale})"
2682                    + f"     {tensor_id}.{axis.id}."
2683                )
2684
2685    @model_validator(mode="after")
2686    def _validate_test_tensors(self) -> Self:
2687        if not get_validation_context().perform_io_checks:
2688            return self
2689
2690        test_output_arrays = [
2691            load_array(descr.test_tensor.download().path) for descr in self.outputs
2692        ]
2693        test_input_arrays = [
2694            load_array(descr.test_tensor.download().path) for descr in self.inputs
2695        ]
2696
2697        tensors = {
2698            descr.id: (descr, array)
2699            for descr, array in zip(
2700                chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays
2701            )
2702        }
2703        validate_tensors(tensors, tensor_origin="test_tensor")
2704
2705        output_arrays = {
2706            descr.id: array for descr, array in zip(self.outputs, test_output_arrays)
2707        }
2708        for rep_tol in self.config.bioimageio.reproducibility_tolerance:
2709            if not rep_tol.absolute_tolerance:
2710                continue
2711
2712            if rep_tol.output_ids:
2713                out_arrays = {
2714                    oid: a
2715                    for oid, a in output_arrays.items()
2716                    if oid in rep_tol.output_ids
2717                }
2718            else:
2719                out_arrays = output_arrays
2720
2721            for out_id, array in out_arrays.items():
2722                if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01:
2723                    raise ValueError(
2724                        "config.bioimageio.reproducibility_tolerance.absolute_tolerance="
2725                        + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}"
2726                        + f" (1% of the maximum value of the test tensor '{out_id}')"
2727                    )
2728
2729        return self
2730
2731    @model_validator(mode="after")
2732    def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self:
2733        ipt_refs = {t.id for t in self.inputs}
2734        out_refs = {t.id for t in self.outputs}
2735        for ipt in self.inputs:
2736            for p in ipt.preprocessing:
2737                ref = p.kwargs.get("reference_tensor")
2738                if ref is None:
2739                    continue
2740                if ref not in ipt_refs:
2741                    raise ValueError(
2742                        f"`reference_tensor` '{ref}' not found. Valid input tensor"
2743                        + f" references are: {ipt_refs}."
2744                    )
2745
2746        for out in self.outputs:
2747            for p in out.postprocessing:
2748                ref = p.kwargs.get("reference_tensor")
2749                if ref is None:
2750                    continue
2751
2752                if ref not in ipt_refs and ref not in out_refs:
2753                    raise ValueError(
2754                        f"`reference_tensor` '{ref}' not found. Valid tensor references"
2755                        + f" are: {ipt_refs | out_refs}."
2756                    )
2757
2758        return self
2759
2760    # TODO: use validate funcs in validate_test_tensors
2761    # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]:
2762
2763    name: Annotated[
2764        Annotated[
2765            str, RestrictCharacters(string.ascii_letters + string.digits + "_+- ()")
2766        ],
2767        MinLen(5),
2768        MaxLen(128),
2769        warn(MaxLen(64), "Name longer than 64 characters.", INFO),
2770    ]
2771    """A human-readable name of this model.
2772    It should be no longer than 64 characters
2773    and may only contain letter, number, underscore, minus, parentheses and spaces.
2774    We recommend to chose a name that refers to the model's task and image modality.
2775    """
2776
2777    outputs: NotEmpty[Sequence[OutputTensorDescr]]
2778    """Describes the output tensors."""
2779
2780    @field_validator("outputs", mode="after")
2781    @classmethod
2782    def _validate_tensor_ids(
2783        cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo
2784    ) -> Sequence[OutputTensorDescr]:
2785        tensor_ids = [
2786            t.id for t in info.data.get("inputs", []) + info.data.get("outputs", [])
2787        ]
2788        duplicate_tensor_ids: List[str] = []
2789        seen: Set[str] = set()
2790        for t in tensor_ids:
2791            if t in seen:
2792                duplicate_tensor_ids.append(t)
2793
2794            seen.add(t)
2795
2796        if duplicate_tensor_ids:
2797            raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}")
2798
2799        return outputs
2800
2801    @staticmethod
2802    def _get_axes_with_parameterized_size(
2803        io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2804    ):
2805        return {
2806            f"{t.id}.{a.id}": (t, a, a.size)
2807            for t in io
2808            for a in t.axes
2809            if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize)
2810        }
2811
2812    @staticmethod
2813    def _get_axes_with_independent_size(
2814        io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2815    ):
2816        return {
2817            (t.id, a.id): (t, a, a.size)
2818            for t in io
2819            for a in t.axes
2820            if not isinstance(a, BatchAxis)
2821            and isinstance(a.size, (int, ParameterizedSize))
2822        }
2823
2824    @field_validator("outputs", mode="after")
2825    @classmethod
2826    def _validate_output_axes(
2827        cls, outputs: List[OutputTensorDescr], info: ValidationInfo
2828    ) -> List[OutputTensorDescr]:
2829        input_size_refs = cls._get_axes_with_independent_size(
2830            info.data.get("inputs", [])
2831        )
2832        output_size_refs = cls._get_axes_with_independent_size(outputs)
2833
2834        for i, out in enumerate(outputs):
2835            valid_independent_refs: Dict[
2836                Tuple[TensorId, AxisId],
2837                Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2838            ] = {
2839                **{
2840                    (out.id, a.id): (out, a, a.size)
2841                    for a in out.axes
2842                    if not isinstance(a, BatchAxis)
2843                    and isinstance(a.size, (int, ParameterizedSize))
2844                },
2845                **input_size_refs,
2846                **output_size_refs,
2847            }
2848            for a, ax in enumerate(out.axes):
2849                cls._validate_axis(
2850                    "outputs",
2851                    i,
2852                    out.id,
2853                    a,
2854                    ax,
2855                    valid_independent_refs=valid_independent_refs,
2856                )
2857
2858        return outputs
2859
2860    packaged_by: List[Author] = Field(default_factory=list)
2861    """The persons that have packaged and uploaded this model.
2862    Only required if those persons differ from the `authors`."""
2863
2864    parent: Optional[LinkedModel] = None
2865    """The model from which this model is derived, e.g. by fine-tuning the weights."""
2866
2867    @model_validator(mode="after")
2868    def _validate_parent_is_not_self(self) -> Self:
2869        if self.parent is not None and self.parent.id == self.id:
2870            raise ValueError("A model description may not reference itself as parent.")
2871
2872        return self
2873
2874    run_mode: Annotated[
2875        Optional[RunMode],
2876        warn(None, "Run mode '{value}' has limited support across consumer softwares."),
2877    ] = None
2878    """Custom run mode for this model: for more complex prediction procedures like test time
2879    data augmentation that currently cannot be expressed in the specification.
2880    No standard run modes are defined yet."""
2881
2882    timestamp: Datetime = Field(default_factory=Datetime.now)
2883    """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format
2884    with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat).
2885    (In Python a datetime object is valid, too)."""
2886
2887    training_data: Annotated[
2888        Union[None, LinkedDataset, DatasetDescr, DatasetDescr02],
2889        Field(union_mode="left_to_right"),
2890    ] = None
2891    """The dataset used to train this model"""
2892
2893    weights: Annotated[WeightsDescr, WrapSerializer(package_weights)]
2894    """The weights for this model.
2895    Weights can be given for different formats, but should otherwise be equivalent.
2896    The available weight formats determine which consumers can use this model."""
2897
2898    config: Config = Field(default_factory=Config)
2899
2900    @model_validator(mode="after")
2901    def _add_default_cover(self) -> Self:
2902        if not get_validation_context().perform_io_checks or self.covers:
2903            return self
2904
2905        try:
2906            generated_covers = generate_covers(
2907                [(t, load_array(t.test_tensor.download().path)) for t in self.inputs],
2908                [(t, load_array(t.test_tensor.download().path)) for t in self.outputs],
2909            )
2910        except Exception as e:
2911            issue_warning(
2912                "Failed to generate cover image(s): {e}",
2913                value=self.covers,
2914                msg_context=dict(e=e),
2915                field="covers",
2916            )
2917        else:
2918            self.covers.extend(generated_covers)
2919
2920        return self
2921
2922    def get_input_test_arrays(self) -> List[NDArray[Any]]:
2923        data = [load_array(ipt.test_tensor.download().path) for ipt in self.inputs]
2924        assert all(isinstance(d, np.ndarray) for d in data)
2925        return data
2926
2927    def get_output_test_arrays(self) -> List[NDArray[Any]]:
2928        data = [load_array(out.test_tensor.download().path) for out in self.outputs]
2929        assert all(isinstance(d, np.ndarray) for d in data)
2930        return data
2931
2932    @staticmethod
2933    def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
2934        batch_size = 1
2935        tensor_with_batchsize: Optional[TensorId] = None
2936        for tid in tensor_sizes:
2937            for aid, s in tensor_sizes[tid].items():
2938                if aid != BATCH_AXIS_ID or s == 1 or s == batch_size:
2939                    continue
2940
2941                if batch_size != 1:
2942                    assert tensor_with_batchsize is not None
2943                    raise ValueError(
2944                        f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})"
2945                    )
2946
2947                batch_size = s
2948                tensor_with_batchsize = tid
2949
2950        return batch_size
2951
2952    def get_output_tensor_sizes(
2953        self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]
2954    ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]:
2955        """Returns the tensor output sizes for given **input_sizes**.
2956        Only if **input_sizes** has a valid input shape, the tensor output size is exact.
2957        Otherwise it might be larger than the actual (valid) output"""
2958        batch_size = self.get_batch_size(input_sizes)
2959        ns = self.get_ns(input_sizes)
2960
2961        tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size)
2962        return tensor_sizes.outputs
2963
2964    def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
2965        """get parameter `n` for each parameterized axis
2966        such that the valid input size is >= the given input size"""
2967        ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {}
2968        axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs}
2969        for tid in input_sizes:
2970            for aid, s in input_sizes[tid].items():
2971                size_descr = axes[tid][aid].size
2972                if isinstance(size_descr, ParameterizedSize):
2973                    ret[(tid, aid)] = size_descr.get_n(s)
2974                elif size_descr is None or isinstance(size_descr, (int, SizeReference)):
2975                    pass
2976                else:
2977                    assert_never(size_descr)
2978
2979        return ret
2980
2981    def get_tensor_sizes(
2982        self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int
2983    ) -> _TensorSizes:
2984        axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size)
2985        return _TensorSizes(
2986            {
2987                t: {
2988                    aa: axis_sizes.inputs[(tt, aa)]
2989                    for tt, aa in axis_sizes.inputs
2990                    if tt == t
2991                }
2992                for t in {tt for tt, _ in axis_sizes.inputs}
2993            },
2994            {
2995                t: {
2996                    aa: axis_sizes.outputs[(tt, aa)]
2997                    for tt, aa in axis_sizes.outputs
2998                    if tt == t
2999                }
3000                for t in {tt for tt, _ in axis_sizes.outputs}
3001            },
3002        )
3003
3004    def get_axis_sizes(
3005        self,
3006        ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
3007        batch_size: Optional[int] = None,
3008        *,
3009        max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
3010    ) -> _AxisSizes:
3011        """Determine input and output block shape for scale factors **ns**
3012        of parameterized input sizes.
3013
3014        Args:
3015            ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id))
3016                that is parameterized as `size = min + n * step`.
3017            batch_size: The desired size of the batch dimension.
3018                If given **batch_size** overwrites any batch size present in
3019                **max_input_shape**. Default 1.
3020            max_input_shape: Limits the derived block shapes.
3021                Each axis for which the input size, parameterized by `n`, is larger
3022                than **max_input_shape** is set to the minimal value `n_min` for which
3023                this is still true.
3024                Use this for small input samples or large values of **ns**.
3025                Or simply whenever you know the full input shape.
3026
3027        Returns:
3028            Resolved axis sizes for model inputs and outputs.
3029        """
3030        max_input_shape = max_input_shape or {}
3031        if batch_size is None:
3032            for (_t_id, a_id), s in max_input_shape.items():
3033                if a_id == BATCH_AXIS_ID:
3034                    batch_size = s
3035                    break
3036            else:
3037                batch_size = 1
3038
3039        all_axes = {
3040            t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
3041        }
3042
3043        inputs: Dict[Tuple[TensorId, AxisId], int] = {}
3044        outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {}
3045
3046        def get_axis_size(a: Union[InputAxis, OutputAxis]):
3047            if isinstance(a, BatchAxis):
3048                if (t_descr.id, a.id) in ns:
3049                    logger.warning(
3050                        "Ignoring unexpected size increment factor (n) for batch axis"
3051                        + " of tensor '{}'.",
3052                        t_descr.id,
3053                    )
3054                return batch_size
3055            elif isinstance(a.size, int):
3056                if (t_descr.id, a.id) in ns:
3057                    logger.warning(
3058                        "Ignoring unexpected size increment factor (n) for fixed size"
3059                        + " axis '{}' of tensor '{}'.",
3060                        a.id,
3061                        t_descr.id,
3062                    )
3063                return a.size
3064            elif isinstance(a.size, ParameterizedSize):
3065                if (t_descr.id, a.id) not in ns:
3066                    raise ValueError(
3067                        "Size increment factor (n) missing for parametrized axis"
3068                        + f" '{a.id}' of tensor '{t_descr.id}'."
3069                    )
3070                n = ns[(t_descr.id, a.id)]
3071                s_max = max_input_shape.get((t_descr.id, a.id))
3072                if s_max is not None:
3073                    n = min(n, a.size.get_n(s_max))
3074
3075                return a.size.get_size(n)
3076
3077            elif isinstance(a.size, SizeReference):
3078                if (t_descr.id, a.id) in ns:
3079                    logger.warning(
3080                        "Ignoring unexpected size increment factor (n) for axis '{}'"
3081                        + " of tensor '{}' with size reference.",
3082                        a.id,
3083                        t_descr.id,
3084                    )
3085                assert not isinstance(a, BatchAxis)
3086                ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
3087                assert not isinstance(ref_axis, BatchAxis)
3088                ref_key = (a.size.tensor_id, a.size.axis_id)
3089                ref_size = inputs.get(ref_key, outputs.get(ref_key))
3090                assert ref_size is not None, ref_key
3091                assert not isinstance(ref_size, _DataDepSize), ref_key
3092                return a.size.get_size(
3093                    axis=a,
3094                    ref_axis=ref_axis,
3095                    ref_size=ref_size,
3096                )
3097            elif isinstance(a.size, DataDependentSize):
3098                if (t_descr.id, a.id) in ns:
3099                    logger.warning(
3100                        "Ignoring unexpected increment factor (n) for data dependent"
3101                        + " size axis '{}' of tensor '{}'.",
3102                        a.id,
3103                        t_descr.id,
3104                    )
3105                return _DataDepSize(a.size.min, a.size.max)
3106            else:
3107                assert_never(a.size)
3108
3109        # first resolve all , but the `SizeReference` input sizes
3110        for t_descr in self.inputs:
3111            for a in t_descr.axes:
3112                if not isinstance(a.size, SizeReference):
3113                    s = get_axis_size(a)
3114                    assert not isinstance(s, _DataDepSize)
3115                    inputs[t_descr.id, a.id] = s
3116
3117        # resolve all other input axis sizes
3118        for t_descr in self.inputs:
3119            for a in t_descr.axes:
3120                if isinstance(a.size, SizeReference):
3121                    s = get_axis_size(a)
3122                    assert not isinstance(s, _DataDepSize)
3123                    inputs[t_descr.id, a.id] = s
3124
3125        # resolve all output axis sizes
3126        for t_descr in self.outputs:
3127            for a in t_descr.axes:
3128                assert not isinstance(a.size, ParameterizedSize)
3129                s = get_axis_size(a)
3130                outputs[t_descr.id, a.id] = s
3131
3132        return _AxisSizes(inputs=inputs, outputs=outputs)
3133
3134    @model_validator(mode="before")
3135    @classmethod
3136    def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]:
3137        cls.convert_from_old_format_wo_validation(data)
3138        return data
3139
3140    @classmethod
3141    def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3142        """Convert metadata following an older format version to this classes' format
3143        without validating the result.
3144        """
3145        if (
3146            data.get("type") == "model"
3147            and isinstance(fv := data.get("format_version"), str)
3148            and fv.count(".") == 2
3149        ):
3150            fv_parts = fv.split(".")
3151            if any(not p.isdigit() for p in fv_parts):
3152                return
3153
3154            fv_tuple = tuple(map(int, fv_parts))
3155
3156            assert cls.implemented_format_version_tuple[0:2] == (0, 5)
3157            if fv_tuple[:2] in ((0, 3), (0, 4)):
3158                m04 = _ModelDescr_v0_4.load(data)
3159                if isinstance(m04, InvalidDescr):
3160                    try:
3161                        updated = _model_conv.convert_as_dict(
3162                            m04  # pyright: ignore[reportArgumentType]
3163                        )
3164                    except Exception as e:
3165                        logger.error(
3166                            "Failed to convert from invalid model 0.4 description."
3167                            + f"\nerror: {e}"
3168                            + "\nProceeding with model 0.5 validation without conversion."
3169                        )
3170                        updated = None
3171                else:
3172                    updated = _model_conv.convert_as_dict(m04)
3173
3174                if updated is not None:
3175                    data.clear()
3176                    data.update(updated)
3177
3178            elif fv_tuple[:2] == (0, 5):
3179                # bump patch version
3180                data["format_version"] = cls.implemented_format_version
3181
3182
3183class _ModelConv(Converter[_ModelDescr_v0_4, ModelDescr]):
3184    def _convert(
3185        self, src: _ModelDescr_v0_4, tgt: "type[ModelDescr] | type[dict[str, Any]]"
3186    ) -> "ModelDescr | dict[str, Any]":
3187        name = "".join(
3188            c if c in string.ascii_letters + string.digits + "_+- ()" else " "
3189            for c in src.name
3190        )
3191
3192        def conv_authors(auths: Optional[Sequence[_Author_v0_4]]):
3193            conv = (
3194                _author_conv.convert if TYPE_CHECKING else _author_conv.convert_as_dict
3195            )
3196            return None if auths is None else [conv(a) for a in auths]
3197
3198        if TYPE_CHECKING:
3199            arch_file_conv = _arch_file_conv.convert
3200            arch_lib_conv = _arch_lib_conv.convert
3201        else:
3202            arch_file_conv = _arch_file_conv.convert_as_dict
3203            arch_lib_conv = _arch_lib_conv.convert_as_dict
3204
3205        input_size_refs = {
3206            ipt.name: {
3207                a: s
3208                for a, s in zip(
3209                    ipt.axes,
3210                    (
3211                        ipt.shape.min
3212                        if isinstance(ipt.shape, _ParameterizedInputShape_v0_4)
3213                        else ipt.shape
3214                    ),
3215                )
3216            }
3217            for ipt in src.inputs
3218            if ipt.shape
3219        }
3220        output_size_refs = {
3221            **{
3222                out.name: {a: s for a, s in zip(out.axes, out.shape)}
3223                for out in src.outputs
3224                if not isinstance(out.shape, _ImplicitOutputShape_v0_4)
3225            },
3226            **input_size_refs,
3227        }
3228
3229        return tgt(
3230            attachments=(
3231                []
3232                if src.attachments is None
3233                else [FileDescr(source=f) for f in src.attachments.files]
3234            ),
3235            authors=[
3236                _author_conv.convert_as_dict(a) for a in src.authors
3237            ],  # pyright: ignore[reportArgumentType]
3238            cite=[
3239                {"text": c.text, "doi": c.doi, "url": c.url} for c in src.cite
3240            ],  # pyright: ignore[reportArgumentType]
3241            config=src.config,  # pyright: ignore[reportArgumentType]
3242            covers=src.covers,
3243            description=src.description,
3244            documentation=src.documentation,
3245            format_version="0.5.4",
3246            git_repo=src.git_repo,  # pyright: ignore[reportArgumentType]
3247            icon=src.icon,
3248            id=None if src.id is None else ModelId(src.id),
3249            id_emoji=src.id_emoji,
3250            license=src.license,  # type: ignore
3251            links=src.links,
3252            maintainers=[
3253                _maintainer_conv.convert_as_dict(m) for m in src.maintainers
3254            ],  # pyright: ignore[reportArgumentType]
3255            name=name,
3256            tags=src.tags,
3257            type=src.type,
3258            uploader=src.uploader,
3259            version=src.version,
3260            inputs=[  # pyright: ignore[reportArgumentType]
3261                _input_tensor_conv.convert_as_dict(ipt, tt, st, input_size_refs)
3262                for ipt, tt, st, in zip(
3263                    src.inputs,
3264                    src.test_inputs,
3265                    src.sample_inputs or [None] * len(src.test_inputs),
3266                )
3267            ],
3268            outputs=[  # pyright: ignore[reportArgumentType]
3269                _output_tensor_conv.convert_as_dict(out, tt, st, output_size_refs)
3270                for out, tt, st, in zip(
3271                    src.outputs,
3272                    src.test_outputs,
3273                    src.sample_outputs or [None] * len(src.test_outputs),
3274                )
3275            ],
3276            parent=(
3277                None
3278                if src.parent is None
3279                else LinkedModel(
3280                    id=ModelId(
3281                        str(src.parent.id)
3282                        + (
3283                            ""
3284                            if src.parent.version_number is None
3285                            else f"/{src.parent.version_number}"
3286                        )
3287                    )
3288                )
3289            ),
3290            training_data=(
3291                None
3292                if src.training_data is None
3293                else (
3294                    LinkedDataset(
3295                        id=DatasetId(
3296                            str(src.training_data.id)
3297                            + (
3298                                ""
3299                                if src.training_data.version_number is None
3300                                else f"/{src.training_data.version_number}"
3301                            )
3302                        )
3303                    )
3304                    if isinstance(src.training_data, LinkedDataset02)
3305                    else src.training_data
3306                )
3307            ),
3308            packaged_by=[
3309                _author_conv.convert_as_dict(a) for a in src.packaged_by
3310            ],  # pyright: ignore[reportArgumentType]
3311            run_mode=src.run_mode,
3312            timestamp=src.timestamp,
3313            weights=(WeightsDescr if TYPE_CHECKING else dict)(
3314                keras_hdf5=(w := src.weights.keras_hdf5)
3315                and (KerasHdf5WeightsDescr if TYPE_CHECKING else dict)(
3316                    authors=conv_authors(w.authors),
3317                    source=w.source,
3318                    tensorflow_version=w.tensorflow_version or Version("1.15"),
3319                    parent=w.parent,
3320                ),
3321                onnx=(w := src.weights.onnx)
3322                and (OnnxWeightsDescr if TYPE_CHECKING else dict)(
3323                    source=w.source,
3324                    authors=conv_authors(w.authors),
3325                    parent=w.parent,
3326                    opset_version=w.opset_version or 15,
3327                ),
3328                pytorch_state_dict=(w := src.weights.pytorch_state_dict)
3329                and (PytorchStateDictWeightsDescr if TYPE_CHECKING else dict)(
3330                    source=w.source,
3331                    authors=conv_authors(w.authors),
3332                    parent=w.parent,
3333                    architecture=(
3334                        arch_file_conv(
3335                            w.architecture,
3336                            w.architecture_sha256,
3337                            w.kwargs,
3338                        )
3339                        if isinstance(w.architecture, _CallableFromFile_v0_4)
3340                        else arch_lib_conv(w.architecture, w.kwargs)
3341                    ),
3342                    pytorch_version=w.pytorch_version or Version("1.10"),
3343                    dependencies=(
3344                        None
3345                        if w.dependencies is None
3346                        else (EnvironmentFileDescr if TYPE_CHECKING else dict)(
3347                            source=cast(
3348                                ImportantFileSource,
3349                                str(deps := w.dependencies)[
3350                                    (
3351                                        len("conda:")
3352                                        if str(deps).startswith("conda:")
3353                                        else 0
3354                                    ) :
3355                                ],
3356                            )
3357                        )
3358                    ),
3359                ),
3360                tensorflow_js=(w := src.weights.tensorflow_js)
3361                and (TensorflowJsWeightsDescr if TYPE_CHECKING else dict)(
3362                    source=w.source,
3363                    authors=conv_authors(w.authors),
3364                    parent=w.parent,
3365                    tensorflow_version=w.tensorflow_version or Version("1.15"),
3366                ),
3367                tensorflow_saved_model_bundle=(
3368                    w := src.weights.tensorflow_saved_model_bundle
3369                )
3370                and (TensorflowSavedModelBundleWeightsDescr if TYPE_CHECKING else dict)(
3371                    authors=conv_authors(w.authors),
3372                    parent=w.parent,
3373                    source=w.source,
3374                    tensorflow_version=w.tensorflow_version or Version("1.15"),
3375                    dependencies=(
3376                        None
3377                        if w.dependencies is None
3378                        else (EnvironmentFileDescr if TYPE_CHECKING else dict)(
3379                            source=cast(
3380                                ImportantFileSource,
3381                                (
3382                                    str(w.dependencies)[len("conda:") :]
3383                                    if str(w.dependencies).startswith("conda:")
3384                                    else str(w.dependencies)
3385                                ),
3386                            )
3387                        )
3388                    ),
3389                ),
3390                torchscript=(w := src.weights.torchscript)
3391                and (TorchscriptWeightsDescr if TYPE_CHECKING else dict)(
3392                    source=w.source,
3393                    authors=conv_authors(w.authors),
3394                    parent=w.parent,
3395                    pytorch_version=w.pytorch_version or Version("1.10"),
3396                ),
3397            ),
3398        )
3399
3400
3401_model_conv = _ModelConv(_ModelDescr_v0_4, ModelDescr)
3402
3403
3404# create better cover images for 3d data and non-image outputs
3405def generate_covers(
3406    inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]],
3407    outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]],
3408) -> List[Path]:
3409    def squeeze(
3410        data: NDArray[Any], axes: Sequence[AnyAxis]
3411    ) -> Tuple[NDArray[Any], List[AnyAxis]]:
3412        """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining"""
3413        if data.ndim != len(axes):
3414            raise ValueError(
3415                f"tensor shape {data.shape} does not match described axes"
3416                + f" {[a.id for a in axes]}"
3417            )
3418
3419        axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1]
3420        return data.squeeze(), axes
3421
3422    def normalize(
3423        data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7
3424    ) -> NDArray[np.float32]:
3425        data = data.astype("float32")
3426        data -= data.min(axis=axis, keepdims=True)
3427        data /= data.max(axis=axis, keepdims=True) + eps
3428        return data
3429
3430    def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]):
3431        original_shape = data.shape
3432        data, axes = squeeze(data, axes)
3433
3434        # take slice fom any batch or index axis if needed
3435        # and convert the first channel axis and take a slice from any additional channel axes
3436        slices: Tuple[slice, ...] = ()
3437        ndim = data.ndim
3438        ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2
3439        has_c_axis = False
3440        for i, a in enumerate(axes):
3441            s = data.shape[i]
3442            assert s > 1
3443            if (
3444                isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis))
3445                and ndim > ndim_need
3446            ):
3447                data = data[slices + (slice(s // 2 - 1, s // 2),)]
3448                ndim -= 1
3449            elif isinstance(a, ChannelAxis):
3450                if has_c_axis:
3451                    # second channel axis
3452                    data = data[slices + (slice(0, 1),)]
3453                    ndim -= 1
3454                else:
3455                    has_c_axis = True
3456                    if s == 2:
3457                        # visualize two channels with cyan and magenta
3458                        data = np.concatenate(
3459                            [
3460                                data[slices + (slice(1, 2),)],
3461                                data[slices + (slice(0, 1),)],
3462                                (
3463                                    data[slices + (slice(0, 1),)]
3464                                    + data[slices + (slice(1, 2),)]
3465                                )
3466                                / 2,  # TODO: take maximum instead?
3467                            ],
3468                            axis=i,
3469                        )
3470                    elif data.shape[i] == 3:
3471                        pass  # visualize 3 channels as RGB
3472                    else:
3473                        # visualize first 3 channels as RGB
3474                        data = data[slices + (slice(3),)]
3475
3476                    assert data.shape[i] == 3
3477
3478            slices += (slice(None),)
3479
3480        data, axes = squeeze(data, axes)
3481        assert len(axes) == ndim
3482        # take slice from z axis if needed
3483        slices = ()
3484        if ndim > ndim_need:
3485            for i, a in enumerate(axes):
3486                s = data.shape[i]
3487                if a.id == AxisId("z"):
3488                    data = data[slices + (slice(s // 2 - 1, s // 2),)]
3489                    data, axes = squeeze(data, axes)
3490                    ndim -= 1
3491                    break
3492
3493            slices += (slice(None),)
3494
3495        # take slice from any space or time axis
3496        slices = ()
3497
3498        for i, a in enumerate(axes):
3499            if ndim <= ndim_need:
3500                break
3501
3502            s = data.shape[i]
3503            assert s > 1
3504            if isinstance(
3505                a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis)
3506            ):
3507                data = data[slices + (slice(s // 2 - 1, s // 2),)]
3508                ndim -= 1
3509
3510            slices += (slice(None),)
3511
3512        del slices
3513        data, axes = squeeze(data, axes)
3514        assert len(axes) == ndim
3515
3516        if (has_c_axis and ndim != 3) or ndim != 2:
3517            raise ValueError(
3518                f"Failed to construct cover image from shape {original_shape}"
3519            )
3520
3521        if not has_c_axis:
3522            assert ndim == 2
3523            data = np.repeat(data[:, :, None], 3, axis=2)
3524            axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB"))))
3525            ndim += 1
3526
3527        assert ndim == 3
3528
3529        # transpose axis order such that longest axis comes first...
3530        axis_order = list(np.argsort(list(data.shape)))
3531        axis_order.reverse()
3532        # ... and channel axis is last
3533        c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0]
3534        axis_order.append(axis_order.pop(c))
3535        axes = [axes[ao] for ao in axis_order]
3536        data = data.transpose(axis_order)
3537
3538        # h, w = data.shape[:2]
3539        # if h / w  in (1.0 or 2.0):
3540        #     pass
3541        # elif h / w < 2:
3542        # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images
3543
3544        norm_along = (
3545            tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None
3546        )
3547        # normalize the data and map to 8 bit
3548        data = normalize(data, norm_along)
3549        data = (data * 255).astype("uint8")
3550
3551        return data
3552
3553    def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]):
3554        assert im0.dtype == im1.dtype == np.uint8
3555        assert im0.shape == im1.shape
3556        assert im0.ndim == 3
3557        N, M, C = im0.shape
3558        assert C == 3
3559        out = np.ones((N, M, C), dtype="uint8")
3560        for c in range(C):
3561            outc = np.tril(im0[..., c])
3562            mask = outc == 0
3563            outc[mask] = np.triu(im1[..., c])[mask]
3564            out[..., c] = outc
3565
3566        return out
3567
3568    ipt_descr, ipt = inputs[0]
3569    out_descr, out = outputs[0]
3570
3571    ipt_img = to_2d_image(ipt, ipt_descr.axes)
3572    out_img = to_2d_image(out, out_descr.axes)
3573
3574    cover_folder = Path(mkdtemp())
3575    if ipt_img.shape == out_img.shape:
3576        covers = [cover_folder / "cover.png"]
3577        imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img))
3578    else:
3579        covers = [cover_folder / "input.png", cover_folder / "output.png"]
3580        imwrite(covers[0], ipt_img)
3581        imwrite(covers[1], out_img)
3582
3583    return covers
SpaceUnit = typing.Literal['attometer', 'angstrom', 'centimeter', 'decimeter', 'exameter', 'femtometer', 'foot', 'gigameter', 'hectometer', 'inch', 'kilometer', 'megameter', 'meter', 'micrometer', 'mile', 'millimeter', 'nanometer', 'parsec', 'petameter', 'picometer', 'terameter', 'yard', 'yoctometer', 'yottameter', 'zeptometer', 'zettameter']

Space unit compatible to the OME-Zarr axes specification 0.5

TimeUnit = typing.Literal['attosecond', 'centisecond', 'day', 'decisecond', 'exasecond', 'femtosecond', 'gigasecond', 'hectosecond', 'hour', 'kilosecond', 'megasecond', 'microsecond', 'millisecond', 'minute', 'nanosecond', 'petasecond', 'picosecond', 'second', 'terasecond', 'yoctosecond', 'yottasecond', 'zeptosecond', 'zettasecond']

Time unit compatible to the OME-Zarr axes specification 0.5

AxisType = typing.Literal['batch', 'channel', 'index', 'time', 'space']
class TensorId(bioimageio.spec._internal.types.LowerCaseIdentifier):
213class TensorId(LowerCaseIdentifier):
214    root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
215        Annotated[LowerCaseIdentifierAnno, MaxLen(32)]
216    ]

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.

root_model: ClassVar[Type[pydantic.root_model.RootModel[Any]]] = <class 'pydantic.root_model.RootModel[Annotated[str, MinLen, AfterValidator, AfterValidator, Annotated[TypeVar, Predicate], MaxLen]]'>

the pydantic root model to validate the string

class AxisId(bioimageio.spec._internal.types.LowerCaseIdentifier):
229class AxisId(LowerCaseIdentifier):
230    root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
231        Annotated[
232            LowerCaseIdentifierAnno,
233            MaxLen(16),
234            AfterValidator(_normalize_axis_id),
235        ]
236    ]

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.

root_model: ClassVar[Type[pydantic.root_model.RootModel[Any]]] = <class 'pydantic.root_model.RootModel[Annotated[str, MinLen, AfterValidator, AfterValidator, Annotated[TypeVar, Predicate], MaxLen, AfterValidator]]'>

the pydantic root model to validate the string

NonBatchAxisId = typing.Annotated[AxisId, Predicate(_is_not_batch)]
PostprocessingId = typing.Literal['binarize', 'clip', 'ensure_dtype', 'fixed_zero_mean_unit_variance', 'scale_linear', 'scale_mean_variance', 'scale_range', 'sigmoid', 'zero_mean_unit_variance']
PreprocessingId = typing.Literal['binarize', 'clip', 'ensure_dtype', 'scale_linear', 'sigmoid', 'zero_mean_unit_variance', 'scale_range']
SAME_AS_TYPE = '<same as type>'
ParameterizedSize_N = <class 'int'>

Annotates an integer to calculate a concrete axis size from a ParameterizedSize.

class ParameterizedSize(bioimageio.spec._internal.node.Node):
280class ParameterizedSize(Node):
281    """Describes a range of valid tensor axis sizes as `size = min + n*step`.
282
283    - **min** and **step** are given by the model description.
284    - All blocksize paramters n = 0,1,2,... yield a valid `size`.
285    - A greater blocksize paramter n = 0,1,2,... results in a greater **size**.
286      This allows to adjust the axis size more generically.
287    """
288
289    N: ClassVar[Type[int]] = ParameterizedSize_N
290    """Positive integer to parameterize this axis"""
291
292    min: Annotated[int, Gt(0)]
293    step: Annotated[int, Gt(0)]
294
295    def validate_size(self, size: int) -> int:
296        if size < self.min:
297            raise ValueError(f"size {size} < {self.min}")
298        if (size - self.min) % self.step != 0:
299            raise ValueError(
300                f"axis of size {size} is not parameterized by `min + n*step` ="
301                + f" `{self.min} + n*{self.step}`"
302            )
303
304        return size
305
306    def get_size(self, n: ParameterizedSize_N) -> int:
307        return self.min + self.step * n
308
309    def get_n(self, s: int) -> ParameterizedSize_N:
310        """return smallest n parameterizing a size greater or equal than `s`"""
311        return ceil((s - self.min) / self.step)

Describes a range of valid tensor axis sizes as size = min + n*step.

  • min and step are given by the model description.
  • All blocksize paramters n = 0,1,2,... yield a valid size.
  • A greater blocksize paramter n = 0,1,2,... results in a greater size. This allows to adjust the axis size more generically.
N: ClassVar[Type[int]] = <class 'int'>

Positive integer to parameterize this axis

min: Annotated[int, Gt(gt=0)]
step: Annotated[int, Gt(gt=0)]
def validate_size(self, size: int) -> int:
295    def validate_size(self, size: int) -> int:
296        if size < self.min:
297            raise ValueError(f"size {size} < {self.min}")
298        if (size - self.min) % self.step != 0:
299            raise ValueError(
300                f"axis of size {size} is not parameterized by `min + n*step` ="
301                + f" `{self.min} + n*{self.step}`"
302            )
303
304        return size
def get_size(self, n: int) -> int:
306    def get_size(self, n: ParameterizedSize_N) -> int:
307        return self.min + self.step * n
def get_n(self, s: int) -> int:
309    def get_n(self, s: int) -> ParameterizedSize_N:
310        """return smallest n parameterizing a size greater or equal than `s`"""
311        return ceil((s - self.min) / self.step)

return smallest n parameterizing a size greater or equal than s

class DataDependentSize(bioimageio.spec._internal.node.Node):
314class DataDependentSize(Node):
315    min: Annotated[int, Gt(0)] = 1
316    max: Annotated[Optional[int], Gt(1)] = None
317
318    @model_validator(mode="after")
319    def _validate_max_gt_min(self):
320        if self.max is not None and self.min >= self.max:
321            raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}")
322
323        return self
324
325    def validate_size(self, size: int) -> int:
326        if size < self.min:
327            raise ValueError(f"size {size} < {self.min}")
328
329        if self.max is not None and size > self.max:
330            raise ValueError(f"size {size} > {self.max}")
331
332        return size
min: Annotated[int, Gt(gt=0)]
max: Annotated[Optional[int], Gt(gt=1)]
def validate_size(self, size: int) -> int:
325    def validate_size(self, size: int) -> int:
326        if size < self.min:
327            raise ValueError(f"size {size} < {self.min}")
328
329        if self.max is not None and size > self.max:
330            raise ValueError(f"size {size} > {self.max}")
331
332        return size
class SizeReference(bioimageio.spec._internal.node.Node):
335class SizeReference(Node):
336    """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis.
337
338    `axis.size = reference.size * reference.scale / axis.scale + offset`
339
340    Note:
341    1. The axis and the referenced axis need to have the same unit (or no unit).
342    2. Batch axes may not be referenced.
343    3. Fractions are rounded down.
344    4. If the reference axis is `concatenable` the referencing axis is assumed to be
345        `concatenable` as well with the same block order.
346
347    Example:
348    An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm².
349    Let's assume that we want to express the image height h in relation to its width w
350    instead of only accepting input images of exactly 100*49 pixels
351    (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`).
352
353    >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2)
354    >>> h = SpaceInputAxis(
355    ...     id=AxisId("h"),
356    ...     size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1),
357    ...     unit="millimeter",
358    ...     scale=4,
359    ... )
360    >>> print(h.size.get_size(h, w))
361    49
362
363    ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49
364    """
365
366    tensor_id: TensorId
367    """tensor id of the reference axis"""
368
369    axis_id: AxisId
370    """axis id of the reference axis"""
371
372    offset: int = 0
373
374    def get_size(
375        self,
376        axis: Union[
377            ChannelAxis,
378            IndexInputAxis,
379            IndexOutputAxis,
380            TimeInputAxis,
381            SpaceInputAxis,
382            TimeOutputAxis,
383            TimeOutputAxisWithHalo,
384            SpaceOutputAxis,
385            SpaceOutputAxisWithHalo,
386        ],
387        ref_axis: Union[
388            ChannelAxis,
389            IndexInputAxis,
390            IndexOutputAxis,
391            TimeInputAxis,
392            SpaceInputAxis,
393            TimeOutputAxis,
394            TimeOutputAxisWithHalo,
395            SpaceOutputAxis,
396            SpaceOutputAxisWithHalo,
397        ],
398        n: ParameterizedSize_N = 0,
399        ref_size: Optional[int] = None,
400    ):
401        """Compute the concrete size for a given axis and its reference axis.
402
403        Args:
404            axis: The axis this `SizeReference` is the size of.
405            ref_axis: The reference axis to compute the size from.
406            n: If the **ref_axis** is parameterized (of type `ParameterizedSize`)
407                and no fixed **ref_size** is given,
408                **n** is used to compute the size of the parameterized **ref_axis**.
409            ref_size: Overwrite the reference size instead of deriving it from
410                **ref_axis**
411                (**ref_axis.scale** is still used; any given **n** is ignored).
412        """
413        assert (
414            axis.size == self
415        ), "Given `axis.size` is not defined by this `SizeReference`"
416
417        assert (
418            ref_axis.id == self.axis_id
419        ), f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}."
420
421        assert axis.unit == ref_axis.unit, (
422            "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`,"
423            f" but {axis.unit}!={ref_axis.unit}"
424        )
425        if ref_size is None:
426            if isinstance(ref_axis.size, (int, float)):
427                ref_size = ref_axis.size
428            elif isinstance(ref_axis.size, ParameterizedSize):
429                ref_size = ref_axis.size.get_size(n)
430            elif isinstance(ref_axis.size, DataDependentSize):
431                raise ValueError(
432                    "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`."
433                )
434            elif isinstance(ref_axis.size, SizeReference):
435                raise ValueError(
436                    "Reference axis referenced in `SizeReference` may not be sized by a"
437                    + " `SizeReference` itself."
438                )
439            else:
440                assert_never(ref_axis.size)
441
442        return int(ref_size * ref_axis.scale / axis.scale + self.offset)
443
444    @staticmethod
445    def _get_unit(
446        axis: Union[
447            ChannelAxis,
448            IndexInputAxis,
449            IndexOutputAxis,
450            TimeInputAxis,
451            SpaceInputAxis,
452            TimeOutputAxis,
453            TimeOutputAxisWithHalo,
454            SpaceOutputAxis,
455            SpaceOutputAxisWithHalo,
456        ],
457    ):
458        return axis.unit

A tensor axis size (extent in pixels/frames) defined in relation to a reference axis.

axis.size = reference.size * reference.scale / axis.scale + offset

Note:

  1. The axis and the referenced axis need to have the same unit (or no unit).
  2. Batch axes may not be referenced.
  3. Fractions are rounded down.
  4. If the reference axis is concatenable the referencing axis is assumed to be concatenable as well with the same block order.

Example: An unisotropic input image of wh=10049 pixels depicts a phsical space of 200196mm². Let's assume that we want to express the image height h in relation to its width w instead of only accepting input images of exactly 10049 pixels (for example to express a range of valid image shapes by parametrizing w, see ParameterizedSize).

>>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2)
>>> h = SpaceInputAxis(
...     id=AxisId("h"),
...     size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1),
...     unit="millimeter",
...     scale=4,
... )
>>> print(h.size.get_size(h, w))
49

⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49

tensor_id: TensorId

tensor id of the reference axis

axis_id: AxisId

axis id of the reference axis

offset: int
374    def get_size(
375        self,
376        axis: Union[
377            ChannelAxis,
378            IndexInputAxis,
379            IndexOutputAxis,
380            TimeInputAxis,
381            SpaceInputAxis,
382            TimeOutputAxis,
383            TimeOutputAxisWithHalo,
384            SpaceOutputAxis,
385            SpaceOutputAxisWithHalo,
386        ],
387        ref_axis: Union[
388            ChannelAxis,
389            IndexInputAxis,
390            IndexOutputAxis,
391            TimeInputAxis,
392            SpaceInputAxis,
393            TimeOutputAxis,
394            TimeOutputAxisWithHalo,
395            SpaceOutputAxis,
396            SpaceOutputAxisWithHalo,
397        ],
398        n: ParameterizedSize_N = 0,
399        ref_size: Optional[int] = None,
400    ):
401        """Compute the concrete size for a given axis and its reference axis.
402
403        Args:
404            axis: The axis this `SizeReference` is the size of.
405            ref_axis: The reference axis to compute the size from.
406            n: If the **ref_axis** is parameterized (of type `ParameterizedSize`)
407                and no fixed **ref_size** is given,
408                **n** is used to compute the size of the parameterized **ref_axis**.
409            ref_size: Overwrite the reference size instead of deriving it from
410                **ref_axis**
411                (**ref_axis.scale** is still used; any given **n** is ignored).
412        """
413        assert (
414            axis.size == self
415        ), "Given `axis.size` is not defined by this `SizeReference`"
416
417        assert (
418            ref_axis.id == self.axis_id
419        ), f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}."
420
421        assert axis.unit == ref_axis.unit, (
422            "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`,"
423            f" but {axis.unit}!={ref_axis.unit}"
424        )
425        if ref_size is None:
426            if isinstance(ref_axis.size, (int, float)):
427                ref_size = ref_axis.size
428            elif isinstance(ref_axis.size, ParameterizedSize):
429                ref_size = ref_axis.size.get_size(n)
430            elif isinstance(ref_axis.size, DataDependentSize):
431                raise ValueError(
432                    "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`."
433                )
434            elif isinstance(ref_axis.size, SizeReference):
435                raise ValueError(
436                    "Reference axis referenced in `SizeReference` may not be sized by a"
437                    + " `SizeReference` itself."
438                )
439            else:
440                assert_never(ref_axis.size)
441
442        return int(ref_size * ref_axis.scale / axis.scale + self.offset)

Compute the concrete size for a given axis and its reference axis.

Arguments:
  • axis: The axis this SizeReference is the size of.
  • ref_axis: The reference axis to compute the size from.
  • n: If the ref_axis is parameterized (of type ParameterizedSize) and no fixed ref_size is given, n is used to compute the size of the parameterized ref_axis.
  • ref_size: Overwrite the reference size instead of deriving it from ref_axis (ref_axis.scale is still used; any given n is ignored).
class AxisBase(bioimageio.spec._internal.common_nodes.NodeWithExplicitlySetFields):
461class AxisBase(NodeWithExplicitlySetFields):
462    id: AxisId
463    """An axis id unique across all axes of one tensor."""
464
465    description: Annotated[str, MaxLen(128)] = ""
id: AxisId

An axis id unique across all axes of one tensor.

description: Annotated[str, MaxLen(max_length=128)]
class WithHalo(bioimageio.spec._internal.node.Node):
468class WithHalo(Node):
469    halo: Annotated[int, Ge(1)]
470    """The halo should be cropped from the output tensor to avoid boundary effects.
471    It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`.
472    To document a halo that is already cropped by the model use `size.offset` instead."""
473
474    size: Annotated[
475        SizeReference,
476        Field(
477            examples=[
478                10,
479                SizeReference(
480                    tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
481                ).model_dump(mode="json"),
482            ]
483        ),
484    ]
485    """reference to another axis with an optional offset (see `SizeReference`)"""
halo: Annotated[int, Ge(ge=1)]

The halo should be cropped from the output tensor to avoid boundary effects. It is to be cropped from both sides, i.e. size_after_crop = size - 2 * halo. To document a halo that is already cropped by the model use size.offset instead.

size: Annotated[SizeReference, FieldInfo(annotation=NoneType, required=True, examples=[10, {'tensor_id': 't', 'axis_id': 'a', 'offset': 5}])]

reference to another axis with an optional offset (see SizeReference)

BATCH_AXIS_ID = 'batch'
class BatchAxis(AxisBase):
491class BatchAxis(AxisBase):
492    implemented_type: ClassVar[Literal["batch"]] = "batch"
493    if TYPE_CHECKING:
494        type: Literal["batch"] = "batch"
495    else:
496        type: Literal["batch"]
497
498    id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID
499    size: Optional[Literal[1]] = None
500    """The batch size may be fixed to 1,
501    otherwise (the default) it may be chosen arbitrarily depending on available memory"""
502
503    @property
504    def scale(self):
505        return 1.0
506
507    @property
508    def concatenable(self):
509        return True
510
511    @property
512    def unit(self):
513        return None
implemented_type: ClassVar[Literal['batch']] = 'batch'
id: Annotated[AxisId, Predicate(_is_batch)]

An axis id unique across all axes of one tensor.

size: Optional[Literal[1]]

The batch size may be fixed to 1, otherwise (the default) it may be chosen arbitrarily depending on available memory

scale
503    @property
504    def scale(self):
505        return 1.0
concatenable
507    @property
508    def concatenable(self):
509        return True
unit
511    @property
512    def unit(self):
513        return None
type: Literal['batch']
Inherited Members
AxisBase
description
class ChannelAxis(AxisBase):
516class ChannelAxis(AxisBase):
517    implemented_type: ClassVar[Literal["channel"]] = "channel"
518    if TYPE_CHECKING:
519        type: Literal["channel"] = "channel"
520    else:
521        type: Literal["channel"]
522
523    id: NonBatchAxisId = AxisId("channel")
524    channel_names: NotEmpty[List[Identifier]]
525
526    @property
527    def size(self) -> int:
528        return len(self.channel_names)
529
530    @property
531    def concatenable(self):
532        return False
533
534    @property
535    def scale(self) -> float:
536        return 1.0
537
538    @property
539    def unit(self):
540        return None
implemented_type: ClassVar[Literal['channel']] = 'channel'
id: Annotated[AxisId, Predicate(_is_not_batch)]

An axis id unique across all axes of one tensor.

channel_names: Annotated[List[bioimageio.spec._internal.types.Identifier], MinLen(min_length=1)]
size: int
526    @property
527    def size(self) -> int:
528        return len(self.channel_names)
concatenable
530    @property
531    def concatenable(self):
532        return False
scale: float
534    @property
535    def scale(self) -> float:
536        return 1.0
unit
538    @property
539    def unit(self):
540        return None
type: Literal['channel']
Inherited Members
AxisBase
description
class IndexAxisBase(AxisBase):
543class IndexAxisBase(AxisBase):
544    implemented_type: ClassVar[Literal["index"]] = "index"
545    if TYPE_CHECKING:
546        type: Literal["index"] = "index"
547    else:
548        type: Literal["index"]
549
550    id: NonBatchAxisId = AxisId("index")
551
552    @property
553    def scale(self) -> float:
554        return 1.0
555
556    @property
557    def unit(self):
558        return None
implemented_type: ClassVar[Literal['index']] = 'index'
id: Annotated[AxisId, Predicate(_is_not_batch)]

An axis id unique across all axes of one tensor.

scale: float
552    @property
553    def scale(self) -> float:
554        return 1.0
unit
556    @property
557    def unit(self):
558        return None
type: Literal['index']
Inherited Members
AxisBase
description
class IndexInputAxis(IndexAxisBase, _WithInputAxisSize):
581class IndexInputAxis(IndexAxisBase, _WithInputAxisSize):
582    concatenable: bool = False
583    """If a model has a `concatenable` input axis, it can be processed blockwise,
584    splitting a longer sample axis into blocks matching its input tensor description.
585    Output axes are concatenable if they have a `SizeReference` to a concatenable
586    input axis.
587    """
concatenable: bool

If a model has a concatenable input axis, it can be processed blockwise, splitting a longer sample axis into blocks matching its input tensor description. Output axes are concatenable if they have a SizeReference to a concatenable input axis.

type: Literal['index']
class IndexOutputAxis(IndexAxisBase):
590class IndexOutputAxis(IndexAxisBase):
591    size: Annotated[
592        Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize],
593        Field(
594            examples=[
595                10,
596                SizeReference(
597                    tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
598                ).model_dump(mode="json"),
599            ]
600        ),
601    ]
602    """The size/length of this axis can be specified as
603    - fixed integer
604    - reference to another axis with an optional offset (`SizeReference`)
605    - data dependent size using `DataDependentSize` (size is only known after model inference)
606    """
size: Annotated[Union[Annotated[int, Gt(gt=0)], SizeReference, DataDependentSize], FieldInfo(annotation=NoneType, required=True, examples=[10, {'tensor_id': 't', 'axis_id': 'a', 'offset': 5}])]

The size/length of this axis can be specified as

  • fixed integer
  • reference to another axis with an optional offset (SizeReference)
  • data dependent size using DataDependentSize (size is only known after model inference)
type: Literal['index']
class TimeAxisBase(AxisBase):
609class TimeAxisBase(AxisBase):
610    implemented_type: ClassVar[Literal["time"]] = "time"
611    if TYPE_CHECKING:
612        type: Literal["time"] = "time"
613    else:
614        type: Literal["time"]
615
616    id: NonBatchAxisId = AxisId("time")
617    unit: Optional[TimeUnit] = None
618    scale: Annotated[float, Gt(0)] = 1.0
implemented_type: ClassVar[Literal['time']] = 'time'
id: Annotated[AxisId, Predicate(_is_not_batch)]

An axis id unique across all axes of one tensor.

unit: Optional[Literal['attosecond', 'centisecond', 'day', 'decisecond', 'exasecond', 'femtosecond', 'gigasecond', 'hectosecond', 'hour', 'kilosecond', 'megasecond', 'microsecond', 'millisecond', 'minute', 'nanosecond', 'petasecond', 'picosecond', 'second', 'terasecond', 'yoctosecond', 'yottasecond', 'zeptosecond', 'zettasecond']]
scale: Annotated[float, Gt(gt=0)]
type: Literal['time']
Inherited Members
AxisBase
description
class TimeInputAxis(TimeAxisBase, _WithInputAxisSize):
621class TimeInputAxis(TimeAxisBase, _WithInputAxisSize):
622    concatenable: bool = False
623    """If a model has a `concatenable` input axis, it can be processed blockwise,
624    splitting a longer sample axis into blocks matching its input tensor description.
625    Output axes are concatenable if they have a `SizeReference` to a concatenable
626    input axis.
627    """
concatenable: bool

If a model has a concatenable input axis, it can be processed blockwise, splitting a longer sample axis into blocks matching its input tensor description. Output axes are concatenable if they have a SizeReference to a concatenable input axis.

type: Literal['time']
class SpaceAxisBase(AxisBase):
630class SpaceAxisBase(AxisBase):
631    implemented_type: ClassVar[Literal["space"]] = "space"
632    if TYPE_CHECKING:
633        type: Literal["space"] = "space"
634    else:
635        type: Literal["space"]
636
637    id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x")
638    unit: Optional[SpaceUnit] = None
639    scale: Annotated[float, Gt(0)] = 1.0
implemented_type: ClassVar[Literal['space']] = 'space'
id: Annotated[AxisId, Predicate(_is_not_batch), FieldInfo(annotation=NoneType, required=True, examples=['x', 'y', 'z'])]

An axis id unique across all axes of one tensor.

unit: Optional[Literal['attometer', 'angstrom', 'centimeter', 'decimeter', 'exameter', 'femtometer', 'foot', 'gigameter', 'hectometer', 'inch', 'kilometer', 'megameter', 'meter', 'micrometer', 'mile', 'millimeter', 'nanometer', 'parsec', 'petameter', 'picometer', 'terameter', 'yard', 'yoctometer', 'yottameter', 'zeptometer', 'zettameter']]
scale: Annotated[float, Gt(gt=0)]
type: Literal['space']
Inherited Members
AxisBase
description
class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize):
642class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize):
643    concatenable: bool = False
644    """If a model has a `concatenable` input axis, it can be processed blockwise,
645    splitting a longer sample axis into blocks matching its input tensor description.
646    Output axes are concatenable if they have a `SizeReference` to a concatenable
647    input axis.
648    """
concatenable: bool

If a model has a concatenable input axis, it can be processed blockwise, splitting a longer sample axis into blocks matching its input tensor description. Output axes are concatenable if they have a SizeReference to a concatenable input axis.

type: Literal['space']
INPUT_AXIS_TYPES = (<class 'BatchAxis'>, <class 'ChannelAxis'>, <class 'IndexInputAxis'>, <class 'TimeInputAxis'>, <class 'SpaceInputAxis'>)

intended for isinstance comparisons in py<3.10

InputAxis = typing.Annotated[typing.Union[BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)]
class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize):
684class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize):
685    pass
type: Literal['time']
class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo):
688class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo):
689    pass
type: Literal['time']
class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize):
708class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize):
709    pass
type: Literal['space']
class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo):
712class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo):
713    pass
type: Literal['space']
OutputAxis = typing.Annotated[typing.Union[BatchAxis, ChannelAxis, IndexOutputAxis, typing.Annotated[typing.Union[typing.Annotated[TimeOutputAxis, Tag(tag='wo_halo')], typing.Annotated[TimeOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)], typing.Annotated[typing.Union[typing.Annotated[SpaceOutputAxis, Tag(tag='wo_halo')], typing.Annotated[SpaceOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)]
OUTPUT_AXIS_TYPES = (<class 'BatchAxis'>, <class 'ChannelAxis'>, <class 'IndexOutputAxis'>, <class 'TimeOutputAxis'>, <class 'TimeOutputAxisWithHalo'>, <class 'SpaceOutputAxis'>, <class 'SpaceOutputAxisWithHalo'>)

intended for isinstance comparisons in py<3.10

AnyAxis = typing.Union[typing.Annotated[typing.Union[BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], typing.Annotated[typing.Union[BatchAxis, ChannelAxis, IndexOutputAxis, typing.Annotated[typing.Union[typing.Annotated[TimeOutputAxis, Tag(tag='wo_halo')], typing.Annotated[TimeOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)], typing.Annotated[typing.Union[typing.Annotated[SpaceOutputAxis, Tag(tag='wo_halo')], typing.Annotated[SpaceOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]
ANY_AXIS_TYPES = (<class 'BatchAxis'>, <class 'ChannelAxis'>, <class 'IndexInputAxis'>, <class 'TimeInputAxis'>, <class 'SpaceInputAxis'>, <class 'BatchAxis'>, <class 'ChannelAxis'>, <class 'IndexOutputAxis'>, <class 'TimeOutputAxis'>, <class 'TimeOutputAxisWithHalo'>, <class 'SpaceOutputAxis'>, <class 'SpaceOutputAxisWithHalo'>)

intended for isinstance comparisons in py<3.10

TVs = typing.Union[typing.Annotated[typing.List[int], MinLen(min_length=1)], typing.Annotated[typing.List[float], MinLen(min_length=1)], typing.Annotated[typing.List[bool], MinLen(min_length=1)], typing.Annotated[typing.List[str], MinLen(min_length=1)]]
NominalOrOrdinalDType = typing.Literal['float32', 'float64', 'uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64', 'bool']
class NominalOrOrdinalDataDescr(bioimageio.spec._internal.node.Node):
770class NominalOrOrdinalDataDescr(Node):
771    values: TVs
772    """A fixed set of nominal or an ascending sequence of ordinal values.
773    In this case `data.type` is required to be an unsigend integer type, e.g. 'uint8'.
774    String `values` are interpreted as labels for tensor values 0, ..., N.
775    Note: as YAML 1.2 does not natively support a "set" datatype,
776    nominal values should be given as a sequence (aka list/array) as well.
777    """
778
779    type: Annotated[
780        NominalOrOrdinalDType,
781        Field(
782            examples=[
783                "float32",
784                "uint8",
785                "uint16",
786                "int64",
787                "bool",
788            ],
789        ),
790    ] = "uint8"
791
792    @model_validator(mode="after")
793    def _validate_values_match_type(
794        self,
795    ) -> Self:
796        incompatible: List[Any] = []
797        for v in self.values:
798            if self.type == "bool":
799                if not isinstance(v, bool):
800                    incompatible.append(v)
801            elif self.type in DTYPE_LIMITS:
802                if (
803                    isinstance(v, (int, float))
804                    and (
805                        v < DTYPE_LIMITS[self.type].min
806                        or v > DTYPE_LIMITS[self.type].max
807                    )
808                    or (isinstance(v, str) and "uint" not in self.type)
809                    or (isinstance(v, float) and "int" in self.type)
810                ):
811                    incompatible.append(v)
812            else:
813                incompatible.append(v)
814
815            if len(incompatible) == 5:
816                incompatible.append("...")
817                break
818
819        if incompatible:
820            raise ValueError(
821                f"data type '{self.type}' incompatible with values {incompatible}"
822            )
823
824        return self
825
826    unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None
827
828    @property
829    def range(self):
830        if isinstance(self.values[0], str):
831            return 0, len(self.values) - 1
832        else:
833            return min(self.values), max(self.values)
values: Union[Annotated[List[int], MinLen(min_length=1)], Annotated[List[float], MinLen(min_length=1)], Annotated[List[bool], MinLen(min_length=1)], Annotated[List[str], MinLen(min_length=1)]]

A fixed set of nominal or an ascending sequence of ordinal values. In this case data.type is required to be an unsigend integer type, e.g. 'uint8'. String values are interpreted as labels for tensor values 0, ..., N. Note: as YAML 1.2 does not natively support a "set" datatype, nominal values should be given as a sequence (aka list/array) as well.

type: Annotated[Literal['float32', 'float64', 'uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64', 'bool'], FieldInfo(annotation=NoneType, required=True, examples=['float32', 'uint8', 'uint16', 'int64', 'bool'])]
unit: Union[Literal['arbitrary unit'], bioimageio.spec._internal.types.SiUnit, NoneType]
range
828    @property
829    def range(self):
830        if isinstance(self.values[0], str):
831            return 0, len(self.values) - 1
832        else:
833            return min(self.values), max(self.values)
IntervalOrRatioDType = typing.Literal['float32', 'float64', 'uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64']
class IntervalOrRatioDataDescr(bioimageio.spec._internal.node.Node):
850class IntervalOrRatioDataDescr(Node):
851    type: Annotated[  # todo: rename to dtype
852        IntervalOrRatioDType,
853        Field(
854            examples=["float32", "float64", "uint8", "uint16"],
855        ),
856    ] = "float32"
857    range: Tuple[Optional[float], Optional[float]] = (
858        None,
859        None,
860    )
861    """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor.
862    `None` corresponds to min/max of what can be expressed by **type**."""
863    unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit"
864    scale: float = 1.0
865    """Scale for data on an interval (or ratio) scale."""
866    offset: Optional[float] = None
867    """Offset for data on a ratio scale."""
type: Annotated[Literal['float32', 'float64', 'uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64'], FieldInfo(annotation=NoneType, required=True, examples=['float32', 'float64', 'uint8', 'uint16'])]
range: Tuple[Optional[float], Optional[float]]

Tuple (minimum, maximum) specifying the allowed range of the data in this tensor. None corresponds to min/max of what can be expressed by type.

unit: Union[Literal['arbitrary unit'], bioimageio.spec._internal.types.SiUnit]
scale: float

Scale for data on an interval (or ratio) scale.

offset: Optional[float]

Offset for data on a ratio scale.

TensorDataDescr = typing.Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr]
class ProcessingDescrBase(bioimageio.spec._internal.common_nodes.NodeWithExplicitlySetFields, abc.ABC):
873class ProcessingDescrBase(NodeWithExplicitlySetFields, ABC):
874    """processing base class"""

processing base class

class BinarizeKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
877class BinarizeKwargs(ProcessingKwargs):
878    """key word arguments for `BinarizeDescr`"""
879
880    threshold: float
881    """The fixed threshold"""

key word arguments for BinarizeDescr

threshold: float

The fixed threshold

class BinarizeAlongAxisKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
884class BinarizeAlongAxisKwargs(ProcessingKwargs):
885    """key word arguments for `BinarizeDescr`"""
886
887    threshold: NotEmpty[List[float]]
888    """The fixed threshold values along `axis`"""
889
890    axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
891    """The `threshold` axis"""

key word arguments for BinarizeDescr

threshold: Annotated[List[float], MinLen(min_length=1)]

The fixed threshold values along axis

axis: Annotated[AxisId, Predicate(_is_not_batch), FieldInfo(annotation=NoneType, required=True, examples=['channel'])]

The threshold axis

class BinarizeDescr(ProcessingDescrBase):
894class BinarizeDescr(ProcessingDescrBase):
895    """Binarize the tensor with a fixed threshold.
896
897    Values above `BinarizeKwargs.threshold`/`BinarizeAlongAxisKwargs.threshold`
898    will be set to one, values below the threshold to zero.
899
900    Examples:
901    - in YAML
902        ```yaml
903        postprocessing:
904          - id: binarize
905            kwargs:
906              axis: 'channel'
907              threshold: [0.25, 0.5, 0.75]
908        ```
909    - in Python:
910        >>> postprocessing = [BinarizeDescr(
911        ...   kwargs=BinarizeAlongAxisKwargs(
912        ...       axis=AxisId('channel'),
913        ...       threshold=[0.25, 0.5, 0.75],
914        ...   )
915        ... )]
916    """
917
918    implemented_id: ClassVar[Literal["binarize"]] = "binarize"
919    if TYPE_CHECKING:
920        id: Literal["binarize"] = "binarize"
921    else:
922        id: Literal["binarize"]
923    kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs]

Binarize the tensor with a fixed threshold.

Values above BinarizeKwargs.threshold/BinarizeAlongAxisKwargs.threshold will be set to one, values below the threshold to zero.

Examples:

  • in YAML
postprocessing:
  - id: binarize
    kwargs:
      axis: 'channel'
      threshold: [0.25, 0.5, 0.75]
  • in Python:
    >>> postprocessing = [BinarizeDescr(
    ...   kwargs=BinarizeAlongAxisKwargs(
    ...       axis=AxisId('channel'),
    ...       threshold=[0.25, 0.5, 0.75],
    ...   )
    ... )]
    
implemented_id: ClassVar[Literal['binarize']] = 'binarize'
id: Literal['binarize']
class ClipDescr(ProcessingDescrBase):
926class ClipDescr(ProcessingDescrBase):
927    """Set tensor values below min to min and above max to max.
928
929    See `ScaleRangeDescr` for examples.
930    """
931
932    implemented_id: ClassVar[Literal["clip"]] = "clip"
933    if TYPE_CHECKING:
934        id: Literal["clip"] = "clip"
935    else:
936        id: Literal["clip"]
937
938    kwargs: ClipKwargs

Set tensor values below min to min and above max to max.

See ScaleRangeDescr for examples.

implemented_id: ClassVar[Literal['clip']] = 'clip'
id: Literal['clip']
class EnsureDtypeKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
941class EnsureDtypeKwargs(ProcessingKwargs):
942    """key word arguments for `EnsureDtypeDescr`"""
943
944    dtype: Literal[
945        "float32",
946        "float64",
947        "uint8",
948        "int8",
949        "uint16",
950        "int16",
951        "uint32",
952        "int32",
953        "uint64",
954        "int64",
955        "bool",
956    ]

key word arguments for EnsureDtypeDescr

dtype: Literal['float32', 'float64', 'uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64', 'bool']
class EnsureDtypeDescr(ProcessingDescrBase):
 959class EnsureDtypeDescr(ProcessingDescrBase):
 960    """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching).
 961
 962    This can for example be used to ensure the inner neural network model gets a
 963    different input tensor data type than the fully described bioimage.io model does.
 964
 965    Examples:
 966        The described bioimage.io model (incl. preprocessing) accepts any
 967        float32-compatible tensor, normalizes it with percentiles and clipping and then
 968        casts it to uint8, which is what the neural network in this example expects.
 969        - in YAML
 970            ```yaml
 971            inputs:
 972            - data:
 973                type: float32  # described bioimage.io model is compatible with any float32 input tensor
 974            preprocessing:
 975            - id: scale_range
 976                kwargs:
 977                axes: ['y', 'x']
 978                max_percentile: 99.8
 979                min_percentile: 5.0
 980            - id: clip
 981                kwargs:
 982                min: 0.0
 983                max: 1.0
 984            - id: ensure_dtype
 985                kwargs:
 986                dtype: uint8
 987            ```
 988        - in Python:
 989            >>> preprocessing = [
 990            ...     ScaleRangeDescr(
 991            ...         kwargs=ScaleRangeKwargs(
 992            ...           axes= (AxisId('y'), AxisId('x')),
 993            ...           max_percentile= 99.8,
 994            ...           min_percentile= 5.0,
 995            ...         )
 996            ...     ),
 997            ...     ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)),
 998            ...     EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")),
 999            ... ]
1000    """
1001
1002    implemented_id: ClassVar[Literal["ensure_dtype"]] = "ensure_dtype"
1003    if TYPE_CHECKING:
1004        id: Literal["ensure_dtype"] = "ensure_dtype"
1005    else:
1006        id: Literal["ensure_dtype"]
1007
1008    kwargs: EnsureDtypeKwargs

Cast the tensor data type to EnsureDtypeKwargs.dtype (if not matching).

This can for example be used to ensure the inner neural network model gets a different input tensor data type than the fully described bioimage.io model does.

Examples:

The described bioimage.io model (incl. preprocessing) accepts any float32-compatible tensor, normalizes it with percentiles and clipping and then casts it to uint8, which is what the neural network in this example expects.

  • in YAML

inputs:
- data:
    type: float32  # described bioimage.io model is compatible with any float32 input tensor
preprocessing:
- id: scale_range
    kwargs:
    axes: ['y', 'x']
    max_percentile: 99.8
    min_percentile: 5.0
- id: clip
    kwargs:
    min: 0.0
    max: 1.0
- id: ensure_dtype
    kwargs:
    dtype: uint8

  • in Python:
    >>> preprocessing = [
    ...     ScaleRangeDescr(
    ...         kwargs=ScaleRangeKwargs(
    ...           axes= (AxisId('y'), AxisId('x')),
    ...           max_percentile= 99.8,
    ...           min_percentile= 5.0,
    ...         )
    ...     ),
    ...     ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)),
    ...     EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")),
    ... ]
    
implemented_id: ClassVar[Literal['ensure_dtype']] = 'ensure_dtype'
id: Literal['ensure_dtype']
class ScaleLinearKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1011class ScaleLinearKwargs(ProcessingKwargs):
1012    """Key word arguments for `ScaleLinearDescr`"""
1013
1014    gain: float = 1.0
1015    """multiplicative factor"""
1016
1017    offset: float = 0.0
1018    """additive term"""
1019
1020    @model_validator(mode="after")
1021    def _validate(self) -> Self:
1022        if self.gain == 1.0 and self.offset == 0.0:
1023            raise ValueError(
1024                "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
1025                + " != 0.0."
1026            )
1027
1028        return self

Key word arguments for ScaleLinearDescr

gain: float

multiplicative factor

offset: float

additive term

class ScaleLinearAlongAxisKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1031class ScaleLinearAlongAxisKwargs(ProcessingKwargs):
1032    """Key word arguments for `ScaleLinearDescr`"""
1033
1034    axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
1035    """The axis of gain and offset values."""
1036
1037    gain: Union[float, NotEmpty[List[float]]] = 1.0
1038    """multiplicative factor"""
1039
1040    offset: Union[float, NotEmpty[List[float]]] = 0.0
1041    """additive term"""
1042
1043    @model_validator(mode="after")
1044    def _validate(self) -> Self:
1045
1046        if isinstance(self.gain, list):
1047            if isinstance(self.offset, list):
1048                if len(self.gain) != len(self.offset):
1049                    raise ValueError(
1050                        f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match."
1051                    )
1052            else:
1053                self.offset = [float(self.offset)] * len(self.gain)
1054        elif isinstance(self.offset, list):
1055            self.gain = [float(self.gain)] * len(self.offset)
1056        else:
1057            raise ValueError(
1058                "Do not specify an `axis` for scalar gain and offset values."
1059            )
1060
1061        if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset):
1062            raise ValueError(
1063                "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
1064                + " != 0.0."
1065            )
1066
1067        return self

Key word arguments for ScaleLinearDescr

axis: Annotated[AxisId, Predicate(_is_not_batch), FieldInfo(annotation=NoneType, required=True, examples=['channel'])]

The axis of gain and offset values.

gain: Union[float, Annotated[List[float], MinLen(min_length=1)]]

multiplicative factor

offset: Union[float, Annotated[List[float], MinLen(min_length=1)]]

additive term

class ScaleLinearDescr(ProcessingDescrBase):
1070class ScaleLinearDescr(ProcessingDescrBase):
1071    """Fixed linear scaling.
1072
1073    Examples:
1074      1. Scale with scalar gain and offset
1075        - in YAML
1076        ```yaml
1077        preprocessing:
1078          - id: scale_linear
1079            kwargs:
1080              gain: 2.0
1081              offset: 3.0
1082        ```
1083        - in Python:
1084        >>> preprocessing = [
1085        ...     ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0))
1086        ... ]
1087
1088      2. Independent scaling along an axis
1089        - in YAML
1090        ```yaml
1091        preprocessing:
1092          - id: scale_linear
1093            kwargs:
1094              axis: 'channel'
1095              gain: [1.0, 2.0, 3.0]
1096        ```
1097        - in Python:
1098        >>> preprocessing = [
1099        ...     ScaleLinearDescr(
1100        ...         kwargs=ScaleLinearAlongAxisKwargs(
1101        ...             axis=AxisId("channel"),
1102        ...             gain=[1.0, 2.0, 3.0],
1103        ...         )
1104        ...     )
1105        ... ]
1106
1107    """
1108
1109    implemented_id: ClassVar[Literal["scale_linear"]] = "scale_linear"
1110    if TYPE_CHECKING:
1111        id: Literal["scale_linear"] = "scale_linear"
1112    else:
1113        id: Literal["scale_linear"]
1114    kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs]

Fixed linear scaling.

Examples:
  1. Scale with scalar gain and offset

    • in YAML

      preprocessing:
        - id: scale_linear
          kwargs:
            gain: 2.0
            offset: 3.0
      
    • in Python:

      >>> preprocessing = [
      ...     ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0))
      ... ]
      
  2. Independent scaling along an axis

    • in YAML

      preprocessing:
        - id: scale_linear
          kwargs:
            axis: 'channel'
            gain: [1.0, 2.0, 3.0]
      
    • in Python:

      >>> preprocessing = [
      ...     ScaleLinearDescr(
      ...         kwargs=ScaleLinearAlongAxisKwargs(
      ...             axis=AxisId("channel"),
      ...             gain=[1.0, 2.0, 3.0],
      ...         )
      ...     )
      ... ]
      
implemented_id: ClassVar[Literal['scale_linear']] = 'scale_linear'
id: Literal['scale_linear']
class SigmoidDescr(ProcessingDescrBase):
1117class SigmoidDescr(ProcessingDescrBase):
1118    """The logistic sigmoid funciton, a.k.a. expit function.
1119
1120    Examples:
1121    - in YAML
1122        ```yaml
1123        postprocessing:
1124          - id: sigmoid
1125        ```
1126    - in Python:
1127        >>> postprocessing = [SigmoidDescr()]
1128    """
1129
1130    implemented_id: ClassVar[Literal["sigmoid"]] = "sigmoid"
1131    if TYPE_CHECKING:
1132        id: Literal["sigmoid"] = "sigmoid"
1133    else:
1134        id: Literal["sigmoid"]
1135
1136    @property
1137    def kwargs(self) -> ProcessingKwargs:
1138        """empty kwargs"""
1139        return ProcessingKwargs()

The logistic sigmoid funciton, a.k.a. expit function.

Examples:

  • in YAML
postprocessing:
  - id: sigmoid
  • in Python:
    >>> postprocessing = [SigmoidDescr()]
    
implemented_id: ClassVar[Literal['sigmoid']] = 'sigmoid'
1136    @property
1137    def kwargs(self) -> ProcessingKwargs:
1138        """empty kwargs"""
1139        return ProcessingKwargs()

empty kwargs

id: Literal['sigmoid']
class FixedZeroMeanUnitVarianceKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1142class FixedZeroMeanUnitVarianceKwargs(ProcessingKwargs):
1143    """key word arguments for `FixedZeroMeanUnitVarianceDescr`"""
1144
1145    mean: float
1146    """The mean value to normalize with."""
1147
1148    std: Annotated[float, Ge(1e-6)]
1149    """The standard deviation value to normalize with."""

key word arguments for FixedZeroMeanUnitVarianceDescr

mean: float

The mean value to normalize with.

std: Annotated[float, Ge(ge=1e-06)]

The standard deviation value to normalize with.

class FixedZeroMeanUnitVarianceAlongAxisKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1152class FixedZeroMeanUnitVarianceAlongAxisKwargs(ProcessingKwargs):
1153    """key word arguments for `FixedZeroMeanUnitVarianceDescr`"""
1154
1155    mean: NotEmpty[List[float]]
1156    """The mean value(s) to normalize with."""
1157
1158    std: NotEmpty[List[Annotated[float, Ge(1e-6)]]]
1159    """The standard deviation value(s) to normalize with.
1160    Size must match `mean` values."""
1161
1162    axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])]
1163    """The axis of the mean/std values to normalize each entry along that dimension
1164    separately."""
1165
1166    @model_validator(mode="after")
1167    def _mean_and_std_match(self) -> Self:
1168        if len(self.mean) != len(self.std):
1169            raise ValueError(
1170                f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})"
1171                + " must match."
1172            )
1173
1174        return self

key word arguments for FixedZeroMeanUnitVarianceDescr

mean: Annotated[List[float], MinLen(min_length=1)]

The mean value(s) to normalize with.

std: Annotated[List[Annotated[float, Ge(ge=1e-06)]], MinLen(min_length=1)]

The standard deviation value(s) to normalize with. Size must match mean values.

axis: Annotated[AxisId, Predicate(_is_not_batch), FieldInfo(annotation=NoneType, required=True, examples=['channel', 'index'])]

The axis of the mean/std values to normalize each entry along that dimension separately.

class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1177class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1178    """Subtract a given mean and divide by the standard deviation.
1179
1180    Normalize with fixed, precomputed values for
1181    `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std`
1182    Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given
1183    axes.
1184
1185    Examples:
1186    1. scalar value for whole tensor
1187        - in YAML
1188        ```yaml
1189        preprocessing:
1190          - id: fixed_zero_mean_unit_variance
1191            kwargs:
1192              mean: 103.5
1193              std: 13.7
1194        ```
1195        - in Python
1196        >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
1197        ...   kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7)
1198        ... )]
1199
1200    2. independently along an axis
1201        - in YAML
1202        ```yaml
1203        preprocessing:
1204          - id: fixed_zero_mean_unit_variance
1205            kwargs:
1206              axis: channel
1207              mean: [101.5, 102.5, 103.5]
1208              std: [11.7, 12.7, 13.7]
1209        ```
1210        - in Python
1211        >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
1212        ...   kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
1213        ...     axis=AxisId("channel"),
1214        ...     mean=[101.5, 102.5, 103.5],
1215        ...     std=[11.7, 12.7, 13.7],
1216        ...   )
1217        ... )]
1218    """
1219
1220    implemented_id: ClassVar[Literal["fixed_zero_mean_unit_variance"]] = (
1221        "fixed_zero_mean_unit_variance"
1222    )
1223    if TYPE_CHECKING:
1224        id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance"
1225    else:
1226        id: Literal["fixed_zero_mean_unit_variance"]
1227
1228    kwargs: Union[
1229        FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs
1230    ]

Subtract a given mean and divide by the standard deviation.

Normalize with fixed, precomputed values for FixedZeroMeanUnitVarianceKwargs.mean and FixedZeroMeanUnitVarianceKwargs.std Use FixedZeroMeanUnitVarianceAlongAxisKwargs for independent scaling along given axes.

Examples:

  1. scalar value for whole tensor
    • in YAML
preprocessing:
  - id: fixed_zero_mean_unit_variance
    kwargs:
      mean: 103.5
      std: 13.7
- in Python >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( ... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7) ... )]
  1. independently along an axis

    • in YAML
    preprocessing:
      - id: fixed_zero_mean_unit_variance
        kwargs:
          axis: channel
          mean: [101.5, 102.5, 103.5]
          std: [11.7, 12.7, 13.7]
    
    • in Python
      >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
      ...   kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
      ...     axis=AxisId("channel"),
      ...     mean=[101.5, 102.5, 103.5],
      ...     std=[11.7, 12.7, 13.7],
      ...   )
      ... )]
      
implemented_id: ClassVar[Literal['fixed_zero_mean_unit_variance']] = 'fixed_zero_mean_unit_variance'
id: Literal['fixed_zero_mean_unit_variance']
class ZeroMeanUnitVarianceKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1233class ZeroMeanUnitVarianceKwargs(ProcessingKwargs):
1234    """key word arguments for `ZeroMeanUnitVarianceDescr`"""
1235
1236    axes: Annotated[
1237        Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1238    ] = None
1239    """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1240    For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1241    resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1242    To normalize each sample independently leave out the 'batch' axis.
1243    Default: Scale all axes jointly."""
1244
1245    eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1246    """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`."""

key word arguments for ZeroMeanUnitVarianceDescr

axes: Annotated[Optional[Sequence[AxisId]], FieldInfo(annotation=NoneType, required=True, examples=[('batch', 'x', 'y')])]

The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') resulting in a tensor of equal shape normalized per channel, specify axes=('batch', 'x', 'y'). To normalize each sample independently leave out the 'batch' axis. Default: Scale all axes jointly.

eps: Annotated[float, Interval(gt=0, ge=None, lt=None, le=0.1)]

epsilon for numeric stability: out = (tensor - mean) / (std + eps).

class ZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1249class ZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1250    """Subtract mean and divide by variance.
1251
1252    Examples:
1253        Subtract tensor mean and variance
1254        - in YAML
1255        ```yaml
1256        preprocessing:
1257          - id: zero_mean_unit_variance
1258        ```
1259        - in Python
1260        >>> preprocessing = [ZeroMeanUnitVarianceDescr()]
1261    """
1262
1263    implemented_id: ClassVar[Literal["zero_mean_unit_variance"]] = (
1264        "zero_mean_unit_variance"
1265    )
1266    if TYPE_CHECKING:
1267        id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance"
1268    else:
1269        id: Literal["zero_mean_unit_variance"]
1270
1271    kwargs: ZeroMeanUnitVarianceKwargs = Field(
1272        default_factory=ZeroMeanUnitVarianceKwargs
1273    )

Subtract mean and divide by variance.

Examples:

Subtract tensor mean and variance

  • in YAML
preprocessing:
  - id: zero_mean_unit_variance
  • in Python
    >>> preprocessing = [ZeroMeanUnitVarianceDescr()]
    
implemented_id: ClassVar[Literal['zero_mean_unit_variance']] = 'zero_mean_unit_variance'
id: Literal['zero_mean_unit_variance']
class ScaleRangeKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1276class ScaleRangeKwargs(ProcessingKwargs):
1277    """key word arguments for `ScaleRangeDescr`
1278
1279    For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default)
1280    this processing step normalizes data to the [0, 1] intervall.
1281    For other percentiles the normalized values will partially be outside the [0, 1]
1282    intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the
1283    normalized values to a range.
1284    """
1285
1286    axes: Annotated[
1287        Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1288    ] = None
1289    """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value.
1290    For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1291    resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1292    To normalize samples independently, leave out the "batch" axis.
1293    Default: Scale all axes jointly."""
1294
1295    min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0
1296    """The lower percentile used to determine the value to align with zero."""
1297
1298    max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0
1299    """The upper percentile used to determine the value to align with one.
1300    Has to be bigger than `min_percentile`.
1301    The range is 1 to 100 instead of 0 to 100 to avoid mistakenly
1302    accepting percentiles specified in the range 0.0 to 1.0."""
1303
1304    eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1305    """Epsilon for numeric stability.
1306    `out = (tensor - v_lower) / (v_upper - v_lower + eps)`;
1307    with `v_lower,v_upper` values at the respective percentiles."""
1308
1309    reference_tensor: Optional[TensorId] = None
1310    """Tensor ID to compute the percentiles from. Default: The tensor itself.
1311    For any tensor in `inputs` only input tensor references are allowed."""
1312
1313    @field_validator("max_percentile", mode="after")
1314    @classmethod
1315    def min_smaller_max(cls, value: float, info: ValidationInfo) -> float:
1316        if (min_p := info.data["min_percentile"]) >= value:
1317            raise ValueError(f"min_percentile {min_p} >= max_percentile {value}")
1318
1319        return value

key word arguments for ScaleRangeDescr

For min_percentile=0.0 (the default) and max_percentile=100 (the default) this processing step normalizes data to the [0, 1] intervall. For other percentiles the normalized values will partially be outside the [0, 1] intervall. Use ScaleRange followed by ClipDescr if you want to limit the normalized values to a range.

axes: Annotated[Optional[Sequence[AxisId]], FieldInfo(annotation=NoneType, required=True, examples=[('batch', 'x', 'y')])]

The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value. For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') resulting in a tensor of equal shape normalized per channel, specify axes=('batch', 'x', 'y'). To normalize samples independently, leave out the "batch" axis. Default: Scale all axes jointly.

min_percentile: Annotated[float, Interval(gt=None, ge=0, lt=100, le=None)]

The lower percentile used to determine the value to align with zero.

max_percentile: Annotated[float, Interval(gt=1, ge=None, lt=None, le=100)]

The upper percentile used to determine the value to align with one. Has to be bigger than min_percentile. The range is 1 to 100 instead of 0 to 100 to avoid mistakenly accepting percentiles specified in the range 0.0 to 1.0.

eps: Annotated[float, Interval(gt=0, ge=None, lt=None, le=0.1)]

Epsilon for numeric stability. out = (tensor - v_lower) / (v_upper - v_lower + eps); with v_lower,v_upper values at the respective percentiles.

reference_tensor: Optional[TensorId]

Tensor ID to compute the percentiles from. Default: The tensor itself. For any tensor in inputs only input tensor references are allowed.

@field_validator('max_percentile', mode='after')
@classmethod
def min_smaller_max( cls, value: float, info: pydantic_core.core_schema.ValidationInfo) -> float:
1313    @field_validator("max_percentile", mode="after")
1314    @classmethod
1315    def min_smaller_max(cls, value: float, info: ValidationInfo) -> float:
1316        if (min_p := info.data["min_percentile"]) >= value:
1317            raise ValueError(f"min_percentile {min_p} >= max_percentile {value}")
1318
1319        return value
class ScaleRangeDescr(ProcessingDescrBase):
1322class ScaleRangeDescr(ProcessingDescrBase):
1323    """Scale with percentiles.
1324
1325    Examples:
1326    1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0
1327        - in YAML
1328        ```yaml
1329        preprocessing:
1330          - id: scale_range
1331            kwargs:
1332              axes: ['y', 'x']
1333              max_percentile: 99.8
1334              min_percentile: 5.0
1335        ```
1336        - in Python
1337        >>> preprocessing = [
1338        ...     ScaleRangeDescr(
1339        ...         kwargs=ScaleRangeKwargs(
1340        ...           axes= (AxisId('y'), AxisId('x')),
1341        ...           max_percentile= 99.8,
1342        ...           min_percentile= 5.0,
1343        ...         )
1344        ...     ),
1345        ...     ClipDescr(
1346        ...         kwargs=ClipKwargs(
1347        ...             min=0.0,
1348        ...             max=1.0,
1349        ...         )
1350        ...     ),
1351        ... ]
1352
1353      2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles.
1354        - in YAML
1355        ```yaml
1356        preprocessing:
1357          - id: scale_range
1358            kwargs:
1359              axes: ['y', 'x']
1360              max_percentile: 99.8
1361              min_percentile: 5.0
1362                  - id: scale_range
1363           - id: clip
1364             kwargs:
1365              min: 0.0
1366              max: 1.0
1367        ```
1368        - in Python
1369        >>> preprocessing = [ScaleRangeDescr(
1370        ...   kwargs=ScaleRangeKwargs(
1371        ...       axes= (AxisId('y'), AxisId('x')),
1372        ...       max_percentile= 99.8,
1373        ...       min_percentile= 5.0,
1374        ...   )
1375        ... )]
1376
1377    """
1378
1379    implemented_id: ClassVar[Literal["scale_range"]] = "scale_range"
1380    if TYPE_CHECKING:
1381        id: Literal["scale_range"] = "scale_range"
1382    else:
1383        id: Literal["scale_range"]
1384    kwargs: ScaleRangeKwargs

Scale with percentiles.

Examples:

  1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0
    • in YAML
preprocessing:
  - id: scale_range
    kwargs:
      axes: ['y', 'x']
      max_percentile: 99.8
      min_percentile: 5.0
- in Python >>> preprocessing = [ ... ScaleRangeDescr( ... kwargs=ScaleRangeKwargs( ... axes= (AxisId('y'), AxisId('x')), ... max_percentile= 99.8, ... min_percentile= 5.0, ... ) ... ), ... ClipDescr( ... kwargs=ClipKwargs( ... min=0.0, ... max=1.0, ... ) ... ), ... ]
  1. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles.

    • in YAML
    preprocessing:
      - id: scale_range
        kwargs:
          axes: ['y', 'x']
          max_percentile: 99.8
          min_percentile: 5.0
              - id: scale_range
       - id: clip
         kwargs:
          min: 0.0
          max: 1.0
    
    • in Python
      >>> preprocessing = [ScaleRangeDescr(
      ...   kwargs=ScaleRangeKwargs(
      ...       axes= (AxisId('y'), AxisId('x')),
      ...       max_percentile= 99.8,
      ...       min_percentile= 5.0,
      ...   )
      ... )]
      
implemented_id: ClassVar[Literal['scale_range']] = 'scale_range'
id: Literal['scale_range']
class ScaleMeanVarianceKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1387class ScaleMeanVarianceKwargs(ProcessingKwargs):
1388    """key word arguments for `ScaleMeanVarianceKwargs`"""
1389
1390    reference_tensor: TensorId
1391    """Name of tensor to match."""
1392
1393    axes: Annotated[
1394        Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1395    ] = None
1396    """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1397    For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1398    resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1399    To normalize samples independently, leave out the 'batch' axis.
1400    Default: Scale all axes jointly."""
1401
1402    eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1403    """Epsilon for numeric stability:
1404    `out  = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`"""

key word arguments for ScaleMeanVarianceKwargs

reference_tensor: TensorId

Name of tensor to match.

axes: Annotated[Optional[Sequence[AxisId]], FieldInfo(annotation=NoneType, required=True, examples=[('batch', 'x', 'y')])]

The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') resulting in a tensor of equal shape normalized per channel, specify axes=('batch', 'x', 'y'). To normalize samples independently, leave out the 'batch' axis. Default: Scale all axes jointly.

eps: Annotated[float, Interval(gt=0, ge=None, lt=None, le=0.1)]

Epsilon for numeric stability: out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.

class ScaleMeanVarianceDescr(ProcessingDescrBase):
1407class ScaleMeanVarianceDescr(ProcessingDescrBase):
1408    """Scale a tensor's data distribution to match another tensor's mean/std.
1409    `out  = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`
1410    """
1411
1412    implemented_id: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance"
1413    if TYPE_CHECKING:
1414        id: Literal["scale_mean_variance"] = "scale_mean_variance"
1415    else:
1416        id: Literal["scale_mean_variance"]
1417    kwargs: ScaleMeanVarianceKwargs

Scale a tensor's data distribution to match another tensor's mean/std. out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.

implemented_id: ClassVar[Literal['scale_mean_variance']] = 'scale_mean_variance'
id: Literal['scale_mean_variance']
PreprocessingDescr = typing.Annotated[typing.Union[BinarizeDescr, ClipDescr, EnsureDtypeDescr, ScaleLinearDescr, SigmoidDescr, FixedZeroMeanUnitVarianceDescr, ZeroMeanUnitVarianceDescr, ScaleRangeDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)]
PostprocessingDescr = typing.Annotated[typing.Union[BinarizeDescr, ClipDescr, EnsureDtypeDescr, ScaleLinearDescr, SigmoidDescr, FixedZeroMeanUnitVarianceDescr, ZeroMeanUnitVarianceDescr, ScaleRangeDescr, ScaleMeanVarianceDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)]
class TensorDescrBase(bioimageio.spec._internal.node.Node, typing.Generic[~IO_AxisT]):
1451class TensorDescrBase(Node, Generic[IO_AxisT]):
1452    id: TensorId
1453    """Tensor id. No duplicates are allowed."""
1454
1455    description: Annotated[str, MaxLen(128)] = ""
1456    """free text description"""
1457
1458    axes: NotEmpty[Sequence[IO_AxisT]]
1459    """tensor axes"""
1460
1461    @property
1462    def shape(self):
1463        return tuple(a.size for a in self.axes)
1464
1465    @field_validator("axes", mode="after", check_fields=False)
1466    @classmethod
1467    def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]:
1468        batch_axes = [a for a in axes if a.type == "batch"]
1469        if len(batch_axes) > 1:
1470            raise ValueError(
1471                f"Only one batch axis (per tensor) allowed, but got {batch_axes}"
1472            )
1473
1474        seen_ids: Set[AxisId] = set()
1475        duplicate_axes_ids: Set[AxisId] = set()
1476        for a in axes:
1477            (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id)
1478
1479        if duplicate_axes_ids:
1480            raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}")
1481
1482        return axes
1483
1484    test_tensor: FileDescr
1485    """An example tensor to use for testing.
1486    Using the model with the test input tensors is expected to yield the test output tensors.
1487    Each test tensor has be a an ndarray in the
1488    [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format).
1489    The file extension must be '.npy'."""
1490
1491    sample_tensor: Optional[FileDescr] = None
1492    """A sample tensor to illustrate a possible input/output for the model,
1493    The sample image primarily serves to inform a human user about an example use case
1494    and is typically stored as .hdf5, .png or .tiff.
1495    It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats)
1496    (numpy's `.npy` format is not supported).
1497    The image dimensionality has to match the number of axes specified in this tensor description.
1498    """
1499
1500    @model_validator(mode="after")
1501    def _validate_sample_tensor(self) -> Self:
1502        if self.sample_tensor is None or not get_validation_context().perform_io_checks:
1503            return self
1504
1505        local = download(self.sample_tensor.source, sha256=self.sample_tensor.sha256)
1506        tensor: NDArray[Any] = imread(
1507            local.path.read_bytes(),
1508            extension=PurePosixPath(local.original_file_name).suffix,
1509        )
1510        n_dims = len(tensor.squeeze().shape)
1511        n_dims_min = n_dims_max = len(self.axes)
1512
1513        for a in self.axes:
1514            if isinstance(a, BatchAxis):
1515                n_dims_min -= 1
1516            elif isinstance(a.size, int):
1517                if a.size == 1:
1518                    n_dims_min -= 1
1519            elif isinstance(a.size, (ParameterizedSize, DataDependentSize)):
1520                if a.size.min == 1:
1521                    n_dims_min -= 1
1522            elif isinstance(a.size, SizeReference):
1523                if a.size.offset < 2:
1524                    # size reference may result in singleton axis
1525                    n_dims_min -= 1
1526            else:
1527                assert_never(a.size)
1528
1529        n_dims_min = max(0, n_dims_min)
1530        if n_dims < n_dims_min or n_dims > n_dims_max:
1531            raise ValueError(
1532                f"Expected sample tensor to have {n_dims_min} to"
1533                + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})."
1534            )
1535
1536        return self
1537
1538    data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = (
1539        IntervalOrRatioDataDescr()
1540    )
1541    """Description of the tensor's data values, optionally per channel.
1542    If specified per channel, the data `type` needs to match across channels."""
1543
1544    @property
1545    def dtype(
1546        self,
1547    ) -> Literal[
1548        "float32",
1549        "float64",
1550        "uint8",
1551        "int8",
1552        "uint16",
1553        "int16",
1554        "uint32",
1555        "int32",
1556        "uint64",
1557        "int64",
1558        "bool",
1559    ]:
1560        """dtype as specified under `data.type` or `data[i].type`"""
1561        if isinstance(self.data, collections.abc.Sequence):
1562            return self.data[0].type
1563        else:
1564            return self.data.type
1565
1566    @field_validator("data", mode="after")
1567    @classmethod
1568    def _check_data_type_across_channels(
1569        cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]
1570    ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]:
1571        if not isinstance(value, list):
1572            return value
1573
1574        dtypes = {t.type for t in value}
1575        if len(dtypes) > 1:
1576            raise ValueError(
1577                "Tensor data descriptions per channel need to agree in their data"
1578                + f" `type`, but found {dtypes}."
1579            )
1580
1581        return value
1582
1583    @model_validator(mode="after")
1584    def _check_data_matches_channelaxis(self) -> Self:
1585        if not isinstance(self.data, (list, tuple)):
1586            return self
1587
1588        for a in self.axes:
1589            if isinstance(a, ChannelAxis):
1590                size = a.size
1591                assert isinstance(size, int)
1592                break
1593        else:
1594            return self
1595
1596        if len(self.data) != size:
1597            raise ValueError(
1598                f"Got tensor data descriptions for {len(self.data)} channels, but"
1599                + f" '{a.id}' axis has size {size}."
1600            )
1601
1602        return self
1603
1604    def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]:
1605        if len(array.shape) != len(self.axes):
1606            raise ValueError(
1607                f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})"
1608                + f" incompatible with {len(self.axes)} axes."
1609            )
1610        return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
id: TensorId

Tensor id. No duplicates are allowed.

description: Annotated[str, MaxLen(max_length=128)]

free text description

axes: Annotated[Sequence[~IO_AxisT], MinLen(min_length=1)]

tensor axes

shape
1461    @property
1462    def shape(self):
1463        return tuple(a.size for a in self.axes)
test_tensor: bioimageio.spec._internal.io.FileDescr

An example tensor to use for testing. Using the model with the test input tensors is expected to yield the test output tensors. Each test tensor has be a an ndarray in the numpy.lib file format. The file extension must be '.npy'.

sample_tensor: Optional[bioimageio.spec._internal.io.FileDescr]

A sample tensor to illustrate a possible input/output for the model, The sample image primarily serves to inform a human user about an example use case and is typically stored as .hdf5, .png or .tiff. It has to be readable by the imageio library (numpy's .npy format is not supported). The image dimensionality has to match the number of axes specified in this tensor description.

data: Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr, Annotated[Sequence[Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr]], MinLen(min_length=1)]]

Description of the tensor's data values, optionally per channel. If specified per channel, the data type needs to match across channels.

dtype: Literal['float32', 'float64', 'uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64', 'bool']
1544    @property
1545    def dtype(
1546        self,
1547    ) -> Literal[
1548        "float32",
1549        "float64",
1550        "uint8",
1551        "int8",
1552        "uint16",
1553        "int16",
1554        "uint32",
1555        "int32",
1556        "uint64",
1557        "int64",
1558        "bool",
1559    ]:
1560        """dtype as specified under `data.type` or `data[i].type`"""
1561        if isinstance(self.data, collections.abc.Sequence):
1562            return self.data[0].type
1563        else:
1564            return self.data.type

dtype as specified under data.type or data[i].type

def get_axis_sizes_for_array( self, array: numpy.ndarray[typing.Any, numpy.dtype[typing.Any]]) -> Dict[AxisId, int]:
1604    def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]:
1605        if len(array.shape) != len(self.axes):
1606            raise ValueError(
1607                f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})"
1608                + f" incompatible with {len(self.axes)} axes."
1609            )
1610        return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
class InputTensorDescr(bioimageio.spec._internal.node.Node, typing.Generic[~IO_AxisT]):
1613class InputTensorDescr(TensorDescrBase[InputAxis]):
1614    id: TensorId = TensorId("input")
1615    """Input tensor id.
1616    No duplicates are allowed across all inputs and outputs."""
1617
1618    optional: bool = False
1619    """indicates that this tensor may be `None`"""
1620
1621    preprocessing: List[PreprocessingDescr] = Field(default_factory=list)
1622    """Description of how this input should be preprocessed.
1623
1624    notes:
1625    - If preprocessing does not start with an 'ensure_dtype' entry, it is added
1626      to ensure an input tensor's data type matches the input tensor's data description.
1627    - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an
1628      'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally
1629      changing the data type.
1630    """
1631
1632    @model_validator(mode="after")
1633    def _validate_preprocessing_kwargs(self) -> Self:
1634        axes_ids = [a.id for a in self.axes]
1635        for p in self.preprocessing:
1636            kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
1637            if kwargs_axes is None:
1638                continue
1639
1640            if not isinstance(kwargs_axes, collections.abc.Sequence):
1641                raise ValueError(
1642                    f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}"
1643                )
1644
1645            if any(a not in axes_ids for a in kwargs_axes):
1646                raise ValueError(
1647                    "`preprocessing.i.kwargs.axes` needs to be subset of axes ids"
1648                )
1649
1650        if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
1651            dtype = self.data.type
1652        else:
1653            dtype = self.data[0].type
1654
1655        # ensure `preprocessing` begins with `EnsureDtypeDescr`
1656        if not self.preprocessing or not isinstance(
1657            self.preprocessing[0], EnsureDtypeDescr
1658        ):
1659            self.preprocessing.insert(
1660                0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1661            )
1662
1663        # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
1664        if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)):
1665            self.preprocessing.append(
1666                EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1667            )
1668
1669        return self
id: TensorId

Input tensor id. No duplicates are allowed across all inputs and outputs.

optional: bool

indicates that this tensor may be None

preprocessing: List[Annotated[Union[BinarizeDescr, ClipDescr, EnsureDtypeDescr, ScaleLinearDescr, SigmoidDescr, FixedZeroMeanUnitVarianceDescr, ZeroMeanUnitVarianceDescr, ScaleRangeDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]

Description of how this input should be preprocessed.

notes:

  • If preprocessing does not start with an 'ensure_dtype' entry, it is added to ensure an input tensor's data type matches the input tensor's data description.
  • If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally changing the data type.
def convert_axes( axes: str, *, shape: Union[Sequence[int], bioimageio.spec.model.v0_4.ParameterizedInputShape, bioimageio.spec.model.v0_4.ImplicitOutputShape], tensor_type: Literal['input', 'output'], halo: Optional[Sequence[int]], size_refs: Mapping[bioimageio.spec.model.v0_4.TensorName, Mapping[str, int]]):
1672def convert_axes(
1673    axes: str,
1674    *,
1675    shape: Union[
1676        Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4
1677    ],
1678    tensor_type: Literal["input", "output"],
1679    halo: Optional[Sequence[int]],
1680    size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
1681):
1682    ret: List[AnyAxis] = []
1683    for i, a in enumerate(axes):
1684        axis_type = _AXIS_TYPE_MAP.get(a, a)
1685        if axis_type == "batch":
1686            ret.append(BatchAxis())
1687            continue
1688
1689        scale = 1.0
1690        if isinstance(shape, _ParameterizedInputShape_v0_4):
1691            if shape.step[i] == 0:
1692                size = shape.min[i]
1693            else:
1694                size = ParameterizedSize(min=shape.min[i], step=shape.step[i])
1695        elif isinstance(shape, _ImplicitOutputShape_v0_4):
1696            ref_t = str(shape.reference_tensor)
1697            if ref_t.count(".") == 1:
1698                t_id, orig_a_id = ref_t.split(".")
1699            else:
1700                t_id = ref_t
1701                orig_a_id = a
1702
1703            a_id = _AXIS_ID_MAP.get(orig_a_id, a)
1704            if not (orig_scale := shape.scale[i]):
1705                # old way to insert a new axis dimension
1706                size = int(2 * shape.offset[i])
1707            else:
1708                scale = 1 / orig_scale
1709                if axis_type in ("channel", "index"):
1710                    # these axes no longer have a scale
1711                    offset_from_scale = orig_scale * size_refs.get(
1712                        _TensorName_v0_4(t_id), {}
1713                    ).get(orig_a_id, 0)
1714                else:
1715                    offset_from_scale = 0
1716                size = SizeReference(
1717                    tensor_id=TensorId(t_id),
1718                    axis_id=AxisId(a_id),
1719                    offset=int(offset_from_scale + 2 * shape.offset[i]),
1720                )
1721        else:
1722            size = shape[i]
1723
1724        if axis_type == "time":
1725            if tensor_type == "input":
1726                ret.append(TimeInputAxis(size=size, scale=scale))
1727            else:
1728                assert not isinstance(size, ParameterizedSize)
1729                if halo is None:
1730                    ret.append(TimeOutputAxis(size=size, scale=scale))
1731                else:
1732                    assert not isinstance(size, int)
1733                    ret.append(
1734                        TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i])
1735                    )
1736
1737        elif axis_type == "index":
1738            if tensor_type == "input":
1739                ret.append(IndexInputAxis(size=size))
1740            else:
1741                if isinstance(size, ParameterizedSize):
1742                    size = DataDependentSize(min=size.min)
1743
1744                ret.append(IndexOutputAxis(size=size))
1745        elif axis_type == "channel":
1746            assert not isinstance(size, ParameterizedSize)
1747            if isinstance(size, SizeReference):
1748                warnings.warn(
1749                    "Conversion of channel size from an implicit output shape may be"
1750                    + " wrong"
1751                )
1752                ret.append(
1753                    ChannelAxis(
1754                        channel_names=[
1755                            Identifier(f"channel{i}") for i in range(size.offset)
1756                        ]
1757                    )
1758                )
1759            else:
1760                ret.append(
1761                    ChannelAxis(
1762                        channel_names=[Identifier(f"channel{i}") for i in range(size)]
1763                    )
1764                )
1765        elif axis_type == "space":
1766            if tensor_type == "input":
1767                ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale))
1768            else:
1769                assert not isinstance(size, ParameterizedSize)
1770                if halo is None or halo[i] == 0:
1771                    ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale))
1772                elif isinstance(size, int):
1773                    raise NotImplementedError(
1774                        f"output axis with halo and fixed size (here {size}) not allowed"
1775                    )
1776                else:
1777                    ret.append(
1778                        SpaceOutputAxisWithHalo(
1779                            id=AxisId(a), size=size, scale=scale, halo=halo[i]
1780                        )
1781                    )
1782
1783    return ret
class OutputTensorDescr(bioimageio.spec._internal.node.Node, typing.Generic[~IO_AxisT]):
1943class OutputTensorDescr(TensorDescrBase[OutputAxis]):
1944    id: TensorId = TensorId("output")
1945    """Output tensor id.
1946    No duplicates are allowed across all inputs and outputs."""
1947
1948    postprocessing: List[PostprocessingDescr] = Field(default_factory=list)
1949    """Description of how this output should be postprocessed.
1950
1951    note: `postprocessing` always ends with an 'ensure_dtype' operation.
1952          If not given this is added to cast to this tensor's `data.type`.
1953    """
1954
1955    @model_validator(mode="after")
1956    def _validate_postprocessing_kwargs(self) -> Self:
1957        axes_ids = [a.id for a in self.axes]
1958        for p in self.postprocessing:
1959            kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
1960            if kwargs_axes is None:
1961                continue
1962
1963            if not isinstance(kwargs_axes, collections.abc.Sequence):
1964                raise ValueError(
1965                    f"expected `axes` sequence, but got {type(kwargs_axes)}"
1966                )
1967
1968            if any(a not in axes_ids for a in kwargs_axes):
1969                raise ValueError("`kwargs.axes` needs to be subset of axes ids")
1970
1971        if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
1972            dtype = self.data.type
1973        else:
1974            dtype = self.data[0].type
1975
1976        # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
1977        if not self.postprocessing or not isinstance(
1978            self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)
1979        ):
1980            self.postprocessing.append(
1981                EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1982            )
1983        return self
id: TensorId

Output tensor id. No duplicates are allowed across all inputs and outputs.

postprocessing: List[Annotated[Union[BinarizeDescr, ClipDescr, EnsureDtypeDescr, ScaleLinearDescr, SigmoidDescr, FixedZeroMeanUnitVarianceDescr, ZeroMeanUnitVarianceDescr, ScaleRangeDescr, ScaleMeanVarianceDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]

Description of how this output should be postprocessed.

note: postprocessing always ends with an 'ensure_dtype' operation. If not given this is added to cast to this tensor's data.type.

TensorDescr = typing.Union[InputTensorDescr, OutputTensorDescr]
def validate_tensors( tensors: Mapping[TensorId, Tuple[Union[InputTensorDescr, OutputTensorDescr], numpy.ndarray[Any, numpy.dtype[Any]]]], tensor_origin: Literal['test_tensor']):
2033def validate_tensors(
2034    tensors: Mapping[TensorId, Tuple[TensorDescr, NDArray[Any]]],
2035    tensor_origin: Literal[
2036        "test_tensor"
2037    ],  # for more precise error messages, e.g. 'test_tensor'
2038):
2039    all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, int]]] = {}
2040
2041    def e_msg(d: TensorDescr):
2042        return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]"
2043
2044    for descr, array in tensors.values():
2045        try:
2046            axis_sizes = descr.get_axis_sizes_for_array(array)
2047        except ValueError as e:
2048            raise ValueError(f"{e_msg(descr)} {e}")
2049        else:
2050            all_tensor_axes[descr.id] = {
2051                a.id: (a, axis_sizes[a.id]) for a in descr.axes
2052            }
2053
2054    for descr, array in tensors.values():
2055        if descr.dtype in ("float32", "float64"):
2056            invalid_test_tensor_dtype = array.dtype.name not in (
2057                "float32",
2058                "float64",
2059                "uint8",
2060                "int8",
2061                "uint16",
2062                "int16",
2063                "uint32",
2064                "int32",
2065                "uint64",
2066                "int64",
2067            )
2068        else:
2069            invalid_test_tensor_dtype = array.dtype.name != descr.dtype
2070
2071        if invalid_test_tensor_dtype:
2072            raise ValueError(
2073                f"{e_msg(descr)}.{tensor_origin}.dtype '{array.dtype.name}' does not"
2074                + f" match described dtype '{descr.dtype}'"
2075            )
2076
2077        if array.min() > -1e-4 and array.max() < 1e-4:
2078            raise ValueError(
2079                "Output values are too small for reliable testing."
2080                + f" Values <-1e5 or >=1e5 must be present in {tensor_origin}"
2081            )
2082
2083        for a in descr.axes:
2084            actual_size = all_tensor_axes[descr.id][a.id][1]
2085            if a.size is None:
2086                continue
2087
2088            if isinstance(a.size, int):
2089                if actual_size != a.size:
2090                    raise ValueError(
2091                        f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' "
2092                        + f"has incompatible size {actual_size}, expected {a.size}"
2093                    )
2094            elif isinstance(a.size, ParameterizedSize):
2095                _ = a.size.validate_size(actual_size)
2096            elif isinstance(a.size, DataDependentSize):
2097                _ = a.size.validate_size(actual_size)
2098            elif isinstance(a.size, SizeReference):
2099                ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id)
2100                if ref_tensor_axes is None:
2101                    raise ValueError(
2102                        f"{e_msg(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor"
2103                        + f" reference '{a.size.tensor_id}'"
2104                    )
2105
2106                ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None))
2107                if ref_axis is None or ref_size is None:
2108                    raise ValueError(
2109                        f"{e_msg(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis"
2110                        + f" reference '{a.size.tensor_id}.{a.size.axis_id}"
2111                    )
2112
2113                if a.unit != ref_axis.unit:
2114                    raise ValueError(
2115                        f"{e_msg(descr)}.axes[{a.id}].size: `SizeReference` requires"
2116                        + " axis and reference axis to have the same `unit`, but"
2117                        + f" {a.unit}!={ref_axis.unit}"
2118                    )
2119
2120                if actual_size != (
2121                    expected_size := (
2122                        ref_size * ref_axis.scale / a.scale + a.size.offset
2123                    )
2124                ):
2125                    raise ValueError(
2126                        f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' of size"
2127                        + f" {actual_size} invalid for referenced size {ref_size};"
2128                        + f" expected {expected_size}"
2129                    )
2130            else:
2131                assert_never(a.size)
class EnvironmentFileDescr(bioimageio.spec._internal.io.FileDescr):
2134class EnvironmentFileDescr(FileDescr):
2135    source: Annotated[
2136        ImportantFileSource,
2137        WithSuffix((".yaml", ".yml"), case_sensitive=True),
2138        Field(
2139            examples=["environment.yaml"],
2140        ),
2141    ]
2142    """∈📦 Conda environment file.
2143    Allows to specify custom dependencies, see conda docs:
2144    - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms)
2145    - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually)
2146    """
source: Annotated[Union[bioimageio.spec._internal.url.HttpUrl, bioimageio.spec._internal.io.RelativeFilePath, Annotated[pathlib.Path, PathType(path_type='file'), FieldInfo(annotation=NoneType, required=True, title='FilePath')]], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')]), AfterValidator(func=<function wo_special_file_name at 0x7f2602536ca0>), PlainSerializer(func=<function _package at 0x7f2602535e40>, return_type=PydanticUndefined, when_used='unless-none'), WithSuffix(suffix=('.yaml', '.yml'), case_sensitive=True), FieldInfo(annotation=NoneType, required=True, examples=['environment.yaml'])]

∈📦 Conda environment file. Allows to specify custom dependencies, see conda docs:

class ArchitectureFromFileDescr(_ArchitectureCallableDescr, bioimageio.spec._internal.io.FileDescr):
2157class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr):
2158    pass
class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr):
2161class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr):
2162    import_from: str
2163    """Where to import the callable from, i.e. `from <import_from> import <callable>`"""
import_from: str

Where to import the callable from, i.e. from <import_from> import <callable>

ArchitectureDescr = typing.Annotated[typing.Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')])]
class WeightsEntryDescrBase(bioimageio.spec._internal.io.FileDescr):
2229class WeightsEntryDescrBase(FileDescr):
2230    type: ClassVar[WeightsFormat]
2231    weights_format_name: ClassVar[str]  # human readable
2232
2233    source: ImportantFileSource
2234    """∈📦 The weights file."""
2235
2236    authors: Optional[List[Author]] = None
2237    """Authors
2238    Either the person(s) that have trained this model resulting in the original weights file.
2239        (If this is the initial weights entry, i.e. it does not have a `parent`)
2240    Or the person(s) who have converted the weights to this weights format.
2241        (If this is a child weight, i.e. it has a `parent` field)
2242    """
2243
2244    parent: Annotated[
2245        Optional[WeightsFormat], Field(examples=["pytorch_state_dict"])
2246    ] = None
2247    """The source weights these weights were converted from.
2248    For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`,
2249    The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights.
2250    All weight entries except one (the initial set of weights resulting from training the model),
2251    need to have this field."""
2252
2253    comment: str = ""
2254    """A comment about this weights entry, for example how these weights were created."""
2255
2256    @model_validator(mode="after")
2257    def check_parent_is_not_self(self) -> Self:
2258        if self.type == self.parent:
2259            raise ValueError("Weights entry can't be it's own parent.")
2260
2261        return self
type: ClassVar[Literal['keras_hdf5', 'onnx', 'pytorch_state_dict', 'tensorflow_js', 'tensorflow_saved_model_bundle', 'torchscript']]
weights_format_name: ClassVar[str]
source: Annotated[Union[bioimageio.spec._internal.url.HttpUrl, bioimageio.spec._internal.io.RelativeFilePath, Annotated[pathlib.Path, PathType(path_type='file'), FieldInfo(annotation=NoneType, required=True, title='FilePath')]], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')]), AfterValidator(func=<function wo_special_file_name at 0x7f2602536ca0>), PlainSerializer(func=<function _package at 0x7f2602535e40>, return_type=PydanticUndefined, when_used='unless-none')]

∈📦 The weights file.

authors: Optional[List[bioimageio.spec.generic.v0_3.Author]]

Authors Either the person(s) that have trained this model resulting in the original weights file. (If this is the initial weights entry, i.e. it does not have a parent) Or the person(s) who have converted the weights to this weights format. (If this is a child weight, i.e. it has a parent field)

parent: Annotated[Optional[Literal['keras_hdf5', 'onnx', 'pytorch_state_dict', 'tensorflow_js', 'tensorflow_saved_model_bundle', 'torchscript']], FieldInfo(annotation=NoneType, required=True, examples=['pytorch_state_dict'])]

The source weights these weights were converted from. For example, if a model's weights were converted from the pytorch_state_dict format to torchscript, The pytorch_state_dict weights entry has no parent and is the parent of the torchscript weights. All weight entries except one (the initial set of weights resulting from training the model), need to have this field.

comment: str

A comment about this weights entry, for example how these weights were created.

@model_validator(mode='after')
def check_parent_is_not_self(self) -> Self:
2256    @model_validator(mode="after")
2257    def check_parent_is_not_self(self) -> Self:
2258        if self.type == self.parent:
2259            raise ValueError("Weights entry can't be it's own parent.")
2260
2261        return self
class KerasHdf5WeightsDescr(WeightsEntryDescrBase):
2264class KerasHdf5WeightsDescr(WeightsEntryDescrBase):
2265    type = "keras_hdf5"
2266    weights_format_name: ClassVar[str] = "Keras HDF5"
2267    tensorflow_version: Version
2268    """TensorFlow version used to create these weights."""
type = 'keras_hdf5'
weights_format_name: ClassVar[str] = 'Keras HDF5'
tensorflow_version: bioimageio.spec._internal.version_type.Version

TensorFlow version used to create these weights.

class OnnxWeightsDescr(WeightsEntryDescrBase):
2271class OnnxWeightsDescr(WeightsEntryDescrBase):
2272    type = "onnx"
2273    weights_format_name: ClassVar[str] = "ONNX"
2274    opset_version: Annotated[int, Ge(7)]
2275    """ONNX opset version"""
type = 'onnx'
weights_format_name: ClassVar[str] = 'ONNX'
opset_version: Annotated[int, Ge(ge=7)]

ONNX opset version

class PytorchStateDictWeightsDescr(WeightsEntryDescrBase):
2278class PytorchStateDictWeightsDescr(WeightsEntryDescrBase):
2279    type = "pytorch_state_dict"
2280    weights_format_name: ClassVar[str] = "Pytorch State Dict"
2281    architecture: ArchitectureDescr
2282    pytorch_version: Version
2283    """Version of the PyTorch library used.
2284    If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible.
2285    """
2286    dependencies: Optional[EnvironmentFileDescr] = None
2287    """Custom depencies beyond pytorch.
2288    The conda environment file should include pytorch and any version pinning has to be compatible with
2289    `pytorch_version`.
2290    """
type = 'pytorch_state_dict'
weights_format_name: ClassVar[str] = 'Pytorch State Dict'
architecture: Annotated[Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')])]
pytorch_version: bioimageio.spec._internal.version_type.Version

Version of the PyTorch library used. If architecture.depencencies is specified it has to include pytorch and any version pinning has to be compatible.

dependencies: Optional[EnvironmentFileDescr]

Custom depencies beyond pytorch. The conda environment file should include pytorch and any version pinning has to be compatible with pytorch_version.

class TensorflowJsWeightsDescr(WeightsEntryDescrBase):
2293class TensorflowJsWeightsDescr(WeightsEntryDescrBase):
2294    type = "tensorflow_js"
2295    weights_format_name: ClassVar[str] = "Tensorflow.js"
2296    tensorflow_version: Version
2297    """Version of the TensorFlow library used."""
2298
2299    source: ImportantFileSource
2300    """∈📦 The multi-file weights.
2301    All required files/folders should be a zip archive."""
type = 'tensorflow_js'
weights_format_name: ClassVar[str] = 'Tensorflow.js'
tensorflow_version: bioimageio.spec._internal.version_type.Version

Version of the TensorFlow library used.

source: Annotated[Union[bioimageio.spec._internal.url.HttpUrl, bioimageio.spec._internal.io.RelativeFilePath, Annotated[pathlib.Path, PathType(path_type='file'), FieldInfo(annotation=NoneType, required=True, title='FilePath')]], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')]), AfterValidator(func=<function wo_special_file_name at 0x7f2602536ca0>), PlainSerializer(func=<function _package at 0x7f2602535e40>, return_type=PydanticUndefined, when_used='unless-none')]

∈📦 The multi-file weights. All required files/folders should be a zip archive.

class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase):
2304class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase):
2305    type = "tensorflow_saved_model_bundle"
2306    weights_format_name: ClassVar[str] = "Tensorflow Saved Model"
2307    tensorflow_version: Version
2308    """Version of the TensorFlow library used."""
2309
2310    dependencies: Optional[EnvironmentFileDescr] = None
2311    """Custom dependencies beyond tensorflow.
2312    Should include tensorflow and any version pinning has to be compatible with `tensorflow_version`."""
2313
2314    source: ImportantFileSource
2315    """∈📦 The multi-file weights.
2316    All required files/folders should be a zip archive."""
type = 'tensorflow_saved_model_bundle'
weights_format_name: ClassVar[str] = 'Tensorflow Saved Model'
tensorflow_version: bioimageio.spec._internal.version_type.Version

Version of the TensorFlow library used.

dependencies: Optional[EnvironmentFileDescr]

Custom dependencies beyond tensorflow. Should include tensorflow and any version pinning has to be compatible with tensorflow_version.

source: Annotated[Union[bioimageio.spec._internal.url.HttpUrl, bioimageio.spec._internal.io.RelativeFilePath, Annotated[pathlib.Path, PathType(path_type='file'), FieldInfo(annotation=NoneType, required=True, title='FilePath')]], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')]), AfterValidator(func=<function wo_special_file_name at 0x7f2602536ca0>), PlainSerializer(func=<function _package at 0x7f2602535e40>, return_type=PydanticUndefined, when_used='unless-none')]

∈📦 The multi-file weights. All required files/folders should be a zip archive.

class TorchscriptWeightsDescr(WeightsEntryDescrBase):
2319class TorchscriptWeightsDescr(WeightsEntryDescrBase):
2320    type = "torchscript"
2321    weights_format_name: ClassVar[str] = "TorchScript"
2322    pytorch_version: Version
2323    """Version of the PyTorch library used."""
type = 'torchscript'
weights_format_name: ClassVar[str] = 'TorchScript'
pytorch_version: bioimageio.spec._internal.version_type.Version

Version of the PyTorch library used.

class WeightsDescr(bioimageio.spec._internal.node.Node):
2326class WeightsDescr(Node):
2327    keras_hdf5: Optional[KerasHdf5WeightsDescr] = None
2328    onnx: Optional[OnnxWeightsDescr] = None
2329    pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None
2330    tensorflow_js: Optional[TensorflowJsWeightsDescr] = None
2331    tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = (
2332        None
2333    )
2334    torchscript: Optional[TorchscriptWeightsDescr] = None
2335
2336    @model_validator(mode="after")
2337    def check_entries(self) -> Self:
2338        entries = {wtype for wtype, entry in self if entry is not None}
2339
2340        if not entries:
2341            raise ValueError("Missing weights entry")
2342
2343        entries_wo_parent = {
2344            wtype
2345            for wtype, entry in self
2346            if entry is not None and hasattr(entry, "parent") and entry.parent is None
2347        }
2348        if len(entries_wo_parent) != 1:
2349            issue_warning(
2350                "Exactly one weights entry may not specify the `parent` field (got"
2351                + " {value}). That entry is considered the original set of model weights."
2352                + " Other weight formats are created through conversion of the orignal or"
2353                + " already converted weights. They have to reference the weights format"
2354                + " they were converted from as their `parent`.",
2355                value=len(entries_wo_parent),
2356                field="weights",
2357            )
2358
2359        for wtype, entry in self:
2360            if entry is None:
2361                continue
2362
2363            assert hasattr(entry, "type")
2364            assert hasattr(entry, "parent")
2365            assert wtype == entry.type
2366            if (
2367                entry.parent is not None and entry.parent not in entries
2368            ):  # self reference checked for `parent` field
2369                raise ValueError(
2370                    f"`weights.{wtype}.parent={entry.parent} not in specified weight"
2371                    + f" formats: {entries}"
2372                )
2373
2374        return self
2375
2376    def __getitem__(
2377        self,
2378        key: Literal[
2379            "keras_hdf5",
2380            "onnx",
2381            "pytorch_state_dict",
2382            "tensorflow_js",
2383            "tensorflow_saved_model_bundle",
2384            "torchscript",
2385        ],
2386    ):
2387        if key == "keras_hdf5":
2388            ret = self.keras_hdf5
2389        elif key == "onnx":
2390            ret = self.onnx
2391        elif key == "pytorch_state_dict":
2392            ret = self.pytorch_state_dict
2393        elif key == "tensorflow_js":
2394            ret = self.tensorflow_js
2395        elif key == "tensorflow_saved_model_bundle":
2396            ret = self.tensorflow_saved_model_bundle
2397        elif key == "torchscript":
2398            ret = self.torchscript
2399        else:
2400            raise KeyError(key)
2401
2402        if ret is None:
2403            raise KeyError(key)
2404
2405        return ret
2406
2407    @property
2408    def available_formats(self):
2409        return {
2410            **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}),
2411            **({} if self.onnx is None else {"onnx": self.onnx}),
2412            **(
2413                {}
2414                if self.pytorch_state_dict is None
2415                else {"pytorch_state_dict": self.pytorch_state_dict}
2416            ),
2417            **(
2418                {}
2419                if self.tensorflow_js is None
2420                else {"tensorflow_js": self.tensorflow_js}
2421            ),
2422            **(
2423                {}
2424                if self.tensorflow_saved_model_bundle is None
2425                else {
2426                    "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle
2427                }
2428            ),
2429            **({} if self.torchscript is None else {"torchscript": self.torchscript}),
2430        }
2431
2432    @property
2433    def missing_formats(self):
2434        return {
2435            wf for wf in get_args(WeightsFormat) if wf not in self.available_formats
2436        }
keras_hdf5: Optional[KerasHdf5WeightsDescr]
onnx: Optional[OnnxWeightsDescr]
pytorch_state_dict: Optional[PytorchStateDictWeightsDescr]
tensorflow_js: Optional[TensorflowJsWeightsDescr]
tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr]
torchscript: Optional[TorchscriptWeightsDescr]
@model_validator(mode='after')
def check_entries(self) -> Self:
2336    @model_validator(mode="after")
2337    def check_entries(self) -> Self:
2338        entries = {wtype for wtype, entry in self if entry is not None}
2339
2340        if not entries:
2341            raise ValueError("Missing weights entry")
2342
2343        entries_wo_parent = {
2344            wtype
2345            for wtype, entry in self
2346            if entry is not None and hasattr(entry, "parent") and entry.parent is None
2347        }
2348        if len(entries_wo_parent) != 1:
2349            issue_warning(
2350                "Exactly one weights entry may not specify the `parent` field (got"
2351                + " {value}). That entry is considered the original set of model weights."
2352                + " Other weight formats are created through conversion of the orignal or"
2353                + " already converted weights. They have to reference the weights format"
2354                + " they were converted from as their `parent`.",
2355                value=len(entries_wo_parent),
2356                field="weights",
2357            )
2358
2359        for wtype, entry in self:
2360            if entry is None:
2361                continue
2362
2363            assert hasattr(entry, "type")
2364            assert hasattr(entry, "parent")
2365            assert wtype == entry.type
2366            if (
2367                entry.parent is not None and entry.parent not in entries
2368            ):  # self reference checked for `parent` field
2369                raise ValueError(
2370                    f"`weights.{wtype}.parent={entry.parent} not in specified weight"
2371                    + f" formats: {entries}"
2372                )
2373
2374        return self
available_formats
2407    @property
2408    def available_formats(self):
2409        return {
2410            **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}),
2411            **({} if self.onnx is None else {"onnx": self.onnx}),
2412            **(
2413                {}
2414                if self.pytorch_state_dict is None
2415                else {"pytorch_state_dict": self.pytorch_state_dict}
2416            ),
2417            **(
2418                {}
2419                if self.tensorflow_js is None
2420                else {"tensorflow_js": self.tensorflow_js}
2421            ),
2422            **(
2423                {}
2424                if self.tensorflow_saved_model_bundle is None
2425                else {
2426                    "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle
2427                }
2428            ),
2429            **({} if self.torchscript is None else {"torchscript": self.torchscript}),
2430        }
missing_formats
2432    @property
2433    def missing_formats(self):
2434        return {
2435            wf for wf in get_args(WeightsFormat) if wf not in self.available_formats
2436        }
class ModelId(bioimageio.spec.generic.v0_3.ResourceId):
2439class ModelId(ResourceId):
2440    pass

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.

class LinkedModel(bioimageio.spec.generic.v0_3.LinkedResourceBase):
2443class LinkedModel(LinkedResourceBase):
2444    """Reference to a bioimage.io model."""
2445
2446    id: ModelId
2447    """A valid model `id` from the bioimage.io collection."""

Reference to a bioimage.io model.

id: ModelId

A valid model id from the bioimage.io collection.

class ReproducibilityTolerance(bioimageio.spec._internal.node.Node):
2469class ReproducibilityTolerance(Node, extra="allow"):
2470    """Describes what small numerical differences -- if any -- may be tolerated
2471    in the generated output when executing in different environments.
2472
2473    A tensor element *output* is considered mismatched to the **test_tensor** if
2474    abs(*output* - **test_tensor**) > **absolute_tolerance** + **relative_tolerance** * abs(**test_tensor**).
2475    (Internally we call [numpy.testing.assert_allclose](https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html).)
2476
2477    Motivation:
2478        For testing we can request the respective deep learning frameworks to be as
2479        reproducible as possible by setting seeds and chosing deterministic algorithms,
2480        but differences in operating systems, available hardware and installed drivers
2481        may still lead to numerical differences.
2482    """
2483
2484    relative_tolerance: RelativeTolerance = 1e-3
2485    """Maximum relative tolerance of reproduced test tensor."""
2486
2487    absolute_tolerance: AbsoluteTolerance = 1e-4
2488    """Maximum absolute tolerance of reproduced test tensor."""
2489
2490    mismatched_elements_per_million: MismatchedElementsPerMillion = 0
2491    """Maximum number of mismatched elements/pixels per million to tolerate."""
2492
2493    output_ids: Sequence[TensorId] = ()
2494    """Limits the output tensor IDs these reproducibility details apply to."""
2495
2496    weights_formats: Sequence[WeightsFormat] = ()
2497    """Limits the weights formats these details apply to."""

Describes what small numerical differences -- if any -- may be tolerated in the generated output when executing in different environments.

A tensor element output is considered mismatched to the test_tensor if abs(output - test_tensor) > absolute_tolerance + relative_tolerance * abs(test_tensor). (Internally we call numpy.testing.assert_allclose.)

Motivation:

For testing we can request the respective deep learning frameworks to be as reproducible as possible by setting seeds and chosing deterministic algorithms, but differences in operating systems, available hardware and installed drivers may still lead to numerical differences.

relative_tolerance: Annotated[float, Interval(gt=None, ge=0, lt=None, le=0.01)]

Maximum relative tolerance of reproduced test tensor.

absolute_tolerance: Annotated[float, Interval(gt=None, ge=0, lt=None, le=None)]

Maximum absolute tolerance of reproduced test tensor.

mismatched_elements_per_million: Annotated[int, Interval(gt=None, ge=0, lt=None, le=100)]

Maximum number of mismatched elements/pixels per million to tolerate.

output_ids: Sequence[TensorId]

Limits the output tensor IDs these reproducibility details apply to.

weights_formats: Sequence[Literal['keras_hdf5', 'onnx', 'pytorch_state_dict', 'tensorflow_js', 'tensorflow_saved_model_bundle', 'torchscript']]

Limits the weights formats these details apply to.

class BioimageioConfig(bioimageio.spec._internal.node.Node):
2500class BioimageioConfig(Node, extra="allow"):
2501    reproducibility_tolerance: Sequence[ReproducibilityTolerance] = ()
2502    """Tolerances to allow when reproducing the model's test outputs
2503    from the model's test inputs.
2504    Only the first entry matching tensor id and weights format is considered.
2505    """
reproducibility_tolerance: Sequence[ReproducibilityTolerance]

Tolerances to allow when reproducing the model's test outputs from the model's test inputs. Only the first entry matching tensor id and weights format is considered.

class Config(bioimageio.spec._internal.node.Node):
2508class Config(Node, extra="allow"):
2509    bioimageio: BioimageioConfig = Field(default_factory=BioimageioConfig)
bioimageio: BioimageioConfig
2512class ModelDescr(GenericModelDescrBase):
2513    """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights.
2514    These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
2515    """
2516
2517    implemented_format_version: ClassVar[Literal["0.5.4"]] = "0.5.4"
2518    if TYPE_CHECKING:
2519        format_version: Literal["0.5.4"] = "0.5.4"
2520    else:
2521        format_version: Literal["0.5.4"]
2522        """Version of the bioimage.io model description specification used.
2523        When creating a new model always use the latest micro/patch version described here.
2524        The `format_version` is important for any consumer software to understand how to parse the fields.
2525        """
2526
2527    implemented_type: ClassVar[Literal["model"]] = "model"
2528    if TYPE_CHECKING:
2529        type: Literal["model"] = "model"
2530    else:
2531        type: Literal["model"]
2532        """Specialized resource type 'model'"""
2533
2534    id: Optional[ModelId] = None
2535    """bioimage.io-wide unique resource identifier
2536    assigned by bioimage.io; version **un**specific."""
2537
2538    authors: NotEmpty[List[Author]]
2539    """The authors are the creators of the model RDF and the primary points of contact."""
2540
2541    documentation: Annotated[
2542        DocumentationSource,
2543        Field(
2544            examples=[
2545                "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/README.md",
2546                "README.md",
2547            ],
2548        ),
2549    ]
2550    """∈📦 URL or relative path to a markdown file with additional documentation.
2551    The recommended documentation file name is `README.md`. An `.md` suffix is mandatory.
2552    The documentation should include a '#[#] Validation' (sub)section
2553    with details on how to quantitatively validate the model on unseen data."""
2554
2555    @field_validator("documentation", mode="after")
2556    @classmethod
2557    def _validate_documentation(cls, value: DocumentationSource) -> DocumentationSource:
2558        if not get_validation_context().perform_io_checks:
2559            return value
2560
2561        doc_path = download(value).path
2562        doc_content = doc_path.read_text(encoding="utf-8")
2563        assert isinstance(doc_content, str)
2564        if not re.search("#.*[vV]alidation", doc_content):
2565            issue_warning(
2566                "No '# Validation' (sub)section found in {value}.",
2567                value=value,
2568                field="documentation",
2569            )
2570
2571        return value
2572
2573    inputs: NotEmpty[Sequence[InputTensorDescr]]
2574    """Describes the input tensors expected by this model."""
2575
2576    @field_validator("inputs", mode="after")
2577    @classmethod
2578    def _validate_input_axes(
2579        cls, inputs: Sequence[InputTensorDescr]
2580    ) -> Sequence[InputTensorDescr]:
2581        input_size_refs = cls._get_axes_with_independent_size(inputs)
2582
2583        for i, ipt in enumerate(inputs):
2584            valid_independent_refs: Dict[
2585                Tuple[TensorId, AxisId],
2586                Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2587            ] = {
2588                **{
2589                    (ipt.id, a.id): (ipt, a, a.size)
2590                    for a in ipt.axes
2591                    if not isinstance(a, BatchAxis)
2592                    and isinstance(a.size, (int, ParameterizedSize))
2593                },
2594                **input_size_refs,
2595            }
2596            for a, ax in enumerate(ipt.axes):
2597                cls._validate_axis(
2598                    "inputs",
2599                    i=i,
2600                    tensor_id=ipt.id,
2601                    a=a,
2602                    axis=ax,
2603                    valid_independent_refs=valid_independent_refs,
2604                )
2605        return inputs
2606
2607    @staticmethod
2608    def _validate_axis(
2609        field_name: str,
2610        i: int,
2611        tensor_id: TensorId,
2612        a: int,
2613        axis: AnyAxis,
2614        valid_independent_refs: Dict[
2615            Tuple[TensorId, AxisId],
2616            Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2617        ],
2618    ):
2619        if isinstance(axis, BatchAxis) or isinstance(
2620            axis.size, (int, ParameterizedSize, DataDependentSize)
2621        ):
2622            return
2623        elif not isinstance(axis.size, SizeReference):
2624            assert_never(axis.size)
2625
2626        # validate axis.size SizeReference
2627        ref = (axis.size.tensor_id, axis.size.axis_id)
2628        if ref not in valid_independent_refs:
2629            raise ValueError(
2630                "Invalid tensor axis reference at"
2631                + f" {field_name}[{i}].axes[{a}].size: {axis.size}."
2632            )
2633        if ref == (tensor_id, axis.id):
2634            raise ValueError(
2635                "Self-referencing not allowed for"
2636                + f" {field_name}[{i}].axes[{a}].size: {axis.size}"
2637            )
2638        if axis.type == "channel":
2639            if valid_independent_refs[ref][1].type != "channel":
2640                raise ValueError(
2641                    "A channel axis' size may only reference another fixed size"
2642                    + " channel axis."
2643                )
2644            if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names:
2645                ref_size = valid_independent_refs[ref][2]
2646                assert isinstance(ref_size, int), (
2647                    "channel axis ref (another channel axis) has to specify fixed"
2648                    + " size"
2649                )
2650                generated_channel_names = [
2651                    Identifier(axis.channel_names.format(i=i))
2652                    for i in range(1, ref_size + 1)
2653                ]
2654                axis.channel_names = generated_channel_names
2655
2656        if (ax_unit := getattr(axis, "unit", None)) != (
2657            ref_unit := getattr(valid_independent_refs[ref][1], "unit", None)
2658        ):
2659            raise ValueError(
2660                "The units of an axis and its reference axis need to match, but"
2661                + f" '{ax_unit}' != '{ref_unit}'."
2662            )
2663        ref_axis = valid_independent_refs[ref][1]
2664        if isinstance(ref_axis, BatchAxis):
2665            raise ValueError(
2666                f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}"
2667                + " (a batch axis is not allowed as reference)."
2668            )
2669
2670        if isinstance(axis, WithHalo):
2671            min_size = axis.size.get_size(axis, ref_axis, n=0)
2672            if (min_size - 2 * axis.halo) < 1:
2673                raise ValueError(
2674                    f"axis {axis.id} with minimum size {min_size} is too small for halo"
2675                    + f" {axis.halo}."
2676                )
2677
2678            input_halo = axis.halo * axis.scale / ref_axis.scale
2679            if input_halo != int(input_halo) or input_halo % 2 == 1:
2680                raise ValueError(
2681                    f"input_halo {input_halo} (output_halo {axis.halo} *"
2682                    + f" output_scale {axis.scale} / input_scale {ref_axis.scale})"
2683                    + f"     {tensor_id}.{axis.id}."
2684                )
2685
2686    @model_validator(mode="after")
2687    def _validate_test_tensors(self) -> Self:
2688        if not get_validation_context().perform_io_checks:
2689            return self
2690
2691        test_output_arrays = [
2692            load_array(descr.test_tensor.download().path) for descr in self.outputs
2693        ]
2694        test_input_arrays = [
2695            load_array(descr.test_tensor.download().path) for descr in self.inputs
2696        ]
2697
2698        tensors = {
2699            descr.id: (descr, array)
2700            for descr, array in zip(
2701                chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays
2702            )
2703        }
2704        validate_tensors(tensors, tensor_origin="test_tensor")
2705
2706        output_arrays = {
2707            descr.id: array for descr, array in zip(self.outputs, test_output_arrays)
2708        }
2709        for rep_tol in self.config.bioimageio.reproducibility_tolerance:
2710            if not rep_tol.absolute_tolerance:
2711                continue
2712
2713            if rep_tol.output_ids:
2714                out_arrays = {
2715                    oid: a
2716                    for oid, a in output_arrays.items()
2717                    if oid in rep_tol.output_ids
2718                }
2719            else:
2720                out_arrays = output_arrays
2721
2722            for out_id, array in out_arrays.items():
2723                if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01:
2724                    raise ValueError(
2725                        "config.bioimageio.reproducibility_tolerance.absolute_tolerance="
2726                        + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}"
2727                        + f" (1% of the maximum value of the test tensor '{out_id}')"
2728                    )
2729
2730        return self
2731
2732    @model_validator(mode="after")
2733    def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self:
2734        ipt_refs = {t.id for t in self.inputs}
2735        out_refs = {t.id for t in self.outputs}
2736        for ipt in self.inputs:
2737            for p in ipt.preprocessing:
2738                ref = p.kwargs.get("reference_tensor")
2739                if ref is None:
2740                    continue
2741                if ref not in ipt_refs:
2742                    raise ValueError(
2743                        f"`reference_tensor` '{ref}' not found. Valid input tensor"
2744                        + f" references are: {ipt_refs}."
2745                    )
2746
2747        for out in self.outputs:
2748            for p in out.postprocessing:
2749                ref = p.kwargs.get("reference_tensor")
2750                if ref is None:
2751                    continue
2752
2753                if ref not in ipt_refs and ref not in out_refs:
2754                    raise ValueError(
2755                        f"`reference_tensor` '{ref}' not found. Valid tensor references"
2756                        + f" are: {ipt_refs | out_refs}."
2757                    )
2758
2759        return self
2760
2761    # TODO: use validate funcs in validate_test_tensors
2762    # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]:
2763
2764    name: Annotated[
2765        Annotated[
2766            str, RestrictCharacters(string.ascii_letters + string.digits + "_+- ()")
2767        ],
2768        MinLen(5),
2769        MaxLen(128),
2770        warn(MaxLen(64), "Name longer than 64 characters.", INFO),
2771    ]
2772    """A human-readable name of this model.
2773    It should be no longer than 64 characters
2774    and may only contain letter, number, underscore, minus, parentheses and spaces.
2775    We recommend to chose a name that refers to the model's task and image modality.
2776    """
2777
2778    outputs: NotEmpty[Sequence[OutputTensorDescr]]
2779    """Describes the output tensors."""
2780
2781    @field_validator("outputs", mode="after")
2782    @classmethod
2783    def _validate_tensor_ids(
2784        cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo
2785    ) -> Sequence[OutputTensorDescr]:
2786        tensor_ids = [
2787            t.id for t in info.data.get("inputs", []) + info.data.get("outputs", [])
2788        ]
2789        duplicate_tensor_ids: List[str] = []
2790        seen: Set[str] = set()
2791        for t in tensor_ids:
2792            if t in seen:
2793                duplicate_tensor_ids.append(t)
2794
2795            seen.add(t)
2796
2797        if duplicate_tensor_ids:
2798            raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}")
2799
2800        return outputs
2801
2802    @staticmethod
2803    def _get_axes_with_parameterized_size(
2804        io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2805    ):
2806        return {
2807            f"{t.id}.{a.id}": (t, a, a.size)
2808            for t in io
2809            for a in t.axes
2810            if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize)
2811        }
2812
2813    @staticmethod
2814    def _get_axes_with_independent_size(
2815        io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2816    ):
2817        return {
2818            (t.id, a.id): (t, a, a.size)
2819            for t in io
2820            for a in t.axes
2821            if not isinstance(a, BatchAxis)
2822            and isinstance(a.size, (int, ParameterizedSize))
2823        }
2824
2825    @field_validator("outputs", mode="after")
2826    @classmethod
2827    def _validate_output_axes(
2828        cls, outputs: List[OutputTensorDescr], info: ValidationInfo
2829    ) -> List[OutputTensorDescr]:
2830        input_size_refs = cls._get_axes_with_independent_size(
2831            info.data.get("inputs", [])
2832        )
2833        output_size_refs = cls._get_axes_with_independent_size(outputs)
2834
2835        for i, out in enumerate(outputs):
2836            valid_independent_refs: Dict[
2837                Tuple[TensorId, AxisId],
2838                Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2839            ] = {
2840                **{
2841                    (out.id, a.id): (out, a, a.size)
2842                    for a in out.axes
2843                    if not isinstance(a, BatchAxis)
2844                    and isinstance(a.size, (int, ParameterizedSize))
2845                },
2846                **input_size_refs,
2847                **output_size_refs,
2848            }
2849            for a, ax in enumerate(out.axes):
2850                cls._validate_axis(
2851                    "outputs",
2852                    i,
2853                    out.id,
2854                    a,
2855                    ax,
2856                    valid_independent_refs=valid_independent_refs,
2857                )
2858
2859        return outputs
2860
2861    packaged_by: List[Author] = Field(default_factory=list)
2862    """The persons that have packaged and uploaded this model.
2863    Only required if those persons differ from the `authors`."""
2864
2865    parent: Optional[LinkedModel] = None
2866    """The model from which this model is derived, e.g. by fine-tuning the weights."""
2867
2868    @model_validator(mode="after")
2869    def _validate_parent_is_not_self(self) -> Self:
2870        if self.parent is not None and self.parent.id == self.id:
2871            raise ValueError("A model description may not reference itself as parent.")
2872
2873        return self
2874
2875    run_mode: Annotated[
2876        Optional[RunMode],
2877        warn(None, "Run mode '{value}' has limited support across consumer softwares."),
2878    ] = None
2879    """Custom run mode for this model: for more complex prediction procedures like test time
2880    data augmentation that currently cannot be expressed in the specification.
2881    No standard run modes are defined yet."""
2882
2883    timestamp: Datetime = Field(default_factory=Datetime.now)
2884    """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format
2885    with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat).
2886    (In Python a datetime object is valid, too)."""
2887
2888    training_data: Annotated[
2889        Union[None, LinkedDataset, DatasetDescr, DatasetDescr02],
2890        Field(union_mode="left_to_right"),
2891    ] = None
2892    """The dataset used to train this model"""
2893
2894    weights: Annotated[WeightsDescr, WrapSerializer(package_weights)]
2895    """The weights for this model.
2896    Weights can be given for different formats, but should otherwise be equivalent.
2897    The available weight formats determine which consumers can use this model."""
2898
2899    config: Config = Field(default_factory=Config)
2900
2901    @model_validator(mode="after")
2902    def _add_default_cover(self) -> Self:
2903        if not get_validation_context().perform_io_checks or self.covers:
2904            return self
2905
2906        try:
2907            generated_covers = generate_covers(
2908                [(t, load_array(t.test_tensor.download().path)) for t in self.inputs],
2909                [(t, load_array(t.test_tensor.download().path)) for t in self.outputs],
2910            )
2911        except Exception as e:
2912            issue_warning(
2913                "Failed to generate cover image(s): {e}",
2914                value=self.covers,
2915                msg_context=dict(e=e),
2916                field="covers",
2917            )
2918        else:
2919            self.covers.extend(generated_covers)
2920
2921        return self
2922
2923    def get_input_test_arrays(self) -> List[NDArray[Any]]:
2924        data = [load_array(ipt.test_tensor.download().path) for ipt in self.inputs]
2925        assert all(isinstance(d, np.ndarray) for d in data)
2926        return data
2927
2928    def get_output_test_arrays(self) -> List[NDArray[Any]]:
2929        data = [load_array(out.test_tensor.download().path) for out in self.outputs]
2930        assert all(isinstance(d, np.ndarray) for d in data)
2931        return data
2932
2933    @staticmethod
2934    def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
2935        batch_size = 1
2936        tensor_with_batchsize: Optional[TensorId] = None
2937        for tid in tensor_sizes:
2938            for aid, s in tensor_sizes[tid].items():
2939                if aid != BATCH_AXIS_ID or s == 1 or s == batch_size:
2940                    continue
2941
2942                if batch_size != 1:
2943                    assert tensor_with_batchsize is not None
2944                    raise ValueError(
2945                        f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})"
2946                    )
2947
2948                batch_size = s
2949                tensor_with_batchsize = tid
2950
2951        return batch_size
2952
2953    def get_output_tensor_sizes(
2954        self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]
2955    ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]:
2956        """Returns the tensor output sizes for given **input_sizes**.
2957        Only if **input_sizes** has a valid input shape, the tensor output size is exact.
2958        Otherwise it might be larger than the actual (valid) output"""
2959        batch_size = self.get_batch_size(input_sizes)
2960        ns = self.get_ns(input_sizes)
2961
2962        tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size)
2963        return tensor_sizes.outputs
2964
2965    def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
2966        """get parameter `n` for each parameterized axis
2967        such that the valid input size is >= the given input size"""
2968        ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {}
2969        axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs}
2970        for tid in input_sizes:
2971            for aid, s in input_sizes[tid].items():
2972                size_descr = axes[tid][aid].size
2973                if isinstance(size_descr, ParameterizedSize):
2974                    ret[(tid, aid)] = size_descr.get_n(s)
2975                elif size_descr is None or isinstance(size_descr, (int, SizeReference)):
2976                    pass
2977                else:
2978                    assert_never(size_descr)
2979
2980        return ret
2981
2982    def get_tensor_sizes(
2983        self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int
2984    ) -> _TensorSizes:
2985        axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size)
2986        return _TensorSizes(
2987            {
2988                t: {
2989                    aa: axis_sizes.inputs[(tt, aa)]
2990                    for tt, aa in axis_sizes.inputs
2991                    if tt == t
2992                }
2993                for t in {tt for tt, _ in axis_sizes.inputs}
2994            },
2995            {
2996                t: {
2997                    aa: axis_sizes.outputs[(tt, aa)]
2998                    for tt, aa in axis_sizes.outputs
2999                    if tt == t
3000                }
3001                for t in {tt for tt, _ in axis_sizes.outputs}
3002            },
3003        )
3004
3005    def get_axis_sizes(
3006        self,
3007        ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
3008        batch_size: Optional[int] = None,
3009        *,
3010        max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
3011    ) -> _AxisSizes:
3012        """Determine input and output block shape for scale factors **ns**
3013        of parameterized input sizes.
3014
3015        Args:
3016            ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id))
3017                that is parameterized as `size = min + n * step`.
3018            batch_size: The desired size of the batch dimension.
3019                If given **batch_size** overwrites any batch size present in
3020                **max_input_shape**. Default 1.
3021            max_input_shape: Limits the derived block shapes.
3022                Each axis for which the input size, parameterized by `n`, is larger
3023                than **max_input_shape** is set to the minimal value `n_min` for which
3024                this is still true.
3025                Use this for small input samples or large values of **ns**.
3026                Or simply whenever you know the full input shape.
3027
3028        Returns:
3029            Resolved axis sizes for model inputs and outputs.
3030        """
3031        max_input_shape = max_input_shape or {}
3032        if batch_size is None:
3033            for (_t_id, a_id), s in max_input_shape.items():
3034                if a_id == BATCH_AXIS_ID:
3035                    batch_size = s
3036                    break
3037            else:
3038                batch_size = 1
3039
3040        all_axes = {
3041            t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
3042        }
3043
3044        inputs: Dict[Tuple[TensorId, AxisId], int] = {}
3045        outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {}
3046
3047        def get_axis_size(a: Union[InputAxis, OutputAxis]):
3048            if isinstance(a, BatchAxis):
3049                if (t_descr.id, a.id) in ns:
3050                    logger.warning(
3051                        "Ignoring unexpected size increment factor (n) for batch axis"
3052                        + " of tensor '{}'.",
3053                        t_descr.id,
3054                    )
3055                return batch_size
3056            elif isinstance(a.size, int):
3057                if (t_descr.id, a.id) in ns:
3058                    logger.warning(
3059                        "Ignoring unexpected size increment factor (n) for fixed size"
3060                        + " axis '{}' of tensor '{}'.",
3061                        a.id,
3062                        t_descr.id,
3063                    )
3064                return a.size
3065            elif isinstance(a.size, ParameterizedSize):
3066                if (t_descr.id, a.id) not in ns:
3067                    raise ValueError(
3068                        "Size increment factor (n) missing for parametrized axis"
3069                        + f" '{a.id}' of tensor '{t_descr.id}'."
3070                    )
3071                n = ns[(t_descr.id, a.id)]
3072                s_max = max_input_shape.get((t_descr.id, a.id))
3073                if s_max is not None:
3074                    n = min(n, a.size.get_n(s_max))
3075
3076                return a.size.get_size(n)
3077
3078            elif isinstance(a.size, SizeReference):
3079                if (t_descr.id, a.id) in ns:
3080                    logger.warning(
3081                        "Ignoring unexpected size increment factor (n) for axis '{}'"
3082                        + " of tensor '{}' with size reference.",
3083                        a.id,
3084                        t_descr.id,
3085                    )
3086                assert not isinstance(a, BatchAxis)
3087                ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
3088                assert not isinstance(ref_axis, BatchAxis)
3089                ref_key = (a.size.tensor_id, a.size.axis_id)
3090                ref_size = inputs.get(ref_key, outputs.get(ref_key))
3091                assert ref_size is not None, ref_key
3092                assert not isinstance(ref_size, _DataDepSize), ref_key
3093                return a.size.get_size(
3094                    axis=a,
3095                    ref_axis=ref_axis,
3096                    ref_size=ref_size,
3097                )
3098            elif isinstance(a.size, DataDependentSize):
3099                if (t_descr.id, a.id) in ns:
3100                    logger.warning(
3101                        "Ignoring unexpected increment factor (n) for data dependent"
3102                        + " size axis '{}' of tensor '{}'.",
3103                        a.id,
3104                        t_descr.id,
3105                    )
3106                return _DataDepSize(a.size.min, a.size.max)
3107            else:
3108                assert_never(a.size)
3109
3110        # first resolve all , but the `SizeReference` input sizes
3111        for t_descr in self.inputs:
3112            for a in t_descr.axes:
3113                if not isinstance(a.size, SizeReference):
3114                    s = get_axis_size(a)
3115                    assert not isinstance(s, _DataDepSize)
3116                    inputs[t_descr.id, a.id] = s
3117
3118        # resolve all other input axis sizes
3119        for t_descr in self.inputs:
3120            for a in t_descr.axes:
3121                if isinstance(a.size, SizeReference):
3122                    s = get_axis_size(a)
3123                    assert not isinstance(s, _DataDepSize)
3124                    inputs[t_descr.id, a.id] = s
3125
3126        # resolve all output axis sizes
3127        for t_descr in self.outputs:
3128            for a in t_descr.axes:
3129                assert not isinstance(a.size, ParameterizedSize)
3130                s = get_axis_size(a)
3131                outputs[t_descr.id, a.id] = s
3132
3133        return _AxisSizes(inputs=inputs, outputs=outputs)
3134
3135    @model_validator(mode="before")
3136    @classmethod
3137    def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]:
3138        cls.convert_from_old_format_wo_validation(data)
3139        return data
3140
3141    @classmethod
3142    def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3143        """Convert metadata following an older format version to this classes' format
3144        without validating the result.
3145        """
3146        if (
3147            data.get("type") == "model"
3148            and isinstance(fv := data.get("format_version"), str)
3149            and fv.count(".") == 2
3150        ):
3151            fv_parts = fv.split(".")
3152            if any(not p.isdigit() for p in fv_parts):
3153                return
3154
3155            fv_tuple = tuple(map(int, fv_parts))
3156
3157            assert cls.implemented_format_version_tuple[0:2] == (0, 5)
3158            if fv_tuple[:2] in ((0, 3), (0, 4)):
3159                m04 = _ModelDescr_v0_4.load(data)
3160                if isinstance(m04, InvalidDescr):
3161                    try:
3162                        updated = _model_conv.convert_as_dict(
3163                            m04  # pyright: ignore[reportArgumentType]
3164                        )
3165                    except Exception as e:
3166                        logger.error(
3167                            "Failed to convert from invalid model 0.4 description."
3168                            + f"\nerror: {e}"
3169                            + "\nProceeding with model 0.5 validation without conversion."
3170                        )
3171                        updated = None
3172                else:
3173                    updated = _model_conv.convert_as_dict(m04)
3174
3175                if updated is not None:
3176                    data.clear()
3177                    data.update(updated)
3178
3179            elif fv_tuple[:2] == (0, 5):
3180                # bump patch version
3181                data["format_version"] = cls.implemented_format_version

Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. These fields are typically stored in a YAML file which we call a model resource description file (model RDF).

implemented_format_version: ClassVar[Literal['0.5.4']] = '0.5.4'
implemented_type: ClassVar[Literal['model']] = 'model'
id: Optional[ModelId]

bioimage.io-wide unique resource identifier assigned by bioimage.io; version unspecific.

authors: Annotated[List[bioimageio.spec.generic.v0_3.Author], MinLen(min_length=1)]

The authors are the creators of the model RDF and the primary points of contact.

documentation: Annotated[Union[Annotated[pathlib.Path, PathType(path_type='file'), Predicate(is_absolute), FieldInfo(annotation=NoneType, required=True, title='AbsoluteFilePath')], bioimageio.spec._internal.io.RelativeFilePath, bioimageio.spec._internal.url.HttpUrl], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')]), AfterValidator(func=<function _validate_md_suffix at 0x7f26013f3e20>), PlainSerializer(func=<function _package at 0x7f2602535e40>, return_type=PydanticUndefined, when_used='unless-none'), FieldInfo(annotation=NoneType, required=True, examples=['https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/README.md', 'README.md'])]

∈📦 URL or relative path to a markdown file with additional documentation. The recommended documentation file name is README.md. An .md suffix is mandatory. The documentation should include a '#[#] Validation' (sub)section with details on how to quantitatively validate the model on unseen data.

inputs: Annotated[Sequence[InputTensorDescr], MinLen(min_length=1)]

Describes the input tensors expected by this model.

name: Annotated[str, RestrictCharacters(alphabet='abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_+- ()'), MinLen(min_length=5), MaxLen(max_length=128), AfterWarner(func=<function as_warning.<locals>.wrapper at 0x7f25f3aae5c0>, severity=20, msg='Name longer than 64 characters.', context={'typ': Annotated[Any, MaxLen(max_length=64)]})]

A human-readable name of this model. It should be no longer than 64 characters and may only contain letter, number, underscore, minus, parentheses and spaces. We recommend to chose a name that refers to the model's task and image modality.

outputs: Annotated[Sequence[OutputTensorDescr], MinLen(min_length=1)]

Describes the output tensors.

The persons that have packaged and uploaded this model. Only required if those persons differ from the authors.

parent: Optional[LinkedModel]

The model from which this model is derived, e.g. by fine-tuning the weights.

run_mode: Annotated[Optional[bioimageio.spec.model.v0_4.RunMode], AfterWarner(func=<function as_warning.<locals>.wrapper at 0x7f25f3aaea20>, severity=30, msg="Run mode '{value}' has limited support across consumer softwares.", context={'typ': None})]

Custom run mode for this model: for more complex prediction procedures like test time data augmentation that currently cannot be expressed in the specification. No standard run modes are defined yet.

timestamp: bioimageio.spec._internal.types.Datetime

Timestamp in ISO 8601 format with a few restrictions listed here. (In Python a datetime object is valid, too).

training_data: Annotated[Union[NoneType, bioimageio.spec.dataset.v0_3.LinkedDataset, bioimageio.spec.DatasetDescr, bioimageio.spec.dataset.v0_2.DatasetDescr], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')])]

The dataset used to train this model

weights: Annotated[WeightsDescr, WrapSerializer(func=<function package_weights at 0x7f25f8f8df80>, return_type=PydanticUndefined, when_used='always')]

The weights for this model. Weights can be given for different formats, but should otherwise be equivalent. The available weight formats determine which consumers can use this model.

config: Config
def get_input_test_arrays(self) -> List[numpy.ndarray[Any, numpy.dtype[Any]]]:
2923    def get_input_test_arrays(self) -> List[NDArray[Any]]:
2924        data = [load_array(ipt.test_tensor.download().path) for ipt in self.inputs]
2925        assert all(isinstance(d, np.ndarray) for d in data)
2926        return data
def get_output_test_arrays(self) -> List[numpy.ndarray[Any, numpy.dtype[Any]]]:
2928    def get_output_test_arrays(self) -> List[NDArray[Any]]:
2929        data = [load_array(out.test_tensor.download().path) for out in self.outputs]
2930        assert all(isinstance(d, np.ndarray) for d in data)
2931        return data
@staticmethod
def get_batch_size( tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
2933    @staticmethod
2934    def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
2935        batch_size = 1
2936        tensor_with_batchsize: Optional[TensorId] = None
2937        for tid in tensor_sizes:
2938            for aid, s in tensor_sizes[tid].items():
2939                if aid != BATCH_AXIS_ID or s == 1 or s == batch_size:
2940                    continue
2941
2942                if batch_size != 1:
2943                    assert tensor_with_batchsize is not None
2944                    raise ValueError(
2945                        f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})"
2946                    )
2947
2948                batch_size = s
2949                tensor_with_batchsize = tid
2950
2951        return batch_size
def get_output_tensor_sizes( self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> Dict[TensorId, Dict[AxisId, Union[int, bioimageio.spec.model.v0_5._DataDepSize]]]:
2953    def get_output_tensor_sizes(
2954        self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]
2955    ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]:
2956        """Returns the tensor output sizes for given **input_sizes**.
2957        Only if **input_sizes** has a valid input shape, the tensor output size is exact.
2958        Otherwise it might be larger than the actual (valid) output"""
2959        batch_size = self.get_batch_size(input_sizes)
2960        ns = self.get_ns(input_sizes)
2961
2962        tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size)
2963        return tensor_sizes.outputs

Returns the tensor output sizes for given input_sizes. Only if input_sizes has a valid input shape, the tensor output size is exact. Otherwise it might be larger than the actual (valid) output

def get_ns( self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
2965    def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
2966        """get parameter `n` for each parameterized axis
2967        such that the valid input size is >= the given input size"""
2968        ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {}
2969        axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs}
2970        for tid in input_sizes:
2971            for aid, s in input_sizes[tid].items():
2972                size_descr = axes[tid][aid].size
2973                if isinstance(size_descr, ParameterizedSize):
2974                    ret[(tid, aid)] = size_descr.get_n(s)
2975                elif size_descr is None or isinstance(size_descr, (int, SizeReference)):
2976                    pass
2977                else:
2978                    assert_never(size_descr)
2979
2980        return ret

get parameter n for each parameterized axis such that the valid input size is >= the given input size

def get_tensor_sizes( self, ns: Mapping[Tuple[TensorId, AxisId], int], batch_size: int) -> bioimageio.spec.model.v0_5._TensorSizes:
2982    def get_tensor_sizes(
2983        self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int
2984    ) -> _TensorSizes:
2985        axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size)
2986        return _TensorSizes(
2987            {
2988                t: {
2989                    aa: axis_sizes.inputs[(tt, aa)]
2990                    for tt, aa in axis_sizes.inputs
2991                    if tt == t
2992                }
2993                for t in {tt for tt, _ in axis_sizes.inputs}
2994            },
2995            {
2996                t: {
2997                    aa: axis_sizes.outputs[(tt, aa)]
2998                    for tt, aa in axis_sizes.outputs
2999                    if tt == t
3000                }
3001                for t in {tt for tt, _ in axis_sizes.outputs}
3002            },
3003        )
def get_axis_sizes( self, ns: Mapping[Tuple[TensorId, AxisId], int], batch_size: Optional[int] = None, *, max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None) -> bioimageio.spec.model.v0_5._AxisSizes:
3005    def get_axis_sizes(
3006        self,
3007        ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
3008        batch_size: Optional[int] = None,
3009        *,
3010        max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
3011    ) -> _AxisSizes:
3012        """Determine input and output block shape for scale factors **ns**
3013        of parameterized input sizes.
3014
3015        Args:
3016            ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id))
3017                that is parameterized as `size = min + n * step`.
3018            batch_size: The desired size of the batch dimension.
3019                If given **batch_size** overwrites any batch size present in
3020                **max_input_shape**. Default 1.
3021            max_input_shape: Limits the derived block shapes.
3022                Each axis for which the input size, parameterized by `n`, is larger
3023                than **max_input_shape** is set to the minimal value `n_min` for which
3024                this is still true.
3025                Use this for small input samples or large values of **ns**.
3026                Or simply whenever you know the full input shape.
3027
3028        Returns:
3029            Resolved axis sizes for model inputs and outputs.
3030        """
3031        max_input_shape = max_input_shape or {}
3032        if batch_size is None:
3033            for (_t_id, a_id), s in max_input_shape.items():
3034                if a_id == BATCH_AXIS_ID:
3035                    batch_size = s
3036                    break
3037            else:
3038                batch_size = 1
3039
3040        all_axes = {
3041            t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
3042        }
3043
3044        inputs: Dict[Tuple[TensorId, AxisId], int] = {}
3045        outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {}
3046
3047        def get_axis_size(a: Union[InputAxis, OutputAxis]):
3048            if isinstance(a, BatchAxis):
3049                if (t_descr.id, a.id) in ns:
3050                    logger.warning(
3051                        "Ignoring unexpected size increment factor (n) for batch axis"
3052                        + " of tensor '{}'.",
3053                        t_descr.id,
3054                    )
3055                return batch_size
3056            elif isinstance(a.size, int):
3057                if (t_descr.id, a.id) in ns:
3058                    logger.warning(
3059                        "Ignoring unexpected size increment factor (n) for fixed size"
3060                        + " axis '{}' of tensor '{}'.",
3061                        a.id,
3062                        t_descr.id,
3063                    )
3064                return a.size
3065            elif isinstance(a.size, ParameterizedSize):
3066                if (t_descr.id, a.id) not in ns:
3067                    raise ValueError(
3068                        "Size increment factor (n) missing for parametrized axis"
3069                        + f" '{a.id}' of tensor '{t_descr.id}'."
3070                    )
3071                n = ns[(t_descr.id, a.id)]
3072                s_max = max_input_shape.get((t_descr.id, a.id))
3073                if s_max is not None:
3074                    n = min(n, a.size.get_n(s_max))
3075
3076                return a.size.get_size(n)
3077
3078            elif isinstance(a.size, SizeReference):
3079                if (t_descr.id, a.id) in ns:
3080                    logger.warning(
3081                        "Ignoring unexpected size increment factor (n) for axis '{}'"
3082                        + " of tensor '{}' with size reference.",
3083                        a.id,
3084                        t_descr.id,
3085                    )
3086                assert not isinstance(a, BatchAxis)
3087                ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
3088                assert not isinstance(ref_axis, BatchAxis)
3089                ref_key = (a.size.tensor_id, a.size.axis_id)
3090                ref_size = inputs.get(ref_key, outputs.get(ref_key))
3091                assert ref_size is not None, ref_key
3092                assert not isinstance(ref_size, _DataDepSize), ref_key
3093                return a.size.get_size(
3094                    axis=a,
3095                    ref_axis=ref_axis,
3096                    ref_size=ref_size,
3097                )
3098            elif isinstance(a.size, DataDependentSize):
3099                if (t_descr.id, a.id) in ns:
3100                    logger.warning(
3101                        "Ignoring unexpected increment factor (n) for data dependent"
3102                        + " size axis '{}' of tensor '{}'.",
3103                        a.id,
3104                        t_descr.id,
3105                    )
3106                return _DataDepSize(a.size.min, a.size.max)
3107            else:
3108                assert_never(a.size)
3109
3110        # first resolve all , but the `SizeReference` input sizes
3111        for t_descr in self.inputs:
3112            for a in t_descr.axes:
3113                if not isinstance(a.size, SizeReference):
3114                    s = get_axis_size(a)
3115                    assert not isinstance(s, _DataDepSize)
3116                    inputs[t_descr.id, a.id] = s
3117
3118        # resolve all other input axis sizes
3119        for t_descr in self.inputs:
3120            for a in t_descr.axes:
3121                if isinstance(a.size, SizeReference):
3122                    s = get_axis_size(a)
3123                    assert not isinstance(s, _DataDepSize)
3124                    inputs[t_descr.id, a.id] = s
3125
3126        # resolve all output axis sizes
3127        for t_descr in self.outputs:
3128            for a in t_descr.axes:
3129                assert not isinstance(a.size, ParameterizedSize)
3130                s = get_axis_size(a)
3131                outputs[t_descr.id, a.id] = s
3132
3133        return _AxisSizes(inputs=inputs, outputs=outputs)

Determine input and output block shape for scale factors ns of parameterized input sizes.

Arguments:
  • ns: Scale factor n for each axis (keyed by (tensor_id, axis_id)) that is parameterized as size = min + n * step.
  • batch_size: The desired size of the batch dimension. If given batch_size overwrites any batch size present in max_input_shape. Default 1.
  • max_input_shape: Limits the derived block shapes. Each axis for which the input size, parameterized by n, is larger than max_input_shape is set to the minimal value n_min for which this is still true. Use this for small input samples or large values of ns. Or simply whenever you know the full input shape.
Returns:

Resolved axis sizes for model inputs and outputs.

@classmethod
def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3141    @classmethod
3142    def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3143        """Convert metadata following an older format version to this classes' format
3144        without validating the result.
3145        """
3146        if (
3147            data.get("type") == "model"
3148            and isinstance(fv := data.get("format_version"), str)
3149            and fv.count(".") == 2
3150        ):
3151            fv_parts = fv.split(".")
3152            if any(not p.isdigit() for p in fv_parts):
3153                return
3154
3155            fv_tuple = tuple(map(int, fv_parts))
3156
3157            assert cls.implemented_format_version_tuple[0:2] == (0, 5)
3158            if fv_tuple[:2] in ((0, 3), (0, 4)):
3159                m04 = _ModelDescr_v0_4.load(data)
3160                if isinstance(m04, InvalidDescr):
3161                    try:
3162                        updated = _model_conv.convert_as_dict(
3163                            m04  # pyright: ignore[reportArgumentType]
3164                        )
3165                    except Exception as e:
3166                        logger.error(
3167                            "Failed to convert from invalid model 0.4 description."
3168                            + f"\nerror: {e}"
3169                            + "\nProceeding with model 0.5 validation without conversion."
3170                        )
3171                        updated = None
3172                else:
3173                    updated = _model_conv.convert_as_dict(m04)
3174
3175                if updated is not None:
3176                    data.clear()
3177                    data.update(updated)
3178
3179            elif fv_tuple[:2] == (0, 5):
3180                # bump patch version
3181                data["format_version"] = cls.implemented_format_version

Convert metadata following an older format version to this classes' format without validating the result.

implemented_format_version_tuple: ClassVar[Tuple[int, int, int]] = (0, 5, 4)
def model_post_init(self: pydantic.main.BaseModel, context: Any, /) -> None:
124                    def wrapped_model_post_init(self: BaseModel, context: Any, /) -> None:
125                        """We need to both initialize private attributes and call the user-defined model_post_init
126                        method.
127                        """
128                        init_private_attributes(self, context)
129                        original_model_post_init(self, context)

We need to both initialize private attributes and call the user-defined model_post_init method.

def generate_covers( inputs: Sequence[Tuple[InputTensorDescr, numpy.ndarray[Any, numpy.dtype[Any]]]], outputs: Sequence[Tuple[OutputTensorDescr, numpy.ndarray[Any, numpy.dtype[Any]]]]) -> List[pathlib.Path]:
3406def generate_covers(
3407    inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]],
3408    outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]],
3409) -> List[Path]:
3410    def squeeze(
3411        data: NDArray[Any], axes: Sequence[AnyAxis]
3412    ) -> Tuple[NDArray[Any], List[AnyAxis]]:
3413        """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining"""
3414        if data.ndim != len(axes):
3415            raise ValueError(
3416                f"tensor shape {data.shape} does not match described axes"
3417                + f" {[a.id for a in axes]}"
3418            )
3419
3420        axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1]
3421        return data.squeeze(), axes
3422
3423    def normalize(
3424        data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7
3425    ) -> NDArray[np.float32]:
3426        data = data.astype("float32")
3427        data -= data.min(axis=axis, keepdims=True)
3428        data /= data.max(axis=axis, keepdims=True) + eps
3429        return data
3430
3431    def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]):
3432        original_shape = data.shape
3433        data, axes = squeeze(data, axes)
3434
3435        # take slice fom any batch or index axis if needed
3436        # and convert the first channel axis and take a slice from any additional channel axes
3437        slices: Tuple[slice, ...] = ()
3438        ndim = data.ndim
3439        ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2
3440        has_c_axis = False
3441        for i, a in enumerate(axes):
3442            s = data.shape[i]
3443            assert s > 1
3444            if (
3445                isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis))
3446                and ndim > ndim_need
3447            ):
3448                data = data[slices + (slice(s // 2 - 1, s // 2),)]
3449                ndim -= 1
3450            elif isinstance(a, ChannelAxis):
3451                if has_c_axis:
3452                    # second channel axis
3453                    data = data[slices + (slice(0, 1),)]
3454                    ndim -= 1
3455                else:
3456                    has_c_axis = True
3457                    if s == 2:
3458                        # visualize two channels with cyan and magenta
3459                        data = np.concatenate(
3460                            [
3461                                data[slices + (slice(1, 2),)],
3462                                data[slices + (slice(0, 1),)],
3463                                (
3464                                    data[slices + (slice(0, 1),)]
3465                                    + data[slices + (slice(1, 2),)]
3466                                )
3467                                / 2,  # TODO: take maximum instead?
3468                            ],
3469                            axis=i,
3470                        )
3471                    elif data.shape[i] == 3:
3472                        pass  # visualize 3 channels as RGB
3473                    else:
3474                        # visualize first 3 channels as RGB
3475                        data = data[slices + (slice(3),)]
3476
3477                    assert data.shape[i] == 3
3478
3479            slices += (slice(None),)
3480
3481        data, axes = squeeze(data, axes)
3482        assert len(axes) == ndim
3483        # take slice from z axis if needed
3484        slices = ()
3485        if ndim > ndim_need:
3486            for i, a in enumerate(axes):
3487                s = data.shape[i]
3488                if a.id == AxisId("z"):
3489                    data = data[slices + (slice(s // 2 - 1, s // 2),)]
3490                    data, axes = squeeze(data, axes)
3491                    ndim -= 1
3492                    break
3493
3494            slices += (slice(None),)
3495
3496        # take slice from any space or time axis
3497        slices = ()
3498
3499        for i, a in enumerate(axes):
3500            if ndim <= ndim_need:
3501                break
3502
3503            s = data.shape[i]
3504            assert s > 1
3505            if isinstance(
3506                a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis)
3507            ):
3508                data = data[slices + (slice(s // 2 - 1, s // 2),)]
3509                ndim -= 1
3510
3511            slices += (slice(None),)
3512
3513        del slices
3514        data, axes = squeeze(data, axes)
3515        assert len(axes) == ndim
3516
3517        if (has_c_axis and ndim != 3) or ndim != 2:
3518            raise ValueError(
3519                f"Failed to construct cover image from shape {original_shape}"
3520            )
3521
3522        if not has_c_axis:
3523            assert ndim == 2
3524            data = np.repeat(data[:, :, None], 3, axis=2)
3525            axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB"))))
3526            ndim += 1
3527
3528        assert ndim == 3
3529
3530        # transpose axis order such that longest axis comes first...
3531        axis_order = list(np.argsort(list(data.shape)))
3532        axis_order.reverse()
3533        # ... and channel axis is last
3534        c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0]
3535        axis_order.append(axis_order.pop(c))
3536        axes = [axes[ao] for ao in axis_order]
3537        data = data.transpose(axis_order)
3538
3539        # h, w = data.shape[:2]
3540        # if h / w  in (1.0 or 2.0):
3541        #     pass
3542        # elif h / w < 2:
3543        # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images
3544
3545        norm_along = (
3546            tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None
3547        )
3548        # normalize the data and map to 8 bit
3549        data = normalize(data, norm_along)
3550        data = (data * 255).astype("uint8")
3551
3552        return data
3553
3554    def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]):
3555        assert im0.dtype == im1.dtype == np.uint8
3556        assert im0.shape == im1.shape
3557        assert im0.ndim == 3
3558        N, M, C = im0.shape
3559        assert C == 3
3560        out = np.ones((N, M, C), dtype="uint8")
3561        for c in range(C):
3562            outc = np.tril(im0[..., c])
3563            mask = outc == 0
3564            outc[mask] = np.triu(im1[..., c])[mask]
3565            out[..., c] = outc
3566
3567        return out
3568
3569    ipt_descr, ipt = inputs[0]
3570    out_descr, out = outputs[0]
3571
3572    ipt_img = to_2d_image(ipt, ipt_descr.axes)
3573    out_img = to_2d_image(out, out_descr.axes)
3574
3575    cover_folder = Path(mkdtemp())
3576    if ipt_img.shape == out_img.shape:
3577        covers = [cover_folder / "cover.png"]
3578        imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img))
3579    else:
3580        covers = [cover_folder / "input.png", cover_folder / "output.png"]
3581        imwrite(covers[0], ipt_img)
3582        imwrite(covers[1], out_img)
3583
3584    return covers
class TensorDescrBase[Annotated[Union[BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis], Discriminator]](bioimageio.spec._internal.node.Node, typing.Generic[~IO_AxisT]):
class TensorDescrBase[Annotated[Union[BatchAxis, ChannelAxis, IndexOutputAxis, Annotated[Union[Annotated[TimeOutputAxis, Tag], Annotated[TimeOutputAxisWithHalo, Tag]], Discriminator], Annotated[Union[Annotated[SpaceOutputAxis, Tag], Annotated[SpaceOutputAxisWithHalo, Tag]], Discriminator]], Discriminator]](bioimageio.spec._internal.node.Node, typing.Generic[~IO_AxisT]):