bioimageio.spec.model.v0_5
1from __future__ import annotations 2 3import collections.abc 4import re 5import string 6import warnings 7from abc import ABC 8from copy import deepcopy 9from datetime import datetime 10from itertools import chain 11from math import ceil 12from pathlib import Path, PurePosixPath 13from tempfile import mkdtemp 14from typing import ( 15 TYPE_CHECKING, 16 Any, 17 ClassVar, 18 Dict, 19 FrozenSet, 20 Generic, 21 List, 22 Literal, 23 Mapping, 24 NamedTuple, 25 Optional, 26 Sequence, 27 Set, 28 Tuple, 29 Type, 30 TypeVar, 31 Union, 32 cast, 33) 34 35import numpy as np 36from annotated_types import Ge, Gt, Interval, MaxLen, MinLen, Predicate 37from imageio.v3 import imread, imwrite # pyright: ignore[reportUnknownVariableType] 38from loguru import logger 39from numpy.typing import NDArray 40from pydantic import ( 41 Discriminator, 42 Field, 43 RootModel, 44 Tag, 45 ValidationInfo, 46 WrapSerializer, 47 field_validator, 48 model_validator, 49) 50from typing_extensions import Annotated, LiteralString, Self, assert_never 51 52from .._internal.common_nodes import ( 53 Converter, 54 InvalidDescr, 55 Node, 56 NodeWithExplicitlySetFields, 57) 58from .._internal.constants import DTYPE_LIMITS 59from .._internal.field_warning import issue_warning, warn 60from .._internal.io import BioimageioYamlContent as BioimageioYamlContent 61from .._internal.io import FileDescr as FileDescr 62from .._internal.io import WithSuffix, YamlValue, download 63from .._internal.io_basics import AbsoluteFilePath as AbsoluteFilePath 64from .._internal.io_basics import Sha256 as Sha256 65from .._internal.io_utils import load_array 66from .._internal.types import Datetime as Datetime 67from .._internal.types import Identifier as Identifier 68from .._internal.types import ( 69 ImportantFileSource, 70 LowerCaseIdentifier, 71 LowerCaseIdentifierAnno, 72 SiUnit, 73) 74from .._internal.types import NotEmpty as NotEmpty 75from .._internal.url import HttpUrl as HttpUrl 76from .._internal.validation_context import validation_context_var 77from .._internal.validator_annotations import RestrictCharacters 78from .._internal.version_type import Version as Version 79from .._internal.warning_levels import INFO 80from ..dataset.v0_2 import DatasetDescr as DatasetDescr02 81from ..dataset.v0_2 import LinkedDataset as LinkedDataset02 82from ..dataset.v0_3 import DatasetDescr as DatasetDescr 83from ..dataset.v0_3 import DatasetId as DatasetId 84from ..dataset.v0_3 import LinkedDataset as LinkedDataset 85from ..dataset.v0_3 import Uploader as Uploader 86from ..generic.v0_3 import ( 87 VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS, 88) 89from ..generic.v0_3 import Author as Author 90from ..generic.v0_3 import BadgeDescr as BadgeDescr 91from ..generic.v0_3 import CiteEntry as CiteEntry 92from ..generic.v0_3 import DeprecatedLicenseId as DeprecatedLicenseId 93from ..generic.v0_3 import ( 94 DocumentationSource, 95 GenericModelDescrBase, 96 LinkedResourceNode, 97 _author_conv, # pyright: ignore[reportPrivateUsage] 98 _maintainer_conv, # pyright: ignore[reportPrivateUsage] 99) 100from ..generic.v0_3 import Doi as Doi 101from ..generic.v0_3 import LicenseId as LicenseId 102from ..generic.v0_3 import LinkedResource as LinkedResource 103from ..generic.v0_3 import Maintainer as Maintainer 104from ..generic.v0_3 import OrcidId as OrcidId 105from ..generic.v0_3 import RelativeFilePath as RelativeFilePath 106from ..generic.v0_3 import ResourceId as ResourceId 107from .v0_4 import Author as _Author_v0_4 108from .v0_4 import BinarizeDescr as _BinarizeDescr_v0_4 109from .v0_4 import CallableFromDepencency as CallableFromDepencency 110from .v0_4 import CallableFromDepencency as _CallableFromDepencency_v0_4 111from .v0_4 import CallableFromFile as _CallableFromFile_v0_4 112from .v0_4 import ClipDescr as _ClipDescr_v0_4 113from .v0_4 import ClipKwargs as ClipKwargs 114from .v0_4 import ImplicitOutputShape as _ImplicitOutputShape_v0_4 115from .v0_4 import InputTensorDescr as _InputTensorDescr_v0_4 116from .v0_4 import KnownRunMode as KnownRunMode 117from .v0_4 import ModelDescr as _ModelDescr_v0_4 118from .v0_4 import OutputTensorDescr as _OutputTensorDescr_v0_4 119from .v0_4 import ParameterizedInputShape as _ParameterizedInputShape_v0_4 120from .v0_4 import PostprocessingDescr as _PostprocessingDescr_v0_4 121from .v0_4 import PreprocessingDescr as _PreprocessingDescr_v0_4 122from .v0_4 import ProcessingKwargs as ProcessingKwargs 123from .v0_4 import RunMode as RunMode 124from .v0_4 import ScaleLinearDescr as _ScaleLinearDescr_v0_4 125from .v0_4 import ScaleMeanVarianceDescr as _ScaleMeanVarianceDescr_v0_4 126from .v0_4 import ScaleRangeDescr as _ScaleRangeDescr_v0_4 127from .v0_4 import SigmoidDescr as _SigmoidDescr_v0_4 128from .v0_4 import TensorName as _TensorName_v0_4 129from .v0_4 import WeightsFormat as WeightsFormat 130from .v0_4 import ZeroMeanUnitVarianceDescr as _ZeroMeanUnitVarianceDescr_v0_4 131from .v0_4 import package_weights 132 133# unit names from https://ngff.openmicroscopy.org/latest/#axes-md 134SpaceUnit = Literal[ 135 "attometer", 136 "angstrom", 137 "centimeter", 138 "decimeter", 139 "exameter", 140 "femtometer", 141 "foot", 142 "gigameter", 143 "hectometer", 144 "inch", 145 "kilometer", 146 "megameter", 147 "meter", 148 "micrometer", 149 "mile", 150 "millimeter", 151 "nanometer", 152 "parsec", 153 "petameter", 154 "picometer", 155 "terameter", 156 "yard", 157 "yoctometer", 158 "yottameter", 159 "zeptometer", 160 "zettameter", 161] 162 163TimeUnit = Literal[ 164 "attosecond", 165 "centisecond", 166 "day", 167 "decisecond", 168 "exasecond", 169 "femtosecond", 170 "gigasecond", 171 "hectosecond", 172 "hour", 173 "kilosecond", 174 "megasecond", 175 "microsecond", 176 "millisecond", 177 "minute", 178 "nanosecond", 179 "petasecond", 180 "picosecond", 181 "second", 182 "terasecond", 183 "yoctosecond", 184 "yottasecond", 185 "zeptosecond", 186 "zettasecond", 187] 188 189AxisType = Literal["batch", "channel", "index", "time", "space"] 190 191 192class TensorId(LowerCaseIdentifier): 193 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 194 Annotated[LowerCaseIdentifierAnno, MaxLen(32)] 195 ] 196 197 198class AxisId(LowerCaseIdentifier): 199 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 200 Annotated[LowerCaseIdentifierAnno, MaxLen(16)] 201 ] 202 203 204def _is_batch(a: str) -> bool: 205 return a == BATCH_AXIS_ID 206 207 208def _is_not_batch(a: str) -> bool: 209 return not _is_batch(a) 210 211 212NonBatchAxisId = Annotated[AxisId, Predicate(_is_not_batch)] 213 214PostprocessingId = Literal[ 215 "binarize", 216 "clip", 217 "ensure_dtype", 218 "fixed_zero_mean_unit_variance", 219 "scale_linear", 220 "scale_mean_variance", 221 "scale_range", 222 "sigmoid", 223 "zero_mean_unit_variance", 224] 225PreprocessingId = Literal[ 226 "binarize", 227 "clip", 228 "ensure_dtype", 229 "scale_linear", 230 "sigmoid", 231 "zero_mean_unit_variance", 232 "scale_range", 233] 234 235 236SAME_AS_TYPE = "<same as type>" 237 238 239ParameterizedSize_N = int 240 241 242class ParameterizedSize(Node): 243 """Describes a range of valid tensor axis sizes as `size = min + n*step`.""" 244 245 N: ClassVar[Type[int]] = ParameterizedSize_N 246 """integer to parameterize this axis""" 247 248 min: Annotated[int, Gt(0)] 249 step: Annotated[int, Gt(0)] 250 251 def validate_size(self, size: int) -> int: 252 if size < self.min: 253 raise ValueError(f"size {size} < {self.min}") 254 if (size - self.min) % self.step != 0: 255 raise ValueError( 256 f"axis of size {size} is not parameterized by `min + n*step` =" 257 + f" `{self.min} + n*{self.step}`" 258 ) 259 260 return size 261 262 def get_size(self, n: ParameterizedSize_N) -> int: 263 return self.min + self.step * n 264 265 def get_n(self, s: int) -> ParameterizedSize_N: 266 """return smallest n parameterizing a size greater or equal than `s`""" 267 return ceil((s - self.min) / self.step) 268 269 270class DataDependentSize(Node): 271 min: Annotated[int, Gt(0)] = 1 272 max: Annotated[Optional[int], Gt(1)] = None 273 274 @model_validator(mode="after") 275 def _validate_max_gt_min(self): 276 if self.max is not None and self.min >= self.max: 277 raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}") 278 279 return self 280 281 def validate_size(self, size: int) -> int: 282 if size < self.min: 283 raise ValueError(f"size {size} < {self.min}") 284 285 if self.max is not None and size > self.max: 286 raise ValueError(f"size {size} > {self.max}") 287 288 return size 289 290 291class SizeReference(Node): 292 """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis. 293 294 `axis.size = reference.size * reference.scale / axis.scale + offset` 295 296 note: 297 1. The axis and the referenced axis need to have the same unit (or no unit). 298 2. Batch axes may not be referenced. 299 3. Fractions are rounded down. 300 4. If the reference axis is `concatenable` the referencing axis is assumed to be 301 `concatenable` as well with the same block order. 302 303 example: 304 An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm². 305 Let's assume that we want to express the image height h in relation to its width w 306 instead of only accepting input images of exactly 100*49 pixels 307 (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`). 308 309 >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2) 310 >>> h = SpaceInputAxis( 311 ... id=AxisId("h"), 312 ... size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1), 313 ... unit="millimeter", 314 ... scale=4, 315 ... ) 316 >>> print(h.size.compute(h, w)) 317 49 318 319 -> h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49 320 """ 321 322 tensor_id: TensorId 323 """tensor id of the reference axis""" 324 325 axis_id: AxisId 326 """axis id of the reference axis""" 327 328 offset: int = 0 329 330 def get_size( 331 self, 332 axis: Union[ 333 ChannelAxis, 334 IndexInputAxis, 335 IndexOutputAxis, 336 TimeInputAxis, 337 SpaceInputAxis, 338 TimeOutputAxis, 339 TimeOutputAxisWithHalo, 340 SpaceOutputAxis, 341 SpaceOutputAxisWithHalo, 342 ], 343 ref_axis: Union[ 344 ChannelAxis, 345 IndexInputAxis, 346 IndexOutputAxis, 347 TimeInputAxis, 348 SpaceInputAxis, 349 TimeOutputAxis, 350 TimeOutputAxisWithHalo, 351 SpaceOutputAxis, 352 SpaceOutputAxisWithHalo, 353 ], 354 n: ParameterizedSize_N = 0, 355 ref_size: Optional[int] = None, 356 ): 357 """Compute the concrete size for a given axis and its reference axis. 358 359 Args: 360 axis: The axis this `SizeReference` is the size of. 361 ref_axis: The reference axis to compute the size from. 362 n: If the **ref_axis** is parameterized (of type `ParameterizedSize`) 363 and no fixed **ref_size** is given, 364 **n** is used to compute the size of the parameterized **ref_axis**. 365 ref_size: Overwrite the reference size instead of deriving it from 366 **ref_axis** 367 (**ref_axis.scale** is still used; any given **n** is ignored). 368 """ 369 assert ( 370 axis.size == self 371 ), "Given `axis.size` is not defined by this `SizeReference`" 372 373 assert ( 374 ref_axis.id == self.axis_id 375 ), f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}." 376 377 assert axis.unit == ref_axis.unit, ( 378 "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`," 379 f" but {axis.unit}!={ref_axis.unit}" 380 ) 381 if ref_size is None: 382 if isinstance(ref_axis.size, (int, float)): 383 ref_size = ref_axis.size 384 elif isinstance(ref_axis.size, ParameterizedSize): 385 ref_size = ref_axis.size.get_size(n) 386 elif isinstance(ref_axis.size, DataDependentSize): 387 raise ValueError( 388 "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`." 389 ) 390 elif isinstance(ref_axis.size, SizeReference): 391 raise ValueError( 392 "Reference axis referenced in `SizeReference` may not be sized by a" 393 + " `SizeReference` itself." 394 ) 395 else: 396 assert_never(ref_axis.size) 397 398 return int(ref_size * ref_axis.scale / axis.scale + self.offset) 399 400 @staticmethod 401 def _get_unit( 402 axis: Union[ 403 ChannelAxis, 404 IndexInputAxis, 405 IndexOutputAxis, 406 TimeInputAxis, 407 SpaceInputAxis, 408 TimeOutputAxis, 409 TimeOutputAxisWithHalo, 410 SpaceOutputAxis, 411 SpaceOutputAxisWithHalo, 412 ], 413 ): 414 return axis.unit 415 416 417# this Axis definition is compatible with the NGFF draft from July 10, 2023 418# https://ngff.openmicroscopy.org/latest/#axes-md 419class AxisBase(NodeWithExplicitlySetFields): 420 fields_to_set_explicitly: ClassVar[FrozenSet[LiteralString]] = frozenset({"type"}) 421 422 id: AxisId 423 """An axis id unique across all axes of one tensor.""" 424 425 description: Annotated[str, MaxLen(128)] = "" 426 427 428class WithHalo(Node): 429 halo: Annotated[int, Ge(1)] 430 """The halo should be cropped from the output tensor to avoid boundary effects. 431 It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`. 432 To document a halo that is already cropped by the model use `size.offset` instead.""" 433 434 size: Annotated[ 435 SizeReference, 436 Field( 437 examples=[ 438 10, 439 SizeReference( 440 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 441 ).model_dump(mode="json"), 442 ] 443 ), 444 ] 445 """reference to another axis with an optional offset (see `SizeReference`)""" 446 447 448BATCH_AXIS_ID = AxisId("batch") 449 450 451class BatchAxis(AxisBase): 452 type: Literal["batch"] = "batch" 453 id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID 454 size: Optional[Literal[1]] = None 455 """The batch size may be fixed to 1, 456 otherwise (the default) it may be chosen arbitrarily depending on available memory""" 457 458 @property 459 def scale(self): 460 return 1.0 461 462 @property 463 def concatenable(self): 464 return True 465 466 @property 467 def unit(self): 468 return None 469 470 471class ChannelAxis(AxisBase): 472 type: Literal["channel"] = "channel" 473 id: NonBatchAxisId = AxisId("channel") 474 channel_names: NotEmpty[List[Identifier]] 475 476 @property 477 def size(self) -> int: 478 return len(self.channel_names) 479 480 @property 481 def concatenable(self): 482 return False 483 484 @property 485 def scale(self) -> float: 486 return 1.0 487 488 @property 489 def unit(self): 490 return None 491 492 493class IndexAxisBase(AxisBase): 494 type: Literal["index"] = "index" 495 id: NonBatchAxisId = AxisId("index") 496 497 @property 498 def scale(self) -> float: 499 return 1.0 500 501 @property 502 def unit(self): 503 return None 504 505 506class _WithInputAxisSize(Node): 507 size: Annotated[ 508 Union[Annotated[int, Gt(0)], ParameterizedSize, SizeReference], 509 Field( 510 examples=[ 511 10, 512 ParameterizedSize(min=32, step=16).model_dump(mode="json"), 513 SizeReference( 514 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 515 ).model_dump(mode="json"), 516 ] 517 ), 518 ] 519 """The size/length of this axis can be specified as 520 - fixed integer 521 - parameterized series of valid sizes (`ParameterizedSize`) 522 - reference to another axis with an optional offset (`SizeReference`) 523 """ 524 525 526class IndexInputAxis(IndexAxisBase, _WithInputAxisSize): 527 concatenable: bool = False 528 """If a model has a `concatenable` input axis, it can be processed blockwise, 529 splitting a longer sample axis into blocks matching its input tensor description. 530 Output axes are concatenable if they have a `SizeReference` to a concatenable 531 input axis. 532 """ 533 534 535class IndexOutputAxis(IndexAxisBase): 536 size: Annotated[ 537 Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize], 538 Field( 539 examples=[ 540 10, 541 SizeReference( 542 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 543 ).model_dump(mode="json"), 544 ] 545 ), 546 ] 547 """The size/length of this axis can be specified as 548 - fixed integer 549 - reference to another axis with an optional offset (`SizeReference`) 550 - data dependent size using `DataDependentSize` (size is only known after model inference) 551 """ 552 553 554class TimeAxisBase(AxisBase): 555 type: Literal["time"] = "time" 556 id: NonBatchAxisId = AxisId("time") 557 unit: Optional[TimeUnit] = None 558 scale: Annotated[float, Gt(0)] = 1.0 559 560 561class TimeInputAxis(TimeAxisBase, _WithInputAxisSize): 562 concatenable: bool = False 563 """If a model has a `concatenable` input axis, it can be processed blockwise, 564 splitting a longer sample axis into blocks matching its input tensor description. 565 Output axes are concatenable if they have a `SizeReference` to a concatenable 566 input axis. 567 """ 568 569 570class SpaceAxisBase(AxisBase): 571 type: Literal["space"] = "space" 572 id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x") 573 unit: Optional[SpaceUnit] = None 574 scale: Annotated[float, Gt(0)] = 1.0 575 576 577class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize): 578 concatenable: bool = False 579 """If a model has a `concatenable` input axis, it can be processed blockwise, 580 splitting a longer sample axis into blocks matching its input tensor description. 581 Output axes are concatenable if they have a `SizeReference` to a concatenable 582 input axis. 583 """ 584 585 586_InputAxisUnion = Union[ 587 BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis 588] 589InputAxis = Annotated[_InputAxisUnion, Discriminator("type")] 590 591 592class _WithOutputAxisSize(Node): 593 size: Annotated[ 594 Union[Annotated[int, Gt(0)], SizeReference], 595 Field( 596 examples=[ 597 10, 598 SizeReference( 599 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 600 ).model_dump(mode="json"), 601 ] 602 ), 603 ] 604 """The size/length of this axis can be specified as 605 - fixed integer 606 - reference to another axis with an optional offset (see `SizeReference`) 607 """ 608 609 610class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize): 611 pass 612 613 614class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo): 615 pass 616 617 618def _get_halo_axis_discriminator_value(v: Any) -> Literal["with_halo", "wo_halo"]: 619 if isinstance(v, dict): 620 return "with_halo" if "halo" in v else "wo_halo" 621 else: 622 return "with_halo" if hasattr(v, "halo") else "wo_halo" 623 624 625_TimeOutputAxisUnion = Annotated[ 626 Union[ 627 Annotated[TimeOutputAxis, Tag("wo_halo")], 628 Annotated[TimeOutputAxisWithHalo, Tag("with_halo")], 629 ], 630 Discriminator(_get_halo_axis_discriminator_value), 631] 632 633 634class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize): 635 pass 636 637 638class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo): 639 pass 640 641 642_SpaceOutputAxisUnion = Annotated[ 643 Union[ 644 Annotated[SpaceOutputAxis, Tag("wo_halo")], 645 Annotated[SpaceOutputAxisWithHalo, Tag("with_halo")], 646 ], 647 Discriminator(_get_halo_axis_discriminator_value), 648] 649 650 651_OutputAxisUnion = Union[ 652 BatchAxis, ChannelAxis, IndexOutputAxis, _TimeOutputAxisUnion, _SpaceOutputAxisUnion 653] 654OutputAxis = Annotated[_OutputAxisUnion, Discriminator("type")] 655 656AnyAxis = Union[InputAxis, OutputAxis] 657 658TVs = Union[ 659 NotEmpty[List[int]], 660 NotEmpty[List[float]], 661 NotEmpty[List[bool]], 662 NotEmpty[List[str]], 663] 664 665 666NominalOrOrdinalDType = Literal[ 667 "float32", 668 "float64", 669 "uint8", 670 "int8", 671 "uint16", 672 "int16", 673 "uint32", 674 "int32", 675 "uint64", 676 "int64", 677 "bool", 678] 679 680 681class NominalOrOrdinalDataDescr(Node): 682 values: TVs 683 """A fixed set of nominal or an ascending sequence of ordinal values. 684 In this case `data_type` is required to be an unsigend integer type, e.g. 'uint8'. 685 String `values` are interpreted as labels for tensor values 0, ..., N. 686 Note: as YAML 1.2 does not natively support a "set" datatype, 687 nominal values should be given as a sequence (aka list/array) as well. 688 """ 689 690 type: Annotated[ 691 NominalOrOrdinalDType, 692 Field( 693 examples=[ 694 "float32", 695 "uint8", 696 "uint16", 697 "int64", 698 "bool", 699 ], 700 ), 701 ] = "uint8" 702 703 @model_validator(mode="after") 704 def _validate_values_match_type( 705 self, 706 ) -> Self: 707 incompatible: List[Any] = [] 708 for v in self.values: 709 if self.type == "bool": 710 if not isinstance(v, bool): 711 incompatible.append(v) 712 elif self.type in DTYPE_LIMITS: 713 if ( 714 isinstance(v, (int, float)) 715 and ( 716 v < DTYPE_LIMITS[self.type].min 717 or v > DTYPE_LIMITS[self.type].max 718 ) 719 or (isinstance(v, str) and "uint" not in self.type) 720 or (isinstance(v, float) and "int" in self.type) 721 ): 722 incompatible.append(v) 723 else: 724 incompatible.append(v) 725 726 if len(incompatible) == 5: 727 incompatible.append("...") 728 break 729 730 if incompatible: 731 raise ValueError( 732 f"data type '{self.type}' incompatible with values {incompatible}" 733 ) 734 735 return self 736 737 unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None 738 739 @property 740 def range(self): 741 if isinstance(self.values[0], str): 742 return 0, len(self.values) - 1 743 else: 744 return min(self.values), max(self.values) 745 746 747IntervalOrRatioDType = Literal[ 748 "float32", 749 "float64", 750 "uint8", 751 "int8", 752 "uint16", 753 "int16", 754 "uint32", 755 "int32", 756 "uint64", 757 "int64", 758] 759 760 761class IntervalOrRatioDataDescr(Node): 762 type: Annotated[ # todo: rename to dtype 763 IntervalOrRatioDType, 764 Field( 765 examples=["float32", "float64", "uint8", "uint16"], 766 ), 767 ] = "float32" 768 range: Tuple[Optional[float], Optional[float]] = ( 769 None, 770 None, 771 ) 772 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor. 773 `None` corresponds to min/max of what can be expressed by `data_type`.""" 774 unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit" 775 scale: float = 1.0 776 """Scale for data on an interval (or ratio) scale.""" 777 offset: Optional[float] = None 778 """Offset for data on a ratio scale.""" 779 780 781TensorDataDescr = Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr] 782 783 784class ProcessingDescrBase(NodeWithExplicitlySetFields, ABC): 785 """processing base class""" 786 787 # id: Literal[PreprocessingId, PostprocessingId] # make abstract field 788 fields_to_set_explicitly: ClassVar[FrozenSet[LiteralString]] = frozenset({"id"}) 789 790 791class BinarizeKwargs(ProcessingKwargs): 792 """key word arguments for `BinarizeDescr`""" 793 794 threshold: float 795 """The fixed threshold""" 796 797 798class BinarizeAlongAxisKwargs(ProcessingKwargs): 799 """key word arguments for `BinarizeDescr`""" 800 801 threshold: NotEmpty[List[float]] 802 """The fixed threshold values along `axis`""" 803 804 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 805 """The `threshold` axis""" 806 807 808class BinarizeDescr(ProcessingDescrBase): 809 """Binarize the tensor with a fixed threshold. 810 811 Values above `BinarizeKwargs.threshold`/`BinarizeAlongAxisKwargs.threshold` 812 will be set to one, values below the threshold to zero. 813 """ 814 815 id: Literal["binarize"] = "binarize" 816 kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs] 817 818 819class ClipDescr(ProcessingDescrBase): 820 """Set tensor values below min to min and above max to max.""" 821 822 id: Literal["clip"] = "clip" 823 kwargs: ClipKwargs 824 825 826class EnsureDtypeKwargs(ProcessingKwargs): 827 """key word arguments for `EnsureDtypeDescr`""" 828 829 dtype: Literal[ 830 "float32", 831 "float64", 832 "uint8", 833 "int8", 834 "uint16", 835 "int16", 836 "uint32", 837 "int32", 838 "uint64", 839 "int64", 840 "bool", 841 ] 842 843 844class EnsureDtypeDescr(ProcessingDescrBase): 845 """cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching)""" 846 847 id: Literal["ensure_dtype"] = "ensure_dtype" 848 kwargs: EnsureDtypeKwargs 849 850 851class ScaleLinearKwargs(ProcessingKwargs): 852 """key word arguments for `ScaleLinearDescr`""" 853 854 gain: float = 1.0 855 """multiplicative factor""" 856 857 offset: float = 0.0 858 """additive term""" 859 860 @model_validator(mode="after") 861 def _validate(self) -> Self: 862 if self.gain == 1.0 and self.offset == 0.0: 863 raise ValueError( 864 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 865 + " != 0.0." 866 ) 867 868 return self 869 870 871class ScaleLinearAlongAxisKwargs(ProcessingKwargs): 872 """key word arguments for `ScaleLinearDescr`""" 873 874 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 875 """The axis of of gains/offsets values.""" 876 877 gain: Union[float, NotEmpty[List[float]]] = 1.0 878 """multiplicative factor""" 879 880 offset: Union[float, NotEmpty[List[float]]] = 0.0 881 """additive term""" 882 883 @model_validator(mode="after") 884 def _validate(self) -> Self: 885 886 if isinstance(self.gain, list): 887 if isinstance(self.offset, list): 888 if len(self.gain) != len(self.offset): 889 raise ValueError( 890 f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match." 891 ) 892 else: 893 self.offset = [float(self.offset)] * len(self.gain) 894 elif isinstance(self.offset, list): 895 self.gain = [float(self.gain)] * len(self.offset) 896 else: 897 raise ValueError( 898 "Do not specify an `axis` for scalar gain and offset values." 899 ) 900 901 if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset): 902 raise ValueError( 903 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 904 + " != 0.0." 905 ) 906 907 return self 908 909 910class ScaleLinearDescr(ProcessingDescrBase): 911 """Fixed linear scaling.""" 912 913 id: Literal["scale_linear"] = "scale_linear" 914 kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs] 915 916 917class SigmoidDescr(ProcessingDescrBase): 918 """The logistic sigmoid funciton, a.k.a. expit function.""" 919 920 id: Literal["sigmoid"] = "sigmoid" 921 922 @property 923 def kwargs(self) -> ProcessingKwargs: 924 """empty kwargs""" 925 return ProcessingKwargs() 926 927 928class FixedZeroMeanUnitVarianceKwargs(ProcessingKwargs): 929 """key word arguments for `FixedZeroMeanUnitVarianceDescr`""" 930 931 mean: float 932 """The mean value to normalize with.""" 933 934 std: Annotated[float, Ge(1e-6)] 935 """The standard deviation value to normalize with.""" 936 937 938class FixedZeroMeanUnitVarianceAlongAxisKwargs(ProcessingKwargs): 939 """key word arguments for `FixedZeroMeanUnitVarianceDescr`""" 940 941 mean: NotEmpty[List[float]] 942 """The mean value(s) to normalize with.""" 943 944 std: NotEmpty[List[Annotated[float, Ge(1e-6)]]] 945 """The standard deviation value(s) to normalize with. 946 Size must match `mean` values.""" 947 948 axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])] 949 """The axis of the mean/std values to normalize each entry along that dimension 950 separately.""" 951 952 @model_validator(mode="after") 953 def _mean_and_std_match(self) -> Self: 954 if len(self.mean) != len(self.std): 955 raise ValueError( 956 f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})" 957 + " must match." 958 ) 959 960 return self 961 962 963class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase): 964 """Subtract a given mean and divide by the standard deviation. 965 966 Normalize with fixed, precomputed values for 967 `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std` 968 Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given 969 axes. 970 """ 971 972 id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance" 973 kwargs: Union[ 974 FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs 975 ] 976 977 978class ZeroMeanUnitVarianceKwargs(ProcessingKwargs): 979 """key word arguments for `ZeroMeanUnitVarianceDescr`""" 980 981 axes: Annotated[ 982 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 983 ] = None 984 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 985 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 986 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 987 To normalize each sample independently leave out the 'batch' axis. 988 Default: Scale all axes jointly.""" 989 990 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 991 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`.""" 992 993 994class ZeroMeanUnitVarianceDescr(ProcessingDescrBase): 995 """Subtract mean and divide by variance.""" 996 997 id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance" 998 kwargs: ZeroMeanUnitVarianceKwargs 999 1000 1001class ScaleRangeKwargs(ProcessingKwargs): 1002 """key word arguments for `ScaleRangeDescr` 1003 1004 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default) 1005 this processing step normalizes data to the [0, 1] intervall. 1006 For other percentiles the normalized values will partially be outside the [0, 1] 1007 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the 1008 normalized values to a range. 1009 """ 1010 1011 axes: Annotated[ 1012 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 1013 ] = None 1014 """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value. 1015 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 1016 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 1017 To normalize samples indepdencently, leave out the "batch" axis. 1018 Default: Scale all axes jointly.""" 1019 1020 min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0 1021 """The lower percentile used to determine the value to align with zero.""" 1022 1023 max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0 1024 """The upper percentile used to determine the value to align with one. 1025 Has to be bigger than `min_percentile`. 1026 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly 1027 accepting percentiles specified in the range 0.0 to 1.0.""" 1028 1029 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 1030 """Epsilon for numeric stability. 1031 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`; 1032 with `v_lower,v_upper` values at the respective percentiles.""" 1033 1034 reference_tensor: Optional[TensorId] = None 1035 """Tensor ID to compute the percentiles from. Default: The tensor itself. 1036 For any tensor in `inputs` only input tensor references are allowed.""" 1037 1038 @field_validator("max_percentile", mode="after") 1039 @classmethod 1040 def min_smaller_max(cls, value: float, info: ValidationInfo) -> float: 1041 if (min_p := info.data["min_percentile"]) >= value: 1042 raise ValueError(f"min_percentile {min_p} >= max_percentile {value}") 1043 1044 return value 1045 1046 1047class ScaleRangeDescr(ProcessingDescrBase): 1048 """Scale with percentiles.""" 1049 1050 id: Literal["scale_range"] = "scale_range" 1051 kwargs: ScaleRangeKwargs 1052 1053 1054class ScaleMeanVarianceKwargs(ProcessingKwargs): 1055 """key word arguments for `ScaleMeanVarianceKwargs`""" 1056 1057 reference_tensor: TensorId 1058 """Name of tensor to match.""" 1059 1060 axes: Annotated[ 1061 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 1062 ] = None 1063 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 1064 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 1065 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 1066 To normalize samples independently, leave out the 'batch' axis. 1067 Default: Scale all axes jointly.""" 1068 1069 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 1070 """Epsilon for numeric stability: 1071 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`""" 1072 1073 1074class ScaleMeanVarianceDescr(ProcessingDescrBase): 1075 """Scale a tensor's data distribution to match another tensor's mean/std. 1076 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.` 1077 """ 1078 1079 id: Literal["scale_mean_variance"] = "scale_mean_variance" 1080 kwargs: ScaleMeanVarianceKwargs 1081 1082 1083PreprocessingDescr = Annotated[ 1084 Union[ 1085 BinarizeDescr, 1086 ClipDescr, 1087 EnsureDtypeDescr, 1088 ScaleLinearDescr, 1089 SigmoidDescr, 1090 FixedZeroMeanUnitVarianceDescr, 1091 ZeroMeanUnitVarianceDescr, 1092 ScaleRangeDescr, 1093 ], 1094 Discriminator("id"), 1095] 1096PostprocessingDescr = Annotated[ 1097 Union[ 1098 BinarizeDescr, 1099 ClipDescr, 1100 EnsureDtypeDescr, 1101 ScaleLinearDescr, 1102 SigmoidDescr, 1103 FixedZeroMeanUnitVarianceDescr, 1104 ZeroMeanUnitVarianceDescr, 1105 ScaleRangeDescr, 1106 ScaleMeanVarianceDescr, 1107 ], 1108 Discriminator("id"), 1109] 1110 1111IO_AxisT = TypeVar("IO_AxisT", InputAxis, OutputAxis) 1112 1113 1114class TensorDescrBase(Node, Generic[IO_AxisT]): 1115 id: TensorId 1116 """Tensor id. No duplicates are allowed.""" 1117 1118 description: Annotated[str, MaxLen(128)] = "" 1119 """free text description""" 1120 1121 axes: NotEmpty[Sequence[IO_AxisT]] 1122 """tensor axes""" 1123 1124 @property 1125 def shape(self): 1126 return tuple(a.size for a in self.axes) 1127 1128 @field_validator("axes", mode="after", check_fields=False) 1129 @classmethod 1130 def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]: 1131 batch_axes = [a for a in axes if a.type == "batch"] 1132 if len(batch_axes) > 1: 1133 raise ValueError( 1134 f"Only one batch axis (per tensor) allowed, but got {batch_axes}" 1135 ) 1136 1137 seen_ids: Set[AxisId] = set() 1138 duplicate_axes_ids: Set[AxisId] = set() 1139 for a in axes: 1140 (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id) 1141 1142 if duplicate_axes_ids: 1143 raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}") 1144 1145 return axes 1146 1147 test_tensor: FileDescr 1148 """An example tensor to use for testing. 1149 Using the model with the test input tensors is expected to yield the test output tensors. 1150 Each test tensor has be a an ndarray in the 1151 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format). 1152 The file extension must be '.npy'.""" 1153 1154 sample_tensor: Optional[FileDescr] = None 1155 """A sample tensor to illustrate a possible input/output for the model, 1156 The sample image primarily serves to inform a human user about an example use case 1157 and is typically stored as .hdf5, .png or .tiff. 1158 It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats) 1159 (numpy's `.npy` format is not supported). 1160 The image dimensionality has to match the number of axes specified in this tensor description. 1161 """ 1162 1163 @model_validator(mode="after") 1164 def _validate_sample_tensor(self) -> Self: 1165 if ( 1166 self.sample_tensor is None 1167 or not validation_context_var.get().perform_io_checks 1168 ): 1169 return self 1170 1171 local = download(self.sample_tensor.source, sha256=self.sample_tensor.sha256) 1172 tensor: NDArray[Any] = imread( 1173 local.path.read_bytes(), 1174 extension=PurePosixPath(local.original_file_name).suffix, 1175 ) 1176 n_dims = len(tensor.squeeze().shape) 1177 n_dims_min = n_dims_max = len(self.axes) 1178 1179 for a in self.axes: 1180 if isinstance(a, BatchAxis): 1181 n_dims_min -= 1 1182 elif isinstance(a.size, int): 1183 if a.size == 1: 1184 n_dims_min -= 1 1185 elif isinstance(a.size, (ParameterizedSize, DataDependentSize)): 1186 if a.size.min == 1: 1187 n_dims_min -= 1 1188 elif isinstance(a.size, SizeReference): 1189 if a.size.offset < 2: 1190 # size reference may result in singleton axis 1191 n_dims_min -= 1 1192 else: 1193 assert_never(a.size) 1194 1195 n_dims_min = max(0, n_dims_min) 1196 if n_dims < n_dims_min or n_dims > n_dims_max: 1197 raise ValueError( 1198 f"Expected sample tensor to have {n_dims_min} to" 1199 + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})." 1200 ) 1201 1202 return self 1203 1204 data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = ( 1205 IntervalOrRatioDataDescr() 1206 ) 1207 """Description of the tensor's data values, optionally per channel. 1208 If specified per channel, the data `type` needs to match across channels.""" 1209 1210 @property 1211 def dtype( 1212 self, 1213 ) -> Literal[ 1214 "float32", 1215 "float64", 1216 "uint8", 1217 "int8", 1218 "uint16", 1219 "int16", 1220 "uint32", 1221 "int32", 1222 "uint64", 1223 "int64", 1224 "bool", 1225 ]: 1226 """dtype as specified under `data.type` or `data[i].type`""" 1227 if isinstance(self.data, collections.abc.Sequence): 1228 return self.data[0].type 1229 else: 1230 return self.data.type 1231 1232 @field_validator("data", mode="after") 1233 @classmethod 1234 def _check_data_type_across_channels( 1235 cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] 1236 ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]: 1237 if not isinstance(value, list): 1238 return value 1239 1240 dtypes = {t.type for t in value} 1241 if len(dtypes) > 1: 1242 raise ValueError( 1243 "Tensor data descriptions per channel need to agree in their data" 1244 + f" `type`, but found {dtypes}." 1245 ) 1246 1247 return value 1248 1249 @model_validator(mode="after") 1250 def _check_data_matches_channelaxis(self) -> Self: 1251 if not isinstance(self.data, (list, tuple)): 1252 return self 1253 1254 for a in self.axes: 1255 if isinstance(a, ChannelAxis): 1256 size = a.size 1257 assert isinstance(size, int) 1258 break 1259 else: 1260 return self 1261 1262 if len(self.data) != size: 1263 raise ValueError( 1264 f"Got tensor data descriptions for {len(self.data)} channels, but" 1265 + f" '{a.id}' axis has size {size}." 1266 ) 1267 1268 return self 1269 1270 def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]: 1271 if len(array.shape) != len(self.axes): 1272 raise ValueError( 1273 f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})" 1274 + f" incompatible with {len(self.axes)} axes." 1275 ) 1276 return {a.id: array.shape[i] for i, a in enumerate(self.axes)} 1277 1278 1279class InputTensorDescr(TensorDescrBase[InputAxis]): 1280 id: TensorId = TensorId("input") 1281 """Input tensor id. 1282 No duplicates are allowed across all inputs and outputs.""" 1283 1284 optional: bool = False 1285 """indicates that this tensor may be `None`""" 1286 1287 preprocessing: List[PreprocessingDescr] = Field(default_factory=list) 1288 """Description of how this input should be preprocessed. 1289 1290 notes: 1291 - If preprocessing does not start with an 'ensure_dtype' entry, it is added 1292 to ensure an input tensor's data type matches the input tensor's data description. 1293 - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an 1294 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally 1295 changing the data type. 1296 """ 1297 1298 @model_validator(mode="after") 1299 def _validate_preprocessing_kwargs(self) -> Self: 1300 axes_ids = [a.id for a in self.axes] 1301 for p in self.preprocessing: 1302 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 1303 if kwargs_axes is None: 1304 continue 1305 1306 if not isinstance(kwargs_axes, collections.abc.Sequence): 1307 raise ValueError( 1308 f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}" 1309 ) 1310 1311 if any(a not in axes_ids for a in kwargs_axes): 1312 raise ValueError( 1313 "`preprocessing.i.kwargs.axes` needs to be subset of axes ids" 1314 ) 1315 1316 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 1317 dtype = self.data.type 1318 else: 1319 dtype = self.data[0].type 1320 1321 # ensure `preprocessing` begins with `EnsureDtypeDescr` 1322 if not self.preprocessing or not isinstance( 1323 self.preprocessing[0], EnsureDtypeDescr 1324 ): 1325 self.preprocessing.insert( 1326 0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 1327 ) 1328 1329 # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 1330 if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)): 1331 self.preprocessing.append( 1332 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 1333 ) 1334 1335 return self 1336 1337 1338def convert_axes( 1339 axes: str, 1340 *, 1341 shape: Union[ 1342 Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4 1343 ], 1344 tensor_type: Literal["input", "output"], 1345 halo: Optional[Sequence[int]], 1346 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 1347): 1348 ret: List[AnyAxis] = [] 1349 for i, a in enumerate(axes): 1350 axis_type = _AXIS_TYPE_MAP.get(a, a) 1351 if axis_type == "batch": 1352 ret.append(BatchAxis()) 1353 continue 1354 1355 scale = 1.0 1356 if isinstance(shape, _ParameterizedInputShape_v0_4): 1357 if shape.step[i] == 0: 1358 size = shape.min[i] 1359 else: 1360 size = ParameterizedSize(min=shape.min[i], step=shape.step[i]) 1361 elif isinstance(shape, _ImplicitOutputShape_v0_4): 1362 ref_t = str(shape.reference_tensor) 1363 if ref_t.count(".") == 1: 1364 t_id, orig_a_id = ref_t.split(".") 1365 else: 1366 t_id = ref_t 1367 orig_a_id = a 1368 1369 a_id = _AXIS_ID_MAP.get(orig_a_id, a) 1370 if not (orig_scale := shape.scale[i]): 1371 # old way to insert a new axis dimension 1372 size = int(2 * shape.offset[i]) 1373 else: 1374 scale = 1 / orig_scale 1375 if axis_type in ("channel", "index"): 1376 # these axes no longer have a scale 1377 offset_from_scale = orig_scale * size_refs.get( 1378 _TensorName_v0_4(t_id), {} 1379 ).get(orig_a_id, 0) 1380 else: 1381 offset_from_scale = 0 1382 size = SizeReference( 1383 tensor_id=TensorId(t_id), 1384 axis_id=AxisId(a_id), 1385 offset=int(offset_from_scale + 2 * shape.offset[i]), 1386 ) 1387 else: 1388 size = shape[i] 1389 1390 if axis_type == "time": 1391 if tensor_type == "input": 1392 ret.append(TimeInputAxis(size=size, scale=scale)) 1393 else: 1394 assert not isinstance(size, ParameterizedSize) 1395 if halo is None: 1396 ret.append(TimeOutputAxis(size=size, scale=scale)) 1397 else: 1398 assert not isinstance(size, int) 1399 ret.append( 1400 TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i]) 1401 ) 1402 1403 elif axis_type == "index": 1404 if tensor_type == "input": 1405 ret.append(IndexInputAxis(size=size)) 1406 else: 1407 if isinstance(size, ParameterizedSize): 1408 size = DataDependentSize(min=size.min) 1409 1410 ret.append(IndexOutputAxis(size=size)) 1411 elif axis_type == "channel": 1412 assert not isinstance(size, ParameterizedSize) 1413 if isinstance(size, SizeReference): 1414 warnings.warn( 1415 "Conversion of channel size from an implicit output shape may be" 1416 + " wrong" 1417 ) 1418 ret.append( 1419 ChannelAxis( 1420 channel_names=[ 1421 Identifier(f"channel{i}") for i in range(size.offset) 1422 ] 1423 ) 1424 ) 1425 else: 1426 ret.append( 1427 ChannelAxis( 1428 channel_names=[Identifier(f"channel{i}") for i in range(size)] 1429 ) 1430 ) 1431 elif axis_type == "space": 1432 if tensor_type == "input": 1433 ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale)) 1434 else: 1435 assert not isinstance(size, ParameterizedSize) 1436 if halo is None or halo[i] == 0: 1437 ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale)) 1438 elif isinstance(size, int): 1439 raise NotImplementedError( 1440 f"output axis with halo and fixed size (here {size}) not allowed" 1441 ) 1442 else: 1443 ret.append( 1444 SpaceOutputAxisWithHalo( 1445 id=AxisId(a), size=size, scale=scale, halo=halo[i] 1446 ) 1447 ) 1448 1449 return ret 1450 1451 1452_AXIS_TYPE_MAP = { 1453 "b": "batch", 1454 "t": "time", 1455 "i": "index", 1456 "c": "channel", 1457 "x": "space", 1458 "y": "space", 1459 "z": "space", 1460} 1461 1462_AXIS_ID_MAP = { 1463 "b": "batch", 1464 "t": "time", 1465 "i": "index", 1466 "c": "channel", 1467} 1468 1469 1470def _axes_letters_to_ids( 1471 axes: Optional[str], 1472) -> Optional[List[AxisId]]: 1473 if axes is None: 1474 return None 1475 return [AxisId(_AXIS_ID_MAP.get(a, a)) for a in map(str, axes)] 1476 1477 1478def _get_complement_v04_axis( 1479 tensor_axes: Sequence[str], axes: Optional[Sequence[str]] 1480) -> Optional[AxisId]: 1481 if axes is None: 1482 return None 1483 1484 non_complement_axes = set(axes) | {"b"} 1485 complement_axes = [a for a in tensor_axes if a not in non_complement_axes] 1486 if len(complement_axes) > 1: 1487 raise ValueError( 1488 f"Expected none or a single complement axis, but axes '{axes}' " 1489 + f"for tensor dims '{tensor_axes}' leave '{complement_axes}'." 1490 ) 1491 1492 return None if not complement_axes else AxisId(complement_axes[0]) 1493 1494 1495def _convert_proc( 1496 p: Union[_PreprocessingDescr_v0_4, _PostprocessingDescr_v0_4], 1497 tensor_axes: Sequence[str], 1498) -> Union[PreprocessingDescr, PostprocessingDescr]: 1499 if isinstance(p, _BinarizeDescr_v0_4): 1500 return BinarizeDescr(kwargs=BinarizeKwargs(threshold=p.kwargs.threshold)) 1501 elif isinstance(p, _ClipDescr_v0_4): 1502 return ClipDescr(kwargs=ClipKwargs(min=p.kwargs.min, max=p.kwargs.max)) 1503 elif isinstance(p, _SigmoidDescr_v0_4): 1504 return SigmoidDescr() 1505 elif isinstance(p, _ScaleLinearDescr_v0_4): 1506 axes = _axes_letters_to_ids(p.kwargs.axes) 1507 if p.kwargs.axes is None: 1508 axis = None 1509 else: 1510 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes) 1511 1512 if axis is None: 1513 assert not isinstance(p.kwargs.gain, list) 1514 assert not isinstance(p.kwargs.offset, list) 1515 kwargs = ScaleLinearKwargs(gain=p.kwargs.gain, offset=p.kwargs.offset) 1516 else: 1517 kwargs = ScaleLinearAlongAxisKwargs( 1518 axis=axis, gain=p.kwargs.gain, offset=p.kwargs.offset 1519 ) 1520 return ScaleLinearDescr(kwargs=kwargs) 1521 elif isinstance(p, _ScaleMeanVarianceDescr_v0_4): 1522 return ScaleMeanVarianceDescr( 1523 kwargs=ScaleMeanVarianceKwargs( 1524 axes=_axes_letters_to_ids(p.kwargs.axes), 1525 reference_tensor=TensorId(str(p.kwargs.reference_tensor)), 1526 eps=p.kwargs.eps, 1527 ) 1528 ) 1529 elif isinstance(p, _ZeroMeanUnitVarianceDescr_v0_4): 1530 if p.kwargs.mode == "fixed": 1531 mean = p.kwargs.mean 1532 std = p.kwargs.std 1533 assert mean is not None 1534 assert std is not None 1535 1536 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes) 1537 1538 if axis is None: 1539 return FixedZeroMeanUnitVarianceDescr( 1540 kwargs=FixedZeroMeanUnitVarianceKwargs( 1541 mean=mean, std=std # pyright: ignore[reportArgumentType] 1542 ) 1543 ) 1544 else: 1545 if not isinstance(mean, list): 1546 mean = [float(mean)] 1547 if not isinstance(std, list): 1548 std = [float(std)] 1549 1550 return FixedZeroMeanUnitVarianceDescr( 1551 kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( 1552 axis=axis, mean=mean, std=std 1553 ) 1554 ) 1555 1556 else: 1557 axes = _axes_letters_to_ids(p.kwargs.axes) or [] 1558 if p.kwargs.mode == "per_dataset": 1559 axes = [AxisId("batch")] + axes 1560 if not axes: 1561 axes = None 1562 return ZeroMeanUnitVarianceDescr( 1563 kwargs=ZeroMeanUnitVarianceKwargs(axes=axes, eps=p.kwargs.eps) 1564 ) 1565 1566 elif isinstance(p, _ScaleRangeDescr_v0_4): 1567 return ScaleRangeDescr( 1568 kwargs=ScaleRangeKwargs( 1569 axes=_axes_letters_to_ids(p.kwargs.axes), 1570 min_percentile=p.kwargs.min_percentile, 1571 max_percentile=p.kwargs.max_percentile, 1572 eps=p.kwargs.eps, 1573 ) 1574 ) 1575 else: 1576 assert_never(p) 1577 1578 1579class _InputTensorConv( 1580 Converter[ 1581 _InputTensorDescr_v0_4, 1582 InputTensorDescr, 1583 ImportantFileSource, 1584 Optional[ImportantFileSource], 1585 Mapping[_TensorName_v0_4, Mapping[str, int]], 1586 ] 1587): 1588 def _convert( 1589 self, 1590 src: _InputTensorDescr_v0_4, 1591 tgt: "type[InputTensorDescr] | type[dict[str, Any]]", 1592 test_tensor: ImportantFileSource, 1593 sample_tensor: Optional[ImportantFileSource], 1594 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 1595 ) -> "InputTensorDescr | dict[str, Any]": 1596 axes: List[InputAxis] = convert_axes( # pyright: ignore[reportAssignmentType] 1597 src.axes, 1598 shape=src.shape, 1599 tensor_type="input", 1600 halo=None, 1601 size_refs=size_refs, 1602 ) 1603 prep: List[PreprocessingDescr] = [] 1604 for p in src.preprocessing: 1605 cp = _convert_proc(p, src.axes) 1606 assert not isinstance(cp, ScaleMeanVarianceDescr) 1607 prep.append(cp) 1608 1609 return tgt( 1610 axes=axes, 1611 id=TensorId(str(src.name)), 1612 test_tensor=FileDescr(source=test_tensor), 1613 sample_tensor=( 1614 None if sample_tensor is None else FileDescr(source=sample_tensor) 1615 ), 1616 data=dict(type=src.data_type), # pyright: ignore[reportArgumentType] 1617 preprocessing=prep, 1618 ) 1619 1620 1621_input_tensor_conv = _InputTensorConv(_InputTensorDescr_v0_4, InputTensorDescr) 1622 1623 1624class OutputTensorDescr(TensorDescrBase[OutputAxis]): 1625 id: TensorId = TensorId("output") 1626 """Output tensor id. 1627 No duplicates are allowed across all inputs and outputs.""" 1628 1629 postprocessing: List[PostprocessingDescr] = Field(default_factory=list) 1630 """Description of how this output should be postprocessed. 1631 1632 note: `postprocessing` always ends with an 'ensure_dtype' operation. 1633 If not given this is added to cast to this tensor's `data.type`. 1634 """ 1635 1636 @model_validator(mode="after") 1637 def _validate_postprocessing_kwargs(self) -> Self: 1638 axes_ids = [a.id for a in self.axes] 1639 for p in self.postprocessing: 1640 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 1641 if kwargs_axes is None: 1642 continue 1643 1644 if not isinstance(kwargs_axes, collections.abc.Sequence): 1645 raise ValueError( 1646 f"expected `axes` sequence, but got {type(kwargs_axes)}" 1647 ) 1648 1649 if any(a not in axes_ids for a in kwargs_axes): 1650 raise ValueError("`kwargs.axes` needs to be subset of axes ids") 1651 1652 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 1653 dtype = self.data.type 1654 else: 1655 dtype = self.data[0].type 1656 1657 # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 1658 if not self.postprocessing or not isinstance( 1659 self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr) 1660 ): 1661 self.postprocessing.append( 1662 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 1663 ) 1664 return self 1665 1666 1667class _OutputTensorConv( 1668 Converter[ 1669 _OutputTensorDescr_v0_4, 1670 OutputTensorDescr, 1671 ImportantFileSource, 1672 Optional[ImportantFileSource], 1673 Mapping[_TensorName_v0_4, Mapping[str, int]], 1674 ] 1675): 1676 def _convert( 1677 self, 1678 src: _OutputTensorDescr_v0_4, 1679 tgt: "type[OutputTensorDescr] | type[dict[str, Any]]", 1680 test_tensor: ImportantFileSource, 1681 sample_tensor: Optional[ImportantFileSource], 1682 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 1683 ) -> "OutputTensorDescr | dict[str, Any]": 1684 # TODO: split convert_axes into convert_output_axes and convert_input_axes 1685 axes: List[OutputAxis] = convert_axes( # pyright: ignore[reportAssignmentType] 1686 src.axes, 1687 shape=src.shape, 1688 tensor_type="output", 1689 halo=src.halo, 1690 size_refs=size_refs, 1691 ) 1692 data_descr: Dict[str, Any] = dict(type=src.data_type) 1693 if data_descr["type"] == "bool": 1694 data_descr["values"] = [False, True] 1695 1696 return tgt( 1697 axes=axes, 1698 id=TensorId(str(src.name)), 1699 test_tensor=FileDescr(source=test_tensor), 1700 sample_tensor=( 1701 None if sample_tensor is None else FileDescr(source=sample_tensor) 1702 ), 1703 data=data_descr, # pyright: ignore[reportArgumentType] 1704 postprocessing=[_convert_proc(p, src.axes) for p in src.postprocessing], 1705 ) 1706 1707 1708_output_tensor_conv = _OutputTensorConv(_OutputTensorDescr_v0_4, OutputTensorDescr) 1709 1710 1711TensorDescr = Union[InputTensorDescr, OutputTensorDescr] 1712 1713 1714def validate_tensors( 1715 tensors: Mapping[TensorId, Tuple[TensorDescr, NDArray[Any]]], 1716 tensor_origin: str, # for more precise error messages, e.g. 'test_tensor' 1717): 1718 all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, int]]] = {} 1719 1720 def e_msg(d: TensorDescr): 1721 return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]" 1722 1723 for descr, array in tensors.values(): 1724 try: 1725 axis_sizes = descr.get_axis_sizes_for_array(array) 1726 except ValueError as e: 1727 raise ValueError(f"{e_msg(descr)} {e}") 1728 else: 1729 all_tensor_axes[descr.id] = { 1730 a.id: (a, axis_sizes[a.id]) for a in descr.axes 1731 } 1732 1733 for descr, array in tensors.values(): 1734 if array.dtype.name != descr.dtype: 1735 raise ValueError( 1736 f"{e_msg(descr)}.{tensor_origin}.dtype '{array.dtype.name}' does not" 1737 + f" match described dtype '{descr.dtype}'" 1738 ) 1739 1740 for a in descr.axes: 1741 actual_size = all_tensor_axes[descr.id][a.id][1] 1742 if a.size is None: 1743 continue 1744 1745 if isinstance(a.size, int): 1746 if actual_size != a.size: 1747 raise ValueError( 1748 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' " 1749 + f"has incompatible size {actual_size}, expected {a.size}" 1750 ) 1751 elif isinstance(a.size, ParameterizedSize): 1752 _ = a.size.validate_size(actual_size) 1753 elif isinstance(a.size, DataDependentSize): 1754 _ = a.size.validate_size(actual_size) 1755 elif isinstance(a.size, SizeReference): 1756 ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id) 1757 if ref_tensor_axes is None: 1758 raise ValueError( 1759 f"{e_msg(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor" 1760 + f" reference '{a.size.tensor_id}'" 1761 ) 1762 1763 ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None)) 1764 if ref_axis is None or ref_size is None: 1765 raise ValueError( 1766 f"{e_msg(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis" 1767 + f" reference '{a.size.tensor_id}.{a.size.axis_id}" 1768 ) 1769 1770 if a.unit != ref_axis.unit: 1771 raise ValueError( 1772 f"{e_msg(descr)}.axes[{a.id}].size: `SizeReference` requires" 1773 + " axis and reference axis to have the same `unit`, but" 1774 + f" {a.unit}!={ref_axis.unit}" 1775 ) 1776 1777 if actual_size != ( 1778 expected_size := ( 1779 ref_size * ref_axis.scale / a.scale + a.size.offset 1780 ) 1781 ): 1782 raise ValueError( 1783 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' of size" 1784 + f" {actual_size} invalid for referenced size {ref_size};" 1785 + f" expected {expected_size}" 1786 ) 1787 else: 1788 assert_never(a.size) 1789 1790 1791class EnvironmentFileDescr(FileDescr): 1792 source: Annotated[ 1793 ImportantFileSource, 1794 WithSuffix((".yaml", ".yml"), case_sensitive=True), 1795 Field( 1796 examples=["environment.yaml"], 1797 ), 1798 ] 1799 """∈📦 Conda environment file. 1800 Allows to specify custom dependencies, see conda docs: 1801 - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms) 1802 - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually) 1803 """ 1804 1805 1806class _ArchitectureCallableDescr(Node): 1807 callable: Annotated[Identifier, Field(examples=["MyNetworkClass", "get_my_model"])] 1808 """Identifier of the callable that returns a torch.nn.Module instance.""" 1809 1810 kwargs: Dict[str, YamlValue] = Field(default_factory=dict) 1811 """key word arguments for the `callable`""" 1812 1813 1814class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr): 1815 pass 1816 1817 1818class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr): 1819 import_from: str 1820 """Where to import the callable from, i.e. `from <import_from> import <callable>`""" 1821 1822 1823ArchitectureDescr = Annotated[ 1824 Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr], 1825 Field(union_mode="left_to_right"), 1826] 1827 1828 1829class _ArchFileConv( 1830 Converter[ 1831 _CallableFromFile_v0_4, 1832 ArchitectureFromFileDescr, 1833 Optional[Sha256], 1834 Dict[str, Any], 1835 ] 1836): 1837 def _convert( 1838 self, 1839 src: _CallableFromFile_v0_4, 1840 tgt: "type[ArchitectureFromFileDescr | dict[str, Any]]", 1841 sha256: Optional[Sha256], 1842 kwargs: Dict[str, Any], 1843 ) -> "ArchitectureFromFileDescr | dict[str, Any]": 1844 if src.startswith("http") and src.count(":") == 2: 1845 http, source, callable_ = src.split(":") 1846 source = ":".join((http, source)) 1847 elif not src.startswith("http") and src.count(":") == 1: 1848 source, callable_ = src.split(":") 1849 else: 1850 source = str(src) 1851 callable_ = str(src) 1852 return tgt( 1853 callable=Identifier(callable_), 1854 source=cast(ImportantFileSource, source), 1855 sha256=sha256, 1856 kwargs=kwargs, 1857 ) 1858 1859 1860_arch_file_conv = _ArchFileConv(_CallableFromFile_v0_4, ArchitectureFromFileDescr) 1861 1862 1863class _ArchLibConv( 1864 Converter[ 1865 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr, Dict[str, Any] 1866 ] 1867): 1868 def _convert( 1869 self, 1870 src: _CallableFromDepencency_v0_4, 1871 tgt: "type[ArchitectureFromLibraryDescr | dict[str, Any]]", 1872 kwargs: Dict[str, Any], 1873 ) -> "ArchitectureFromLibraryDescr | dict[str, Any]": 1874 *mods, callable_ = src.split(".") 1875 import_from = ".".join(mods) 1876 return tgt( 1877 import_from=import_from, callable=Identifier(callable_), kwargs=kwargs 1878 ) 1879 1880 1881_arch_lib_conv = _ArchLibConv( 1882 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr 1883) 1884 1885 1886class WeightsEntryDescrBase(FileDescr): 1887 type: ClassVar[WeightsFormat] 1888 weights_format_name: ClassVar[str] # human readable 1889 1890 source: ImportantFileSource 1891 """∈📦 The weights file.""" 1892 1893 authors: Optional[List[Author]] = None 1894 """Authors 1895 Either the person(s) that have trained this model resulting in the original weights file. 1896 (If this is the initial weights entry, i.e. it does not have a `parent`) 1897 Or the person(s) who have converted the weights to this weights format. 1898 (If this is a child weight, i.e. it has a `parent` field) 1899 """ 1900 1901 parent: Annotated[ 1902 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"]) 1903 ] = None 1904 """The source weights these weights were converted from. 1905 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`, 1906 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights. 1907 All weight entries except one (the initial set of weights resulting from training the model), 1908 need to have this field.""" 1909 1910 @model_validator(mode="after") 1911 def check_parent_is_not_self(self) -> Self: 1912 if self.type == self.parent: 1913 raise ValueError("Weights entry can't be it's own parent.") 1914 1915 return self 1916 1917 1918class KerasHdf5WeightsDescr(WeightsEntryDescrBase): 1919 type = "keras_hdf5" 1920 weights_format_name: ClassVar[str] = "Keras HDF5" 1921 tensorflow_version: Version 1922 """TensorFlow version used to create these weights.""" 1923 1924 1925class OnnxWeightsDescr(WeightsEntryDescrBase): 1926 type = "onnx" 1927 weights_format_name: ClassVar[str] = "ONNX" 1928 opset_version: Annotated[int, Ge(7)] 1929 """ONNX opset version""" 1930 1931 1932class PytorchStateDictWeightsDescr(WeightsEntryDescrBase): 1933 type = "pytorch_state_dict" 1934 weights_format_name: ClassVar[str] = "Pytorch State Dict" 1935 architecture: ArchitectureDescr 1936 pytorch_version: Version 1937 """Version of the PyTorch library used. 1938 If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible. 1939 """ 1940 dependencies: Optional[EnvironmentFileDescr] = None 1941 """Custom depencies beyond pytorch. 1942 The conda environment file should include pytorch and any version pinning has to be compatible with 1943 `pytorch_version`. 1944 """ 1945 1946 1947class TensorflowJsWeightsDescr(WeightsEntryDescrBase): 1948 type = "tensorflow_js" 1949 weights_format_name: ClassVar[str] = "Tensorflow.js" 1950 tensorflow_version: Version 1951 """Version of the TensorFlow library used.""" 1952 1953 source: ImportantFileSource 1954 """∈📦 The multi-file weights. 1955 All required files/folders should be a zip archive.""" 1956 1957 1958class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase): 1959 type = "tensorflow_saved_model_bundle" 1960 weights_format_name: ClassVar[str] = "Tensorflow Saved Model" 1961 tensorflow_version: Version 1962 """Version of the TensorFlow library used.""" 1963 1964 dependencies: Optional[EnvironmentFileDescr] = None 1965 """Custom dependencies beyond tensorflow. 1966 Should include tensorflow and any version pinning has to be compatible with `tensorflow_version`.""" 1967 1968 source: ImportantFileSource 1969 """∈📦 The multi-file weights. 1970 All required files/folders should be a zip archive.""" 1971 1972 1973class TorchscriptWeightsDescr(WeightsEntryDescrBase): 1974 type = "torchscript" 1975 weights_format_name: ClassVar[str] = "TorchScript" 1976 pytorch_version: Version 1977 """Version of the PyTorch library used.""" 1978 1979 1980class WeightsDescr(Node): 1981 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None 1982 onnx: Optional[OnnxWeightsDescr] = None 1983 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None 1984 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None 1985 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = ( 1986 None 1987 ) 1988 torchscript: Optional[TorchscriptWeightsDescr] = None 1989 1990 @model_validator(mode="after") 1991 def check_entries(self) -> Self: 1992 entries = {wtype for wtype, entry in self if entry is not None} 1993 1994 if not entries: 1995 raise ValueError("Missing weights entry") 1996 1997 entries_wo_parent = { 1998 wtype 1999 for wtype, entry in self 2000 if entry is not None and hasattr(entry, "parent") and entry.parent is None 2001 } 2002 if len(entries_wo_parent) != 1: 2003 issue_warning( 2004 "Exactly one weights entry may not specify the `parent` field (got" 2005 + " {value}). That entry is considered the original set of model weights." 2006 + " Other weight formats are created through conversion of the orignal or" 2007 + " already converted weights. They have to reference the weights format" 2008 + " they were converted from as their `parent`.", 2009 value=len(entries_wo_parent), 2010 field="weights", 2011 ) 2012 2013 for wtype, entry in self: 2014 if entry is None: 2015 continue 2016 2017 assert hasattr(entry, "type") 2018 assert hasattr(entry, "parent") 2019 assert wtype == entry.type 2020 if ( 2021 entry.parent is not None and entry.parent not in entries 2022 ): # self reference checked for `parent` field 2023 raise ValueError( 2024 f"`weights.{wtype}.parent={entry.parent} not in specified weight" 2025 + f" formats: {entries}" 2026 ) 2027 2028 return self 2029 2030 2031class ModelId(ResourceId): 2032 pass 2033 2034 2035class LinkedModel(LinkedResourceNode): 2036 """Reference to a bioimage.io model.""" 2037 2038 id: ModelId 2039 """A valid model `id` from the bioimage.io collection.""" 2040 2041 2042class _DataDepSize(NamedTuple): 2043 min: int 2044 max: Optional[int] 2045 2046 2047class _AxisSizes(NamedTuple): 2048 """the lenghts of all axes of model inputs and outputs""" 2049 2050 inputs: Dict[Tuple[TensorId, AxisId], int] 2051 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] 2052 2053 2054class _TensorSizes(NamedTuple): 2055 """_AxisSizes as nested dicts""" 2056 2057 inputs: Dict[TensorId, Dict[AxisId, int]] 2058 outputs: Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]] 2059 2060 2061class ModelDescr(GenericModelDescrBase, title="bioimage.io model specification"): 2062 """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. 2063 These fields are typically stored in a YAML file which we call a model resource description file (model RDF). 2064 """ 2065 2066 format_version: Literal["0.5.3"] = "0.5.3" 2067 """Version of the bioimage.io model description specification used. 2068 When creating a new model always use the latest micro/patch version described here. 2069 The `format_version` is important for any consumer software to understand how to parse the fields. 2070 """ 2071 2072 type: Literal["model"] = "model" 2073 """Specialized resource type 'model'""" 2074 2075 id: Optional[ModelId] = None 2076 """bioimage.io-wide unique resource identifier 2077 assigned by bioimage.io; version **un**specific.""" 2078 2079 authors: NotEmpty[List[Author]] 2080 """The authors are the creators of the model RDF and the primary points of contact.""" 2081 2082 documentation: Annotated[ 2083 DocumentationSource, 2084 Field( 2085 examples=[ 2086 "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/README.md", 2087 "README.md", 2088 ], 2089 ), 2090 ] 2091 """∈📦 URL or relative path to a markdown file with additional documentation. 2092 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory. 2093 The documentation should include a '#[#] Validation' (sub)section 2094 with details on how to quantitatively validate the model on unseen data.""" 2095 2096 @field_validator("documentation", mode="after") 2097 @classmethod 2098 def _validate_documentation(cls, value: DocumentationSource) -> DocumentationSource: 2099 if not validation_context_var.get().perform_io_checks: 2100 return value 2101 2102 doc_path = download(value).path 2103 doc_content = doc_path.read_text(encoding="utf-8") 2104 assert isinstance(doc_content, str) 2105 if not re.match("#.*[vV]alidation", doc_content): 2106 issue_warning( 2107 "No '# Validation' (sub)section found in {value}.", 2108 value=value, 2109 field="documentation", 2110 ) 2111 2112 return value 2113 2114 inputs: NotEmpty[Sequence[InputTensorDescr]] 2115 """Describes the input tensors expected by this model.""" 2116 2117 @field_validator("inputs", mode="after") 2118 @classmethod 2119 def _validate_input_axes( 2120 cls, inputs: Sequence[InputTensorDescr] 2121 ) -> Sequence[InputTensorDescr]: 2122 input_size_refs = cls._get_axes_with_independent_size(inputs) 2123 2124 for i, ipt in enumerate(inputs): 2125 valid_independent_refs: Dict[ 2126 Tuple[TensorId, AxisId], 2127 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2128 ] = { 2129 **{ 2130 (ipt.id, a.id): (ipt, a, a.size) 2131 for a in ipt.axes 2132 if not isinstance(a, BatchAxis) 2133 and isinstance(a.size, (int, ParameterizedSize)) 2134 }, 2135 **input_size_refs, 2136 } 2137 for a, ax in enumerate(ipt.axes): 2138 cls._validate_axis( 2139 "inputs", 2140 i=i, 2141 tensor_id=ipt.id, 2142 a=a, 2143 axis=ax, 2144 valid_independent_refs=valid_independent_refs, 2145 ) 2146 return inputs 2147 2148 @staticmethod 2149 def _validate_axis( 2150 field_name: str, 2151 i: int, 2152 tensor_id: TensorId, 2153 a: int, 2154 axis: AnyAxis, 2155 valid_independent_refs: Dict[ 2156 Tuple[TensorId, AxisId], 2157 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2158 ], 2159 ): 2160 if isinstance(axis, BatchAxis) or isinstance( 2161 axis.size, (int, ParameterizedSize, DataDependentSize) 2162 ): 2163 return 2164 elif not isinstance(axis.size, SizeReference): 2165 assert_never(axis.size) 2166 2167 # validate axis.size SizeReference 2168 ref = (axis.size.tensor_id, axis.size.axis_id) 2169 if ref not in valid_independent_refs: 2170 raise ValueError( 2171 "Invalid tensor axis reference at" 2172 + f" {field_name}[{i}].axes[{a}].size: {axis.size}." 2173 ) 2174 if ref == (tensor_id, axis.id): 2175 raise ValueError( 2176 "Self-referencing not allowed for" 2177 + f" {field_name}[{i}].axes[{a}].size: {axis.size}" 2178 ) 2179 if axis.type == "channel": 2180 if valid_independent_refs[ref][1].type != "channel": 2181 raise ValueError( 2182 "A channel axis' size may only reference another fixed size" 2183 + " channel axis." 2184 ) 2185 if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names: 2186 ref_size = valid_independent_refs[ref][2] 2187 assert isinstance(ref_size, int), ( 2188 "channel axis ref (another channel axis) has to specify fixed" 2189 + " size" 2190 ) 2191 generated_channel_names = [ 2192 Identifier(axis.channel_names.format(i=i)) 2193 for i in range(1, ref_size + 1) 2194 ] 2195 axis.channel_names = generated_channel_names 2196 2197 if (ax_unit := getattr(axis, "unit", None)) != ( 2198 ref_unit := getattr(valid_independent_refs[ref][1], "unit", None) 2199 ): 2200 raise ValueError( 2201 "The units of an axis and its reference axis need to match, but" 2202 + f" '{ax_unit}' != '{ref_unit}'." 2203 ) 2204 ref_axis = valid_independent_refs[ref][1] 2205 if isinstance(ref_axis, BatchAxis): 2206 raise ValueError( 2207 f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}" 2208 + " (a batch axis is not allowed as reference)." 2209 ) 2210 2211 if isinstance(axis, WithHalo): 2212 min_size = axis.size.get_size(axis, ref_axis, n=0) 2213 if (min_size - 2 * axis.halo) < 1: 2214 raise ValueError( 2215 f"axis {axis.id} with minimum size {min_size} is too small for halo" 2216 + f" {axis.halo}." 2217 ) 2218 2219 input_halo = axis.halo * axis.scale / ref_axis.scale 2220 if input_halo != int(input_halo) or input_halo % 2 == 1: 2221 raise ValueError( 2222 f"input_halo {input_halo} (output_halo {axis.halo} *" 2223 + f" output_scale {axis.scale} / input_scale {ref_axis.scale})" 2224 + f" is not an even integer for {tensor_id}.{axis.id}." 2225 ) 2226 2227 @model_validator(mode="after") 2228 def _validate_test_tensors(self) -> Self: 2229 if not validation_context_var.get().perform_io_checks: 2230 return self 2231 2232 test_arrays = [ 2233 load_array(descr.test_tensor.download().path) 2234 for descr in chain(self.inputs, self.outputs) 2235 ] 2236 tensors = { 2237 descr.id: (descr, array) 2238 for descr, array in zip(chain(self.inputs, self.outputs), test_arrays) 2239 } 2240 validate_tensors(tensors, tensor_origin="test_tensor") 2241 return self 2242 2243 @model_validator(mode="after") 2244 def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self: 2245 ipt_refs = {t.id for t in self.inputs} 2246 out_refs = {t.id for t in self.outputs} 2247 for ipt in self.inputs: 2248 for p in ipt.preprocessing: 2249 ref = p.kwargs.get("reference_tensor") 2250 if ref is None: 2251 continue 2252 if ref not in ipt_refs: 2253 raise ValueError( 2254 f"`reference_tensor` '{ref}' not found. Valid input tensor" 2255 + f" references are: {ipt_refs}." 2256 ) 2257 2258 for out in self.outputs: 2259 for p in out.postprocessing: 2260 ref = p.kwargs.get("reference_tensor") 2261 if ref is None: 2262 continue 2263 2264 if ref not in ipt_refs and ref not in out_refs: 2265 raise ValueError( 2266 f"`reference_tensor` '{ref}' not found. Valid tensor references" 2267 + f" are: {ipt_refs | out_refs}." 2268 ) 2269 2270 return self 2271 2272 # TODO: use validate funcs in validate_test_tensors 2273 # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]: 2274 2275 name: Annotated[ 2276 Annotated[ 2277 str, RestrictCharacters(string.ascii_letters + string.digits + "_- ()") 2278 ], 2279 MinLen(5), 2280 MaxLen(128), 2281 warn(MaxLen(64), "Name longer than 64 characters.", INFO), 2282 ] 2283 """A human-readable name of this model. 2284 It should be no longer than 64 characters 2285 and may only contain letter, number, underscore, minus, parentheses and spaces. 2286 We recommend to chose a name that refers to the model's task and image modality. 2287 """ 2288 2289 outputs: NotEmpty[Sequence[OutputTensorDescr]] 2290 """Describes the output tensors.""" 2291 2292 @field_validator("outputs", mode="after") 2293 @classmethod 2294 def _validate_tensor_ids( 2295 cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo 2296 ) -> Sequence[OutputTensorDescr]: 2297 tensor_ids = [ 2298 t.id for t in info.data.get("inputs", []) + info.data.get("outputs", []) 2299 ] 2300 duplicate_tensor_ids: List[str] = [] 2301 seen: Set[str] = set() 2302 for t in tensor_ids: 2303 if t in seen: 2304 duplicate_tensor_ids.append(t) 2305 2306 seen.add(t) 2307 2308 if duplicate_tensor_ids: 2309 raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}") 2310 2311 return outputs 2312 2313 @staticmethod 2314 def _get_axes_with_parameterized_size( 2315 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 2316 ): 2317 return { 2318 f"{t.id}.{a.id}": (t, a, a.size) 2319 for t in io 2320 for a in t.axes 2321 if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize) 2322 } 2323 2324 @staticmethod 2325 def _get_axes_with_independent_size( 2326 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 2327 ): 2328 return { 2329 (t.id, a.id): (t, a, a.size) 2330 for t in io 2331 for a in t.axes 2332 if not isinstance(a, BatchAxis) 2333 and isinstance(a.size, (int, ParameterizedSize)) 2334 } 2335 2336 @field_validator("outputs", mode="after") 2337 @classmethod 2338 def _validate_output_axes( 2339 cls, outputs: List[OutputTensorDescr], info: ValidationInfo 2340 ) -> List[OutputTensorDescr]: 2341 input_size_refs = cls._get_axes_with_independent_size( 2342 info.data.get("inputs", []) 2343 ) 2344 output_size_refs = cls._get_axes_with_independent_size(outputs) 2345 2346 for i, out in enumerate(outputs): 2347 valid_independent_refs: Dict[ 2348 Tuple[TensorId, AxisId], 2349 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2350 ] = { 2351 **{ 2352 (out.id, a.id): (out, a, a.size) 2353 for a in out.axes 2354 if not isinstance(a, BatchAxis) 2355 and isinstance(a.size, (int, ParameterizedSize)) 2356 }, 2357 **input_size_refs, 2358 **output_size_refs, 2359 } 2360 for a, ax in enumerate(out.axes): 2361 cls._validate_axis( 2362 "outputs", 2363 i, 2364 out.id, 2365 a, 2366 ax, 2367 valid_independent_refs=valid_independent_refs, 2368 ) 2369 2370 return outputs 2371 2372 packaged_by: List[Author] = Field(default_factory=list) 2373 """The persons that have packaged and uploaded this model. 2374 Only required if those persons differ from the `authors`.""" 2375 2376 parent: Optional[LinkedModel] = None 2377 """The model from which this model is derived, e.g. by fine-tuning the weights.""" 2378 2379 # todo: add parent self check once we have `id` 2380 # @model_validator(mode="after") 2381 # def validate_parent_is_not_self(self) -> Self: 2382 # if self.parent is not None and self.parent == self.id: 2383 # raise ValueError("The model may not reference itself as parent model") 2384 2385 # return self 2386 2387 run_mode: Annotated[ 2388 Optional[RunMode], 2389 warn(None, "Run mode '{value}' has limited support across consumer softwares."), 2390 ] = None 2391 """Custom run mode for this model: for more complex prediction procedures like test time 2392 data augmentation that currently cannot be expressed in the specification. 2393 No standard run modes are defined yet.""" 2394 2395 timestamp: Datetime = Datetime(datetime.now()) 2396 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format 2397 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat). 2398 (In Python a datetime object is valid, too).""" 2399 2400 training_data: Annotated[ 2401 Union[None, LinkedDataset, DatasetDescr, DatasetDescr02], 2402 Field(union_mode="left_to_right"), 2403 ] = None 2404 """The dataset used to train this model""" 2405 2406 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)] 2407 """The weights for this model. 2408 Weights can be given for different formats, but should otherwise be equivalent. 2409 The available weight formats determine which consumers can use this model.""" 2410 2411 @model_validator(mode="after") 2412 def _add_default_cover(self) -> Self: 2413 if not validation_context_var.get().perform_io_checks or self.covers: 2414 return self 2415 2416 try: 2417 generated_covers = generate_covers( 2418 [(t, load_array(t.test_tensor.download().path)) for t in self.inputs], 2419 [(t, load_array(t.test_tensor.download().path)) for t in self.outputs], 2420 ) 2421 except Exception as e: 2422 issue_warning( 2423 "Failed to generate cover image(s): {e}", 2424 value=self.covers, 2425 msg_context=dict(e=e), 2426 field="covers", 2427 ) 2428 else: 2429 self.covers.extend(generated_covers) 2430 2431 return self 2432 2433 def get_input_test_arrays(self) -> List[NDArray[Any]]: 2434 data = [load_array(ipt.test_tensor.download().path) for ipt in self.inputs] 2435 assert all(isinstance(d, np.ndarray) for d in data) 2436 return data 2437 2438 def get_output_test_arrays(self) -> List[NDArray[Any]]: 2439 data = [load_array(out.test_tensor.download().path) for out in self.outputs] 2440 assert all(isinstance(d, np.ndarray) for d in data) 2441 return data 2442 2443 @staticmethod 2444 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int: 2445 batch_size = 1 2446 tensor_with_batchsize: Optional[TensorId] = None 2447 for tid in tensor_sizes: 2448 for aid, s in tensor_sizes[tid].items(): 2449 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size: 2450 continue 2451 2452 if batch_size != 1: 2453 assert tensor_with_batchsize is not None 2454 raise ValueError( 2455 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})" 2456 ) 2457 2458 batch_size = s 2459 tensor_with_batchsize = tid 2460 2461 return batch_size 2462 2463 def get_output_tensor_sizes( 2464 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]] 2465 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]: 2466 """Returns the tensor output sizes for given **input_sizes**. 2467 Only if **input_sizes** has a valid input shape, the tensor output size is exact. 2468 Otherwise it might be larger than the actual (valid) output""" 2469 batch_size = self.get_batch_size(input_sizes) 2470 ns = self.get_ns(input_sizes) 2471 2472 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size) 2473 return tensor_sizes.outputs 2474 2475 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]): 2476 """get parameter `n` for each parameterized axis 2477 such that the valid input size is >= the given input size""" 2478 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {} 2479 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs} 2480 for tid in input_sizes: 2481 for aid, s in input_sizes[tid].items(): 2482 size_descr = axes[tid][aid].size 2483 if isinstance(size_descr, ParameterizedSize): 2484 ret[(tid, aid)] = size_descr.get_n(s) 2485 elif size_descr is None or isinstance(size_descr, (int, SizeReference)): 2486 pass 2487 else: 2488 assert_never(size_descr) 2489 2490 return ret 2491 2492 def get_tensor_sizes( 2493 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int 2494 ) -> _TensorSizes: 2495 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size) 2496 return _TensorSizes( 2497 { 2498 t: { 2499 aa: axis_sizes.inputs[(tt, aa)] 2500 for tt, aa in axis_sizes.inputs 2501 if tt == t 2502 } 2503 for t in {tt for tt, _ in axis_sizes.inputs} 2504 }, 2505 { 2506 t: { 2507 aa: axis_sizes.outputs[(tt, aa)] 2508 for tt, aa in axis_sizes.outputs 2509 if tt == t 2510 } 2511 for t in {tt for tt, _ in axis_sizes.outputs} 2512 }, 2513 ) 2514 2515 def get_axis_sizes( 2516 self, 2517 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], 2518 batch_size: Optional[int] = None, 2519 *, 2520 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None, 2521 ) -> _AxisSizes: 2522 """Determine input and output block shape for scale factors **ns** 2523 of parameterized input sizes. 2524 2525 Args: 2526 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id)) 2527 that is parameterized as `size = min + n * step`. 2528 batch_size: The desired size of the batch dimension. 2529 If given **batch_size** overwrites any batch size present in 2530 **max_input_shape**. Default 1. 2531 max_input_shape: Limits the derived block shapes. 2532 Each axis for which the input size, parameterized by `n`, is larger 2533 than **max_input_shape** is set to the minimal value `n_min` for which 2534 this is still true. 2535 Use this for small input samples or large values of **ns**. 2536 Or simply whenever you know the full input shape. 2537 2538 Returns: 2539 Resolved axis sizes for model inputs and outputs. 2540 """ 2541 max_input_shape = max_input_shape or {} 2542 if batch_size is None: 2543 for (_t_id, a_id), s in max_input_shape.items(): 2544 if a_id == BATCH_AXIS_ID: 2545 batch_size = s 2546 break 2547 else: 2548 batch_size = 1 2549 2550 all_axes = { 2551 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs) 2552 } 2553 2554 inputs: Dict[Tuple[TensorId, AxisId], int] = {} 2555 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {} 2556 2557 def get_axis_size(a: Union[InputAxis, OutputAxis]): 2558 if isinstance(a, BatchAxis): 2559 if (t_descr.id, a.id) in ns: 2560 logger.warning( 2561 "Ignoring unexpected size increment factor (n) for batch axis" 2562 + " of tensor '{}'.", 2563 t_descr.id, 2564 ) 2565 return batch_size 2566 elif isinstance(a.size, int): 2567 if (t_descr.id, a.id) in ns: 2568 logger.warning( 2569 "Ignoring unexpected size increment factor (n) for fixed size" 2570 + " axis '{}' of tensor '{}'.", 2571 a.id, 2572 t_descr.id, 2573 ) 2574 return a.size 2575 elif isinstance(a.size, ParameterizedSize): 2576 if (t_descr.id, a.id) not in ns: 2577 raise ValueError( 2578 "Size increment factor (n) missing for parametrized axis" 2579 + f" '{a.id}' of tensor '{t_descr.id}'." 2580 ) 2581 n = ns[(t_descr.id, a.id)] 2582 s_max = max_input_shape.get((t_descr.id, a.id)) 2583 if s_max is not None: 2584 n = min(n, a.size.get_n(s_max)) 2585 2586 return a.size.get_size(n) 2587 2588 elif isinstance(a.size, SizeReference): 2589 if (t_descr.id, a.id) in ns: 2590 logger.warning( 2591 "Ignoring unexpected size increment factor (n) for axis '{}'" 2592 + " of tensor '{}' with size reference.", 2593 a.id, 2594 t_descr.id, 2595 ) 2596 assert not isinstance(a, BatchAxis) 2597 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id] 2598 assert not isinstance(ref_axis, BatchAxis) 2599 ref_key = (a.size.tensor_id, a.size.axis_id) 2600 ref_size = inputs.get(ref_key, outputs.get(ref_key)) 2601 assert ref_size is not None, ref_key 2602 assert not isinstance(ref_size, _DataDepSize), ref_key 2603 return a.size.get_size( 2604 axis=a, 2605 ref_axis=ref_axis, 2606 ref_size=ref_size, 2607 ) 2608 elif isinstance(a.size, DataDependentSize): 2609 if (t_descr.id, a.id) in ns: 2610 logger.warning( 2611 "Ignoring unexpected increment factor (n) for data dependent" 2612 + " size axis '{}' of tensor '{}'.", 2613 a.id, 2614 t_descr.id, 2615 ) 2616 return _DataDepSize(a.size.min, a.size.max) 2617 else: 2618 assert_never(a.size) 2619 2620 # first resolve all , but the `SizeReference` input sizes 2621 for t_descr in self.inputs: 2622 for a in t_descr.axes: 2623 if not isinstance(a.size, SizeReference): 2624 s = get_axis_size(a) 2625 assert not isinstance(s, _DataDepSize) 2626 inputs[t_descr.id, a.id] = s 2627 2628 # resolve all other input axis sizes 2629 for t_descr in self.inputs: 2630 for a in t_descr.axes: 2631 if isinstance(a.size, SizeReference): 2632 s = get_axis_size(a) 2633 assert not isinstance(s, _DataDepSize) 2634 inputs[t_descr.id, a.id] = s 2635 2636 # resolve all output axis sizes 2637 for t_descr in self.outputs: 2638 for a in t_descr.axes: 2639 assert not isinstance(a.size, ParameterizedSize) 2640 s = get_axis_size(a) 2641 outputs[t_descr.id, a.id] = s 2642 2643 return _AxisSizes(inputs=inputs, outputs=outputs) 2644 2645 @model_validator(mode="before") 2646 @classmethod 2647 def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]: 2648 if ( 2649 data.get("type") == "model" 2650 and isinstance(fv := data.get("format_version"), str) 2651 and fv.count(".") == 2 2652 ): 2653 fv_parts = fv.split(".") 2654 if any(not p.isdigit() for p in fv_parts): 2655 return data 2656 2657 fv_tuple = tuple(map(int, fv_parts)) 2658 2659 assert cls.implemented_format_version_tuple[0:2] == (0, 5) 2660 if fv_tuple[:2] in ((0, 3), (0, 4)): 2661 m04 = _ModelDescr_v0_4.load(data) 2662 if not isinstance(m04, InvalidDescr): 2663 return _model_conv.convert_as_dict(m04) 2664 elif fv_tuple[:2] == (0, 5): 2665 # bump patch version 2666 data["format_version"] = cls.implemented_format_version 2667 2668 return data 2669 2670 2671class _ModelConv(Converter[_ModelDescr_v0_4, ModelDescr]): 2672 def _convert( 2673 self, src: _ModelDescr_v0_4, tgt: "type[ModelDescr] | type[dict[str, Any]]" 2674 ) -> "ModelDescr | dict[str, Any]": 2675 name = "".join( 2676 c if c in string.ascii_letters + string.digits + "_- ()" else " " 2677 for c in src.name 2678 ) 2679 2680 def conv_authors(auths: Optional[Sequence[_Author_v0_4]]): 2681 conv = ( 2682 _author_conv.convert if TYPE_CHECKING else _author_conv.convert_as_dict 2683 ) 2684 return None if auths is None else [conv(a) for a in auths] 2685 2686 if TYPE_CHECKING: 2687 arch_file_conv = _arch_file_conv.convert 2688 arch_lib_conv = _arch_lib_conv.convert 2689 else: 2690 arch_file_conv = _arch_file_conv.convert_as_dict 2691 arch_lib_conv = _arch_lib_conv.convert_as_dict 2692 2693 input_size_refs = { 2694 ipt.name: { 2695 a: s 2696 for a, s in zip( 2697 ipt.axes, 2698 ( 2699 ipt.shape.min 2700 if isinstance(ipt.shape, _ParameterizedInputShape_v0_4) 2701 else ipt.shape 2702 ), 2703 ) 2704 } 2705 for ipt in src.inputs 2706 if ipt.shape 2707 } 2708 output_size_refs = { 2709 **{ 2710 out.name: {a: s for a, s in zip(out.axes, out.shape)} 2711 for out in src.outputs 2712 if not isinstance(out.shape, _ImplicitOutputShape_v0_4) 2713 }, 2714 **input_size_refs, 2715 } 2716 2717 return tgt( 2718 attachments=( 2719 [] 2720 if src.attachments is None 2721 else [FileDescr(source=f) for f in src.attachments.files] 2722 ), 2723 authors=[ 2724 _author_conv.convert_as_dict(a) for a in src.authors 2725 ], # pyright: ignore[reportArgumentType] 2726 cite=[ 2727 {"text": c.text, "doi": c.doi, "url": c.url} for c in src.cite 2728 ], # pyright: ignore[reportArgumentType] 2729 config=src.config, 2730 covers=src.covers, 2731 description=src.description, 2732 documentation=src.documentation, 2733 format_version="0.5.3", 2734 git_repo=src.git_repo, # pyright: ignore[reportArgumentType] 2735 icon=src.icon, 2736 id=None if src.id is None else ModelId(src.id), 2737 id_emoji=src.id_emoji, 2738 license=src.license, # type: ignore 2739 links=src.links, 2740 maintainers=[ 2741 _maintainer_conv.convert_as_dict(m) for m in src.maintainers 2742 ], # pyright: ignore[reportArgumentType] 2743 name=name, 2744 tags=src.tags, 2745 type=src.type, 2746 uploader=src.uploader, 2747 version=src.version, 2748 inputs=[ # pyright: ignore[reportArgumentType] 2749 _input_tensor_conv.convert_as_dict(ipt, tt, st, input_size_refs) 2750 for ipt, tt, st, in zip( 2751 src.inputs, 2752 src.test_inputs, 2753 src.sample_inputs or [None] * len(src.test_inputs), 2754 ) 2755 ], 2756 outputs=[ # pyright: ignore[reportArgumentType] 2757 _output_tensor_conv.convert_as_dict(out, tt, st, output_size_refs) 2758 for out, tt, st, in zip( 2759 src.outputs, 2760 src.test_outputs, 2761 src.sample_outputs or [None] * len(src.test_outputs), 2762 ) 2763 ], 2764 parent=( 2765 None 2766 if src.parent is None 2767 else LinkedModel( 2768 id=ModelId( 2769 str(src.parent.id) 2770 + ( 2771 "" 2772 if src.parent.version_number is None 2773 else f"/{src.parent.version_number}" 2774 ) 2775 ) 2776 ) 2777 ), 2778 training_data=( 2779 None 2780 if src.training_data is None 2781 else ( 2782 LinkedDataset( 2783 id=DatasetId( 2784 str(src.training_data.id) 2785 + ( 2786 "" 2787 if src.training_data.version_number is None 2788 else f"/{src.training_data.version_number}" 2789 ) 2790 ) 2791 ) 2792 if isinstance(src.training_data, LinkedDataset02) 2793 else src.training_data 2794 ) 2795 ), 2796 packaged_by=[ 2797 _author_conv.convert_as_dict(a) for a in src.packaged_by 2798 ], # pyright: ignore[reportArgumentType] 2799 run_mode=src.run_mode, 2800 timestamp=src.timestamp, 2801 weights=(WeightsDescr if TYPE_CHECKING else dict)( 2802 keras_hdf5=(w := src.weights.keras_hdf5) 2803 and (KerasHdf5WeightsDescr if TYPE_CHECKING else dict)( 2804 authors=conv_authors(w.authors), 2805 source=w.source, 2806 tensorflow_version=w.tensorflow_version or Version("1.15"), 2807 parent=w.parent, 2808 ), 2809 onnx=(w := src.weights.onnx) 2810 and (OnnxWeightsDescr if TYPE_CHECKING else dict)( 2811 source=w.source, 2812 authors=conv_authors(w.authors), 2813 parent=w.parent, 2814 opset_version=w.opset_version or 15, 2815 ), 2816 pytorch_state_dict=(w := src.weights.pytorch_state_dict) 2817 and (PytorchStateDictWeightsDescr if TYPE_CHECKING else dict)( 2818 source=w.source, 2819 authors=conv_authors(w.authors), 2820 parent=w.parent, 2821 architecture=( 2822 arch_file_conv( 2823 w.architecture, 2824 w.architecture_sha256, 2825 w.kwargs, 2826 ) 2827 if isinstance(w.architecture, _CallableFromFile_v0_4) 2828 else arch_lib_conv(w.architecture, w.kwargs) 2829 ), 2830 pytorch_version=w.pytorch_version or Version("1.10"), 2831 dependencies=( 2832 None 2833 if w.dependencies is None 2834 else (EnvironmentFileDescr if TYPE_CHECKING else dict)( 2835 source=cast( 2836 ImportantFileSource, 2837 str(deps := w.dependencies)[ 2838 ( 2839 len("conda:") 2840 if str(deps).startswith("conda:") 2841 else 0 2842 ) : 2843 ], 2844 ) 2845 ) 2846 ), 2847 ), 2848 tensorflow_js=(w := src.weights.tensorflow_js) 2849 and (TensorflowJsWeightsDescr if TYPE_CHECKING else dict)( 2850 source=w.source, 2851 authors=conv_authors(w.authors), 2852 parent=w.parent, 2853 tensorflow_version=w.tensorflow_version or Version("1.15"), 2854 ), 2855 tensorflow_saved_model_bundle=( 2856 w := src.weights.tensorflow_saved_model_bundle 2857 ) 2858 and (TensorflowSavedModelBundleWeightsDescr if TYPE_CHECKING else dict)( 2859 authors=conv_authors(w.authors), 2860 parent=w.parent, 2861 source=w.source, 2862 tensorflow_version=w.tensorflow_version or Version("1.15"), 2863 dependencies=( 2864 None 2865 if w.dependencies is None 2866 else (EnvironmentFileDescr if TYPE_CHECKING else dict)( 2867 source=cast( 2868 ImportantFileSource, 2869 ( 2870 str(w.dependencies)[len("conda:") :] 2871 if str(w.dependencies).startswith("conda:") 2872 else str(w.dependencies) 2873 ), 2874 ) 2875 ) 2876 ), 2877 ), 2878 torchscript=(w := src.weights.torchscript) 2879 and (TorchscriptWeightsDescr if TYPE_CHECKING else dict)( 2880 source=w.source, 2881 authors=conv_authors(w.authors), 2882 parent=w.parent, 2883 pytorch_version=w.pytorch_version or Version("1.10"), 2884 ), 2885 ), 2886 ) 2887 2888 2889_model_conv = _ModelConv(_ModelDescr_v0_4, ModelDescr) 2890 2891 2892# create better cover images for 3d data and non-image outputs 2893def generate_covers( 2894 inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]], 2895 outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]], 2896) -> List[Path]: 2897 def squeeze( 2898 data: NDArray[Any], axes: Sequence[AnyAxis] 2899 ) -> Tuple[NDArray[Any], List[AnyAxis]]: 2900 """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining""" 2901 if data.ndim != len(axes): 2902 raise ValueError( 2903 f"tensor shape {data.shape} does not match described axes" 2904 + f" {[a.id for a in axes]}" 2905 ) 2906 2907 axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1] 2908 return data.squeeze(), axes 2909 2910 def normalize( 2911 data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7 2912 ) -> NDArray[np.float32]: 2913 data = data.astype("float32") 2914 data -= data.min(axis=axis, keepdims=True) 2915 data /= data.max(axis=axis, keepdims=True) + eps 2916 return data 2917 2918 def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]): 2919 original_shape = data.shape 2920 data, axes = squeeze(data, axes) 2921 2922 # take slice fom any batch or index axis if needed 2923 # and convert the first channel axis and take a slice from any additional channel axes 2924 slices: Tuple[slice, ...] = () 2925 ndim = data.ndim 2926 ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2 2927 has_c_axis = False 2928 for i, a in enumerate(axes): 2929 s = data.shape[i] 2930 assert s > 1 2931 if ( 2932 isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis)) 2933 and ndim > ndim_need 2934 ): 2935 data = data[slices + (slice(s // 2 - 1, s // 2),)] 2936 ndim -= 1 2937 elif isinstance(a, ChannelAxis): 2938 if has_c_axis: 2939 # second channel axis 2940 data = data[slices + (slice(0, 1),)] 2941 ndim -= 1 2942 else: 2943 has_c_axis = True 2944 if s == 2: 2945 # visualize two channels with cyan and magenta 2946 data = np.concatenate( 2947 [ 2948 data[slices + (slice(1, 2),)], 2949 data[slices + (slice(0, 1),)], 2950 ( 2951 data[slices + (slice(0, 1),)] 2952 + data[slices + (slice(1, 2),)] 2953 ) 2954 / 2, # TODO: take maximum instead? 2955 ], 2956 axis=i, 2957 ) 2958 elif data.shape[i] == 3: 2959 pass # visualize 3 channels as RGB 2960 else: 2961 # visualize first 3 channels as RGB 2962 data = data[slices + (slice(3),)] 2963 2964 assert data.shape[i] == 3 2965 2966 slices += (slice(None),) 2967 2968 data, axes = squeeze(data, axes) 2969 assert len(axes) == ndim 2970 # take slice from z axis if needed 2971 slices = () 2972 if ndim > ndim_need: 2973 for i, a in enumerate(axes): 2974 s = data.shape[i] 2975 if a.id == AxisId("z"): 2976 data = data[slices + (slice(s // 2 - 1, s // 2),)] 2977 data, axes = squeeze(data, axes) 2978 ndim -= 1 2979 break 2980 2981 slices += (slice(None),) 2982 2983 # take slice from any space or time axis 2984 slices = () 2985 2986 for i, a in enumerate(axes): 2987 if ndim <= ndim_need: 2988 break 2989 2990 s = data.shape[i] 2991 assert s > 1 2992 if isinstance( 2993 a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis) 2994 ): 2995 data = data[slices + (slice(s // 2 - 1, s // 2),)] 2996 ndim -= 1 2997 2998 slices += (slice(None),) 2999 3000 del slices 3001 data, axes = squeeze(data, axes) 3002 assert len(axes) == ndim 3003 3004 if (has_c_axis and ndim != 3) or ndim != 2: 3005 raise ValueError( 3006 f"Failed to construct cover image from shape {original_shape}" 3007 ) 3008 3009 if not has_c_axis: 3010 assert ndim == 2 3011 data = np.repeat(data[:, :, None], 3, axis=2) 3012 axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB")))) 3013 ndim += 1 3014 3015 assert ndim == 3 3016 3017 # transpose axis order such that longest axis comes first... 3018 axis_order = list(np.argsort(list(data.shape))) 3019 axis_order.reverse() 3020 # ... and channel axis is last 3021 c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0] 3022 axis_order.append(axis_order.pop(c)) 3023 axes = [axes[ao] for ao in axis_order] 3024 data = data.transpose(axis_order) 3025 3026 # h, w = data.shape[:2] 3027 # if h / w in (1.0 or 2.0): 3028 # pass 3029 # elif h / w < 2: 3030 # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images 3031 3032 norm_along = ( 3033 tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None 3034 ) 3035 # normalize the data and map to 8 bit 3036 data = normalize(data, norm_along) 3037 data = (data * 255).astype("uint8") 3038 3039 return data 3040 3041 def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]): 3042 assert im0.dtype == im1.dtype == np.uint8 3043 assert im0.shape == im1.shape 3044 assert im0.ndim == 3 3045 N, M, C = im0.shape 3046 assert C == 3 3047 out = np.ones((N, M, C), dtype="uint8") 3048 for c in range(C): 3049 outc = np.tril(im0[..., c]) 3050 mask = outc == 0 3051 outc[mask] = np.triu(im1[..., c])[mask] 3052 out[..., c] = outc 3053 3054 return out 3055 3056 ipt_descr, ipt = inputs[0] 3057 out_descr, out = outputs[0] 3058 3059 ipt_img = to_2d_image(ipt, ipt_descr.axes) 3060 out_img = to_2d_image(out, out_descr.axes) 3061 3062 cover_folder = Path(mkdtemp()) 3063 if ipt_img.shape == out_img.shape: 3064 covers = [cover_folder / "cover.png"] 3065 imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img)) 3066 else: 3067 covers = [cover_folder / "input.png", cover_folder / "output.png"] 3068 imwrite(covers[0], ipt_img) 3069 imwrite(covers[1], out_img) 3070 3071 return covers
193class TensorId(LowerCaseIdentifier): 194 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 195 Annotated[LowerCaseIdentifierAnno, MaxLen(32)] 196 ]
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.
199class AxisId(LowerCaseIdentifier): 200 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 201 Annotated[LowerCaseIdentifierAnno, MaxLen(16)] 202 ]
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.
243class ParameterizedSize(Node): 244 """Describes a range of valid tensor axis sizes as `size = min + n*step`.""" 245 246 N: ClassVar[Type[int]] = ParameterizedSize_N 247 """integer to parameterize this axis""" 248 249 min: Annotated[int, Gt(0)] 250 step: Annotated[int, Gt(0)] 251 252 def validate_size(self, size: int) -> int: 253 if size < self.min: 254 raise ValueError(f"size {size} < {self.min}") 255 if (size - self.min) % self.step != 0: 256 raise ValueError( 257 f"axis of size {size} is not parameterized by `min + n*step` =" 258 + f" `{self.min} + n*{self.step}`" 259 ) 260 261 return size 262 263 def get_size(self, n: ParameterizedSize_N) -> int: 264 return self.min + self.step * n 265 266 def get_n(self, s: int) -> ParameterizedSize_N: 267 """return smallest n parameterizing a size greater or equal than `s`""" 268 return ceil((s - self.min) / self.step)
Describes a range of valid tensor axis sizes as size = min + n*step
.
252 def validate_size(self, size: int) -> int: 253 if size < self.min: 254 raise ValueError(f"size {size} < {self.min}") 255 if (size - self.min) % self.step != 0: 256 raise ValueError( 257 f"axis of size {size} is not parameterized by `min + n*step` =" 258 + f" `{self.min} + n*{self.step}`" 259 ) 260 261 return size
271class DataDependentSize(Node): 272 min: Annotated[int, Gt(0)] = 1 273 max: Annotated[Optional[int], Gt(1)] = None 274 275 @model_validator(mode="after") 276 def _validate_max_gt_min(self): 277 if self.max is not None and self.min >= self.max: 278 raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}") 279 280 return self 281 282 def validate_size(self, size: int) -> int: 283 if size < self.min: 284 raise ValueError(f"size {size} < {self.min}") 285 286 if self.max is not None and size > self.max: 287 raise ValueError(f"size {size} > {self.max}") 288 289 return size
Subpart of a resource description
292class SizeReference(Node): 293 """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis. 294 295 `axis.size = reference.size * reference.scale / axis.scale + offset` 296 297 note: 298 1. The axis and the referenced axis need to have the same unit (or no unit). 299 2. Batch axes may not be referenced. 300 3. Fractions are rounded down. 301 4. If the reference axis is `concatenable` the referencing axis is assumed to be 302 `concatenable` as well with the same block order. 303 304 example: 305 An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm². 306 Let's assume that we want to express the image height h in relation to its width w 307 instead of only accepting input images of exactly 100*49 pixels 308 (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`). 309 310 >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2) 311 >>> h = SpaceInputAxis( 312 ... id=AxisId("h"), 313 ... size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1), 314 ... unit="millimeter", 315 ... scale=4, 316 ... ) 317 >>> print(h.size.compute(h, w)) 318 49 319 320 -> h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49 321 """ 322 323 tensor_id: TensorId 324 """tensor id of the reference axis""" 325 326 axis_id: AxisId 327 """axis id of the reference axis""" 328 329 offset: int = 0 330 331 def get_size( 332 self, 333 axis: Union[ 334 ChannelAxis, 335 IndexInputAxis, 336 IndexOutputAxis, 337 TimeInputAxis, 338 SpaceInputAxis, 339 TimeOutputAxis, 340 TimeOutputAxisWithHalo, 341 SpaceOutputAxis, 342 SpaceOutputAxisWithHalo, 343 ], 344 ref_axis: Union[ 345 ChannelAxis, 346 IndexInputAxis, 347 IndexOutputAxis, 348 TimeInputAxis, 349 SpaceInputAxis, 350 TimeOutputAxis, 351 TimeOutputAxisWithHalo, 352 SpaceOutputAxis, 353 SpaceOutputAxisWithHalo, 354 ], 355 n: ParameterizedSize_N = 0, 356 ref_size: Optional[int] = None, 357 ): 358 """Compute the concrete size for a given axis and its reference axis. 359 360 Args: 361 axis: The axis this `SizeReference` is the size of. 362 ref_axis: The reference axis to compute the size from. 363 n: If the **ref_axis** is parameterized (of type `ParameterizedSize`) 364 and no fixed **ref_size** is given, 365 **n** is used to compute the size of the parameterized **ref_axis**. 366 ref_size: Overwrite the reference size instead of deriving it from 367 **ref_axis** 368 (**ref_axis.scale** is still used; any given **n** is ignored). 369 """ 370 assert ( 371 axis.size == self 372 ), "Given `axis.size` is not defined by this `SizeReference`" 373 374 assert ( 375 ref_axis.id == self.axis_id 376 ), f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}." 377 378 assert axis.unit == ref_axis.unit, ( 379 "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`," 380 f" but {axis.unit}!={ref_axis.unit}" 381 ) 382 if ref_size is None: 383 if isinstance(ref_axis.size, (int, float)): 384 ref_size = ref_axis.size 385 elif isinstance(ref_axis.size, ParameterizedSize): 386 ref_size = ref_axis.size.get_size(n) 387 elif isinstance(ref_axis.size, DataDependentSize): 388 raise ValueError( 389 "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`." 390 ) 391 elif isinstance(ref_axis.size, SizeReference): 392 raise ValueError( 393 "Reference axis referenced in `SizeReference` may not be sized by a" 394 + " `SizeReference` itself." 395 ) 396 else: 397 assert_never(ref_axis.size) 398 399 return int(ref_size * ref_axis.scale / axis.scale + self.offset) 400 401 @staticmethod 402 def _get_unit( 403 axis: Union[ 404 ChannelAxis, 405 IndexInputAxis, 406 IndexOutputAxis, 407 TimeInputAxis, 408 SpaceInputAxis, 409 TimeOutputAxis, 410 TimeOutputAxisWithHalo, 411 SpaceOutputAxis, 412 SpaceOutputAxisWithHalo, 413 ], 414 ): 415 return axis.unit
A tensor axis size (extent in pixels/frames) defined in relation to a reference axis.
axis.size = reference.size * reference.scale / axis.scale + offset
note:
- The axis and the referenced axis need to have the same unit (or no unit).
- Batch axes may not be referenced.
- Fractions are rounded down.
- If the reference axis is
concatenable
the referencing axis is assumed to beconcatenable
as well with the same block order.
example:
An unisotropic input image of wh=10049 pixels depicts a phsical space of 200196mm².
Let's assume that we want to express the image height h in relation to its width w
instead of only accepting input images of exactly 10049 pixels
(for example to express a range of valid image shapes by parametrizing w, see ParameterizedSize
).
>>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2)
>>> h = SpaceInputAxis(
... id=AxisId("h"),
... size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1),
... unit="millimeter",
... scale=4,
... )
>>> print(h.size.compute(h, w))
49
-> h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49
331 def get_size( 332 self, 333 axis: Union[ 334 ChannelAxis, 335 IndexInputAxis, 336 IndexOutputAxis, 337 TimeInputAxis, 338 SpaceInputAxis, 339 TimeOutputAxis, 340 TimeOutputAxisWithHalo, 341 SpaceOutputAxis, 342 SpaceOutputAxisWithHalo, 343 ], 344 ref_axis: Union[ 345 ChannelAxis, 346 IndexInputAxis, 347 IndexOutputAxis, 348 TimeInputAxis, 349 SpaceInputAxis, 350 TimeOutputAxis, 351 TimeOutputAxisWithHalo, 352 SpaceOutputAxis, 353 SpaceOutputAxisWithHalo, 354 ], 355 n: ParameterizedSize_N = 0, 356 ref_size: Optional[int] = None, 357 ): 358 """Compute the concrete size for a given axis and its reference axis. 359 360 Args: 361 axis: The axis this `SizeReference` is the size of. 362 ref_axis: The reference axis to compute the size from. 363 n: If the **ref_axis** is parameterized (of type `ParameterizedSize`) 364 and no fixed **ref_size** is given, 365 **n** is used to compute the size of the parameterized **ref_axis**. 366 ref_size: Overwrite the reference size instead of deriving it from 367 **ref_axis** 368 (**ref_axis.scale** is still used; any given **n** is ignored). 369 """ 370 assert ( 371 axis.size == self 372 ), "Given `axis.size` is not defined by this `SizeReference`" 373 374 assert ( 375 ref_axis.id == self.axis_id 376 ), f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}." 377 378 assert axis.unit == ref_axis.unit, ( 379 "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`," 380 f" but {axis.unit}!={ref_axis.unit}" 381 ) 382 if ref_size is None: 383 if isinstance(ref_axis.size, (int, float)): 384 ref_size = ref_axis.size 385 elif isinstance(ref_axis.size, ParameterizedSize): 386 ref_size = ref_axis.size.get_size(n) 387 elif isinstance(ref_axis.size, DataDependentSize): 388 raise ValueError( 389 "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`." 390 ) 391 elif isinstance(ref_axis.size, SizeReference): 392 raise ValueError( 393 "Reference axis referenced in `SizeReference` may not be sized by a" 394 + " `SizeReference` itself." 395 ) 396 else: 397 assert_never(ref_axis.size) 398 399 return int(ref_size * ref_axis.scale / axis.scale + self.offset)
Compute the concrete size for a given axis and its reference axis.
Arguments:
- axis: The axis this
SizeReference
is the size of. - ref_axis: The reference axis to compute the size from.
- n: If the ref_axis is parameterized (of type
ParameterizedSize
) and no fixed ref_size is given, n is used to compute the size of the parameterized ref_axis. - ref_size: Overwrite the reference size instead of deriving it from ref_axis (ref_axis.scale is still used; any given n is ignored).
420class AxisBase(NodeWithExplicitlySetFields): 421 fields_to_set_explicitly: ClassVar[FrozenSet[LiteralString]] = frozenset({"type"}) 422 423 id: AxisId 424 """An axis id unique across all axes of one tensor.""" 425 426 description: Annotated[str, MaxLen(128)] = ""
Subpart of a resource description
429class WithHalo(Node): 430 halo: Annotated[int, Ge(1)] 431 """The halo should be cropped from the output tensor to avoid boundary effects. 432 It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`. 433 To document a halo that is already cropped by the model use `size.offset` instead.""" 434 435 size: Annotated[ 436 SizeReference, 437 Field( 438 examples=[ 439 10, 440 SizeReference( 441 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 442 ).model_dump(mode="json"), 443 ] 444 ), 445 ] 446 """reference to another axis with an optional offset (see `SizeReference`)"""
Subpart of a resource description
The halo should be cropped from the output tensor to avoid boundary effects.
It is to be cropped from both sides, i.e. size_after_crop = size - 2 * halo
.
To document a halo that is already cropped by the model use size.offset
instead.
reference to another axis with an optional offset (see SizeReference
)
452class BatchAxis(AxisBase): 453 type: Literal["batch"] = "batch" 454 id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID 455 size: Optional[Literal[1]] = None 456 """The batch size may be fixed to 1, 457 otherwise (the default) it may be chosen arbitrarily depending on available memory""" 458 459 @property 460 def scale(self): 461 return 1.0 462 463 @property 464 def concatenable(self): 465 return True 466 467 @property 468 def unit(self): 469 return None
Subpart of a resource description
The batch size may be fixed to 1, otherwise (the default) it may be chosen arbitrarily depending on available memory
Inherited Members
472class ChannelAxis(AxisBase): 473 type: Literal["channel"] = "channel" 474 id: NonBatchAxisId = AxisId("channel") 475 channel_names: NotEmpty[List[Identifier]] 476 477 @property 478 def size(self) -> int: 479 return len(self.channel_names) 480 481 @property 482 def concatenable(self): 483 return False 484 485 @property 486 def scale(self) -> float: 487 return 1.0 488 489 @property 490 def unit(self): 491 return None
Subpart of a resource description
Inherited Members
494class IndexAxisBase(AxisBase): 495 type: Literal["index"] = "index" 496 id: NonBatchAxisId = AxisId("index") 497 498 @property 499 def scale(self) -> float: 500 return 1.0 501 502 @property 503 def unit(self): 504 return None
Subpart of a resource description
Inherited Members
527class IndexInputAxis(IndexAxisBase, _WithInputAxisSize): 528 concatenable: bool = False 529 """If a model has a `concatenable` input axis, it can be processed blockwise, 530 splitting a longer sample axis into blocks matching its input tensor description. 531 Output axes are concatenable if they have a `SizeReference` to a concatenable 532 input axis. 533 """
Subpart of a resource description
If a model has a concatenable
input axis, it can be processed blockwise,
splitting a longer sample axis into blocks matching its input tensor description.
Output axes are concatenable if they have a SizeReference
to a concatenable
input axis.
Inherited Members
536class IndexOutputAxis(IndexAxisBase): 537 size: Annotated[ 538 Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize], 539 Field( 540 examples=[ 541 10, 542 SizeReference( 543 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 544 ).model_dump(mode="json"), 545 ] 546 ), 547 ] 548 """The size/length of this axis can be specified as 549 - fixed integer 550 - reference to another axis with an optional offset (`SizeReference`) 551 - data dependent size using `DataDependentSize` (size is only known after model inference) 552 """
Subpart of a resource description
The size/length of this axis can be specified as
- fixed integer
- reference to another axis with an optional offset (
SizeReference
) - data dependent size using
DataDependentSize
(size is only known after model inference)
Inherited Members
555class TimeAxisBase(AxisBase): 556 type: Literal["time"] = "time" 557 id: NonBatchAxisId = AxisId("time") 558 unit: Optional[TimeUnit] = None 559 scale: Annotated[float, Gt(0)] = 1.0
Subpart of a resource description
Inherited Members
562class TimeInputAxis(TimeAxisBase, _WithInputAxisSize): 563 concatenable: bool = False 564 """If a model has a `concatenable` input axis, it can be processed blockwise, 565 splitting a longer sample axis into blocks matching its input tensor description. 566 Output axes are concatenable if they have a `SizeReference` to a concatenable 567 input axis. 568 """
Subpart of a resource description
If a model has a concatenable
input axis, it can be processed blockwise,
splitting a longer sample axis into blocks matching its input tensor description.
Output axes are concatenable if they have a SizeReference
to a concatenable
input axis.
Inherited Members
571class SpaceAxisBase(AxisBase): 572 type: Literal["space"] = "space" 573 id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x") 574 unit: Optional[SpaceUnit] = None 575 scale: Annotated[float, Gt(0)] = 1.0
Subpart of a resource description
An axis id unique across all axes of one tensor.
Inherited Members
578class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize): 579 concatenable: bool = False 580 """If a model has a `concatenable` input axis, it can be processed blockwise, 581 splitting a longer sample axis into blocks matching its input tensor description. 582 Output axes are concatenable if they have a `SizeReference` to a concatenable 583 input axis. 584 """
Subpart of a resource description
If a model has a concatenable
input axis, it can be processed blockwise,
splitting a longer sample axis into blocks matching its input tensor description.
Output axes are concatenable if they have a SizeReference
to a concatenable
input axis.
Inherited Members
Subpart of a resource description
Inherited Members
Subpart of a resource description
Inherited Members
Subpart of a resource description
Inherited Members
Subpart of a resource description
Inherited Members
682class NominalOrOrdinalDataDescr(Node): 683 values: TVs 684 """A fixed set of nominal or an ascending sequence of ordinal values. 685 In this case `data_type` is required to be an unsigend integer type, e.g. 'uint8'. 686 String `values` are interpreted as labels for tensor values 0, ..., N. 687 Note: as YAML 1.2 does not natively support a "set" datatype, 688 nominal values should be given as a sequence (aka list/array) as well. 689 """ 690 691 type: Annotated[ 692 NominalOrOrdinalDType, 693 Field( 694 examples=[ 695 "float32", 696 "uint8", 697 "uint16", 698 "int64", 699 "bool", 700 ], 701 ), 702 ] = "uint8" 703 704 @model_validator(mode="after") 705 def _validate_values_match_type( 706 self, 707 ) -> Self: 708 incompatible: List[Any] = [] 709 for v in self.values: 710 if self.type == "bool": 711 if not isinstance(v, bool): 712 incompatible.append(v) 713 elif self.type in DTYPE_LIMITS: 714 if ( 715 isinstance(v, (int, float)) 716 and ( 717 v < DTYPE_LIMITS[self.type].min 718 or v > DTYPE_LIMITS[self.type].max 719 ) 720 or (isinstance(v, str) and "uint" not in self.type) 721 or (isinstance(v, float) and "int" in self.type) 722 ): 723 incompatible.append(v) 724 else: 725 incompatible.append(v) 726 727 if len(incompatible) == 5: 728 incompatible.append("...") 729 break 730 731 if incompatible: 732 raise ValueError( 733 f"data type '{self.type}' incompatible with values {incompatible}" 734 ) 735 736 return self 737 738 unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None 739 740 @property 741 def range(self): 742 if isinstance(self.values[0], str): 743 return 0, len(self.values) - 1 744 else: 745 return min(self.values), max(self.values)
Subpart of a resource description
A fixed set of nominal or an ascending sequence of ordinal values.
In this case data_type
is required to be an unsigend integer type, e.g. 'uint8'.
String values
are interpreted as labels for tensor values 0, ..., N.
Note: as YAML 1.2 does not natively support a "set" datatype,
nominal values should be given as a sequence (aka list/array) as well.
762class IntervalOrRatioDataDescr(Node): 763 type: Annotated[ # todo: rename to dtype 764 IntervalOrRatioDType, 765 Field( 766 examples=["float32", "float64", "uint8", "uint16"], 767 ), 768 ] = "float32" 769 range: Tuple[Optional[float], Optional[float]] = ( 770 None, 771 None, 772 ) 773 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor. 774 `None` corresponds to min/max of what can be expressed by `data_type`.""" 775 unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit" 776 scale: float = 1.0 777 """Scale for data on an interval (or ratio) scale.""" 778 offset: Optional[float] = None 779 """Offset for data on a ratio scale."""
Subpart of a resource description
785class ProcessingDescrBase(NodeWithExplicitlySetFields, ABC): 786 """processing base class""" 787 788 # id: Literal[PreprocessingId, PostprocessingId] # make abstract field 789 fields_to_set_explicitly: ClassVar[FrozenSet[LiteralString]] = frozenset({"id"})
processing base class
792class BinarizeKwargs(ProcessingKwargs): 793 """key word arguments for `BinarizeDescr`""" 794 795 threshold: float 796 """The fixed threshold"""
key word arguments for BinarizeDescr
799class BinarizeAlongAxisKwargs(ProcessingKwargs): 800 """key word arguments for `BinarizeDescr`""" 801 802 threshold: NotEmpty[List[float]] 803 """The fixed threshold values along `axis`""" 804 805 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 806 """The `threshold` axis"""
key word arguments for BinarizeDescr
809class BinarizeDescr(ProcessingDescrBase): 810 """Binarize the tensor with a fixed threshold. 811 812 Values above `BinarizeKwargs.threshold`/`BinarizeAlongAxisKwargs.threshold` 813 will be set to one, values below the threshold to zero. 814 """ 815 816 id: Literal["binarize"] = "binarize" 817 kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs]
Binarize the tensor with a fixed threshold.
Values above BinarizeKwargs.threshold
/BinarizeAlongAxisKwargs.threshold
will be set to one, values below the threshold to zero.
Inherited Members
820class ClipDescr(ProcessingDescrBase): 821 """Set tensor values below min to min and above max to max.""" 822 823 id: Literal["clip"] = "clip" 824 kwargs: ClipKwargs
Set tensor values below min to min and above max to max.
Inherited Members
827class EnsureDtypeKwargs(ProcessingKwargs): 828 """key word arguments for `EnsureDtypeDescr`""" 829 830 dtype: Literal[ 831 "float32", 832 "float64", 833 "uint8", 834 "int8", 835 "uint16", 836 "int16", 837 "uint32", 838 "int32", 839 "uint64", 840 "int64", 841 "bool", 842 ]
key word arguments for EnsureDtypeDescr
845class EnsureDtypeDescr(ProcessingDescrBase): 846 """cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching)""" 847 848 id: Literal["ensure_dtype"] = "ensure_dtype" 849 kwargs: EnsureDtypeKwargs
cast the tensor data type to EnsureDtypeKwargs.dtype
(if not matching)
Inherited Members
852class ScaleLinearKwargs(ProcessingKwargs): 853 """key word arguments for `ScaleLinearDescr`""" 854 855 gain: float = 1.0 856 """multiplicative factor""" 857 858 offset: float = 0.0 859 """additive term""" 860 861 @model_validator(mode="after") 862 def _validate(self) -> Self: 863 if self.gain == 1.0 and self.offset == 0.0: 864 raise ValueError( 865 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 866 + " != 0.0." 867 ) 868 869 return self
key word arguments for ScaleLinearDescr
872class ScaleLinearAlongAxisKwargs(ProcessingKwargs): 873 """key word arguments for `ScaleLinearDescr`""" 874 875 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 876 """The axis of of gains/offsets values.""" 877 878 gain: Union[float, NotEmpty[List[float]]] = 1.0 879 """multiplicative factor""" 880 881 offset: Union[float, NotEmpty[List[float]]] = 0.0 882 """additive term""" 883 884 @model_validator(mode="after") 885 def _validate(self) -> Self: 886 887 if isinstance(self.gain, list): 888 if isinstance(self.offset, list): 889 if len(self.gain) != len(self.offset): 890 raise ValueError( 891 f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match." 892 ) 893 else: 894 self.offset = [float(self.offset)] * len(self.gain) 895 elif isinstance(self.offset, list): 896 self.gain = [float(self.gain)] * len(self.offset) 897 else: 898 raise ValueError( 899 "Do not specify an `axis` for scalar gain and offset values." 900 ) 901 902 if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset): 903 raise ValueError( 904 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 905 + " != 0.0." 906 ) 907 908 return self
key word arguments for ScaleLinearDescr
The axis of of gains/offsets values.
911class ScaleLinearDescr(ProcessingDescrBase): 912 """Fixed linear scaling.""" 913 914 id: Literal["scale_linear"] = "scale_linear" 915 kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs]
Fixed linear scaling.
Inherited Members
918class SigmoidDescr(ProcessingDescrBase): 919 """The logistic sigmoid funciton, a.k.a. expit function.""" 920 921 id: Literal["sigmoid"] = "sigmoid" 922 923 @property 924 def kwargs(self) -> ProcessingKwargs: 925 """empty kwargs""" 926 return ProcessingKwargs()
The logistic sigmoid funciton, a.k.a. expit function.
923 @property 924 def kwargs(self) -> ProcessingKwargs: 925 """empty kwargs""" 926 return ProcessingKwargs()
empty kwargs
Inherited Members
929class FixedZeroMeanUnitVarianceKwargs(ProcessingKwargs): 930 """key word arguments for `FixedZeroMeanUnitVarianceDescr`""" 931 932 mean: float 933 """The mean value to normalize with.""" 934 935 std: Annotated[float, Ge(1e-6)] 936 """The standard deviation value to normalize with."""
key word arguments for FixedZeroMeanUnitVarianceDescr
939class FixedZeroMeanUnitVarianceAlongAxisKwargs(ProcessingKwargs): 940 """key word arguments for `FixedZeroMeanUnitVarianceDescr`""" 941 942 mean: NotEmpty[List[float]] 943 """The mean value(s) to normalize with.""" 944 945 std: NotEmpty[List[Annotated[float, Ge(1e-6)]]] 946 """The standard deviation value(s) to normalize with. 947 Size must match `mean` values.""" 948 949 axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])] 950 """The axis of the mean/std values to normalize each entry along that dimension 951 separately.""" 952 953 @model_validator(mode="after") 954 def _mean_and_std_match(self) -> Self: 955 if len(self.mean) != len(self.std): 956 raise ValueError( 957 f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})" 958 + " must match." 959 ) 960 961 return self
key word arguments for FixedZeroMeanUnitVarianceDescr
The standard deviation value(s) to normalize with.
Size must match mean
values.
The axis of the mean/std values to normalize each entry along that dimension separately.
964class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase): 965 """Subtract a given mean and divide by the standard deviation. 966 967 Normalize with fixed, precomputed values for 968 `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std` 969 Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given 970 axes. 971 """ 972 973 id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance" 974 kwargs: Union[ 975 FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs 976 ]
Subtract a given mean and divide by the standard deviation.
Normalize with fixed, precomputed values for
FixedZeroMeanUnitVarianceKwargs.mean
and FixedZeroMeanUnitVarianceKwargs.std
Use FixedZeroMeanUnitVarianceAlongAxisKwargs
for independent scaling along given
axes.
Inherited Members
979class ZeroMeanUnitVarianceKwargs(ProcessingKwargs): 980 """key word arguments for `ZeroMeanUnitVarianceDescr`""" 981 982 axes: Annotated[ 983 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 984 ] = None 985 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 986 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 987 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 988 To normalize each sample independently leave out the 'batch' axis. 989 Default: Scale all axes jointly.""" 990 991 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 992 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`."""
key word arguments for ZeroMeanUnitVarianceDescr
The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
resulting in a tensor of equal shape normalized per channel, specify axes=('batch', 'x', 'y')
.
To normalize each sample independently leave out the 'batch' axis.
Default: Scale all axes jointly.
995class ZeroMeanUnitVarianceDescr(ProcessingDescrBase): 996 """Subtract mean and divide by variance.""" 997 998 id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance" 999 kwargs: ZeroMeanUnitVarianceKwargs
Subtract mean and divide by variance.
Inherited Members
1002class ScaleRangeKwargs(ProcessingKwargs): 1003 """key word arguments for `ScaleRangeDescr` 1004 1005 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default) 1006 this processing step normalizes data to the [0, 1] intervall. 1007 For other percentiles the normalized values will partially be outside the [0, 1] 1008 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the 1009 normalized values to a range. 1010 """ 1011 1012 axes: Annotated[ 1013 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 1014 ] = None 1015 """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value. 1016 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 1017 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 1018 To normalize samples indepdencently, leave out the "batch" axis. 1019 Default: Scale all axes jointly.""" 1020 1021 min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0 1022 """The lower percentile used to determine the value to align with zero.""" 1023 1024 max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0 1025 """The upper percentile used to determine the value to align with one. 1026 Has to be bigger than `min_percentile`. 1027 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly 1028 accepting percentiles specified in the range 0.0 to 1.0.""" 1029 1030 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 1031 """Epsilon for numeric stability. 1032 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`; 1033 with `v_lower,v_upper` values at the respective percentiles.""" 1034 1035 reference_tensor: Optional[TensorId] = None 1036 """Tensor ID to compute the percentiles from. Default: The tensor itself. 1037 For any tensor in `inputs` only input tensor references are allowed.""" 1038 1039 @field_validator("max_percentile", mode="after") 1040 @classmethod 1041 def min_smaller_max(cls, value: float, info: ValidationInfo) -> float: 1042 if (min_p := info.data["min_percentile"]) >= value: 1043 raise ValueError(f"min_percentile {min_p} >= max_percentile {value}") 1044 1045 return value
key word arguments for ScaleRangeDescr
For min_percentile
=0.0 (the default) and max_percentile
=100 (the default)
this processing step normalizes data to the [0, 1] intervall.
For other percentiles the normalized values will partially be outside the [0, 1]
intervall. Use ScaleRange
followed by ClipDescr
if you want to limit the
normalized values to a range.
The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value.
For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
resulting in a tensor of equal shape normalized per channel, specify axes=('batch', 'x', 'y')
.
To normalize samples indepdencently, leave out the "batch" axis.
Default: Scale all axes jointly.
The lower percentile used to determine the value to align with zero.
The upper percentile used to determine the value to align with one.
Has to be bigger than min_percentile
.
The range is 1 to 100 instead of 0 to 100 to avoid mistakenly
accepting percentiles specified in the range 0.0 to 1.0.
Epsilon for numeric stability.
out = (tensor - v_lower) / (v_upper - v_lower + eps)
;
with v_lower,v_upper
values at the respective percentiles.
Tensor ID to compute the percentiles from. Default: The tensor itself.
For any tensor in inputs
only input tensor references are allowed.
1039 @field_validator("max_percentile", mode="after") 1040 @classmethod 1041 def min_smaller_max(cls, value: float, info: ValidationInfo) -> float: 1042 if (min_p := info.data["min_percentile"]) >= value: 1043 raise ValueError(f"min_percentile {min_p} >= max_percentile {value}") 1044 1045 return value
1048class ScaleRangeDescr(ProcessingDescrBase): 1049 """Scale with percentiles.""" 1050 1051 id: Literal["scale_range"] = "scale_range" 1052 kwargs: ScaleRangeKwargs
Scale with percentiles.
Inherited Members
1055class ScaleMeanVarianceKwargs(ProcessingKwargs): 1056 """key word arguments for `ScaleMeanVarianceKwargs`""" 1057 1058 reference_tensor: TensorId 1059 """Name of tensor to match.""" 1060 1061 axes: Annotated[ 1062 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 1063 ] = None 1064 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 1065 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 1066 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 1067 To normalize samples independently, leave out the 'batch' axis. 1068 Default: Scale all axes jointly.""" 1069 1070 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 1071 """Epsilon for numeric stability: 1072 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`"""
key word arguments for ScaleMeanVarianceKwargs
The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
resulting in a tensor of equal shape normalized per channel, specify axes=('batch', 'x', 'y')
.
To normalize samples independently, leave out the 'batch' axis.
Default: Scale all axes jointly.
1075class ScaleMeanVarianceDescr(ProcessingDescrBase): 1076 """Scale a tensor's data distribution to match another tensor's mean/std. 1077 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.` 1078 """ 1079 1080 id: Literal["scale_mean_variance"] = "scale_mean_variance" 1081 kwargs: ScaleMeanVarianceKwargs
Scale a tensor's data distribution to match another tensor's mean/std.
out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.
Inherited Members
1115class TensorDescrBase(Node, Generic[IO_AxisT]): 1116 id: TensorId 1117 """Tensor id. No duplicates are allowed.""" 1118 1119 description: Annotated[str, MaxLen(128)] = "" 1120 """free text description""" 1121 1122 axes: NotEmpty[Sequence[IO_AxisT]] 1123 """tensor axes""" 1124 1125 @property 1126 def shape(self): 1127 return tuple(a.size for a in self.axes) 1128 1129 @field_validator("axes", mode="after", check_fields=False) 1130 @classmethod 1131 def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]: 1132 batch_axes = [a for a in axes if a.type == "batch"] 1133 if len(batch_axes) > 1: 1134 raise ValueError( 1135 f"Only one batch axis (per tensor) allowed, but got {batch_axes}" 1136 ) 1137 1138 seen_ids: Set[AxisId] = set() 1139 duplicate_axes_ids: Set[AxisId] = set() 1140 for a in axes: 1141 (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id) 1142 1143 if duplicate_axes_ids: 1144 raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}") 1145 1146 return axes 1147 1148 test_tensor: FileDescr 1149 """An example tensor to use for testing. 1150 Using the model with the test input tensors is expected to yield the test output tensors. 1151 Each test tensor has be a an ndarray in the 1152 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format). 1153 The file extension must be '.npy'.""" 1154 1155 sample_tensor: Optional[FileDescr] = None 1156 """A sample tensor to illustrate a possible input/output for the model, 1157 The sample image primarily serves to inform a human user about an example use case 1158 and is typically stored as .hdf5, .png or .tiff. 1159 It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats) 1160 (numpy's `.npy` format is not supported). 1161 The image dimensionality has to match the number of axes specified in this tensor description. 1162 """ 1163 1164 @model_validator(mode="after") 1165 def _validate_sample_tensor(self) -> Self: 1166 if ( 1167 self.sample_tensor is None 1168 or not validation_context_var.get().perform_io_checks 1169 ): 1170 return self 1171 1172 local = download(self.sample_tensor.source, sha256=self.sample_tensor.sha256) 1173 tensor: NDArray[Any] = imread( 1174 local.path.read_bytes(), 1175 extension=PurePosixPath(local.original_file_name).suffix, 1176 ) 1177 n_dims = len(tensor.squeeze().shape) 1178 n_dims_min = n_dims_max = len(self.axes) 1179 1180 for a in self.axes: 1181 if isinstance(a, BatchAxis): 1182 n_dims_min -= 1 1183 elif isinstance(a.size, int): 1184 if a.size == 1: 1185 n_dims_min -= 1 1186 elif isinstance(a.size, (ParameterizedSize, DataDependentSize)): 1187 if a.size.min == 1: 1188 n_dims_min -= 1 1189 elif isinstance(a.size, SizeReference): 1190 if a.size.offset < 2: 1191 # size reference may result in singleton axis 1192 n_dims_min -= 1 1193 else: 1194 assert_never(a.size) 1195 1196 n_dims_min = max(0, n_dims_min) 1197 if n_dims < n_dims_min or n_dims > n_dims_max: 1198 raise ValueError( 1199 f"Expected sample tensor to have {n_dims_min} to" 1200 + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})." 1201 ) 1202 1203 return self 1204 1205 data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = ( 1206 IntervalOrRatioDataDescr() 1207 ) 1208 """Description of the tensor's data values, optionally per channel. 1209 If specified per channel, the data `type` needs to match across channels.""" 1210 1211 @property 1212 def dtype( 1213 self, 1214 ) -> Literal[ 1215 "float32", 1216 "float64", 1217 "uint8", 1218 "int8", 1219 "uint16", 1220 "int16", 1221 "uint32", 1222 "int32", 1223 "uint64", 1224 "int64", 1225 "bool", 1226 ]: 1227 """dtype as specified under `data.type` or `data[i].type`""" 1228 if isinstance(self.data, collections.abc.Sequence): 1229 return self.data[0].type 1230 else: 1231 return self.data.type 1232 1233 @field_validator("data", mode="after") 1234 @classmethod 1235 def _check_data_type_across_channels( 1236 cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] 1237 ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]: 1238 if not isinstance(value, list): 1239 return value 1240 1241 dtypes = {t.type for t in value} 1242 if len(dtypes) > 1: 1243 raise ValueError( 1244 "Tensor data descriptions per channel need to agree in their data" 1245 + f" `type`, but found {dtypes}." 1246 ) 1247 1248 return value 1249 1250 @model_validator(mode="after") 1251 def _check_data_matches_channelaxis(self) -> Self: 1252 if not isinstance(self.data, (list, tuple)): 1253 return self 1254 1255 for a in self.axes: 1256 if isinstance(a, ChannelAxis): 1257 size = a.size 1258 assert isinstance(size, int) 1259 break 1260 else: 1261 return self 1262 1263 if len(self.data) != size: 1264 raise ValueError( 1265 f"Got tensor data descriptions for {len(self.data)} channels, but" 1266 + f" '{a.id}' axis has size {size}." 1267 ) 1268 1269 return self 1270 1271 def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]: 1272 if len(array.shape) != len(self.axes): 1273 raise ValueError( 1274 f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})" 1275 + f" incompatible with {len(self.axes)} axes." 1276 ) 1277 return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
Subpart of a resource description
An example tensor to use for testing. Using the model with the test input tensors is expected to yield the test output tensors. Each test tensor has be a an ndarray in the numpy.lib file format. The file extension must be '.npy'.
A sample tensor to illustrate a possible input/output for the model,
The sample image primarily serves to inform a human user about an example use case
and is typically stored as .hdf5, .png or .tiff.
It has to be readable by the imageio library
(numpy's .npy
format is not supported).
The image dimensionality has to match the number of axes specified in this tensor description.
Description of the tensor's data values, optionally per channel.
If specified per channel, the data type
needs to match across channels.
1211 @property 1212 def dtype( 1213 self, 1214 ) -> Literal[ 1215 "float32", 1216 "float64", 1217 "uint8", 1218 "int8", 1219 "uint16", 1220 "int16", 1221 "uint32", 1222 "int32", 1223 "uint64", 1224 "int64", 1225 "bool", 1226 ]: 1227 """dtype as specified under `data.type` or `data[i].type`""" 1228 if isinstance(self.data, collections.abc.Sequence): 1229 return self.data[0].type 1230 else: 1231 return self.data.type
dtype as specified under data.type
or data[i].type
1271 def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]: 1272 if len(array.shape) != len(self.axes): 1273 raise ValueError( 1274 f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})" 1275 + f" incompatible with {len(self.axes)} axes." 1276 ) 1277 return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
1280class InputTensorDescr(TensorDescrBase[InputAxis]): 1281 id: TensorId = TensorId("input") 1282 """Input tensor id. 1283 No duplicates are allowed across all inputs and outputs.""" 1284 1285 optional: bool = False 1286 """indicates that this tensor may be `None`""" 1287 1288 preprocessing: List[PreprocessingDescr] = Field(default_factory=list) 1289 """Description of how this input should be preprocessed. 1290 1291 notes: 1292 - If preprocessing does not start with an 'ensure_dtype' entry, it is added 1293 to ensure an input tensor's data type matches the input tensor's data description. 1294 - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an 1295 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally 1296 changing the data type. 1297 """ 1298 1299 @model_validator(mode="after") 1300 def _validate_preprocessing_kwargs(self) -> Self: 1301 axes_ids = [a.id for a in self.axes] 1302 for p in self.preprocessing: 1303 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 1304 if kwargs_axes is None: 1305 continue 1306 1307 if not isinstance(kwargs_axes, collections.abc.Sequence): 1308 raise ValueError( 1309 f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}" 1310 ) 1311 1312 if any(a not in axes_ids for a in kwargs_axes): 1313 raise ValueError( 1314 "`preprocessing.i.kwargs.axes` needs to be subset of axes ids" 1315 ) 1316 1317 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 1318 dtype = self.data.type 1319 else: 1320 dtype = self.data[0].type 1321 1322 # ensure `preprocessing` begins with `EnsureDtypeDescr` 1323 if not self.preprocessing or not isinstance( 1324 self.preprocessing[0], EnsureDtypeDescr 1325 ): 1326 self.preprocessing.insert( 1327 0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 1328 ) 1329 1330 # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 1331 if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)): 1332 self.preprocessing.append( 1333 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 1334 ) 1335 1336 return self
Subpart of a resource description
Description of how this input should be preprocessed.
notes:
- If preprocessing does not start with an 'ensure_dtype' entry, it is added to ensure an input tensor's data type matches the input tensor's data description.
- If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally changing the data type.
Inherited Members
1339def convert_axes( 1340 axes: str, 1341 *, 1342 shape: Union[ 1343 Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4 1344 ], 1345 tensor_type: Literal["input", "output"], 1346 halo: Optional[Sequence[int]], 1347 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 1348): 1349 ret: List[AnyAxis] = [] 1350 for i, a in enumerate(axes): 1351 axis_type = _AXIS_TYPE_MAP.get(a, a) 1352 if axis_type == "batch": 1353 ret.append(BatchAxis()) 1354 continue 1355 1356 scale = 1.0 1357 if isinstance(shape, _ParameterizedInputShape_v0_4): 1358 if shape.step[i] == 0: 1359 size = shape.min[i] 1360 else: 1361 size = ParameterizedSize(min=shape.min[i], step=shape.step[i]) 1362 elif isinstance(shape, _ImplicitOutputShape_v0_4): 1363 ref_t = str(shape.reference_tensor) 1364 if ref_t.count(".") == 1: 1365 t_id, orig_a_id = ref_t.split(".") 1366 else: 1367 t_id = ref_t 1368 orig_a_id = a 1369 1370 a_id = _AXIS_ID_MAP.get(orig_a_id, a) 1371 if not (orig_scale := shape.scale[i]): 1372 # old way to insert a new axis dimension 1373 size = int(2 * shape.offset[i]) 1374 else: 1375 scale = 1 / orig_scale 1376 if axis_type in ("channel", "index"): 1377 # these axes no longer have a scale 1378 offset_from_scale = orig_scale * size_refs.get( 1379 _TensorName_v0_4(t_id), {} 1380 ).get(orig_a_id, 0) 1381 else: 1382 offset_from_scale = 0 1383 size = SizeReference( 1384 tensor_id=TensorId(t_id), 1385 axis_id=AxisId(a_id), 1386 offset=int(offset_from_scale + 2 * shape.offset[i]), 1387 ) 1388 else: 1389 size = shape[i] 1390 1391 if axis_type == "time": 1392 if tensor_type == "input": 1393 ret.append(TimeInputAxis(size=size, scale=scale)) 1394 else: 1395 assert not isinstance(size, ParameterizedSize) 1396 if halo is None: 1397 ret.append(TimeOutputAxis(size=size, scale=scale)) 1398 else: 1399 assert not isinstance(size, int) 1400 ret.append( 1401 TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i]) 1402 ) 1403 1404 elif axis_type == "index": 1405 if tensor_type == "input": 1406 ret.append(IndexInputAxis(size=size)) 1407 else: 1408 if isinstance(size, ParameterizedSize): 1409 size = DataDependentSize(min=size.min) 1410 1411 ret.append(IndexOutputAxis(size=size)) 1412 elif axis_type == "channel": 1413 assert not isinstance(size, ParameterizedSize) 1414 if isinstance(size, SizeReference): 1415 warnings.warn( 1416 "Conversion of channel size from an implicit output shape may be" 1417 + " wrong" 1418 ) 1419 ret.append( 1420 ChannelAxis( 1421 channel_names=[ 1422 Identifier(f"channel{i}") for i in range(size.offset) 1423 ] 1424 ) 1425 ) 1426 else: 1427 ret.append( 1428 ChannelAxis( 1429 channel_names=[Identifier(f"channel{i}") for i in range(size)] 1430 ) 1431 ) 1432 elif axis_type == "space": 1433 if tensor_type == "input": 1434 ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale)) 1435 else: 1436 assert not isinstance(size, ParameterizedSize) 1437 if halo is None or halo[i] == 0: 1438 ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale)) 1439 elif isinstance(size, int): 1440 raise NotImplementedError( 1441 f"output axis with halo and fixed size (here {size}) not allowed" 1442 ) 1443 else: 1444 ret.append( 1445 SpaceOutputAxisWithHalo( 1446 id=AxisId(a), size=size, scale=scale, halo=halo[i] 1447 ) 1448 ) 1449 1450 return ret
1625class OutputTensorDescr(TensorDescrBase[OutputAxis]): 1626 id: TensorId = TensorId("output") 1627 """Output tensor id. 1628 No duplicates are allowed across all inputs and outputs.""" 1629 1630 postprocessing: List[PostprocessingDescr] = Field(default_factory=list) 1631 """Description of how this output should be postprocessed. 1632 1633 note: `postprocessing` always ends with an 'ensure_dtype' operation. 1634 If not given this is added to cast to this tensor's `data.type`. 1635 """ 1636 1637 @model_validator(mode="after") 1638 def _validate_postprocessing_kwargs(self) -> Self: 1639 axes_ids = [a.id for a in self.axes] 1640 for p in self.postprocessing: 1641 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 1642 if kwargs_axes is None: 1643 continue 1644 1645 if not isinstance(kwargs_axes, collections.abc.Sequence): 1646 raise ValueError( 1647 f"expected `axes` sequence, but got {type(kwargs_axes)}" 1648 ) 1649 1650 if any(a not in axes_ids for a in kwargs_axes): 1651 raise ValueError("`kwargs.axes` needs to be subset of axes ids") 1652 1653 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 1654 dtype = self.data.type 1655 else: 1656 dtype = self.data[0].type 1657 1658 # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 1659 if not self.postprocessing or not isinstance( 1660 self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr) 1661 ): 1662 self.postprocessing.append( 1663 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 1664 ) 1665 return self
Subpart of a resource description
Description of how this output should be postprocessed.
note: postprocessing
always ends with an 'ensure_dtype' operation.
If not given this is added to cast to this tensor's data.type
.
Inherited Members
1715def validate_tensors( 1716 tensors: Mapping[TensorId, Tuple[TensorDescr, NDArray[Any]]], 1717 tensor_origin: str, # for more precise error messages, e.g. 'test_tensor' 1718): 1719 all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, int]]] = {} 1720 1721 def e_msg(d: TensorDescr): 1722 return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]" 1723 1724 for descr, array in tensors.values(): 1725 try: 1726 axis_sizes = descr.get_axis_sizes_for_array(array) 1727 except ValueError as e: 1728 raise ValueError(f"{e_msg(descr)} {e}") 1729 else: 1730 all_tensor_axes[descr.id] = { 1731 a.id: (a, axis_sizes[a.id]) for a in descr.axes 1732 } 1733 1734 for descr, array in tensors.values(): 1735 if array.dtype.name != descr.dtype: 1736 raise ValueError( 1737 f"{e_msg(descr)}.{tensor_origin}.dtype '{array.dtype.name}' does not" 1738 + f" match described dtype '{descr.dtype}'" 1739 ) 1740 1741 for a in descr.axes: 1742 actual_size = all_tensor_axes[descr.id][a.id][1] 1743 if a.size is None: 1744 continue 1745 1746 if isinstance(a.size, int): 1747 if actual_size != a.size: 1748 raise ValueError( 1749 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' " 1750 + f"has incompatible size {actual_size}, expected {a.size}" 1751 ) 1752 elif isinstance(a.size, ParameterizedSize): 1753 _ = a.size.validate_size(actual_size) 1754 elif isinstance(a.size, DataDependentSize): 1755 _ = a.size.validate_size(actual_size) 1756 elif isinstance(a.size, SizeReference): 1757 ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id) 1758 if ref_tensor_axes is None: 1759 raise ValueError( 1760 f"{e_msg(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor" 1761 + f" reference '{a.size.tensor_id}'" 1762 ) 1763 1764 ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None)) 1765 if ref_axis is None or ref_size is None: 1766 raise ValueError( 1767 f"{e_msg(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis" 1768 + f" reference '{a.size.tensor_id}.{a.size.axis_id}" 1769 ) 1770 1771 if a.unit != ref_axis.unit: 1772 raise ValueError( 1773 f"{e_msg(descr)}.axes[{a.id}].size: `SizeReference` requires" 1774 + " axis and reference axis to have the same `unit`, but" 1775 + f" {a.unit}!={ref_axis.unit}" 1776 ) 1777 1778 if actual_size != ( 1779 expected_size := ( 1780 ref_size * ref_axis.scale / a.scale + a.size.offset 1781 ) 1782 ): 1783 raise ValueError( 1784 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' of size" 1785 + f" {actual_size} invalid for referenced size {ref_size};" 1786 + f" expected {expected_size}" 1787 ) 1788 else: 1789 assert_never(a.size)
1792class EnvironmentFileDescr(FileDescr): 1793 source: Annotated[ 1794 ImportantFileSource, 1795 WithSuffix((".yaml", ".yml"), case_sensitive=True), 1796 Field( 1797 examples=["environment.yaml"], 1798 ), 1799 ] 1800 """∈📦 Conda environment file. 1801 Allows to specify custom dependencies, see conda docs: 1802 - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms) 1803 - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually) 1804 """
Subpart of a resource description
∈📦 Conda environment file. Allows to specify custom dependencies, see conda docs:
Subpart of a resource description
Inherited Members
1819class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr): 1820 import_from: str 1821 """Where to import the callable from, i.e. `from <import_from> import <callable>`"""
Subpart of a resource description
Inherited Members
1887class WeightsEntryDescrBase(FileDescr): 1888 type: ClassVar[WeightsFormat] 1889 weights_format_name: ClassVar[str] # human readable 1890 1891 source: ImportantFileSource 1892 """∈📦 The weights file.""" 1893 1894 authors: Optional[List[Author]] = None 1895 """Authors 1896 Either the person(s) that have trained this model resulting in the original weights file. 1897 (If this is the initial weights entry, i.e. it does not have a `parent`) 1898 Or the person(s) who have converted the weights to this weights format. 1899 (If this is a child weight, i.e. it has a `parent` field) 1900 """ 1901 1902 parent: Annotated[ 1903 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"]) 1904 ] = None 1905 """The source weights these weights were converted from. 1906 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`, 1907 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights. 1908 All weight entries except one (the initial set of weights resulting from training the model), 1909 need to have this field.""" 1910 1911 @model_validator(mode="after") 1912 def check_parent_is_not_self(self) -> Self: 1913 if self.type == self.parent: 1914 raise ValueError("Weights entry can't be it's own parent.") 1915 1916 return self
Subpart of a resource description
∈📦 The weights file.
The source weights these weights were converted from.
For example, if a model's weights were converted from the pytorch_state_dict
format to torchscript
,
The pytorch_state_dict
weights entry has no parent
and is the parent of the torchscript
weights.
All weight entries except one (the initial set of weights resulting from training the model),
need to have this field.
1919class KerasHdf5WeightsDescr(WeightsEntryDescrBase): 1920 type = "keras_hdf5" 1921 weights_format_name: ClassVar[str] = "Keras HDF5" 1922 tensorflow_version: Version 1923 """TensorFlow version used to create these weights."""
Subpart of a resource description
TensorFlow version used to create these weights.
Inherited Members
1926class OnnxWeightsDescr(WeightsEntryDescrBase): 1927 type = "onnx" 1928 weights_format_name: ClassVar[str] = "ONNX" 1929 opset_version: Annotated[int, Ge(7)] 1930 """ONNX opset version"""
Subpart of a resource description
Inherited Members
1933class PytorchStateDictWeightsDescr(WeightsEntryDescrBase): 1934 type = "pytorch_state_dict" 1935 weights_format_name: ClassVar[str] = "Pytorch State Dict" 1936 architecture: ArchitectureDescr 1937 pytorch_version: Version 1938 """Version of the PyTorch library used. 1939 If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible. 1940 """ 1941 dependencies: Optional[EnvironmentFileDescr] = None 1942 """Custom depencies beyond pytorch. 1943 The conda environment file should include pytorch and any version pinning has to be compatible with 1944 `pytorch_version`. 1945 """
Subpart of a resource description
Version of the PyTorch library used.
If architecture.depencencies
is specified it has to include pytorch and any version pinning has to be compatible.
Custom depencies beyond pytorch.
The conda environment file should include pytorch and any version pinning has to be compatible with
pytorch_version
.
Inherited Members
1948class TensorflowJsWeightsDescr(WeightsEntryDescrBase): 1949 type = "tensorflow_js" 1950 weights_format_name: ClassVar[str] = "Tensorflow.js" 1951 tensorflow_version: Version 1952 """Version of the TensorFlow library used.""" 1953 1954 source: ImportantFileSource 1955 """∈📦 The multi-file weights. 1956 All required files/folders should be a zip archive."""
Subpart of a resource description
Version of the TensorFlow library used.
∈📦 The multi-file weights. All required files/folders should be a zip archive.
Inherited Members
1959class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase): 1960 type = "tensorflow_saved_model_bundle" 1961 weights_format_name: ClassVar[str] = "Tensorflow Saved Model" 1962 tensorflow_version: Version 1963 """Version of the TensorFlow library used.""" 1964 1965 dependencies: Optional[EnvironmentFileDescr] = None 1966 """Custom dependencies beyond tensorflow. 1967 Should include tensorflow and any version pinning has to be compatible with `tensorflow_version`.""" 1968 1969 source: ImportantFileSource 1970 """∈📦 The multi-file weights. 1971 All required files/folders should be a zip archive."""
Subpart of a resource description
Version of the TensorFlow library used.
Custom dependencies beyond tensorflow.
Should include tensorflow and any version pinning has to be compatible with tensorflow_version
.
∈📦 The multi-file weights. All required files/folders should be a zip archive.
Inherited Members
1974class TorchscriptWeightsDescr(WeightsEntryDescrBase): 1975 type = "torchscript" 1976 weights_format_name: ClassVar[str] = "TorchScript" 1977 pytorch_version: Version 1978 """Version of the PyTorch library used."""
Subpart of a resource description
Version of the PyTorch library used.
Inherited Members
1981class WeightsDescr(Node): 1982 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None 1983 onnx: Optional[OnnxWeightsDescr] = None 1984 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None 1985 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None 1986 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = ( 1987 None 1988 ) 1989 torchscript: Optional[TorchscriptWeightsDescr] = None 1990 1991 @model_validator(mode="after") 1992 def check_entries(self) -> Self: 1993 entries = {wtype for wtype, entry in self if entry is not None} 1994 1995 if not entries: 1996 raise ValueError("Missing weights entry") 1997 1998 entries_wo_parent = { 1999 wtype 2000 for wtype, entry in self 2001 if entry is not None and hasattr(entry, "parent") and entry.parent is None 2002 } 2003 if len(entries_wo_parent) != 1: 2004 issue_warning( 2005 "Exactly one weights entry may not specify the `parent` field (got" 2006 + " {value}). That entry is considered the original set of model weights." 2007 + " Other weight formats are created through conversion of the orignal or" 2008 + " already converted weights. They have to reference the weights format" 2009 + " they were converted from as their `parent`.", 2010 value=len(entries_wo_parent), 2011 field="weights", 2012 ) 2013 2014 for wtype, entry in self: 2015 if entry is None: 2016 continue 2017 2018 assert hasattr(entry, "type") 2019 assert hasattr(entry, "parent") 2020 assert wtype == entry.type 2021 if ( 2022 entry.parent is not None and entry.parent not in entries 2023 ): # self reference checked for `parent` field 2024 raise ValueError( 2025 f"`weights.{wtype}.parent={entry.parent} not in specified weight" 2026 + f" formats: {entries}" 2027 ) 2028 2029 return self
Subpart of a resource description
1991 @model_validator(mode="after") 1992 def check_entries(self) -> Self: 1993 entries = {wtype for wtype, entry in self if entry is not None} 1994 1995 if not entries: 1996 raise ValueError("Missing weights entry") 1997 1998 entries_wo_parent = { 1999 wtype 2000 for wtype, entry in self 2001 if entry is not None and hasattr(entry, "parent") and entry.parent is None 2002 } 2003 if len(entries_wo_parent) != 1: 2004 issue_warning( 2005 "Exactly one weights entry may not specify the `parent` field (got" 2006 + " {value}). That entry is considered the original set of model weights." 2007 + " Other weight formats are created through conversion of the orignal or" 2008 + " already converted weights. They have to reference the weights format" 2009 + " they were converted from as their `parent`.", 2010 value=len(entries_wo_parent), 2011 field="weights", 2012 ) 2013 2014 for wtype, entry in self: 2015 if entry is None: 2016 continue 2017 2018 assert hasattr(entry, "type") 2019 assert hasattr(entry, "parent") 2020 assert wtype == entry.type 2021 if ( 2022 entry.parent is not None and entry.parent not in entries 2023 ): # self reference checked for `parent` field 2024 raise ValueError( 2025 f"`weights.{wtype}.parent={entry.parent} not in specified weight" 2026 + f" formats: {entries}" 2027 ) 2028 2029 return self
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.
Inherited Members
2036class LinkedModel(LinkedResourceNode): 2037 """Reference to a bioimage.io model.""" 2038 2039 id: ModelId 2040 """A valid model `id` from the bioimage.io collection."""
Reference to a bioimage.io model.
Inherited Members
2062class ModelDescr(GenericModelDescrBase, title="bioimage.io model specification"): 2063 """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. 2064 These fields are typically stored in a YAML file which we call a model resource description file (model RDF). 2065 """ 2066 2067 format_version: Literal["0.5.3"] = "0.5.3" 2068 """Version of the bioimage.io model description specification used. 2069 When creating a new model always use the latest micro/patch version described here. 2070 The `format_version` is important for any consumer software to understand how to parse the fields. 2071 """ 2072 2073 type: Literal["model"] = "model" 2074 """Specialized resource type 'model'""" 2075 2076 id: Optional[ModelId] = None 2077 """bioimage.io-wide unique resource identifier 2078 assigned by bioimage.io; version **un**specific.""" 2079 2080 authors: NotEmpty[List[Author]] 2081 """The authors are the creators of the model RDF and the primary points of contact.""" 2082 2083 documentation: Annotated[ 2084 DocumentationSource, 2085 Field( 2086 examples=[ 2087 "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/README.md", 2088 "README.md", 2089 ], 2090 ), 2091 ] 2092 """∈📦 URL or relative path to a markdown file with additional documentation. 2093 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory. 2094 The documentation should include a '#[#] Validation' (sub)section 2095 with details on how to quantitatively validate the model on unseen data.""" 2096 2097 @field_validator("documentation", mode="after") 2098 @classmethod 2099 def _validate_documentation(cls, value: DocumentationSource) -> DocumentationSource: 2100 if not validation_context_var.get().perform_io_checks: 2101 return value 2102 2103 doc_path = download(value).path 2104 doc_content = doc_path.read_text(encoding="utf-8") 2105 assert isinstance(doc_content, str) 2106 if not re.match("#.*[vV]alidation", doc_content): 2107 issue_warning( 2108 "No '# Validation' (sub)section found in {value}.", 2109 value=value, 2110 field="documentation", 2111 ) 2112 2113 return value 2114 2115 inputs: NotEmpty[Sequence[InputTensorDescr]] 2116 """Describes the input tensors expected by this model.""" 2117 2118 @field_validator("inputs", mode="after") 2119 @classmethod 2120 def _validate_input_axes( 2121 cls, inputs: Sequence[InputTensorDescr] 2122 ) -> Sequence[InputTensorDescr]: 2123 input_size_refs = cls._get_axes_with_independent_size(inputs) 2124 2125 for i, ipt in enumerate(inputs): 2126 valid_independent_refs: Dict[ 2127 Tuple[TensorId, AxisId], 2128 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2129 ] = { 2130 **{ 2131 (ipt.id, a.id): (ipt, a, a.size) 2132 for a in ipt.axes 2133 if not isinstance(a, BatchAxis) 2134 and isinstance(a.size, (int, ParameterizedSize)) 2135 }, 2136 **input_size_refs, 2137 } 2138 for a, ax in enumerate(ipt.axes): 2139 cls._validate_axis( 2140 "inputs", 2141 i=i, 2142 tensor_id=ipt.id, 2143 a=a, 2144 axis=ax, 2145 valid_independent_refs=valid_independent_refs, 2146 ) 2147 return inputs 2148 2149 @staticmethod 2150 def _validate_axis( 2151 field_name: str, 2152 i: int, 2153 tensor_id: TensorId, 2154 a: int, 2155 axis: AnyAxis, 2156 valid_independent_refs: Dict[ 2157 Tuple[TensorId, AxisId], 2158 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2159 ], 2160 ): 2161 if isinstance(axis, BatchAxis) or isinstance( 2162 axis.size, (int, ParameterizedSize, DataDependentSize) 2163 ): 2164 return 2165 elif not isinstance(axis.size, SizeReference): 2166 assert_never(axis.size) 2167 2168 # validate axis.size SizeReference 2169 ref = (axis.size.tensor_id, axis.size.axis_id) 2170 if ref not in valid_independent_refs: 2171 raise ValueError( 2172 "Invalid tensor axis reference at" 2173 + f" {field_name}[{i}].axes[{a}].size: {axis.size}." 2174 ) 2175 if ref == (tensor_id, axis.id): 2176 raise ValueError( 2177 "Self-referencing not allowed for" 2178 + f" {field_name}[{i}].axes[{a}].size: {axis.size}" 2179 ) 2180 if axis.type == "channel": 2181 if valid_independent_refs[ref][1].type != "channel": 2182 raise ValueError( 2183 "A channel axis' size may only reference another fixed size" 2184 + " channel axis." 2185 ) 2186 if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names: 2187 ref_size = valid_independent_refs[ref][2] 2188 assert isinstance(ref_size, int), ( 2189 "channel axis ref (another channel axis) has to specify fixed" 2190 + " size" 2191 ) 2192 generated_channel_names = [ 2193 Identifier(axis.channel_names.format(i=i)) 2194 for i in range(1, ref_size + 1) 2195 ] 2196 axis.channel_names = generated_channel_names 2197 2198 if (ax_unit := getattr(axis, "unit", None)) != ( 2199 ref_unit := getattr(valid_independent_refs[ref][1], "unit", None) 2200 ): 2201 raise ValueError( 2202 "The units of an axis and its reference axis need to match, but" 2203 + f" '{ax_unit}' != '{ref_unit}'." 2204 ) 2205 ref_axis = valid_independent_refs[ref][1] 2206 if isinstance(ref_axis, BatchAxis): 2207 raise ValueError( 2208 f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}" 2209 + " (a batch axis is not allowed as reference)." 2210 ) 2211 2212 if isinstance(axis, WithHalo): 2213 min_size = axis.size.get_size(axis, ref_axis, n=0) 2214 if (min_size - 2 * axis.halo) < 1: 2215 raise ValueError( 2216 f"axis {axis.id} with minimum size {min_size} is too small for halo" 2217 + f" {axis.halo}." 2218 ) 2219 2220 input_halo = axis.halo * axis.scale / ref_axis.scale 2221 if input_halo != int(input_halo) or input_halo % 2 == 1: 2222 raise ValueError( 2223 f"input_halo {input_halo} (output_halo {axis.halo} *" 2224 + f" output_scale {axis.scale} / input_scale {ref_axis.scale})" 2225 + f" is not an even integer for {tensor_id}.{axis.id}." 2226 ) 2227 2228 @model_validator(mode="after") 2229 def _validate_test_tensors(self) -> Self: 2230 if not validation_context_var.get().perform_io_checks: 2231 return self 2232 2233 test_arrays = [ 2234 load_array(descr.test_tensor.download().path) 2235 for descr in chain(self.inputs, self.outputs) 2236 ] 2237 tensors = { 2238 descr.id: (descr, array) 2239 for descr, array in zip(chain(self.inputs, self.outputs), test_arrays) 2240 } 2241 validate_tensors(tensors, tensor_origin="test_tensor") 2242 return self 2243 2244 @model_validator(mode="after") 2245 def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self: 2246 ipt_refs = {t.id for t in self.inputs} 2247 out_refs = {t.id for t in self.outputs} 2248 for ipt in self.inputs: 2249 for p in ipt.preprocessing: 2250 ref = p.kwargs.get("reference_tensor") 2251 if ref is None: 2252 continue 2253 if ref not in ipt_refs: 2254 raise ValueError( 2255 f"`reference_tensor` '{ref}' not found. Valid input tensor" 2256 + f" references are: {ipt_refs}." 2257 ) 2258 2259 for out in self.outputs: 2260 for p in out.postprocessing: 2261 ref = p.kwargs.get("reference_tensor") 2262 if ref is None: 2263 continue 2264 2265 if ref not in ipt_refs and ref not in out_refs: 2266 raise ValueError( 2267 f"`reference_tensor` '{ref}' not found. Valid tensor references" 2268 + f" are: {ipt_refs | out_refs}." 2269 ) 2270 2271 return self 2272 2273 # TODO: use validate funcs in validate_test_tensors 2274 # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]: 2275 2276 name: Annotated[ 2277 Annotated[ 2278 str, RestrictCharacters(string.ascii_letters + string.digits + "_- ()") 2279 ], 2280 MinLen(5), 2281 MaxLen(128), 2282 warn(MaxLen(64), "Name longer than 64 characters.", INFO), 2283 ] 2284 """A human-readable name of this model. 2285 It should be no longer than 64 characters 2286 and may only contain letter, number, underscore, minus, parentheses and spaces. 2287 We recommend to chose a name that refers to the model's task and image modality. 2288 """ 2289 2290 outputs: NotEmpty[Sequence[OutputTensorDescr]] 2291 """Describes the output tensors.""" 2292 2293 @field_validator("outputs", mode="after") 2294 @classmethod 2295 def _validate_tensor_ids( 2296 cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo 2297 ) -> Sequence[OutputTensorDescr]: 2298 tensor_ids = [ 2299 t.id for t in info.data.get("inputs", []) + info.data.get("outputs", []) 2300 ] 2301 duplicate_tensor_ids: List[str] = [] 2302 seen: Set[str] = set() 2303 for t in tensor_ids: 2304 if t in seen: 2305 duplicate_tensor_ids.append(t) 2306 2307 seen.add(t) 2308 2309 if duplicate_tensor_ids: 2310 raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}") 2311 2312 return outputs 2313 2314 @staticmethod 2315 def _get_axes_with_parameterized_size( 2316 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 2317 ): 2318 return { 2319 f"{t.id}.{a.id}": (t, a, a.size) 2320 for t in io 2321 for a in t.axes 2322 if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize) 2323 } 2324 2325 @staticmethod 2326 def _get_axes_with_independent_size( 2327 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 2328 ): 2329 return { 2330 (t.id, a.id): (t, a, a.size) 2331 for t in io 2332 for a in t.axes 2333 if not isinstance(a, BatchAxis) 2334 and isinstance(a.size, (int, ParameterizedSize)) 2335 } 2336 2337 @field_validator("outputs", mode="after") 2338 @classmethod 2339 def _validate_output_axes( 2340 cls, outputs: List[OutputTensorDescr], info: ValidationInfo 2341 ) -> List[OutputTensorDescr]: 2342 input_size_refs = cls._get_axes_with_independent_size( 2343 info.data.get("inputs", []) 2344 ) 2345 output_size_refs = cls._get_axes_with_independent_size(outputs) 2346 2347 for i, out in enumerate(outputs): 2348 valid_independent_refs: Dict[ 2349 Tuple[TensorId, AxisId], 2350 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2351 ] = { 2352 **{ 2353 (out.id, a.id): (out, a, a.size) 2354 for a in out.axes 2355 if not isinstance(a, BatchAxis) 2356 and isinstance(a.size, (int, ParameterizedSize)) 2357 }, 2358 **input_size_refs, 2359 **output_size_refs, 2360 } 2361 for a, ax in enumerate(out.axes): 2362 cls._validate_axis( 2363 "outputs", 2364 i, 2365 out.id, 2366 a, 2367 ax, 2368 valid_independent_refs=valid_independent_refs, 2369 ) 2370 2371 return outputs 2372 2373 packaged_by: List[Author] = Field(default_factory=list) 2374 """The persons that have packaged and uploaded this model. 2375 Only required if those persons differ from the `authors`.""" 2376 2377 parent: Optional[LinkedModel] = None 2378 """The model from which this model is derived, e.g. by fine-tuning the weights.""" 2379 2380 # todo: add parent self check once we have `id` 2381 # @model_validator(mode="after") 2382 # def validate_parent_is_not_self(self) -> Self: 2383 # if self.parent is not None and self.parent == self.id: 2384 # raise ValueError("The model may not reference itself as parent model") 2385 2386 # return self 2387 2388 run_mode: Annotated[ 2389 Optional[RunMode], 2390 warn(None, "Run mode '{value}' has limited support across consumer softwares."), 2391 ] = None 2392 """Custom run mode for this model: for more complex prediction procedures like test time 2393 data augmentation that currently cannot be expressed in the specification. 2394 No standard run modes are defined yet.""" 2395 2396 timestamp: Datetime = Datetime(datetime.now()) 2397 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format 2398 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat). 2399 (In Python a datetime object is valid, too).""" 2400 2401 training_data: Annotated[ 2402 Union[None, LinkedDataset, DatasetDescr, DatasetDescr02], 2403 Field(union_mode="left_to_right"), 2404 ] = None 2405 """The dataset used to train this model""" 2406 2407 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)] 2408 """The weights for this model. 2409 Weights can be given for different formats, but should otherwise be equivalent. 2410 The available weight formats determine which consumers can use this model.""" 2411 2412 @model_validator(mode="after") 2413 def _add_default_cover(self) -> Self: 2414 if not validation_context_var.get().perform_io_checks or self.covers: 2415 return self 2416 2417 try: 2418 generated_covers = generate_covers( 2419 [(t, load_array(t.test_tensor.download().path)) for t in self.inputs], 2420 [(t, load_array(t.test_tensor.download().path)) for t in self.outputs], 2421 ) 2422 except Exception as e: 2423 issue_warning( 2424 "Failed to generate cover image(s): {e}", 2425 value=self.covers, 2426 msg_context=dict(e=e), 2427 field="covers", 2428 ) 2429 else: 2430 self.covers.extend(generated_covers) 2431 2432 return self 2433 2434 def get_input_test_arrays(self) -> List[NDArray[Any]]: 2435 data = [load_array(ipt.test_tensor.download().path) for ipt in self.inputs] 2436 assert all(isinstance(d, np.ndarray) for d in data) 2437 return data 2438 2439 def get_output_test_arrays(self) -> List[NDArray[Any]]: 2440 data = [load_array(out.test_tensor.download().path) for out in self.outputs] 2441 assert all(isinstance(d, np.ndarray) for d in data) 2442 return data 2443 2444 @staticmethod 2445 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int: 2446 batch_size = 1 2447 tensor_with_batchsize: Optional[TensorId] = None 2448 for tid in tensor_sizes: 2449 for aid, s in tensor_sizes[tid].items(): 2450 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size: 2451 continue 2452 2453 if batch_size != 1: 2454 assert tensor_with_batchsize is not None 2455 raise ValueError( 2456 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})" 2457 ) 2458 2459 batch_size = s 2460 tensor_with_batchsize = tid 2461 2462 return batch_size 2463 2464 def get_output_tensor_sizes( 2465 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]] 2466 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]: 2467 """Returns the tensor output sizes for given **input_sizes**. 2468 Only if **input_sizes** has a valid input shape, the tensor output size is exact. 2469 Otherwise it might be larger than the actual (valid) output""" 2470 batch_size = self.get_batch_size(input_sizes) 2471 ns = self.get_ns(input_sizes) 2472 2473 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size) 2474 return tensor_sizes.outputs 2475 2476 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]): 2477 """get parameter `n` for each parameterized axis 2478 such that the valid input size is >= the given input size""" 2479 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {} 2480 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs} 2481 for tid in input_sizes: 2482 for aid, s in input_sizes[tid].items(): 2483 size_descr = axes[tid][aid].size 2484 if isinstance(size_descr, ParameterizedSize): 2485 ret[(tid, aid)] = size_descr.get_n(s) 2486 elif size_descr is None or isinstance(size_descr, (int, SizeReference)): 2487 pass 2488 else: 2489 assert_never(size_descr) 2490 2491 return ret 2492 2493 def get_tensor_sizes( 2494 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int 2495 ) -> _TensorSizes: 2496 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size) 2497 return _TensorSizes( 2498 { 2499 t: { 2500 aa: axis_sizes.inputs[(tt, aa)] 2501 for tt, aa in axis_sizes.inputs 2502 if tt == t 2503 } 2504 for t in {tt for tt, _ in axis_sizes.inputs} 2505 }, 2506 { 2507 t: { 2508 aa: axis_sizes.outputs[(tt, aa)] 2509 for tt, aa in axis_sizes.outputs 2510 if tt == t 2511 } 2512 for t in {tt for tt, _ in axis_sizes.outputs} 2513 }, 2514 ) 2515 2516 def get_axis_sizes( 2517 self, 2518 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], 2519 batch_size: Optional[int] = None, 2520 *, 2521 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None, 2522 ) -> _AxisSizes: 2523 """Determine input and output block shape for scale factors **ns** 2524 of parameterized input sizes. 2525 2526 Args: 2527 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id)) 2528 that is parameterized as `size = min + n * step`. 2529 batch_size: The desired size of the batch dimension. 2530 If given **batch_size** overwrites any batch size present in 2531 **max_input_shape**. Default 1. 2532 max_input_shape: Limits the derived block shapes. 2533 Each axis for which the input size, parameterized by `n`, is larger 2534 than **max_input_shape** is set to the minimal value `n_min` for which 2535 this is still true. 2536 Use this for small input samples or large values of **ns**. 2537 Or simply whenever you know the full input shape. 2538 2539 Returns: 2540 Resolved axis sizes for model inputs and outputs. 2541 """ 2542 max_input_shape = max_input_shape or {} 2543 if batch_size is None: 2544 for (_t_id, a_id), s in max_input_shape.items(): 2545 if a_id == BATCH_AXIS_ID: 2546 batch_size = s 2547 break 2548 else: 2549 batch_size = 1 2550 2551 all_axes = { 2552 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs) 2553 } 2554 2555 inputs: Dict[Tuple[TensorId, AxisId], int] = {} 2556 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {} 2557 2558 def get_axis_size(a: Union[InputAxis, OutputAxis]): 2559 if isinstance(a, BatchAxis): 2560 if (t_descr.id, a.id) in ns: 2561 logger.warning( 2562 "Ignoring unexpected size increment factor (n) for batch axis" 2563 + " of tensor '{}'.", 2564 t_descr.id, 2565 ) 2566 return batch_size 2567 elif isinstance(a.size, int): 2568 if (t_descr.id, a.id) in ns: 2569 logger.warning( 2570 "Ignoring unexpected size increment factor (n) for fixed size" 2571 + " axis '{}' of tensor '{}'.", 2572 a.id, 2573 t_descr.id, 2574 ) 2575 return a.size 2576 elif isinstance(a.size, ParameterizedSize): 2577 if (t_descr.id, a.id) not in ns: 2578 raise ValueError( 2579 "Size increment factor (n) missing for parametrized axis" 2580 + f" '{a.id}' of tensor '{t_descr.id}'." 2581 ) 2582 n = ns[(t_descr.id, a.id)] 2583 s_max = max_input_shape.get((t_descr.id, a.id)) 2584 if s_max is not None: 2585 n = min(n, a.size.get_n(s_max)) 2586 2587 return a.size.get_size(n) 2588 2589 elif isinstance(a.size, SizeReference): 2590 if (t_descr.id, a.id) in ns: 2591 logger.warning( 2592 "Ignoring unexpected size increment factor (n) for axis '{}'" 2593 + " of tensor '{}' with size reference.", 2594 a.id, 2595 t_descr.id, 2596 ) 2597 assert not isinstance(a, BatchAxis) 2598 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id] 2599 assert not isinstance(ref_axis, BatchAxis) 2600 ref_key = (a.size.tensor_id, a.size.axis_id) 2601 ref_size = inputs.get(ref_key, outputs.get(ref_key)) 2602 assert ref_size is not None, ref_key 2603 assert not isinstance(ref_size, _DataDepSize), ref_key 2604 return a.size.get_size( 2605 axis=a, 2606 ref_axis=ref_axis, 2607 ref_size=ref_size, 2608 ) 2609 elif isinstance(a.size, DataDependentSize): 2610 if (t_descr.id, a.id) in ns: 2611 logger.warning( 2612 "Ignoring unexpected increment factor (n) for data dependent" 2613 + " size axis '{}' of tensor '{}'.", 2614 a.id, 2615 t_descr.id, 2616 ) 2617 return _DataDepSize(a.size.min, a.size.max) 2618 else: 2619 assert_never(a.size) 2620 2621 # first resolve all , but the `SizeReference` input sizes 2622 for t_descr in self.inputs: 2623 for a in t_descr.axes: 2624 if not isinstance(a.size, SizeReference): 2625 s = get_axis_size(a) 2626 assert not isinstance(s, _DataDepSize) 2627 inputs[t_descr.id, a.id] = s 2628 2629 # resolve all other input axis sizes 2630 for t_descr in self.inputs: 2631 for a in t_descr.axes: 2632 if isinstance(a.size, SizeReference): 2633 s = get_axis_size(a) 2634 assert not isinstance(s, _DataDepSize) 2635 inputs[t_descr.id, a.id] = s 2636 2637 # resolve all output axis sizes 2638 for t_descr in self.outputs: 2639 for a in t_descr.axes: 2640 assert not isinstance(a.size, ParameterizedSize) 2641 s = get_axis_size(a) 2642 outputs[t_descr.id, a.id] = s 2643 2644 return _AxisSizes(inputs=inputs, outputs=outputs) 2645 2646 @model_validator(mode="before") 2647 @classmethod 2648 def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]: 2649 if ( 2650 data.get("type") == "model" 2651 and isinstance(fv := data.get("format_version"), str) 2652 and fv.count(".") == 2 2653 ): 2654 fv_parts = fv.split(".") 2655 if any(not p.isdigit() for p in fv_parts): 2656 return data 2657 2658 fv_tuple = tuple(map(int, fv_parts)) 2659 2660 assert cls.implemented_format_version_tuple[0:2] == (0, 5) 2661 if fv_tuple[:2] in ((0, 3), (0, 4)): 2662 m04 = _ModelDescr_v0_4.load(data) 2663 if not isinstance(m04, InvalidDescr): 2664 return _model_conv.convert_as_dict(m04) 2665 elif fv_tuple[:2] == (0, 5): 2666 # bump patch version 2667 data["format_version"] = cls.implemented_format_version 2668 2669 return data
Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
Version of the bioimage.io model description specification used.
When creating a new model always use the latest micro/patch version described here.
The format_version
is important for any consumer software to understand how to parse the fields.
bioimage.io-wide unique resource identifier assigned by bioimage.io; version unspecific.
∈📦 URL or relative path to a markdown file with additional documentation.
The recommended documentation file name is README.md
. An .md
suffix is mandatory.
The documentation should include a '#[#] Validation' (sub)section
with details on how to quantitatively validate the model on unseen data.
Describes the input tensors expected by this model.
A human-readable name of this model. It should be no longer than 64 characters and may only contain letter, number, underscore, minus, parentheses and spaces. We recommend to chose a name that refers to the model's task and image modality.
The persons that have packaged and uploaded this model.
Only required if those persons differ from the authors
.
The model from which this model is derived, e.g. by fine-tuning the weights.
Custom run mode for this model: for more complex prediction procedures like test time data augmentation that currently cannot be expressed in the specification. No standard run modes are defined yet.
The dataset used to train this model
The weights for this model. Weights can be given for different formats, but should otherwise be equivalent. The available weight formats determine which consumers can use this model.
2444 @staticmethod 2445 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int: 2446 batch_size = 1 2447 tensor_with_batchsize: Optional[TensorId] = None 2448 for tid in tensor_sizes: 2449 for aid, s in tensor_sizes[tid].items(): 2450 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size: 2451 continue 2452 2453 if batch_size != 1: 2454 assert tensor_with_batchsize is not None 2455 raise ValueError( 2456 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})" 2457 ) 2458 2459 batch_size = s 2460 tensor_with_batchsize = tid 2461 2462 return batch_size
2464 def get_output_tensor_sizes( 2465 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]] 2466 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]: 2467 """Returns the tensor output sizes for given **input_sizes**. 2468 Only if **input_sizes** has a valid input shape, the tensor output size is exact. 2469 Otherwise it might be larger than the actual (valid) output""" 2470 batch_size = self.get_batch_size(input_sizes) 2471 ns = self.get_ns(input_sizes) 2472 2473 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size) 2474 return tensor_sizes.outputs
Returns the tensor output sizes for given input_sizes. Only if input_sizes has a valid input shape, the tensor output size is exact. Otherwise it might be larger than the actual (valid) output
2476 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]): 2477 """get parameter `n` for each parameterized axis 2478 such that the valid input size is >= the given input size""" 2479 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {} 2480 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs} 2481 for tid in input_sizes: 2482 for aid, s in input_sizes[tid].items(): 2483 size_descr = axes[tid][aid].size 2484 if isinstance(size_descr, ParameterizedSize): 2485 ret[(tid, aid)] = size_descr.get_n(s) 2486 elif size_descr is None or isinstance(size_descr, (int, SizeReference)): 2487 pass 2488 else: 2489 assert_never(size_descr) 2490 2491 return ret
get parameter n
for each parameterized axis
such that the valid input size is >= the given input size
2493 def get_tensor_sizes( 2494 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int 2495 ) -> _TensorSizes: 2496 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size) 2497 return _TensorSizes( 2498 { 2499 t: { 2500 aa: axis_sizes.inputs[(tt, aa)] 2501 for tt, aa in axis_sizes.inputs 2502 if tt == t 2503 } 2504 for t in {tt for tt, _ in axis_sizes.inputs} 2505 }, 2506 { 2507 t: { 2508 aa: axis_sizes.outputs[(tt, aa)] 2509 for tt, aa in axis_sizes.outputs 2510 if tt == t 2511 } 2512 for t in {tt for tt, _ in axis_sizes.outputs} 2513 }, 2514 )
2516 def get_axis_sizes( 2517 self, 2518 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], 2519 batch_size: Optional[int] = None, 2520 *, 2521 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None, 2522 ) -> _AxisSizes: 2523 """Determine input and output block shape for scale factors **ns** 2524 of parameterized input sizes. 2525 2526 Args: 2527 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id)) 2528 that is parameterized as `size = min + n * step`. 2529 batch_size: The desired size of the batch dimension. 2530 If given **batch_size** overwrites any batch size present in 2531 **max_input_shape**. Default 1. 2532 max_input_shape: Limits the derived block shapes. 2533 Each axis for which the input size, parameterized by `n`, is larger 2534 than **max_input_shape** is set to the minimal value `n_min` for which 2535 this is still true. 2536 Use this for small input samples or large values of **ns**. 2537 Or simply whenever you know the full input shape. 2538 2539 Returns: 2540 Resolved axis sizes for model inputs and outputs. 2541 """ 2542 max_input_shape = max_input_shape or {} 2543 if batch_size is None: 2544 for (_t_id, a_id), s in max_input_shape.items(): 2545 if a_id == BATCH_AXIS_ID: 2546 batch_size = s 2547 break 2548 else: 2549 batch_size = 1 2550 2551 all_axes = { 2552 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs) 2553 } 2554 2555 inputs: Dict[Tuple[TensorId, AxisId], int] = {} 2556 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {} 2557 2558 def get_axis_size(a: Union[InputAxis, OutputAxis]): 2559 if isinstance(a, BatchAxis): 2560 if (t_descr.id, a.id) in ns: 2561 logger.warning( 2562 "Ignoring unexpected size increment factor (n) for batch axis" 2563 + " of tensor '{}'.", 2564 t_descr.id, 2565 ) 2566 return batch_size 2567 elif isinstance(a.size, int): 2568 if (t_descr.id, a.id) in ns: 2569 logger.warning( 2570 "Ignoring unexpected size increment factor (n) for fixed size" 2571 + " axis '{}' of tensor '{}'.", 2572 a.id, 2573 t_descr.id, 2574 ) 2575 return a.size 2576 elif isinstance(a.size, ParameterizedSize): 2577 if (t_descr.id, a.id) not in ns: 2578 raise ValueError( 2579 "Size increment factor (n) missing for parametrized axis" 2580 + f" '{a.id}' of tensor '{t_descr.id}'." 2581 ) 2582 n = ns[(t_descr.id, a.id)] 2583 s_max = max_input_shape.get((t_descr.id, a.id)) 2584 if s_max is not None: 2585 n = min(n, a.size.get_n(s_max)) 2586 2587 return a.size.get_size(n) 2588 2589 elif isinstance(a.size, SizeReference): 2590 if (t_descr.id, a.id) in ns: 2591 logger.warning( 2592 "Ignoring unexpected size increment factor (n) for axis '{}'" 2593 + " of tensor '{}' with size reference.", 2594 a.id, 2595 t_descr.id, 2596 ) 2597 assert not isinstance(a, BatchAxis) 2598 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id] 2599 assert not isinstance(ref_axis, BatchAxis) 2600 ref_key = (a.size.tensor_id, a.size.axis_id) 2601 ref_size = inputs.get(ref_key, outputs.get(ref_key)) 2602 assert ref_size is not None, ref_key 2603 assert not isinstance(ref_size, _DataDepSize), ref_key 2604 return a.size.get_size( 2605 axis=a, 2606 ref_axis=ref_axis, 2607 ref_size=ref_size, 2608 ) 2609 elif isinstance(a.size, DataDependentSize): 2610 if (t_descr.id, a.id) in ns: 2611 logger.warning( 2612 "Ignoring unexpected increment factor (n) for data dependent" 2613 + " size axis '{}' of tensor '{}'.", 2614 a.id, 2615 t_descr.id, 2616 ) 2617 return _DataDepSize(a.size.min, a.size.max) 2618 else: 2619 assert_never(a.size) 2620 2621 # first resolve all , but the `SizeReference` input sizes 2622 for t_descr in self.inputs: 2623 for a in t_descr.axes: 2624 if not isinstance(a.size, SizeReference): 2625 s = get_axis_size(a) 2626 assert not isinstance(s, _DataDepSize) 2627 inputs[t_descr.id, a.id] = s 2628 2629 # resolve all other input axis sizes 2630 for t_descr in self.inputs: 2631 for a in t_descr.axes: 2632 if isinstance(a.size, SizeReference): 2633 s = get_axis_size(a) 2634 assert not isinstance(s, _DataDepSize) 2635 inputs[t_descr.id, a.id] = s 2636 2637 # resolve all output axis sizes 2638 for t_descr in self.outputs: 2639 for a in t_descr.axes: 2640 assert not isinstance(a.size, ParameterizedSize) 2641 s = get_axis_size(a) 2642 outputs[t_descr.id, a.id] = s 2643 2644 return _AxisSizes(inputs=inputs, outputs=outputs)
Determine input and output block shape for scale factors ns of parameterized input sizes.
Arguments:
- ns: Scale factor
n
for each axis (keyed by (tensor_id, axis_id)) that is parameterized assize = min + n * step
. - batch_size: The desired size of the batch dimension. If given batch_size overwrites any batch size present in max_input_shape. Default 1.
- max_input_shape: Limits the derived block shapes.
Each axis for which the input size, parameterized by
n
, is larger than max_input_shape is set to the minimal valuen_min
for which this is still true. Use this for small input samples or large values of ns. Or simply whenever you know the full input shape.
Returns:
Resolved axis sizes for model inputs and outputs.
124 def wrapped_model_post_init(self: BaseModel, context: Any, /) -> None: 125 """We need to both initialize private attributes and call the user-defined model_post_init 126 method. 127 """ 128 init_private_attributes(self, context) 129 original_model_post_init(self, context)
We need to both initialize private attributes and call the user-defined model_post_init method.
2894def generate_covers( 2895 inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]], 2896 outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]], 2897) -> List[Path]: 2898 def squeeze( 2899 data: NDArray[Any], axes: Sequence[AnyAxis] 2900 ) -> Tuple[NDArray[Any], List[AnyAxis]]: 2901 """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining""" 2902 if data.ndim != len(axes): 2903 raise ValueError( 2904 f"tensor shape {data.shape} does not match described axes" 2905 + f" {[a.id for a in axes]}" 2906 ) 2907 2908 axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1] 2909 return data.squeeze(), axes 2910 2911 def normalize( 2912 data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7 2913 ) -> NDArray[np.float32]: 2914 data = data.astype("float32") 2915 data -= data.min(axis=axis, keepdims=True) 2916 data /= data.max(axis=axis, keepdims=True) + eps 2917 return data 2918 2919 def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]): 2920 original_shape = data.shape 2921 data, axes = squeeze(data, axes) 2922 2923 # take slice fom any batch or index axis if needed 2924 # and convert the first channel axis and take a slice from any additional channel axes 2925 slices: Tuple[slice, ...] = () 2926 ndim = data.ndim 2927 ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2 2928 has_c_axis = False 2929 for i, a in enumerate(axes): 2930 s = data.shape[i] 2931 assert s > 1 2932 if ( 2933 isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis)) 2934 and ndim > ndim_need 2935 ): 2936 data = data[slices + (slice(s // 2 - 1, s // 2),)] 2937 ndim -= 1 2938 elif isinstance(a, ChannelAxis): 2939 if has_c_axis: 2940 # second channel axis 2941 data = data[slices + (slice(0, 1),)] 2942 ndim -= 1 2943 else: 2944 has_c_axis = True 2945 if s == 2: 2946 # visualize two channels with cyan and magenta 2947 data = np.concatenate( 2948 [ 2949 data[slices + (slice(1, 2),)], 2950 data[slices + (slice(0, 1),)], 2951 ( 2952 data[slices + (slice(0, 1),)] 2953 + data[slices + (slice(1, 2),)] 2954 ) 2955 / 2, # TODO: take maximum instead? 2956 ], 2957 axis=i, 2958 ) 2959 elif data.shape[i] == 3: 2960 pass # visualize 3 channels as RGB 2961 else: 2962 # visualize first 3 channels as RGB 2963 data = data[slices + (slice(3),)] 2964 2965 assert data.shape[i] == 3 2966 2967 slices += (slice(None),) 2968 2969 data, axes = squeeze(data, axes) 2970 assert len(axes) == ndim 2971 # take slice from z axis if needed 2972 slices = () 2973 if ndim > ndim_need: 2974 for i, a in enumerate(axes): 2975 s = data.shape[i] 2976 if a.id == AxisId("z"): 2977 data = data[slices + (slice(s // 2 - 1, s // 2),)] 2978 data, axes = squeeze(data, axes) 2979 ndim -= 1 2980 break 2981 2982 slices += (slice(None),) 2983 2984 # take slice from any space or time axis 2985 slices = () 2986 2987 for i, a in enumerate(axes): 2988 if ndim <= ndim_need: 2989 break 2990 2991 s = data.shape[i] 2992 assert s > 1 2993 if isinstance( 2994 a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis) 2995 ): 2996 data = data[slices + (slice(s // 2 - 1, s // 2),)] 2997 ndim -= 1 2998 2999 slices += (slice(None),) 3000 3001 del slices 3002 data, axes = squeeze(data, axes) 3003 assert len(axes) == ndim 3004 3005 if (has_c_axis and ndim != 3) or ndim != 2: 3006 raise ValueError( 3007 f"Failed to construct cover image from shape {original_shape}" 3008 ) 3009 3010 if not has_c_axis: 3011 assert ndim == 2 3012 data = np.repeat(data[:, :, None], 3, axis=2) 3013 axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB")))) 3014 ndim += 1 3015 3016 assert ndim == 3 3017 3018 # transpose axis order such that longest axis comes first... 3019 axis_order = list(np.argsort(list(data.shape))) 3020 axis_order.reverse() 3021 # ... and channel axis is last 3022 c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0] 3023 axis_order.append(axis_order.pop(c)) 3024 axes = [axes[ao] for ao in axis_order] 3025 data = data.transpose(axis_order) 3026 3027 # h, w = data.shape[:2] 3028 # if h / w in (1.0 or 2.0): 3029 # pass 3030 # elif h / w < 2: 3031 # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images 3032 3033 norm_along = ( 3034 tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None 3035 ) 3036 # normalize the data and map to 8 bit 3037 data = normalize(data, norm_along) 3038 data = (data * 255).astype("uint8") 3039 3040 return data 3041 3042 def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]): 3043 assert im0.dtype == im1.dtype == np.uint8 3044 assert im0.shape == im1.shape 3045 assert im0.ndim == 3 3046 N, M, C = im0.shape 3047 assert C == 3 3048 out = np.ones((N, M, C), dtype="uint8") 3049 for c in range(C): 3050 outc = np.tril(im0[..., c]) 3051 mask = outc == 0 3052 outc[mask] = np.triu(im1[..., c])[mask] 3053 out[..., c] = outc 3054 3055 return out 3056 3057 ipt_descr, ipt = inputs[0] 3058 out_descr, out = outputs[0] 3059 3060 ipt_img = to_2d_image(ipt, ipt_descr.axes) 3061 out_img = to_2d_image(out, out_descr.axes) 3062 3063 cover_folder = Path(mkdtemp()) 3064 if ipt_img.shape == out_img.shape: 3065 covers = [cover_folder / "cover.png"] 3066 imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img)) 3067 else: 3068 covers = [cover_folder / "input.png", cover_folder / "output.png"] 3069 imwrite(covers[0], ipt_img) 3070 imwrite(covers[1], out_img) 3071 3072 return covers
Subpart of a resource description
Inherited Members
Subpart of a resource description