bioimageio.spec.model.v0_5

   1from __future__ import annotations
   2
   3import collections.abc
   4import re
   5import string
   6import warnings
   7from abc import ABC
   8from copy import deepcopy
   9from itertools import chain
  10from math import ceil
  11from pathlib import Path, PurePosixPath
  12from tempfile import mkdtemp
  13from typing import (
  14    TYPE_CHECKING,
  15    Any,
  16    Callable,
  17    ClassVar,
  18    Dict,
  19    Generic,
  20    List,
  21    Literal,
  22    Mapping,
  23    NamedTuple,
  24    Optional,
  25    Sequence,
  26    Set,
  27    Tuple,
  28    Type,
  29    TypeVar,
  30    Union,
  31    cast,
  32)
  33
  34import numpy as np
  35from annotated_types import Ge, Gt, Interval, MaxLen, MinLen, Predicate
  36from imageio.v3 import imread, imwrite  # pyright: ignore[reportUnknownVariableType]
  37from loguru import logger
  38from numpy.typing import NDArray
  39from pydantic import (
  40    AfterValidator,
  41    Discriminator,
  42    Field,
  43    RootModel,
  44    SerializationInfo,
  45    SerializerFunctionWrapHandler,
  46    Tag,
  47    ValidationInfo,
  48    WrapSerializer,
  49    field_validator,
  50    model_serializer,
  51    model_validator,
  52)
  53from typing_extensions import Annotated, Self, assert_never, get_args
  54
  55from .._internal.common_nodes import (
  56    InvalidDescr,
  57    Node,
  58    NodeWithExplicitlySetFields,
  59)
  60from .._internal.constants import DTYPE_LIMITS
  61from .._internal.field_warning import issue_warning, warn
  62from .._internal.io import BioimageioYamlContent as BioimageioYamlContent
  63from .._internal.io import FileDescr as FileDescr
  64from .._internal.io import (
  65    FileSource,
  66    WithSuffix,
  67    YamlValue,
  68    get_reader,
  69    wo_special_file_name,
  70)
  71from .._internal.io_basics import Sha256 as Sha256
  72from .._internal.io_packaging import (
  73    FileDescr_,
  74    FileSource_,
  75    package_file_descr_serializer,
  76)
  77from .._internal.io_utils import load_array
  78from .._internal.node_converter import Converter
  79from .._internal.types import (
  80    AbsoluteTolerance,
  81    LowerCaseIdentifier,
  82    LowerCaseIdentifierAnno,
  83    MismatchedElementsPerMillion,
  84    RelativeTolerance,
  85)
  86from .._internal.types import Datetime as Datetime
  87from .._internal.types import Identifier as Identifier
  88from .._internal.types import NotEmpty as NotEmpty
  89from .._internal.types import SiUnit as SiUnit
  90from .._internal.url import HttpUrl as HttpUrl
  91from .._internal.validation_context import get_validation_context
  92from .._internal.validator_annotations import RestrictCharacters
  93from .._internal.version_type import Version as Version
  94from .._internal.warning_levels import INFO
  95from ..dataset.v0_2 import DatasetDescr as DatasetDescr02
  96from ..dataset.v0_2 import LinkedDataset as LinkedDataset02
  97from ..dataset.v0_3 import DatasetDescr as DatasetDescr
  98from ..dataset.v0_3 import DatasetId as DatasetId
  99from ..dataset.v0_3 import LinkedDataset as LinkedDataset
 100from ..dataset.v0_3 import Uploader as Uploader
 101from ..generic.v0_3 import (
 102    VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS,
 103)
 104from ..generic.v0_3 import Author as Author
 105from ..generic.v0_3 import BadgeDescr as BadgeDescr
 106from ..generic.v0_3 import CiteEntry as CiteEntry
 107from ..generic.v0_3 import DeprecatedLicenseId as DeprecatedLicenseId
 108from ..generic.v0_3 import Doi as Doi
 109from ..generic.v0_3 import (
 110    FileSource_documentation,
 111    GenericModelDescrBase,
 112    LinkedResourceBase,
 113    _author_conv,  # pyright: ignore[reportPrivateUsage]
 114    _maintainer_conv,  # pyright: ignore[reportPrivateUsage]
 115)
 116from ..generic.v0_3 import LicenseId as LicenseId
 117from ..generic.v0_3 import LinkedResource as LinkedResource
 118from ..generic.v0_3 import Maintainer as Maintainer
 119from ..generic.v0_3 import OrcidId as OrcidId
 120from ..generic.v0_3 import RelativeFilePath as RelativeFilePath
 121from ..generic.v0_3 import ResourceId as ResourceId
 122from .v0_4 import Author as _Author_v0_4
 123from .v0_4 import BinarizeDescr as _BinarizeDescr_v0_4
 124from .v0_4 import CallableFromDepencency as CallableFromDepencency
 125from .v0_4 import CallableFromDepencency as _CallableFromDepencency_v0_4
 126from .v0_4 import CallableFromFile as _CallableFromFile_v0_4
 127from .v0_4 import ClipDescr as _ClipDescr_v0_4
 128from .v0_4 import ClipKwargs as ClipKwargs
 129from .v0_4 import ImplicitOutputShape as _ImplicitOutputShape_v0_4
 130from .v0_4 import InputTensorDescr as _InputTensorDescr_v0_4
 131from .v0_4 import KnownRunMode as KnownRunMode
 132from .v0_4 import ModelDescr as _ModelDescr_v0_4
 133from .v0_4 import OutputTensorDescr as _OutputTensorDescr_v0_4
 134from .v0_4 import ParameterizedInputShape as _ParameterizedInputShape_v0_4
 135from .v0_4 import PostprocessingDescr as _PostprocessingDescr_v0_4
 136from .v0_4 import PreprocessingDescr as _PreprocessingDescr_v0_4
 137from .v0_4 import ProcessingKwargs as ProcessingKwargs
 138from .v0_4 import RunMode as RunMode
 139from .v0_4 import ScaleLinearDescr as _ScaleLinearDescr_v0_4
 140from .v0_4 import ScaleMeanVarianceDescr as _ScaleMeanVarianceDescr_v0_4
 141from .v0_4 import ScaleRangeDescr as _ScaleRangeDescr_v0_4
 142from .v0_4 import SigmoidDescr as _SigmoidDescr_v0_4
 143from .v0_4 import TensorName as _TensorName_v0_4
 144from .v0_4 import WeightsFormat as WeightsFormat
 145from .v0_4 import ZeroMeanUnitVarianceDescr as _ZeroMeanUnitVarianceDescr_v0_4
 146from .v0_4 import package_weights
 147
 148SpaceUnit = Literal[
 149    "attometer",
 150    "angstrom",
 151    "centimeter",
 152    "decimeter",
 153    "exameter",
 154    "femtometer",
 155    "foot",
 156    "gigameter",
 157    "hectometer",
 158    "inch",
 159    "kilometer",
 160    "megameter",
 161    "meter",
 162    "micrometer",
 163    "mile",
 164    "millimeter",
 165    "nanometer",
 166    "parsec",
 167    "petameter",
 168    "picometer",
 169    "terameter",
 170    "yard",
 171    "yoctometer",
 172    "yottameter",
 173    "zeptometer",
 174    "zettameter",
 175]
 176"""Space unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)"""
 177
 178TimeUnit = Literal[
 179    "attosecond",
 180    "centisecond",
 181    "day",
 182    "decisecond",
 183    "exasecond",
 184    "femtosecond",
 185    "gigasecond",
 186    "hectosecond",
 187    "hour",
 188    "kilosecond",
 189    "megasecond",
 190    "microsecond",
 191    "millisecond",
 192    "minute",
 193    "nanosecond",
 194    "petasecond",
 195    "picosecond",
 196    "second",
 197    "terasecond",
 198    "yoctosecond",
 199    "yottasecond",
 200    "zeptosecond",
 201    "zettasecond",
 202]
 203"""Time unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)"""
 204
 205AxisType = Literal["batch", "channel", "index", "time", "space"]
 206
 207_AXIS_TYPE_MAP: Mapping[str, AxisType] = {
 208    "b": "batch",
 209    "t": "time",
 210    "i": "index",
 211    "c": "channel",
 212    "x": "space",
 213    "y": "space",
 214    "z": "space",
 215}
 216
 217_AXIS_ID_MAP = {
 218    "b": "batch",
 219    "t": "time",
 220    "i": "index",
 221    "c": "channel",
 222}
 223
 224
 225class TensorId(LowerCaseIdentifier):
 226    root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
 227        Annotated[LowerCaseIdentifierAnno, MaxLen(32)]
 228    ]
 229
 230
 231def _normalize_axis_id(a: str):
 232    a = str(a)
 233    normalized = _AXIS_ID_MAP.get(a, a)
 234    if a != normalized:
 235        logger.opt(depth=3).warning(
 236            "Normalized axis id from '{}' to '{}'.", a, normalized
 237        )
 238    return normalized
 239
 240
 241class AxisId(LowerCaseIdentifier):
 242    root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
 243        Annotated[
 244            LowerCaseIdentifierAnno,
 245            MaxLen(16),
 246            AfterValidator(_normalize_axis_id),
 247        ]
 248    ]
 249
 250
 251def _is_batch(a: str) -> bool:
 252    return str(a) == "batch"
 253
 254
 255def _is_not_batch(a: str) -> bool:
 256    return not _is_batch(a)
 257
 258
 259NonBatchAxisId = Annotated[AxisId, Predicate(_is_not_batch)]
 260
 261PostprocessingId = Literal[
 262    "binarize",
 263    "clip",
 264    "ensure_dtype",
 265    "fixed_zero_mean_unit_variance",
 266    "scale_linear",
 267    "scale_mean_variance",
 268    "scale_range",
 269    "sigmoid",
 270    "zero_mean_unit_variance",
 271]
 272PreprocessingId = Literal[
 273    "binarize",
 274    "clip",
 275    "ensure_dtype",
 276    "scale_linear",
 277    "sigmoid",
 278    "zero_mean_unit_variance",
 279    "scale_range",
 280]
 281
 282
 283SAME_AS_TYPE = "<same as type>"
 284
 285
 286ParameterizedSize_N = int
 287"""
 288Annotates an integer to calculate a concrete axis size from a `ParameterizedSize`.
 289"""
 290
 291
 292class ParameterizedSize(Node):
 293    """Describes a range of valid tensor axis sizes as `size = min + n*step`.
 294
 295    - **min** and **step** are given by the model description.
 296    - All blocksize paramters n = 0,1,2,... yield a valid `size`.
 297    - A greater blocksize paramter n = 0,1,2,... results in a greater **size**.
 298      This allows to adjust the axis size more generically.
 299    """
 300
 301    N: ClassVar[Type[int]] = ParameterizedSize_N
 302    """Positive integer to parameterize this axis"""
 303
 304    min: Annotated[int, Gt(0)]
 305    step: Annotated[int, Gt(0)]
 306
 307    def validate_size(self, size: int) -> int:
 308        if size < self.min:
 309            raise ValueError(f"size {size} < {self.min}")
 310        if (size - self.min) % self.step != 0:
 311            raise ValueError(
 312                f"axis of size {size} is not parameterized by `min + n*step` ="
 313                + f" `{self.min} + n*{self.step}`"
 314            )
 315
 316        return size
 317
 318    def get_size(self, n: ParameterizedSize_N) -> int:
 319        return self.min + self.step * n
 320
 321    def get_n(self, s: int) -> ParameterizedSize_N:
 322        """return smallest n parameterizing a size greater or equal than `s`"""
 323        return ceil((s - self.min) / self.step)
 324
 325
 326class DataDependentSize(Node):
 327    min: Annotated[int, Gt(0)] = 1
 328    max: Annotated[Optional[int], Gt(1)] = None
 329
 330    @model_validator(mode="after")
 331    def _validate_max_gt_min(self):
 332        if self.max is not None and self.min >= self.max:
 333            raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}")
 334
 335        return self
 336
 337    def validate_size(self, size: int) -> int:
 338        if size < self.min:
 339            raise ValueError(f"size {size} < {self.min}")
 340
 341        if self.max is not None and size > self.max:
 342            raise ValueError(f"size {size} > {self.max}")
 343
 344        return size
 345
 346
 347class SizeReference(Node):
 348    """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis.
 349
 350    `axis.size = reference.size * reference.scale / axis.scale + offset`
 351
 352    Note:
 353    1. The axis and the referenced axis need to have the same unit (or no unit).
 354    2. Batch axes may not be referenced.
 355    3. Fractions are rounded down.
 356    4. If the reference axis is `concatenable` the referencing axis is assumed to be
 357        `concatenable` as well with the same block order.
 358
 359    Example:
 360    An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm².
 361    Let's assume that we want to express the image height h in relation to its width w
 362    instead of only accepting input images of exactly 100*49 pixels
 363    (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`).
 364
 365    >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2)
 366    >>> h = SpaceInputAxis(
 367    ...     id=AxisId("h"),
 368    ...     size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1),
 369    ...     unit="millimeter",
 370    ...     scale=4,
 371    ... )
 372    >>> print(h.size.get_size(h, w))
 373    49
 374
 375    ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49
 376    """
 377
 378    tensor_id: TensorId
 379    """tensor id of the reference axis"""
 380
 381    axis_id: AxisId
 382    """axis id of the reference axis"""
 383
 384    offset: int = 0
 385
 386    def get_size(
 387        self,
 388        axis: Union[
 389            ChannelAxis,
 390            IndexInputAxis,
 391            IndexOutputAxis,
 392            TimeInputAxis,
 393            SpaceInputAxis,
 394            TimeOutputAxis,
 395            TimeOutputAxisWithHalo,
 396            SpaceOutputAxis,
 397            SpaceOutputAxisWithHalo,
 398        ],
 399        ref_axis: Union[
 400            ChannelAxis,
 401            IndexInputAxis,
 402            IndexOutputAxis,
 403            TimeInputAxis,
 404            SpaceInputAxis,
 405            TimeOutputAxis,
 406            TimeOutputAxisWithHalo,
 407            SpaceOutputAxis,
 408            SpaceOutputAxisWithHalo,
 409        ],
 410        n: ParameterizedSize_N = 0,
 411        ref_size: Optional[int] = None,
 412    ):
 413        """Compute the concrete size for a given axis and its reference axis.
 414
 415        Args:
 416            axis: The axis this `SizeReference` is the size of.
 417            ref_axis: The reference axis to compute the size from.
 418            n: If the **ref_axis** is parameterized (of type `ParameterizedSize`)
 419                and no fixed **ref_size** is given,
 420                **n** is used to compute the size of the parameterized **ref_axis**.
 421            ref_size: Overwrite the reference size instead of deriving it from
 422                **ref_axis**
 423                (**ref_axis.scale** is still used; any given **n** is ignored).
 424        """
 425        assert (
 426            axis.size == self
 427        ), "Given `axis.size` is not defined by this `SizeReference`"
 428
 429        assert (
 430            ref_axis.id == self.axis_id
 431        ), f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}."
 432
 433        assert axis.unit == ref_axis.unit, (
 434            "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`,"
 435            f" but {axis.unit}!={ref_axis.unit}"
 436        )
 437        if ref_size is None:
 438            if isinstance(ref_axis.size, (int, float)):
 439                ref_size = ref_axis.size
 440            elif isinstance(ref_axis.size, ParameterizedSize):
 441                ref_size = ref_axis.size.get_size(n)
 442            elif isinstance(ref_axis.size, DataDependentSize):
 443                raise ValueError(
 444                    "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`."
 445                )
 446            elif isinstance(ref_axis.size, SizeReference):
 447                raise ValueError(
 448                    "Reference axis referenced in `SizeReference` may not be sized by a"
 449                    + " `SizeReference` itself."
 450                )
 451            else:
 452                assert_never(ref_axis.size)
 453
 454        return int(ref_size * ref_axis.scale / axis.scale + self.offset)
 455
 456    @staticmethod
 457    def _get_unit(
 458        axis: Union[
 459            ChannelAxis,
 460            IndexInputAxis,
 461            IndexOutputAxis,
 462            TimeInputAxis,
 463            SpaceInputAxis,
 464            TimeOutputAxis,
 465            TimeOutputAxisWithHalo,
 466            SpaceOutputAxis,
 467            SpaceOutputAxisWithHalo,
 468        ],
 469    ):
 470        return axis.unit
 471
 472
 473class AxisBase(NodeWithExplicitlySetFields):
 474    id: AxisId
 475    """An axis id unique across all axes of one tensor."""
 476
 477    description: Annotated[str, MaxLen(128)] = ""
 478
 479
 480class WithHalo(Node):
 481    halo: Annotated[int, Ge(1)]
 482    """The halo should be cropped from the output tensor to avoid boundary effects.
 483    It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`.
 484    To document a halo that is already cropped by the model use `size.offset` instead."""
 485
 486    size: Annotated[
 487        SizeReference,
 488        Field(
 489            examples=[
 490                10,
 491                SizeReference(
 492                    tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
 493                ).model_dump(mode="json"),
 494            ]
 495        ),
 496    ]
 497    """reference to another axis with an optional offset (see `SizeReference`)"""
 498
 499
 500BATCH_AXIS_ID = AxisId("batch")
 501
 502
 503class BatchAxis(AxisBase):
 504    implemented_type: ClassVar[Literal["batch"]] = "batch"
 505    if TYPE_CHECKING:
 506        type: Literal["batch"] = "batch"
 507    else:
 508        type: Literal["batch"]
 509
 510    id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID
 511    size: Optional[Literal[1]] = None
 512    """The batch size may be fixed to 1,
 513    otherwise (the default) it may be chosen arbitrarily depending on available memory"""
 514
 515    @property
 516    def scale(self):
 517        return 1.0
 518
 519    @property
 520    def concatenable(self):
 521        return True
 522
 523    @property
 524    def unit(self):
 525        return None
 526
 527
 528class ChannelAxis(AxisBase):
 529    implemented_type: ClassVar[Literal["channel"]] = "channel"
 530    if TYPE_CHECKING:
 531        type: Literal["channel"] = "channel"
 532    else:
 533        type: Literal["channel"]
 534
 535    id: NonBatchAxisId = AxisId("channel")
 536    channel_names: NotEmpty[List[Identifier]]
 537
 538    @property
 539    def size(self) -> int:
 540        return len(self.channel_names)
 541
 542    @property
 543    def concatenable(self):
 544        return False
 545
 546    @property
 547    def scale(self) -> float:
 548        return 1.0
 549
 550    @property
 551    def unit(self):
 552        return None
 553
 554
 555class IndexAxisBase(AxisBase):
 556    implemented_type: ClassVar[Literal["index"]] = "index"
 557    if TYPE_CHECKING:
 558        type: Literal["index"] = "index"
 559    else:
 560        type: Literal["index"]
 561
 562    id: NonBatchAxisId = AxisId("index")
 563
 564    @property
 565    def scale(self) -> float:
 566        return 1.0
 567
 568    @property
 569    def unit(self):
 570        return None
 571
 572
 573class _WithInputAxisSize(Node):
 574    size: Annotated[
 575        Union[Annotated[int, Gt(0)], ParameterizedSize, SizeReference],
 576        Field(
 577            examples=[
 578                10,
 579                ParameterizedSize(min=32, step=16).model_dump(mode="json"),
 580                SizeReference(
 581                    tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
 582                ).model_dump(mode="json"),
 583            ]
 584        ),
 585    ]
 586    """The size/length of this axis can be specified as
 587    - fixed integer
 588    - parameterized series of valid sizes (`ParameterizedSize`)
 589    - reference to another axis with an optional offset (`SizeReference`)
 590    """
 591
 592
 593class IndexInputAxis(IndexAxisBase, _WithInputAxisSize):
 594    concatenable: bool = False
 595    """If a model has a `concatenable` input axis, it can be processed blockwise,
 596    splitting a longer sample axis into blocks matching its input tensor description.
 597    Output axes are concatenable if they have a `SizeReference` to a concatenable
 598    input axis.
 599    """
 600
 601
 602class IndexOutputAxis(IndexAxisBase):
 603    size: Annotated[
 604        Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize],
 605        Field(
 606            examples=[
 607                10,
 608                SizeReference(
 609                    tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
 610                ).model_dump(mode="json"),
 611            ]
 612        ),
 613    ]
 614    """The size/length of this axis can be specified as
 615    - fixed integer
 616    - reference to another axis with an optional offset (`SizeReference`)
 617    - data dependent size using `DataDependentSize` (size is only known after model inference)
 618    """
 619
 620
 621class TimeAxisBase(AxisBase):
 622    implemented_type: ClassVar[Literal["time"]] = "time"
 623    if TYPE_CHECKING:
 624        type: Literal["time"] = "time"
 625    else:
 626        type: Literal["time"]
 627
 628    id: NonBatchAxisId = AxisId("time")
 629    unit: Optional[TimeUnit] = None
 630    scale: Annotated[float, Gt(0)] = 1.0
 631
 632
 633class TimeInputAxis(TimeAxisBase, _WithInputAxisSize):
 634    concatenable: bool = False
 635    """If a model has a `concatenable` input axis, it can be processed blockwise,
 636    splitting a longer sample axis into blocks matching its input tensor description.
 637    Output axes are concatenable if they have a `SizeReference` to a concatenable
 638    input axis.
 639    """
 640
 641
 642class SpaceAxisBase(AxisBase):
 643    implemented_type: ClassVar[Literal["space"]] = "space"
 644    if TYPE_CHECKING:
 645        type: Literal["space"] = "space"
 646    else:
 647        type: Literal["space"]
 648
 649    id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x")
 650    unit: Optional[SpaceUnit] = None
 651    scale: Annotated[float, Gt(0)] = 1.0
 652
 653
 654class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize):
 655    concatenable: bool = False
 656    """If a model has a `concatenable` input axis, it can be processed blockwise,
 657    splitting a longer sample axis into blocks matching its input tensor description.
 658    Output axes are concatenable if they have a `SizeReference` to a concatenable
 659    input axis.
 660    """
 661
 662
 663INPUT_AXIS_TYPES = (
 664    BatchAxis,
 665    ChannelAxis,
 666    IndexInputAxis,
 667    TimeInputAxis,
 668    SpaceInputAxis,
 669)
 670"""intended for isinstance comparisons in py<3.10"""
 671
 672_InputAxisUnion = Union[
 673    BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis
 674]
 675InputAxis = Annotated[_InputAxisUnion, Discriminator("type")]
 676
 677
 678class _WithOutputAxisSize(Node):
 679    size: Annotated[
 680        Union[Annotated[int, Gt(0)], SizeReference],
 681        Field(
 682            examples=[
 683                10,
 684                SizeReference(
 685                    tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
 686                ).model_dump(mode="json"),
 687            ]
 688        ),
 689    ]
 690    """The size/length of this axis can be specified as
 691    - fixed integer
 692    - reference to another axis with an optional offset (see `SizeReference`)
 693    """
 694
 695
 696class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize):
 697    pass
 698
 699
 700class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo):
 701    pass
 702
 703
 704def _get_halo_axis_discriminator_value(v: Any) -> Literal["with_halo", "wo_halo"]:
 705    if isinstance(v, dict):
 706        return "with_halo" if "halo" in v else "wo_halo"
 707    else:
 708        return "with_halo" if hasattr(v, "halo") else "wo_halo"
 709
 710
 711_TimeOutputAxisUnion = Annotated[
 712    Union[
 713        Annotated[TimeOutputAxis, Tag("wo_halo")],
 714        Annotated[TimeOutputAxisWithHalo, Tag("with_halo")],
 715    ],
 716    Discriminator(_get_halo_axis_discriminator_value),
 717]
 718
 719
 720class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize):
 721    pass
 722
 723
 724class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo):
 725    pass
 726
 727
 728_SpaceOutputAxisUnion = Annotated[
 729    Union[
 730        Annotated[SpaceOutputAxis, Tag("wo_halo")],
 731        Annotated[SpaceOutputAxisWithHalo, Tag("with_halo")],
 732    ],
 733    Discriminator(_get_halo_axis_discriminator_value),
 734]
 735
 736
 737_OutputAxisUnion = Union[
 738    BatchAxis, ChannelAxis, IndexOutputAxis, _TimeOutputAxisUnion, _SpaceOutputAxisUnion
 739]
 740OutputAxis = Annotated[_OutputAxisUnion, Discriminator("type")]
 741
 742OUTPUT_AXIS_TYPES = (
 743    BatchAxis,
 744    ChannelAxis,
 745    IndexOutputAxis,
 746    TimeOutputAxis,
 747    TimeOutputAxisWithHalo,
 748    SpaceOutputAxis,
 749    SpaceOutputAxisWithHalo,
 750)
 751"""intended for isinstance comparisons in py<3.10"""
 752
 753
 754AnyAxis = Union[InputAxis, OutputAxis]
 755
 756ANY_AXIS_TYPES = INPUT_AXIS_TYPES + OUTPUT_AXIS_TYPES
 757"""intended for isinstance comparisons in py<3.10"""
 758
 759TVs = Union[
 760    NotEmpty[List[int]],
 761    NotEmpty[List[float]],
 762    NotEmpty[List[bool]],
 763    NotEmpty[List[str]],
 764]
 765
 766
 767NominalOrOrdinalDType = Literal[
 768    "float32",
 769    "float64",
 770    "uint8",
 771    "int8",
 772    "uint16",
 773    "int16",
 774    "uint32",
 775    "int32",
 776    "uint64",
 777    "int64",
 778    "bool",
 779]
 780
 781
 782class NominalOrOrdinalDataDescr(Node):
 783    values: TVs
 784    """A fixed set of nominal or an ascending sequence of ordinal values.
 785    In this case `data.type` is required to be an unsigend integer type, e.g. 'uint8'.
 786    String `values` are interpreted as labels for tensor values 0, ..., N.
 787    Note: as YAML 1.2 does not natively support a "set" datatype,
 788    nominal values should be given as a sequence (aka list/array) as well.
 789    """
 790
 791    type: Annotated[
 792        NominalOrOrdinalDType,
 793        Field(
 794            examples=[
 795                "float32",
 796                "uint8",
 797                "uint16",
 798                "int64",
 799                "bool",
 800            ],
 801        ),
 802    ] = "uint8"
 803
 804    @model_validator(mode="after")
 805    def _validate_values_match_type(
 806        self,
 807    ) -> Self:
 808        incompatible: List[Any] = []
 809        for v in self.values:
 810            if self.type == "bool":
 811                if not isinstance(v, bool):
 812                    incompatible.append(v)
 813            elif self.type in DTYPE_LIMITS:
 814                if (
 815                    isinstance(v, (int, float))
 816                    and (
 817                        v < DTYPE_LIMITS[self.type].min
 818                        or v > DTYPE_LIMITS[self.type].max
 819                    )
 820                    or (isinstance(v, str) and "uint" not in self.type)
 821                    or (isinstance(v, float) and "int" in self.type)
 822                ):
 823                    incompatible.append(v)
 824            else:
 825                incompatible.append(v)
 826
 827            if len(incompatible) == 5:
 828                incompatible.append("...")
 829                break
 830
 831        if incompatible:
 832            raise ValueError(
 833                f"data type '{self.type}' incompatible with values {incompatible}"
 834            )
 835
 836        return self
 837
 838    unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None
 839
 840    @property
 841    def range(self):
 842        if isinstance(self.values[0], str):
 843            return 0, len(self.values) - 1
 844        else:
 845            return min(self.values), max(self.values)
 846
 847
 848IntervalOrRatioDType = Literal[
 849    "float32",
 850    "float64",
 851    "uint8",
 852    "int8",
 853    "uint16",
 854    "int16",
 855    "uint32",
 856    "int32",
 857    "uint64",
 858    "int64",
 859]
 860
 861
 862class IntervalOrRatioDataDescr(Node):
 863    type: Annotated[  # todo: rename to dtype
 864        IntervalOrRatioDType,
 865        Field(
 866            examples=["float32", "float64", "uint8", "uint16"],
 867        ),
 868    ] = "float32"
 869    range: Tuple[Optional[float], Optional[float]] = (
 870        None,
 871        None,
 872    )
 873    """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor.
 874    `None` corresponds to min/max of what can be expressed by **type**."""
 875    unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit"
 876    scale: float = 1.0
 877    """Scale for data on an interval (or ratio) scale."""
 878    offset: Optional[float] = None
 879    """Offset for data on a ratio scale."""
 880
 881
 882TensorDataDescr = Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr]
 883
 884
 885class ProcessingDescrBase(NodeWithExplicitlySetFields, ABC):
 886    """processing base class"""
 887
 888
 889class BinarizeKwargs(ProcessingKwargs):
 890    """key word arguments for `BinarizeDescr`"""
 891
 892    threshold: float
 893    """The fixed threshold"""
 894
 895
 896class BinarizeAlongAxisKwargs(ProcessingKwargs):
 897    """key word arguments for `BinarizeDescr`"""
 898
 899    threshold: NotEmpty[List[float]]
 900    """The fixed threshold values along `axis`"""
 901
 902    axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
 903    """The `threshold` axis"""
 904
 905
 906class BinarizeDescr(ProcessingDescrBase):
 907    """Binarize the tensor with a fixed threshold.
 908
 909    Values above `BinarizeKwargs.threshold`/`BinarizeAlongAxisKwargs.threshold`
 910    will be set to one, values below the threshold to zero.
 911
 912    Examples:
 913    - in YAML
 914        ```yaml
 915        postprocessing:
 916          - id: binarize
 917            kwargs:
 918              axis: 'channel'
 919              threshold: [0.25, 0.5, 0.75]
 920        ```
 921    - in Python:
 922        >>> postprocessing = [BinarizeDescr(
 923        ...   kwargs=BinarizeAlongAxisKwargs(
 924        ...       axis=AxisId('channel'),
 925        ...       threshold=[0.25, 0.5, 0.75],
 926        ...   )
 927        ... )]
 928    """
 929
 930    implemented_id: ClassVar[Literal["binarize"]] = "binarize"
 931    if TYPE_CHECKING:
 932        id: Literal["binarize"] = "binarize"
 933    else:
 934        id: Literal["binarize"]
 935    kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs]
 936
 937
 938class ClipDescr(ProcessingDescrBase):
 939    """Set tensor values below min to min and above max to max.
 940
 941    See `ScaleRangeDescr` for examples.
 942    """
 943
 944    implemented_id: ClassVar[Literal["clip"]] = "clip"
 945    if TYPE_CHECKING:
 946        id: Literal["clip"] = "clip"
 947    else:
 948        id: Literal["clip"]
 949
 950    kwargs: ClipKwargs
 951
 952
 953class EnsureDtypeKwargs(ProcessingKwargs):
 954    """key word arguments for `EnsureDtypeDescr`"""
 955
 956    dtype: Literal[
 957        "float32",
 958        "float64",
 959        "uint8",
 960        "int8",
 961        "uint16",
 962        "int16",
 963        "uint32",
 964        "int32",
 965        "uint64",
 966        "int64",
 967        "bool",
 968    ]
 969
 970
 971class EnsureDtypeDescr(ProcessingDescrBase):
 972    """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching).
 973
 974    This can for example be used to ensure the inner neural network model gets a
 975    different input tensor data type than the fully described bioimage.io model does.
 976
 977    Examples:
 978        The described bioimage.io model (incl. preprocessing) accepts any
 979        float32-compatible tensor, normalizes it with percentiles and clipping and then
 980        casts it to uint8, which is what the neural network in this example expects.
 981        - in YAML
 982            ```yaml
 983            inputs:
 984            - data:
 985                type: float32  # described bioimage.io model is compatible with any float32 input tensor
 986              preprocessing:
 987              - id: scale_range
 988                  kwargs:
 989                  axes: ['y', 'x']
 990                  max_percentile: 99.8
 991                  min_percentile: 5.0
 992              - id: clip
 993                  kwargs:
 994                  min: 0.0
 995                  max: 1.0
 996              - id: ensure_dtype  # the neural network of the model requires uint8
 997                  kwargs:
 998                  dtype: uint8
 999            ```
1000        - in Python:
1001            >>> preprocessing = [
1002            ...     ScaleRangeDescr(
1003            ...         kwargs=ScaleRangeKwargs(
1004            ...           axes= (AxisId('y'), AxisId('x')),
1005            ...           max_percentile= 99.8,
1006            ...           min_percentile= 5.0,
1007            ...         )
1008            ...     ),
1009            ...     ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)),
1010            ...     EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")),
1011            ... ]
1012    """
1013
1014    implemented_id: ClassVar[Literal["ensure_dtype"]] = "ensure_dtype"
1015    if TYPE_CHECKING:
1016        id: Literal["ensure_dtype"] = "ensure_dtype"
1017    else:
1018        id: Literal["ensure_dtype"]
1019
1020    kwargs: EnsureDtypeKwargs
1021
1022
1023class ScaleLinearKwargs(ProcessingKwargs):
1024    """Key word arguments for `ScaleLinearDescr`"""
1025
1026    gain: float = 1.0
1027    """multiplicative factor"""
1028
1029    offset: float = 0.0
1030    """additive term"""
1031
1032    @model_validator(mode="after")
1033    def _validate(self) -> Self:
1034        if self.gain == 1.0 and self.offset == 0.0:
1035            raise ValueError(
1036                "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
1037                + " != 0.0."
1038            )
1039
1040        return self
1041
1042
1043class ScaleLinearAlongAxisKwargs(ProcessingKwargs):
1044    """Key word arguments for `ScaleLinearDescr`"""
1045
1046    axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
1047    """The axis of gain and offset values."""
1048
1049    gain: Union[float, NotEmpty[List[float]]] = 1.0
1050    """multiplicative factor"""
1051
1052    offset: Union[float, NotEmpty[List[float]]] = 0.0
1053    """additive term"""
1054
1055    @model_validator(mode="after")
1056    def _validate(self) -> Self:
1057
1058        if isinstance(self.gain, list):
1059            if isinstance(self.offset, list):
1060                if len(self.gain) != len(self.offset):
1061                    raise ValueError(
1062                        f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match."
1063                    )
1064            else:
1065                self.offset = [float(self.offset)] * len(self.gain)
1066        elif isinstance(self.offset, list):
1067            self.gain = [float(self.gain)] * len(self.offset)
1068        else:
1069            raise ValueError(
1070                "Do not specify an `axis` for scalar gain and offset values."
1071            )
1072
1073        if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset):
1074            raise ValueError(
1075                "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
1076                + " != 0.0."
1077            )
1078
1079        return self
1080
1081
1082class ScaleLinearDescr(ProcessingDescrBase):
1083    """Fixed linear scaling.
1084
1085    Examples:
1086      1. Scale with scalar gain and offset
1087        - in YAML
1088        ```yaml
1089        preprocessing:
1090          - id: scale_linear
1091            kwargs:
1092              gain: 2.0
1093              offset: 3.0
1094        ```
1095        - in Python:
1096        >>> preprocessing = [
1097        ...     ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0))
1098        ... ]
1099
1100      2. Independent scaling along an axis
1101        - in YAML
1102        ```yaml
1103        preprocessing:
1104          - id: scale_linear
1105            kwargs:
1106              axis: 'channel'
1107              gain: [1.0, 2.0, 3.0]
1108        ```
1109        - in Python:
1110        >>> preprocessing = [
1111        ...     ScaleLinearDescr(
1112        ...         kwargs=ScaleLinearAlongAxisKwargs(
1113        ...             axis=AxisId("channel"),
1114        ...             gain=[1.0, 2.0, 3.0],
1115        ...         )
1116        ...     )
1117        ... ]
1118
1119    """
1120
1121    implemented_id: ClassVar[Literal["scale_linear"]] = "scale_linear"
1122    if TYPE_CHECKING:
1123        id: Literal["scale_linear"] = "scale_linear"
1124    else:
1125        id: Literal["scale_linear"]
1126    kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs]
1127
1128
1129class SigmoidDescr(ProcessingDescrBase):
1130    """The logistic sigmoid funciton, a.k.a. expit function.
1131
1132    Examples:
1133    - in YAML
1134        ```yaml
1135        postprocessing:
1136          - id: sigmoid
1137        ```
1138    - in Python:
1139        >>> postprocessing = [SigmoidDescr()]
1140    """
1141
1142    implemented_id: ClassVar[Literal["sigmoid"]] = "sigmoid"
1143    if TYPE_CHECKING:
1144        id: Literal["sigmoid"] = "sigmoid"
1145    else:
1146        id: Literal["sigmoid"]
1147
1148    @property
1149    def kwargs(self) -> ProcessingKwargs:
1150        """empty kwargs"""
1151        return ProcessingKwargs()
1152
1153
1154class FixedZeroMeanUnitVarianceKwargs(ProcessingKwargs):
1155    """key word arguments for `FixedZeroMeanUnitVarianceDescr`"""
1156
1157    mean: float
1158    """The mean value to normalize with."""
1159
1160    std: Annotated[float, Ge(1e-6)]
1161    """The standard deviation value to normalize with."""
1162
1163
1164class FixedZeroMeanUnitVarianceAlongAxisKwargs(ProcessingKwargs):
1165    """key word arguments for `FixedZeroMeanUnitVarianceDescr`"""
1166
1167    mean: NotEmpty[List[float]]
1168    """The mean value(s) to normalize with."""
1169
1170    std: NotEmpty[List[Annotated[float, Ge(1e-6)]]]
1171    """The standard deviation value(s) to normalize with.
1172    Size must match `mean` values."""
1173
1174    axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])]
1175    """The axis of the mean/std values to normalize each entry along that dimension
1176    separately."""
1177
1178    @model_validator(mode="after")
1179    def _mean_and_std_match(self) -> Self:
1180        if len(self.mean) != len(self.std):
1181            raise ValueError(
1182                f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})"
1183                + " must match."
1184            )
1185
1186        return self
1187
1188
1189class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1190    """Subtract a given mean and divide by the standard deviation.
1191
1192    Normalize with fixed, precomputed values for
1193    `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std`
1194    Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given
1195    axes.
1196
1197    Examples:
1198    1. scalar value for whole tensor
1199        - in YAML
1200        ```yaml
1201        preprocessing:
1202          - id: fixed_zero_mean_unit_variance
1203            kwargs:
1204              mean: 103.5
1205              std: 13.7
1206        ```
1207        - in Python
1208        >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
1209        ...   kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7)
1210        ... )]
1211
1212    2. independently along an axis
1213        - in YAML
1214        ```yaml
1215        preprocessing:
1216          - id: fixed_zero_mean_unit_variance
1217            kwargs:
1218              axis: channel
1219              mean: [101.5, 102.5, 103.5]
1220              std: [11.7, 12.7, 13.7]
1221        ```
1222        - in Python
1223        >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
1224        ...   kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
1225        ...     axis=AxisId("channel"),
1226        ...     mean=[101.5, 102.5, 103.5],
1227        ...     std=[11.7, 12.7, 13.7],
1228        ...   )
1229        ... )]
1230    """
1231
1232    implemented_id: ClassVar[Literal["fixed_zero_mean_unit_variance"]] = (
1233        "fixed_zero_mean_unit_variance"
1234    )
1235    if TYPE_CHECKING:
1236        id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance"
1237    else:
1238        id: Literal["fixed_zero_mean_unit_variance"]
1239
1240    kwargs: Union[
1241        FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs
1242    ]
1243
1244
1245class ZeroMeanUnitVarianceKwargs(ProcessingKwargs):
1246    """key word arguments for `ZeroMeanUnitVarianceDescr`"""
1247
1248    axes: Annotated[
1249        Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1250    ] = None
1251    """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1252    For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1253    resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1254    To normalize each sample independently leave out the 'batch' axis.
1255    Default: Scale all axes jointly."""
1256
1257    eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1258    """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`."""
1259
1260
1261class ZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1262    """Subtract mean and divide by variance.
1263
1264    Examples:
1265        Subtract tensor mean and variance
1266        - in YAML
1267        ```yaml
1268        preprocessing:
1269          - id: zero_mean_unit_variance
1270        ```
1271        - in Python
1272        >>> preprocessing = [ZeroMeanUnitVarianceDescr()]
1273    """
1274
1275    implemented_id: ClassVar[Literal["zero_mean_unit_variance"]] = (
1276        "zero_mean_unit_variance"
1277    )
1278    if TYPE_CHECKING:
1279        id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance"
1280    else:
1281        id: Literal["zero_mean_unit_variance"]
1282
1283    kwargs: ZeroMeanUnitVarianceKwargs = Field(
1284        default_factory=ZeroMeanUnitVarianceKwargs
1285    )
1286
1287
1288class ScaleRangeKwargs(ProcessingKwargs):
1289    """key word arguments for `ScaleRangeDescr`
1290
1291    For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default)
1292    this processing step normalizes data to the [0, 1] intervall.
1293    For other percentiles the normalized values will partially be outside the [0, 1]
1294    intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the
1295    normalized values to a range.
1296    """
1297
1298    axes: Annotated[
1299        Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1300    ] = None
1301    """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value.
1302    For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1303    resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1304    To normalize samples independently, leave out the "batch" axis.
1305    Default: Scale all axes jointly."""
1306
1307    min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0
1308    """The lower percentile used to determine the value to align with zero."""
1309
1310    max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0
1311    """The upper percentile used to determine the value to align with one.
1312    Has to be bigger than `min_percentile`.
1313    The range is 1 to 100 instead of 0 to 100 to avoid mistakenly
1314    accepting percentiles specified in the range 0.0 to 1.0."""
1315
1316    eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1317    """Epsilon for numeric stability.
1318    `out = (tensor - v_lower) / (v_upper - v_lower + eps)`;
1319    with `v_lower,v_upper` values at the respective percentiles."""
1320
1321    reference_tensor: Optional[TensorId] = None
1322    """Tensor ID to compute the percentiles from. Default: The tensor itself.
1323    For any tensor in `inputs` only input tensor references are allowed."""
1324
1325    @field_validator("max_percentile", mode="after")
1326    @classmethod
1327    def min_smaller_max(cls, value: float, info: ValidationInfo) -> float:
1328        if (min_p := info.data["min_percentile"]) >= value:
1329            raise ValueError(f"min_percentile {min_p} >= max_percentile {value}")
1330
1331        return value
1332
1333
1334class ScaleRangeDescr(ProcessingDescrBase):
1335    """Scale with percentiles.
1336
1337    Examples:
1338    1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0
1339        - in YAML
1340        ```yaml
1341        preprocessing:
1342          - id: scale_range
1343            kwargs:
1344              axes: ['y', 'x']
1345              max_percentile: 99.8
1346              min_percentile: 5.0
1347        ```
1348        - in Python
1349        >>> preprocessing = [
1350        ...     ScaleRangeDescr(
1351        ...         kwargs=ScaleRangeKwargs(
1352        ...           axes= (AxisId('y'), AxisId('x')),
1353        ...           max_percentile= 99.8,
1354        ...           min_percentile= 5.0,
1355        ...         )
1356        ...     ),
1357        ...     ClipDescr(
1358        ...         kwargs=ClipKwargs(
1359        ...             min=0.0,
1360        ...             max=1.0,
1361        ...         )
1362        ...     ),
1363        ... ]
1364
1365      2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles.
1366        - in YAML
1367        ```yaml
1368        preprocessing:
1369          - id: scale_range
1370            kwargs:
1371              axes: ['y', 'x']
1372              max_percentile: 99.8
1373              min_percentile: 5.0
1374                  - id: scale_range
1375           - id: clip
1376             kwargs:
1377              min: 0.0
1378              max: 1.0
1379        ```
1380        - in Python
1381        >>> preprocessing = [ScaleRangeDescr(
1382        ...   kwargs=ScaleRangeKwargs(
1383        ...       axes= (AxisId('y'), AxisId('x')),
1384        ...       max_percentile= 99.8,
1385        ...       min_percentile= 5.0,
1386        ...   )
1387        ... )]
1388
1389    """
1390
1391    implemented_id: ClassVar[Literal["scale_range"]] = "scale_range"
1392    if TYPE_CHECKING:
1393        id: Literal["scale_range"] = "scale_range"
1394    else:
1395        id: Literal["scale_range"]
1396    kwargs: ScaleRangeKwargs
1397
1398
1399class ScaleMeanVarianceKwargs(ProcessingKwargs):
1400    """key word arguments for `ScaleMeanVarianceKwargs`"""
1401
1402    reference_tensor: TensorId
1403    """Name of tensor to match."""
1404
1405    axes: Annotated[
1406        Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1407    ] = None
1408    """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1409    For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1410    resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1411    To normalize samples independently, leave out the 'batch' axis.
1412    Default: Scale all axes jointly."""
1413
1414    eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1415    """Epsilon for numeric stability:
1416    `out  = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`"""
1417
1418
1419class ScaleMeanVarianceDescr(ProcessingDescrBase):
1420    """Scale a tensor's data distribution to match another tensor's mean/std.
1421    `out  = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`
1422    """
1423
1424    implemented_id: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance"
1425    if TYPE_CHECKING:
1426        id: Literal["scale_mean_variance"] = "scale_mean_variance"
1427    else:
1428        id: Literal["scale_mean_variance"]
1429    kwargs: ScaleMeanVarianceKwargs
1430
1431
1432PreprocessingDescr = Annotated[
1433    Union[
1434        BinarizeDescr,
1435        ClipDescr,
1436        EnsureDtypeDescr,
1437        ScaleLinearDescr,
1438        SigmoidDescr,
1439        FixedZeroMeanUnitVarianceDescr,
1440        ZeroMeanUnitVarianceDescr,
1441        ScaleRangeDescr,
1442    ],
1443    Discriminator("id"),
1444]
1445PostprocessingDescr = Annotated[
1446    Union[
1447        BinarizeDescr,
1448        ClipDescr,
1449        EnsureDtypeDescr,
1450        ScaleLinearDescr,
1451        SigmoidDescr,
1452        FixedZeroMeanUnitVarianceDescr,
1453        ZeroMeanUnitVarianceDescr,
1454        ScaleRangeDescr,
1455        ScaleMeanVarianceDescr,
1456    ],
1457    Discriminator("id"),
1458]
1459
1460IO_AxisT = TypeVar("IO_AxisT", InputAxis, OutputAxis)
1461
1462
1463class TensorDescrBase(Node, Generic[IO_AxisT]):
1464    id: TensorId
1465    """Tensor id. No duplicates are allowed."""
1466
1467    description: Annotated[str, MaxLen(128)] = ""
1468    """free text description"""
1469
1470    axes: NotEmpty[Sequence[IO_AxisT]]
1471    """tensor axes"""
1472
1473    @property
1474    def shape(self):
1475        return tuple(a.size for a in self.axes)
1476
1477    @field_validator("axes", mode="after", check_fields=False)
1478    @classmethod
1479    def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]:
1480        batch_axes = [a for a in axes if a.type == "batch"]
1481        if len(batch_axes) > 1:
1482            raise ValueError(
1483                f"Only one batch axis (per tensor) allowed, but got {batch_axes}"
1484            )
1485
1486        seen_ids: Set[AxisId] = set()
1487        duplicate_axes_ids: Set[AxisId] = set()
1488        for a in axes:
1489            (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id)
1490
1491        if duplicate_axes_ids:
1492            raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}")
1493
1494        return axes
1495
1496    test_tensor: FileDescr_
1497    """An example tensor to use for testing.
1498    Using the model with the test input tensors is expected to yield the test output tensors.
1499    Each test tensor has be a an ndarray in the
1500    [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format).
1501    The file extension must be '.npy'."""
1502
1503    sample_tensor: Optional[FileDescr_] = None
1504    """A sample tensor to illustrate a possible input/output for the model,
1505    The sample image primarily serves to inform a human user about an example use case
1506    and is typically stored as .hdf5, .png or .tiff.
1507    It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats)
1508    (numpy's `.npy` format is not supported).
1509    The image dimensionality has to match the number of axes specified in this tensor description.
1510    """
1511
1512    @model_validator(mode="after")
1513    def _validate_sample_tensor(self) -> Self:
1514        if self.sample_tensor is None or not get_validation_context().perform_io_checks:
1515            return self
1516
1517        reader = get_reader(self.sample_tensor.source, sha256=self.sample_tensor.sha256)
1518        tensor: NDArray[Any] = imread(
1519            reader.read(),
1520            extension=PurePosixPath(reader.original_file_name).suffix,
1521        )
1522        n_dims = len(tensor.squeeze().shape)
1523        n_dims_min = n_dims_max = len(self.axes)
1524
1525        for a in self.axes:
1526            if isinstance(a, BatchAxis):
1527                n_dims_min -= 1
1528            elif isinstance(a.size, int):
1529                if a.size == 1:
1530                    n_dims_min -= 1
1531            elif isinstance(a.size, (ParameterizedSize, DataDependentSize)):
1532                if a.size.min == 1:
1533                    n_dims_min -= 1
1534            elif isinstance(a.size, SizeReference):
1535                if a.size.offset < 2:
1536                    # size reference may result in singleton axis
1537                    n_dims_min -= 1
1538            else:
1539                assert_never(a.size)
1540
1541        n_dims_min = max(0, n_dims_min)
1542        if n_dims < n_dims_min or n_dims > n_dims_max:
1543            raise ValueError(
1544                f"Expected sample tensor to have {n_dims_min} to"
1545                + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})."
1546            )
1547
1548        return self
1549
1550    data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = (
1551        IntervalOrRatioDataDescr()
1552    )
1553    """Description of the tensor's data values, optionally per channel.
1554    If specified per channel, the data `type` needs to match across channels."""
1555
1556    @property
1557    def dtype(
1558        self,
1559    ) -> Literal[
1560        "float32",
1561        "float64",
1562        "uint8",
1563        "int8",
1564        "uint16",
1565        "int16",
1566        "uint32",
1567        "int32",
1568        "uint64",
1569        "int64",
1570        "bool",
1571    ]:
1572        """dtype as specified under `data.type` or `data[i].type`"""
1573        if isinstance(self.data, collections.abc.Sequence):
1574            return self.data[0].type
1575        else:
1576            return self.data.type
1577
1578    @field_validator("data", mode="after")
1579    @classmethod
1580    def _check_data_type_across_channels(
1581        cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]
1582    ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]:
1583        if not isinstance(value, list):
1584            return value
1585
1586        dtypes = {t.type for t in value}
1587        if len(dtypes) > 1:
1588            raise ValueError(
1589                "Tensor data descriptions per channel need to agree in their data"
1590                + f" `type`, but found {dtypes}."
1591            )
1592
1593        return value
1594
1595    @model_validator(mode="after")
1596    def _check_data_matches_channelaxis(self) -> Self:
1597        if not isinstance(self.data, (list, tuple)):
1598            return self
1599
1600        for a in self.axes:
1601            if isinstance(a, ChannelAxis):
1602                size = a.size
1603                assert isinstance(size, int)
1604                break
1605        else:
1606            return self
1607
1608        if len(self.data) != size:
1609            raise ValueError(
1610                f"Got tensor data descriptions for {len(self.data)} channels, but"
1611                + f" '{a.id}' axis has size {size}."
1612            )
1613
1614        return self
1615
1616    def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]:
1617        if len(array.shape) != len(self.axes):
1618            raise ValueError(
1619                f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})"
1620                + f" incompatible with {len(self.axes)} axes."
1621            )
1622        return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
1623
1624
1625class InputTensorDescr(TensorDescrBase[InputAxis]):
1626    id: TensorId = TensorId("input")
1627    """Input tensor id.
1628    No duplicates are allowed across all inputs and outputs."""
1629
1630    optional: bool = False
1631    """indicates that this tensor may be `None`"""
1632
1633    preprocessing: List[PreprocessingDescr] = Field(
1634        default_factory=cast(Callable[[], List[PreprocessingDescr]], list)
1635    )
1636
1637    """Description of how this input should be preprocessed.
1638
1639    notes:
1640    - If preprocessing does not start with an 'ensure_dtype' entry, it is added
1641      to ensure an input tensor's data type matches the input tensor's data description.
1642    - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an
1643      'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally
1644      changing the data type.
1645    """
1646
1647    @model_validator(mode="after")
1648    def _validate_preprocessing_kwargs(self) -> Self:
1649        axes_ids = [a.id for a in self.axes]
1650        for p in self.preprocessing:
1651            kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
1652            if kwargs_axes is None:
1653                continue
1654
1655            if not isinstance(kwargs_axes, collections.abc.Sequence):
1656                raise ValueError(
1657                    f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}"
1658                )
1659
1660            if any(a not in axes_ids for a in kwargs_axes):
1661                raise ValueError(
1662                    "`preprocessing.i.kwargs.axes` needs to be subset of axes ids"
1663                )
1664
1665        if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
1666            dtype = self.data.type
1667        else:
1668            dtype = self.data[0].type
1669
1670        # ensure `preprocessing` begins with `EnsureDtypeDescr`
1671        if not self.preprocessing or not isinstance(
1672            self.preprocessing[0], EnsureDtypeDescr
1673        ):
1674            self.preprocessing.insert(
1675                0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1676            )
1677
1678        # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
1679        if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)):
1680            self.preprocessing.append(
1681                EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1682            )
1683
1684        return self
1685
1686
1687def convert_axes(
1688    axes: str,
1689    *,
1690    shape: Union[
1691        Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4
1692    ],
1693    tensor_type: Literal["input", "output"],
1694    halo: Optional[Sequence[int]],
1695    size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
1696):
1697    ret: List[AnyAxis] = []
1698    for i, a in enumerate(axes):
1699        axis_type = _AXIS_TYPE_MAP.get(a, a)
1700        if axis_type == "batch":
1701            ret.append(BatchAxis())
1702            continue
1703
1704        scale = 1.0
1705        if isinstance(shape, _ParameterizedInputShape_v0_4):
1706            if shape.step[i] == 0:
1707                size = shape.min[i]
1708            else:
1709                size = ParameterizedSize(min=shape.min[i], step=shape.step[i])
1710        elif isinstance(shape, _ImplicitOutputShape_v0_4):
1711            ref_t = str(shape.reference_tensor)
1712            if ref_t.count(".") == 1:
1713                t_id, orig_a_id = ref_t.split(".")
1714            else:
1715                t_id = ref_t
1716                orig_a_id = a
1717
1718            a_id = _AXIS_ID_MAP.get(orig_a_id, a)
1719            if not (orig_scale := shape.scale[i]):
1720                # old way to insert a new axis dimension
1721                size = int(2 * shape.offset[i])
1722            else:
1723                scale = 1 / orig_scale
1724                if axis_type in ("channel", "index"):
1725                    # these axes no longer have a scale
1726                    offset_from_scale = orig_scale * size_refs.get(
1727                        _TensorName_v0_4(t_id), {}
1728                    ).get(orig_a_id, 0)
1729                else:
1730                    offset_from_scale = 0
1731                size = SizeReference(
1732                    tensor_id=TensorId(t_id),
1733                    axis_id=AxisId(a_id),
1734                    offset=int(offset_from_scale + 2 * shape.offset[i]),
1735                )
1736        else:
1737            size = shape[i]
1738
1739        if axis_type == "time":
1740            if tensor_type == "input":
1741                ret.append(TimeInputAxis(size=size, scale=scale))
1742            else:
1743                assert not isinstance(size, ParameterizedSize)
1744                if halo is None:
1745                    ret.append(TimeOutputAxis(size=size, scale=scale))
1746                else:
1747                    assert not isinstance(size, int)
1748                    ret.append(
1749                        TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i])
1750                    )
1751
1752        elif axis_type == "index":
1753            if tensor_type == "input":
1754                ret.append(IndexInputAxis(size=size))
1755            else:
1756                if isinstance(size, ParameterizedSize):
1757                    size = DataDependentSize(min=size.min)
1758
1759                ret.append(IndexOutputAxis(size=size))
1760        elif axis_type == "channel":
1761            assert not isinstance(size, ParameterizedSize)
1762            if isinstance(size, SizeReference):
1763                warnings.warn(
1764                    "Conversion of channel size from an implicit output shape may be"
1765                    + " wrong"
1766                )
1767                ret.append(
1768                    ChannelAxis(
1769                        channel_names=[
1770                            Identifier(f"channel{i}") for i in range(size.offset)
1771                        ]
1772                    )
1773                )
1774            else:
1775                ret.append(
1776                    ChannelAxis(
1777                        channel_names=[Identifier(f"channel{i}") for i in range(size)]
1778                    )
1779                )
1780        elif axis_type == "space":
1781            if tensor_type == "input":
1782                ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale))
1783            else:
1784                assert not isinstance(size, ParameterizedSize)
1785                if halo is None or halo[i] == 0:
1786                    ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale))
1787                elif isinstance(size, int):
1788                    raise NotImplementedError(
1789                        f"output axis with halo and fixed size (here {size}) not allowed"
1790                    )
1791                else:
1792                    ret.append(
1793                        SpaceOutputAxisWithHalo(
1794                            id=AxisId(a), size=size, scale=scale, halo=halo[i]
1795                        )
1796                    )
1797
1798    return ret
1799
1800
1801def _axes_letters_to_ids(
1802    axes: Optional[str],
1803) -> Optional[List[AxisId]]:
1804    if axes is None:
1805        return None
1806
1807    return [AxisId(a) for a in axes]
1808
1809
1810def _get_complement_v04_axis(
1811    tensor_axes: Sequence[str], axes: Optional[Sequence[str]]
1812) -> Optional[AxisId]:
1813    if axes is None:
1814        return None
1815
1816    non_complement_axes = set(axes) | {"b"}
1817    complement_axes = [a for a in tensor_axes if a not in non_complement_axes]
1818    if len(complement_axes) > 1:
1819        raise ValueError(
1820            f"Expected none or a single complement axis, but axes '{axes}' "
1821            + f"for tensor dims '{tensor_axes}' leave '{complement_axes}'."
1822        )
1823
1824    return None if not complement_axes else AxisId(complement_axes[0])
1825
1826
1827def _convert_proc(
1828    p: Union[_PreprocessingDescr_v0_4, _PostprocessingDescr_v0_4],
1829    tensor_axes: Sequence[str],
1830) -> Union[PreprocessingDescr, PostprocessingDescr]:
1831    if isinstance(p, _BinarizeDescr_v0_4):
1832        return BinarizeDescr(kwargs=BinarizeKwargs(threshold=p.kwargs.threshold))
1833    elif isinstance(p, _ClipDescr_v0_4):
1834        return ClipDescr(kwargs=ClipKwargs(min=p.kwargs.min, max=p.kwargs.max))
1835    elif isinstance(p, _SigmoidDescr_v0_4):
1836        return SigmoidDescr()
1837    elif isinstance(p, _ScaleLinearDescr_v0_4):
1838        axes = _axes_letters_to_ids(p.kwargs.axes)
1839        if p.kwargs.axes is None:
1840            axis = None
1841        else:
1842            axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes)
1843
1844        if axis is None:
1845            assert not isinstance(p.kwargs.gain, list)
1846            assert not isinstance(p.kwargs.offset, list)
1847            kwargs = ScaleLinearKwargs(gain=p.kwargs.gain, offset=p.kwargs.offset)
1848        else:
1849            kwargs = ScaleLinearAlongAxisKwargs(
1850                axis=axis, gain=p.kwargs.gain, offset=p.kwargs.offset
1851            )
1852        return ScaleLinearDescr(kwargs=kwargs)
1853    elif isinstance(p, _ScaleMeanVarianceDescr_v0_4):
1854        return ScaleMeanVarianceDescr(
1855            kwargs=ScaleMeanVarianceKwargs(
1856                axes=_axes_letters_to_ids(p.kwargs.axes),
1857                reference_tensor=TensorId(str(p.kwargs.reference_tensor)),
1858                eps=p.kwargs.eps,
1859            )
1860        )
1861    elif isinstance(p, _ZeroMeanUnitVarianceDescr_v0_4):
1862        if p.kwargs.mode == "fixed":
1863            mean = p.kwargs.mean
1864            std = p.kwargs.std
1865            assert mean is not None
1866            assert std is not None
1867
1868            axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes)
1869
1870            if axis is None:
1871                return FixedZeroMeanUnitVarianceDescr(
1872                    kwargs=FixedZeroMeanUnitVarianceKwargs(
1873                        mean=mean, std=std  # pyright: ignore[reportArgumentType]
1874                    )
1875                )
1876            else:
1877                if not isinstance(mean, list):
1878                    mean = [float(mean)]
1879                if not isinstance(std, list):
1880                    std = [float(std)]
1881
1882                return FixedZeroMeanUnitVarianceDescr(
1883                    kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
1884                        axis=axis, mean=mean, std=std
1885                    )
1886                )
1887
1888        else:
1889            axes = _axes_letters_to_ids(p.kwargs.axes) or []
1890            if p.kwargs.mode == "per_dataset":
1891                axes = [AxisId("batch")] + axes
1892            if not axes:
1893                axes = None
1894            return ZeroMeanUnitVarianceDescr(
1895                kwargs=ZeroMeanUnitVarianceKwargs(axes=axes, eps=p.kwargs.eps)
1896            )
1897
1898    elif isinstance(p, _ScaleRangeDescr_v0_4):
1899        return ScaleRangeDescr(
1900            kwargs=ScaleRangeKwargs(
1901                axes=_axes_letters_to_ids(p.kwargs.axes),
1902                min_percentile=p.kwargs.min_percentile,
1903                max_percentile=p.kwargs.max_percentile,
1904                eps=p.kwargs.eps,
1905            )
1906        )
1907    else:
1908        assert_never(p)
1909
1910
1911class _InputTensorConv(
1912    Converter[
1913        _InputTensorDescr_v0_4,
1914        InputTensorDescr,
1915        FileSource_,
1916        Optional[FileSource_],
1917        Mapping[_TensorName_v0_4, Mapping[str, int]],
1918    ]
1919):
1920    def _convert(
1921        self,
1922        src: _InputTensorDescr_v0_4,
1923        tgt: "type[InputTensorDescr] | type[dict[str, Any]]",
1924        test_tensor: FileSource_,
1925        sample_tensor: Optional[FileSource_],
1926        size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
1927    ) -> "InputTensorDescr | dict[str, Any]":
1928        axes: List[InputAxis] = convert_axes(  # pyright: ignore[reportAssignmentType]
1929            src.axes,
1930            shape=src.shape,
1931            tensor_type="input",
1932            halo=None,
1933            size_refs=size_refs,
1934        )
1935        prep: List[PreprocessingDescr] = []
1936        for p in src.preprocessing:
1937            cp = _convert_proc(p, src.axes)
1938            assert not isinstance(cp, ScaleMeanVarianceDescr)
1939            prep.append(cp)
1940
1941        prep.append(EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="float32")))
1942
1943        return tgt(
1944            axes=axes,
1945            id=TensorId(str(src.name)),
1946            test_tensor=FileDescr(source=test_tensor),
1947            sample_tensor=(
1948                None if sample_tensor is None else FileDescr(source=sample_tensor)
1949            ),
1950            data=dict(type=src.data_type),  # pyright: ignore[reportArgumentType]
1951            preprocessing=prep,
1952        )
1953
1954
1955_input_tensor_conv = _InputTensorConv(_InputTensorDescr_v0_4, InputTensorDescr)
1956
1957
1958class OutputTensorDescr(TensorDescrBase[OutputAxis]):
1959    id: TensorId = TensorId("output")
1960    """Output tensor id.
1961    No duplicates are allowed across all inputs and outputs."""
1962
1963    postprocessing: List[PostprocessingDescr] = Field(
1964        default_factory=cast(Callable[[], List[PostprocessingDescr]], list)
1965    )
1966    """Description of how this output should be postprocessed.
1967
1968    note: `postprocessing` always ends with an 'ensure_dtype' operation.
1969          If not given this is added to cast to this tensor's `data.type`.
1970    """
1971
1972    @model_validator(mode="after")
1973    def _validate_postprocessing_kwargs(self) -> Self:
1974        axes_ids = [a.id for a in self.axes]
1975        for p in self.postprocessing:
1976            kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
1977            if kwargs_axes is None:
1978                continue
1979
1980            if not isinstance(kwargs_axes, collections.abc.Sequence):
1981                raise ValueError(
1982                    f"expected `axes` sequence, but got {type(kwargs_axes)}"
1983                )
1984
1985            if any(a not in axes_ids for a in kwargs_axes):
1986                raise ValueError("`kwargs.axes` needs to be subset of axes ids")
1987
1988        if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
1989            dtype = self.data.type
1990        else:
1991            dtype = self.data[0].type
1992
1993        # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
1994        if not self.postprocessing or not isinstance(
1995            self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)
1996        ):
1997            self.postprocessing.append(
1998                EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1999            )
2000        return self
2001
2002
2003class _OutputTensorConv(
2004    Converter[
2005        _OutputTensorDescr_v0_4,
2006        OutputTensorDescr,
2007        FileSource_,
2008        Optional[FileSource_],
2009        Mapping[_TensorName_v0_4, Mapping[str, int]],
2010    ]
2011):
2012    def _convert(
2013        self,
2014        src: _OutputTensorDescr_v0_4,
2015        tgt: "type[OutputTensorDescr] | type[dict[str, Any]]",
2016        test_tensor: FileSource_,
2017        sample_tensor: Optional[FileSource_],
2018        size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
2019    ) -> "OutputTensorDescr | dict[str, Any]":
2020        # TODO: split convert_axes into convert_output_axes and convert_input_axes
2021        axes: List[OutputAxis] = convert_axes(  # pyright: ignore[reportAssignmentType]
2022            src.axes,
2023            shape=src.shape,
2024            tensor_type="output",
2025            halo=src.halo,
2026            size_refs=size_refs,
2027        )
2028        data_descr: Dict[str, Any] = dict(type=src.data_type)
2029        if data_descr["type"] == "bool":
2030            data_descr["values"] = [False, True]
2031
2032        return tgt(
2033            axes=axes,
2034            id=TensorId(str(src.name)),
2035            test_tensor=FileDescr(source=test_tensor),
2036            sample_tensor=(
2037                None if sample_tensor is None else FileDescr(source=sample_tensor)
2038            ),
2039            data=data_descr,  # pyright: ignore[reportArgumentType]
2040            postprocessing=[_convert_proc(p, src.axes) for p in src.postprocessing],
2041        )
2042
2043
2044_output_tensor_conv = _OutputTensorConv(_OutputTensorDescr_v0_4, OutputTensorDescr)
2045
2046
2047TensorDescr = Union[InputTensorDescr, OutputTensorDescr]
2048
2049
2050def validate_tensors(
2051    tensors: Mapping[TensorId, Tuple[TensorDescr, NDArray[Any]]],
2052    tensor_origin: Literal[
2053        "test_tensor"
2054    ],  # for more precise error messages, e.g. 'test_tensor'
2055):
2056    all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, int]]] = {}
2057
2058    def e_msg(d: TensorDescr):
2059        return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]"
2060
2061    for descr, array in tensors.values():
2062        try:
2063            axis_sizes = descr.get_axis_sizes_for_array(array)
2064        except ValueError as e:
2065            raise ValueError(f"{e_msg(descr)} {e}")
2066        else:
2067            all_tensor_axes[descr.id] = {
2068                a.id: (a, axis_sizes[a.id]) for a in descr.axes
2069            }
2070
2071    for descr, array in tensors.values():
2072        if descr.dtype in ("float32", "float64"):
2073            invalid_test_tensor_dtype = array.dtype.name not in (
2074                "float32",
2075                "float64",
2076                "uint8",
2077                "int8",
2078                "uint16",
2079                "int16",
2080                "uint32",
2081                "int32",
2082                "uint64",
2083                "int64",
2084            )
2085        else:
2086            invalid_test_tensor_dtype = array.dtype.name != descr.dtype
2087
2088        if invalid_test_tensor_dtype:
2089            raise ValueError(
2090                f"{e_msg(descr)}.{tensor_origin}.dtype '{array.dtype.name}' does not"
2091                + f" match described dtype '{descr.dtype}'"
2092            )
2093
2094        if array.min() > -1e-4 and array.max() < 1e-4:
2095            raise ValueError(
2096                "Output values are too small for reliable testing."
2097                + f" Values <-1e5 or >=1e5 must be present in {tensor_origin}"
2098            )
2099
2100        for a in descr.axes:
2101            actual_size = all_tensor_axes[descr.id][a.id][1]
2102            if a.size is None:
2103                continue
2104
2105            if isinstance(a.size, int):
2106                if actual_size != a.size:
2107                    raise ValueError(
2108                        f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' "
2109                        + f"has incompatible size {actual_size}, expected {a.size}"
2110                    )
2111            elif isinstance(a.size, ParameterizedSize):
2112                _ = a.size.validate_size(actual_size)
2113            elif isinstance(a.size, DataDependentSize):
2114                _ = a.size.validate_size(actual_size)
2115            elif isinstance(a.size, SizeReference):
2116                ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id)
2117                if ref_tensor_axes is None:
2118                    raise ValueError(
2119                        f"{e_msg(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor"
2120                        + f" reference '{a.size.tensor_id}'"
2121                    )
2122
2123                ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None))
2124                if ref_axis is None or ref_size is None:
2125                    raise ValueError(
2126                        f"{e_msg(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis"
2127                        + f" reference '{a.size.tensor_id}.{a.size.axis_id}"
2128                    )
2129
2130                if a.unit != ref_axis.unit:
2131                    raise ValueError(
2132                        f"{e_msg(descr)}.axes[{a.id}].size: `SizeReference` requires"
2133                        + " axis and reference axis to have the same `unit`, but"
2134                        + f" {a.unit}!={ref_axis.unit}"
2135                    )
2136
2137                if actual_size != (
2138                    expected_size := (
2139                        ref_size * ref_axis.scale / a.scale + a.size.offset
2140                    )
2141                ):
2142                    raise ValueError(
2143                        f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' of size"
2144                        + f" {actual_size} invalid for referenced size {ref_size};"
2145                        + f" expected {expected_size}"
2146                    )
2147            else:
2148                assert_never(a.size)
2149
2150
2151FileDescr_dependencies = Annotated[
2152    FileDescr_,
2153    WithSuffix((".yaml", ".yml"), case_sensitive=True),
2154    Field(examples=[dict(source="environment.yaml")]),
2155]
2156
2157
2158class _ArchitectureCallableDescr(Node):
2159    callable: Annotated[Identifier, Field(examples=["MyNetworkClass", "get_my_model"])]
2160    """Identifier of the callable that returns a torch.nn.Module instance."""
2161
2162    kwargs: Dict[str, YamlValue] = Field(
2163        default_factory=cast(Callable[[], Dict[str, YamlValue]], dict)
2164    )
2165    """key word arguments for the `callable`"""
2166
2167
2168class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr):
2169    source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2170    """Architecture source file"""
2171
2172    @model_serializer(mode="wrap", when_used="unless-none")
2173    def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo):
2174        return package_file_descr_serializer(self, nxt, info)
2175
2176
2177class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr):
2178    import_from: str
2179    """Where to import the callable from, i.e. `from <import_from> import <callable>`"""
2180
2181
2182class _ArchFileConv(
2183    Converter[
2184        _CallableFromFile_v0_4,
2185        ArchitectureFromFileDescr,
2186        Optional[Sha256],
2187        Dict[str, Any],
2188    ]
2189):
2190    def _convert(
2191        self,
2192        src: _CallableFromFile_v0_4,
2193        tgt: "type[ArchitectureFromFileDescr | dict[str, Any]]",
2194        sha256: Optional[Sha256],
2195        kwargs: Dict[str, Any],
2196    ) -> "ArchitectureFromFileDescr | dict[str, Any]":
2197        if src.startswith("http") and src.count(":") == 2:
2198            http, source, callable_ = src.split(":")
2199            source = ":".join((http, source))
2200        elif not src.startswith("http") and src.count(":") == 1:
2201            source, callable_ = src.split(":")
2202        else:
2203            source = str(src)
2204            callable_ = str(src)
2205        return tgt(
2206            callable=Identifier(callable_),
2207            source=cast(FileSource_, source),
2208            sha256=sha256,
2209            kwargs=kwargs,
2210        )
2211
2212
2213_arch_file_conv = _ArchFileConv(_CallableFromFile_v0_4, ArchitectureFromFileDescr)
2214
2215
2216class _ArchLibConv(
2217    Converter[
2218        _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr, Dict[str, Any]
2219    ]
2220):
2221    def _convert(
2222        self,
2223        src: _CallableFromDepencency_v0_4,
2224        tgt: "type[ArchitectureFromLibraryDescr | dict[str, Any]]",
2225        kwargs: Dict[str, Any],
2226    ) -> "ArchitectureFromLibraryDescr | dict[str, Any]":
2227        *mods, callable_ = src.split(".")
2228        import_from = ".".join(mods)
2229        return tgt(
2230            import_from=import_from, callable=Identifier(callable_), kwargs=kwargs
2231        )
2232
2233
2234_arch_lib_conv = _ArchLibConv(
2235    _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr
2236)
2237
2238
2239class WeightsEntryDescrBase(FileDescr):
2240    type: ClassVar[WeightsFormat]
2241    weights_format_name: ClassVar[str]  # human readable
2242
2243    source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2244    """Source of the weights file."""
2245
2246    authors: Optional[List[Author]] = None
2247    """Authors
2248    Either the person(s) that have trained this model resulting in the original weights file.
2249        (If this is the initial weights entry, i.e. it does not have a `parent`)
2250    Or the person(s) who have converted the weights to this weights format.
2251        (If this is a child weight, i.e. it has a `parent` field)
2252    """
2253
2254    parent: Annotated[
2255        Optional[WeightsFormat], Field(examples=["pytorch_state_dict"])
2256    ] = None
2257    """The source weights these weights were converted from.
2258    For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`,
2259    The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights.
2260    All weight entries except one (the initial set of weights resulting from training the model),
2261    need to have this field."""
2262
2263    comment: str = ""
2264    """A comment about this weights entry, for example how these weights were created."""
2265
2266    @model_validator(mode="after")
2267    def _validate(self) -> Self:
2268        if self.type == self.parent:
2269            raise ValueError("Weights entry can't be it's own parent.")
2270
2271        return self
2272
2273    @model_serializer(mode="wrap", when_used="unless-none")
2274    def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo):
2275        return package_file_descr_serializer(self, nxt, info)
2276
2277
2278class KerasHdf5WeightsDescr(WeightsEntryDescrBase):
2279    type = "keras_hdf5"
2280    weights_format_name: ClassVar[str] = "Keras HDF5"
2281    tensorflow_version: Version
2282    """TensorFlow version used to create these weights."""
2283
2284
2285class OnnxWeightsDescr(WeightsEntryDescrBase):
2286    type = "onnx"
2287    weights_format_name: ClassVar[str] = "ONNX"
2288    opset_version: Annotated[int, Ge(7)]
2289    """ONNX opset version"""
2290
2291
2292class PytorchStateDictWeightsDescr(WeightsEntryDescrBase):
2293    type = "pytorch_state_dict"
2294    weights_format_name: ClassVar[str] = "Pytorch State Dict"
2295    architecture: Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr]
2296    pytorch_version: Version
2297    """Version of the PyTorch library used.
2298    If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible.
2299    """
2300    dependencies: Optional[FileDescr_dependencies] = None
2301    """Custom depencies beyond pytorch described in a Conda environment file.
2302    Allows to specify custom dependencies, see conda docs:
2303    - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms)
2304    - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually)
2305
2306    The conda environment file should include pytorch and any version pinning has to be compatible with
2307    **pytorch_version**.
2308    """
2309
2310
2311class TensorflowJsWeightsDescr(WeightsEntryDescrBase):
2312    type = "tensorflow_js"
2313    weights_format_name: ClassVar[str] = "Tensorflow.js"
2314    tensorflow_version: Version
2315    """Version of the TensorFlow library used."""
2316
2317    source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2318    """The multi-file weights.
2319    All required files/folders should be a zip archive."""
2320
2321
2322class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase):
2323    type = "tensorflow_saved_model_bundle"
2324    weights_format_name: ClassVar[str] = "Tensorflow Saved Model"
2325    tensorflow_version: Version
2326    """Version of the TensorFlow library used."""
2327
2328    dependencies: Optional[FileDescr_dependencies] = None
2329    """Custom dependencies beyond tensorflow.
2330    Should include tensorflow and any version pinning has to be compatible with **tensorflow_version**."""
2331
2332    source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2333    """The multi-file weights.
2334    All required files/folders should be a zip archive."""
2335
2336
2337class TorchscriptWeightsDescr(WeightsEntryDescrBase):
2338    type = "torchscript"
2339    weights_format_name: ClassVar[str] = "TorchScript"
2340    pytorch_version: Version
2341    """Version of the PyTorch library used."""
2342
2343
2344class WeightsDescr(Node):
2345    keras_hdf5: Optional[KerasHdf5WeightsDescr] = None
2346    onnx: Optional[OnnxWeightsDescr] = None
2347    pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None
2348    tensorflow_js: Optional[TensorflowJsWeightsDescr] = None
2349    tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = (
2350        None
2351    )
2352    torchscript: Optional[TorchscriptWeightsDescr] = None
2353
2354    @model_validator(mode="after")
2355    def check_entries(self) -> Self:
2356        entries = {wtype for wtype, entry in self if entry is not None}
2357
2358        if not entries:
2359            raise ValueError("Missing weights entry")
2360
2361        entries_wo_parent = {
2362            wtype
2363            for wtype, entry in self
2364            if entry is not None and hasattr(entry, "parent") and entry.parent is None
2365        }
2366        if len(entries_wo_parent) != 1:
2367            issue_warning(
2368                "Exactly one weights entry may not specify the `parent` field (got"
2369                + " {value}). That entry is considered the original set of model weights."
2370                + " Other weight formats are created through conversion of the orignal or"
2371                + " already converted weights. They have to reference the weights format"
2372                + " they were converted from as their `parent`.",
2373                value=len(entries_wo_parent),
2374                field="weights",
2375            )
2376
2377        for wtype, entry in self:
2378            if entry is None:
2379                continue
2380
2381            assert hasattr(entry, "type")
2382            assert hasattr(entry, "parent")
2383            assert wtype == entry.type
2384            if (
2385                entry.parent is not None and entry.parent not in entries
2386            ):  # self reference checked for `parent` field
2387                raise ValueError(
2388                    f"`weights.{wtype}.parent={entry.parent} not in specified weight"
2389                    + f" formats: {entries}"
2390                )
2391
2392        return self
2393
2394    def __getitem__(
2395        self,
2396        key: Literal[
2397            "keras_hdf5",
2398            "onnx",
2399            "pytorch_state_dict",
2400            "tensorflow_js",
2401            "tensorflow_saved_model_bundle",
2402            "torchscript",
2403        ],
2404    ):
2405        if key == "keras_hdf5":
2406            ret = self.keras_hdf5
2407        elif key == "onnx":
2408            ret = self.onnx
2409        elif key == "pytorch_state_dict":
2410            ret = self.pytorch_state_dict
2411        elif key == "tensorflow_js":
2412            ret = self.tensorflow_js
2413        elif key == "tensorflow_saved_model_bundle":
2414            ret = self.tensorflow_saved_model_bundle
2415        elif key == "torchscript":
2416            ret = self.torchscript
2417        else:
2418            raise KeyError(key)
2419
2420        if ret is None:
2421            raise KeyError(key)
2422
2423        return ret
2424
2425    @property
2426    def available_formats(self):
2427        return {
2428            **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}),
2429            **({} if self.onnx is None else {"onnx": self.onnx}),
2430            **(
2431                {}
2432                if self.pytorch_state_dict is None
2433                else {"pytorch_state_dict": self.pytorch_state_dict}
2434            ),
2435            **(
2436                {}
2437                if self.tensorflow_js is None
2438                else {"tensorflow_js": self.tensorflow_js}
2439            ),
2440            **(
2441                {}
2442                if self.tensorflow_saved_model_bundle is None
2443                else {
2444                    "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle
2445                }
2446            ),
2447            **({} if self.torchscript is None else {"torchscript": self.torchscript}),
2448        }
2449
2450    @property
2451    def missing_formats(self):
2452        return {
2453            wf for wf in get_args(WeightsFormat) if wf not in self.available_formats
2454        }
2455
2456
2457class ModelId(ResourceId):
2458    pass
2459
2460
2461class LinkedModel(LinkedResourceBase):
2462    """Reference to a bioimage.io model."""
2463
2464    id: ModelId
2465    """A valid model `id` from the bioimage.io collection."""
2466
2467
2468class _DataDepSize(NamedTuple):
2469    min: int
2470    max: Optional[int]
2471
2472
2473class _AxisSizes(NamedTuple):
2474    """the lenghts of all axes of model inputs and outputs"""
2475
2476    inputs: Dict[Tuple[TensorId, AxisId], int]
2477    outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]]
2478
2479
2480class _TensorSizes(NamedTuple):
2481    """_AxisSizes as nested dicts"""
2482
2483    inputs: Dict[TensorId, Dict[AxisId, int]]
2484    outputs: Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]
2485
2486
2487class ReproducibilityTolerance(Node, extra="allow"):
2488    """Describes what small numerical differences -- if any -- may be tolerated
2489    in the generated output when executing in different environments.
2490
2491    A tensor element *output* is considered mismatched to the **test_tensor** if
2492    abs(*output* - **test_tensor**) > **absolute_tolerance** + **relative_tolerance** * abs(**test_tensor**).
2493    (Internally we call [numpy.testing.assert_allclose](https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html).)
2494
2495    Motivation:
2496        For testing we can request the respective deep learning frameworks to be as
2497        reproducible as possible by setting seeds and chosing deterministic algorithms,
2498        but differences in operating systems, available hardware and installed drivers
2499        may still lead to numerical differences.
2500    """
2501
2502    relative_tolerance: RelativeTolerance = 1e-3
2503    """Maximum relative tolerance of reproduced test tensor."""
2504
2505    absolute_tolerance: AbsoluteTolerance = 1e-4
2506    """Maximum absolute tolerance of reproduced test tensor."""
2507
2508    mismatched_elements_per_million: MismatchedElementsPerMillion = 100
2509    """Maximum number of mismatched elements/pixels per million to tolerate."""
2510
2511    output_ids: Sequence[TensorId] = ()
2512    """Limits the output tensor IDs these reproducibility details apply to."""
2513
2514    weights_formats: Sequence[WeightsFormat] = ()
2515    """Limits the weights formats these details apply to."""
2516
2517
2518class BioimageioConfig(Node, extra="allow"):
2519    reproducibility_tolerance: Sequence[ReproducibilityTolerance] = ()
2520    """Tolerances to allow when reproducing the model's test outputs
2521    from the model's test inputs.
2522    Only the first entry matching tensor id and weights format is considered.
2523    """
2524
2525
2526class Config(Node, extra="allow"):
2527    bioimageio: BioimageioConfig = Field(default_factory=BioimageioConfig)
2528
2529
2530class ModelDescr(GenericModelDescrBase):
2531    """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights.
2532    These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
2533    """
2534
2535    implemented_format_version: ClassVar[Literal["0.5.4"]] = "0.5.4"
2536    if TYPE_CHECKING:
2537        format_version: Literal["0.5.4"] = "0.5.4"
2538    else:
2539        format_version: Literal["0.5.4"]
2540        """Version of the bioimage.io model description specification used.
2541        When creating a new model always use the latest micro/patch version described here.
2542        The `format_version` is important for any consumer software to understand how to parse the fields.
2543        """
2544
2545    implemented_type: ClassVar[Literal["model"]] = "model"
2546    if TYPE_CHECKING:
2547        type: Literal["model"] = "model"
2548    else:
2549        type: Literal["model"]
2550        """Specialized resource type 'model'"""
2551
2552    id: Optional[ModelId] = None
2553    """bioimage.io-wide unique resource identifier
2554    assigned by bioimage.io; version **un**specific."""
2555
2556    authors: NotEmpty[List[Author]]
2557    """The authors are the creators of the model RDF and the primary points of contact."""
2558
2559    documentation: FileSource_documentation
2560    """URL or relative path to a markdown file with additional documentation.
2561    The recommended documentation file name is `README.md`. An `.md` suffix is mandatory.
2562    The documentation should include a '#[#] Validation' (sub)section
2563    with details on how to quantitatively validate the model on unseen data."""
2564
2565    @field_validator("documentation", mode="after")
2566    @classmethod
2567    def _validate_documentation(
2568        cls, value: FileSource_documentation
2569    ) -> FileSource_documentation:
2570        if not get_validation_context().perform_io_checks:
2571            return value
2572
2573        doc_reader = get_reader(value)
2574        doc_content = doc_reader.read().decode(encoding="utf-8")
2575        if not re.search("#.*[vV]alidation", doc_content):
2576            issue_warning(
2577                "No '# Validation' (sub)section found in {value}.",
2578                value=value,
2579                field="documentation",
2580            )
2581
2582        return value
2583
2584    inputs: NotEmpty[Sequence[InputTensorDescr]]
2585    """Describes the input tensors expected by this model."""
2586
2587    @field_validator("inputs", mode="after")
2588    @classmethod
2589    def _validate_input_axes(
2590        cls, inputs: Sequence[InputTensorDescr]
2591    ) -> Sequence[InputTensorDescr]:
2592        input_size_refs = cls._get_axes_with_independent_size(inputs)
2593
2594        for i, ipt in enumerate(inputs):
2595            valid_independent_refs: Dict[
2596                Tuple[TensorId, AxisId],
2597                Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2598            ] = {
2599                **{
2600                    (ipt.id, a.id): (ipt, a, a.size)
2601                    for a in ipt.axes
2602                    if not isinstance(a, BatchAxis)
2603                    and isinstance(a.size, (int, ParameterizedSize))
2604                },
2605                **input_size_refs,
2606            }
2607            for a, ax in enumerate(ipt.axes):
2608                cls._validate_axis(
2609                    "inputs",
2610                    i=i,
2611                    tensor_id=ipt.id,
2612                    a=a,
2613                    axis=ax,
2614                    valid_independent_refs=valid_independent_refs,
2615                )
2616        return inputs
2617
2618    @staticmethod
2619    def _validate_axis(
2620        field_name: str,
2621        i: int,
2622        tensor_id: TensorId,
2623        a: int,
2624        axis: AnyAxis,
2625        valid_independent_refs: Dict[
2626            Tuple[TensorId, AxisId],
2627            Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2628        ],
2629    ):
2630        if isinstance(axis, BatchAxis) or isinstance(
2631            axis.size, (int, ParameterizedSize, DataDependentSize)
2632        ):
2633            return
2634        elif not isinstance(axis.size, SizeReference):
2635            assert_never(axis.size)
2636
2637        # validate axis.size SizeReference
2638        ref = (axis.size.tensor_id, axis.size.axis_id)
2639        if ref not in valid_independent_refs:
2640            raise ValueError(
2641                "Invalid tensor axis reference at"
2642                + f" {field_name}[{i}].axes[{a}].size: {axis.size}."
2643            )
2644        if ref == (tensor_id, axis.id):
2645            raise ValueError(
2646                "Self-referencing not allowed for"
2647                + f" {field_name}[{i}].axes[{a}].size: {axis.size}"
2648            )
2649        if axis.type == "channel":
2650            if valid_independent_refs[ref][1].type != "channel":
2651                raise ValueError(
2652                    "A channel axis' size may only reference another fixed size"
2653                    + " channel axis."
2654                )
2655            if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names:
2656                ref_size = valid_independent_refs[ref][2]
2657                assert isinstance(ref_size, int), (
2658                    "channel axis ref (another channel axis) has to specify fixed"
2659                    + " size"
2660                )
2661                generated_channel_names = [
2662                    Identifier(axis.channel_names.format(i=i))
2663                    for i in range(1, ref_size + 1)
2664                ]
2665                axis.channel_names = generated_channel_names
2666
2667        if (ax_unit := getattr(axis, "unit", None)) != (
2668            ref_unit := getattr(valid_independent_refs[ref][1], "unit", None)
2669        ):
2670            raise ValueError(
2671                "The units of an axis and its reference axis need to match, but"
2672                + f" '{ax_unit}' != '{ref_unit}'."
2673            )
2674        ref_axis = valid_independent_refs[ref][1]
2675        if isinstance(ref_axis, BatchAxis):
2676            raise ValueError(
2677                f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}"
2678                + " (a batch axis is not allowed as reference)."
2679            )
2680
2681        if isinstance(axis, WithHalo):
2682            min_size = axis.size.get_size(axis, ref_axis, n=0)
2683            if (min_size - 2 * axis.halo) < 1:
2684                raise ValueError(
2685                    f"axis {axis.id} with minimum size {min_size} is too small for halo"
2686                    + f" {axis.halo}."
2687                )
2688
2689            input_halo = axis.halo * axis.scale / ref_axis.scale
2690            if input_halo != int(input_halo) or input_halo % 2 == 1:
2691                raise ValueError(
2692                    f"input_halo {input_halo} (output_halo {axis.halo} *"
2693                    + f" output_scale {axis.scale} / input_scale {ref_axis.scale})"
2694                    + f"     {tensor_id}.{axis.id}."
2695                )
2696
2697    @model_validator(mode="after")
2698    def _validate_test_tensors(self) -> Self:
2699        if not get_validation_context().perform_io_checks:
2700            return self
2701
2702        test_output_arrays = [load_array(descr.test_tensor) for descr in self.outputs]
2703        test_input_arrays = [load_array(descr.test_tensor) for descr in self.inputs]
2704
2705        tensors = {
2706            descr.id: (descr, array)
2707            for descr, array in zip(
2708                chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays
2709            )
2710        }
2711        validate_tensors(tensors, tensor_origin="test_tensor")
2712
2713        output_arrays = {
2714            descr.id: array for descr, array in zip(self.outputs, test_output_arrays)
2715        }
2716        for rep_tol in self.config.bioimageio.reproducibility_tolerance:
2717            if not rep_tol.absolute_tolerance:
2718                continue
2719
2720            if rep_tol.output_ids:
2721                out_arrays = {
2722                    oid: a
2723                    for oid, a in output_arrays.items()
2724                    if oid in rep_tol.output_ids
2725                }
2726            else:
2727                out_arrays = output_arrays
2728
2729            for out_id, array in out_arrays.items():
2730                if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01:
2731                    raise ValueError(
2732                        "config.bioimageio.reproducibility_tolerance.absolute_tolerance="
2733                        + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}"
2734                        + f" (1% of the maximum value of the test tensor '{out_id}')"
2735                    )
2736
2737        return self
2738
2739    @model_validator(mode="after")
2740    def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self:
2741        ipt_refs = {t.id for t in self.inputs}
2742        out_refs = {t.id for t in self.outputs}
2743        for ipt in self.inputs:
2744            for p in ipt.preprocessing:
2745                ref = p.kwargs.get("reference_tensor")
2746                if ref is None:
2747                    continue
2748                if ref not in ipt_refs:
2749                    raise ValueError(
2750                        f"`reference_tensor` '{ref}' not found. Valid input tensor"
2751                        + f" references are: {ipt_refs}."
2752                    )
2753
2754        for out in self.outputs:
2755            for p in out.postprocessing:
2756                ref = p.kwargs.get("reference_tensor")
2757                if ref is None:
2758                    continue
2759
2760                if ref not in ipt_refs and ref not in out_refs:
2761                    raise ValueError(
2762                        f"`reference_tensor` '{ref}' not found. Valid tensor references"
2763                        + f" are: {ipt_refs | out_refs}."
2764                    )
2765
2766        return self
2767
2768    # TODO: use validate funcs in validate_test_tensors
2769    # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]:
2770
2771    name: Annotated[
2772        Annotated[
2773            str, RestrictCharacters(string.ascii_letters + string.digits + "_+- ()")
2774        ],
2775        MinLen(5),
2776        MaxLen(128),
2777        warn(MaxLen(64), "Name longer than 64 characters.", INFO),
2778    ]
2779    """A human-readable name of this model.
2780    It should be no longer than 64 characters
2781    and may only contain letter, number, underscore, minus, parentheses and spaces.
2782    We recommend to chose a name that refers to the model's task and image modality.
2783    """
2784
2785    outputs: NotEmpty[Sequence[OutputTensorDescr]]
2786    """Describes the output tensors."""
2787
2788    @field_validator("outputs", mode="after")
2789    @classmethod
2790    def _validate_tensor_ids(
2791        cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo
2792    ) -> Sequence[OutputTensorDescr]:
2793        tensor_ids = [
2794            t.id for t in info.data.get("inputs", []) + info.data.get("outputs", [])
2795        ]
2796        duplicate_tensor_ids: List[str] = []
2797        seen: Set[str] = set()
2798        for t in tensor_ids:
2799            if t in seen:
2800                duplicate_tensor_ids.append(t)
2801
2802            seen.add(t)
2803
2804        if duplicate_tensor_ids:
2805            raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}")
2806
2807        return outputs
2808
2809    @staticmethod
2810    def _get_axes_with_parameterized_size(
2811        io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2812    ):
2813        return {
2814            f"{t.id}.{a.id}": (t, a, a.size)
2815            for t in io
2816            for a in t.axes
2817            if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize)
2818        }
2819
2820    @staticmethod
2821    def _get_axes_with_independent_size(
2822        io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2823    ):
2824        return {
2825            (t.id, a.id): (t, a, a.size)
2826            for t in io
2827            for a in t.axes
2828            if not isinstance(a, BatchAxis)
2829            and isinstance(a.size, (int, ParameterizedSize))
2830        }
2831
2832    @field_validator("outputs", mode="after")
2833    @classmethod
2834    def _validate_output_axes(
2835        cls, outputs: List[OutputTensorDescr], info: ValidationInfo
2836    ) -> List[OutputTensorDescr]:
2837        input_size_refs = cls._get_axes_with_independent_size(
2838            info.data.get("inputs", [])
2839        )
2840        output_size_refs = cls._get_axes_with_independent_size(outputs)
2841
2842        for i, out in enumerate(outputs):
2843            valid_independent_refs: Dict[
2844                Tuple[TensorId, AxisId],
2845                Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2846            ] = {
2847                **{
2848                    (out.id, a.id): (out, a, a.size)
2849                    for a in out.axes
2850                    if not isinstance(a, BatchAxis)
2851                    and isinstance(a.size, (int, ParameterizedSize))
2852                },
2853                **input_size_refs,
2854                **output_size_refs,
2855            }
2856            for a, ax in enumerate(out.axes):
2857                cls._validate_axis(
2858                    "outputs",
2859                    i,
2860                    out.id,
2861                    a,
2862                    ax,
2863                    valid_independent_refs=valid_independent_refs,
2864                )
2865
2866        return outputs
2867
2868    packaged_by: List[Author] = Field(
2869        default_factory=cast(Callable[[], List[Author]], list)
2870    )
2871    """The persons that have packaged and uploaded this model.
2872    Only required if those persons differ from the `authors`."""
2873
2874    parent: Optional[LinkedModel] = None
2875    """The model from which this model is derived, e.g. by fine-tuning the weights."""
2876
2877    @model_validator(mode="after")
2878    def _validate_parent_is_not_self(self) -> Self:
2879        if self.parent is not None and self.parent.id == self.id:
2880            raise ValueError("A model description may not reference itself as parent.")
2881
2882        return self
2883
2884    run_mode: Annotated[
2885        Optional[RunMode],
2886        warn(None, "Run mode '{value}' has limited support across consumer softwares."),
2887    ] = None
2888    """Custom run mode for this model: for more complex prediction procedures like test time
2889    data augmentation that currently cannot be expressed in the specification.
2890    No standard run modes are defined yet."""
2891
2892    timestamp: Datetime = Field(default_factory=Datetime.now)
2893    """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format
2894    with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat).
2895    (In Python a datetime object is valid, too)."""
2896
2897    training_data: Annotated[
2898        Union[None, LinkedDataset, DatasetDescr, DatasetDescr02],
2899        Field(union_mode="left_to_right"),
2900    ] = None
2901    """The dataset used to train this model"""
2902
2903    weights: Annotated[WeightsDescr, WrapSerializer(package_weights)]
2904    """The weights for this model.
2905    Weights can be given for different formats, but should otherwise be equivalent.
2906    The available weight formats determine which consumers can use this model."""
2907
2908    config: Config = Field(default_factory=Config)
2909
2910    @model_validator(mode="after")
2911    def _add_default_cover(self) -> Self:
2912        if not get_validation_context().perform_io_checks or self.covers:
2913            return self
2914
2915        try:
2916            generated_covers = generate_covers(
2917                [(t, load_array(t.test_tensor)) for t in self.inputs],
2918                [(t, load_array(t.test_tensor)) for t in self.outputs],
2919            )
2920        except Exception as e:
2921            issue_warning(
2922                "Failed to generate cover image(s): {e}",
2923                value=self.covers,
2924                msg_context=dict(e=e),
2925                field="covers",
2926            )
2927        else:
2928            self.covers.extend(generated_covers)
2929
2930        return self
2931
2932    def get_input_test_arrays(self) -> List[NDArray[Any]]:
2933        data = [load_array(ipt.test_tensor) for ipt in self.inputs]
2934        assert all(isinstance(d, np.ndarray) for d in data)
2935        return data
2936
2937    def get_output_test_arrays(self) -> List[NDArray[Any]]:
2938        data = [load_array(out.test_tensor) for out in self.outputs]
2939        assert all(isinstance(d, np.ndarray) for d in data)
2940        return data
2941
2942    @staticmethod
2943    def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
2944        batch_size = 1
2945        tensor_with_batchsize: Optional[TensorId] = None
2946        for tid in tensor_sizes:
2947            for aid, s in tensor_sizes[tid].items():
2948                if aid != BATCH_AXIS_ID or s == 1 or s == batch_size:
2949                    continue
2950
2951                if batch_size != 1:
2952                    assert tensor_with_batchsize is not None
2953                    raise ValueError(
2954                        f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})"
2955                    )
2956
2957                batch_size = s
2958                tensor_with_batchsize = tid
2959
2960        return batch_size
2961
2962    def get_output_tensor_sizes(
2963        self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]
2964    ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]:
2965        """Returns the tensor output sizes for given **input_sizes**.
2966        Only if **input_sizes** has a valid input shape, the tensor output size is exact.
2967        Otherwise it might be larger than the actual (valid) output"""
2968        batch_size = self.get_batch_size(input_sizes)
2969        ns = self.get_ns(input_sizes)
2970
2971        tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size)
2972        return tensor_sizes.outputs
2973
2974    def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
2975        """get parameter `n` for each parameterized axis
2976        such that the valid input size is >= the given input size"""
2977        ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {}
2978        axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs}
2979        for tid in input_sizes:
2980            for aid, s in input_sizes[tid].items():
2981                size_descr = axes[tid][aid].size
2982                if isinstance(size_descr, ParameterizedSize):
2983                    ret[(tid, aid)] = size_descr.get_n(s)
2984                elif size_descr is None or isinstance(size_descr, (int, SizeReference)):
2985                    pass
2986                else:
2987                    assert_never(size_descr)
2988
2989        return ret
2990
2991    def get_tensor_sizes(
2992        self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int
2993    ) -> _TensorSizes:
2994        axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size)
2995        return _TensorSizes(
2996            {
2997                t: {
2998                    aa: axis_sizes.inputs[(tt, aa)]
2999                    for tt, aa in axis_sizes.inputs
3000                    if tt == t
3001                }
3002                for t in {tt for tt, _ in axis_sizes.inputs}
3003            },
3004            {
3005                t: {
3006                    aa: axis_sizes.outputs[(tt, aa)]
3007                    for tt, aa in axis_sizes.outputs
3008                    if tt == t
3009                }
3010                for t in {tt for tt, _ in axis_sizes.outputs}
3011            },
3012        )
3013
3014    def get_axis_sizes(
3015        self,
3016        ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
3017        batch_size: Optional[int] = None,
3018        *,
3019        max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
3020    ) -> _AxisSizes:
3021        """Determine input and output block shape for scale factors **ns**
3022        of parameterized input sizes.
3023
3024        Args:
3025            ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id))
3026                that is parameterized as `size = min + n * step`.
3027            batch_size: The desired size of the batch dimension.
3028                If given **batch_size** overwrites any batch size present in
3029                **max_input_shape**. Default 1.
3030            max_input_shape: Limits the derived block shapes.
3031                Each axis for which the input size, parameterized by `n`, is larger
3032                than **max_input_shape** is set to the minimal value `n_min` for which
3033                this is still true.
3034                Use this for small input samples or large values of **ns**.
3035                Or simply whenever you know the full input shape.
3036
3037        Returns:
3038            Resolved axis sizes for model inputs and outputs.
3039        """
3040        max_input_shape = max_input_shape or {}
3041        if batch_size is None:
3042            for (_t_id, a_id), s in max_input_shape.items():
3043                if a_id == BATCH_AXIS_ID:
3044                    batch_size = s
3045                    break
3046            else:
3047                batch_size = 1
3048
3049        all_axes = {
3050            t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
3051        }
3052
3053        inputs: Dict[Tuple[TensorId, AxisId], int] = {}
3054        outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {}
3055
3056        def get_axis_size(a: Union[InputAxis, OutputAxis]):
3057            if isinstance(a, BatchAxis):
3058                if (t_descr.id, a.id) in ns:
3059                    logger.warning(
3060                        "Ignoring unexpected size increment factor (n) for batch axis"
3061                        + " of tensor '{}'.",
3062                        t_descr.id,
3063                    )
3064                return batch_size
3065            elif isinstance(a.size, int):
3066                if (t_descr.id, a.id) in ns:
3067                    logger.warning(
3068                        "Ignoring unexpected size increment factor (n) for fixed size"
3069                        + " axis '{}' of tensor '{}'.",
3070                        a.id,
3071                        t_descr.id,
3072                    )
3073                return a.size
3074            elif isinstance(a.size, ParameterizedSize):
3075                if (t_descr.id, a.id) not in ns:
3076                    raise ValueError(
3077                        "Size increment factor (n) missing for parametrized axis"
3078                        + f" '{a.id}' of tensor '{t_descr.id}'."
3079                    )
3080                n = ns[(t_descr.id, a.id)]
3081                s_max = max_input_shape.get((t_descr.id, a.id))
3082                if s_max is not None:
3083                    n = min(n, a.size.get_n(s_max))
3084
3085                return a.size.get_size(n)
3086
3087            elif isinstance(a.size, SizeReference):
3088                if (t_descr.id, a.id) in ns:
3089                    logger.warning(
3090                        "Ignoring unexpected size increment factor (n) for axis '{}'"
3091                        + " of tensor '{}' with size reference.",
3092                        a.id,
3093                        t_descr.id,
3094                    )
3095                assert not isinstance(a, BatchAxis)
3096                ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
3097                assert not isinstance(ref_axis, BatchAxis)
3098                ref_key = (a.size.tensor_id, a.size.axis_id)
3099                ref_size = inputs.get(ref_key, outputs.get(ref_key))
3100                assert ref_size is not None, ref_key
3101                assert not isinstance(ref_size, _DataDepSize), ref_key
3102                return a.size.get_size(
3103                    axis=a,
3104                    ref_axis=ref_axis,
3105                    ref_size=ref_size,
3106                )
3107            elif isinstance(a.size, DataDependentSize):
3108                if (t_descr.id, a.id) in ns:
3109                    logger.warning(
3110                        "Ignoring unexpected increment factor (n) for data dependent"
3111                        + " size axis '{}' of tensor '{}'.",
3112                        a.id,
3113                        t_descr.id,
3114                    )
3115                return _DataDepSize(a.size.min, a.size.max)
3116            else:
3117                assert_never(a.size)
3118
3119        # first resolve all , but the `SizeReference` input sizes
3120        for t_descr in self.inputs:
3121            for a in t_descr.axes:
3122                if not isinstance(a.size, SizeReference):
3123                    s = get_axis_size(a)
3124                    assert not isinstance(s, _DataDepSize)
3125                    inputs[t_descr.id, a.id] = s
3126
3127        # resolve all other input axis sizes
3128        for t_descr in self.inputs:
3129            for a in t_descr.axes:
3130                if isinstance(a.size, SizeReference):
3131                    s = get_axis_size(a)
3132                    assert not isinstance(s, _DataDepSize)
3133                    inputs[t_descr.id, a.id] = s
3134
3135        # resolve all output axis sizes
3136        for t_descr in self.outputs:
3137            for a in t_descr.axes:
3138                assert not isinstance(a.size, ParameterizedSize)
3139                s = get_axis_size(a)
3140                outputs[t_descr.id, a.id] = s
3141
3142        return _AxisSizes(inputs=inputs, outputs=outputs)
3143
3144    @model_validator(mode="before")
3145    @classmethod
3146    def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]:
3147        cls.convert_from_old_format_wo_validation(data)
3148        return data
3149
3150    @classmethod
3151    def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3152        """Convert metadata following an older format version to this classes' format
3153        without validating the result.
3154        """
3155        if (
3156            data.get("type") == "model"
3157            and isinstance(fv := data.get("format_version"), str)
3158            and fv.count(".") == 2
3159        ):
3160            fv_parts = fv.split(".")
3161            if any(not p.isdigit() for p in fv_parts):
3162                return
3163
3164            fv_tuple = tuple(map(int, fv_parts))
3165
3166            assert cls.implemented_format_version_tuple[0:2] == (0, 5)
3167            if fv_tuple[:2] in ((0, 3), (0, 4)):
3168                m04 = _ModelDescr_v0_4.load(data)
3169                if isinstance(m04, InvalidDescr):
3170                    try:
3171                        updated = _model_conv.convert_as_dict(
3172                            m04  # pyright: ignore[reportArgumentType]
3173                        )
3174                    except Exception as e:
3175                        logger.error(
3176                            "Failed to convert from invalid model 0.4 description."
3177                            + f"\nerror: {e}"
3178                            + "\nProceeding with model 0.5 validation without conversion."
3179                        )
3180                        updated = None
3181                else:
3182                    updated = _model_conv.convert_as_dict(m04)
3183
3184                if updated is not None:
3185                    data.clear()
3186                    data.update(updated)
3187
3188            elif fv_tuple[:2] == (0, 5):
3189                # bump patch version
3190                data["format_version"] = cls.implemented_format_version
3191
3192
3193class _ModelConv(Converter[_ModelDescr_v0_4, ModelDescr]):
3194    def _convert(
3195        self, src: _ModelDescr_v0_4, tgt: "type[ModelDescr] | type[dict[str, Any]]"
3196    ) -> "ModelDescr | dict[str, Any]":
3197        name = "".join(
3198            c if c in string.ascii_letters + string.digits + "_+- ()" else " "
3199            for c in src.name
3200        )
3201
3202        def conv_authors(auths: Optional[Sequence[_Author_v0_4]]):
3203            conv = (
3204                _author_conv.convert if TYPE_CHECKING else _author_conv.convert_as_dict
3205            )
3206            return None if auths is None else [conv(a) for a in auths]
3207
3208        if TYPE_CHECKING:
3209            arch_file_conv = _arch_file_conv.convert
3210            arch_lib_conv = _arch_lib_conv.convert
3211        else:
3212            arch_file_conv = _arch_file_conv.convert_as_dict
3213            arch_lib_conv = _arch_lib_conv.convert_as_dict
3214
3215        input_size_refs = {
3216            ipt.name: {
3217                a: s
3218                for a, s in zip(
3219                    ipt.axes,
3220                    (
3221                        ipt.shape.min
3222                        if isinstance(ipt.shape, _ParameterizedInputShape_v0_4)
3223                        else ipt.shape
3224                    ),
3225                )
3226            }
3227            for ipt in src.inputs
3228            if ipt.shape
3229        }
3230        output_size_refs = {
3231            **{
3232                out.name: {a: s for a, s in zip(out.axes, out.shape)}
3233                for out in src.outputs
3234                if not isinstance(out.shape, _ImplicitOutputShape_v0_4)
3235            },
3236            **input_size_refs,
3237        }
3238
3239        return tgt(
3240            attachments=(
3241                []
3242                if src.attachments is None
3243                else [FileDescr(source=f) for f in src.attachments.files]
3244            ),
3245            authors=[
3246                _author_conv.convert_as_dict(a) for a in src.authors
3247            ],  # pyright: ignore[reportArgumentType]
3248            cite=[
3249                {"text": c.text, "doi": c.doi, "url": c.url} for c in src.cite
3250            ],  # pyright: ignore[reportArgumentType]
3251            config=src.config,  # pyright: ignore[reportArgumentType]
3252            covers=src.covers,
3253            description=src.description,
3254            documentation=src.documentation,
3255            format_version="0.5.4",
3256            git_repo=src.git_repo,  # pyright: ignore[reportArgumentType]
3257            icon=src.icon,
3258            id=None if src.id is None else ModelId(src.id),
3259            id_emoji=src.id_emoji,
3260            license=src.license,  # type: ignore
3261            links=src.links,
3262            maintainers=[
3263                _maintainer_conv.convert_as_dict(m) for m in src.maintainers
3264            ],  # pyright: ignore[reportArgumentType]
3265            name=name,
3266            tags=src.tags,
3267            type=src.type,
3268            uploader=src.uploader,
3269            version=src.version,
3270            inputs=[  # pyright: ignore[reportArgumentType]
3271                _input_tensor_conv.convert_as_dict(ipt, tt, st, input_size_refs)
3272                for ipt, tt, st, in zip(
3273                    src.inputs,
3274                    src.test_inputs,
3275                    src.sample_inputs or [None] * len(src.test_inputs),
3276                )
3277            ],
3278            outputs=[  # pyright: ignore[reportArgumentType]
3279                _output_tensor_conv.convert_as_dict(out, tt, st, output_size_refs)
3280                for out, tt, st, in zip(
3281                    src.outputs,
3282                    src.test_outputs,
3283                    src.sample_outputs or [None] * len(src.test_outputs),
3284                )
3285            ],
3286            parent=(
3287                None
3288                if src.parent is None
3289                else LinkedModel(
3290                    id=ModelId(
3291                        str(src.parent.id)
3292                        + (
3293                            ""
3294                            if src.parent.version_number is None
3295                            else f"/{src.parent.version_number}"
3296                        )
3297                    )
3298                )
3299            ),
3300            training_data=(
3301                None
3302                if src.training_data is None
3303                else (
3304                    LinkedDataset(
3305                        id=DatasetId(
3306                            str(src.training_data.id)
3307                            + (
3308                                ""
3309                                if src.training_data.version_number is None
3310                                else f"/{src.training_data.version_number}"
3311                            )
3312                        )
3313                    )
3314                    if isinstance(src.training_data, LinkedDataset02)
3315                    else src.training_data
3316                )
3317            ),
3318            packaged_by=[
3319                _author_conv.convert_as_dict(a) for a in src.packaged_by
3320            ],  # pyright: ignore[reportArgumentType]
3321            run_mode=src.run_mode,
3322            timestamp=src.timestamp,
3323            weights=(WeightsDescr if TYPE_CHECKING else dict)(
3324                keras_hdf5=(w := src.weights.keras_hdf5)
3325                and (KerasHdf5WeightsDescr if TYPE_CHECKING else dict)(
3326                    authors=conv_authors(w.authors),
3327                    source=w.source,
3328                    tensorflow_version=w.tensorflow_version or Version("1.15"),
3329                    parent=w.parent,
3330                ),
3331                onnx=(w := src.weights.onnx)
3332                and (OnnxWeightsDescr if TYPE_CHECKING else dict)(
3333                    source=w.source,
3334                    authors=conv_authors(w.authors),
3335                    parent=w.parent,
3336                    opset_version=w.opset_version or 15,
3337                ),
3338                pytorch_state_dict=(w := src.weights.pytorch_state_dict)
3339                and (PytorchStateDictWeightsDescr if TYPE_CHECKING else dict)(
3340                    source=w.source,
3341                    authors=conv_authors(w.authors),
3342                    parent=w.parent,
3343                    architecture=(
3344                        arch_file_conv(
3345                            w.architecture,
3346                            w.architecture_sha256,
3347                            w.kwargs,
3348                        )
3349                        if isinstance(w.architecture, _CallableFromFile_v0_4)
3350                        else arch_lib_conv(w.architecture, w.kwargs)
3351                    ),
3352                    pytorch_version=w.pytorch_version or Version("1.10"),
3353                    dependencies=(
3354                        None
3355                        if w.dependencies is None
3356                        else (FileDescr if TYPE_CHECKING else dict)(
3357                            source=cast(
3358                                FileSource,
3359                                str(deps := w.dependencies)[
3360                                    (
3361                                        len("conda:")
3362                                        if str(deps).startswith("conda:")
3363                                        else 0
3364                                    ) :
3365                                ],
3366                            )
3367                        )
3368                    ),
3369                ),
3370                tensorflow_js=(w := src.weights.tensorflow_js)
3371                and (TensorflowJsWeightsDescr if TYPE_CHECKING else dict)(
3372                    source=w.source,
3373                    authors=conv_authors(w.authors),
3374                    parent=w.parent,
3375                    tensorflow_version=w.tensorflow_version or Version("1.15"),
3376                ),
3377                tensorflow_saved_model_bundle=(
3378                    w := src.weights.tensorflow_saved_model_bundle
3379                )
3380                and (TensorflowSavedModelBundleWeightsDescr if TYPE_CHECKING else dict)(
3381                    authors=conv_authors(w.authors),
3382                    parent=w.parent,
3383                    source=w.source,
3384                    tensorflow_version=w.tensorflow_version or Version("1.15"),
3385                    dependencies=(
3386                        None
3387                        if w.dependencies is None
3388                        else (FileDescr if TYPE_CHECKING else dict)(
3389                            source=cast(
3390                                FileSource,
3391                                (
3392                                    str(w.dependencies)[len("conda:") :]
3393                                    if str(w.dependencies).startswith("conda:")
3394                                    else str(w.dependencies)
3395                                ),
3396                            )
3397                        )
3398                    ),
3399                ),
3400                torchscript=(w := src.weights.torchscript)
3401                and (TorchscriptWeightsDescr if TYPE_CHECKING else dict)(
3402                    source=w.source,
3403                    authors=conv_authors(w.authors),
3404                    parent=w.parent,
3405                    pytorch_version=w.pytorch_version or Version("1.10"),
3406                ),
3407            ),
3408        )
3409
3410
3411_model_conv = _ModelConv(_ModelDescr_v0_4, ModelDescr)
3412
3413
3414# create better cover images for 3d data and non-image outputs
3415def generate_covers(
3416    inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]],
3417    outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]],
3418) -> List[Path]:
3419    def squeeze(
3420        data: NDArray[Any], axes: Sequence[AnyAxis]
3421    ) -> Tuple[NDArray[Any], List[AnyAxis]]:
3422        """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining"""
3423        if data.ndim != len(axes):
3424            raise ValueError(
3425                f"tensor shape {data.shape} does not match described axes"
3426                + f" {[a.id for a in axes]}"
3427            )
3428
3429        axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1]
3430        return data.squeeze(), axes
3431
3432    def normalize(
3433        data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7
3434    ) -> NDArray[np.float32]:
3435        data = data.astype("float32")
3436        data -= data.min(axis=axis, keepdims=True)
3437        data /= data.max(axis=axis, keepdims=True) + eps
3438        return data
3439
3440    def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]):
3441        original_shape = data.shape
3442        data, axes = squeeze(data, axes)
3443
3444        # take slice fom any batch or index axis if needed
3445        # and convert the first channel axis and take a slice from any additional channel axes
3446        slices: Tuple[slice, ...] = ()
3447        ndim = data.ndim
3448        ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2
3449        has_c_axis = False
3450        for i, a in enumerate(axes):
3451            s = data.shape[i]
3452            assert s > 1
3453            if (
3454                isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis))
3455                and ndim > ndim_need
3456            ):
3457                data = data[slices + (slice(s // 2 - 1, s // 2),)]
3458                ndim -= 1
3459            elif isinstance(a, ChannelAxis):
3460                if has_c_axis:
3461                    # second channel axis
3462                    data = data[slices + (slice(0, 1),)]
3463                    ndim -= 1
3464                else:
3465                    has_c_axis = True
3466                    if s == 2:
3467                        # visualize two channels with cyan and magenta
3468                        data = np.concatenate(
3469                            [
3470                                data[slices + (slice(1, 2),)],
3471                                data[slices + (slice(0, 1),)],
3472                                (
3473                                    data[slices + (slice(0, 1),)]
3474                                    + data[slices + (slice(1, 2),)]
3475                                )
3476                                / 2,  # TODO: take maximum instead?
3477                            ],
3478                            axis=i,
3479                        )
3480                    elif data.shape[i] == 3:
3481                        pass  # visualize 3 channels as RGB
3482                    else:
3483                        # visualize first 3 channels as RGB
3484                        data = data[slices + (slice(3),)]
3485
3486                    assert data.shape[i] == 3
3487
3488            slices += (slice(None),)
3489
3490        data, axes = squeeze(data, axes)
3491        assert len(axes) == ndim
3492        # take slice from z axis if needed
3493        slices = ()
3494        if ndim > ndim_need:
3495            for i, a in enumerate(axes):
3496                s = data.shape[i]
3497                if a.id == AxisId("z"):
3498                    data = data[slices + (slice(s // 2 - 1, s // 2),)]
3499                    data, axes = squeeze(data, axes)
3500                    ndim -= 1
3501                    break
3502
3503            slices += (slice(None),)
3504
3505        # take slice from any space or time axis
3506        slices = ()
3507
3508        for i, a in enumerate(axes):
3509            if ndim <= ndim_need:
3510                break
3511
3512            s = data.shape[i]
3513            assert s > 1
3514            if isinstance(
3515                a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis)
3516            ):
3517                data = data[slices + (slice(s // 2 - 1, s // 2),)]
3518                ndim -= 1
3519
3520            slices += (slice(None),)
3521
3522        del slices
3523        data, axes = squeeze(data, axes)
3524        assert len(axes) == ndim
3525
3526        if (has_c_axis and ndim != 3) or ndim != 2:
3527            raise ValueError(
3528                f"Failed to construct cover image from shape {original_shape}"
3529            )
3530
3531        if not has_c_axis:
3532            assert ndim == 2
3533            data = np.repeat(data[:, :, None], 3, axis=2)
3534            axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB"))))
3535            ndim += 1
3536
3537        assert ndim == 3
3538
3539        # transpose axis order such that longest axis comes first...
3540        axis_order: List[int] = list(np.argsort(list(data.shape)))
3541        axis_order.reverse()
3542        # ... and channel axis is last
3543        c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0]
3544        axis_order.append(axis_order.pop(c))
3545        axes = [axes[ao] for ao in axis_order]
3546        data = data.transpose(axis_order)
3547
3548        # h, w = data.shape[:2]
3549        # if h / w  in (1.0 or 2.0):
3550        #     pass
3551        # elif h / w < 2:
3552        # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images
3553
3554        norm_along = (
3555            tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None
3556        )
3557        # normalize the data and map to 8 bit
3558        data = normalize(data, norm_along)
3559        data = (data * 255).astype("uint8")
3560
3561        return data
3562
3563    def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]):
3564        assert im0.dtype == im1.dtype == np.uint8
3565        assert im0.shape == im1.shape
3566        assert im0.ndim == 3
3567        N, M, C = im0.shape
3568        assert C == 3
3569        out = np.ones((N, M, C), dtype="uint8")
3570        for c in range(C):
3571            outc = np.tril(im0[..., c])
3572            mask = outc == 0
3573            outc[mask] = np.triu(im1[..., c])[mask]
3574            out[..., c] = outc
3575
3576        return out
3577
3578    ipt_descr, ipt = inputs[0]
3579    out_descr, out = outputs[0]
3580
3581    ipt_img = to_2d_image(ipt, ipt_descr.axes)
3582    out_img = to_2d_image(out, out_descr.axes)
3583
3584    cover_folder = Path(mkdtemp())
3585    if ipt_img.shape == out_img.shape:
3586        covers = [cover_folder / "cover.png"]
3587        imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img))
3588    else:
3589        covers = [cover_folder / "input.png", cover_folder / "output.png"]
3590        imwrite(covers[0], ipt_img)
3591        imwrite(covers[1], out_img)
3592
3593    return covers
SpaceUnit = typing.Literal['attometer', 'angstrom', 'centimeter', 'decimeter', 'exameter', 'femtometer', 'foot', 'gigameter', 'hectometer', 'inch', 'kilometer', 'megameter', 'meter', 'micrometer', 'mile', 'millimeter', 'nanometer', 'parsec', 'petameter', 'picometer', 'terameter', 'yard', 'yoctometer', 'yottameter', 'zeptometer', 'zettameter']

Space unit compatible to the OME-Zarr axes specification 0.5

TimeUnit = typing.Literal['attosecond', 'centisecond', 'day', 'decisecond', 'exasecond', 'femtosecond', 'gigasecond', 'hectosecond', 'hour', 'kilosecond', 'megasecond', 'microsecond', 'millisecond', 'minute', 'nanosecond', 'petasecond', 'picosecond', 'second', 'terasecond', 'yoctosecond', 'yottasecond', 'zeptosecond', 'zettasecond']

Time unit compatible to the OME-Zarr axes specification 0.5

AxisType = typing.Literal['batch', 'channel', 'index', 'time', 'space']
226class TensorId(LowerCaseIdentifier):
227    root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
228        Annotated[LowerCaseIdentifierAnno, MaxLen(32)]
229    ]

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.

root_model: ClassVar[Type[pydantic.root_model.RootModel[Any]]] = <class 'pydantic.root_model.RootModel[Annotated[str, MinLen, AfterValidator, AfterValidator, Annotated[TypeVar, Predicate], MaxLen]]'>

the pydantic root model to validate the string

242class AxisId(LowerCaseIdentifier):
243    root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
244        Annotated[
245            LowerCaseIdentifierAnno,
246            MaxLen(16),
247            AfterValidator(_normalize_axis_id),
248        ]
249    ]

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.

root_model: ClassVar[Type[pydantic.root_model.RootModel[Any]]] = <class 'pydantic.root_model.RootModel[Annotated[str, MinLen, AfterValidator, AfterValidator, Annotated[TypeVar, Predicate], MaxLen, AfterValidator]]'>

the pydantic root model to validate the string

NonBatchAxisId = typing.Annotated[AxisId, Predicate(_is_not_batch)]
PostprocessingId = typing.Literal['binarize', 'clip', 'ensure_dtype', 'fixed_zero_mean_unit_variance', 'scale_linear', 'scale_mean_variance', 'scale_range', 'sigmoid', 'zero_mean_unit_variance']
PreprocessingId = typing.Literal['binarize', 'clip', 'ensure_dtype', 'scale_linear', 'sigmoid', 'zero_mean_unit_variance', 'scale_range']
SAME_AS_TYPE = '<same as type>'
ParameterizedSize_N = <class 'int'>

Annotates an integer to calculate a concrete axis size from a ParameterizedSize.

class ParameterizedSize(bioimageio.spec._internal.node.Node):
293class ParameterizedSize(Node):
294    """Describes a range of valid tensor axis sizes as `size = min + n*step`.
295
296    - **min** and **step** are given by the model description.
297    - All blocksize paramters n = 0,1,2,... yield a valid `size`.
298    - A greater blocksize paramter n = 0,1,2,... results in a greater **size**.
299      This allows to adjust the axis size more generically.
300    """
301
302    N: ClassVar[Type[int]] = ParameterizedSize_N
303    """Positive integer to parameterize this axis"""
304
305    min: Annotated[int, Gt(0)]
306    step: Annotated[int, Gt(0)]
307
308    def validate_size(self, size: int) -> int:
309        if size < self.min:
310            raise ValueError(f"size {size} < {self.min}")
311        if (size - self.min) % self.step != 0:
312            raise ValueError(
313                f"axis of size {size} is not parameterized by `min + n*step` ="
314                + f" `{self.min} + n*{self.step}`"
315            )
316
317        return size
318
319    def get_size(self, n: ParameterizedSize_N) -> int:
320        return self.min + self.step * n
321
322    def get_n(self, s: int) -> ParameterizedSize_N:
323        """return smallest n parameterizing a size greater or equal than `s`"""
324        return ceil((s - self.min) / self.step)

Describes a range of valid tensor axis sizes as size = min + n*step.

  • min and step are given by the model description.
  • All blocksize paramters n = 0,1,2,... yield a valid size.
  • A greater blocksize paramter n = 0,1,2,... results in a greater size. This allows to adjust the axis size more generically.
N: ClassVar[Type[int]] = <class 'int'>

Positive integer to parameterize this axis

min: Annotated[int, Gt(gt=0)]
step: Annotated[int, Gt(gt=0)]
def validate_size(self, size: int) -> int:
308    def validate_size(self, size: int) -> int:
309        if size < self.min:
310            raise ValueError(f"size {size} < {self.min}")
311        if (size - self.min) % self.step != 0:
312            raise ValueError(
313                f"axis of size {size} is not parameterized by `min + n*step` ="
314                + f" `{self.min} + n*{self.step}`"
315            )
316
317        return size
def get_size(self, n: int) -> int:
319    def get_size(self, n: ParameterizedSize_N) -> int:
320        return self.min + self.step * n
def get_n(self, s: int) -> int:
322    def get_n(self, s: int) -> ParameterizedSize_N:
323        """return smallest n parameterizing a size greater or equal than `s`"""
324        return ceil((s - self.min) / self.step)

return smallest n parameterizing a size greater or equal than s

model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class DataDependentSize(bioimageio.spec._internal.node.Node):
327class DataDependentSize(Node):
328    min: Annotated[int, Gt(0)] = 1
329    max: Annotated[Optional[int], Gt(1)] = None
330
331    @model_validator(mode="after")
332    def _validate_max_gt_min(self):
333        if self.max is not None and self.min >= self.max:
334            raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}")
335
336        return self
337
338    def validate_size(self, size: int) -> int:
339        if size < self.min:
340            raise ValueError(f"size {size} < {self.min}")
341
342        if self.max is not None and size > self.max:
343            raise ValueError(f"size {size} > {self.max}")
344
345        return size
min: Annotated[int, Gt(gt=0)]
max: Annotated[Optional[int], Gt(gt=1)]
def validate_size(self, size: int) -> int:
338    def validate_size(self, size: int) -> int:
339        if size < self.min:
340            raise ValueError(f"size {size} < {self.min}")
341
342        if self.max is not None and size > self.max:
343            raise ValueError(f"size {size} > {self.max}")
344
345        return size
model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class SizeReference(bioimageio.spec._internal.node.Node):
348class SizeReference(Node):
349    """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis.
350
351    `axis.size = reference.size * reference.scale / axis.scale + offset`
352
353    Note:
354    1. The axis and the referenced axis need to have the same unit (or no unit).
355    2. Batch axes may not be referenced.
356    3. Fractions are rounded down.
357    4. If the reference axis is `concatenable` the referencing axis is assumed to be
358        `concatenable` as well with the same block order.
359
360    Example:
361    An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm².
362    Let's assume that we want to express the image height h in relation to its width w
363    instead of only accepting input images of exactly 100*49 pixels
364    (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`).
365
366    >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2)
367    >>> h = SpaceInputAxis(
368    ...     id=AxisId("h"),
369    ...     size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1),
370    ...     unit="millimeter",
371    ...     scale=4,
372    ... )
373    >>> print(h.size.get_size(h, w))
374    49
375
376    ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49
377    """
378
379    tensor_id: TensorId
380    """tensor id of the reference axis"""
381
382    axis_id: AxisId
383    """axis id of the reference axis"""
384
385    offset: int = 0
386
387    def get_size(
388        self,
389        axis: Union[
390            ChannelAxis,
391            IndexInputAxis,
392            IndexOutputAxis,
393            TimeInputAxis,
394            SpaceInputAxis,
395            TimeOutputAxis,
396            TimeOutputAxisWithHalo,
397            SpaceOutputAxis,
398            SpaceOutputAxisWithHalo,
399        ],
400        ref_axis: Union[
401            ChannelAxis,
402            IndexInputAxis,
403            IndexOutputAxis,
404            TimeInputAxis,
405            SpaceInputAxis,
406            TimeOutputAxis,
407            TimeOutputAxisWithHalo,
408            SpaceOutputAxis,
409            SpaceOutputAxisWithHalo,
410        ],
411        n: ParameterizedSize_N = 0,
412        ref_size: Optional[int] = None,
413    ):
414        """Compute the concrete size for a given axis and its reference axis.
415
416        Args:
417            axis: The axis this `SizeReference` is the size of.
418            ref_axis: The reference axis to compute the size from.
419            n: If the **ref_axis** is parameterized (of type `ParameterizedSize`)
420                and no fixed **ref_size** is given,
421                **n** is used to compute the size of the parameterized **ref_axis**.
422            ref_size: Overwrite the reference size instead of deriving it from
423                **ref_axis**
424                (**ref_axis.scale** is still used; any given **n** is ignored).
425        """
426        assert (
427            axis.size == self
428        ), "Given `axis.size` is not defined by this `SizeReference`"
429
430        assert (
431            ref_axis.id == self.axis_id
432        ), f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}."
433
434        assert axis.unit == ref_axis.unit, (
435            "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`,"
436            f" but {axis.unit}!={ref_axis.unit}"
437        )
438        if ref_size is None:
439            if isinstance(ref_axis.size, (int, float)):
440                ref_size = ref_axis.size
441            elif isinstance(ref_axis.size, ParameterizedSize):
442                ref_size = ref_axis.size.get_size(n)
443            elif isinstance(ref_axis.size, DataDependentSize):
444                raise ValueError(
445                    "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`."
446                )
447            elif isinstance(ref_axis.size, SizeReference):
448                raise ValueError(
449                    "Reference axis referenced in `SizeReference` may not be sized by a"
450                    + " `SizeReference` itself."
451                )
452            else:
453                assert_never(ref_axis.size)
454
455        return int(ref_size * ref_axis.scale / axis.scale + self.offset)
456
457    @staticmethod
458    def _get_unit(
459        axis: Union[
460            ChannelAxis,
461            IndexInputAxis,
462            IndexOutputAxis,
463            TimeInputAxis,
464            SpaceInputAxis,
465            TimeOutputAxis,
466            TimeOutputAxisWithHalo,
467            SpaceOutputAxis,
468            SpaceOutputAxisWithHalo,
469        ],
470    ):
471        return axis.unit

A tensor axis size (extent in pixels/frames) defined in relation to a reference axis.

axis.size = reference.size * reference.scale / axis.scale + offset

Note:

  1. The axis and the referenced axis need to have the same unit (or no unit).
  2. Batch axes may not be referenced.
  3. Fractions are rounded down.
  4. If the reference axis is concatenable the referencing axis is assumed to be concatenable as well with the same block order.

Example: An unisotropic input image of wh=10049 pixels depicts a phsical space of 200196mm². Let's assume that we want to express the image height h in relation to its width w instead of only accepting input images of exactly 10049 pixels (for example to express a range of valid image shapes by parametrizing w, see ParameterizedSize).

>>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2)
>>> h = SpaceInputAxis(
...     id=AxisId("h"),
...     size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1),
...     unit="millimeter",
...     scale=4,
... )
>>> print(h.size.get_size(h, w))
49

⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49

tensor_id: TensorId

tensor id of the reference axis

axis_id: AxisId

axis id of the reference axis

offset: int
387    def get_size(
388        self,
389        axis: Union[
390            ChannelAxis,
391            IndexInputAxis,
392            IndexOutputAxis,
393            TimeInputAxis,
394            SpaceInputAxis,
395            TimeOutputAxis,
396            TimeOutputAxisWithHalo,
397            SpaceOutputAxis,
398            SpaceOutputAxisWithHalo,
399        ],
400        ref_axis: Union[
401            ChannelAxis,
402            IndexInputAxis,
403            IndexOutputAxis,
404            TimeInputAxis,
405            SpaceInputAxis,
406            TimeOutputAxis,
407            TimeOutputAxisWithHalo,
408            SpaceOutputAxis,
409            SpaceOutputAxisWithHalo,
410        ],
411        n: ParameterizedSize_N = 0,
412        ref_size: Optional[int] = None,
413    ):
414        """Compute the concrete size for a given axis and its reference axis.
415
416        Args:
417            axis: The axis this `SizeReference` is the size of.
418            ref_axis: The reference axis to compute the size from.
419            n: If the **ref_axis** is parameterized (of type `ParameterizedSize`)
420                and no fixed **ref_size** is given,
421                **n** is used to compute the size of the parameterized **ref_axis**.
422            ref_size: Overwrite the reference size instead of deriving it from
423                **ref_axis**
424                (**ref_axis.scale** is still used; any given **n** is ignored).
425        """
426        assert (
427            axis.size == self
428        ), "Given `axis.size` is not defined by this `SizeReference`"
429
430        assert (
431            ref_axis.id == self.axis_id
432        ), f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}."
433
434        assert axis.unit == ref_axis.unit, (
435            "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`,"
436            f" but {axis.unit}!={ref_axis.unit}"
437        )
438        if ref_size is None:
439            if isinstance(ref_axis.size, (int, float)):
440                ref_size = ref_axis.size
441            elif isinstance(ref_axis.size, ParameterizedSize):
442                ref_size = ref_axis.size.get_size(n)
443            elif isinstance(ref_axis.size, DataDependentSize):
444                raise ValueError(
445                    "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`."
446                )
447            elif isinstance(ref_axis.size, SizeReference):
448                raise ValueError(
449                    "Reference axis referenced in `SizeReference` may not be sized by a"
450                    + " `SizeReference` itself."
451                )
452            else:
453                assert_never(ref_axis.size)
454
455        return int(ref_size * ref_axis.scale / axis.scale + self.offset)

Compute the concrete size for a given axis and its reference axis.

Arguments:
  • axis: The axis this SizeReference is the size of.
  • ref_axis: The reference axis to compute the size from.
  • n: If the ref_axis is parameterized (of type ParameterizedSize) and no fixed ref_size is given, n is used to compute the size of the parameterized ref_axis.
  • ref_size: Overwrite the reference size instead of deriving it from ref_axis (ref_axis.scale is still used; any given n is ignored).
model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

474class AxisBase(NodeWithExplicitlySetFields):
475    id: AxisId
476    """An axis id unique across all axes of one tensor."""
477
478    description: Annotated[str, MaxLen(128)] = ""
id: AxisId

An axis id unique across all axes of one tensor.

description: Annotated[str, MaxLen(max_length=128)]
model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class WithHalo(bioimageio.spec._internal.node.Node):
481class WithHalo(Node):
482    halo: Annotated[int, Ge(1)]
483    """The halo should be cropped from the output tensor to avoid boundary effects.
484    It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`.
485    To document a halo that is already cropped by the model use `size.offset` instead."""
486
487    size: Annotated[
488        SizeReference,
489        Field(
490            examples=[
491                10,
492                SizeReference(
493                    tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
494                ).model_dump(mode="json"),
495            ]
496        ),
497    ]
498    """reference to another axis with an optional offset (see `SizeReference`)"""
halo: Annotated[int, Ge(ge=1)]

The halo should be cropped from the output tensor to avoid boundary effects. It is to be cropped from both sides, i.e. size_after_crop = size - 2 * halo. To document a halo that is already cropped by the model use size.offset instead.

size: Annotated[SizeReference, FieldInfo(annotation=NoneType, required=True, examples=[10, {'tensor_id': 't', 'axis_id': 'a', 'offset': 5}])]

reference to another axis with an optional offset (see SizeReference)

model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

BATCH_AXIS_ID = 'batch'
class BatchAxis(AxisBase):
504class BatchAxis(AxisBase):
505    implemented_type: ClassVar[Literal["batch"]] = "batch"
506    if TYPE_CHECKING:
507        type: Literal["batch"] = "batch"
508    else:
509        type: Literal["batch"]
510
511    id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID
512    size: Optional[Literal[1]] = None
513    """The batch size may be fixed to 1,
514    otherwise (the default) it may be chosen arbitrarily depending on available memory"""
515
516    @property
517    def scale(self):
518        return 1.0
519
520    @property
521    def concatenable(self):
522        return True
523
524    @property
525    def unit(self):
526        return None
implemented_type: ClassVar[Literal['batch']] = 'batch'
id: Annotated[AxisId, Predicate(_is_batch)]

An axis id unique across all axes of one tensor.

size: Optional[Literal[1]]

The batch size may be fixed to 1, otherwise (the default) it may be chosen arbitrarily depending on available memory

scale
516    @property
517    def scale(self):
518        return 1.0
concatenable
520    @property
521    def concatenable(self):
522        return True
unit
524    @property
525    def unit(self):
526        return None
model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['batch']
class ChannelAxis(AxisBase):
529class ChannelAxis(AxisBase):
530    implemented_type: ClassVar[Literal["channel"]] = "channel"
531    if TYPE_CHECKING:
532        type: Literal["channel"] = "channel"
533    else:
534        type: Literal["channel"]
535
536    id: NonBatchAxisId = AxisId("channel")
537    channel_names: NotEmpty[List[Identifier]]
538
539    @property
540    def size(self) -> int:
541        return len(self.channel_names)
542
543    @property
544    def concatenable(self):
545        return False
546
547    @property
548    def scale(self) -> float:
549        return 1.0
550
551    @property
552    def unit(self):
553        return None
implemented_type: ClassVar[Literal['channel']] = 'channel'
id: Annotated[AxisId, Predicate(_is_not_batch)]

An axis id unique across all axes of one tensor.

channel_names: Annotated[List[bioimageio.spec._internal.types.Identifier], MinLen(min_length=1)]
size: int
539    @property
540    def size(self) -> int:
541        return len(self.channel_names)
concatenable
543    @property
544    def concatenable(self):
545        return False
scale: float
547    @property
548    def scale(self) -> float:
549        return 1.0
unit
551    @property
552    def unit(self):
553        return None
model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['channel']
class IndexAxisBase(AxisBase):
556class IndexAxisBase(AxisBase):
557    implemented_type: ClassVar[Literal["index"]] = "index"
558    if TYPE_CHECKING:
559        type: Literal["index"] = "index"
560    else:
561        type: Literal["index"]
562
563    id: NonBatchAxisId = AxisId("index")
564
565    @property
566    def scale(self) -> float:
567        return 1.0
568
569    @property
570    def unit(self):
571        return None
implemented_type: ClassVar[Literal['index']] = 'index'
id: Annotated[AxisId, Predicate(_is_not_batch)]

An axis id unique across all axes of one tensor.

scale: float
565    @property
566    def scale(self) -> float:
567        return 1.0
unit
569    @property
570    def unit(self):
571        return None
model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['index']
class IndexInputAxis(IndexAxisBase, _WithInputAxisSize):
594class IndexInputAxis(IndexAxisBase, _WithInputAxisSize):
595    concatenable: bool = False
596    """If a model has a `concatenable` input axis, it can be processed blockwise,
597    splitting a longer sample axis into blocks matching its input tensor description.
598    Output axes are concatenable if they have a `SizeReference` to a concatenable
599    input axis.
600    """
concatenable: bool

If a model has a concatenable input axis, it can be processed blockwise, splitting a longer sample axis into blocks matching its input tensor description. Output axes are concatenable if they have a SizeReference to a concatenable input axis.

model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['index']
class IndexOutputAxis(IndexAxisBase):
603class IndexOutputAxis(IndexAxisBase):
604    size: Annotated[
605        Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize],
606        Field(
607            examples=[
608                10,
609                SizeReference(
610                    tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
611                ).model_dump(mode="json"),
612            ]
613        ),
614    ]
615    """The size/length of this axis can be specified as
616    - fixed integer
617    - reference to another axis with an optional offset (`SizeReference`)
618    - data dependent size using `DataDependentSize` (size is only known after model inference)
619    """
size: Annotated[Union[Annotated[int, Gt(gt=0)], SizeReference, DataDependentSize], FieldInfo(annotation=NoneType, required=True, examples=[10, {'tensor_id': 't', 'axis_id': 'a', 'offset': 5}])]

The size/length of this axis can be specified as

  • fixed integer
  • reference to another axis with an optional offset (SizeReference)
  • data dependent size using DataDependentSize (size is only known after model inference)
model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['index']
class TimeAxisBase(AxisBase):
622class TimeAxisBase(AxisBase):
623    implemented_type: ClassVar[Literal["time"]] = "time"
624    if TYPE_CHECKING:
625        type: Literal["time"] = "time"
626    else:
627        type: Literal["time"]
628
629    id: NonBatchAxisId = AxisId("time")
630    unit: Optional[TimeUnit] = None
631    scale: Annotated[float, Gt(0)] = 1.0
implemented_type: ClassVar[Literal['time']] = 'time'
id: Annotated[AxisId, Predicate(_is_not_batch)]

An axis id unique across all axes of one tensor.

unit: Optional[Literal['attosecond', 'centisecond', 'day', 'decisecond', 'exasecond', 'femtosecond', 'gigasecond', 'hectosecond', 'hour', 'kilosecond', 'megasecond', 'microsecond', 'millisecond', 'minute', 'nanosecond', 'petasecond', 'picosecond', 'second', 'terasecond', 'yoctosecond', 'yottasecond', 'zeptosecond', 'zettasecond']]
scale: Annotated[float, Gt(gt=0)]
model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['time']
class TimeInputAxis(TimeAxisBase, _WithInputAxisSize):
634class TimeInputAxis(TimeAxisBase, _WithInputAxisSize):
635    concatenable: bool = False
636    """If a model has a `concatenable` input axis, it can be processed blockwise,
637    splitting a longer sample axis into blocks matching its input tensor description.
638    Output axes are concatenable if they have a `SizeReference` to a concatenable
639    input axis.
640    """
concatenable: bool

If a model has a concatenable input axis, it can be processed blockwise, splitting a longer sample axis into blocks matching its input tensor description. Output axes are concatenable if they have a SizeReference to a concatenable input axis.

model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['time']
class SpaceAxisBase(AxisBase):
643class SpaceAxisBase(AxisBase):
644    implemented_type: ClassVar[Literal["space"]] = "space"
645    if TYPE_CHECKING:
646        type: Literal["space"] = "space"
647    else:
648        type: Literal["space"]
649
650    id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x")
651    unit: Optional[SpaceUnit] = None
652    scale: Annotated[float, Gt(0)] = 1.0
implemented_type: ClassVar[Literal['space']] = 'space'
id: Annotated[AxisId, Predicate(_is_not_batch), FieldInfo(annotation=NoneType, required=True, examples=['x', 'y', 'z'])]

An axis id unique across all axes of one tensor.

unit: Optional[Literal['attometer', 'angstrom', 'centimeter', 'decimeter', 'exameter', 'femtometer', 'foot', 'gigameter', 'hectometer', 'inch', 'kilometer', 'megameter', 'meter', 'micrometer', 'mile', 'millimeter', 'nanometer', 'parsec', 'petameter', 'picometer', 'terameter', 'yard', 'yoctometer', 'yottameter', 'zeptometer', 'zettameter']]
scale: Annotated[float, Gt(gt=0)]
model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['space']
class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize):
655class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize):
656    concatenable: bool = False
657    """If a model has a `concatenable` input axis, it can be processed blockwise,
658    splitting a longer sample axis into blocks matching its input tensor description.
659    Output axes are concatenable if they have a `SizeReference` to a concatenable
660    input axis.
661    """
concatenable: bool

If a model has a concatenable input axis, it can be processed blockwise, splitting a longer sample axis into blocks matching its input tensor description. Output axes are concatenable if they have a SizeReference to a concatenable input axis.

model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['space']
INPUT_AXIS_TYPES = (<class 'BatchAxis'>, <class 'ChannelAxis'>, <class 'IndexInputAxis'>, <class 'TimeInputAxis'>, <class 'SpaceInputAxis'>)

intended for isinstance comparisons in py<3.10

InputAxis = typing.Annotated[typing.Union[BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)]
class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize):
697class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize):
698    pass
model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['time']
class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo):
701class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo):
702    pass
model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['time']
class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize):
721class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize):
722    pass
model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['space']
class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo):
725class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo):
726    pass
model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

type: Literal['space']
OutputAxis = typing.Annotated[typing.Union[BatchAxis, ChannelAxis, IndexOutputAxis, typing.Annotated[typing.Union[typing.Annotated[TimeOutputAxis, Tag(tag='wo_halo')], typing.Annotated[TimeOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)], typing.Annotated[typing.Union[typing.Annotated[SpaceOutputAxis, Tag(tag='wo_halo')], typing.Annotated[SpaceOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)]
OUTPUT_AXIS_TYPES = (<class 'BatchAxis'>, <class 'ChannelAxis'>, <class 'IndexOutputAxis'>, <class 'TimeOutputAxis'>, <class 'TimeOutputAxisWithHalo'>, <class 'SpaceOutputAxis'>, <class 'SpaceOutputAxisWithHalo'>)

intended for isinstance comparisons in py<3.10

AnyAxis = typing.Union[typing.Annotated[typing.Union[BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)], typing.Annotated[typing.Union[BatchAxis, ChannelAxis, IndexOutputAxis, typing.Annotated[typing.Union[typing.Annotated[TimeOutputAxis, Tag(tag='wo_halo')], typing.Annotated[TimeOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)], typing.Annotated[typing.Union[typing.Annotated[SpaceOutputAxis, Tag(tag='wo_halo')], typing.Annotated[SpaceOutputAxisWithHalo, Tag(tag='with_halo')]], Discriminator(discriminator=<function _get_halo_axis_discriminator_value>, custom_error_type=None, custom_error_message=None, custom_error_context=None)]], Discriminator(discriminator='type', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]
ANY_AXIS_TYPES = (<class 'BatchAxis'>, <class 'ChannelAxis'>, <class 'IndexInputAxis'>, <class 'TimeInputAxis'>, <class 'SpaceInputAxis'>, <class 'BatchAxis'>, <class 'ChannelAxis'>, <class 'IndexOutputAxis'>, <class 'TimeOutputAxis'>, <class 'TimeOutputAxisWithHalo'>, <class 'SpaceOutputAxis'>, <class 'SpaceOutputAxisWithHalo'>)

intended for isinstance comparisons in py<3.10

TVs = typing.Union[typing.Annotated[typing.List[int], MinLen(min_length=1)], typing.Annotated[typing.List[float], MinLen(min_length=1)], typing.Annotated[typing.List[bool], MinLen(min_length=1)], typing.Annotated[typing.List[str], MinLen(min_length=1)]]
NominalOrOrdinalDType = typing.Literal['float32', 'float64', 'uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64', 'bool']
class NominalOrOrdinalDataDescr(bioimageio.spec._internal.node.Node):
783class NominalOrOrdinalDataDescr(Node):
784    values: TVs
785    """A fixed set of nominal or an ascending sequence of ordinal values.
786    In this case `data.type` is required to be an unsigend integer type, e.g. 'uint8'.
787    String `values` are interpreted as labels for tensor values 0, ..., N.
788    Note: as YAML 1.2 does not natively support a "set" datatype,
789    nominal values should be given as a sequence (aka list/array) as well.
790    """
791
792    type: Annotated[
793        NominalOrOrdinalDType,
794        Field(
795            examples=[
796                "float32",
797                "uint8",
798                "uint16",
799                "int64",
800                "bool",
801            ],
802        ),
803    ] = "uint8"
804
805    @model_validator(mode="after")
806    def _validate_values_match_type(
807        self,
808    ) -> Self:
809        incompatible: List[Any] = []
810        for v in self.values:
811            if self.type == "bool":
812                if not isinstance(v, bool):
813                    incompatible.append(v)
814            elif self.type in DTYPE_LIMITS:
815                if (
816                    isinstance(v, (int, float))
817                    and (
818                        v < DTYPE_LIMITS[self.type].min
819                        or v > DTYPE_LIMITS[self.type].max
820                    )
821                    or (isinstance(v, str) and "uint" not in self.type)
822                    or (isinstance(v, float) and "int" in self.type)
823                ):
824                    incompatible.append(v)
825            else:
826                incompatible.append(v)
827
828            if len(incompatible) == 5:
829                incompatible.append("...")
830                break
831
832        if incompatible:
833            raise ValueError(
834                f"data type '{self.type}' incompatible with values {incompatible}"
835            )
836
837        return self
838
839    unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None
840
841    @property
842    def range(self):
843        if isinstance(self.values[0], str):
844            return 0, len(self.values) - 1
845        else:
846            return min(self.values), max(self.values)
values: Union[Annotated[List[int], MinLen(min_length=1)], Annotated[List[float], MinLen(min_length=1)], Annotated[List[bool], MinLen(min_length=1)], Annotated[List[str], MinLen(min_length=1)]]

A fixed set of nominal or an ascending sequence of ordinal values. In this case data.type is required to be an unsigend integer type, e.g. 'uint8'. String values are interpreted as labels for tensor values 0, ..., N. Note: as YAML 1.2 does not natively support a "set" datatype, nominal values should be given as a sequence (aka list/array) as well.

type: Annotated[Literal['float32', 'float64', 'uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64', 'bool'], FieldInfo(annotation=NoneType, required=True, examples=['float32', 'uint8', 'uint16', 'int64', 'bool'])]
unit: Union[Literal['arbitrary unit'], bioimageio.spec._internal.types.SiUnit, NoneType]
range
841    @property
842    def range(self):
843        if isinstance(self.values[0], str):
844            return 0, len(self.values) - 1
845        else:
846            return min(self.values), max(self.values)
model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

IntervalOrRatioDType = typing.Literal['float32', 'float64', 'uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64']
class IntervalOrRatioDataDescr(bioimageio.spec._internal.node.Node):
863class IntervalOrRatioDataDescr(Node):
864    type: Annotated[  # todo: rename to dtype
865        IntervalOrRatioDType,
866        Field(
867            examples=["float32", "float64", "uint8", "uint16"],
868        ),
869    ] = "float32"
870    range: Tuple[Optional[float], Optional[float]] = (
871        None,
872        None,
873    )
874    """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor.
875    `None` corresponds to min/max of what can be expressed by **type**."""
876    unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit"
877    scale: float = 1.0
878    """Scale for data on an interval (or ratio) scale."""
879    offset: Optional[float] = None
880    """Offset for data on a ratio scale."""
type: Annotated[Literal['float32', 'float64', 'uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64'], FieldInfo(annotation=NoneType, required=True, examples=['float32', 'float64', 'uint8', 'uint16'])]
range: Tuple[Optional[float], Optional[float]]

Tuple (minimum, maximum) specifying the allowed range of the data in this tensor. None corresponds to min/max of what can be expressed by type.

unit: Union[Literal['arbitrary unit'], bioimageio.spec._internal.types.SiUnit]
scale: float

Scale for data on an interval (or ratio) scale.

offset: Optional[float]

Offset for data on a ratio scale.

model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

TensorDataDescr = typing.Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr]
class ProcessingDescrBase(bioimageio.spec._internal.common_nodes.NodeWithExplicitlySetFields, abc.ABC):
886class ProcessingDescrBase(NodeWithExplicitlySetFields, ABC):
887    """processing base class"""

processing base class

model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class BinarizeKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
890class BinarizeKwargs(ProcessingKwargs):
891    """key word arguments for `BinarizeDescr`"""
892
893    threshold: float
894    """The fixed threshold"""

key word arguments for BinarizeDescr

threshold: float

The fixed threshold

model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class BinarizeAlongAxisKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
897class BinarizeAlongAxisKwargs(ProcessingKwargs):
898    """key word arguments for `BinarizeDescr`"""
899
900    threshold: NotEmpty[List[float]]
901    """The fixed threshold values along `axis`"""
902
903    axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
904    """The `threshold` axis"""

key word arguments for BinarizeDescr

threshold: Annotated[List[float], MinLen(min_length=1)]

The fixed threshold values along axis

axis: Annotated[AxisId, Predicate(_is_not_batch), FieldInfo(annotation=NoneType, required=True, examples=['channel'])]

The threshold axis

model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class BinarizeDescr(ProcessingDescrBase):
907class BinarizeDescr(ProcessingDescrBase):
908    """Binarize the tensor with a fixed threshold.
909
910    Values above `BinarizeKwargs.threshold`/`BinarizeAlongAxisKwargs.threshold`
911    will be set to one, values below the threshold to zero.
912
913    Examples:
914    - in YAML
915        ```yaml
916        postprocessing:
917          - id: binarize
918            kwargs:
919              axis: 'channel'
920              threshold: [0.25, 0.5, 0.75]
921        ```
922    - in Python:
923        >>> postprocessing = [BinarizeDescr(
924        ...   kwargs=BinarizeAlongAxisKwargs(
925        ...       axis=AxisId('channel'),
926        ...       threshold=[0.25, 0.5, 0.75],
927        ...   )
928        ... )]
929    """
930
931    implemented_id: ClassVar[Literal["binarize"]] = "binarize"
932    if TYPE_CHECKING:
933        id: Literal["binarize"] = "binarize"
934    else:
935        id: Literal["binarize"]
936    kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs]

Binarize the tensor with a fixed threshold.

Values above BinarizeKwargs.threshold/BinarizeAlongAxisKwargs.threshold will be set to one, values below the threshold to zero.

Examples:

  • in YAML
postprocessing:
  - id: binarize
    kwargs:
      axis: 'channel'
      threshold: [0.25, 0.5, 0.75]
  • in Python:
    >>> postprocessing = [BinarizeDescr(
    ...   kwargs=BinarizeAlongAxisKwargs(
    ...       axis=AxisId('channel'),
    ...       threshold=[0.25, 0.5, 0.75],
    ...   )
    ... )]
    
implemented_id: ClassVar[Literal['binarize']] = 'binarize'
model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

id: Literal['binarize']
class ClipDescr(ProcessingDescrBase):
939class ClipDescr(ProcessingDescrBase):
940    """Set tensor values below min to min and above max to max.
941
942    See `ScaleRangeDescr` for examples.
943    """
944
945    implemented_id: ClassVar[Literal["clip"]] = "clip"
946    if TYPE_CHECKING:
947        id: Literal["clip"] = "clip"
948    else:
949        id: Literal["clip"]
950
951    kwargs: ClipKwargs

Set tensor values below min to min and above max to max.

See ScaleRangeDescr for examples.

implemented_id: ClassVar[Literal['clip']] = 'clip'
model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

id: Literal['clip']
class EnsureDtypeKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
954class EnsureDtypeKwargs(ProcessingKwargs):
955    """key word arguments for `EnsureDtypeDescr`"""
956
957    dtype: Literal[
958        "float32",
959        "float64",
960        "uint8",
961        "int8",
962        "uint16",
963        "int16",
964        "uint32",
965        "int32",
966        "uint64",
967        "int64",
968        "bool",
969    ]

key word arguments for EnsureDtypeDescr

dtype: Literal['float32', 'float64', 'uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64', 'bool']
model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class EnsureDtypeDescr(ProcessingDescrBase):
 972class EnsureDtypeDescr(ProcessingDescrBase):
 973    """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching).
 974
 975    This can for example be used to ensure the inner neural network model gets a
 976    different input tensor data type than the fully described bioimage.io model does.
 977
 978    Examples:
 979        The described bioimage.io model (incl. preprocessing) accepts any
 980        float32-compatible tensor, normalizes it with percentiles and clipping and then
 981        casts it to uint8, which is what the neural network in this example expects.
 982        - in YAML
 983            ```yaml
 984            inputs:
 985            - data:
 986                type: float32  # described bioimage.io model is compatible with any float32 input tensor
 987              preprocessing:
 988              - id: scale_range
 989                  kwargs:
 990                  axes: ['y', 'x']
 991                  max_percentile: 99.8
 992                  min_percentile: 5.0
 993              - id: clip
 994                  kwargs:
 995                  min: 0.0
 996                  max: 1.0
 997              - id: ensure_dtype  # the neural network of the model requires uint8
 998                  kwargs:
 999                  dtype: uint8
1000            ```
1001        - in Python:
1002            >>> preprocessing = [
1003            ...     ScaleRangeDescr(
1004            ...         kwargs=ScaleRangeKwargs(
1005            ...           axes= (AxisId('y'), AxisId('x')),
1006            ...           max_percentile= 99.8,
1007            ...           min_percentile= 5.0,
1008            ...         )
1009            ...     ),
1010            ...     ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)),
1011            ...     EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")),
1012            ... ]
1013    """
1014
1015    implemented_id: ClassVar[Literal["ensure_dtype"]] = "ensure_dtype"
1016    if TYPE_CHECKING:
1017        id: Literal["ensure_dtype"] = "ensure_dtype"
1018    else:
1019        id: Literal["ensure_dtype"]
1020
1021    kwargs: EnsureDtypeKwargs

Cast the tensor data type to EnsureDtypeKwargs.dtype (if not matching).

This can for example be used to ensure the inner neural network model gets a different input tensor data type than the fully described bioimage.io model does.

Examples:

The described bioimage.io model (incl. preprocessing) accepts any float32-compatible tensor, normalizes it with percentiles and clipping and then casts it to uint8, which is what the neural network in this example expects.

  • in YAML

inputs:
- data:
    type: float32  # described bioimage.io model is compatible with any float32 input tensor
  preprocessing:
  - id: scale_range
      kwargs:
      axes: ['y', 'x']
      max_percentile: 99.8
      min_percentile: 5.0
  - id: clip
      kwargs:
      min: 0.0
      max: 1.0
  - id: ensure_dtype  # the neural network of the model requires uint8
      kwargs:
      dtype: uint8

  • in Python:
    >>> preprocessing = [
    ...     ScaleRangeDescr(
    ...         kwargs=ScaleRangeKwargs(
    ...           axes= (AxisId('y'), AxisId('x')),
    ...           max_percentile= 99.8,
    ...           min_percentile= 5.0,
    ...         )
    ...     ),
    ...     ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)),
    ...     EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")),
    ... ]
    
implemented_id: ClassVar[Literal['ensure_dtype']] = 'ensure_dtype'
model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

id: Literal['ensure_dtype']
class ScaleLinearKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1024class ScaleLinearKwargs(ProcessingKwargs):
1025    """Key word arguments for `ScaleLinearDescr`"""
1026
1027    gain: float = 1.0
1028    """multiplicative factor"""
1029
1030    offset: float = 0.0
1031    """additive term"""
1032
1033    @model_validator(mode="after")
1034    def _validate(self) -> Self:
1035        if self.gain == 1.0 and self.offset == 0.0:
1036            raise ValueError(
1037                "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
1038                + " != 0.0."
1039            )
1040
1041        return self

Key word arguments for ScaleLinearDescr

gain: float

multiplicative factor

offset: float

additive term

model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class ScaleLinearAlongAxisKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1044class ScaleLinearAlongAxisKwargs(ProcessingKwargs):
1045    """Key word arguments for `ScaleLinearDescr`"""
1046
1047    axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
1048    """The axis of gain and offset values."""
1049
1050    gain: Union[float, NotEmpty[List[float]]] = 1.0
1051    """multiplicative factor"""
1052
1053    offset: Union[float, NotEmpty[List[float]]] = 0.0
1054    """additive term"""
1055
1056    @model_validator(mode="after")
1057    def _validate(self) -> Self:
1058
1059        if isinstance(self.gain, list):
1060            if isinstance(self.offset, list):
1061                if len(self.gain) != len(self.offset):
1062                    raise ValueError(
1063                        f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match."
1064                    )
1065            else:
1066                self.offset = [float(self.offset)] * len(self.gain)
1067        elif isinstance(self.offset, list):
1068            self.gain = [float(self.gain)] * len(self.offset)
1069        else:
1070            raise ValueError(
1071                "Do not specify an `axis` for scalar gain and offset values."
1072            )
1073
1074        if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset):
1075            raise ValueError(
1076                "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
1077                + " != 0.0."
1078            )
1079
1080        return self

Key word arguments for ScaleLinearDescr

axis: Annotated[AxisId, Predicate(_is_not_batch), FieldInfo(annotation=NoneType, required=True, examples=['channel'])]

The axis of gain and offset values.

gain: Union[float, Annotated[List[float], MinLen(min_length=1)]]

multiplicative factor

offset: Union[float, Annotated[List[float], MinLen(min_length=1)]]

additive term

model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class ScaleLinearDescr(ProcessingDescrBase):
1083class ScaleLinearDescr(ProcessingDescrBase):
1084    """Fixed linear scaling.
1085
1086    Examples:
1087      1. Scale with scalar gain and offset
1088        - in YAML
1089        ```yaml
1090        preprocessing:
1091          - id: scale_linear
1092            kwargs:
1093              gain: 2.0
1094              offset: 3.0
1095        ```
1096        - in Python:
1097        >>> preprocessing = [
1098        ...     ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0))
1099        ... ]
1100
1101      2. Independent scaling along an axis
1102        - in YAML
1103        ```yaml
1104        preprocessing:
1105          - id: scale_linear
1106            kwargs:
1107              axis: 'channel'
1108              gain: [1.0, 2.0, 3.0]
1109        ```
1110        - in Python:
1111        >>> preprocessing = [
1112        ...     ScaleLinearDescr(
1113        ...         kwargs=ScaleLinearAlongAxisKwargs(
1114        ...             axis=AxisId("channel"),
1115        ...             gain=[1.0, 2.0, 3.0],
1116        ...         )
1117        ...     )
1118        ... ]
1119
1120    """
1121
1122    implemented_id: ClassVar[Literal["scale_linear"]] = "scale_linear"
1123    if TYPE_CHECKING:
1124        id: Literal["scale_linear"] = "scale_linear"
1125    else:
1126        id: Literal["scale_linear"]
1127    kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs]

Fixed linear scaling.

Examples:
  1. Scale with scalar gain and offset

    • in YAML

      preprocessing:
        - id: scale_linear
          kwargs:
            gain: 2.0
            offset: 3.0
      
    • in Python:

      >>> preprocessing = [
      ...     ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0))
      ... ]
      
  2. Independent scaling along an axis

    • in YAML

      preprocessing:
        - id: scale_linear
          kwargs:
            axis: 'channel'
            gain: [1.0, 2.0, 3.0]
      
    • in Python:

      >>> preprocessing = [
      ...     ScaleLinearDescr(
      ...         kwargs=ScaleLinearAlongAxisKwargs(
      ...             axis=AxisId("channel"),
      ...             gain=[1.0, 2.0, 3.0],
      ...         )
      ...     )
      ... ]
      
implemented_id: ClassVar[Literal['scale_linear']] = 'scale_linear'
model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

id: Literal['scale_linear']
class SigmoidDescr(ProcessingDescrBase):
1130class SigmoidDescr(ProcessingDescrBase):
1131    """The logistic sigmoid funciton, a.k.a. expit function.
1132
1133    Examples:
1134    - in YAML
1135        ```yaml
1136        postprocessing:
1137          - id: sigmoid
1138        ```
1139    - in Python:
1140        >>> postprocessing = [SigmoidDescr()]
1141    """
1142
1143    implemented_id: ClassVar[Literal["sigmoid"]] = "sigmoid"
1144    if TYPE_CHECKING:
1145        id: Literal["sigmoid"] = "sigmoid"
1146    else:
1147        id: Literal["sigmoid"]
1148
1149    @property
1150    def kwargs(self) -> ProcessingKwargs:
1151        """empty kwargs"""
1152        return ProcessingKwargs()

The logistic sigmoid funciton, a.k.a. expit function.

Examples:

  • in YAML
postprocessing:
  - id: sigmoid
  • in Python:
    >>> postprocessing = [SigmoidDescr()]
    
implemented_id: ClassVar[Literal['sigmoid']] = 'sigmoid'
1149    @property
1150    def kwargs(self) -> ProcessingKwargs:
1151        """empty kwargs"""
1152        return ProcessingKwargs()

empty kwargs

model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

id: Literal['sigmoid']
class FixedZeroMeanUnitVarianceKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1155class FixedZeroMeanUnitVarianceKwargs(ProcessingKwargs):
1156    """key word arguments for `FixedZeroMeanUnitVarianceDescr`"""
1157
1158    mean: float
1159    """The mean value to normalize with."""
1160
1161    std: Annotated[float, Ge(1e-6)]
1162    """The standard deviation value to normalize with."""

key word arguments for FixedZeroMeanUnitVarianceDescr

mean: float

The mean value to normalize with.

std: Annotated[float, Ge(ge=1e-06)]

The standard deviation value to normalize with.

model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class FixedZeroMeanUnitVarianceAlongAxisKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1165class FixedZeroMeanUnitVarianceAlongAxisKwargs(ProcessingKwargs):
1166    """key word arguments for `FixedZeroMeanUnitVarianceDescr`"""
1167
1168    mean: NotEmpty[List[float]]
1169    """The mean value(s) to normalize with."""
1170
1171    std: NotEmpty[List[Annotated[float, Ge(1e-6)]]]
1172    """The standard deviation value(s) to normalize with.
1173    Size must match `mean` values."""
1174
1175    axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])]
1176    """The axis of the mean/std values to normalize each entry along that dimension
1177    separately."""
1178
1179    @model_validator(mode="after")
1180    def _mean_and_std_match(self) -> Self:
1181        if len(self.mean) != len(self.std):
1182            raise ValueError(
1183                f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})"
1184                + " must match."
1185            )
1186
1187        return self

key word arguments for FixedZeroMeanUnitVarianceDescr

mean: Annotated[List[float], MinLen(min_length=1)]

The mean value(s) to normalize with.

std: Annotated[List[Annotated[float, Ge(ge=1e-06)]], MinLen(min_length=1)]

The standard deviation value(s) to normalize with. Size must match mean values.

axis: Annotated[AxisId, Predicate(_is_not_batch), FieldInfo(annotation=NoneType, required=True, examples=['channel', 'index'])]

The axis of the mean/std values to normalize each entry along that dimension separately.

model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1190class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1191    """Subtract a given mean and divide by the standard deviation.
1192
1193    Normalize with fixed, precomputed values for
1194    `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std`
1195    Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given
1196    axes.
1197
1198    Examples:
1199    1. scalar value for whole tensor
1200        - in YAML
1201        ```yaml
1202        preprocessing:
1203          - id: fixed_zero_mean_unit_variance
1204            kwargs:
1205              mean: 103.5
1206              std: 13.7
1207        ```
1208        - in Python
1209        >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
1210        ...   kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7)
1211        ... )]
1212
1213    2. independently along an axis
1214        - in YAML
1215        ```yaml
1216        preprocessing:
1217          - id: fixed_zero_mean_unit_variance
1218            kwargs:
1219              axis: channel
1220              mean: [101.5, 102.5, 103.5]
1221              std: [11.7, 12.7, 13.7]
1222        ```
1223        - in Python
1224        >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
1225        ...   kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
1226        ...     axis=AxisId("channel"),
1227        ...     mean=[101.5, 102.5, 103.5],
1228        ...     std=[11.7, 12.7, 13.7],
1229        ...   )
1230        ... )]
1231    """
1232
1233    implemented_id: ClassVar[Literal["fixed_zero_mean_unit_variance"]] = (
1234        "fixed_zero_mean_unit_variance"
1235    )
1236    if TYPE_CHECKING:
1237        id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance"
1238    else:
1239        id: Literal["fixed_zero_mean_unit_variance"]
1240
1241    kwargs: Union[
1242        FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs
1243    ]

Subtract a given mean and divide by the standard deviation.

Normalize with fixed, precomputed values for FixedZeroMeanUnitVarianceKwargs.mean and FixedZeroMeanUnitVarianceKwargs.std Use FixedZeroMeanUnitVarianceAlongAxisKwargs for independent scaling along given axes.

Examples:

  1. scalar value for whole tensor
    • in YAML
preprocessing:
  - id: fixed_zero_mean_unit_variance
    kwargs:
      mean: 103.5
      std: 13.7
- in Python >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( ... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7) ... )]
  1. independently along an axis

    • in YAML
    preprocessing:
      - id: fixed_zero_mean_unit_variance
        kwargs:
          axis: channel
          mean: [101.5, 102.5, 103.5]
          std: [11.7, 12.7, 13.7]
    
    • in Python
      >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
      ...   kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
      ...     axis=AxisId("channel"),
      ...     mean=[101.5, 102.5, 103.5],
      ...     std=[11.7, 12.7, 13.7],
      ...   )
      ... )]
      
implemented_id: ClassVar[Literal['fixed_zero_mean_unit_variance']] = 'fixed_zero_mean_unit_variance'
model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

id: Literal['fixed_zero_mean_unit_variance']
class ZeroMeanUnitVarianceKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1246class ZeroMeanUnitVarianceKwargs(ProcessingKwargs):
1247    """key word arguments for `ZeroMeanUnitVarianceDescr`"""
1248
1249    axes: Annotated[
1250        Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1251    ] = None
1252    """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1253    For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1254    resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1255    To normalize each sample independently leave out the 'batch' axis.
1256    Default: Scale all axes jointly."""
1257
1258    eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1259    """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`."""

key word arguments for ZeroMeanUnitVarianceDescr

axes: Annotated[Optional[Sequence[AxisId]], FieldInfo(annotation=NoneType, required=True, examples=[('batch', 'x', 'y')])]

The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') resulting in a tensor of equal shape normalized per channel, specify axes=('batch', 'x', 'y'). To normalize each sample independently leave out the 'batch' axis. Default: Scale all axes jointly.

eps: Annotated[float, Interval(gt=0, ge=None, lt=None, le=0.1)]

epsilon for numeric stability: out = (tensor - mean) / (std + eps).

model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class ZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1262class ZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1263    """Subtract mean and divide by variance.
1264
1265    Examples:
1266        Subtract tensor mean and variance
1267        - in YAML
1268        ```yaml
1269        preprocessing:
1270          - id: zero_mean_unit_variance
1271        ```
1272        - in Python
1273        >>> preprocessing = [ZeroMeanUnitVarianceDescr()]
1274    """
1275
1276    implemented_id: ClassVar[Literal["zero_mean_unit_variance"]] = (
1277        "zero_mean_unit_variance"
1278    )
1279    if TYPE_CHECKING:
1280        id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance"
1281    else:
1282        id: Literal["zero_mean_unit_variance"]
1283
1284    kwargs: ZeroMeanUnitVarianceKwargs = Field(
1285        default_factory=ZeroMeanUnitVarianceKwargs
1286    )

Subtract mean and divide by variance.

Examples:

Subtract tensor mean and variance

  • in YAML
preprocessing:
  - id: zero_mean_unit_variance
  • in Python
    >>> preprocessing = [ZeroMeanUnitVarianceDescr()]
    
implemented_id: ClassVar[Literal['zero_mean_unit_variance']] = 'zero_mean_unit_variance'
model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

id: Literal['zero_mean_unit_variance']
class ScaleRangeKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1289class ScaleRangeKwargs(ProcessingKwargs):
1290    """key word arguments for `ScaleRangeDescr`
1291
1292    For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default)
1293    this processing step normalizes data to the [0, 1] intervall.
1294    For other percentiles the normalized values will partially be outside the [0, 1]
1295    intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the
1296    normalized values to a range.
1297    """
1298
1299    axes: Annotated[
1300        Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1301    ] = None
1302    """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value.
1303    For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1304    resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1305    To normalize samples independently, leave out the "batch" axis.
1306    Default: Scale all axes jointly."""
1307
1308    min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0
1309    """The lower percentile used to determine the value to align with zero."""
1310
1311    max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0
1312    """The upper percentile used to determine the value to align with one.
1313    Has to be bigger than `min_percentile`.
1314    The range is 1 to 100 instead of 0 to 100 to avoid mistakenly
1315    accepting percentiles specified in the range 0.0 to 1.0."""
1316
1317    eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1318    """Epsilon for numeric stability.
1319    `out = (tensor - v_lower) / (v_upper - v_lower + eps)`;
1320    with `v_lower,v_upper` values at the respective percentiles."""
1321
1322    reference_tensor: Optional[TensorId] = None
1323    """Tensor ID to compute the percentiles from. Default: The tensor itself.
1324    For any tensor in `inputs` only input tensor references are allowed."""
1325
1326    @field_validator("max_percentile", mode="after")
1327    @classmethod
1328    def min_smaller_max(cls, value: float, info: ValidationInfo) -> float:
1329        if (min_p := info.data["min_percentile"]) >= value:
1330            raise ValueError(f"min_percentile {min_p} >= max_percentile {value}")
1331
1332        return value

key word arguments for ScaleRangeDescr

For min_percentile=0.0 (the default) and max_percentile=100 (the default) this processing step normalizes data to the [0, 1] intervall. For other percentiles the normalized values will partially be outside the [0, 1] intervall. Use ScaleRange followed by ClipDescr if you want to limit the normalized values to a range.

axes: Annotated[Optional[Sequence[AxisId]], FieldInfo(annotation=NoneType, required=True, examples=[('batch', 'x', 'y')])]

The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value. For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') resulting in a tensor of equal shape normalized per channel, specify axes=('batch', 'x', 'y'). To normalize samples independently, leave out the "batch" axis. Default: Scale all axes jointly.

min_percentile: Annotated[float, Interval(gt=None, ge=0, lt=100, le=None)]

The lower percentile used to determine the value to align with zero.

max_percentile: Annotated[float, Interval(gt=1, ge=None, lt=None, le=100)]

The upper percentile used to determine the value to align with one. Has to be bigger than min_percentile. The range is 1 to 100 instead of 0 to 100 to avoid mistakenly accepting percentiles specified in the range 0.0 to 1.0.

eps: Annotated[float, Interval(gt=0, ge=None, lt=None, le=0.1)]

Epsilon for numeric stability. out = (tensor - v_lower) / (v_upper - v_lower + eps); with v_lower,v_upper values at the respective percentiles.

reference_tensor: Optional[TensorId]

Tensor ID to compute the percentiles from. Default: The tensor itself. For any tensor in inputs only input tensor references are allowed.

@field_validator('max_percentile', mode='after')
@classmethod
def min_smaller_max( cls, value: float, info: pydantic_core.core_schema.ValidationInfo) -> float:
1326    @field_validator("max_percentile", mode="after")
1327    @classmethod
1328    def min_smaller_max(cls, value: float, info: ValidationInfo) -> float:
1329        if (min_p := info.data["min_percentile"]) >= value:
1330            raise ValueError(f"min_percentile {min_p} >= max_percentile {value}")
1331
1332        return value
model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class ScaleRangeDescr(ProcessingDescrBase):
1335class ScaleRangeDescr(ProcessingDescrBase):
1336    """Scale with percentiles.
1337
1338    Examples:
1339    1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0
1340        - in YAML
1341        ```yaml
1342        preprocessing:
1343          - id: scale_range
1344            kwargs:
1345              axes: ['y', 'x']
1346              max_percentile: 99.8
1347              min_percentile: 5.0
1348        ```
1349        - in Python
1350        >>> preprocessing = [
1351        ...     ScaleRangeDescr(
1352        ...         kwargs=ScaleRangeKwargs(
1353        ...           axes= (AxisId('y'), AxisId('x')),
1354        ...           max_percentile= 99.8,
1355        ...           min_percentile= 5.0,
1356        ...         )
1357        ...     ),
1358        ...     ClipDescr(
1359        ...         kwargs=ClipKwargs(
1360        ...             min=0.0,
1361        ...             max=1.0,
1362        ...         )
1363        ...     ),
1364        ... ]
1365
1366      2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles.
1367        - in YAML
1368        ```yaml
1369        preprocessing:
1370          - id: scale_range
1371            kwargs:
1372              axes: ['y', 'x']
1373              max_percentile: 99.8
1374              min_percentile: 5.0
1375                  - id: scale_range
1376           - id: clip
1377             kwargs:
1378              min: 0.0
1379              max: 1.0
1380        ```
1381        - in Python
1382        >>> preprocessing = [ScaleRangeDescr(
1383        ...   kwargs=ScaleRangeKwargs(
1384        ...       axes= (AxisId('y'), AxisId('x')),
1385        ...       max_percentile= 99.8,
1386        ...       min_percentile= 5.0,
1387        ...   )
1388        ... )]
1389
1390    """
1391
1392    implemented_id: ClassVar[Literal["scale_range"]] = "scale_range"
1393    if TYPE_CHECKING:
1394        id: Literal["scale_range"] = "scale_range"
1395    else:
1396        id: Literal["scale_range"]
1397    kwargs: ScaleRangeKwargs

Scale with percentiles.

Examples:

  1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0
    • in YAML
preprocessing:
  - id: scale_range
    kwargs:
      axes: ['y', 'x']
      max_percentile: 99.8
      min_percentile: 5.0
- in Python >>> preprocessing = [ ... ScaleRangeDescr( ... kwargs=ScaleRangeKwargs( ... axes= (AxisId('y'), AxisId('x')), ... max_percentile= 99.8, ... min_percentile= 5.0, ... ) ... ), ... ClipDescr( ... kwargs=ClipKwargs( ... min=0.0, ... max=1.0, ... ) ... ), ... ]
  1. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles.

    • in YAML
    preprocessing:
      - id: scale_range
        kwargs:
          axes: ['y', 'x']
          max_percentile: 99.8
          min_percentile: 5.0
              - id: scale_range
       - id: clip
         kwargs:
          min: 0.0
          max: 1.0
    
    • in Python
      >>> preprocessing = [ScaleRangeDescr(
      ...   kwargs=ScaleRangeKwargs(
      ...       axes= (AxisId('y'), AxisId('x')),
      ...       max_percentile= 99.8,
      ...       min_percentile= 5.0,
      ...   )
      ... )]
      
implemented_id: ClassVar[Literal['scale_range']] = 'scale_range'
model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

id: Literal['scale_range']
class ScaleMeanVarianceKwargs(bioimageio.spec.model.v0_4.ProcessingKwargs):
1400class ScaleMeanVarianceKwargs(ProcessingKwargs):
1401    """key word arguments for `ScaleMeanVarianceKwargs`"""
1402
1403    reference_tensor: TensorId
1404    """Name of tensor to match."""
1405
1406    axes: Annotated[
1407        Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1408    ] = None
1409    """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1410    For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1411    resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1412    To normalize samples independently, leave out the 'batch' axis.
1413    Default: Scale all axes jointly."""
1414
1415    eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1416    """Epsilon for numeric stability:
1417    `out  = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`"""

key word arguments for ScaleMeanVarianceKwargs

reference_tensor: TensorId

Name of tensor to match.

axes: Annotated[Optional[Sequence[AxisId]], FieldInfo(annotation=NoneType, required=True, examples=[('batch', 'x', 'y')])]

The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') resulting in a tensor of equal shape normalized per channel, specify axes=('batch', 'x', 'y'). To normalize samples independently, leave out the 'batch' axis. Default: Scale all axes jointly.

eps: Annotated[float, Interval(gt=0, ge=None, lt=None, le=0.1)]

Epsilon for numeric stability: out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.

model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class ScaleMeanVarianceDescr(ProcessingDescrBase):
1420class ScaleMeanVarianceDescr(ProcessingDescrBase):
1421    """Scale a tensor's data distribution to match another tensor's mean/std.
1422    `out  = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`
1423    """
1424
1425    implemented_id: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance"
1426    if TYPE_CHECKING:
1427        id: Literal["scale_mean_variance"] = "scale_mean_variance"
1428    else:
1429        id: Literal["scale_mean_variance"]
1430    kwargs: ScaleMeanVarianceKwargs

Scale a tensor's data distribution to match another tensor's mean/std. out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.

implemented_id: ClassVar[Literal['scale_mean_variance']] = 'scale_mean_variance'
model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

id: Literal['scale_mean_variance']
PreprocessingDescr = typing.Annotated[typing.Union[BinarizeDescr, ClipDescr, EnsureDtypeDescr, ScaleLinearDescr, SigmoidDescr, FixedZeroMeanUnitVarianceDescr, ZeroMeanUnitVarianceDescr, ScaleRangeDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)]
PostprocessingDescr = typing.Annotated[typing.Union[BinarizeDescr, ClipDescr, EnsureDtypeDescr, ScaleLinearDescr, SigmoidDescr, FixedZeroMeanUnitVarianceDescr, ZeroMeanUnitVarianceDescr, ScaleRangeDescr, ScaleMeanVarianceDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)]
class TensorDescrBase(bioimageio.spec._internal.node.Node, typing.Generic[~IO_AxisT]):
1464class TensorDescrBase(Node, Generic[IO_AxisT]):
1465    id: TensorId
1466    """Tensor id. No duplicates are allowed."""
1467
1468    description: Annotated[str, MaxLen(128)] = ""
1469    """free text description"""
1470
1471    axes: NotEmpty[Sequence[IO_AxisT]]
1472    """tensor axes"""
1473
1474    @property
1475    def shape(self):
1476        return tuple(a.size for a in self.axes)
1477
1478    @field_validator("axes", mode="after", check_fields=False)
1479    @classmethod
1480    def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]:
1481        batch_axes = [a for a in axes if a.type == "batch"]
1482        if len(batch_axes) > 1:
1483            raise ValueError(
1484                f"Only one batch axis (per tensor) allowed, but got {batch_axes}"
1485            )
1486
1487        seen_ids: Set[AxisId] = set()
1488        duplicate_axes_ids: Set[AxisId] = set()
1489        for a in axes:
1490            (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id)
1491
1492        if duplicate_axes_ids:
1493            raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}")
1494
1495        return axes
1496
1497    test_tensor: FileDescr_
1498    """An example tensor to use for testing.
1499    Using the model with the test input tensors is expected to yield the test output tensors.
1500    Each test tensor has be a an ndarray in the
1501    [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format).
1502    The file extension must be '.npy'."""
1503
1504    sample_tensor: Optional[FileDescr_] = None
1505    """A sample tensor to illustrate a possible input/output for the model,
1506    The sample image primarily serves to inform a human user about an example use case
1507    and is typically stored as .hdf5, .png or .tiff.
1508    It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats)
1509    (numpy's `.npy` format is not supported).
1510    The image dimensionality has to match the number of axes specified in this tensor description.
1511    """
1512
1513    @model_validator(mode="after")
1514    def _validate_sample_tensor(self) -> Self:
1515        if self.sample_tensor is None or not get_validation_context().perform_io_checks:
1516            return self
1517
1518        reader = get_reader(self.sample_tensor.source, sha256=self.sample_tensor.sha256)
1519        tensor: NDArray[Any] = imread(
1520            reader.read(),
1521            extension=PurePosixPath(reader.original_file_name).suffix,
1522        )
1523        n_dims = len(tensor.squeeze().shape)
1524        n_dims_min = n_dims_max = len(self.axes)
1525
1526        for a in self.axes:
1527            if isinstance(a, BatchAxis):
1528                n_dims_min -= 1
1529            elif isinstance(a.size, int):
1530                if a.size == 1:
1531                    n_dims_min -= 1
1532            elif isinstance(a.size, (ParameterizedSize, DataDependentSize)):
1533                if a.size.min == 1:
1534                    n_dims_min -= 1
1535            elif isinstance(a.size, SizeReference):
1536                if a.size.offset < 2:
1537                    # size reference may result in singleton axis
1538                    n_dims_min -= 1
1539            else:
1540                assert_never(a.size)
1541
1542        n_dims_min = max(0, n_dims_min)
1543        if n_dims < n_dims_min or n_dims > n_dims_max:
1544            raise ValueError(
1545                f"Expected sample tensor to have {n_dims_min} to"
1546                + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})."
1547            )
1548
1549        return self
1550
1551    data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = (
1552        IntervalOrRatioDataDescr()
1553    )
1554    """Description of the tensor's data values, optionally per channel.
1555    If specified per channel, the data `type` needs to match across channels."""
1556
1557    @property
1558    def dtype(
1559        self,
1560    ) -> Literal[
1561        "float32",
1562        "float64",
1563        "uint8",
1564        "int8",
1565        "uint16",
1566        "int16",
1567        "uint32",
1568        "int32",
1569        "uint64",
1570        "int64",
1571        "bool",
1572    ]:
1573        """dtype as specified under `data.type` or `data[i].type`"""
1574        if isinstance(self.data, collections.abc.Sequence):
1575            return self.data[0].type
1576        else:
1577            return self.data.type
1578
1579    @field_validator("data", mode="after")
1580    @classmethod
1581    def _check_data_type_across_channels(
1582        cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]
1583    ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]:
1584        if not isinstance(value, list):
1585            return value
1586
1587        dtypes = {t.type for t in value}
1588        if len(dtypes) > 1:
1589            raise ValueError(
1590                "Tensor data descriptions per channel need to agree in their data"
1591                + f" `type`, but found {dtypes}."
1592            )
1593
1594        return value
1595
1596    @model_validator(mode="after")
1597    def _check_data_matches_channelaxis(self) -> Self:
1598        if not isinstance(self.data, (list, tuple)):
1599            return self
1600
1601        for a in self.axes:
1602            if isinstance(a, ChannelAxis):
1603                size = a.size
1604                assert isinstance(size, int)
1605                break
1606        else:
1607            return self
1608
1609        if len(self.data) != size:
1610            raise ValueError(
1611                f"Got tensor data descriptions for {len(self.data)} channels, but"
1612                + f" '{a.id}' axis has size {size}."
1613            )
1614
1615        return self
1616
1617    def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]:
1618        if len(array.shape) != len(self.axes):
1619            raise ValueError(
1620                f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})"
1621                + f" incompatible with {len(self.axes)} axes."
1622            )
1623        return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
id: TensorId

Tensor id. No duplicates are allowed.

description: Annotated[str, MaxLen(max_length=128)]

free text description

axes: Annotated[Sequence[~IO_AxisT], MinLen(min_length=1)]

tensor axes

shape
1474    @property
1475    def shape(self):
1476        return tuple(a.size for a in self.axes)
test_tensor: Annotated[bioimageio.spec._internal.io.FileDescr, AfterValidator(func=<function wo_special_file_name at 0x7febd4d13c40>), WrapSerializer(func=<function package_file_descr_serializer at 0x7febd4daeca0>, return_type=PydanticUndefined, when_used='unless-none')]

An example tensor to use for testing. Using the model with the test input tensors is expected to yield the test output tensors. Each test tensor has be a an ndarray in the numpy.lib file format. The file extension must be '.npy'.

sample_tensor: Optional[Annotated[bioimageio.spec._internal.io.FileDescr, AfterValidator(func=<function wo_special_file_name at 0x7febd4d13c40>), WrapSerializer(func=<function package_file_descr_serializer at 0x7febd4daeca0>, return_type=PydanticUndefined, when_used='unless-none')]]

A sample tensor to illustrate a possible input/output for the model, The sample image primarily serves to inform a human user about an example use case and is typically stored as .hdf5, .png or .tiff. It has to be readable by the imageio library (numpy's .npy format is not supported). The image dimensionality has to match the number of axes specified in this tensor description.

data: Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr, Annotated[Sequence[Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr]], MinLen(min_length=1)]]

Description of the tensor's data values, optionally per channel. If specified per channel, the data type needs to match across channels.

dtype: Literal['float32', 'float64', 'uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64', 'bool']
1557    @property
1558    def dtype(
1559        self,
1560    ) -> Literal[
1561        "float32",
1562        "float64",
1563        "uint8",
1564        "int8",
1565        "uint16",
1566        "int16",
1567        "uint32",
1568        "int32",
1569        "uint64",
1570        "int64",
1571        "bool",
1572    ]:
1573        """dtype as specified under `data.type` or `data[i].type`"""
1574        if isinstance(self.data, collections.abc.Sequence):
1575            return self.data[0].type
1576        else:
1577            return self.data.type

dtype as specified under data.type or data[i].type

def get_axis_sizes_for_array( self, array: numpy.ndarray[tuple[typing.Any, ...], numpy.dtype[typing.Any]]) -> Dict[AxisId, int]:
1617    def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]:
1618        if len(array.shape) != len(self.axes):
1619            raise ValueError(
1620                f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})"
1621                + f" incompatible with {len(self.axes)} axes."
1622            )
1623        return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
class InputTensorDescr(bioimageio.spec._internal.node.Node, typing.Generic[~IO_AxisT]):
1626class InputTensorDescr(TensorDescrBase[InputAxis]):
1627    id: TensorId = TensorId("input")
1628    """Input tensor id.
1629    No duplicates are allowed across all inputs and outputs."""
1630
1631    optional: bool = False
1632    """indicates that this tensor may be `None`"""
1633
1634    preprocessing: List[PreprocessingDescr] = Field(
1635        default_factory=cast(Callable[[], List[PreprocessingDescr]], list)
1636    )
1637
1638    """Description of how this input should be preprocessed.
1639
1640    notes:
1641    - If preprocessing does not start with an 'ensure_dtype' entry, it is added
1642      to ensure an input tensor's data type matches the input tensor's data description.
1643    - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an
1644      'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally
1645      changing the data type.
1646    """
1647
1648    @model_validator(mode="after")
1649    def _validate_preprocessing_kwargs(self) -> Self:
1650        axes_ids = [a.id for a in self.axes]
1651        for p in self.preprocessing:
1652            kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
1653            if kwargs_axes is None:
1654                continue
1655
1656            if not isinstance(kwargs_axes, collections.abc.Sequence):
1657                raise ValueError(
1658                    f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}"
1659                )
1660
1661            if any(a not in axes_ids for a in kwargs_axes):
1662                raise ValueError(
1663                    "`preprocessing.i.kwargs.axes` needs to be subset of axes ids"
1664                )
1665
1666        if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
1667            dtype = self.data.type
1668        else:
1669            dtype = self.data[0].type
1670
1671        # ensure `preprocessing` begins with `EnsureDtypeDescr`
1672        if not self.preprocessing or not isinstance(
1673            self.preprocessing[0], EnsureDtypeDescr
1674        ):
1675            self.preprocessing.insert(
1676                0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1677            )
1678
1679        # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
1680        if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)):
1681            self.preprocessing.append(
1682                EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1683            )
1684
1685        return self
id: TensorId

Input tensor id. No duplicates are allowed across all inputs and outputs.

optional: bool

indicates that this tensor may be None

preprocessing: List[Annotated[Union[BinarizeDescr, ClipDescr, EnsureDtypeDescr, ScaleLinearDescr, SigmoidDescr, FixedZeroMeanUnitVarianceDescr, ZeroMeanUnitVarianceDescr, ScaleRangeDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]

Description of how this input should be preprocessed.

notes:

  • If preprocessing does not start with an 'ensure_dtype' entry, it is added to ensure an input tensor's data type matches the input tensor's data description.
  • If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally changing the data type.
def convert_axes( axes: str, *, shape: Union[Sequence[int], bioimageio.spec.model.v0_4.ParameterizedInputShape, bioimageio.spec.model.v0_4.ImplicitOutputShape], tensor_type: Literal['input', 'output'], halo: Optional[Sequence[int]], size_refs: Mapping[bioimageio.spec.model.v0_4.TensorName, Mapping[str, int]]):
1688def convert_axes(
1689    axes: str,
1690    *,
1691    shape: Union[
1692        Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4
1693    ],
1694    tensor_type: Literal["input", "output"],
1695    halo: Optional[Sequence[int]],
1696    size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
1697):
1698    ret: List[AnyAxis] = []
1699    for i, a in enumerate(axes):
1700        axis_type = _AXIS_TYPE_MAP.get(a, a)
1701        if axis_type == "batch":
1702            ret.append(BatchAxis())
1703            continue
1704
1705        scale = 1.0
1706        if isinstance(shape, _ParameterizedInputShape_v0_4):
1707            if shape.step[i] == 0:
1708                size = shape.min[i]
1709            else:
1710                size = ParameterizedSize(min=shape.min[i], step=shape.step[i])
1711        elif isinstance(shape, _ImplicitOutputShape_v0_4):
1712            ref_t = str(shape.reference_tensor)
1713            if ref_t.count(".") == 1:
1714                t_id, orig_a_id = ref_t.split(".")
1715            else:
1716                t_id = ref_t
1717                orig_a_id = a
1718
1719            a_id = _AXIS_ID_MAP.get(orig_a_id, a)
1720            if not (orig_scale := shape.scale[i]):
1721                # old way to insert a new axis dimension
1722                size = int(2 * shape.offset[i])
1723            else:
1724                scale = 1 / orig_scale
1725                if axis_type in ("channel", "index"):
1726                    # these axes no longer have a scale
1727                    offset_from_scale = orig_scale * size_refs.get(
1728                        _TensorName_v0_4(t_id), {}
1729                    ).get(orig_a_id, 0)
1730                else:
1731                    offset_from_scale = 0
1732                size = SizeReference(
1733                    tensor_id=TensorId(t_id),
1734                    axis_id=AxisId(a_id),
1735                    offset=int(offset_from_scale + 2 * shape.offset[i]),
1736                )
1737        else:
1738            size = shape[i]
1739
1740        if axis_type == "time":
1741            if tensor_type == "input":
1742                ret.append(TimeInputAxis(size=size, scale=scale))
1743            else:
1744                assert not isinstance(size, ParameterizedSize)
1745                if halo is None:
1746                    ret.append(TimeOutputAxis(size=size, scale=scale))
1747                else:
1748                    assert not isinstance(size, int)
1749                    ret.append(
1750                        TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i])
1751                    )
1752
1753        elif axis_type == "index":
1754            if tensor_type == "input":
1755                ret.append(IndexInputAxis(size=size))
1756            else:
1757                if isinstance(size, ParameterizedSize):
1758                    size = DataDependentSize(min=size.min)
1759
1760                ret.append(IndexOutputAxis(size=size))
1761        elif axis_type == "channel":
1762            assert not isinstance(size, ParameterizedSize)
1763            if isinstance(size, SizeReference):
1764                warnings.warn(
1765                    "Conversion of channel size from an implicit output shape may be"
1766                    + " wrong"
1767                )
1768                ret.append(
1769                    ChannelAxis(
1770                        channel_names=[
1771                            Identifier(f"channel{i}") for i in range(size.offset)
1772                        ]
1773                    )
1774                )
1775            else:
1776                ret.append(
1777                    ChannelAxis(
1778                        channel_names=[Identifier(f"channel{i}") for i in range(size)]
1779                    )
1780                )
1781        elif axis_type == "space":
1782            if tensor_type == "input":
1783                ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale))
1784            else:
1785                assert not isinstance(size, ParameterizedSize)
1786                if halo is None or halo[i] == 0:
1787                    ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale))
1788                elif isinstance(size, int):
1789                    raise NotImplementedError(
1790                        f"output axis with halo and fixed size (here {size}) not allowed"
1791                    )
1792                else:
1793                    ret.append(
1794                        SpaceOutputAxisWithHalo(
1795                            id=AxisId(a), size=size, scale=scale, halo=halo[i]
1796                        )
1797                    )
1798
1799    return ret
class OutputTensorDescr(bioimageio.spec._internal.node.Node, typing.Generic[~IO_AxisT]):
1959class OutputTensorDescr(TensorDescrBase[OutputAxis]):
1960    id: TensorId = TensorId("output")
1961    """Output tensor id.
1962    No duplicates are allowed across all inputs and outputs."""
1963
1964    postprocessing: List[PostprocessingDescr] = Field(
1965        default_factory=cast(Callable[[], List[PostprocessingDescr]], list)
1966    )
1967    """Description of how this output should be postprocessed.
1968
1969    note: `postprocessing` always ends with an 'ensure_dtype' operation.
1970          If not given this is added to cast to this tensor's `data.type`.
1971    """
1972
1973    @model_validator(mode="after")
1974    def _validate_postprocessing_kwargs(self) -> Self:
1975        axes_ids = [a.id for a in self.axes]
1976        for p in self.postprocessing:
1977            kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
1978            if kwargs_axes is None:
1979                continue
1980
1981            if not isinstance(kwargs_axes, collections.abc.Sequence):
1982                raise ValueError(
1983                    f"expected `axes` sequence, but got {type(kwargs_axes)}"
1984                )
1985
1986            if any(a not in axes_ids for a in kwargs_axes):
1987                raise ValueError("`kwargs.axes` needs to be subset of axes ids")
1988
1989        if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
1990            dtype = self.data.type
1991        else:
1992            dtype = self.data[0].type
1993
1994        # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
1995        if not self.postprocessing or not isinstance(
1996            self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)
1997        ):
1998            self.postprocessing.append(
1999                EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
2000            )
2001        return self
id: TensorId

Output tensor id. No duplicates are allowed across all inputs and outputs.

postprocessing: List[Annotated[Union[BinarizeDescr, ClipDescr, EnsureDtypeDescr, ScaleLinearDescr, SigmoidDescr, FixedZeroMeanUnitVarianceDescr, ZeroMeanUnitVarianceDescr, ScaleRangeDescr, ScaleMeanVarianceDescr], Discriminator(discriminator='id', custom_error_type=None, custom_error_message=None, custom_error_context=None)]]

Description of how this output should be postprocessed.

note: postprocessing always ends with an 'ensure_dtype' operation. If not given this is added to cast to this tensor's data.type.

TensorDescr = typing.Union[InputTensorDescr, OutputTensorDescr]
def validate_tensors( tensors: Mapping[TensorId, Tuple[Union[InputTensorDescr, OutputTensorDescr], numpy.ndarray[tuple[Any, ...], numpy.dtype[Any]]]], tensor_origin: Literal['test_tensor']):
2051def validate_tensors(
2052    tensors: Mapping[TensorId, Tuple[TensorDescr, NDArray[Any]]],
2053    tensor_origin: Literal[
2054        "test_tensor"
2055    ],  # for more precise error messages, e.g. 'test_tensor'
2056):
2057    all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, int]]] = {}
2058
2059    def e_msg(d: TensorDescr):
2060        return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]"
2061
2062    for descr, array in tensors.values():
2063        try:
2064            axis_sizes = descr.get_axis_sizes_for_array(array)
2065        except ValueError as e:
2066            raise ValueError(f"{e_msg(descr)} {e}")
2067        else:
2068            all_tensor_axes[descr.id] = {
2069                a.id: (a, axis_sizes[a.id]) for a in descr.axes
2070            }
2071
2072    for descr, array in tensors.values():
2073        if descr.dtype in ("float32", "float64"):
2074            invalid_test_tensor_dtype = array.dtype.name not in (
2075                "float32",
2076                "float64",
2077                "uint8",
2078                "int8",
2079                "uint16",
2080                "int16",
2081                "uint32",
2082                "int32",
2083                "uint64",
2084                "int64",
2085            )
2086        else:
2087            invalid_test_tensor_dtype = array.dtype.name != descr.dtype
2088
2089        if invalid_test_tensor_dtype:
2090            raise ValueError(
2091                f"{e_msg(descr)}.{tensor_origin}.dtype '{array.dtype.name}' does not"
2092                + f" match described dtype '{descr.dtype}'"
2093            )
2094
2095        if array.min() > -1e-4 and array.max() < 1e-4:
2096            raise ValueError(
2097                "Output values are too small for reliable testing."
2098                + f" Values <-1e5 or >=1e5 must be present in {tensor_origin}"
2099            )
2100
2101        for a in descr.axes:
2102            actual_size = all_tensor_axes[descr.id][a.id][1]
2103            if a.size is None:
2104                continue
2105
2106            if isinstance(a.size, int):
2107                if actual_size != a.size:
2108                    raise ValueError(
2109                        f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' "
2110                        + f"has incompatible size {actual_size}, expected {a.size}"
2111                    )
2112            elif isinstance(a.size, ParameterizedSize):
2113                _ = a.size.validate_size(actual_size)
2114            elif isinstance(a.size, DataDependentSize):
2115                _ = a.size.validate_size(actual_size)
2116            elif isinstance(a.size, SizeReference):
2117                ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id)
2118                if ref_tensor_axes is None:
2119                    raise ValueError(
2120                        f"{e_msg(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor"
2121                        + f" reference '{a.size.tensor_id}'"
2122                    )
2123
2124                ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None))
2125                if ref_axis is None or ref_size is None:
2126                    raise ValueError(
2127                        f"{e_msg(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis"
2128                        + f" reference '{a.size.tensor_id}.{a.size.axis_id}"
2129                    )
2130
2131                if a.unit != ref_axis.unit:
2132                    raise ValueError(
2133                        f"{e_msg(descr)}.axes[{a.id}].size: `SizeReference` requires"
2134                        + " axis and reference axis to have the same `unit`, but"
2135                        + f" {a.unit}!={ref_axis.unit}"
2136                    )
2137
2138                if actual_size != (
2139                    expected_size := (
2140                        ref_size * ref_axis.scale / a.scale + a.size.offset
2141                    )
2142                ):
2143                    raise ValueError(
2144                        f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' of size"
2145                        + f" {actual_size} invalid for referenced size {ref_size};"
2146                        + f" expected {expected_size}"
2147                    )
2148            else:
2149                assert_never(a.size)
FileDescr_dependencies = typing.Annotated[bioimageio.spec._internal.io.FileDescr, AfterValidator(func=<function wo_special_file_name>), WrapSerializer(func=<function package_file_descr_serializer>, return_type=PydanticUndefined, when_used='unless-none'), WithSuffix(suffix=('.yaml', '.yml'), case_sensitive=True), FieldInfo(annotation=NoneType, required=True, examples=[{'source': 'environment.yaml'}])]
class ArchitectureFromFileDescr(_ArchitectureCallableDescr, bioimageio.spec._internal.io.FileDescr):
2169class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr):
2170    source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2171    """Architecture source file"""
2172
2173    @model_serializer(mode="wrap", when_used="unless-none")
2174    def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo):
2175        return package_file_descr_serializer(self, nxt, info)

A file description

source: Annotated[Union[bioimageio.spec._internal.url.HttpUrl, bioimageio.spec._internal.io.RelativeFilePath, Annotated[pathlib.Path, PathType(path_type='file'), FieldInfo(annotation=NoneType, required=True, title='FilePath')]], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')]), AfterValidator(func=<function wo_special_file_name at 0x7febd4d13c40>)]

Architecture source file

model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr):
2178class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr):
2179    import_from: str
2180    """Where to import the callable from, i.e. `from <import_from> import <callable>`"""
import_from: str

Where to import the callable from, i.e. from <import_from> import <callable>

model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class WeightsEntryDescrBase(bioimageio.spec._internal.io.FileDescr):
2240class WeightsEntryDescrBase(FileDescr):
2241    type: ClassVar[WeightsFormat]
2242    weights_format_name: ClassVar[str]  # human readable
2243
2244    source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2245    """Source of the weights file."""
2246
2247    authors: Optional[List[Author]] = None
2248    """Authors
2249    Either the person(s) that have trained this model resulting in the original weights file.
2250        (If this is the initial weights entry, i.e. it does not have a `parent`)
2251    Or the person(s) who have converted the weights to this weights format.
2252        (If this is a child weight, i.e. it has a `parent` field)
2253    """
2254
2255    parent: Annotated[
2256        Optional[WeightsFormat], Field(examples=["pytorch_state_dict"])
2257    ] = None
2258    """The source weights these weights were converted from.
2259    For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`,
2260    The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights.
2261    All weight entries except one (the initial set of weights resulting from training the model),
2262    need to have this field."""
2263
2264    comment: str = ""
2265    """A comment about this weights entry, for example how these weights were created."""
2266
2267    @model_validator(mode="after")
2268    def _validate(self) -> Self:
2269        if self.type == self.parent:
2270            raise ValueError("Weights entry can't be it's own parent.")
2271
2272        return self
2273
2274    @model_serializer(mode="wrap", when_used="unless-none")
2275    def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo):
2276        return package_file_descr_serializer(self, nxt, info)

A file description

type: ClassVar[Literal['keras_hdf5', 'onnx', 'pytorch_state_dict', 'tensorflow_js', 'tensorflow_saved_model_bundle', 'torchscript']]
weights_format_name: ClassVar[str]
source: Annotated[Union[bioimageio.spec._internal.url.HttpUrl, bioimageio.spec._internal.io.RelativeFilePath, Annotated[pathlib.Path, PathType(path_type='file'), FieldInfo(annotation=NoneType, required=True, title='FilePath')]], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')]), AfterValidator(func=<function wo_special_file_name at 0x7febd4d13c40>)]

Source of the weights file.

authors: Optional[List[bioimageio.spec.generic.v0_3.Author]]

Authors Either the person(s) that have trained this model resulting in the original weights file. (If this is the initial weights entry, i.e. it does not have a parent) Or the person(s) who have converted the weights to this weights format. (If this is a child weight, i.e. it has a parent field)

parent: Annotated[Optional[Literal['keras_hdf5', 'onnx', 'pytorch_state_dict', 'tensorflow_js', 'tensorflow_saved_model_bundle', 'torchscript']], FieldInfo(annotation=NoneType, required=True, examples=['pytorch_state_dict'])]

The source weights these weights were converted from. For example, if a model's weights were converted from the pytorch_state_dict format to torchscript, The pytorch_state_dict weights entry has no parent and is the parent of the torchscript weights. All weight entries except one (the initial set of weights resulting from training the model), need to have this field.

comment: str

A comment about this weights entry, for example how these weights were created.

model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class KerasHdf5WeightsDescr(WeightsEntryDescrBase):
2279class KerasHdf5WeightsDescr(WeightsEntryDescrBase):
2280    type = "keras_hdf5"
2281    weights_format_name: ClassVar[str] = "Keras HDF5"
2282    tensorflow_version: Version
2283    """TensorFlow version used to create these weights."""

A file description

type = 'keras_hdf5'
weights_format_name: ClassVar[str] = 'Keras HDF5'

TensorFlow version used to create these weights.

model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class OnnxWeightsDescr(WeightsEntryDescrBase):
2286class OnnxWeightsDescr(WeightsEntryDescrBase):
2287    type = "onnx"
2288    weights_format_name: ClassVar[str] = "ONNX"
2289    opset_version: Annotated[int, Ge(7)]
2290    """ONNX opset version"""

A file description

type = 'onnx'
weights_format_name: ClassVar[str] = 'ONNX'
opset_version: Annotated[int, Ge(ge=7)]

ONNX opset version

model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class PytorchStateDictWeightsDescr(WeightsEntryDescrBase):
2293class PytorchStateDictWeightsDescr(WeightsEntryDescrBase):
2294    type = "pytorch_state_dict"
2295    weights_format_name: ClassVar[str] = "Pytorch State Dict"
2296    architecture: Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr]
2297    pytorch_version: Version
2298    """Version of the PyTorch library used.
2299    If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible.
2300    """
2301    dependencies: Optional[FileDescr_dependencies] = None
2302    """Custom depencies beyond pytorch described in a Conda environment file.
2303    Allows to specify custom dependencies, see conda docs:
2304    - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms)
2305    - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually)
2306
2307    The conda environment file should include pytorch and any version pinning has to be compatible with
2308    **pytorch_version**.
2309    """

A file description

type = 'pytorch_state_dict'
weights_format_name: ClassVar[str] = 'Pytorch State Dict'

Version of the PyTorch library used. If architecture.depencencies is specified it has to include pytorch and any version pinning has to be compatible.

dependencies: Optional[Annotated[bioimageio.spec._internal.io.FileDescr, AfterValidator(func=<function wo_special_file_name at 0x7febd4d13c40>), WrapSerializer(func=<function package_file_descr_serializer at 0x7febd4daeca0>, return_type=PydanticUndefined, when_used='unless-none'), WithSuffix(suffix=('.yaml', '.yml'), case_sensitive=True), FieldInfo(annotation=NoneType, required=True, examples=[{'source': 'environment.yaml'}])]]

Custom depencies beyond pytorch described in a Conda environment file. Allows to specify custom dependencies, see conda docs:

The conda environment file should include pytorch and any version pinning has to be compatible with pytorch_version.

model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class TensorflowJsWeightsDescr(WeightsEntryDescrBase):
2312class TensorflowJsWeightsDescr(WeightsEntryDescrBase):
2313    type = "tensorflow_js"
2314    weights_format_name: ClassVar[str] = "Tensorflow.js"
2315    tensorflow_version: Version
2316    """Version of the TensorFlow library used."""
2317
2318    source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2319    """The multi-file weights.
2320    All required files/folders should be a zip archive."""

A file description

type = 'tensorflow_js'
weights_format_name: ClassVar[str] = 'Tensorflow.js'

Version of the TensorFlow library used.

source: Annotated[Union[bioimageio.spec._internal.url.HttpUrl, bioimageio.spec._internal.io.RelativeFilePath, Annotated[pathlib.Path, PathType(path_type='file'), FieldInfo(annotation=NoneType, required=True, title='FilePath')]], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')]), AfterValidator(func=<function wo_special_file_name at 0x7febd4d13c40>)]

The multi-file weights. All required files/folders should be a zip archive.

model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase):
2323class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase):
2324    type = "tensorflow_saved_model_bundle"
2325    weights_format_name: ClassVar[str] = "Tensorflow Saved Model"
2326    tensorflow_version: Version
2327    """Version of the TensorFlow library used."""
2328
2329    dependencies: Optional[FileDescr_dependencies] = None
2330    """Custom dependencies beyond tensorflow.
2331    Should include tensorflow and any version pinning has to be compatible with **tensorflow_version**."""
2332
2333    source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2334    """The multi-file weights.
2335    All required files/folders should be a zip archive."""

A file description

type = 'tensorflow_saved_model_bundle'
weights_format_name: ClassVar[str] = 'Tensorflow Saved Model'

Version of the TensorFlow library used.

dependencies: Optional[Annotated[bioimageio.spec._internal.io.FileDescr, AfterValidator(func=<function wo_special_file_name at 0x7febd4d13c40>), WrapSerializer(func=<function package_file_descr_serializer at 0x7febd4daeca0>, return_type=PydanticUndefined, when_used='unless-none'), WithSuffix(suffix=('.yaml', '.yml'), case_sensitive=True), FieldInfo(annotation=NoneType, required=True, examples=[{'source': 'environment.yaml'}])]]

Custom dependencies beyond tensorflow. Should include tensorflow and any version pinning has to be compatible with tensorflow_version.

source: Annotated[Union[bioimageio.spec._internal.url.HttpUrl, bioimageio.spec._internal.io.RelativeFilePath, Annotated[pathlib.Path, PathType(path_type='file'), FieldInfo(annotation=NoneType, required=True, title='FilePath')]], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')]), AfterValidator(func=<function wo_special_file_name at 0x7febd4d13c40>)]

The multi-file weights. All required files/folders should be a zip archive.

model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class TorchscriptWeightsDescr(WeightsEntryDescrBase):
2338class TorchscriptWeightsDescr(WeightsEntryDescrBase):
2339    type = "torchscript"
2340    weights_format_name: ClassVar[str] = "TorchScript"
2341    pytorch_version: Version
2342    """Version of the PyTorch library used."""

A file description

type = 'torchscript'
weights_format_name: ClassVar[str] = 'TorchScript'

Version of the PyTorch library used.

model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class WeightsDescr(bioimageio.spec._internal.node.Node):
2345class WeightsDescr(Node):
2346    keras_hdf5: Optional[KerasHdf5WeightsDescr] = None
2347    onnx: Optional[OnnxWeightsDescr] = None
2348    pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None
2349    tensorflow_js: Optional[TensorflowJsWeightsDescr] = None
2350    tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = (
2351        None
2352    )
2353    torchscript: Optional[TorchscriptWeightsDescr] = None
2354
2355    @model_validator(mode="after")
2356    def check_entries(self) -> Self:
2357        entries = {wtype for wtype, entry in self if entry is not None}
2358
2359        if not entries:
2360            raise ValueError("Missing weights entry")
2361
2362        entries_wo_parent = {
2363            wtype
2364            for wtype, entry in self
2365            if entry is not None and hasattr(entry, "parent") and entry.parent is None
2366        }
2367        if len(entries_wo_parent) != 1:
2368            issue_warning(
2369                "Exactly one weights entry may not specify the `parent` field (got"
2370                + " {value}). That entry is considered the original set of model weights."
2371                + " Other weight formats are created through conversion of the orignal or"
2372                + " already converted weights. They have to reference the weights format"
2373                + " they were converted from as their `parent`.",
2374                value=len(entries_wo_parent),
2375                field="weights",
2376            )
2377
2378        for wtype, entry in self:
2379            if entry is None:
2380                continue
2381
2382            assert hasattr(entry, "type")
2383            assert hasattr(entry, "parent")
2384            assert wtype == entry.type
2385            if (
2386                entry.parent is not None and entry.parent not in entries
2387            ):  # self reference checked for `parent` field
2388                raise ValueError(
2389                    f"`weights.{wtype}.parent={entry.parent} not in specified weight"
2390                    + f" formats: {entries}"
2391                )
2392
2393        return self
2394
2395    def __getitem__(
2396        self,
2397        key: Literal[
2398            "keras_hdf5",
2399            "onnx",
2400            "pytorch_state_dict",
2401            "tensorflow_js",
2402            "tensorflow_saved_model_bundle",
2403            "torchscript",
2404        ],
2405    ):
2406        if key == "keras_hdf5":
2407            ret = self.keras_hdf5
2408        elif key == "onnx":
2409            ret = self.onnx
2410        elif key == "pytorch_state_dict":
2411            ret = self.pytorch_state_dict
2412        elif key == "tensorflow_js":
2413            ret = self.tensorflow_js
2414        elif key == "tensorflow_saved_model_bundle":
2415            ret = self.tensorflow_saved_model_bundle
2416        elif key == "torchscript":
2417            ret = self.torchscript
2418        else:
2419            raise KeyError(key)
2420
2421        if ret is None:
2422            raise KeyError(key)
2423
2424        return ret
2425
2426    @property
2427    def available_formats(self):
2428        return {
2429            **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}),
2430            **({} if self.onnx is None else {"onnx": self.onnx}),
2431            **(
2432                {}
2433                if self.pytorch_state_dict is None
2434                else {"pytorch_state_dict": self.pytorch_state_dict}
2435            ),
2436            **(
2437                {}
2438                if self.tensorflow_js is None
2439                else {"tensorflow_js": self.tensorflow_js}
2440            ),
2441            **(
2442                {}
2443                if self.tensorflow_saved_model_bundle is None
2444                else {
2445                    "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle
2446                }
2447            ),
2448            **({} if self.torchscript is None else {"torchscript": self.torchscript}),
2449        }
2450
2451    @property
2452    def missing_formats(self):
2453        return {
2454            wf for wf in get_args(WeightsFormat) if wf not in self.available_formats
2455        }
keras_hdf5: Optional[KerasHdf5WeightsDescr]
onnx: Optional[OnnxWeightsDescr]
pytorch_state_dict: Optional[PytorchStateDictWeightsDescr]
tensorflow_js: Optional[TensorflowJsWeightsDescr]
tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr]
torchscript: Optional[TorchscriptWeightsDescr]
@model_validator(mode='after')
def check_entries(self) -> Self:
2355    @model_validator(mode="after")
2356    def check_entries(self) -> Self:
2357        entries = {wtype for wtype, entry in self if entry is not None}
2358
2359        if not entries:
2360            raise ValueError("Missing weights entry")
2361
2362        entries_wo_parent = {
2363            wtype
2364            for wtype, entry in self
2365            if entry is not None and hasattr(entry, "parent") and entry.parent is None
2366        }
2367        if len(entries_wo_parent) != 1:
2368            issue_warning(
2369                "Exactly one weights entry may not specify the `parent` field (got"
2370                + " {value}). That entry is considered the original set of model weights."
2371                + " Other weight formats are created through conversion of the orignal or"
2372                + " already converted weights. They have to reference the weights format"
2373                + " they were converted from as their `parent`.",
2374                value=len(entries_wo_parent),
2375                field="weights",
2376            )
2377
2378        for wtype, entry in self:
2379            if entry is None:
2380                continue
2381
2382            assert hasattr(entry, "type")
2383            assert hasattr(entry, "parent")
2384            assert wtype == entry.type
2385            if (
2386                entry.parent is not None and entry.parent not in entries
2387            ):  # self reference checked for `parent` field
2388                raise ValueError(
2389                    f"`weights.{wtype}.parent={entry.parent} not in specified weight"
2390                    + f" formats: {entries}"
2391                )
2392
2393        return self
available_formats
2426    @property
2427    def available_formats(self):
2428        return {
2429            **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}),
2430            **({} if self.onnx is None else {"onnx": self.onnx}),
2431            **(
2432                {}
2433                if self.pytorch_state_dict is None
2434                else {"pytorch_state_dict": self.pytorch_state_dict}
2435            ),
2436            **(
2437                {}
2438                if self.tensorflow_js is None
2439                else {"tensorflow_js": self.tensorflow_js}
2440            ),
2441            **(
2442                {}
2443                if self.tensorflow_saved_model_bundle is None
2444                else {
2445                    "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle
2446                }
2447            ),
2448            **({} if self.torchscript is None else {"torchscript": self.torchscript}),
2449        }
missing_formats
2451    @property
2452    def missing_formats(self):
2453        return {
2454            wf for wf in get_args(WeightsFormat) if wf not in self.available_formats
2455        }
model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class ModelId(bioimageio.spec.generic.v0_3.ResourceId):
2458class ModelId(ResourceId):
2459    pass

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.

class LinkedModel(bioimageio.spec.generic.v0_3.LinkedResourceBase):
2462class LinkedModel(LinkedResourceBase):
2463    """Reference to a bioimage.io model."""
2464
2465    id: ModelId
2466    """A valid model `id` from the bioimage.io collection."""

Reference to a bioimage.io model.

id: ModelId

A valid model id from the bioimage.io collection.

model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class ReproducibilityTolerance(bioimageio.spec._internal.node.Node):
2488class ReproducibilityTolerance(Node, extra="allow"):
2489    """Describes what small numerical differences -- if any -- may be tolerated
2490    in the generated output when executing in different environments.
2491
2492    A tensor element *output* is considered mismatched to the **test_tensor** if
2493    abs(*output* - **test_tensor**) > **absolute_tolerance** + **relative_tolerance** * abs(**test_tensor**).
2494    (Internally we call [numpy.testing.assert_allclose](https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html).)
2495
2496    Motivation:
2497        For testing we can request the respective deep learning frameworks to be as
2498        reproducible as possible by setting seeds and chosing deterministic algorithms,
2499        but differences in operating systems, available hardware and installed drivers
2500        may still lead to numerical differences.
2501    """
2502
2503    relative_tolerance: RelativeTolerance = 1e-3
2504    """Maximum relative tolerance of reproduced test tensor."""
2505
2506    absolute_tolerance: AbsoluteTolerance = 1e-4
2507    """Maximum absolute tolerance of reproduced test tensor."""
2508
2509    mismatched_elements_per_million: MismatchedElementsPerMillion = 100
2510    """Maximum number of mismatched elements/pixels per million to tolerate."""
2511
2512    output_ids: Sequence[TensorId] = ()
2513    """Limits the output tensor IDs these reproducibility details apply to."""
2514
2515    weights_formats: Sequence[WeightsFormat] = ()
2516    """Limits the weights formats these details apply to."""

Describes what small numerical differences -- if any -- may be tolerated in the generated output when executing in different environments.

A tensor element output is considered mismatched to the test_tensor if abs(output - test_tensor) > absolute_tolerance + relative_tolerance * abs(test_tensor). (Internally we call numpy.testing.assert_allclose.)

Motivation:

For testing we can request the respective deep learning frameworks to be as reproducible as possible by setting seeds and chosing deterministic algorithms, but differences in operating systems, available hardware and installed drivers may still lead to numerical differences.

relative_tolerance: Annotated[float, Interval(gt=None, ge=0, lt=None, le=0.01)]

Maximum relative tolerance of reproduced test tensor.

absolute_tolerance: Annotated[float, Interval(gt=None, ge=0, lt=None, le=None)]

Maximum absolute tolerance of reproduced test tensor.

mismatched_elements_per_million: Annotated[int, Interval(gt=None, ge=0, lt=None, le=1000)]

Maximum number of mismatched elements/pixels per million to tolerate.

output_ids: Sequence[TensorId]

Limits the output tensor IDs these reproducibility details apply to.

weights_formats: Sequence[Literal['keras_hdf5', 'onnx', 'pytorch_state_dict', 'tensorflow_js', 'tensorflow_saved_model_bundle', 'torchscript']]

Limits the weights formats these details apply to.

model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'allow', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class BioimageioConfig(bioimageio.spec._internal.node.Node):
2519class BioimageioConfig(Node, extra="allow"):
2520    reproducibility_tolerance: Sequence[ReproducibilityTolerance] = ()
2521    """Tolerances to allow when reproducing the model's test outputs
2522    from the model's test inputs.
2523    Only the first entry matching tensor id and weights format is considered.
2524    """
reproducibility_tolerance: Sequence[ReproducibilityTolerance]

Tolerances to allow when reproducing the model's test outputs from the model's test inputs. Only the first entry matching tensor id and weights format is considered.

model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'allow', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class Config(bioimageio.spec._internal.node.Node):
2527class Config(Node, extra="allow"):
2528    bioimageio: BioimageioConfig = Field(default_factory=BioimageioConfig)
bioimageio: BioimageioConfig
model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'allow', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

2531class ModelDescr(GenericModelDescrBase):
2532    """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights.
2533    These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
2534    """
2535
2536    implemented_format_version: ClassVar[Literal["0.5.4"]] = "0.5.4"
2537    if TYPE_CHECKING:
2538        format_version: Literal["0.5.4"] = "0.5.4"
2539    else:
2540        format_version: Literal["0.5.4"]
2541        """Version of the bioimage.io model description specification used.
2542        When creating a new model always use the latest micro/patch version described here.
2543        The `format_version` is important for any consumer software to understand how to parse the fields.
2544        """
2545
2546    implemented_type: ClassVar[Literal["model"]] = "model"
2547    if TYPE_CHECKING:
2548        type: Literal["model"] = "model"
2549    else:
2550        type: Literal["model"]
2551        """Specialized resource type 'model'"""
2552
2553    id: Optional[ModelId] = None
2554    """bioimage.io-wide unique resource identifier
2555    assigned by bioimage.io; version **un**specific."""
2556
2557    authors: NotEmpty[List[Author]]
2558    """The authors are the creators of the model RDF and the primary points of contact."""
2559
2560    documentation: FileSource_documentation
2561    """URL or relative path to a markdown file with additional documentation.
2562    The recommended documentation file name is `README.md`. An `.md` suffix is mandatory.
2563    The documentation should include a '#[#] Validation' (sub)section
2564    with details on how to quantitatively validate the model on unseen data."""
2565
2566    @field_validator("documentation", mode="after")
2567    @classmethod
2568    def _validate_documentation(
2569        cls, value: FileSource_documentation
2570    ) -> FileSource_documentation:
2571        if not get_validation_context().perform_io_checks:
2572            return value
2573
2574        doc_reader = get_reader(value)
2575        doc_content = doc_reader.read().decode(encoding="utf-8")
2576        if not re.search("#.*[vV]alidation", doc_content):
2577            issue_warning(
2578                "No '# Validation' (sub)section found in {value}.",
2579                value=value,
2580                field="documentation",
2581            )
2582
2583        return value
2584
2585    inputs: NotEmpty[Sequence[InputTensorDescr]]
2586    """Describes the input tensors expected by this model."""
2587
2588    @field_validator("inputs", mode="after")
2589    @classmethod
2590    def _validate_input_axes(
2591        cls, inputs: Sequence[InputTensorDescr]
2592    ) -> Sequence[InputTensorDescr]:
2593        input_size_refs = cls._get_axes_with_independent_size(inputs)
2594
2595        for i, ipt in enumerate(inputs):
2596            valid_independent_refs: Dict[
2597                Tuple[TensorId, AxisId],
2598                Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2599            ] = {
2600                **{
2601                    (ipt.id, a.id): (ipt, a, a.size)
2602                    for a in ipt.axes
2603                    if not isinstance(a, BatchAxis)
2604                    and isinstance(a.size, (int, ParameterizedSize))
2605                },
2606                **input_size_refs,
2607            }
2608            for a, ax in enumerate(ipt.axes):
2609                cls._validate_axis(
2610                    "inputs",
2611                    i=i,
2612                    tensor_id=ipt.id,
2613                    a=a,
2614                    axis=ax,
2615                    valid_independent_refs=valid_independent_refs,
2616                )
2617        return inputs
2618
2619    @staticmethod
2620    def _validate_axis(
2621        field_name: str,
2622        i: int,
2623        tensor_id: TensorId,
2624        a: int,
2625        axis: AnyAxis,
2626        valid_independent_refs: Dict[
2627            Tuple[TensorId, AxisId],
2628            Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2629        ],
2630    ):
2631        if isinstance(axis, BatchAxis) or isinstance(
2632            axis.size, (int, ParameterizedSize, DataDependentSize)
2633        ):
2634            return
2635        elif not isinstance(axis.size, SizeReference):
2636            assert_never(axis.size)
2637
2638        # validate axis.size SizeReference
2639        ref = (axis.size.tensor_id, axis.size.axis_id)
2640        if ref not in valid_independent_refs:
2641            raise ValueError(
2642                "Invalid tensor axis reference at"
2643                + f" {field_name}[{i}].axes[{a}].size: {axis.size}."
2644            )
2645        if ref == (tensor_id, axis.id):
2646            raise ValueError(
2647                "Self-referencing not allowed for"
2648                + f" {field_name}[{i}].axes[{a}].size: {axis.size}"
2649            )
2650        if axis.type == "channel":
2651            if valid_independent_refs[ref][1].type != "channel":
2652                raise ValueError(
2653                    "A channel axis' size may only reference another fixed size"
2654                    + " channel axis."
2655                )
2656            if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names:
2657                ref_size = valid_independent_refs[ref][2]
2658                assert isinstance(ref_size, int), (
2659                    "channel axis ref (another channel axis) has to specify fixed"
2660                    + " size"
2661                )
2662                generated_channel_names = [
2663                    Identifier(axis.channel_names.format(i=i))
2664                    for i in range(1, ref_size + 1)
2665                ]
2666                axis.channel_names = generated_channel_names
2667
2668        if (ax_unit := getattr(axis, "unit", None)) != (
2669            ref_unit := getattr(valid_independent_refs[ref][1], "unit", None)
2670        ):
2671            raise ValueError(
2672                "The units of an axis and its reference axis need to match, but"
2673                + f" '{ax_unit}' != '{ref_unit}'."
2674            )
2675        ref_axis = valid_independent_refs[ref][1]
2676        if isinstance(ref_axis, BatchAxis):
2677            raise ValueError(
2678                f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}"
2679                + " (a batch axis is not allowed as reference)."
2680            )
2681
2682        if isinstance(axis, WithHalo):
2683            min_size = axis.size.get_size(axis, ref_axis, n=0)
2684            if (min_size - 2 * axis.halo) < 1:
2685                raise ValueError(
2686                    f"axis {axis.id} with minimum size {min_size} is too small for halo"
2687                    + f" {axis.halo}."
2688                )
2689
2690            input_halo = axis.halo * axis.scale / ref_axis.scale
2691            if input_halo != int(input_halo) or input_halo % 2 == 1:
2692                raise ValueError(
2693                    f"input_halo {input_halo} (output_halo {axis.halo} *"
2694                    + f" output_scale {axis.scale} / input_scale {ref_axis.scale})"
2695                    + f"     {tensor_id}.{axis.id}."
2696                )
2697
2698    @model_validator(mode="after")
2699    def _validate_test_tensors(self) -> Self:
2700        if not get_validation_context().perform_io_checks:
2701            return self
2702
2703        test_output_arrays = [load_array(descr.test_tensor) for descr in self.outputs]
2704        test_input_arrays = [load_array(descr.test_tensor) for descr in self.inputs]
2705
2706        tensors = {
2707            descr.id: (descr, array)
2708            for descr, array in zip(
2709                chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays
2710            )
2711        }
2712        validate_tensors(tensors, tensor_origin="test_tensor")
2713
2714        output_arrays = {
2715            descr.id: array for descr, array in zip(self.outputs, test_output_arrays)
2716        }
2717        for rep_tol in self.config.bioimageio.reproducibility_tolerance:
2718            if not rep_tol.absolute_tolerance:
2719                continue
2720
2721            if rep_tol.output_ids:
2722                out_arrays = {
2723                    oid: a
2724                    for oid, a in output_arrays.items()
2725                    if oid in rep_tol.output_ids
2726                }
2727            else:
2728                out_arrays = output_arrays
2729
2730            for out_id, array in out_arrays.items():
2731                if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01:
2732                    raise ValueError(
2733                        "config.bioimageio.reproducibility_tolerance.absolute_tolerance="
2734                        + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}"
2735                        + f" (1% of the maximum value of the test tensor '{out_id}')"
2736                    )
2737
2738        return self
2739
2740    @model_validator(mode="after")
2741    def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self:
2742        ipt_refs = {t.id for t in self.inputs}
2743        out_refs = {t.id for t in self.outputs}
2744        for ipt in self.inputs:
2745            for p in ipt.preprocessing:
2746                ref = p.kwargs.get("reference_tensor")
2747                if ref is None:
2748                    continue
2749                if ref not in ipt_refs:
2750                    raise ValueError(
2751                        f"`reference_tensor` '{ref}' not found. Valid input tensor"
2752                        + f" references are: {ipt_refs}."
2753                    )
2754
2755        for out in self.outputs:
2756            for p in out.postprocessing:
2757                ref = p.kwargs.get("reference_tensor")
2758                if ref is None:
2759                    continue
2760
2761                if ref not in ipt_refs and ref not in out_refs:
2762                    raise ValueError(
2763                        f"`reference_tensor` '{ref}' not found. Valid tensor references"
2764                        + f" are: {ipt_refs | out_refs}."
2765                    )
2766
2767        return self
2768
2769    # TODO: use validate funcs in validate_test_tensors
2770    # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]:
2771
2772    name: Annotated[
2773        Annotated[
2774            str, RestrictCharacters(string.ascii_letters + string.digits + "_+- ()")
2775        ],
2776        MinLen(5),
2777        MaxLen(128),
2778        warn(MaxLen(64), "Name longer than 64 characters.", INFO),
2779    ]
2780    """A human-readable name of this model.
2781    It should be no longer than 64 characters
2782    and may only contain letter, number, underscore, minus, parentheses and spaces.
2783    We recommend to chose a name that refers to the model's task and image modality.
2784    """
2785
2786    outputs: NotEmpty[Sequence[OutputTensorDescr]]
2787    """Describes the output tensors."""
2788
2789    @field_validator("outputs", mode="after")
2790    @classmethod
2791    def _validate_tensor_ids(
2792        cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo
2793    ) -> Sequence[OutputTensorDescr]:
2794        tensor_ids = [
2795            t.id for t in info.data.get("inputs", []) + info.data.get("outputs", [])
2796        ]
2797        duplicate_tensor_ids: List[str] = []
2798        seen: Set[str] = set()
2799        for t in tensor_ids:
2800            if t in seen:
2801                duplicate_tensor_ids.append(t)
2802
2803            seen.add(t)
2804
2805        if duplicate_tensor_ids:
2806            raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}")
2807
2808        return outputs
2809
2810    @staticmethod
2811    def _get_axes_with_parameterized_size(
2812        io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2813    ):
2814        return {
2815            f"{t.id}.{a.id}": (t, a, a.size)
2816            for t in io
2817            for a in t.axes
2818            if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize)
2819        }
2820
2821    @staticmethod
2822    def _get_axes_with_independent_size(
2823        io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2824    ):
2825        return {
2826            (t.id, a.id): (t, a, a.size)
2827            for t in io
2828            for a in t.axes
2829            if not isinstance(a, BatchAxis)
2830            and isinstance(a.size, (int, ParameterizedSize))
2831        }
2832
2833    @field_validator("outputs", mode="after")
2834    @classmethod
2835    def _validate_output_axes(
2836        cls, outputs: List[OutputTensorDescr], info: ValidationInfo
2837    ) -> List[OutputTensorDescr]:
2838        input_size_refs = cls._get_axes_with_independent_size(
2839            info.data.get("inputs", [])
2840        )
2841        output_size_refs = cls._get_axes_with_independent_size(outputs)
2842
2843        for i, out in enumerate(outputs):
2844            valid_independent_refs: Dict[
2845                Tuple[TensorId, AxisId],
2846                Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2847            ] = {
2848                **{
2849                    (out.id, a.id): (out, a, a.size)
2850                    for a in out.axes
2851                    if not isinstance(a, BatchAxis)
2852                    and isinstance(a.size, (int, ParameterizedSize))
2853                },
2854                **input_size_refs,
2855                **output_size_refs,
2856            }
2857            for a, ax in enumerate(out.axes):
2858                cls._validate_axis(
2859                    "outputs",
2860                    i,
2861                    out.id,
2862                    a,
2863                    ax,
2864                    valid_independent_refs=valid_independent_refs,
2865                )
2866
2867        return outputs
2868
2869    packaged_by: List[Author] = Field(
2870        default_factory=cast(Callable[[], List[Author]], list)
2871    )
2872    """The persons that have packaged and uploaded this model.
2873    Only required if those persons differ from the `authors`."""
2874
2875    parent: Optional[LinkedModel] = None
2876    """The model from which this model is derived, e.g. by fine-tuning the weights."""
2877
2878    @model_validator(mode="after")
2879    def _validate_parent_is_not_self(self) -> Self:
2880        if self.parent is not None and self.parent.id == self.id:
2881            raise ValueError("A model description may not reference itself as parent.")
2882
2883        return self
2884
2885    run_mode: Annotated[
2886        Optional[RunMode],
2887        warn(None, "Run mode '{value}' has limited support across consumer softwares."),
2888    ] = None
2889    """Custom run mode for this model: for more complex prediction procedures like test time
2890    data augmentation that currently cannot be expressed in the specification.
2891    No standard run modes are defined yet."""
2892
2893    timestamp: Datetime = Field(default_factory=Datetime.now)
2894    """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format
2895    with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat).
2896    (In Python a datetime object is valid, too)."""
2897
2898    training_data: Annotated[
2899        Union[None, LinkedDataset, DatasetDescr, DatasetDescr02],
2900        Field(union_mode="left_to_right"),
2901    ] = None
2902    """The dataset used to train this model"""
2903
2904    weights: Annotated[WeightsDescr, WrapSerializer(package_weights)]
2905    """The weights for this model.
2906    Weights can be given for different formats, but should otherwise be equivalent.
2907    The available weight formats determine which consumers can use this model."""
2908
2909    config: Config = Field(default_factory=Config)
2910
2911    @model_validator(mode="after")
2912    def _add_default_cover(self) -> Self:
2913        if not get_validation_context().perform_io_checks or self.covers:
2914            return self
2915
2916        try:
2917            generated_covers = generate_covers(
2918                [(t, load_array(t.test_tensor)) for t in self.inputs],
2919                [(t, load_array(t.test_tensor)) for t in self.outputs],
2920            )
2921        except Exception as e:
2922            issue_warning(
2923                "Failed to generate cover image(s): {e}",
2924                value=self.covers,
2925                msg_context=dict(e=e),
2926                field="covers",
2927            )
2928        else:
2929            self.covers.extend(generated_covers)
2930
2931        return self
2932
2933    def get_input_test_arrays(self) -> List[NDArray[Any]]:
2934        data = [load_array(ipt.test_tensor) for ipt in self.inputs]
2935        assert all(isinstance(d, np.ndarray) for d in data)
2936        return data
2937
2938    def get_output_test_arrays(self) -> List[NDArray[Any]]:
2939        data = [load_array(out.test_tensor) for out in self.outputs]
2940        assert all(isinstance(d, np.ndarray) for d in data)
2941        return data
2942
2943    @staticmethod
2944    def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
2945        batch_size = 1
2946        tensor_with_batchsize: Optional[TensorId] = None
2947        for tid in tensor_sizes:
2948            for aid, s in tensor_sizes[tid].items():
2949                if aid != BATCH_AXIS_ID or s == 1 or s == batch_size:
2950                    continue
2951
2952                if batch_size != 1:
2953                    assert tensor_with_batchsize is not None
2954                    raise ValueError(
2955                        f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})"
2956                    )
2957
2958                batch_size = s
2959                tensor_with_batchsize = tid
2960
2961        return batch_size
2962
2963    def get_output_tensor_sizes(
2964        self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]
2965    ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]:
2966        """Returns the tensor output sizes for given **input_sizes**.
2967        Only if **input_sizes** has a valid input shape, the tensor output size is exact.
2968        Otherwise it might be larger than the actual (valid) output"""
2969        batch_size = self.get_batch_size(input_sizes)
2970        ns = self.get_ns(input_sizes)
2971
2972        tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size)
2973        return tensor_sizes.outputs
2974
2975    def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
2976        """get parameter `n` for each parameterized axis
2977        such that the valid input size is >= the given input size"""
2978        ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {}
2979        axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs}
2980        for tid in input_sizes:
2981            for aid, s in input_sizes[tid].items():
2982                size_descr = axes[tid][aid].size
2983                if isinstance(size_descr, ParameterizedSize):
2984                    ret[(tid, aid)] = size_descr.get_n(s)
2985                elif size_descr is None or isinstance(size_descr, (int, SizeReference)):
2986                    pass
2987                else:
2988                    assert_never(size_descr)
2989
2990        return ret
2991
2992    def get_tensor_sizes(
2993        self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int
2994    ) -> _TensorSizes:
2995        axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size)
2996        return _TensorSizes(
2997            {
2998                t: {
2999                    aa: axis_sizes.inputs[(tt, aa)]
3000                    for tt, aa in axis_sizes.inputs
3001                    if tt == t
3002                }
3003                for t in {tt for tt, _ in axis_sizes.inputs}
3004            },
3005            {
3006                t: {
3007                    aa: axis_sizes.outputs[(tt, aa)]
3008                    for tt, aa in axis_sizes.outputs
3009                    if tt == t
3010                }
3011                for t in {tt for tt, _ in axis_sizes.outputs}
3012            },
3013        )
3014
3015    def get_axis_sizes(
3016        self,
3017        ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
3018        batch_size: Optional[int] = None,
3019        *,
3020        max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
3021    ) -> _AxisSizes:
3022        """Determine input and output block shape for scale factors **ns**
3023        of parameterized input sizes.
3024
3025        Args:
3026            ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id))
3027                that is parameterized as `size = min + n * step`.
3028            batch_size: The desired size of the batch dimension.
3029                If given **batch_size** overwrites any batch size present in
3030                **max_input_shape**. Default 1.
3031            max_input_shape: Limits the derived block shapes.
3032                Each axis for which the input size, parameterized by `n`, is larger
3033                than **max_input_shape** is set to the minimal value `n_min` for which
3034                this is still true.
3035                Use this for small input samples or large values of **ns**.
3036                Or simply whenever you know the full input shape.
3037
3038        Returns:
3039            Resolved axis sizes for model inputs and outputs.
3040        """
3041        max_input_shape = max_input_shape or {}
3042        if batch_size is None:
3043            for (_t_id, a_id), s in max_input_shape.items():
3044                if a_id == BATCH_AXIS_ID:
3045                    batch_size = s
3046                    break
3047            else:
3048                batch_size = 1
3049
3050        all_axes = {
3051            t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
3052        }
3053
3054        inputs: Dict[Tuple[TensorId, AxisId], int] = {}
3055        outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {}
3056
3057        def get_axis_size(a: Union[InputAxis, OutputAxis]):
3058            if isinstance(a, BatchAxis):
3059                if (t_descr.id, a.id) in ns:
3060                    logger.warning(
3061                        "Ignoring unexpected size increment factor (n) for batch axis"
3062                        + " of tensor '{}'.",
3063                        t_descr.id,
3064                    )
3065                return batch_size
3066            elif isinstance(a.size, int):
3067                if (t_descr.id, a.id) in ns:
3068                    logger.warning(
3069                        "Ignoring unexpected size increment factor (n) for fixed size"
3070                        + " axis '{}' of tensor '{}'.",
3071                        a.id,
3072                        t_descr.id,
3073                    )
3074                return a.size
3075            elif isinstance(a.size, ParameterizedSize):
3076                if (t_descr.id, a.id) not in ns:
3077                    raise ValueError(
3078                        "Size increment factor (n) missing for parametrized axis"
3079                        + f" '{a.id}' of tensor '{t_descr.id}'."
3080                    )
3081                n = ns[(t_descr.id, a.id)]
3082                s_max = max_input_shape.get((t_descr.id, a.id))
3083                if s_max is not None:
3084                    n = min(n, a.size.get_n(s_max))
3085
3086                return a.size.get_size(n)
3087
3088            elif isinstance(a.size, SizeReference):
3089                if (t_descr.id, a.id) in ns:
3090                    logger.warning(
3091                        "Ignoring unexpected size increment factor (n) for axis '{}'"
3092                        + " of tensor '{}' with size reference.",
3093                        a.id,
3094                        t_descr.id,
3095                    )
3096                assert not isinstance(a, BatchAxis)
3097                ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
3098                assert not isinstance(ref_axis, BatchAxis)
3099                ref_key = (a.size.tensor_id, a.size.axis_id)
3100                ref_size = inputs.get(ref_key, outputs.get(ref_key))
3101                assert ref_size is not None, ref_key
3102                assert not isinstance(ref_size, _DataDepSize), ref_key
3103                return a.size.get_size(
3104                    axis=a,
3105                    ref_axis=ref_axis,
3106                    ref_size=ref_size,
3107                )
3108            elif isinstance(a.size, DataDependentSize):
3109                if (t_descr.id, a.id) in ns:
3110                    logger.warning(
3111                        "Ignoring unexpected increment factor (n) for data dependent"
3112                        + " size axis '{}' of tensor '{}'.",
3113                        a.id,
3114                        t_descr.id,
3115                    )
3116                return _DataDepSize(a.size.min, a.size.max)
3117            else:
3118                assert_never(a.size)
3119
3120        # first resolve all , but the `SizeReference` input sizes
3121        for t_descr in self.inputs:
3122            for a in t_descr.axes:
3123                if not isinstance(a.size, SizeReference):
3124                    s = get_axis_size(a)
3125                    assert not isinstance(s, _DataDepSize)
3126                    inputs[t_descr.id, a.id] = s
3127
3128        # resolve all other input axis sizes
3129        for t_descr in self.inputs:
3130            for a in t_descr.axes:
3131                if isinstance(a.size, SizeReference):
3132                    s = get_axis_size(a)
3133                    assert not isinstance(s, _DataDepSize)
3134                    inputs[t_descr.id, a.id] = s
3135
3136        # resolve all output axis sizes
3137        for t_descr in self.outputs:
3138            for a in t_descr.axes:
3139                assert not isinstance(a.size, ParameterizedSize)
3140                s = get_axis_size(a)
3141                outputs[t_descr.id, a.id] = s
3142
3143        return _AxisSizes(inputs=inputs, outputs=outputs)
3144
3145    @model_validator(mode="before")
3146    @classmethod
3147    def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]:
3148        cls.convert_from_old_format_wo_validation(data)
3149        return data
3150
3151    @classmethod
3152    def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3153        """Convert metadata following an older format version to this classes' format
3154        without validating the result.
3155        """
3156        if (
3157            data.get("type") == "model"
3158            and isinstance(fv := data.get("format_version"), str)
3159            and fv.count(".") == 2
3160        ):
3161            fv_parts = fv.split(".")
3162            if any(not p.isdigit() for p in fv_parts):
3163                return
3164
3165            fv_tuple = tuple(map(int, fv_parts))
3166
3167            assert cls.implemented_format_version_tuple[0:2] == (0, 5)
3168            if fv_tuple[:2] in ((0, 3), (0, 4)):
3169                m04 = _ModelDescr_v0_4.load(data)
3170                if isinstance(m04, InvalidDescr):
3171                    try:
3172                        updated = _model_conv.convert_as_dict(
3173                            m04  # pyright: ignore[reportArgumentType]
3174                        )
3175                    except Exception as e:
3176                        logger.error(
3177                            "Failed to convert from invalid model 0.4 description."
3178                            + f"\nerror: {e}"
3179                            + "\nProceeding with model 0.5 validation without conversion."
3180                        )
3181                        updated = None
3182                else:
3183                    updated = _model_conv.convert_as_dict(m04)
3184
3185                if updated is not None:
3186                    data.clear()
3187                    data.update(updated)
3188
3189            elif fv_tuple[:2] == (0, 5):
3190                # bump patch version
3191                data["format_version"] = cls.implemented_format_version

Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. These fields are typically stored in a YAML file which we call a model resource description file (model RDF).

implemented_format_version: ClassVar[Literal['0.5.4']] = '0.5.4'
implemented_type: ClassVar[Literal['model']] = 'model'
id: Optional[ModelId]

bioimage.io-wide unique resource identifier assigned by bioimage.io; version unspecific.

authors: Annotated[List[bioimageio.spec.generic.v0_3.Author], MinLen(min_length=1)]

The authors are the creators of the model RDF and the primary points of contact.

documentation: Annotated[Union[bioimageio.spec._internal.url.HttpUrl, bioimageio.spec._internal.io.RelativeFilePath, Annotated[pathlib.Path, PathType(path_type='file'), FieldInfo(annotation=NoneType, required=True, title='FilePath')]], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')]), AfterValidator(func=<function wo_special_file_name at 0x7febd4d13c40>), PlainSerializer(func=<function _package_serializer at 0x7febd4daec00>, return_type=PydanticUndefined, when_used='unless-none'), WithSuffix(suffix='.md', case_sensitive=True), FieldInfo(annotation=NoneType, required=True, examples=['https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/README.md', 'README.md'])]

URL or relative path to a markdown file with additional documentation. The recommended documentation file name is README.md. An .md suffix is mandatory. The documentation should include a '#[#] Validation' (sub)section with details on how to quantitatively validate the model on unseen data.

inputs: Annotated[Sequence[InputTensorDescr], MinLen(min_length=1)]

Describes the input tensors expected by this model.

name: Annotated[str, RestrictCharacters(alphabet='abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_+- ()'), MinLen(min_length=5), MaxLen(max_length=128), AfterWarner(func=<function as_warning.<locals>.wrapper at 0x7febd1aaf100>, severity=20, msg='Name longer than 64 characters.', context={'typ': Annotated[Any, MaxLen(max_length=64)]})]

A human-readable name of this model. It should be no longer than 64 characters and may only contain letter, number, underscore, minus, parentheses and spaces. We recommend to chose a name that refers to the model's task and image modality.

outputs: Annotated[Sequence[OutputTensorDescr], MinLen(min_length=1)]

Describes the output tensors.

The persons that have packaged and uploaded this model. Only required if those persons differ from the authors.

parent: Optional[LinkedModel]

The model from which this model is derived, e.g. by fine-tuning the weights.

run_mode: Annotated[Optional[bioimageio.spec.model.v0_4.RunMode], AfterWarner(func=<function as_warning.<locals>.wrapper at 0x7febd1aaef20>, severity=30, msg="Run mode '{value}' has limited support across consumer softwares.", context={'typ': None})]

Custom run mode for this model: for more complex prediction procedures like test time data augmentation that currently cannot be expressed in the specification. No standard run modes are defined yet.

Timestamp in ISO 8601 format with a few restrictions listed here. (In Python a datetime object is valid, too).

training_data: Annotated[Union[NoneType, bioimageio.spec.dataset.v0_3.LinkedDataset, bioimageio.spec.DatasetDescr, bioimageio.spec.dataset.v0_2.DatasetDescr], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')])]

The dataset used to train this model

weights: Annotated[WeightsDescr, WrapSerializer(func=<function package_weights at 0x7febd4a1d8a0>, return_type=PydanticUndefined, when_used='always')]

The weights for this model. Weights can be given for different formats, but should otherwise be equivalent. The available weight formats determine which consumers can use this model.

config: Config
def get_input_test_arrays(self) -> List[numpy.ndarray[tuple[Any, ...], numpy.dtype[Any]]]:
2933    def get_input_test_arrays(self) -> List[NDArray[Any]]:
2934        data = [load_array(ipt.test_tensor) for ipt in self.inputs]
2935        assert all(isinstance(d, np.ndarray) for d in data)
2936        return data
def get_output_test_arrays(self) -> List[numpy.ndarray[tuple[Any, ...], numpy.dtype[Any]]]:
2938    def get_output_test_arrays(self) -> List[NDArray[Any]]:
2939        data = [load_array(out.test_tensor) for out in self.outputs]
2940        assert all(isinstance(d, np.ndarray) for d in data)
2941        return data
@staticmethod
def get_batch_size( tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
2943    @staticmethod
2944    def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
2945        batch_size = 1
2946        tensor_with_batchsize: Optional[TensorId] = None
2947        for tid in tensor_sizes:
2948            for aid, s in tensor_sizes[tid].items():
2949                if aid != BATCH_AXIS_ID or s == 1 or s == batch_size:
2950                    continue
2951
2952                if batch_size != 1:
2953                    assert tensor_with_batchsize is not None
2954                    raise ValueError(
2955                        f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})"
2956                    )
2957
2958                batch_size = s
2959                tensor_with_batchsize = tid
2960
2961        return batch_size
def get_output_tensor_sizes( self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> Dict[TensorId, Dict[AxisId, Union[int, bioimageio.spec.model.v0_5._DataDepSize]]]:
2963    def get_output_tensor_sizes(
2964        self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]
2965    ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]:
2966        """Returns the tensor output sizes for given **input_sizes**.
2967        Only if **input_sizes** has a valid input shape, the tensor output size is exact.
2968        Otherwise it might be larger than the actual (valid) output"""
2969        batch_size = self.get_batch_size(input_sizes)
2970        ns = self.get_ns(input_sizes)
2971
2972        tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size)
2973        return tensor_sizes.outputs

Returns the tensor output sizes for given input_sizes. Only if input_sizes has a valid input shape, the tensor output size is exact. Otherwise it might be larger than the actual (valid) output

def get_ns( self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
2975    def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
2976        """get parameter `n` for each parameterized axis
2977        such that the valid input size is >= the given input size"""
2978        ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {}
2979        axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs}
2980        for tid in input_sizes:
2981            for aid, s in input_sizes[tid].items():
2982                size_descr = axes[tid][aid].size
2983                if isinstance(size_descr, ParameterizedSize):
2984                    ret[(tid, aid)] = size_descr.get_n(s)
2985                elif size_descr is None or isinstance(size_descr, (int, SizeReference)):
2986                    pass
2987                else:
2988                    assert_never(size_descr)
2989
2990        return ret

get parameter n for each parameterized axis such that the valid input size is >= the given input size

def get_tensor_sizes( self, ns: Mapping[Tuple[TensorId, AxisId], int], batch_size: int) -> bioimageio.spec.model.v0_5._TensorSizes:
2992    def get_tensor_sizes(
2993        self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int
2994    ) -> _TensorSizes:
2995        axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size)
2996        return _TensorSizes(
2997            {
2998                t: {
2999                    aa: axis_sizes.inputs[(tt, aa)]
3000                    for tt, aa in axis_sizes.inputs
3001                    if tt == t
3002                }
3003                for t in {tt for tt, _ in axis_sizes.inputs}
3004            },
3005            {
3006                t: {
3007                    aa: axis_sizes.outputs[(tt, aa)]
3008                    for tt, aa in axis_sizes.outputs
3009                    if tt == t
3010                }
3011                for t in {tt for tt, _ in axis_sizes.outputs}
3012            },
3013        )
def get_axis_sizes( self, ns: Mapping[Tuple[TensorId, AxisId], int], batch_size: Optional[int] = None, *, max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None) -> bioimageio.spec.model.v0_5._AxisSizes:
3015    def get_axis_sizes(
3016        self,
3017        ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
3018        batch_size: Optional[int] = None,
3019        *,
3020        max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
3021    ) -> _AxisSizes:
3022        """Determine input and output block shape for scale factors **ns**
3023        of parameterized input sizes.
3024
3025        Args:
3026            ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id))
3027                that is parameterized as `size = min + n * step`.
3028            batch_size: The desired size of the batch dimension.
3029                If given **batch_size** overwrites any batch size present in
3030                **max_input_shape**. Default 1.
3031            max_input_shape: Limits the derived block shapes.
3032                Each axis for which the input size, parameterized by `n`, is larger
3033                than **max_input_shape** is set to the minimal value `n_min` for which
3034                this is still true.
3035                Use this for small input samples or large values of **ns**.
3036                Or simply whenever you know the full input shape.
3037
3038        Returns:
3039            Resolved axis sizes for model inputs and outputs.
3040        """
3041        max_input_shape = max_input_shape or {}
3042        if batch_size is None:
3043            for (_t_id, a_id), s in max_input_shape.items():
3044                if a_id == BATCH_AXIS_ID:
3045                    batch_size = s
3046                    break
3047            else:
3048                batch_size = 1
3049
3050        all_axes = {
3051            t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
3052        }
3053
3054        inputs: Dict[Tuple[TensorId, AxisId], int] = {}
3055        outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {}
3056
3057        def get_axis_size(a: Union[InputAxis, OutputAxis]):
3058            if isinstance(a, BatchAxis):
3059                if (t_descr.id, a.id) in ns:
3060                    logger.warning(
3061                        "Ignoring unexpected size increment factor (n) for batch axis"
3062                        + " of tensor '{}'.",
3063                        t_descr.id,
3064                    )
3065                return batch_size
3066            elif isinstance(a.size, int):
3067                if (t_descr.id, a.id) in ns:
3068                    logger.warning(
3069                        "Ignoring unexpected size increment factor (n) for fixed size"
3070                        + " axis '{}' of tensor '{}'.",
3071                        a.id,
3072                        t_descr.id,
3073                    )
3074                return a.size
3075            elif isinstance(a.size, ParameterizedSize):
3076                if (t_descr.id, a.id) not in ns:
3077                    raise ValueError(
3078                        "Size increment factor (n) missing for parametrized axis"
3079                        + f" '{a.id}' of tensor '{t_descr.id}'."
3080                    )
3081                n = ns[(t_descr.id, a.id)]
3082                s_max = max_input_shape.get((t_descr.id, a.id))
3083                if s_max is not None:
3084                    n = min(n, a.size.get_n(s_max))
3085
3086                return a.size.get_size(n)
3087
3088            elif isinstance(a.size, SizeReference):
3089                if (t_descr.id, a.id) in ns:
3090                    logger.warning(
3091                        "Ignoring unexpected size increment factor (n) for axis '{}'"
3092                        + " of tensor '{}' with size reference.",
3093                        a.id,
3094                        t_descr.id,
3095                    )
3096                assert not isinstance(a, BatchAxis)
3097                ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
3098                assert not isinstance(ref_axis, BatchAxis)
3099                ref_key = (a.size.tensor_id, a.size.axis_id)
3100                ref_size = inputs.get(ref_key, outputs.get(ref_key))
3101                assert ref_size is not None, ref_key
3102                assert not isinstance(ref_size, _DataDepSize), ref_key
3103                return a.size.get_size(
3104                    axis=a,
3105                    ref_axis=ref_axis,
3106                    ref_size=ref_size,
3107                )
3108            elif isinstance(a.size, DataDependentSize):
3109                if (t_descr.id, a.id) in ns:
3110                    logger.warning(
3111                        "Ignoring unexpected increment factor (n) for data dependent"
3112                        + " size axis '{}' of tensor '{}'.",
3113                        a.id,
3114                        t_descr.id,
3115                    )
3116                return _DataDepSize(a.size.min, a.size.max)
3117            else:
3118                assert_never(a.size)
3119
3120        # first resolve all , but the `SizeReference` input sizes
3121        for t_descr in self.inputs:
3122            for a in t_descr.axes:
3123                if not isinstance(a.size, SizeReference):
3124                    s = get_axis_size(a)
3125                    assert not isinstance(s, _DataDepSize)
3126                    inputs[t_descr.id, a.id] = s
3127
3128        # resolve all other input axis sizes
3129        for t_descr in self.inputs:
3130            for a in t_descr.axes:
3131                if isinstance(a.size, SizeReference):
3132                    s = get_axis_size(a)
3133                    assert not isinstance(s, _DataDepSize)
3134                    inputs[t_descr.id, a.id] = s
3135
3136        # resolve all output axis sizes
3137        for t_descr in self.outputs:
3138            for a in t_descr.axes:
3139                assert not isinstance(a.size, ParameterizedSize)
3140                s = get_axis_size(a)
3141                outputs[t_descr.id, a.id] = s
3142
3143        return _AxisSizes(inputs=inputs, outputs=outputs)

Determine input and output block shape for scale factors ns of parameterized input sizes.

Arguments:
  • ns: Scale factor n for each axis (keyed by (tensor_id, axis_id)) that is parameterized as size = min + n * step.
  • batch_size: The desired size of the batch dimension. If given batch_size overwrites any batch size present in max_input_shape. Default 1.
  • max_input_shape: Limits the derived block shapes. Each axis for which the input size, parameterized by n, is larger than max_input_shape is set to the minimal value n_min for which this is still true. Use this for small input samples or large values of ns. Or simply whenever you know the full input shape.
Returns:

Resolved axis sizes for model inputs and outputs.

@classmethod
def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3151    @classmethod
3152    def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3153        """Convert metadata following an older format version to this classes' format
3154        without validating the result.
3155        """
3156        if (
3157            data.get("type") == "model"
3158            and isinstance(fv := data.get("format_version"), str)
3159            and fv.count(".") == 2
3160        ):
3161            fv_parts = fv.split(".")
3162            if any(not p.isdigit() for p in fv_parts):
3163                return
3164
3165            fv_tuple = tuple(map(int, fv_parts))
3166
3167            assert cls.implemented_format_version_tuple[0:2] == (0, 5)
3168            if fv_tuple[:2] in ((0, 3), (0, 4)):
3169                m04 = _ModelDescr_v0_4.load(data)
3170                if isinstance(m04, InvalidDescr):
3171                    try:
3172                        updated = _model_conv.convert_as_dict(
3173                            m04  # pyright: ignore[reportArgumentType]
3174                        )
3175                    except Exception as e:
3176                        logger.error(
3177                            "Failed to convert from invalid model 0.4 description."
3178                            + f"\nerror: {e}"
3179                            + "\nProceeding with model 0.5 validation without conversion."
3180                        )
3181                        updated = None
3182                else:
3183                    updated = _model_conv.convert_as_dict(m04)
3184
3185                if updated is not None:
3186                    data.clear()
3187                    data.update(updated)
3188
3189            elif fv_tuple[:2] == (0, 5):
3190                # bump patch version
3191                data["format_version"] = cls.implemented_format_version

Convert metadata following an older format version to this classes' format without validating the result.

implemented_format_version_tuple: ClassVar[Tuple[int, int, int]] = (0, 5, 4)
model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

def model_post_init(self: pydantic.main.BaseModel, context: Any, /) -> None:
337def init_private_attributes(self: BaseModel, context: Any, /) -> None:
338    """This function is meant to behave like a BaseModel method to initialise private attributes.
339
340    It takes context as an argument since that's what pydantic-core passes when calling it.
341
342    Args:
343        self: The BaseModel instance.
344        context: The context.
345    """
346    if getattr(self, '__pydantic_private__', None) is None:
347        pydantic_private = {}
348        for name, private_attr in self.__private_attributes__.items():
349            default = private_attr.get_default()
350            if default is not PydanticUndefined:
351                pydantic_private[name] = default
352        object_setattr(self, '__pydantic_private__', pydantic_private)

This function is meant to behave like a BaseModel method to initialise private attributes.

It takes context as an argument since that's what pydantic-core passes when calling it.

Arguments:
  • self: The BaseModel instance.
  • context: The context.
def generate_covers( inputs: Sequence[Tuple[InputTensorDescr, numpy.ndarray[tuple[Any, ...], numpy.dtype[Any]]]], outputs: Sequence[Tuple[OutputTensorDescr, numpy.ndarray[tuple[Any, ...], numpy.dtype[Any]]]]) -> List[pathlib.Path]:
3416def generate_covers(
3417    inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]],
3418    outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]],
3419) -> List[Path]:
3420    def squeeze(
3421        data: NDArray[Any], axes: Sequence[AnyAxis]
3422    ) -> Tuple[NDArray[Any], List[AnyAxis]]:
3423        """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining"""
3424        if data.ndim != len(axes):
3425            raise ValueError(
3426                f"tensor shape {data.shape} does not match described axes"
3427                + f" {[a.id for a in axes]}"
3428            )
3429
3430        axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1]
3431        return data.squeeze(), axes
3432
3433    def normalize(
3434        data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7
3435    ) -> NDArray[np.float32]:
3436        data = data.astype("float32")
3437        data -= data.min(axis=axis, keepdims=True)
3438        data /= data.max(axis=axis, keepdims=True) + eps
3439        return data
3440
3441    def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]):
3442        original_shape = data.shape
3443        data, axes = squeeze(data, axes)
3444
3445        # take slice fom any batch or index axis if needed
3446        # and convert the first channel axis and take a slice from any additional channel axes
3447        slices: Tuple[slice, ...] = ()
3448        ndim = data.ndim
3449        ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2
3450        has_c_axis = False
3451        for i, a in enumerate(axes):
3452            s = data.shape[i]
3453            assert s > 1
3454            if (
3455                isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis))
3456                and ndim > ndim_need
3457            ):
3458                data = data[slices + (slice(s // 2 - 1, s // 2),)]
3459                ndim -= 1
3460            elif isinstance(a, ChannelAxis):
3461                if has_c_axis:
3462                    # second channel axis
3463                    data = data[slices + (slice(0, 1),)]
3464                    ndim -= 1
3465                else:
3466                    has_c_axis = True
3467                    if s == 2:
3468                        # visualize two channels with cyan and magenta
3469                        data = np.concatenate(
3470                            [
3471                                data[slices + (slice(1, 2),)],
3472                                data[slices + (slice(0, 1),)],
3473                                (
3474                                    data[slices + (slice(0, 1),)]
3475                                    + data[slices + (slice(1, 2),)]
3476                                )
3477                                / 2,  # TODO: take maximum instead?
3478                            ],
3479                            axis=i,
3480                        )
3481                    elif data.shape[i] == 3:
3482                        pass  # visualize 3 channels as RGB
3483                    else:
3484                        # visualize first 3 channels as RGB
3485                        data = data[slices + (slice(3),)]
3486
3487                    assert data.shape[i] == 3
3488
3489            slices += (slice(None),)
3490
3491        data, axes = squeeze(data, axes)
3492        assert len(axes) == ndim
3493        # take slice from z axis if needed
3494        slices = ()
3495        if ndim > ndim_need:
3496            for i, a in enumerate(axes):
3497                s = data.shape[i]
3498                if a.id == AxisId("z"):
3499                    data = data[slices + (slice(s // 2 - 1, s // 2),)]
3500                    data, axes = squeeze(data, axes)
3501                    ndim -= 1
3502                    break
3503
3504            slices += (slice(None),)
3505
3506        # take slice from any space or time axis
3507        slices = ()
3508
3509        for i, a in enumerate(axes):
3510            if ndim <= ndim_need:
3511                break
3512
3513            s = data.shape[i]
3514            assert s > 1
3515            if isinstance(
3516                a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis)
3517            ):
3518                data = data[slices + (slice(s // 2 - 1, s // 2),)]
3519                ndim -= 1
3520
3521            slices += (slice(None),)
3522
3523        del slices
3524        data, axes = squeeze(data, axes)
3525        assert len(axes) == ndim
3526
3527        if (has_c_axis and ndim != 3) or ndim != 2:
3528            raise ValueError(
3529                f"Failed to construct cover image from shape {original_shape}"
3530            )
3531
3532        if not has_c_axis:
3533            assert ndim == 2
3534            data = np.repeat(data[:, :, None], 3, axis=2)
3535            axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB"))))
3536            ndim += 1
3537
3538        assert ndim == 3
3539
3540        # transpose axis order such that longest axis comes first...
3541        axis_order: List[int] = list(np.argsort(list(data.shape)))
3542        axis_order.reverse()
3543        # ... and channel axis is last
3544        c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0]
3545        axis_order.append(axis_order.pop(c))
3546        axes = [axes[ao] for ao in axis_order]
3547        data = data.transpose(axis_order)
3548
3549        # h, w = data.shape[:2]
3550        # if h / w  in (1.0 or 2.0):
3551        #     pass
3552        # elif h / w < 2:
3553        # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images
3554
3555        norm_along = (
3556            tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None
3557        )
3558        # normalize the data and map to 8 bit
3559        data = normalize(data, norm_along)
3560        data = (data * 255).astype("uint8")
3561
3562        return data
3563
3564    def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]):
3565        assert im0.dtype == im1.dtype == np.uint8
3566        assert im0.shape == im1.shape
3567        assert im0.ndim == 3
3568        N, M, C = im0.shape
3569        assert C == 3
3570        out = np.ones((N, M, C), dtype="uint8")
3571        for c in range(C):
3572            outc = np.tril(im0[..., c])
3573            mask = outc == 0
3574            outc[mask] = np.triu(im1[..., c])[mask]
3575            out[..., c] = outc
3576
3577        return out
3578
3579    ipt_descr, ipt = inputs[0]
3580    out_descr, out = outputs[0]
3581
3582    ipt_img = to_2d_image(ipt, ipt_descr.axes)
3583    out_img = to_2d_image(out, out_descr.axes)
3584
3585    cover_folder = Path(mkdtemp())
3586    if ipt_img.shape == out_img.shape:
3587        covers = [cover_folder / "cover.png"]
3588        imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img))
3589    else:
3590        covers = [cover_folder / "input.png", cover_folder / "output.png"]
3591        imwrite(covers[0], ipt_img)
3592        imwrite(covers[1], out_img)
3593
3594    return covers
class TensorDescrBase[Annotated[Union[BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis], Discriminator]](bioimageio.spec._internal.node.Node, typing.Generic[~IO_AxisT]):
class TensorDescrBase[Annotated[Union[BatchAxis, ChannelAxis, IndexOutputAxis, Annotated[Union[Annotated[TimeOutputAxis, Tag], Annotated[TimeOutputAxisWithHalo, Tag]], Discriminator], Annotated[Union[Annotated[SpaceOutputAxis, Tag], Annotated[SpaceOutputAxisWithHalo, Tag]], Discriminator]], Discriminator]](bioimageio.spec._internal.node.Node, typing.Generic[~IO_AxisT]):