bioimageio.spec.model.v0_5
1from __future__ import annotations 2 3import collections.abc 4import re 5import string 6import warnings 7from abc import ABC 8from copy import deepcopy 9from itertools import chain 10from math import ceil 11from pathlib import Path, PurePosixPath 12from tempfile import mkdtemp 13from typing import ( 14 TYPE_CHECKING, 15 Any, 16 Callable, 17 ClassVar, 18 Dict, 19 Generic, 20 List, 21 Literal, 22 Mapping, 23 NamedTuple, 24 Optional, 25 Sequence, 26 Set, 27 Tuple, 28 Type, 29 TypeVar, 30 Union, 31 cast, 32) 33 34import numpy as np 35from annotated_types import Ge, Gt, Interval, MaxLen, MinLen, Predicate 36from imageio.v3 import imread, imwrite # pyright: ignore[reportUnknownVariableType] 37from loguru import logger 38from numpy.typing import NDArray 39from pydantic import ( 40 AfterValidator, 41 Discriminator, 42 Field, 43 RootModel, 44 SerializationInfo, 45 SerializerFunctionWrapHandler, 46 StrictInt, 47 Tag, 48 ValidationInfo, 49 WrapSerializer, 50 field_validator, 51 model_serializer, 52 model_validator, 53) 54from typing_extensions import Annotated, Self, assert_never, get_args 55 56from .._internal.common_nodes import ( 57 InvalidDescr, 58 Node, 59 NodeWithExplicitlySetFields, 60) 61from .._internal.constants import DTYPE_LIMITS 62from .._internal.field_warning import issue_warning, warn 63from .._internal.io import BioimageioYamlContent as BioimageioYamlContent 64from .._internal.io import FileDescr as FileDescr 65from .._internal.io import ( 66 FileSource, 67 WithSuffix, 68 YamlValue, 69 get_reader, 70 wo_special_file_name, 71) 72from .._internal.io_basics import Sha256 as Sha256 73from .._internal.io_packaging import ( 74 FileDescr_, 75 FileSource_, 76 package_file_descr_serializer, 77) 78from .._internal.io_utils import load_array 79from .._internal.node_converter import Converter 80from .._internal.type_guards import is_dict, is_sequence 81from .._internal.types import ( 82 FAIR, 83 AbsoluteTolerance, 84 LowerCaseIdentifier, 85 LowerCaseIdentifierAnno, 86 MismatchedElementsPerMillion, 87 RelativeTolerance, 88) 89from .._internal.types import Datetime as Datetime 90from .._internal.types import Identifier as Identifier 91from .._internal.types import NotEmpty as NotEmpty 92from .._internal.types import SiUnit as SiUnit 93from .._internal.url import HttpUrl as HttpUrl 94from .._internal.validation_context import get_validation_context 95from .._internal.validator_annotations import RestrictCharacters 96from .._internal.version_type import Version as Version 97from .._internal.warning_levels import INFO 98from ..dataset.v0_2 import DatasetDescr as DatasetDescr02 99from ..dataset.v0_2 import LinkedDataset as LinkedDataset02 100from ..dataset.v0_3 import DatasetDescr as DatasetDescr 101from ..dataset.v0_3 import DatasetId as DatasetId 102from ..dataset.v0_3 import LinkedDataset as LinkedDataset 103from ..dataset.v0_3 import Uploader as Uploader 104from ..generic.v0_3 import ( 105 VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS, 106) 107from ..generic.v0_3 import Author as Author 108from ..generic.v0_3 import BadgeDescr as BadgeDescr 109from ..generic.v0_3 import CiteEntry as CiteEntry 110from ..generic.v0_3 import DeprecatedLicenseId as DeprecatedLicenseId 111from ..generic.v0_3 import Doi as Doi 112from ..generic.v0_3 import ( 113 FileSource_documentation, 114 GenericModelDescrBase, 115 LinkedResourceBase, 116 _author_conv, # pyright: ignore[reportPrivateUsage] 117 _maintainer_conv, # pyright: ignore[reportPrivateUsage] 118) 119from ..generic.v0_3 import LicenseId as LicenseId 120from ..generic.v0_3 import LinkedResource as LinkedResource 121from ..generic.v0_3 import Maintainer as Maintainer 122from ..generic.v0_3 import OrcidId as OrcidId 123from ..generic.v0_3 import RelativeFilePath as RelativeFilePath 124from ..generic.v0_3 import ResourceId as ResourceId 125from .v0_4 import Author as _Author_v0_4 126from .v0_4 import BinarizeDescr as _BinarizeDescr_v0_4 127from .v0_4 import CallableFromDepencency as CallableFromDepencency 128from .v0_4 import CallableFromDepencency as _CallableFromDepencency_v0_4 129from .v0_4 import CallableFromFile as _CallableFromFile_v0_4 130from .v0_4 import ClipDescr as _ClipDescr_v0_4 131from .v0_4 import ClipKwargs as ClipKwargs 132from .v0_4 import ImplicitOutputShape as _ImplicitOutputShape_v0_4 133from .v0_4 import InputTensorDescr as _InputTensorDescr_v0_4 134from .v0_4 import KnownRunMode as KnownRunMode 135from .v0_4 import ModelDescr as _ModelDescr_v0_4 136from .v0_4 import OutputTensorDescr as _OutputTensorDescr_v0_4 137from .v0_4 import ParameterizedInputShape as _ParameterizedInputShape_v0_4 138from .v0_4 import PostprocessingDescr as _PostprocessingDescr_v0_4 139from .v0_4 import PreprocessingDescr as _PreprocessingDescr_v0_4 140from .v0_4 import ProcessingKwargs as ProcessingKwargs 141from .v0_4 import RunMode as RunMode 142from .v0_4 import ScaleLinearDescr as _ScaleLinearDescr_v0_4 143from .v0_4 import ScaleMeanVarianceDescr as _ScaleMeanVarianceDescr_v0_4 144from .v0_4 import ScaleRangeDescr as _ScaleRangeDescr_v0_4 145from .v0_4 import SigmoidDescr as _SigmoidDescr_v0_4 146from .v0_4 import TensorName as _TensorName_v0_4 147from .v0_4 import WeightsFormat as WeightsFormat 148from .v0_4 import ZeroMeanUnitVarianceDescr as _ZeroMeanUnitVarianceDescr_v0_4 149from .v0_4 import package_weights 150 151SpaceUnit = Literal[ 152 "attometer", 153 "angstrom", 154 "centimeter", 155 "decimeter", 156 "exameter", 157 "femtometer", 158 "foot", 159 "gigameter", 160 "hectometer", 161 "inch", 162 "kilometer", 163 "megameter", 164 "meter", 165 "micrometer", 166 "mile", 167 "millimeter", 168 "nanometer", 169 "parsec", 170 "petameter", 171 "picometer", 172 "terameter", 173 "yard", 174 "yoctometer", 175 "yottameter", 176 "zeptometer", 177 "zettameter", 178] 179"""Space unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)""" 180 181TimeUnit = Literal[ 182 "attosecond", 183 "centisecond", 184 "day", 185 "decisecond", 186 "exasecond", 187 "femtosecond", 188 "gigasecond", 189 "hectosecond", 190 "hour", 191 "kilosecond", 192 "megasecond", 193 "microsecond", 194 "millisecond", 195 "minute", 196 "nanosecond", 197 "petasecond", 198 "picosecond", 199 "second", 200 "terasecond", 201 "yoctosecond", 202 "yottasecond", 203 "zeptosecond", 204 "zettasecond", 205] 206"""Time unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)""" 207 208AxisType = Literal["batch", "channel", "index", "time", "space"] 209 210_AXIS_TYPE_MAP: Mapping[str, AxisType] = { 211 "b": "batch", 212 "t": "time", 213 "i": "index", 214 "c": "channel", 215 "x": "space", 216 "y": "space", 217 "z": "space", 218} 219 220_AXIS_ID_MAP = { 221 "b": "batch", 222 "t": "time", 223 "i": "index", 224 "c": "channel", 225} 226 227 228class TensorId(LowerCaseIdentifier): 229 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 230 Annotated[LowerCaseIdentifierAnno, MaxLen(32)] 231 ] 232 233 234def _normalize_axis_id(a: str): 235 a = str(a) 236 normalized = _AXIS_ID_MAP.get(a, a) 237 if a != normalized: 238 logger.opt(depth=3).warning( 239 "Normalized axis id from '{}' to '{}'.", a, normalized 240 ) 241 return normalized 242 243 244class AxisId(LowerCaseIdentifier): 245 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 246 Annotated[ 247 LowerCaseIdentifierAnno, 248 MaxLen(16), 249 AfterValidator(_normalize_axis_id), 250 ] 251 ] 252 253 254def _is_batch(a: str) -> bool: 255 return str(a) == "batch" 256 257 258def _is_not_batch(a: str) -> bool: 259 return not _is_batch(a) 260 261 262NonBatchAxisId = Annotated[AxisId, Predicate(_is_not_batch)] 263 264PreprocessingId = Literal[ 265 "binarize", 266 "clip", 267 "ensure_dtype", 268 "fixed_zero_mean_unit_variance", 269 "scale_linear", 270 "scale_range", 271 "sigmoid", 272 "softmax", 273] 274PostprocessingId = Literal[ 275 "binarize", 276 "clip", 277 "ensure_dtype", 278 "fixed_zero_mean_unit_variance", 279 "scale_linear", 280 "scale_mean_variance", 281 "scale_range", 282 "sigmoid", 283 "softmax", 284 "zero_mean_unit_variance", 285] 286 287 288SAME_AS_TYPE = "<same as type>" 289 290 291ParameterizedSize_N = int 292""" 293Annotates an integer to calculate a concrete axis size from a `ParameterizedSize`. 294""" 295 296 297class ParameterizedSize(Node): 298 """Describes a range of valid tensor axis sizes as `size = min + n*step`. 299 300 - **min** and **step** are given by the model description. 301 - All blocksize paramters n = 0,1,2,... yield a valid `size`. 302 - A greater blocksize paramter n = 0,1,2,... results in a greater **size**. 303 This allows to adjust the axis size more generically. 304 """ 305 306 N: ClassVar[Type[int]] = ParameterizedSize_N 307 """Positive integer to parameterize this axis""" 308 309 min: Annotated[int, Gt(0)] 310 step: Annotated[int, Gt(0)] 311 312 def validate_size(self, size: int) -> int: 313 if size < self.min: 314 raise ValueError(f"size {size} < {self.min}") 315 if (size - self.min) % self.step != 0: 316 raise ValueError( 317 f"axis of size {size} is not parameterized by `min + n*step` =" 318 + f" `{self.min} + n*{self.step}`" 319 ) 320 321 return size 322 323 def get_size(self, n: ParameterizedSize_N) -> int: 324 return self.min + self.step * n 325 326 def get_n(self, s: int) -> ParameterizedSize_N: 327 """return smallest n parameterizing a size greater or equal than `s`""" 328 return ceil((s - self.min) / self.step) 329 330 331class DataDependentSize(Node): 332 min: Annotated[int, Gt(0)] = 1 333 max: Annotated[Optional[int], Gt(1)] = None 334 335 @model_validator(mode="after") 336 def _validate_max_gt_min(self): 337 if self.max is not None and self.min >= self.max: 338 raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}") 339 340 return self 341 342 def validate_size(self, size: int) -> int: 343 if size < self.min: 344 raise ValueError(f"size {size} < {self.min}") 345 346 if self.max is not None and size > self.max: 347 raise ValueError(f"size {size} > {self.max}") 348 349 return size 350 351 352class SizeReference(Node): 353 """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis. 354 355 `axis.size = reference.size * reference.scale / axis.scale + offset` 356 357 Note: 358 1. The axis and the referenced axis need to have the same unit (or no unit). 359 2. Batch axes may not be referenced. 360 3. Fractions are rounded down. 361 4. If the reference axis is `concatenable` the referencing axis is assumed to be 362 `concatenable` as well with the same block order. 363 364 Example: 365 An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm². 366 Let's assume that we want to express the image height h in relation to its width w 367 instead of only accepting input images of exactly 100*49 pixels 368 (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`). 369 370 >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2) 371 >>> h = SpaceInputAxis( 372 ... id=AxisId("h"), 373 ... size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1), 374 ... unit="millimeter", 375 ... scale=4, 376 ... ) 377 >>> print(h.size.get_size(h, w)) 378 49 379 380 ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49 381 """ 382 383 tensor_id: TensorId 384 """tensor id of the reference axis""" 385 386 axis_id: AxisId 387 """axis id of the reference axis""" 388 389 offset: StrictInt = 0 390 391 def get_size( 392 self, 393 axis: Union[ 394 ChannelAxis, 395 IndexInputAxis, 396 IndexOutputAxis, 397 TimeInputAxis, 398 SpaceInputAxis, 399 TimeOutputAxis, 400 TimeOutputAxisWithHalo, 401 SpaceOutputAxis, 402 SpaceOutputAxisWithHalo, 403 ], 404 ref_axis: Union[ 405 ChannelAxis, 406 IndexInputAxis, 407 IndexOutputAxis, 408 TimeInputAxis, 409 SpaceInputAxis, 410 TimeOutputAxis, 411 TimeOutputAxisWithHalo, 412 SpaceOutputAxis, 413 SpaceOutputAxisWithHalo, 414 ], 415 n: ParameterizedSize_N = 0, 416 ref_size: Optional[int] = None, 417 ): 418 """Compute the concrete size for a given axis and its reference axis. 419 420 Args: 421 axis: The axis this `SizeReference` is the size of. 422 ref_axis: The reference axis to compute the size from. 423 n: If the **ref_axis** is parameterized (of type `ParameterizedSize`) 424 and no fixed **ref_size** is given, 425 **n** is used to compute the size of the parameterized **ref_axis**. 426 ref_size: Overwrite the reference size instead of deriving it from 427 **ref_axis** 428 (**ref_axis.scale** is still used; any given **n** is ignored). 429 """ 430 assert axis.size == self, ( 431 "Given `axis.size` is not defined by this `SizeReference`" 432 ) 433 434 assert ref_axis.id == self.axis_id, ( 435 f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}." 436 ) 437 438 assert axis.unit == ref_axis.unit, ( 439 "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`," 440 f" but {axis.unit}!={ref_axis.unit}" 441 ) 442 if ref_size is None: 443 if isinstance(ref_axis.size, (int, float)): 444 ref_size = ref_axis.size 445 elif isinstance(ref_axis.size, ParameterizedSize): 446 ref_size = ref_axis.size.get_size(n) 447 elif isinstance(ref_axis.size, DataDependentSize): 448 raise ValueError( 449 "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`." 450 ) 451 elif isinstance(ref_axis.size, SizeReference): 452 raise ValueError( 453 "Reference axis referenced in `SizeReference` may not be sized by a" 454 + " `SizeReference` itself." 455 ) 456 else: 457 assert_never(ref_axis.size) 458 459 return int(ref_size * ref_axis.scale / axis.scale + self.offset) 460 461 @staticmethod 462 def _get_unit( 463 axis: Union[ 464 ChannelAxis, 465 IndexInputAxis, 466 IndexOutputAxis, 467 TimeInputAxis, 468 SpaceInputAxis, 469 TimeOutputAxis, 470 TimeOutputAxisWithHalo, 471 SpaceOutputAxis, 472 SpaceOutputAxisWithHalo, 473 ], 474 ): 475 return axis.unit 476 477 478class AxisBase(NodeWithExplicitlySetFields): 479 id: AxisId 480 """An axis id unique across all axes of one tensor.""" 481 482 description: Annotated[str, MaxLen(128)] = "" 483 """A short description of this axis beyond its type and id.""" 484 485 486class WithHalo(Node): 487 halo: Annotated[int, Ge(1)] 488 """The halo should be cropped from the output tensor to avoid boundary effects. 489 It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`. 490 To document a halo that is already cropped by the model use `size.offset` instead.""" 491 492 size: Annotated[ 493 SizeReference, 494 Field( 495 examples=[ 496 10, 497 SizeReference( 498 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 499 ).model_dump(mode="json"), 500 ] 501 ), 502 ] 503 """reference to another axis with an optional offset (see `SizeReference`)""" 504 505 506BATCH_AXIS_ID = AxisId("batch") 507 508 509class BatchAxis(AxisBase): 510 implemented_type: ClassVar[Literal["batch"]] = "batch" 511 if TYPE_CHECKING: 512 type: Literal["batch"] = "batch" 513 else: 514 type: Literal["batch"] 515 516 id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID 517 size: Optional[Literal[1]] = None 518 """The batch size may be fixed to 1, 519 otherwise (the default) it may be chosen arbitrarily depending on available memory""" 520 521 @property 522 def scale(self): 523 return 1.0 524 525 @property 526 def concatenable(self): 527 return True 528 529 @property 530 def unit(self): 531 return None 532 533 534class ChannelAxis(AxisBase): 535 implemented_type: ClassVar[Literal["channel"]] = "channel" 536 if TYPE_CHECKING: 537 type: Literal["channel"] = "channel" 538 else: 539 type: Literal["channel"] 540 541 id: NonBatchAxisId = AxisId("channel") 542 543 channel_names: NotEmpty[List[Identifier]] 544 545 @property 546 def size(self) -> int: 547 return len(self.channel_names) 548 549 @property 550 def concatenable(self): 551 return False 552 553 @property 554 def scale(self) -> float: 555 return 1.0 556 557 @property 558 def unit(self): 559 return None 560 561 562class IndexAxisBase(AxisBase): 563 implemented_type: ClassVar[Literal["index"]] = "index" 564 if TYPE_CHECKING: 565 type: Literal["index"] = "index" 566 else: 567 type: Literal["index"] 568 569 id: NonBatchAxisId = AxisId("index") 570 571 @property 572 def scale(self) -> float: 573 return 1.0 574 575 @property 576 def unit(self): 577 return None 578 579 580class _WithInputAxisSize(Node): 581 size: Annotated[ 582 Union[Annotated[int, Gt(0)], ParameterizedSize, SizeReference], 583 Field( 584 examples=[ 585 10, 586 ParameterizedSize(min=32, step=16).model_dump(mode="json"), 587 SizeReference( 588 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 589 ).model_dump(mode="json"), 590 ] 591 ), 592 ] 593 """The size/length of this axis can be specified as 594 - fixed integer 595 - parameterized series of valid sizes (`ParameterizedSize`) 596 - reference to another axis with an optional offset (`SizeReference`) 597 """ 598 599 600class IndexInputAxis(IndexAxisBase, _WithInputAxisSize): 601 concatenable: bool = False 602 """If a model has a `concatenable` input axis, it can be processed blockwise, 603 splitting a longer sample axis into blocks matching its input tensor description. 604 Output axes are concatenable if they have a `SizeReference` to a concatenable 605 input axis. 606 """ 607 608 609class IndexOutputAxis(IndexAxisBase): 610 size: Annotated[ 611 Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize], 612 Field( 613 examples=[ 614 10, 615 SizeReference( 616 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 617 ).model_dump(mode="json"), 618 ] 619 ), 620 ] 621 """The size/length of this axis can be specified as 622 - fixed integer 623 - reference to another axis with an optional offset (`SizeReference`) 624 - data dependent size using `DataDependentSize` (size is only known after model inference) 625 """ 626 627 628class TimeAxisBase(AxisBase): 629 implemented_type: ClassVar[Literal["time"]] = "time" 630 if TYPE_CHECKING: 631 type: Literal["time"] = "time" 632 else: 633 type: Literal["time"] 634 635 id: NonBatchAxisId = AxisId("time") 636 unit: Optional[TimeUnit] = None 637 scale: Annotated[float, Gt(0)] = 1.0 638 639 640class TimeInputAxis(TimeAxisBase, _WithInputAxisSize): 641 concatenable: bool = False 642 """If a model has a `concatenable` input axis, it can be processed blockwise, 643 splitting a longer sample axis into blocks matching its input tensor description. 644 Output axes are concatenable if they have a `SizeReference` to a concatenable 645 input axis. 646 """ 647 648 649class SpaceAxisBase(AxisBase): 650 implemented_type: ClassVar[Literal["space"]] = "space" 651 if TYPE_CHECKING: 652 type: Literal["space"] = "space" 653 else: 654 type: Literal["space"] 655 656 id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x") 657 unit: Optional[SpaceUnit] = None 658 scale: Annotated[float, Gt(0)] = 1.0 659 660 661class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize): 662 concatenable: bool = False 663 """If a model has a `concatenable` input axis, it can be processed blockwise, 664 splitting a longer sample axis into blocks matching its input tensor description. 665 Output axes are concatenable if they have a `SizeReference` to a concatenable 666 input axis. 667 """ 668 669 670INPUT_AXIS_TYPES = ( 671 BatchAxis, 672 ChannelAxis, 673 IndexInputAxis, 674 TimeInputAxis, 675 SpaceInputAxis, 676) 677"""intended for isinstance comparisons in py<3.10""" 678 679_InputAxisUnion = Union[ 680 BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis 681] 682InputAxis = Annotated[_InputAxisUnion, Discriminator("type")] 683 684 685class _WithOutputAxisSize(Node): 686 size: Annotated[ 687 Union[Annotated[int, Gt(0)], SizeReference], 688 Field( 689 examples=[ 690 10, 691 SizeReference( 692 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 693 ).model_dump(mode="json"), 694 ] 695 ), 696 ] 697 """The size/length of this axis can be specified as 698 - fixed integer 699 - reference to another axis with an optional offset (see `SizeReference`) 700 """ 701 702 703class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize): 704 pass 705 706 707class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo): 708 pass 709 710 711def _get_halo_axis_discriminator_value(v: Any) -> Literal["with_halo", "wo_halo"]: 712 if isinstance(v, dict): 713 return "with_halo" if "halo" in v else "wo_halo" 714 else: 715 return "with_halo" if hasattr(v, "halo") else "wo_halo" 716 717 718_TimeOutputAxisUnion = Annotated[ 719 Union[ 720 Annotated[TimeOutputAxis, Tag("wo_halo")], 721 Annotated[TimeOutputAxisWithHalo, Tag("with_halo")], 722 ], 723 Discriminator(_get_halo_axis_discriminator_value), 724] 725 726 727class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize): 728 pass 729 730 731class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo): 732 pass 733 734 735_SpaceOutputAxisUnion = Annotated[ 736 Union[ 737 Annotated[SpaceOutputAxis, Tag("wo_halo")], 738 Annotated[SpaceOutputAxisWithHalo, Tag("with_halo")], 739 ], 740 Discriminator(_get_halo_axis_discriminator_value), 741] 742 743 744_OutputAxisUnion = Union[ 745 BatchAxis, ChannelAxis, IndexOutputAxis, _TimeOutputAxisUnion, _SpaceOutputAxisUnion 746] 747OutputAxis = Annotated[_OutputAxisUnion, Discriminator("type")] 748 749OUTPUT_AXIS_TYPES = ( 750 BatchAxis, 751 ChannelAxis, 752 IndexOutputAxis, 753 TimeOutputAxis, 754 TimeOutputAxisWithHalo, 755 SpaceOutputAxis, 756 SpaceOutputAxisWithHalo, 757) 758"""intended for isinstance comparisons in py<3.10""" 759 760 761AnyAxis = Union[InputAxis, OutputAxis] 762 763ANY_AXIS_TYPES = INPUT_AXIS_TYPES + OUTPUT_AXIS_TYPES 764"""intended for isinstance comparisons in py<3.10""" 765 766TVs = Union[ 767 NotEmpty[List[int]], 768 NotEmpty[List[float]], 769 NotEmpty[List[bool]], 770 NotEmpty[List[str]], 771] 772 773 774NominalOrOrdinalDType = Literal[ 775 "float32", 776 "float64", 777 "uint8", 778 "int8", 779 "uint16", 780 "int16", 781 "uint32", 782 "int32", 783 "uint64", 784 "int64", 785 "bool", 786] 787 788 789class NominalOrOrdinalDataDescr(Node): 790 values: TVs 791 """A fixed set of nominal or an ascending sequence of ordinal values. 792 In this case `data.type` is required to be an unsigend integer type, e.g. 'uint8'. 793 String `values` are interpreted as labels for tensor values 0, ..., N. 794 Note: as YAML 1.2 does not natively support a "set" datatype, 795 nominal values should be given as a sequence (aka list/array) as well. 796 """ 797 798 type: Annotated[ 799 NominalOrOrdinalDType, 800 Field( 801 examples=[ 802 "float32", 803 "uint8", 804 "uint16", 805 "int64", 806 "bool", 807 ], 808 ), 809 ] = "uint8" 810 811 @model_validator(mode="after") 812 def _validate_values_match_type( 813 self, 814 ) -> Self: 815 incompatible: List[Any] = [] 816 for v in self.values: 817 if self.type == "bool": 818 if not isinstance(v, bool): 819 incompatible.append(v) 820 elif self.type in DTYPE_LIMITS: 821 if ( 822 isinstance(v, (int, float)) 823 and ( 824 v < DTYPE_LIMITS[self.type].min 825 or v > DTYPE_LIMITS[self.type].max 826 ) 827 or (isinstance(v, str) and "uint" not in self.type) 828 or (isinstance(v, float) and "int" in self.type) 829 ): 830 incompatible.append(v) 831 else: 832 incompatible.append(v) 833 834 if len(incompatible) == 5: 835 incompatible.append("...") 836 break 837 838 if incompatible: 839 raise ValueError( 840 f"data type '{self.type}' incompatible with values {incompatible}" 841 ) 842 843 return self 844 845 unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None 846 847 @property 848 def range(self): 849 if isinstance(self.values[0], str): 850 return 0, len(self.values) - 1 851 else: 852 return min(self.values), max(self.values) 853 854 855IntervalOrRatioDType = Literal[ 856 "float32", 857 "float64", 858 "uint8", 859 "int8", 860 "uint16", 861 "int16", 862 "uint32", 863 "int32", 864 "uint64", 865 "int64", 866] 867 868 869class IntervalOrRatioDataDescr(Node): 870 type: Annotated[ # TODO: rename to dtype 871 IntervalOrRatioDType, 872 Field( 873 examples=["float32", "float64", "uint8", "uint16"], 874 ), 875 ] = "float32" 876 range: Tuple[Optional[float], Optional[float]] = ( 877 None, 878 None, 879 ) 880 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor. 881 `None` corresponds to min/max of what can be expressed by **type**.""" 882 unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit" 883 scale: float = 1.0 884 """Scale for data on an interval (or ratio) scale.""" 885 offset: Optional[float] = None 886 """Offset for data on a ratio scale.""" 887 888 @model_validator(mode="before") 889 def _replace_inf(cls, data: Any): 890 if is_dict(data): 891 if "range" in data and is_sequence(data["range"]): 892 forbidden = ( 893 "inf", 894 "-inf", 895 ".inf", 896 "-.inf", 897 float("inf"), 898 float("-inf"), 899 ) 900 if any(v in forbidden for v in data["range"]): 901 issue_warning("replaced 'inf' value", value=data["range"]) 902 903 data["range"] = tuple( 904 (None if v in forbidden else v) for v in data["range"] 905 ) 906 907 return data 908 909 910TensorDataDescr = Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr] 911 912 913class ProcessingDescrBase(NodeWithExplicitlySetFields, ABC): 914 """processing base class""" 915 916 917class BinarizeKwargs(ProcessingKwargs): 918 """key word arguments for `BinarizeDescr`""" 919 920 threshold: float 921 """The fixed threshold""" 922 923 924class BinarizeAlongAxisKwargs(ProcessingKwargs): 925 """key word arguments for `BinarizeDescr`""" 926 927 threshold: NotEmpty[List[float]] 928 """The fixed threshold values along `axis`""" 929 930 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 931 """The `threshold` axis""" 932 933 934class BinarizeDescr(ProcessingDescrBase): 935 """Binarize the tensor with a fixed threshold. 936 937 Values above `BinarizeKwargs.threshold`/`BinarizeAlongAxisKwargs.threshold` 938 will be set to one, values below the threshold to zero. 939 940 Examples: 941 - in YAML 942 ```yaml 943 postprocessing: 944 - id: binarize 945 kwargs: 946 axis: 'channel' 947 threshold: [0.25, 0.5, 0.75] 948 ``` 949 - in Python: 950 >>> postprocessing = [BinarizeDescr( 951 ... kwargs=BinarizeAlongAxisKwargs( 952 ... axis=AxisId('channel'), 953 ... threshold=[0.25, 0.5, 0.75], 954 ... ) 955 ... )] 956 """ 957 958 implemented_id: ClassVar[Literal["binarize"]] = "binarize" 959 if TYPE_CHECKING: 960 id: Literal["binarize"] = "binarize" 961 else: 962 id: Literal["binarize"] 963 kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs] 964 965 966class ClipDescr(ProcessingDescrBase): 967 """Set tensor values below min to min and above max to max. 968 969 See `ScaleRangeDescr` for examples. 970 """ 971 972 implemented_id: ClassVar[Literal["clip"]] = "clip" 973 if TYPE_CHECKING: 974 id: Literal["clip"] = "clip" 975 else: 976 id: Literal["clip"] 977 978 kwargs: ClipKwargs 979 980 981class EnsureDtypeKwargs(ProcessingKwargs): 982 """key word arguments for `EnsureDtypeDescr`""" 983 984 dtype: Literal[ 985 "float32", 986 "float64", 987 "uint8", 988 "int8", 989 "uint16", 990 "int16", 991 "uint32", 992 "int32", 993 "uint64", 994 "int64", 995 "bool", 996 ] 997 998 999class EnsureDtypeDescr(ProcessingDescrBase): 1000 """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching). 1001 1002 This can for example be used to ensure the inner neural network model gets a 1003 different input tensor data type than the fully described bioimage.io model does. 1004 1005 Examples: 1006 The described bioimage.io model (incl. preprocessing) accepts any 1007 float32-compatible tensor, normalizes it with percentiles and clipping and then 1008 casts it to uint8, which is what the neural network in this example expects. 1009 - in YAML 1010 ```yaml 1011 inputs: 1012 - data: 1013 type: float32 # described bioimage.io model is compatible with any float32 input tensor 1014 preprocessing: 1015 - id: scale_range 1016 kwargs: 1017 axes: ['y', 'x'] 1018 max_percentile: 99.8 1019 min_percentile: 5.0 1020 - id: clip 1021 kwargs: 1022 min: 0.0 1023 max: 1.0 1024 - id: ensure_dtype # the neural network of the model requires uint8 1025 kwargs: 1026 dtype: uint8 1027 ``` 1028 - in Python: 1029 >>> preprocessing = [ 1030 ... ScaleRangeDescr( 1031 ... kwargs=ScaleRangeKwargs( 1032 ... axes= (AxisId('y'), AxisId('x')), 1033 ... max_percentile= 99.8, 1034 ... min_percentile= 5.0, 1035 ... ) 1036 ... ), 1037 ... ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)), 1038 ... EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")), 1039 ... ] 1040 """ 1041 1042 implemented_id: ClassVar[Literal["ensure_dtype"]] = "ensure_dtype" 1043 if TYPE_CHECKING: 1044 id: Literal["ensure_dtype"] = "ensure_dtype" 1045 else: 1046 id: Literal["ensure_dtype"] 1047 1048 kwargs: EnsureDtypeKwargs 1049 1050 1051class ScaleLinearKwargs(ProcessingKwargs): 1052 """Key word arguments for `ScaleLinearDescr`""" 1053 1054 gain: float = 1.0 1055 """multiplicative factor""" 1056 1057 offset: float = 0.0 1058 """additive term""" 1059 1060 @model_validator(mode="after") 1061 def _validate(self) -> Self: 1062 if self.gain == 1.0 and self.offset == 0.0: 1063 raise ValueError( 1064 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 1065 + " != 0.0." 1066 ) 1067 1068 return self 1069 1070 1071class ScaleLinearAlongAxisKwargs(ProcessingKwargs): 1072 """Key word arguments for `ScaleLinearDescr`""" 1073 1074 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 1075 """The axis of gain and offset values.""" 1076 1077 gain: Union[float, NotEmpty[List[float]]] = 1.0 1078 """multiplicative factor""" 1079 1080 offset: Union[float, NotEmpty[List[float]]] = 0.0 1081 """additive term""" 1082 1083 @model_validator(mode="after") 1084 def _validate(self) -> Self: 1085 if isinstance(self.gain, list): 1086 if isinstance(self.offset, list): 1087 if len(self.gain) != len(self.offset): 1088 raise ValueError( 1089 f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match." 1090 ) 1091 else: 1092 self.offset = [float(self.offset)] * len(self.gain) 1093 elif isinstance(self.offset, list): 1094 self.gain = [float(self.gain)] * len(self.offset) 1095 else: 1096 raise ValueError( 1097 "Do not specify an `axis` for scalar gain and offset values." 1098 ) 1099 1100 if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset): 1101 raise ValueError( 1102 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 1103 + " != 0.0." 1104 ) 1105 1106 return self 1107 1108 1109class ScaleLinearDescr(ProcessingDescrBase): 1110 """Fixed linear scaling. 1111 1112 Examples: 1113 1. Scale with scalar gain and offset 1114 - in YAML 1115 ```yaml 1116 preprocessing: 1117 - id: scale_linear 1118 kwargs: 1119 gain: 2.0 1120 offset: 3.0 1121 ``` 1122 - in Python: 1123 >>> preprocessing = [ 1124 ... ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0)) 1125 ... ] 1126 1127 2. Independent scaling along an axis 1128 - in YAML 1129 ```yaml 1130 preprocessing: 1131 - id: scale_linear 1132 kwargs: 1133 axis: 'channel' 1134 gain: [1.0, 2.0, 3.0] 1135 ``` 1136 - in Python: 1137 >>> preprocessing = [ 1138 ... ScaleLinearDescr( 1139 ... kwargs=ScaleLinearAlongAxisKwargs( 1140 ... axis=AxisId("channel"), 1141 ... gain=[1.0, 2.0, 3.0], 1142 ... ) 1143 ... ) 1144 ... ] 1145 1146 """ 1147 1148 implemented_id: ClassVar[Literal["scale_linear"]] = "scale_linear" 1149 if TYPE_CHECKING: 1150 id: Literal["scale_linear"] = "scale_linear" 1151 else: 1152 id: Literal["scale_linear"] 1153 kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs] 1154 1155 1156class SigmoidDescr(ProcessingDescrBase): 1157 """The logistic sigmoid function, a.k.a. expit function. 1158 1159 Examples: 1160 - in YAML 1161 ```yaml 1162 postprocessing: 1163 - id: sigmoid 1164 ``` 1165 - in Python: 1166 >>> postprocessing = [SigmoidDescr()] 1167 """ 1168 1169 implemented_id: ClassVar[Literal["sigmoid"]] = "sigmoid" 1170 if TYPE_CHECKING: 1171 id: Literal["sigmoid"] = "sigmoid" 1172 else: 1173 id: Literal["sigmoid"] 1174 1175 @property 1176 def kwargs(self) -> ProcessingKwargs: 1177 """empty kwargs""" 1178 return ProcessingKwargs() 1179 1180 1181class SoftmaxKwargs(ProcessingKwargs): 1182 """key word arguments for `SoftmaxDescr`""" 1183 1184 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] = AxisId("channel") 1185 """The axis to apply the softmax function along. 1186 Note: 1187 Defaults to 'channel' axis 1188 (which may not exist, in which case 1189 a different axis id has to be specified). 1190 """ 1191 1192 1193class SoftmaxDescr(ProcessingDescrBase): 1194 """The softmax function. 1195 1196 Examples: 1197 - in YAML 1198 ```yaml 1199 postprocessing: 1200 - id: softmax 1201 kwargs: 1202 axis: channel 1203 ``` 1204 - in Python: 1205 >>> postprocessing = [SoftmaxDescr(kwargs=SoftmaxKwargs(axis=AxisId("channel")))] 1206 """ 1207 1208 implemented_id: ClassVar[Literal["softmax"]] = "softmax" 1209 if TYPE_CHECKING: 1210 id: Literal["softmax"] = "softmax" 1211 else: 1212 id: Literal["softmax"] 1213 1214 kwargs: SoftmaxKwargs = Field(default_factory=SoftmaxKwargs.model_construct) 1215 1216 1217class FixedZeroMeanUnitVarianceKwargs(ProcessingKwargs): 1218 """key word arguments for `FixedZeroMeanUnitVarianceDescr`""" 1219 1220 mean: float 1221 """The mean value to normalize with.""" 1222 1223 std: Annotated[float, Ge(1e-6)] 1224 """The standard deviation value to normalize with.""" 1225 1226 1227class FixedZeroMeanUnitVarianceAlongAxisKwargs(ProcessingKwargs): 1228 """key word arguments for `FixedZeroMeanUnitVarianceDescr`""" 1229 1230 mean: NotEmpty[List[float]] 1231 """The mean value(s) to normalize with.""" 1232 1233 std: NotEmpty[List[Annotated[float, Ge(1e-6)]]] 1234 """The standard deviation value(s) to normalize with. 1235 Size must match `mean` values.""" 1236 1237 axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])] 1238 """The axis of the mean/std values to normalize each entry along that dimension 1239 separately.""" 1240 1241 @model_validator(mode="after") 1242 def _mean_and_std_match(self) -> Self: 1243 if len(self.mean) != len(self.std): 1244 raise ValueError( 1245 f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})" 1246 + " must match." 1247 ) 1248 1249 return self 1250 1251 1252class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase): 1253 """Subtract a given mean and divide by the standard deviation. 1254 1255 Normalize with fixed, precomputed values for 1256 `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std` 1257 Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given 1258 axes. 1259 1260 Examples: 1261 1. scalar value for whole tensor 1262 - in YAML 1263 ```yaml 1264 preprocessing: 1265 - id: fixed_zero_mean_unit_variance 1266 kwargs: 1267 mean: 103.5 1268 std: 13.7 1269 ``` 1270 - in Python 1271 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 1272 ... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7) 1273 ... )] 1274 1275 2. independently along an axis 1276 - in YAML 1277 ```yaml 1278 preprocessing: 1279 - id: fixed_zero_mean_unit_variance 1280 kwargs: 1281 axis: channel 1282 mean: [101.5, 102.5, 103.5] 1283 std: [11.7, 12.7, 13.7] 1284 ``` 1285 - in Python 1286 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 1287 ... kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( 1288 ... axis=AxisId("channel"), 1289 ... mean=[101.5, 102.5, 103.5], 1290 ... std=[11.7, 12.7, 13.7], 1291 ... ) 1292 ... )] 1293 """ 1294 1295 implemented_id: ClassVar[Literal["fixed_zero_mean_unit_variance"]] = ( 1296 "fixed_zero_mean_unit_variance" 1297 ) 1298 if TYPE_CHECKING: 1299 id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance" 1300 else: 1301 id: Literal["fixed_zero_mean_unit_variance"] 1302 1303 kwargs: Union[ 1304 FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs 1305 ] 1306 1307 1308class ZeroMeanUnitVarianceKwargs(ProcessingKwargs): 1309 """key word arguments for `ZeroMeanUnitVarianceDescr`""" 1310 1311 axes: Annotated[ 1312 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 1313 ] = None 1314 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 1315 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 1316 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 1317 To normalize each sample independently leave out the 'batch' axis. 1318 Default: Scale all axes jointly.""" 1319 1320 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 1321 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`.""" 1322 1323 1324class ZeroMeanUnitVarianceDescr(ProcessingDescrBase): 1325 """Subtract mean and divide by variance. 1326 1327 Examples: 1328 Subtract tensor mean and variance 1329 - in YAML 1330 ```yaml 1331 preprocessing: 1332 - id: zero_mean_unit_variance 1333 ``` 1334 - in Python 1335 >>> preprocessing = [ZeroMeanUnitVarianceDescr()] 1336 """ 1337 1338 implemented_id: ClassVar[Literal["zero_mean_unit_variance"]] = ( 1339 "zero_mean_unit_variance" 1340 ) 1341 if TYPE_CHECKING: 1342 id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance" 1343 else: 1344 id: Literal["zero_mean_unit_variance"] 1345 1346 kwargs: ZeroMeanUnitVarianceKwargs = Field( 1347 default_factory=ZeroMeanUnitVarianceKwargs.model_construct 1348 ) 1349 1350 1351class ScaleRangeKwargs(ProcessingKwargs): 1352 """key word arguments for `ScaleRangeDescr` 1353 1354 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default) 1355 this processing step normalizes data to the [0, 1] intervall. 1356 For other percentiles the normalized values will partially be outside the [0, 1] 1357 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the 1358 normalized values to a range. 1359 """ 1360 1361 axes: Annotated[ 1362 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 1363 ] = None 1364 """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value. 1365 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 1366 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 1367 To normalize samples independently, leave out the "batch" axis. 1368 Default: Scale all axes jointly.""" 1369 1370 min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0 1371 """The lower percentile used to determine the value to align with zero.""" 1372 1373 max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0 1374 """The upper percentile used to determine the value to align with one. 1375 Has to be bigger than `min_percentile`. 1376 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly 1377 accepting percentiles specified in the range 0.0 to 1.0.""" 1378 1379 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 1380 """Epsilon for numeric stability. 1381 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`; 1382 with `v_lower,v_upper` values at the respective percentiles.""" 1383 1384 reference_tensor: Optional[TensorId] = None 1385 """Tensor ID to compute the percentiles from. Default: The tensor itself. 1386 For any tensor in `inputs` only input tensor references are allowed.""" 1387 1388 @field_validator("max_percentile", mode="after") 1389 @classmethod 1390 def min_smaller_max(cls, value: float, info: ValidationInfo) -> float: 1391 if (min_p := info.data["min_percentile"]) >= value: 1392 raise ValueError(f"min_percentile {min_p} >= max_percentile {value}") 1393 1394 return value 1395 1396 1397class ScaleRangeDescr(ProcessingDescrBase): 1398 """Scale with percentiles. 1399 1400 Examples: 1401 1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0 1402 - in YAML 1403 ```yaml 1404 preprocessing: 1405 - id: scale_range 1406 kwargs: 1407 axes: ['y', 'x'] 1408 max_percentile: 99.8 1409 min_percentile: 5.0 1410 ``` 1411 - in Python 1412 >>> preprocessing = [ 1413 ... ScaleRangeDescr( 1414 ... kwargs=ScaleRangeKwargs( 1415 ... axes= (AxisId('y'), AxisId('x')), 1416 ... max_percentile= 99.8, 1417 ... min_percentile= 5.0, 1418 ... ) 1419 ... ), 1420 ... ClipDescr( 1421 ... kwargs=ClipKwargs( 1422 ... min=0.0, 1423 ... max=1.0, 1424 ... ) 1425 ... ), 1426 ... ] 1427 1428 2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles. 1429 - in YAML 1430 ```yaml 1431 preprocessing: 1432 - id: scale_range 1433 kwargs: 1434 axes: ['y', 'x'] 1435 max_percentile: 99.8 1436 min_percentile: 5.0 1437 - id: scale_range 1438 - id: clip 1439 kwargs: 1440 min: 0.0 1441 max: 1.0 1442 ``` 1443 - in Python 1444 >>> preprocessing = [ScaleRangeDescr( 1445 ... kwargs=ScaleRangeKwargs( 1446 ... axes= (AxisId('y'), AxisId('x')), 1447 ... max_percentile= 99.8, 1448 ... min_percentile= 5.0, 1449 ... ) 1450 ... )] 1451 1452 """ 1453 1454 implemented_id: ClassVar[Literal["scale_range"]] = "scale_range" 1455 if TYPE_CHECKING: 1456 id: Literal["scale_range"] = "scale_range" 1457 else: 1458 id: Literal["scale_range"] 1459 kwargs: ScaleRangeKwargs = Field(default_factory=ScaleRangeKwargs.model_construct) 1460 1461 1462class ScaleMeanVarianceKwargs(ProcessingKwargs): 1463 """key word arguments for `ScaleMeanVarianceKwargs`""" 1464 1465 reference_tensor: TensorId 1466 """Name of tensor to match.""" 1467 1468 axes: Annotated[ 1469 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 1470 ] = None 1471 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 1472 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 1473 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 1474 To normalize samples independently, leave out the 'batch' axis. 1475 Default: Scale all axes jointly.""" 1476 1477 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 1478 """Epsilon for numeric stability: 1479 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`""" 1480 1481 1482class ScaleMeanVarianceDescr(ProcessingDescrBase): 1483 """Scale a tensor's data distribution to match another tensor's mean/std. 1484 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.` 1485 """ 1486 1487 implemented_id: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance" 1488 if TYPE_CHECKING: 1489 id: Literal["scale_mean_variance"] = "scale_mean_variance" 1490 else: 1491 id: Literal["scale_mean_variance"] 1492 kwargs: ScaleMeanVarianceKwargs 1493 1494 1495PreprocessingDescr = Annotated[ 1496 Union[ 1497 BinarizeDescr, 1498 ClipDescr, 1499 EnsureDtypeDescr, 1500 FixedZeroMeanUnitVarianceDescr, 1501 ScaleLinearDescr, 1502 ScaleRangeDescr, 1503 SigmoidDescr, 1504 SoftmaxDescr, 1505 ZeroMeanUnitVarianceDescr, 1506 ], 1507 Discriminator("id"), 1508] 1509PostprocessingDescr = Annotated[ 1510 Union[ 1511 BinarizeDescr, 1512 ClipDescr, 1513 EnsureDtypeDescr, 1514 FixedZeroMeanUnitVarianceDescr, 1515 ScaleLinearDescr, 1516 ScaleMeanVarianceDescr, 1517 ScaleRangeDescr, 1518 SigmoidDescr, 1519 SoftmaxDescr, 1520 ZeroMeanUnitVarianceDescr, 1521 ], 1522 Discriminator("id"), 1523] 1524 1525IO_AxisT = TypeVar("IO_AxisT", InputAxis, OutputAxis) 1526 1527 1528class TensorDescrBase(Node, Generic[IO_AxisT]): 1529 id: TensorId 1530 """Tensor id. No duplicates are allowed.""" 1531 1532 description: Annotated[str, MaxLen(128)] = "" 1533 """free text description""" 1534 1535 axes: NotEmpty[Sequence[IO_AxisT]] 1536 """tensor axes""" 1537 1538 @property 1539 def shape(self): 1540 return tuple(a.size for a in self.axes) 1541 1542 @field_validator("axes", mode="after", check_fields=False) 1543 @classmethod 1544 def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]: 1545 batch_axes = [a for a in axes if a.type == "batch"] 1546 if len(batch_axes) > 1: 1547 raise ValueError( 1548 f"Only one batch axis (per tensor) allowed, but got {batch_axes}" 1549 ) 1550 1551 seen_ids: Set[AxisId] = set() 1552 duplicate_axes_ids: Set[AxisId] = set() 1553 for a in axes: 1554 (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id) 1555 1556 if duplicate_axes_ids: 1557 raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}") 1558 1559 return axes 1560 1561 test_tensor: FAIR[Optional[FileDescr_]] = None 1562 """An example tensor to use for testing. 1563 Using the model with the test input tensors is expected to yield the test output tensors. 1564 Each test tensor has be a an ndarray in the 1565 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format). 1566 The file extension must be '.npy'.""" 1567 1568 sample_tensor: FAIR[Optional[FileDescr_]] = None 1569 """A sample tensor to illustrate a possible input/output for the model, 1570 The sample image primarily serves to inform a human user about an example use case 1571 and is typically stored as .hdf5, .png or .tiff. 1572 It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats) 1573 (numpy's `.npy` format is not supported). 1574 The image dimensionality has to match the number of axes specified in this tensor description. 1575 """ 1576 1577 @model_validator(mode="after") 1578 def _validate_sample_tensor(self) -> Self: 1579 if self.sample_tensor is None or not get_validation_context().perform_io_checks: 1580 return self 1581 1582 reader = get_reader(self.sample_tensor.source, sha256=self.sample_tensor.sha256) 1583 tensor: NDArray[Any] = imread( 1584 reader.read(), 1585 extension=PurePosixPath(reader.original_file_name).suffix, 1586 ) 1587 n_dims = len(tensor.squeeze().shape) 1588 n_dims_min = n_dims_max = len(self.axes) 1589 1590 for a in self.axes: 1591 if isinstance(a, BatchAxis): 1592 n_dims_min -= 1 1593 elif isinstance(a.size, int): 1594 if a.size == 1: 1595 n_dims_min -= 1 1596 elif isinstance(a.size, (ParameterizedSize, DataDependentSize)): 1597 if a.size.min == 1: 1598 n_dims_min -= 1 1599 elif isinstance(a.size, SizeReference): 1600 if a.size.offset < 2: 1601 # size reference may result in singleton axis 1602 n_dims_min -= 1 1603 else: 1604 assert_never(a.size) 1605 1606 n_dims_min = max(0, n_dims_min) 1607 if n_dims < n_dims_min or n_dims > n_dims_max: 1608 raise ValueError( 1609 f"Expected sample tensor to have {n_dims_min} to" 1610 + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})." 1611 ) 1612 1613 return self 1614 1615 data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = ( 1616 IntervalOrRatioDataDescr() 1617 ) 1618 """Description of the tensor's data values, optionally per channel. 1619 If specified per channel, the data `type` needs to match across channels.""" 1620 1621 @property 1622 def dtype( 1623 self, 1624 ) -> Literal[ 1625 "float32", 1626 "float64", 1627 "uint8", 1628 "int8", 1629 "uint16", 1630 "int16", 1631 "uint32", 1632 "int32", 1633 "uint64", 1634 "int64", 1635 "bool", 1636 ]: 1637 """dtype as specified under `data.type` or `data[i].type`""" 1638 if isinstance(self.data, collections.abc.Sequence): 1639 return self.data[0].type 1640 else: 1641 return self.data.type 1642 1643 @field_validator("data", mode="after") 1644 @classmethod 1645 def _check_data_type_across_channels( 1646 cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] 1647 ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]: 1648 if not isinstance(value, list): 1649 return value 1650 1651 dtypes = {t.type for t in value} 1652 if len(dtypes) > 1: 1653 raise ValueError( 1654 "Tensor data descriptions per channel need to agree in their data" 1655 + f" `type`, but found {dtypes}." 1656 ) 1657 1658 return value 1659 1660 @model_validator(mode="after") 1661 def _check_data_matches_channelaxis(self) -> Self: 1662 if not isinstance(self.data, (list, tuple)): 1663 return self 1664 1665 for a in self.axes: 1666 if isinstance(a, ChannelAxis): 1667 size = a.size 1668 assert isinstance(size, int) 1669 break 1670 else: 1671 return self 1672 1673 if len(self.data) != size: 1674 raise ValueError( 1675 f"Got tensor data descriptions for {len(self.data)} channels, but" 1676 + f" '{a.id}' axis has size {size}." 1677 ) 1678 1679 return self 1680 1681 def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]: 1682 if len(array.shape) != len(self.axes): 1683 raise ValueError( 1684 f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})" 1685 + f" incompatible with {len(self.axes)} axes." 1686 ) 1687 return {a.id: array.shape[i] for i, a in enumerate(self.axes)} 1688 1689 1690class InputTensorDescr(TensorDescrBase[InputAxis]): 1691 id: TensorId = TensorId("input") 1692 """Input tensor id. 1693 No duplicates are allowed across all inputs and outputs.""" 1694 1695 optional: bool = False 1696 """indicates that this tensor may be `None`""" 1697 1698 preprocessing: List[PreprocessingDescr] = Field( 1699 default_factory=cast(Callable[[], List[PreprocessingDescr]], list) 1700 ) 1701 1702 """Description of how this input should be preprocessed. 1703 1704 notes: 1705 - If preprocessing does not start with an 'ensure_dtype' entry, it is added 1706 to ensure an input tensor's data type matches the input tensor's data description. 1707 - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an 1708 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally 1709 changing the data type. 1710 """ 1711 1712 @model_validator(mode="after") 1713 def _validate_preprocessing_kwargs(self) -> Self: 1714 axes_ids = [a.id for a in self.axes] 1715 for p in self.preprocessing: 1716 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 1717 if kwargs_axes is None: 1718 continue 1719 1720 if not isinstance(kwargs_axes, collections.abc.Sequence): 1721 raise ValueError( 1722 f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}" 1723 ) 1724 1725 if any(a not in axes_ids for a in kwargs_axes): 1726 raise ValueError( 1727 "`preprocessing.i.kwargs.axes` needs to be subset of axes ids" 1728 ) 1729 1730 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 1731 dtype = self.data.type 1732 else: 1733 dtype = self.data[0].type 1734 1735 # ensure `preprocessing` begins with `EnsureDtypeDescr` 1736 if not self.preprocessing or not isinstance( 1737 self.preprocessing[0], EnsureDtypeDescr 1738 ): 1739 self.preprocessing.insert( 1740 0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 1741 ) 1742 1743 # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 1744 if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)): 1745 self.preprocessing.append( 1746 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 1747 ) 1748 1749 return self 1750 1751 1752def convert_axes( 1753 axes: str, 1754 *, 1755 shape: Union[ 1756 Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4 1757 ], 1758 tensor_type: Literal["input", "output"], 1759 halo: Optional[Sequence[int]], 1760 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 1761): 1762 ret: List[AnyAxis] = [] 1763 for i, a in enumerate(axes): 1764 axis_type = _AXIS_TYPE_MAP.get(a, a) 1765 if axis_type == "batch": 1766 ret.append(BatchAxis()) 1767 continue 1768 1769 scale = 1.0 1770 if isinstance(shape, _ParameterizedInputShape_v0_4): 1771 if shape.step[i] == 0: 1772 size = shape.min[i] 1773 else: 1774 size = ParameterizedSize(min=shape.min[i], step=shape.step[i]) 1775 elif isinstance(shape, _ImplicitOutputShape_v0_4): 1776 ref_t = str(shape.reference_tensor) 1777 if ref_t.count(".") == 1: 1778 t_id, orig_a_id = ref_t.split(".") 1779 else: 1780 t_id = ref_t 1781 orig_a_id = a 1782 1783 a_id = _AXIS_ID_MAP.get(orig_a_id, a) 1784 if not (orig_scale := shape.scale[i]): 1785 # old way to insert a new axis dimension 1786 size = int(2 * shape.offset[i]) 1787 else: 1788 scale = 1 / orig_scale 1789 if axis_type in ("channel", "index"): 1790 # these axes no longer have a scale 1791 offset_from_scale = orig_scale * size_refs.get( 1792 _TensorName_v0_4(t_id), {} 1793 ).get(orig_a_id, 0) 1794 else: 1795 offset_from_scale = 0 1796 size = SizeReference( 1797 tensor_id=TensorId(t_id), 1798 axis_id=AxisId(a_id), 1799 offset=int(offset_from_scale + 2 * shape.offset[i]), 1800 ) 1801 else: 1802 size = shape[i] 1803 1804 if axis_type == "time": 1805 if tensor_type == "input": 1806 ret.append(TimeInputAxis(size=size, scale=scale)) 1807 else: 1808 assert not isinstance(size, ParameterizedSize) 1809 if halo is None: 1810 ret.append(TimeOutputAxis(size=size, scale=scale)) 1811 else: 1812 assert not isinstance(size, int) 1813 ret.append( 1814 TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i]) 1815 ) 1816 1817 elif axis_type == "index": 1818 if tensor_type == "input": 1819 ret.append(IndexInputAxis(size=size)) 1820 else: 1821 if isinstance(size, ParameterizedSize): 1822 size = DataDependentSize(min=size.min) 1823 1824 ret.append(IndexOutputAxis(size=size)) 1825 elif axis_type == "channel": 1826 assert not isinstance(size, ParameterizedSize) 1827 if isinstance(size, SizeReference): 1828 warnings.warn( 1829 "Conversion of channel size from an implicit output shape may be" 1830 + " wrong" 1831 ) 1832 ret.append( 1833 ChannelAxis( 1834 channel_names=[ 1835 Identifier(f"channel{i}") for i in range(size.offset) 1836 ] 1837 ) 1838 ) 1839 else: 1840 ret.append( 1841 ChannelAxis( 1842 channel_names=[Identifier(f"channel{i}") for i in range(size)] 1843 ) 1844 ) 1845 elif axis_type == "space": 1846 if tensor_type == "input": 1847 ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale)) 1848 else: 1849 assert not isinstance(size, ParameterizedSize) 1850 if halo is None or halo[i] == 0: 1851 ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale)) 1852 elif isinstance(size, int): 1853 raise NotImplementedError( 1854 f"output axis with halo and fixed size (here {size}) not allowed" 1855 ) 1856 else: 1857 ret.append( 1858 SpaceOutputAxisWithHalo( 1859 id=AxisId(a), size=size, scale=scale, halo=halo[i] 1860 ) 1861 ) 1862 1863 return ret 1864 1865 1866def _axes_letters_to_ids( 1867 axes: Optional[str], 1868) -> Optional[List[AxisId]]: 1869 if axes is None: 1870 return None 1871 1872 return [AxisId(a) for a in axes] 1873 1874 1875def _get_complement_v04_axis( 1876 tensor_axes: Sequence[str], axes: Optional[Sequence[str]] 1877) -> Optional[AxisId]: 1878 if axes is None: 1879 return None 1880 1881 non_complement_axes = set(axes) | {"b"} 1882 complement_axes = [a for a in tensor_axes if a not in non_complement_axes] 1883 if len(complement_axes) > 1: 1884 raise ValueError( 1885 f"Expected none or a single complement axis, but axes '{axes}' " 1886 + f"for tensor dims '{tensor_axes}' leave '{complement_axes}'." 1887 ) 1888 1889 return None if not complement_axes else AxisId(complement_axes[0]) 1890 1891 1892def _convert_proc( 1893 p: Union[_PreprocessingDescr_v0_4, _PostprocessingDescr_v0_4], 1894 tensor_axes: Sequence[str], 1895) -> Union[PreprocessingDescr, PostprocessingDescr]: 1896 if isinstance(p, _BinarizeDescr_v0_4): 1897 return BinarizeDescr(kwargs=BinarizeKwargs(threshold=p.kwargs.threshold)) 1898 elif isinstance(p, _ClipDescr_v0_4): 1899 return ClipDescr(kwargs=ClipKwargs(min=p.kwargs.min, max=p.kwargs.max)) 1900 elif isinstance(p, _SigmoidDescr_v0_4): 1901 return SigmoidDescr() 1902 elif isinstance(p, _ScaleLinearDescr_v0_4): 1903 axes = _axes_letters_to_ids(p.kwargs.axes) 1904 if p.kwargs.axes is None: 1905 axis = None 1906 else: 1907 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes) 1908 1909 if axis is None: 1910 assert not isinstance(p.kwargs.gain, list) 1911 assert not isinstance(p.kwargs.offset, list) 1912 kwargs = ScaleLinearKwargs(gain=p.kwargs.gain, offset=p.kwargs.offset) 1913 else: 1914 kwargs = ScaleLinearAlongAxisKwargs( 1915 axis=axis, gain=p.kwargs.gain, offset=p.kwargs.offset 1916 ) 1917 return ScaleLinearDescr(kwargs=kwargs) 1918 elif isinstance(p, _ScaleMeanVarianceDescr_v0_4): 1919 return ScaleMeanVarianceDescr( 1920 kwargs=ScaleMeanVarianceKwargs( 1921 axes=_axes_letters_to_ids(p.kwargs.axes), 1922 reference_tensor=TensorId(str(p.kwargs.reference_tensor)), 1923 eps=p.kwargs.eps, 1924 ) 1925 ) 1926 elif isinstance(p, _ZeroMeanUnitVarianceDescr_v0_4): 1927 if p.kwargs.mode == "fixed": 1928 mean = p.kwargs.mean 1929 std = p.kwargs.std 1930 assert mean is not None 1931 assert std is not None 1932 1933 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes) 1934 1935 if axis is None: 1936 if isinstance(mean, list): 1937 raise ValueError("Expected single float value for mean, not <list>") 1938 if isinstance(std, list): 1939 raise ValueError("Expected single float value for std, not <list>") 1940 return FixedZeroMeanUnitVarianceDescr( 1941 kwargs=FixedZeroMeanUnitVarianceKwargs.model_construct( 1942 mean=mean, 1943 std=std, 1944 ) 1945 ) 1946 else: 1947 if not isinstance(mean, list): 1948 mean = [float(mean)] 1949 if not isinstance(std, list): 1950 std = [float(std)] 1951 1952 return FixedZeroMeanUnitVarianceDescr( 1953 kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( 1954 axis=axis, mean=mean, std=std 1955 ) 1956 ) 1957 1958 else: 1959 axes = _axes_letters_to_ids(p.kwargs.axes) or [] 1960 if p.kwargs.mode == "per_dataset": 1961 axes = [AxisId("batch")] + axes 1962 if not axes: 1963 axes = None 1964 return ZeroMeanUnitVarianceDescr( 1965 kwargs=ZeroMeanUnitVarianceKwargs(axes=axes, eps=p.kwargs.eps) 1966 ) 1967 1968 elif isinstance(p, _ScaleRangeDescr_v0_4): 1969 return ScaleRangeDescr( 1970 kwargs=ScaleRangeKwargs( 1971 axes=_axes_letters_to_ids(p.kwargs.axes), 1972 min_percentile=p.kwargs.min_percentile, 1973 max_percentile=p.kwargs.max_percentile, 1974 eps=p.kwargs.eps, 1975 ) 1976 ) 1977 else: 1978 assert_never(p) 1979 1980 1981class _InputTensorConv( 1982 Converter[ 1983 _InputTensorDescr_v0_4, 1984 InputTensorDescr, 1985 FileSource_, 1986 Optional[FileSource_], 1987 Mapping[_TensorName_v0_4, Mapping[str, int]], 1988 ] 1989): 1990 def _convert( 1991 self, 1992 src: _InputTensorDescr_v0_4, 1993 tgt: "type[InputTensorDescr] | type[dict[str, Any]]", 1994 test_tensor: FileSource_, 1995 sample_tensor: Optional[FileSource_], 1996 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 1997 ) -> "InputTensorDescr | dict[str, Any]": 1998 axes: List[InputAxis] = convert_axes( # pyright: ignore[reportAssignmentType] 1999 src.axes, 2000 shape=src.shape, 2001 tensor_type="input", 2002 halo=None, 2003 size_refs=size_refs, 2004 ) 2005 prep: List[PreprocessingDescr] = [] 2006 for p in src.preprocessing: 2007 cp = _convert_proc(p, src.axes) 2008 assert not isinstance(cp, ScaleMeanVarianceDescr) 2009 prep.append(cp) 2010 2011 prep.append(EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="float32"))) 2012 2013 return tgt( 2014 axes=axes, 2015 id=TensorId(str(src.name)), 2016 test_tensor=FileDescr(source=test_tensor), 2017 sample_tensor=( 2018 None if sample_tensor is None else FileDescr(source=sample_tensor) 2019 ), 2020 data=dict(type=src.data_type), # pyright: ignore[reportArgumentType] 2021 preprocessing=prep, 2022 ) 2023 2024 2025_input_tensor_conv = _InputTensorConv(_InputTensorDescr_v0_4, InputTensorDescr) 2026 2027 2028class OutputTensorDescr(TensorDescrBase[OutputAxis]): 2029 id: TensorId = TensorId("output") 2030 """Output tensor id. 2031 No duplicates are allowed across all inputs and outputs.""" 2032 2033 postprocessing: List[PostprocessingDescr] = Field( 2034 default_factory=cast(Callable[[], List[PostprocessingDescr]], list) 2035 ) 2036 """Description of how this output should be postprocessed. 2037 2038 note: `postprocessing` always ends with an 'ensure_dtype' operation. 2039 If not given this is added to cast to this tensor's `data.type`. 2040 """ 2041 2042 @model_validator(mode="after") 2043 def _validate_postprocessing_kwargs(self) -> Self: 2044 axes_ids = [a.id for a in self.axes] 2045 for p in self.postprocessing: 2046 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 2047 if kwargs_axes is None: 2048 continue 2049 2050 if not isinstance(kwargs_axes, collections.abc.Sequence): 2051 raise ValueError( 2052 f"expected `axes` sequence, but got {type(kwargs_axes)}" 2053 ) 2054 2055 if any(a not in axes_ids for a in kwargs_axes): 2056 raise ValueError("`kwargs.axes` needs to be subset of axes ids") 2057 2058 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 2059 dtype = self.data.type 2060 else: 2061 dtype = self.data[0].type 2062 2063 # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 2064 if not self.postprocessing or not isinstance( 2065 self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr) 2066 ): 2067 self.postprocessing.append( 2068 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 2069 ) 2070 return self 2071 2072 2073class _OutputTensorConv( 2074 Converter[ 2075 _OutputTensorDescr_v0_4, 2076 OutputTensorDescr, 2077 FileSource_, 2078 Optional[FileSource_], 2079 Mapping[_TensorName_v0_4, Mapping[str, int]], 2080 ] 2081): 2082 def _convert( 2083 self, 2084 src: _OutputTensorDescr_v0_4, 2085 tgt: "type[OutputTensorDescr] | type[dict[str, Any]]", 2086 test_tensor: FileSource_, 2087 sample_tensor: Optional[FileSource_], 2088 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 2089 ) -> "OutputTensorDescr | dict[str, Any]": 2090 # TODO: split convert_axes into convert_output_axes and convert_input_axes 2091 axes: List[OutputAxis] = convert_axes( # pyright: ignore[reportAssignmentType] 2092 src.axes, 2093 shape=src.shape, 2094 tensor_type="output", 2095 halo=src.halo, 2096 size_refs=size_refs, 2097 ) 2098 data_descr: Dict[str, Any] = dict(type=src.data_type) 2099 if data_descr["type"] == "bool": 2100 data_descr["values"] = [False, True] 2101 2102 return tgt( 2103 axes=axes, 2104 id=TensorId(str(src.name)), 2105 test_tensor=FileDescr(source=test_tensor), 2106 sample_tensor=( 2107 None if sample_tensor is None else FileDescr(source=sample_tensor) 2108 ), 2109 data=data_descr, # pyright: ignore[reportArgumentType] 2110 postprocessing=[_convert_proc(p, src.axes) for p in src.postprocessing], 2111 ) 2112 2113 2114_output_tensor_conv = _OutputTensorConv(_OutputTensorDescr_v0_4, OutputTensorDescr) 2115 2116 2117TensorDescr = Union[InputTensorDescr, OutputTensorDescr] 2118 2119 2120def validate_tensors( 2121 tensors: Mapping[TensorId, Tuple[TensorDescr, Optional[NDArray[Any]]]], 2122 tensor_origin: Literal[ 2123 "test_tensor" 2124 ], # for more precise error messages, e.g. 'test_tensor' 2125): 2126 all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, Optional[int]]]] = {} 2127 2128 def e_msg(d: TensorDescr): 2129 return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]" 2130 2131 for descr, array in tensors.values(): 2132 if array is None: 2133 axis_sizes = {a.id: None for a in descr.axes} 2134 else: 2135 try: 2136 axis_sizes = descr.get_axis_sizes_for_array(array) 2137 except ValueError as e: 2138 raise ValueError(f"{e_msg(descr)} {e}") 2139 2140 all_tensor_axes[descr.id] = {a.id: (a, axis_sizes[a.id]) for a in descr.axes} 2141 2142 for descr, array in tensors.values(): 2143 if array is None: 2144 continue 2145 2146 if descr.dtype in ("float32", "float64"): 2147 invalid_test_tensor_dtype = array.dtype.name not in ( 2148 "float32", 2149 "float64", 2150 "uint8", 2151 "int8", 2152 "uint16", 2153 "int16", 2154 "uint32", 2155 "int32", 2156 "uint64", 2157 "int64", 2158 ) 2159 else: 2160 invalid_test_tensor_dtype = array.dtype.name != descr.dtype 2161 2162 if invalid_test_tensor_dtype: 2163 raise ValueError( 2164 f"{e_msg(descr)}.{tensor_origin}.dtype '{array.dtype.name}' does not" 2165 + f" match described dtype '{descr.dtype}'" 2166 ) 2167 2168 if array.min() > -1e-4 and array.max() < 1e-4: 2169 raise ValueError( 2170 "Output values are too small for reliable testing." 2171 + f" Values <-1e5 or >=1e5 must be present in {tensor_origin}" 2172 ) 2173 2174 for a in descr.axes: 2175 actual_size = all_tensor_axes[descr.id][a.id][1] 2176 if actual_size is None: 2177 continue 2178 2179 if a.size is None: 2180 continue 2181 2182 if isinstance(a.size, int): 2183 if actual_size != a.size: 2184 raise ValueError( 2185 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' " 2186 + f"has incompatible size {actual_size}, expected {a.size}" 2187 ) 2188 elif isinstance(a.size, ParameterizedSize): 2189 _ = a.size.validate_size(actual_size) 2190 elif isinstance(a.size, DataDependentSize): 2191 _ = a.size.validate_size(actual_size) 2192 elif isinstance(a.size, SizeReference): 2193 ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id) 2194 if ref_tensor_axes is None: 2195 raise ValueError( 2196 f"{e_msg(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor" 2197 + f" reference '{a.size.tensor_id}'" 2198 ) 2199 2200 ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None)) 2201 if ref_axis is None or ref_size is None: 2202 raise ValueError( 2203 f"{e_msg(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis" 2204 + f" reference '{a.size.tensor_id}.{a.size.axis_id}" 2205 ) 2206 2207 if a.unit != ref_axis.unit: 2208 raise ValueError( 2209 f"{e_msg(descr)}.axes[{a.id}].size: `SizeReference` requires" 2210 + " axis and reference axis to have the same `unit`, but" 2211 + f" {a.unit}!={ref_axis.unit}" 2212 ) 2213 2214 if actual_size != ( 2215 expected_size := ( 2216 ref_size * ref_axis.scale / a.scale + a.size.offset 2217 ) 2218 ): 2219 raise ValueError( 2220 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' of size" 2221 + f" {actual_size} invalid for referenced size {ref_size};" 2222 + f" expected {expected_size}" 2223 ) 2224 else: 2225 assert_never(a.size) 2226 2227 2228FileDescr_dependencies = Annotated[ 2229 FileDescr_, 2230 WithSuffix((".yaml", ".yml"), case_sensitive=True), 2231 Field(examples=[dict(source="environment.yaml")]), 2232] 2233 2234 2235class _ArchitectureCallableDescr(Node): 2236 callable: Annotated[Identifier, Field(examples=["MyNetworkClass", "get_my_model"])] 2237 """Identifier of the callable that returns a torch.nn.Module instance.""" 2238 2239 kwargs: Dict[str, YamlValue] = Field( 2240 default_factory=cast(Callable[[], Dict[str, YamlValue]], dict) 2241 ) 2242 """key word arguments for the `callable`""" 2243 2244 2245class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr): 2246 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 2247 """Architecture source file""" 2248 2249 @model_serializer(mode="wrap", when_used="unless-none") 2250 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo): 2251 return package_file_descr_serializer(self, nxt, info) 2252 2253 2254class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr): 2255 import_from: str 2256 """Where to import the callable from, i.e. `from <import_from> import <callable>`""" 2257 2258 2259class _ArchFileConv( 2260 Converter[ 2261 _CallableFromFile_v0_4, 2262 ArchitectureFromFileDescr, 2263 Optional[Sha256], 2264 Dict[str, Any], 2265 ] 2266): 2267 def _convert( 2268 self, 2269 src: _CallableFromFile_v0_4, 2270 tgt: "type[ArchitectureFromFileDescr | dict[str, Any]]", 2271 sha256: Optional[Sha256], 2272 kwargs: Dict[str, Any], 2273 ) -> "ArchitectureFromFileDescr | dict[str, Any]": 2274 if src.startswith("http") and src.count(":") == 2: 2275 http, source, callable_ = src.split(":") 2276 source = ":".join((http, source)) 2277 elif not src.startswith("http") and src.count(":") == 1: 2278 source, callable_ = src.split(":") 2279 else: 2280 source = str(src) 2281 callable_ = str(src) 2282 return tgt( 2283 callable=Identifier(callable_), 2284 source=cast(FileSource_, source), 2285 sha256=sha256, 2286 kwargs=kwargs, 2287 ) 2288 2289 2290_arch_file_conv = _ArchFileConv(_CallableFromFile_v0_4, ArchitectureFromFileDescr) 2291 2292 2293class _ArchLibConv( 2294 Converter[ 2295 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr, Dict[str, Any] 2296 ] 2297): 2298 def _convert( 2299 self, 2300 src: _CallableFromDepencency_v0_4, 2301 tgt: "type[ArchitectureFromLibraryDescr | dict[str, Any]]", 2302 kwargs: Dict[str, Any], 2303 ) -> "ArchitectureFromLibraryDescr | dict[str, Any]": 2304 *mods, callable_ = src.split(".") 2305 import_from = ".".join(mods) 2306 return tgt( 2307 import_from=import_from, callable=Identifier(callable_), kwargs=kwargs 2308 ) 2309 2310 2311_arch_lib_conv = _ArchLibConv( 2312 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr 2313) 2314 2315 2316class WeightsEntryDescrBase(FileDescr): 2317 type: ClassVar[WeightsFormat] 2318 weights_format_name: ClassVar[str] # human readable 2319 2320 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 2321 """Source of the weights file.""" 2322 2323 authors: Optional[List[Author]] = None 2324 """Authors 2325 Either the person(s) that have trained this model resulting in the original weights file. 2326 (If this is the initial weights entry, i.e. it does not have a `parent`) 2327 Or the person(s) who have converted the weights to this weights format. 2328 (If this is a child weight, i.e. it has a `parent` field) 2329 """ 2330 2331 parent: Annotated[ 2332 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"]) 2333 ] = None 2334 """The source weights these weights were converted from. 2335 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`, 2336 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights. 2337 All weight entries except one (the initial set of weights resulting from training the model), 2338 need to have this field.""" 2339 2340 comment: str = "" 2341 """A comment about this weights entry, for example how these weights were created.""" 2342 2343 @model_validator(mode="after") 2344 def _validate(self) -> Self: 2345 if self.type == self.parent: 2346 raise ValueError("Weights entry can't be it's own parent.") 2347 2348 return self 2349 2350 @model_serializer(mode="wrap", when_used="unless-none") 2351 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo): 2352 return package_file_descr_serializer(self, nxt, info) 2353 2354 2355class KerasHdf5WeightsDescr(WeightsEntryDescrBase): 2356 type = "keras_hdf5" 2357 weights_format_name: ClassVar[str] = "Keras HDF5" 2358 tensorflow_version: Version 2359 """TensorFlow version used to create these weights.""" 2360 2361 2362class OnnxWeightsDescr(WeightsEntryDescrBase): 2363 type = "onnx" 2364 weights_format_name: ClassVar[str] = "ONNX" 2365 opset_version: Annotated[int, Ge(7)] 2366 """ONNX opset version""" 2367 2368 2369class PytorchStateDictWeightsDescr(WeightsEntryDescrBase): 2370 type = "pytorch_state_dict" 2371 weights_format_name: ClassVar[str] = "Pytorch State Dict" 2372 architecture: Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr] 2373 pytorch_version: Version 2374 """Version of the PyTorch library used. 2375 If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible. 2376 """ 2377 dependencies: Optional[FileDescr_dependencies] = None 2378 """Custom depencies beyond pytorch described in a Conda environment file. 2379 Allows to specify custom dependencies, see conda docs: 2380 - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms) 2381 - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually) 2382 2383 The conda environment file should include pytorch and any version pinning has to be compatible with 2384 **pytorch_version**. 2385 """ 2386 2387 2388class TensorflowJsWeightsDescr(WeightsEntryDescrBase): 2389 type = "tensorflow_js" 2390 weights_format_name: ClassVar[str] = "Tensorflow.js" 2391 tensorflow_version: Version 2392 """Version of the TensorFlow library used.""" 2393 2394 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 2395 """The multi-file weights. 2396 All required files/folders should be a zip archive.""" 2397 2398 2399class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase): 2400 type = "tensorflow_saved_model_bundle" 2401 weights_format_name: ClassVar[str] = "Tensorflow Saved Model" 2402 tensorflow_version: Version 2403 """Version of the TensorFlow library used.""" 2404 2405 dependencies: Optional[FileDescr_dependencies] = None 2406 """Custom dependencies beyond tensorflow. 2407 Should include tensorflow and any version pinning has to be compatible with **tensorflow_version**.""" 2408 2409 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 2410 """The multi-file weights. 2411 All required files/folders should be a zip archive.""" 2412 2413 2414class TorchscriptWeightsDescr(WeightsEntryDescrBase): 2415 type = "torchscript" 2416 weights_format_name: ClassVar[str] = "TorchScript" 2417 pytorch_version: Version 2418 """Version of the PyTorch library used.""" 2419 2420 2421class WeightsDescr(Node): 2422 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None 2423 onnx: Optional[OnnxWeightsDescr] = None 2424 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None 2425 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None 2426 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = ( 2427 None 2428 ) 2429 torchscript: Optional[TorchscriptWeightsDescr] = None 2430 2431 @model_validator(mode="after") 2432 def check_entries(self) -> Self: 2433 entries = {wtype for wtype, entry in self if entry is not None} 2434 2435 if not entries: 2436 raise ValueError("Missing weights entry") 2437 2438 entries_wo_parent = { 2439 wtype 2440 for wtype, entry in self 2441 if entry is not None and hasattr(entry, "parent") and entry.parent is None 2442 } 2443 if len(entries_wo_parent) != 1: 2444 issue_warning( 2445 "Exactly one weights entry may not specify the `parent` field (got" 2446 + " {value}). That entry is considered the original set of model weights." 2447 + " Other weight formats are created through conversion of the orignal or" 2448 + " already converted weights. They have to reference the weights format" 2449 + " they were converted from as their `parent`.", 2450 value=len(entries_wo_parent), 2451 field="weights", 2452 ) 2453 2454 for wtype, entry in self: 2455 if entry is None: 2456 continue 2457 2458 assert hasattr(entry, "type") 2459 assert hasattr(entry, "parent") 2460 assert wtype == entry.type 2461 if ( 2462 entry.parent is not None and entry.parent not in entries 2463 ): # self reference checked for `parent` field 2464 raise ValueError( 2465 f"`weights.{wtype}.parent={entry.parent} not in specified weight" 2466 + f" formats: {entries}" 2467 ) 2468 2469 return self 2470 2471 def __getitem__( 2472 self, 2473 key: Literal[ 2474 "keras_hdf5", 2475 "onnx", 2476 "pytorch_state_dict", 2477 "tensorflow_js", 2478 "tensorflow_saved_model_bundle", 2479 "torchscript", 2480 ], 2481 ): 2482 if key == "keras_hdf5": 2483 ret = self.keras_hdf5 2484 elif key == "onnx": 2485 ret = self.onnx 2486 elif key == "pytorch_state_dict": 2487 ret = self.pytorch_state_dict 2488 elif key == "tensorflow_js": 2489 ret = self.tensorflow_js 2490 elif key == "tensorflow_saved_model_bundle": 2491 ret = self.tensorflow_saved_model_bundle 2492 elif key == "torchscript": 2493 ret = self.torchscript 2494 else: 2495 raise KeyError(key) 2496 2497 if ret is None: 2498 raise KeyError(key) 2499 2500 return ret 2501 2502 @property 2503 def available_formats(self): 2504 return { 2505 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}), 2506 **({} if self.onnx is None else {"onnx": self.onnx}), 2507 **( 2508 {} 2509 if self.pytorch_state_dict is None 2510 else {"pytorch_state_dict": self.pytorch_state_dict} 2511 ), 2512 **( 2513 {} 2514 if self.tensorflow_js is None 2515 else {"tensorflow_js": self.tensorflow_js} 2516 ), 2517 **( 2518 {} 2519 if self.tensorflow_saved_model_bundle is None 2520 else { 2521 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle 2522 } 2523 ), 2524 **({} if self.torchscript is None else {"torchscript": self.torchscript}), 2525 } 2526 2527 @property 2528 def missing_formats(self): 2529 return { 2530 wf for wf in get_args(WeightsFormat) if wf not in self.available_formats 2531 } 2532 2533 2534class ModelId(ResourceId): 2535 pass 2536 2537 2538class LinkedModel(LinkedResourceBase): 2539 """Reference to a bioimage.io model.""" 2540 2541 id: ModelId 2542 """A valid model `id` from the bioimage.io collection.""" 2543 2544 2545class _DataDepSize(NamedTuple): 2546 min: StrictInt 2547 max: Optional[StrictInt] 2548 2549 2550class _AxisSizes(NamedTuple): 2551 """the lenghts of all axes of model inputs and outputs""" 2552 2553 inputs: Dict[Tuple[TensorId, AxisId], int] 2554 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] 2555 2556 2557class _TensorSizes(NamedTuple): 2558 """_AxisSizes as nested dicts""" 2559 2560 inputs: Dict[TensorId, Dict[AxisId, int]] 2561 outputs: Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]] 2562 2563 2564class ReproducibilityTolerance(Node, extra="allow"): 2565 """Describes what small numerical differences -- if any -- may be tolerated 2566 in the generated output when executing in different environments. 2567 2568 A tensor element *output* is considered mismatched to the **test_tensor** if 2569 abs(*output* - **test_tensor**) > **absolute_tolerance** + **relative_tolerance** * abs(**test_tensor**). 2570 (Internally we call [numpy.testing.assert_allclose](https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html).) 2571 2572 Motivation: 2573 For testing we can request the respective deep learning frameworks to be as 2574 reproducible as possible by setting seeds and chosing deterministic algorithms, 2575 but differences in operating systems, available hardware and installed drivers 2576 may still lead to numerical differences. 2577 """ 2578 2579 relative_tolerance: RelativeTolerance = 1e-3 2580 """Maximum relative tolerance of reproduced test tensor.""" 2581 2582 absolute_tolerance: AbsoluteTolerance = 1e-4 2583 """Maximum absolute tolerance of reproduced test tensor.""" 2584 2585 mismatched_elements_per_million: MismatchedElementsPerMillion = 100 2586 """Maximum number of mismatched elements/pixels per million to tolerate.""" 2587 2588 output_ids: Sequence[TensorId] = () 2589 """Limits the output tensor IDs these reproducibility details apply to.""" 2590 2591 weights_formats: Sequence[WeightsFormat] = () 2592 """Limits the weights formats these details apply to.""" 2593 2594 2595class BioimageioConfig(Node, extra="allow"): 2596 reproducibility_tolerance: Sequence[ReproducibilityTolerance] = () 2597 """Tolerances to allow when reproducing the model's test outputs 2598 from the model's test inputs. 2599 Only the first entry matching tensor id and weights format is considered. 2600 """ 2601 2602 2603class Config(Node, extra="allow"): 2604 bioimageio: BioimageioConfig = Field( 2605 default_factory=BioimageioConfig.model_construct 2606 ) 2607 2608 2609class ModelDescr(GenericModelDescrBase): 2610 """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. 2611 These fields are typically stored in a YAML file which we call a model resource description file (model RDF). 2612 """ 2613 2614 implemented_format_version: ClassVar[Literal["0.5.5"]] = "0.5.5" 2615 if TYPE_CHECKING: 2616 format_version: Literal["0.5.5"] = "0.5.5" 2617 else: 2618 format_version: Literal["0.5.5"] 2619 """Version of the bioimage.io model description specification used. 2620 When creating a new model always use the latest micro/patch version described here. 2621 The `format_version` is important for any consumer software to understand how to parse the fields. 2622 """ 2623 2624 implemented_type: ClassVar[Literal["model"]] = "model" 2625 if TYPE_CHECKING: 2626 type: Literal["model"] = "model" 2627 else: 2628 type: Literal["model"] 2629 """Specialized resource type 'model'""" 2630 2631 id: Optional[ModelId] = None 2632 """bioimage.io-wide unique resource identifier 2633 assigned by bioimage.io; version **un**specific.""" 2634 2635 authors: FAIR[List[Author]] = Field( 2636 default_factory=cast(Callable[[], List[Author]], list) 2637 ) 2638 """The authors are the creators of the model RDF and the primary points of contact.""" 2639 2640 documentation: FAIR[Optional[FileSource_documentation]] = None 2641 """URL or relative path to a markdown file with additional documentation. 2642 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory. 2643 The documentation should include a '#[#] Validation' (sub)section 2644 with details on how to quantitatively validate the model on unseen data.""" 2645 2646 @field_validator("documentation", mode="after") 2647 @classmethod 2648 def _validate_documentation( 2649 cls, value: Optional[FileSource_documentation] 2650 ) -> Optional[FileSource_documentation]: 2651 if not get_validation_context().perform_io_checks or value is None: 2652 return value 2653 2654 doc_reader = get_reader(value) 2655 doc_content = doc_reader.read().decode(encoding="utf-8") 2656 if not re.search("#.*[vV]alidation", doc_content): 2657 issue_warning( 2658 "No '# Validation' (sub)section found in {value}.", 2659 value=value, 2660 field="documentation", 2661 ) 2662 2663 return value 2664 2665 inputs: NotEmpty[Sequence[InputTensorDescr]] 2666 """Describes the input tensors expected by this model.""" 2667 2668 @field_validator("inputs", mode="after") 2669 @classmethod 2670 def _validate_input_axes( 2671 cls, inputs: Sequence[InputTensorDescr] 2672 ) -> Sequence[InputTensorDescr]: 2673 input_size_refs = cls._get_axes_with_independent_size(inputs) 2674 2675 for i, ipt in enumerate(inputs): 2676 valid_independent_refs: Dict[ 2677 Tuple[TensorId, AxisId], 2678 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2679 ] = { 2680 **{ 2681 (ipt.id, a.id): (ipt, a, a.size) 2682 for a in ipt.axes 2683 if not isinstance(a, BatchAxis) 2684 and isinstance(a.size, (int, ParameterizedSize)) 2685 }, 2686 **input_size_refs, 2687 } 2688 for a, ax in enumerate(ipt.axes): 2689 cls._validate_axis( 2690 "inputs", 2691 i=i, 2692 tensor_id=ipt.id, 2693 a=a, 2694 axis=ax, 2695 valid_independent_refs=valid_independent_refs, 2696 ) 2697 return inputs 2698 2699 @staticmethod 2700 def _validate_axis( 2701 field_name: str, 2702 i: int, 2703 tensor_id: TensorId, 2704 a: int, 2705 axis: AnyAxis, 2706 valid_independent_refs: Dict[ 2707 Tuple[TensorId, AxisId], 2708 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2709 ], 2710 ): 2711 if isinstance(axis, BatchAxis) or isinstance( 2712 axis.size, (int, ParameterizedSize, DataDependentSize) 2713 ): 2714 return 2715 elif not isinstance(axis.size, SizeReference): 2716 assert_never(axis.size) 2717 2718 # validate axis.size SizeReference 2719 ref = (axis.size.tensor_id, axis.size.axis_id) 2720 if ref not in valid_independent_refs: 2721 raise ValueError( 2722 "Invalid tensor axis reference at" 2723 + f" {field_name}[{i}].axes[{a}].size: {axis.size}." 2724 ) 2725 if ref == (tensor_id, axis.id): 2726 raise ValueError( 2727 "Self-referencing not allowed for" 2728 + f" {field_name}[{i}].axes[{a}].size: {axis.size}" 2729 ) 2730 if axis.type == "channel": 2731 if valid_independent_refs[ref][1].type != "channel": 2732 raise ValueError( 2733 "A channel axis' size may only reference another fixed size" 2734 + " channel axis." 2735 ) 2736 if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names: 2737 ref_size = valid_independent_refs[ref][2] 2738 assert isinstance(ref_size, int), ( 2739 "channel axis ref (another channel axis) has to specify fixed" 2740 + " size" 2741 ) 2742 generated_channel_names = [ 2743 Identifier(axis.channel_names.format(i=i)) 2744 for i in range(1, ref_size + 1) 2745 ] 2746 axis.channel_names = generated_channel_names 2747 2748 if (ax_unit := getattr(axis, "unit", None)) != ( 2749 ref_unit := getattr(valid_independent_refs[ref][1], "unit", None) 2750 ): 2751 raise ValueError( 2752 "The units of an axis and its reference axis need to match, but" 2753 + f" '{ax_unit}' != '{ref_unit}'." 2754 ) 2755 ref_axis = valid_independent_refs[ref][1] 2756 if isinstance(ref_axis, BatchAxis): 2757 raise ValueError( 2758 f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}" 2759 + " (a batch axis is not allowed as reference)." 2760 ) 2761 2762 if isinstance(axis, WithHalo): 2763 min_size = axis.size.get_size(axis, ref_axis, n=0) 2764 if (min_size - 2 * axis.halo) < 1: 2765 raise ValueError( 2766 f"axis {axis.id} with minimum size {min_size} is too small for halo" 2767 + f" {axis.halo}." 2768 ) 2769 2770 input_halo = axis.halo * axis.scale / ref_axis.scale 2771 if input_halo != int(input_halo) or input_halo % 2 == 1: 2772 raise ValueError( 2773 f"input_halo {input_halo} (output_halo {axis.halo} *" 2774 + f" output_scale {axis.scale} / input_scale {ref_axis.scale})" 2775 + f" {tensor_id}.{axis.id}." 2776 ) 2777 2778 @model_validator(mode="after") 2779 def _validate_test_tensors(self) -> Self: 2780 if not get_validation_context().perform_io_checks: 2781 return self 2782 2783 test_output_arrays = [ 2784 None if descr.test_tensor is None else load_array(descr.test_tensor) 2785 for descr in self.outputs 2786 ] 2787 test_input_arrays = [ 2788 None if descr.test_tensor is None else load_array(descr.test_tensor) 2789 for descr in self.inputs 2790 ] 2791 2792 tensors = { 2793 descr.id: (descr, array) 2794 for descr, array in zip( 2795 chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays 2796 ) 2797 } 2798 validate_tensors(tensors, tensor_origin="test_tensor") 2799 2800 output_arrays = { 2801 descr.id: array for descr, array in zip(self.outputs, test_output_arrays) 2802 } 2803 for rep_tol in self.config.bioimageio.reproducibility_tolerance: 2804 if not rep_tol.absolute_tolerance: 2805 continue 2806 2807 if rep_tol.output_ids: 2808 out_arrays = { 2809 oid: a 2810 for oid, a in output_arrays.items() 2811 if oid in rep_tol.output_ids 2812 } 2813 else: 2814 out_arrays = output_arrays 2815 2816 for out_id, array in out_arrays.items(): 2817 if array is None: 2818 continue 2819 2820 if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01: 2821 raise ValueError( 2822 "config.bioimageio.reproducibility_tolerance.absolute_tolerance=" 2823 + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}" 2824 + f" (1% of the maximum value of the test tensor '{out_id}')" 2825 ) 2826 2827 return self 2828 2829 @model_validator(mode="after") 2830 def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self: 2831 ipt_refs = {t.id for t in self.inputs} 2832 out_refs = {t.id for t in self.outputs} 2833 for ipt in self.inputs: 2834 for p in ipt.preprocessing: 2835 ref = p.kwargs.get("reference_tensor") 2836 if ref is None: 2837 continue 2838 if ref not in ipt_refs: 2839 raise ValueError( 2840 f"`reference_tensor` '{ref}' not found. Valid input tensor" 2841 + f" references are: {ipt_refs}." 2842 ) 2843 2844 for out in self.outputs: 2845 for p in out.postprocessing: 2846 ref = p.kwargs.get("reference_tensor") 2847 if ref is None: 2848 continue 2849 2850 if ref not in ipt_refs and ref not in out_refs: 2851 raise ValueError( 2852 f"`reference_tensor` '{ref}' not found. Valid tensor references" 2853 + f" are: {ipt_refs | out_refs}." 2854 ) 2855 2856 return self 2857 2858 # TODO: use validate funcs in validate_test_tensors 2859 # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]: 2860 2861 name: Annotated[ 2862 str, 2863 RestrictCharacters(string.ascii_letters + string.digits + "_+- ()"), 2864 MinLen(5), 2865 MaxLen(128), 2866 warn(MaxLen(64), "Name longer than 64 characters.", INFO), 2867 ] 2868 """A human-readable name of this model. 2869 It should be no longer than 64 characters 2870 and may only contain letter, number, underscore, minus, parentheses and spaces. 2871 We recommend to chose a name that refers to the model's task and image modality. 2872 """ 2873 2874 outputs: NotEmpty[Sequence[OutputTensorDescr]] 2875 """Describes the output tensors.""" 2876 2877 @field_validator("outputs", mode="after") 2878 @classmethod 2879 def _validate_tensor_ids( 2880 cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo 2881 ) -> Sequence[OutputTensorDescr]: 2882 tensor_ids = [ 2883 t.id for t in info.data.get("inputs", []) + info.data.get("outputs", []) 2884 ] 2885 duplicate_tensor_ids: List[str] = [] 2886 seen: Set[str] = set() 2887 for t in tensor_ids: 2888 if t in seen: 2889 duplicate_tensor_ids.append(t) 2890 2891 seen.add(t) 2892 2893 if duplicate_tensor_ids: 2894 raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}") 2895 2896 return outputs 2897 2898 @staticmethod 2899 def _get_axes_with_parameterized_size( 2900 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 2901 ): 2902 return { 2903 f"{t.id}.{a.id}": (t, a, a.size) 2904 for t in io 2905 for a in t.axes 2906 if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize) 2907 } 2908 2909 @staticmethod 2910 def _get_axes_with_independent_size( 2911 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 2912 ): 2913 return { 2914 (t.id, a.id): (t, a, a.size) 2915 for t in io 2916 for a in t.axes 2917 if not isinstance(a, BatchAxis) 2918 and isinstance(a.size, (int, ParameterizedSize)) 2919 } 2920 2921 @field_validator("outputs", mode="after") 2922 @classmethod 2923 def _validate_output_axes( 2924 cls, outputs: List[OutputTensorDescr], info: ValidationInfo 2925 ) -> List[OutputTensorDescr]: 2926 input_size_refs = cls._get_axes_with_independent_size( 2927 info.data.get("inputs", []) 2928 ) 2929 output_size_refs = cls._get_axes_with_independent_size(outputs) 2930 2931 for i, out in enumerate(outputs): 2932 valid_independent_refs: Dict[ 2933 Tuple[TensorId, AxisId], 2934 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2935 ] = { 2936 **{ 2937 (out.id, a.id): (out, a, a.size) 2938 for a in out.axes 2939 if not isinstance(a, BatchAxis) 2940 and isinstance(a.size, (int, ParameterizedSize)) 2941 }, 2942 **input_size_refs, 2943 **output_size_refs, 2944 } 2945 for a, ax in enumerate(out.axes): 2946 cls._validate_axis( 2947 "outputs", 2948 i, 2949 out.id, 2950 a, 2951 ax, 2952 valid_independent_refs=valid_independent_refs, 2953 ) 2954 2955 return outputs 2956 2957 packaged_by: List[Author] = Field( 2958 default_factory=cast(Callable[[], List[Author]], list) 2959 ) 2960 """The persons that have packaged and uploaded this model. 2961 Only required if those persons differ from the `authors`.""" 2962 2963 parent: Optional[LinkedModel] = None 2964 """The model from which this model is derived, e.g. by fine-tuning the weights.""" 2965 2966 @model_validator(mode="after") 2967 def _validate_parent_is_not_self(self) -> Self: 2968 if self.parent is not None and self.parent.id == self.id: 2969 raise ValueError("A model description may not reference itself as parent.") 2970 2971 return self 2972 2973 run_mode: Annotated[ 2974 Optional[RunMode], 2975 warn(None, "Run mode '{value}' has limited support across consumer softwares."), 2976 ] = None 2977 """Custom run mode for this model: for more complex prediction procedures like test time 2978 data augmentation that currently cannot be expressed in the specification. 2979 No standard run modes are defined yet.""" 2980 2981 timestamp: Datetime = Field(default_factory=Datetime.now) 2982 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format 2983 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat). 2984 (In Python a datetime object is valid, too).""" 2985 2986 training_data: Annotated[ 2987 Union[None, LinkedDataset, DatasetDescr, DatasetDescr02], 2988 Field(union_mode="left_to_right"), 2989 ] = None 2990 """The dataset used to train this model""" 2991 2992 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)] 2993 """The weights for this model. 2994 Weights can be given for different formats, but should otherwise be equivalent. 2995 The available weight formats determine which consumers can use this model.""" 2996 2997 config: Config = Field(default_factory=Config.model_construct) 2998 2999 @model_validator(mode="after") 3000 def _add_default_cover(self) -> Self: 3001 if not get_validation_context().perform_io_checks or self.covers: 3002 return self 3003 3004 try: 3005 generated_covers = generate_covers( 3006 [ 3007 (t, load_array(t.test_tensor)) 3008 for t in self.inputs 3009 if t.test_tensor is not None 3010 ], 3011 [ 3012 (t, load_array(t.test_tensor)) 3013 for t in self.outputs 3014 if t.test_tensor is not None 3015 ], 3016 ) 3017 except Exception as e: 3018 issue_warning( 3019 "Failed to generate cover image(s): {e}", 3020 value=self.covers, 3021 msg_context=dict(e=e), 3022 field="covers", 3023 ) 3024 else: 3025 self.covers.extend(generated_covers) 3026 3027 return self 3028 3029 def get_input_test_arrays(self) -> List[NDArray[Any]]: 3030 return self._get_test_arrays(self.inputs) 3031 3032 def get_output_test_arrays(self) -> List[NDArray[Any]]: 3033 return self._get_test_arrays(self.outputs) 3034 3035 @staticmethod 3036 def _get_test_arrays( 3037 io_descr: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 3038 ): 3039 ts: List[FileDescr] = [] 3040 for d in io_descr: 3041 if d.test_tensor is None: 3042 raise ValueError( 3043 f"Failed to get test arrays: description of '{d.id}' is missing a `test_tensor`." 3044 ) 3045 ts.append(d.test_tensor) 3046 3047 data = [load_array(t) for t in ts] 3048 assert all(isinstance(d, np.ndarray) for d in data) 3049 return data 3050 3051 @staticmethod 3052 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int: 3053 batch_size = 1 3054 tensor_with_batchsize: Optional[TensorId] = None 3055 for tid in tensor_sizes: 3056 for aid, s in tensor_sizes[tid].items(): 3057 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size: 3058 continue 3059 3060 if batch_size != 1: 3061 assert tensor_with_batchsize is not None 3062 raise ValueError( 3063 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})" 3064 ) 3065 3066 batch_size = s 3067 tensor_with_batchsize = tid 3068 3069 return batch_size 3070 3071 def get_output_tensor_sizes( 3072 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]] 3073 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]: 3074 """Returns the tensor output sizes for given **input_sizes**. 3075 Only if **input_sizes** has a valid input shape, the tensor output size is exact. 3076 Otherwise it might be larger than the actual (valid) output""" 3077 batch_size = self.get_batch_size(input_sizes) 3078 ns = self.get_ns(input_sizes) 3079 3080 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size) 3081 return tensor_sizes.outputs 3082 3083 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]): 3084 """get parameter `n` for each parameterized axis 3085 such that the valid input size is >= the given input size""" 3086 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {} 3087 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs} 3088 for tid in input_sizes: 3089 for aid, s in input_sizes[tid].items(): 3090 size_descr = axes[tid][aid].size 3091 if isinstance(size_descr, ParameterizedSize): 3092 ret[(tid, aid)] = size_descr.get_n(s) 3093 elif size_descr is None or isinstance(size_descr, (int, SizeReference)): 3094 pass 3095 else: 3096 assert_never(size_descr) 3097 3098 return ret 3099 3100 def get_tensor_sizes( 3101 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int 3102 ) -> _TensorSizes: 3103 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size) 3104 return _TensorSizes( 3105 { 3106 t: { 3107 aa: axis_sizes.inputs[(tt, aa)] 3108 for tt, aa in axis_sizes.inputs 3109 if tt == t 3110 } 3111 for t in {tt for tt, _ in axis_sizes.inputs} 3112 }, 3113 { 3114 t: { 3115 aa: axis_sizes.outputs[(tt, aa)] 3116 for tt, aa in axis_sizes.outputs 3117 if tt == t 3118 } 3119 for t in {tt for tt, _ in axis_sizes.outputs} 3120 }, 3121 ) 3122 3123 def get_axis_sizes( 3124 self, 3125 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], 3126 batch_size: Optional[int] = None, 3127 *, 3128 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None, 3129 ) -> _AxisSizes: 3130 """Determine input and output block shape for scale factors **ns** 3131 of parameterized input sizes. 3132 3133 Args: 3134 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id)) 3135 that is parameterized as `size = min + n * step`. 3136 batch_size: The desired size of the batch dimension. 3137 If given **batch_size** overwrites any batch size present in 3138 **max_input_shape**. Default 1. 3139 max_input_shape: Limits the derived block shapes. 3140 Each axis for which the input size, parameterized by `n`, is larger 3141 than **max_input_shape** is set to the minimal value `n_min` for which 3142 this is still true. 3143 Use this for small input samples or large values of **ns**. 3144 Or simply whenever you know the full input shape. 3145 3146 Returns: 3147 Resolved axis sizes for model inputs and outputs. 3148 """ 3149 max_input_shape = max_input_shape or {} 3150 if batch_size is None: 3151 for (_t_id, a_id), s in max_input_shape.items(): 3152 if a_id == BATCH_AXIS_ID: 3153 batch_size = s 3154 break 3155 else: 3156 batch_size = 1 3157 3158 all_axes = { 3159 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs) 3160 } 3161 3162 inputs: Dict[Tuple[TensorId, AxisId], int] = {} 3163 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {} 3164 3165 def get_axis_size(a: Union[InputAxis, OutputAxis]): 3166 if isinstance(a, BatchAxis): 3167 if (t_descr.id, a.id) in ns: 3168 logger.warning( 3169 "Ignoring unexpected size increment factor (n) for batch axis" 3170 + " of tensor '{}'.", 3171 t_descr.id, 3172 ) 3173 return batch_size 3174 elif isinstance(a.size, int): 3175 if (t_descr.id, a.id) in ns: 3176 logger.warning( 3177 "Ignoring unexpected size increment factor (n) for fixed size" 3178 + " axis '{}' of tensor '{}'.", 3179 a.id, 3180 t_descr.id, 3181 ) 3182 return a.size 3183 elif isinstance(a.size, ParameterizedSize): 3184 if (t_descr.id, a.id) not in ns: 3185 raise ValueError( 3186 "Size increment factor (n) missing for parametrized axis" 3187 + f" '{a.id}' of tensor '{t_descr.id}'." 3188 ) 3189 n = ns[(t_descr.id, a.id)] 3190 s_max = max_input_shape.get((t_descr.id, a.id)) 3191 if s_max is not None: 3192 n = min(n, a.size.get_n(s_max)) 3193 3194 return a.size.get_size(n) 3195 3196 elif isinstance(a.size, SizeReference): 3197 if (t_descr.id, a.id) in ns: 3198 logger.warning( 3199 "Ignoring unexpected size increment factor (n) for axis '{}'" 3200 + " of tensor '{}' with size reference.", 3201 a.id, 3202 t_descr.id, 3203 ) 3204 assert not isinstance(a, BatchAxis) 3205 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id] 3206 assert not isinstance(ref_axis, BatchAxis) 3207 ref_key = (a.size.tensor_id, a.size.axis_id) 3208 ref_size = inputs.get(ref_key, outputs.get(ref_key)) 3209 assert ref_size is not None, ref_key 3210 assert not isinstance(ref_size, _DataDepSize), ref_key 3211 return a.size.get_size( 3212 axis=a, 3213 ref_axis=ref_axis, 3214 ref_size=ref_size, 3215 ) 3216 elif isinstance(a.size, DataDependentSize): 3217 if (t_descr.id, a.id) in ns: 3218 logger.warning( 3219 "Ignoring unexpected increment factor (n) for data dependent" 3220 + " size axis '{}' of tensor '{}'.", 3221 a.id, 3222 t_descr.id, 3223 ) 3224 return _DataDepSize(a.size.min, a.size.max) 3225 else: 3226 assert_never(a.size) 3227 3228 # first resolve all , but the `SizeReference` input sizes 3229 for t_descr in self.inputs: 3230 for a in t_descr.axes: 3231 if not isinstance(a.size, SizeReference): 3232 s = get_axis_size(a) 3233 assert not isinstance(s, _DataDepSize) 3234 inputs[t_descr.id, a.id] = s 3235 3236 # resolve all other input axis sizes 3237 for t_descr in self.inputs: 3238 for a in t_descr.axes: 3239 if isinstance(a.size, SizeReference): 3240 s = get_axis_size(a) 3241 assert not isinstance(s, _DataDepSize) 3242 inputs[t_descr.id, a.id] = s 3243 3244 # resolve all output axis sizes 3245 for t_descr in self.outputs: 3246 for a in t_descr.axes: 3247 assert not isinstance(a.size, ParameterizedSize) 3248 s = get_axis_size(a) 3249 outputs[t_descr.id, a.id] = s 3250 3251 return _AxisSizes(inputs=inputs, outputs=outputs) 3252 3253 @model_validator(mode="before") 3254 @classmethod 3255 def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]: 3256 cls.convert_from_old_format_wo_validation(data) 3257 return data 3258 3259 @classmethod 3260 def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None: 3261 """Convert metadata following an older format version to this classes' format 3262 without validating the result. 3263 """ 3264 if ( 3265 data.get("type") == "model" 3266 and isinstance(fv := data.get("format_version"), str) 3267 and fv.count(".") == 2 3268 ): 3269 fv_parts = fv.split(".") 3270 if any(not p.isdigit() for p in fv_parts): 3271 return 3272 3273 fv_tuple = tuple(map(int, fv_parts)) 3274 3275 assert cls.implemented_format_version_tuple[0:2] == (0, 5) 3276 if fv_tuple[:2] in ((0, 3), (0, 4)): 3277 m04 = _ModelDescr_v0_4.load(data) 3278 if isinstance(m04, InvalidDescr): 3279 try: 3280 updated = _model_conv.convert_as_dict( 3281 m04 # pyright: ignore[reportArgumentType] 3282 ) 3283 except Exception as e: 3284 logger.error( 3285 "Failed to convert from invalid model 0.4 description." 3286 + f"\nerror: {e}" 3287 + "\nProceeding with model 0.5 validation without conversion." 3288 ) 3289 updated = None 3290 else: 3291 updated = _model_conv.convert_as_dict(m04) 3292 3293 if updated is not None: 3294 data.clear() 3295 data.update(updated) 3296 3297 elif fv_tuple[:2] == (0, 5): 3298 # bump patch version 3299 data["format_version"] = cls.implemented_format_version 3300 3301 3302class _ModelConv(Converter[_ModelDescr_v0_4, ModelDescr]): 3303 def _convert( 3304 self, src: _ModelDescr_v0_4, tgt: "type[ModelDescr] | type[dict[str, Any]]" 3305 ) -> "ModelDescr | dict[str, Any]": 3306 name = "".join( 3307 c if c in string.ascii_letters + string.digits + "_+- ()" else " " 3308 for c in src.name 3309 ) 3310 3311 def conv_authors(auths: Optional[Sequence[_Author_v0_4]]): 3312 conv = ( 3313 _author_conv.convert if TYPE_CHECKING else _author_conv.convert_as_dict 3314 ) 3315 return None if auths is None else [conv(a) for a in auths] 3316 3317 if TYPE_CHECKING: 3318 arch_file_conv = _arch_file_conv.convert 3319 arch_lib_conv = _arch_lib_conv.convert 3320 else: 3321 arch_file_conv = _arch_file_conv.convert_as_dict 3322 arch_lib_conv = _arch_lib_conv.convert_as_dict 3323 3324 input_size_refs = { 3325 ipt.name: { 3326 a: s 3327 for a, s in zip( 3328 ipt.axes, 3329 ( 3330 ipt.shape.min 3331 if isinstance(ipt.shape, _ParameterizedInputShape_v0_4) 3332 else ipt.shape 3333 ), 3334 ) 3335 } 3336 for ipt in src.inputs 3337 if ipt.shape 3338 } 3339 output_size_refs = { 3340 **{ 3341 out.name: {a: s for a, s in zip(out.axes, out.shape)} 3342 for out in src.outputs 3343 if not isinstance(out.shape, _ImplicitOutputShape_v0_4) 3344 }, 3345 **input_size_refs, 3346 } 3347 3348 return tgt( 3349 attachments=( 3350 [] 3351 if src.attachments is None 3352 else [FileDescr(source=f) for f in src.attachments.files] 3353 ), 3354 authors=[_author_conv.convert_as_dict(a) for a in src.authors], # pyright: ignore[reportArgumentType] 3355 cite=[{"text": c.text, "doi": c.doi, "url": c.url} for c in src.cite], # pyright: ignore[reportArgumentType] 3356 config=src.config, # pyright: ignore[reportArgumentType] 3357 covers=src.covers, 3358 description=src.description, 3359 documentation=src.documentation, 3360 format_version="0.5.5", 3361 git_repo=src.git_repo, # pyright: ignore[reportArgumentType] 3362 icon=src.icon, 3363 id=None if src.id is None else ModelId(src.id), 3364 id_emoji=src.id_emoji, 3365 license=src.license, # type: ignore 3366 links=src.links, 3367 maintainers=[_maintainer_conv.convert_as_dict(m) for m in src.maintainers], # pyright: ignore[reportArgumentType] 3368 name=name, 3369 tags=src.tags, 3370 type=src.type, 3371 uploader=src.uploader, 3372 version=src.version, 3373 inputs=[ # pyright: ignore[reportArgumentType] 3374 _input_tensor_conv.convert_as_dict(ipt, tt, st, input_size_refs) 3375 for ipt, tt, st in zip( 3376 src.inputs, 3377 src.test_inputs, 3378 src.sample_inputs or [None] * len(src.test_inputs), 3379 ) 3380 ], 3381 outputs=[ # pyright: ignore[reportArgumentType] 3382 _output_tensor_conv.convert_as_dict(out, tt, st, output_size_refs) 3383 for out, tt, st in zip( 3384 src.outputs, 3385 src.test_outputs, 3386 src.sample_outputs or [None] * len(src.test_outputs), 3387 ) 3388 ], 3389 parent=( 3390 None 3391 if src.parent is None 3392 else LinkedModel( 3393 id=ModelId( 3394 str(src.parent.id) 3395 + ( 3396 "" 3397 if src.parent.version_number is None 3398 else f"/{src.parent.version_number}" 3399 ) 3400 ) 3401 ) 3402 ), 3403 training_data=( 3404 None 3405 if src.training_data is None 3406 else ( 3407 LinkedDataset( 3408 id=DatasetId( 3409 str(src.training_data.id) 3410 + ( 3411 "" 3412 if src.training_data.version_number is None 3413 else f"/{src.training_data.version_number}" 3414 ) 3415 ) 3416 ) 3417 if isinstance(src.training_data, LinkedDataset02) 3418 else src.training_data 3419 ) 3420 ), 3421 packaged_by=[_author_conv.convert_as_dict(a) for a in src.packaged_by], # pyright: ignore[reportArgumentType] 3422 run_mode=src.run_mode, 3423 timestamp=src.timestamp, 3424 weights=(WeightsDescr if TYPE_CHECKING else dict)( 3425 keras_hdf5=(w := src.weights.keras_hdf5) 3426 and (KerasHdf5WeightsDescr if TYPE_CHECKING else dict)( 3427 authors=conv_authors(w.authors), 3428 source=w.source, 3429 tensorflow_version=w.tensorflow_version or Version("1.15"), 3430 parent=w.parent, 3431 ), 3432 onnx=(w := src.weights.onnx) 3433 and (OnnxWeightsDescr if TYPE_CHECKING else dict)( 3434 source=w.source, 3435 authors=conv_authors(w.authors), 3436 parent=w.parent, 3437 opset_version=w.opset_version or 15, 3438 ), 3439 pytorch_state_dict=(w := src.weights.pytorch_state_dict) 3440 and (PytorchStateDictWeightsDescr if TYPE_CHECKING else dict)( 3441 source=w.source, 3442 authors=conv_authors(w.authors), 3443 parent=w.parent, 3444 architecture=( 3445 arch_file_conv( 3446 w.architecture, 3447 w.architecture_sha256, 3448 w.kwargs, 3449 ) 3450 if isinstance(w.architecture, _CallableFromFile_v0_4) 3451 else arch_lib_conv(w.architecture, w.kwargs) 3452 ), 3453 pytorch_version=w.pytorch_version or Version("1.10"), 3454 dependencies=( 3455 None 3456 if w.dependencies is None 3457 else (FileDescr if TYPE_CHECKING else dict)( 3458 source=cast( 3459 FileSource, 3460 str(deps := w.dependencies)[ 3461 ( 3462 len("conda:") 3463 if str(deps).startswith("conda:") 3464 else 0 3465 ) : 3466 ], 3467 ) 3468 ) 3469 ), 3470 ), 3471 tensorflow_js=(w := src.weights.tensorflow_js) 3472 and (TensorflowJsWeightsDescr if TYPE_CHECKING else dict)( 3473 source=w.source, 3474 authors=conv_authors(w.authors), 3475 parent=w.parent, 3476 tensorflow_version=w.tensorflow_version or Version("1.15"), 3477 ), 3478 tensorflow_saved_model_bundle=( 3479 w := src.weights.tensorflow_saved_model_bundle 3480 ) 3481 and (TensorflowSavedModelBundleWeightsDescr if TYPE_CHECKING else dict)( 3482 authors=conv_authors(w.authors), 3483 parent=w.parent, 3484 source=w.source, 3485 tensorflow_version=w.tensorflow_version or Version("1.15"), 3486 dependencies=( 3487 None 3488 if w.dependencies is None 3489 else (FileDescr if TYPE_CHECKING else dict)( 3490 source=cast( 3491 FileSource, 3492 ( 3493 str(w.dependencies)[len("conda:") :] 3494 if str(w.dependencies).startswith("conda:") 3495 else str(w.dependencies) 3496 ), 3497 ) 3498 ) 3499 ), 3500 ), 3501 torchscript=(w := src.weights.torchscript) 3502 and (TorchscriptWeightsDescr if TYPE_CHECKING else dict)( 3503 source=w.source, 3504 authors=conv_authors(w.authors), 3505 parent=w.parent, 3506 pytorch_version=w.pytorch_version or Version("1.10"), 3507 ), 3508 ), 3509 ) 3510 3511 3512_model_conv = _ModelConv(_ModelDescr_v0_4, ModelDescr) 3513 3514 3515# create better cover images for 3d data and non-image outputs 3516def generate_covers( 3517 inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]], 3518 outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]], 3519) -> List[Path]: 3520 def squeeze( 3521 data: NDArray[Any], axes: Sequence[AnyAxis] 3522 ) -> Tuple[NDArray[Any], List[AnyAxis]]: 3523 """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining""" 3524 if data.ndim != len(axes): 3525 raise ValueError( 3526 f"tensor shape {data.shape} does not match described axes" 3527 + f" {[a.id for a in axes]}" 3528 ) 3529 3530 axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1] 3531 return data.squeeze(), axes 3532 3533 def normalize( 3534 data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7 3535 ) -> NDArray[np.float32]: 3536 data = data.astype("float32") 3537 data -= data.min(axis=axis, keepdims=True) 3538 data /= data.max(axis=axis, keepdims=True) + eps 3539 return data 3540 3541 def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]): 3542 original_shape = data.shape 3543 data, axes = squeeze(data, axes) 3544 3545 # take slice fom any batch or index axis if needed 3546 # and convert the first channel axis and take a slice from any additional channel axes 3547 slices: Tuple[slice, ...] = () 3548 ndim = data.ndim 3549 ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2 3550 has_c_axis = False 3551 for i, a in enumerate(axes): 3552 s = data.shape[i] 3553 assert s > 1 3554 if ( 3555 isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis)) 3556 and ndim > ndim_need 3557 ): 3558 data = data[slices + (slice(s // 2 - 1, s // 2),)] 3559 ndim -= 1 3560 elif isinstance(a, ChannelAxis): 3561 if has_c_axis: 3562 # second channel axis 3563 data = data[slices + (slice(0, 1),)] 3564 ndim -= 1 3565 else: 3566 has_c_axis = True 3567 if s == 2: 3568 # visualize two channels with cyan and magenta 3569 data = np.concatenate( 3570 [ 3571 data[slices + (slice(1, 2),)], 3572 data[slices + (slice(0, 1),)], 3573 ( 3574 data[slices + (slice(0, 1),)] 3575 + data[slices + (slice(1, 2),)] 3576 ) 3577 / 2, # TODO: take maximum instead? 3578 ], 3579 axis=i, 3580 ) 3581 elif data.shape[i] == 3: 3582 pass # visualize 3 channels as RGB 3583 else: 3584 # visualize first 3 channels as RGB 3585 data = data[slices + (slice(3),)] 3586 3587 assert data.shape[i] == 3 3588 3589 slices += (slice(None),) 3590 3591 data, axes = squeeze(data, axes) 3592 assert len(axes) == ndim 3593 # take slice from z axis if needed 3594 slices = () 3595 if ndim > ndim_need: 3596 for i, a in enumerate(axes): 3597 s = data.shape[i] 3598 if a.id == AxisId("z"): 3599 data = data[slices + (slice(s // 2 - 1, s // 2),)] 3600 data, axes = squeeze(data, axes) 3601 ndim -= 1 3602 break 3603 3604 slices += (slice(None),) 3605 3606 # take slice from any space or time axis 3607 slices = () 3608 3609 for i, a in enumerate(axes): 3610 if ndim <= ndim_need: 3611 break 3612 3613 s = data.shape[i] 3614 assert s > 1 3615 if isinstance( 3616 a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis) 3617 ): 3618 data = data[slices + (slice(s // 2 - 1, s // 2),)] 3619 ndim -= 1 3620 3621 slices += (slice(None),) 3622 3623 del slices 3624 data, axes = squeeze(data, axes) 3625 assert len(axes) == ndim 3626 3627 if (has_c_axis and ndim != 3) or ndim != 2: 3628 raise ValueError( 3629 f"Failed to construct cover image from shape {original_shape}" 3630 ) 3631 3632 if not has_c_axis: 3633 assert ndim == 2 3634 data = np.repeat(data[:, :, None], 3, axis=2) 3635 axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB")))) 3636 ndim += 1 3637 3638 assert ndim == 3 3639 3640 # transpose axis order such that longest axis comes first... 3641 axis_order: List[int] = list(np.argsort(list(data.shape))) 3642 axis_order.reverse() 3643 # ... and channel axis is last 3644 c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0] 3645 axis_order.append(axis_order.pop(c)) 3646 axes = [axes[ao] for ao in axis_order] 3647 data = data.transpose(axis_order) 3648 3649 # h, w = data.shape[:2] 3650 # if h / w in (1.0 or 2.0): 3651 # pass 3652 # elif h / w < 2: 3653 # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images 3654 3655 norm_along = ( 3656 tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None 3657 ) 3658 # normalize the data and map to 8 bit 3659 data = normalize(data, norm_along) 3660 data = (data * 255).astype("uint8") 3661 3662 return data 3663 3664 def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]): 3665 assert im0.dtype == im1.dtype == np.uint8 3666 assert im0.shape == im1.shape 3667 assert im0.ndim == 3 3668 N, M, C = im0.shape 3669 assert C == 3 3670 out = np.ones((N, M, C), dtype="uint8") 3671 for c in range(C): 3672 outc = np.tril(im0[..., c]) 3673 mask = outc == 0 3674 outc[mask] = np.triu(im1[..., c])[mask] 3675 out[..., c] = outc 3676 3677 return out 3678 3679 if not inputs: 3680 raise ValueError("Missing test input tensor for cover generation.") 3681 3682 if not outputs: 3683 raise ValueError("Missing test output tensor for cover generation.") 3684 3685 ipt_descr, ipt = inputs[0] 3686 out_descr, out = outputs[0] 3687 3688 ipt_img = to_2d_image(ipt, ipt_descr.axes) 3689 out_img = to_2d_image(out, out_descr.axes) 3690 3691 cover_folder = Path(mkdtemp()) 3692 if ipt_img.shape == out_img.shape: 3693 covers = [cover_folder / "cover.png"] 3694 imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img)) 3695 else: 3696 covers = [cover_folder / "input.png", cover_folder / "output.png"] 3697 imwrite(covers[0], ipt_img) 3698 imwrite(covers[1], out_img) 3699 3700 return covers
Space unit compatible to the OME-Zarr axes specification 0.5
Time unit compatible to the OME-Zarr axes specification 0.5
229class TensorId(LowerCaseIdentifier): 230 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 231 Annotated[LowerCaseIdentifierAnno, MaxLen(32)] 232 ]
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.
245class AxisId(LowerCaseIdentifier): 246 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 247 Annotated[ 248 LowerCaseIdentifierAnno, 249 MaxLen(16), 250 AfterValidator(_normalize_axis_id), 251 ] 252 ]
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.
Annotates an integer to calculate a concrete axis size from a ParameterizedSize
.
298class ParameterizedSize(Node): 299 """Describes a range of valid tensor axis sizes as `size = min + n*step`. 300 301 - **min** and **step** are given by the model description. 302 - All blocksize paramters n = 0,1,2,... yield a valid `size`. 303 - A greater blocksize paramter n = 0,1,2,... results in a greater **size**. 304 This allows to adjust the axis size more generically. 305 """ 306 307 N: ClassVar[Type[int]] = ParameterizedSize_N 308 """Positive integer to parameterize this axis""" 309 310 min: Annotated[int, Gt(0)] 311 step: Annotated[int, Gt(0)] 312 313 def validate_size(self, size: int) -> int: 314 if size < self.min: 315 raise ValueError(f"size {size} < {self.min}") 316 if (size - self.min) % self.step != 0: 317 raise ValueError( 318 f"axis of size {size} is not parameterized by `min + n*step` =" 319 + f" `{self.min} + n*{self.step}`" 320 ) 321 322 return size 323 324 def get_size(self, n: ParameterizedSize_N) -> int: 325 return self.min + self.step * n 326 327 def get_n(self, s: int) -> ParameterizedSize_N: 328 """return smallest n parameterizing a size greater or equal than `s`""" 329 return ceil((s - self.min) / self.step)
Describes a range of valid tensor axis sizes as size = min + n*step
.
- min and step are given by the model description.
- All blocksize paramters n = 0,1,2,... yield a valid
size
. - A greater blocksize paramter n = 0,1,2,... results in a greater size. This allows to adjust the axis size more generically.
313 def validate_size(self, size: int) -> int: 314 if size < self.min: 315 raise ValueError(f"size {size} < {self.min}") 316 if (size - self.min) % self.step != 0: 317 raise ValueError( 318 f"axis of size {size} is not parameterized by `min + n*step` =" 319 + f" `{self.min} + n*{self.step}`" 320 ) 321 322 return size
327 def get_n(self, s: int) -> ParameterizedSize_N: 328 """return smallest n parameterizing a size greater or equal than `s`""" 329 return ceil((s - self.min) / self.step)
return smallest n parameterizing a size greater or equal than s
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
332class DataDependentSize(Node): 333 min: Annotated[int, Gt(0)] = 1 334 max: Annotated[Optional[int], Gt(1)] = None 335 336 @model_validator(mode="after") 337 def _validate_max_gt_min(self): 338 if self.max is not None and self.min >= self.max: 339 raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}") 340 341 return self 342 343 def validate_size(self, size: int) -> int: 344 if size < self.min: 345 raise ValueError(f"size {size} < {self.min}") 346 347 if self.max is not None and size > self.max: 348 raise ValueError(f"size {size} > {self.max}") 349 350 return size
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
353class SizeReference(Node): 354 """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis. 355 356 `axis.size = reference.size * reference.scale / axis.scale + offset` 357 358 Note: 359 1. The axis and the referenced axis need to have the same unit (or no unit). 360 2. Batch axes may not be referenced. 361 3. Fractions are rounded down. 362 4. If the reference axis is `concatenable` the referencing axis is assumed to be 363 `concatenable` as well with the same block order. 364 365 Example: 366 An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm². 367 Let's assume that we want to express the image height h in relation to its width w 368 instead of only accepting input images of exactly 100*49 pixels 369 (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`). 370 371 >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2) 372 >>> h = SpaceInputAxis( 373 ... id=AxisId("h"), 374 ... size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1), 375 ... unit="millimeter", 376 ... scale=4, 377 ... ) 378 >>> print(h.size.get_size(h, w)) 379 49 380 381 ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49 382 """ 383 384 tensor_id: TensorId 385 """tensor id of the reference axis""" 386 387 axis_id: AxisId 388 """axis id of the reference axis""" 389 390 offset: StrictInt = 0 391 392 def get_size( 393 self, 394 axis: Union[ 395 ChannelAxis, 396 IndexInputAxis, 397 IndexOutputAxis, 398 TimeInputAxis, 399 SpaceInputAxis, 400 TimeOutputAxis, 401 TimeOutputAxisWithHalo, 402 SpaceOutputAxis, 403 SpaceOutputAxisWithHalo, 404 ], 405 ref_axis: Union[ 406 ChannelAxis, 407 IndexInputAxis, 408 IndexOutputAxis, 409 TimeInputAxis, 410 SpaceInputAxis, 411 TimeOutputAxis, 412 TimeOutputAxisWithHalo, 413 SpaceOutputAxis, 414 SpaceOutputAxisWithHalo, 415 ], 416 n: ParameterizedSize_N = 0, 417 ref_size: Optional[int] = None, 418 ): 419 """Compute the concrete size for a given axis and its reference axis. 420 421 Args: 422 axis: The axis this `SizeReference` is the size of. 423 ref_axis: The reference axis to compute the size from. 424 n: If the **ref_axis** is parameterized (of type `ParameterizedSize`) 425 and no fixed **ref_size** is given, 426 **n** is used to compute the size of the parameterized **ref_axis**. 427 ref_size: Overwrite the reference size instead of deriving it from 428 **ref_axis** 429 (**ref_axis.scale** is still used; any given **n** is ignored). 430 """ 431 assert axis.size == self, ( 432 "Given `axis.size` is not defined by this `SizeReference`" 433 ) 434 435 assert ref_axis.id == self.axis_id, ( 436 f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}." 437 ) 438 439 assert axis.unit == ref_axis.unit, ( 440 "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`," 441 f" but {axis.unit}!={ref_axis.unit}" 442 ) 443 if ref_size is None: 444 if isinstance(ref_axis.size, (int, float)): 445 ref_size = ref_axis.size 446 elif isinstance(ref_axis.size, ParameterizedSize): 447 ref_size = ref_axis.size.get_size(n) 448 elif isinstance(ref_axis.size, DataDependentSize): 449 raise ValueError( 450 "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`." 451 ) 452 elif isinstance(ref_axis.size, SizeReference): 453 raise ValueError( 454 "Reference axis referenced in `SizeReference` may not be sized by a" 455 + " `SizeReference` itself." 456 ) 457 else: 458 assert_never(ref_axis.size) 459 460 return int(ref_size * ref_axis.scale / axis.scale + self.offset) 461 462 @staticmethod 463 def _get_unit( 464 axis: Union[ 465 ChannelAxis, 466 IndexInputAxis, 467 IndexOutputAxis, 468 TimeInputAxis, 469 SpaceInputAxis, 470 TimeOutputAxis, 471 TimeOutputAxisWithHalo, 472 SpaceOutputAxis, 473 SpaceOutputAxisWithHalo, 474 ], 475 ): 476 return axis.unit
A tensor axis size (extent in pixels/frames) defined in relation to a reference axis.
axis.size = reference.size * reference.scale / axis.scale + offset
Note:
- The axis and the referenced axis need to have the same unit (or no unit).
- Batch axes may not be referenced.
- Fractions are rounded down.
- If the reference axis is
concatenable
the referencing axis is assumed to beconcatenable
as well with the same block order.
Example:
An unisotropic input image of wh=10049 pixels depicts a phsical space of 200196mm².
Let's assume that we want to express the image height h in relation to its width w
instead of only accepting input images of exactly 10049 pixels
(for example to express a range of valid image shapes by parametrizing w, see ParameterizedSize
).
>>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2)
>>> h = SpaceInputAxis(
... id=AxisId("h"),
... size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1),
... unit="millimeter",
... scale=4,
... )
>>> print(h.size.get_size(h, w))
49
⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49
392 def get_size( 393 self, 394 axis: Union[ 395 ChannelAxis, 396 IndexInputAxis, 397 IndexOutputAxis, 398 TimeInputAxis, 399 SpaceInputAxis, 400 TimeOutputAxis, 401 TimeOutputAxisWithHalo, 402 SpaceOutputAxis, 403 SpaceOutputAxisWithHalo, 404 ], 405 ref_axis: Union[ 406 ChannelAxis, 407 IndexInputAxis, 408 IndexOutputAxis, 409 TimeInputAxis, 410 SpaceInputAxis, 411 TimeOutputAxis, 412 TimeOutputAxisWithHalo, 413 SpaceOutputAxis, 414 SpaceOutputAxisWithHalo, 415 ], 416 n: ParameterizedSize_N = 0, 417 ref_size: Optional[int] = None, 418 ): 419 """Compute the concrete size for a given axis and its reference axis. 420 421 Args: 422 axis: The axis this `SizeReference` is the size of. 423 ref_axis: The reference axis to compute the size from. 424 n: If the **ref_axis** is parameterized (of type `ParameterizedSize`) 425 and no fixed **ref_size** is given, 426 **n** is used to compute the size of the parameterized **ref_axis**. 427 ref_size: Overwrite the reference size instead of deriving it from 428 **ref_axis** 429 (**ref_axis.scale** is still used; any given **n** is ignored). 430 """ 431 assert axis.size == self, ( 432 "Given `axis.size` is not defined by this `SizeReference`" 433 ) 434 435 assert ref_axis.id == self.axis_id, ( 436 f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}." 437 ) 438 439 assert axis.unit == ref_axis.unit, ( 440 "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`," 441 f" but {axis.unit}!={ref_axis.unit}" 442 ) 443 if ref_size is None: 444 if isinstance(ref_axis.size, (int, float)): 445 ref_size = ref_axis.size 446 elif isinstance(ref_axis.size, ParameterizedSize): 447 ref_size = ref_axis.size.get_size(n) 448 elif isinstance(ref_axis.size, DataDependentSize): 449 raise ValueError( 450 "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`." 451 ) 452 elif isinstance(ref_axis.size, SizeReference): 453 raise ValueError( 454 "Reference axis referenced in `SizeReference` may not be sized by a" 455 + " `SizeReference` itself." 456 ) 457 else: 458 assert_never(ref_axis.size) 459 460 return int(ref_size * ref_axis.scale / axis.scale + self.offset)
Compute the concrete size for a given axis and its reference axis.
Arguments:
- axis: The axis this
SizeReference
is the size of. - ref_axis: The reference axis to compute the size from.
- n: If the ref_axis is parameterized (of type
ParameterizedSize
) and no fixed ref_size is given, n is used to compute the size of the parameterized ref_axis. - ref_size: Overwrite the reference size instead of deriving it from ref_axis (ref_axis.scale is still used; any given n is ignored).
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
479class AxisBase(NodeWithExplicitlySetFields): 480 id: AxisId 481 """An axis id unique across all axes of one tensor.""" 482 483 description: Annotated[str, MaxLen(128)] = "" 484 """A short description of this axis beyond its type and id."""
A short description of this axis beyond its type and id.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
487class WithHalo(Node): 488 halo: Annotated[int, Ge(1)] 489 """The halo should be cropped from the output tensor to avoid boundary effects. 490 It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`. 491 To document a halo that is already cropped by the model use `size.offset` instead.""" 492 493 size: Annotated[ 494 SizeReference, 495 Field( 496 examples=[ 497 10, 498 SizeReference( 499 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 500 ).model_dump(mode="json"), 501 ] 502 ), 503 ] 504 """reference to another axis with an optional offset (see `SizeReference`)"""
The halo should be cropped from the output tensor to avoid boundary effects.
It is to be cropped from both sides, i.e. size_after_crop = size - 2 * halo
.
To document a halo that is already cropped by the model use size.offset
instead.
reference to another axis with an optional offset (see SizeReference
)
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
510class BatchAxis(AxisBase): 511 implemented_type: ClassVar[Literal["batch"]] = "batch" 512 if TYPE_CHECKING: 513 type: Literal["batch"] = "batch" 514 else: 515 type: Literal["batch"] 516 517 id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID 518 size: Optional[Literal[1]] = None 519 """The batch size may be fixed to 1, 520 otherwise (the default) it may be chosen arbitrarily depending on available memory""" 521 522 @property 523 def scale(self): 524 return 1.0 525 526 @property 527 def concatenable(self): 528 return True 529 530 @property 531 def unit(self): 532 return None
The batch size may be fixed to 1, otherwise (the default) it may be chosen arbitrarily depending on available memory
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
535class ChannelAxis(AxisBase): 536 implemented_type: ClassVar[Literal["channel"]] = "channel" 537 if TYPE_CHECKING: 538 type: Literal["channel"] = "channel" 539 else: 540 type: Literal["channel"] 541 542 id: NonBatchAxisId = AxisId("channel") 543 544 channel_names: NotEmpty[List[Identifier]] 545 546 @property 547 def size(self) -> int: 548 return len(self.channel_names) 549 550 @property 551 def concatenable(self): 552 return False 553 554 @property 555 def scale(self) -> float: 556 return 1.0 557 558 @property 559 def unit(self): 560 return None
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
563class IndexAxisBase(AxisBase): 564 implemented_type: ClassVar[Literal["index"]] = "index" 565 if TYPE_CHECKING: 566 type: Literal["index"] = "index" 567 else: 568 type: Literal["index"] 569 570 id: NonBatchAxisId = AxisId("index") 571 572 @property 573 def scale(self) -> float: 574 return 1.0 575 576 @property 577 def unit(self): 578 return None
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
601class IndexInputAxis(IndexAxisBase, _WithInputAxisSize): 602 concatenable: bool = False 603 """If a model has a `concatenable` input axis, it can be processed blockwise, 604 splitting a longer sample axis into blocks matching its input tensor description. 605 Output axes are concatenable if they have a `SizeReference` to a concatenable 606 input axis. 607 """
If a model has a concatenable
input axis, it can be processed blockwise,
splitting a longer sample axis into blocks matching its input tensor description.
Output axes are concatenable if they have a SizeReference
to a concatenable
input axis.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
610class IndexOutputAxis(IndexAxisBase): 611 size: Annotated[ 612 Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize], 613 Field( 614 examples=[ 615 10, 616 SizeReference( 617 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 618 ).model_dump(mode="json"), 619 ] 620 ), 621 ] 622 """The size/length of this axis can be specified as 623 - fixed integer 624 - reference to another axis with an optional offset (`SizeReference`) 625 - data dependent size using `DataDependentSize` (size is only known after model inference) 626 """
The size/length of this axis can be specified as
- fixed integer
- reference to another axis with an optional offset (
SizeReference
) - data dependent size using
DataDependentSize
(size is only known after model inference)
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
629class TimeAxisBase(AxisBase): 630 implemented_type: ClassVar[Literal["time"]] = "time" 631 if TYPE_CHECKING: 632 type: Literal["time"] = "time" 633 else: 634 type: Literal["time"] 635 636 id: NonBatchAxisId = AxisId("time") 637 unit: Optional[TimeUnit] = None 638 scale: Annotated[float, Gt(0)] = 1.0
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
641class TimeInputAxis(TimeAxisBase, _WithInputAxisSize): 642 concatenable: bool = False 643 """If a model has a `concatenable` input axis, it can be processed blockwise, 644 splitting a longer sample axis into blocks matching its input tensor description. 645 Output axes are concatenable if they have a `SizeReference` to a concatenable 646 input axis. 647 """
If a model has a concatenable
input axis, it can be processed blockwise,
splitting a longer sample axis into blocks matching its input tensor description.
Output axes are concatenable if they have a SizeReference
to a concatenable
input axis.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
650class SpaceAxisBase(AxisBase): 651 implemented_type: ClassVar[Literal["space"]] = "space" 652 if TYPE_CHECKING: 653 type: Literal["space"] = "space" 654 else: 655 type: Literal["space"] 656 657 id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x") 658 unit: Optional[SpaceUnit] = None 659 scale: Annotated[float, Gt(0)] = 1.0
An axis id unique across all axes of one tensor.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
662class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize): 663 concatenable: bool = False 664 """If a model has a `concatenable` input axis, it can be processed blockwise, 665 splitting a longer sample axis into blocks matching its input tensor description. 666 Output axes are concatenable if they have a `SizeReference` to a concatenable 667 input axis. 668 """
If a model has a concatenable
input axis, it can be processed blockwise,
splitting a longer sample axis into blocks matching its input tensor description.
Output axes are concatenable if they have a SizeReference
to a concatenable
input axis.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
intended for isinstance comparisons in py<3.10
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
intended for isinstance comparisons in py<3.10
intended for isinstance comparisons in py<3.10
790class NominalOrOrdinalDataDescr(Node): 791 values: TVs 792 """A fixed set of nominal or an ascending sequence of ordinal values. 793 In this case `data.type` is required to be an unsigend integer type, e.g. 'uint8'. 794 String `values` are interpreted as labels for tensor values 0, ..., N. 795 Note: as YAML 1.2 does not natively support a "set" datatype, 796 nominal values should be given as a sequence (aka list/array) as well. 797 """ 798 799 type: Annotated[ 800 NominalOrOrdinalDType, 801 Field( 802 examples=[ 803 "float32", 804 "uint8", 805 "uint16", 806 "int64", 807 "bool", 808 ], 809 ), 810 ] = "uint8" 811 812 @model_validator(mode="after") 813 def _validate_values_match_type( 814 self, 815 ) -> Self: 816 incompatible: List[Any] = [] 817 for v in self.values: 818 if self.type == "bool": 819 if not isinstance(v, bool): 820 incompatible.append(v) 821 elif self.type in DTYPE_LIMITS: 822 if ( 823 isinstance(v, (int, float)) 824 and ( 825 v < DTYPE_LIMITS[self.type].min 826 or v > DTYPE_LIMITS[self.type].max 827 ) 828 or (isinstance(v, str) and "uint" not in self.type) 829 or (isinstance(v, float) and "int" in self.type) 830 ): 831 incompatible.append(v) 832 else: 833 incompatible.append(v) 834 835 if len(incompatible) == 5: 836 incompatible.append("...") 837 break 838 839 if incompatible: 840 raise ValueError( 841 f"data type '{self.type}' incompatible with values {incompatible}" 842 ) 843 844 return self 845 846 unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None 847 848 @property 849 def range(self): 850 if isinstance(self.values[0], str): 851 return 0, len(self.values) - 1 852 else: 853 return min(self.values), max(self.values)
A fixed set of nominal or an ascending sequence of ordinal values.
In this case data.type
is required to be an unsigend integer type, e.g. 'uint8'.
String values
are interpreted as labels for tensor values 0, ..., N.
Note: as YAML 1.2 does not natively support a "set" datatype,
nominal values should be given as a sequence (aka list/array) as well.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
870class IntervalOrRatioDataDescr(Node): 871 type: Annotated[ # TODO: rename to dtype 872 IntervalOrRatioDType, 873 Field( 874 examples=["float32", "float64", "uint8", "uint16"], 875 ), 876 ] = "float32" 877 range: Tuple[Optional[float], Optional[float]] = ( 878 None, 879 None, 880 ) 881 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor. 882 `None` corresponds to min/max of what can be expressed by **type**.""" 883 unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit" 884 scale: float = 1.0 885 """Scale for data on an interval (or ratio) scale.""" 886 offset: Optional[float] = None 887 """Offset for data on a ratio scale.""" 888 889 @model_validator(mode="before") 890 def _replace_inf(cls, data: Any): 891 if is_dict(data): 892 if "range" in data and is_sequence(data["range"]): 893 forbidden = ( 894 "inf", 895 "-inf", 896 ".inf", 897 "-.inf", 898 float("inf"), 899 float("-inf"), 900 ) 901 if any(v in forbidden for v in data["range"]): 902 issue_warning("replaced 'inf' value", value=data["range"]) 903 904 data["range"] = tuple( 905 (None if v in forbidden else v) for v in data["range"] 906 ) 907 908 return data
Tuple (minimum, maximum)
specifying the allowed range of the data in this tensor.
None
corresponds to min/max of what can be expressed by type.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
processing base class
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
918class BinarizeKwargs(ProcessingKwargs): 919 """key word arguments for `BinarizeDescr`""" 920 921 threshold: float 922 """The fixed threshold"""
key word arguments for BinarizeDescr
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
925class BinarizeAlongAxisKwargs(ProcessingKwargs): 926 """key word arguments for `BinarizeDescr`""" 927 928 threshold: NotEmpty[List[float]] 929 """The fixed threshold values along `axis`""" 930 931 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 932 """The `threshold` axis"""
key word arguments for BinarizeDescr
The threshold
axis
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
935class BinarizeDescr(ProcessingDescrBase): 936 """Binarize the tensor with a fixed threshold. 937 938 Values above `BinarizeKwargs.threshold`/`BinarizeAlongAxisKwargs.threshold` 939 will be set to one, values below the threshold to zero. 940 941 Examples: 942 - in YAML 943 ```yaml 944 postprocessing: 945 - id: binarize 946 kwargs: 947 axis: 'channel' 948 threshold: [0.25, 0.5, 0.75] 949 ``` 950 - in Python: 951 >>> postprocessing = [BinarizeDescr( 952 ... kwargs=BinarizeAlongAxisKwargs( 953 ... axis=AxisId('channel'), 954 ... threshold=[0.25, 0.5, 0.75], 955 ... ) 956 ... )] 957 """ 958 959 implemented_id: ClassVar[Literal["binarize"]] = "binarize" 960 if TYPE_CHECKING: 961 id: Literal["binarize"] = "binarize" 962 else: 963 id: Literal["binarize"] 964 kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs]
Binarize the tensor with a fixed threshold.
Values above BinarizeKwargs.threshold
/BinarizeAlongAxisKwargs.threshold
will be set to one, values below the threshold to zero.
Examples:
- in YAML
postprocessing:
- id: binarize
kwargs:
axis: 'channel'
threshold: [0.25, 0.5, 0.75]
- in Python:
>>> postprocessing = [BinarizeDescr( ... kwargs=BinarizeAlongAxisKwargs( ... axis=AxisId('channel'), ... threshold=[0.25, 0.5, 0.75], ... ) ... )]
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
967class ClipDescr(ProcessingDescrBase): 968 """Set tensor values below min to min and above max to max. 969 970 See `ScaleRangeDescr` for examples. 971 """ 972 973 implemented_id: ClassVar[Literal["clip"]] = "clip" 974 if TYPE_CHECKING: 975 id: Literal["clip"] = "clip" 976 else: 977 id: Literal["clip"] 978 979 kwargs: ClipKwargs
Set tensor values below min to min and above max to max.
See ScaleRangeDescr
for examples.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
982class EnsureDtypeKwargs(ProcessingKwargs): 983 """key word arguments for `EnsureDtypeDescr`""" 984 985 dtype: Literal[ 986 "float32", 987 "float64", 988 "uint8", 989 "int8", 990 "uint16", 991 "int16", 992 "uint32", 993 "int32", 994 "uint64", 995 "int64", 996 "bool", 997 ]
key word arguments for EnsureDtypeDescr
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
1000class EnsureDtypeDescr(ProcessingDescrBase): 1001 """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching). 1002 1003 This can for example be used to ensure the inner neural network model gets a 1004 different input tensor data type than the fully described bioimage.io model does. 1005 1006 Examples: 1007 The described bioimage.io model (incl. preprocessing) accepts any 1008 float32-compatible tensor, normalizes it with percentiles and clipping and then 1009 casts it to uint8, which is what the neural network in this example expects. 1010 - in YAML 1011 ```yaml 1012 inputs: 1013 - data: 1014 type: float32 # described bioimage.io model is compatible with any float32 input tensor 1015 preprocessing: 1016 - id: scale_range 1017 kwargs: 1018 axes: ['y', 'x'] 1019 max_percentile: 99.8 1020 min_percentile: 5.0 1021 - id: clip 1022 kwargs: 1023 min: 0.0 1024 max: 1.0 1025 - id: ensure_dtype # the neural network of the model requires uint8 1026 kwargs: 1027 dtype: uint8 1028 ``` 1029 - in Python: 1030 >>> preprocessing = [ 1031 ... ScaleRangeDescr( 1032 ... kwargs=ScaleRangeKwargs( 1033 ... axes= (AxisId('y'), AxisId('x')), 1034 ... max_percentile= 99.8, 1035 ... min_percentile= 5.0, 1036 ... ) 1037 ... ), 1038 ... ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)), 1039 ... EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")), 1040 ... ] 1041 """ 1042 1043 implemented_id: ClassVar[Literal["ensure_dtype"]] = "ensure_dtype" 1044 if TYPE_CHECKING: 1045 id: Literal["ensure_dtype"] = "ensure_dtype" 1046 else: 1047 id: Literal["ensure_dtype"] 1048 1049 kwargs: EnsureDtypeKwargs
Cast the tensor data type to EnsureDtypeKwargs.dtype
(if not matching).
This can for example be used to ensure the inner neural network model gets a different input tensor data type than the fully described bioimage.io model does.
Examples:
The described bioimage.io model (incl. preprocessing) accepts any float32-compatible tensor, normalizes it with percentiles and clipping and then casts it to uint8, which is what the neural network in this example expects.
- in YAML
inputs: - data: type: float32 # described bioimage.io model is compatible with any float32 input tensor preprocessing: - id: scale_range kwargs: axes: ['y', 'x'] max_percentile: 99.8 min_percentile: 5.0 - id: clip kwargs: min: 0.0 max: 1.0 - id: ensure_dtype # the neural network of the model requires uint8 kwargs: dtype: uint8
- in Python:
>>> preprocessing = [ ... ScaleRangeDescr( ... kwargs=ScaleRangeKwargs( ... axes= (AxisId('y'), AxisId('x')), ... max_percentile= 99.8, ... min_percentile= 5.0, ... ) ... ), ... ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)), ... EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")), ... ]
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
1052class ScaleLinearKwargs(ProcessingKwargs): 1053 """Key word arguments for `ScaleLinearDescr`""" 1054 1055 gain: float = 1.0 1056 """multiplicative factor""" 1057 1058 offset: float = 0.0 1059 """additive term""" 1060 1061 @model_validator(mode="after") 1062 def _validate(self) -> Self: 1063 if self.gain == 1.0 and self.offset == 0.0: 1064 raise ValueError( 1065 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 1066 + " != 0.0." 1067 ) 1068 1069 return self
Key word arguments for ScaleLinearDescr
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
1072class ScaleLinearAlongAxisKwargs(ProcessingKwargs): 1073 """Key word arguments for `ScaleLinearDescr`""" 1074 1075 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 1076 """The axis of gain and offset values.""" 1077 1078 gain: Union[float, NotEmpty[List[float]]] = 1.0 1079 """multiplicative factor""" 1080 1081 offset: Union[float, NotEmpty[List[float]]] = 0.0 1082 """additive term""" 1083 1084 @model_validator(mode="after") 1085 def _validate(self) -> Self: 1086 if isinstance(self.gain, list): 1087 if isinstance(self.offset, list): 1088 if len(self.gain) != len(self.offset): 1089 raise ValueError( 1090 f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match." 1091 ) 1092 else: 1093 self.offset = [float(self.offset)] * len(self.gain) 1094 elif isinstance(self.offset, list): 1095 self.gain = [float(self.gain)] * len(self.offset) 1096 else: 1097 raise ValueError( 1098 "Do not specify an `axis` for scalar gain and offset values." 1099 ) 1100 1101 if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset): 1102 raise ValueError( 1103 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 1104 + " != 0.0." 1105 ) 1106 1107 return self
Key word arguments for ScaleLinearDescr
The axis of gain and offset values.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
1110class ScaleLinearDescr(ProcessingDescrBase): 1111 """Fixed linear scaling. 1112 1113 Examples: 1114 1. Scale with scalar gain and offset 1115 - in YAML 1116 ```yaml 1117 preprocessing: 1118 - id: scale_linear 1119 kwargs: 1120 gain: 2.0 1121 offset: 3.0 1122 ``` 1123 - in Python: 1124 >>> preprocessing = [ 1125 ... ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0)) 1126 ... ] 1127 1128 2. Independent scaling along an axis 1129 - in YAML 1130 ```yaml 1131 preprocessing: 1132 - id: scale_linear 1133 kwargs: 1134 axis: 'channel' 1135 gain: [1.0, 2.0, 3.0] 1136 ``` 1137 - in Python: 1138 >>> preprocessing = [ 1139 ... ScaleLinearDescr( 1140 ... kwargs=ScaleLinearAlongAxisKwargs( 1141 ... axis=AxisId("channel"), 1142 ... gain=[1.0, 2.0, 3.0], 1143 ... ) 1144 ... ) 1145 ... ] 1146 1147 """ 1148 1149 implemented_id: ClassVar[Literal["scale_linear"]] = "scale_linear" 1150 if TYPE_CHECKING: 1151 id: Literal["scale_linear"] = "scale_linear" 1152 else: 1153 id: Literal["scale_linear"] 1154 kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs]
Fixed linear scaling.
Examples:
Scale with scalar gain and offset
in YAML
preprocessing: - id: scale_linear kwargs: gain: 2.0 offset: 3.0
in Python:
>>> preprocessing = [ ... ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0)) ... ]
Independent scaling along an axis
in YAML
preprocessing: - id: scale_linear kwargs: axis: 'channel' gain: [1.0, 2.0, 3.0]
in Python:
>>> preprocessing = [ ... ScaleLinearDescr( ... kwargs=ScaleLinearAlongAxisKwargs( ... axis=AxisId("channel"), ... gain=[1.0, 2.0, 3.0], ... ) ... ) ... ]
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
1157class SigmoidDescr(ProcessingDescrBase): 1158 """The logistic sigmoid function, a.k.a. expit function. 1159 1160 Examples: 1161 - in YAML 1162 ```yaml 1163 postprocessing: 1164 - id: sigmoid 1165 ``` 1166 - in Python: 1167 >>> postprocessing = [SigmoidDescr()] 1168 """ 1169 1170 implemented_id: ClassVar[Literal["sigmoid"]] = "sigmoid" 1171 if TYPE_CHECKING: 1172 id: Literal["sigmoid"] = "sigmoid" 1173 else: 1174 id: Literal["sigmoid"] 1175 1176 @property 1177 def kwargs(self) -> ProcessingKwargs: 1178 """empty kwargs""" 1179 return ProcessingKwargs()
The logistic sigmoid function, a.k.a. expit function.
Examples:
- in YAML
postprocessing:
- id: sigmoid
- in Python:
>>> postprocessing = [SigmoidDescr()]
1176 @property 1177 def kwargs(self) -> ProcessingKwargs: 1178 """empty kwargs""" 1179 return ProcessingKwargs()
empty kwargs
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
1182class SoftmaxKwargs(ProcessingKwargs): 1183 """key word arguments for `SoftmaxDescr`""" 1184 1185 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] = AxisId("channel") 1186 """The axis to apply the softmax function along. 1187 Note: 1188 Defaults to 'channel' axis 1189 (which may not exist, in which case 1190 a different axis id has to be specified). 1191 """
key word arguments for SoftmaxDescr
The axis to apply the softmax function along.
Note:
Defaults to 'channel' axis (which may not exist, in which case a different axis id has to be specified).
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
1194class SoftmaxDescr(ProcessingDescrBase): 1195 """The softmax function. 1196 1197 Examples: 1198 - in YAML 1199 ```yaml 1200 postprocessing: 1201 - id: softmax 1202 kwargs: 1203 axis: channel 1204 ``` 1205 - in Python: 1206 >>> postprocessing = [SoftmaxDescr(kwargs=SoftmaxKwargs(axis=AxisId("channel")))] 1207 """ 1208 1209 implemented_id: ClassVar[Literal["softmax"]] = "softmax" 1210 if TYPE_CHECKING: 1211 id: Literal["softmax"] = "softmax" 1212 else: 1213 id: Literal["softmax"] 1214 1215 kwargs: SoftmaxKwargs = Field(default_factory=SoftmaxKwargs.model_construct)
The softmax function.
Examples:
- in YAML
postprocessing:
- id: softmax
kwargs:
axis: channel
- in Python:
>>> postprocessing = [SoftmaxDescr(kwargs=SoftmaxKwargs(axis=AxisId("channel")))]
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
1218class FixedZeroMeanUnitVarianceKwargs(ProcessingKwargs): 1219 """key word arguments for `FixedZeroMeanUnitVarianceDescr`""" 1220 1221 mean: float 1222 """The mean value to normalize with.""" 1223 1224 std: Annotated[float, Ge(1e-6)] 1225 """The standard deviation value to normalize with."""
key word arguments for FixedZeroMeanUnitVarianceDescr
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
1228class FixedZeroMeanUnitVarianceAlongAxisKwargs(ProcessingKwargs): 1229 """key word arguments for `FixedZeroMeanUnitVarianceDescr`""" 1230 1231 mean: NotEmpty[List[float]] 1232 """The mean value(s) to normalize with.""" 1233 1234 std: NotEmpty[List[Annotated[float, Ge(1e-6)]]] 1235 """The standard deviation value(s) to normalize with. 1236 Size must match `mean` values.""" 1237 1238 axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])] 1239 """The axis of the mean/std values to normalize each entry along that dimension 1240 separately.""" 1241 1242 @model_validator(mode="after") 1243 def _mean_and_std_match(self) -> Self: 1244 if len(self.mean) != len(self.std): 1245 raise ValueError( 1246 f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})" 1247 + " must match." 1248 ) 1249 1250 return self
key word arguments for FixedZeroMeanUnitVarianceDescr
The standard deviation value(s) to normalize with.
Size must match mean
values.
The axis of the mean/std values to normalize each entry along that dimension separately.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
1253class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase): 1254 """Subtract a given mean and divide by the standard deviation. 1255 1256 Normalize with fixed, precomputed values for 1257 `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std` 1258 Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given 1259 axes. 1260 1261 Examples: 1262 1. scalar value for whole tensor 1263 - in YAML 1264 ```yaml 1265 preprocessing: 1266 - id: fixed_zero_mean_unit_variance 1267 kwargs: 1268 mean: 103.5 1269 std: 13.7 1270 ``` 1271 - in Python 1272 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 1273 ... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7) 1274 ... )] 1275 1276 2. independently along an axis 1277 - in YAML 1278 ```yaml 1279 preprocessing: 1280 - id: fixed_zero_mean_unit_variance 1281 kwargs: 1282 axis: channel 1283 mean: [101.5, 102.5, 103.5] 1284 std: [11.7, 12.7, 13.7] 1285 ``` 1286 - in Python 1287 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 1288 ... kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( 1289 ... axis=AxisId("channel"), 1290 ... mean=[101.5, 102.5, 103.5], 1291 ... std=[11.7, 12.7, 13.7], 1292 ... ) 1293 ... )] 1294 """ 1295 1296 implemented_id: ClassVar[Literal["fixed_zero_mean_unit_variance"]] = ( 1297 "fixed_zero_mean_unit_variance" 1298 ) 1299 if TYPE_CHECKING: 1300 id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance" 1301 else: 1302 id: Literal["fixed_zero_mean_unit_variance"] 1303 1304 kwargs: Union[ 1305 FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs 1306 ]
Subtract a given mean and divide by the standard deviation.
Normalize with fixed, precomputed values for
FixedZeroMeanUnitVarianceKwargs.mean
and FixedZeroMeanUnitVarianceKwargs.std
Use FixedZeroMeanUnitVarianceAlongAxisKwargs
for independent scaling along given
axes.
Examples:
- scalar value for whole tensor
- in YAML
preprocessing:
- id: fixed_zero_mean_unit_variance
kwargs:
mean: 103.5
std: 13.7
- in Python
>>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7)
... )]
independently along an axis
- in YAML
preprocessing: - id: fixed_zero_mean_unit_variance kwargs: axis: channel mean: [101.5, 102.5, 103.5] std: [11.7, 12.7, 13.7]
- in Python
>>> preprocessing = [FixedZeroMeanUnitVarianceDescr( ... kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( ... axis=AxisId("channel"), ... mean=[101.5, 102.5, 103.5], ... std=[11.7, 12.7, 13.7], ... ) ... )]
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
1309class ZeroMeanUnitVarianceKwargs(ProcessingKwargs): 1310 """key word arguments for `ZeroMeanUnitVarianceDescr`""" 1311 1312 axes: Annotated[ 1313 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 1314 ] = None 1315 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 1316 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 1317 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 1318 To normalize each sample independently leave out the 'batch' axis. 1319 Default: Scale all axes jointly.""" 1320 1321 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 1322 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`."""
key word arguments for ZeroMeanUnitVarianceDescr
The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
resulting in a tensor of equal shape normalized per channel, specify axes=('batch', 'x', 'y')
.
To normalize each sample independently leave out the 'batch' axis.
Default: Scale all axes jointly.
epsilon for numeric stability: out = (tensor - mean) / (std + eps)
.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
1325class ZeroMeanUnitVarianceDescr(ProcessingDescrBase): 1326 """Subtract mean and divide by variance. 1327 1328 Examples: 1329 Subtract tensor mean and variance 1330 - in YAML 1331 ```yaml 1332 preprocessing: 1333 - id: zero_mean_unit_variance 1334 ``` 1335 - in Python 1336 >>> preprocessing = [ZeroMeanUnitVarianceDescr()] 1337 """ 1338 1339 implemented_id: ClassVar[Literal["zero_mean_unit_variance"]] = ( 1340 "zero_mean_unit_variance" 1341 ) 1342 if TYPE_CHECKING: 1343 id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance" 1344 else: 1345 id: Literal["zero_mean_unit_variance"] 1346 1347 kwargs: ZeroMeanUnitVarianceKwargs = Field( 1348 default_factory=ZeroMeanUnitVarianceKwargs.model_construct 1349 )
Subtract mean and divide by variance.
Examples:
Subtract tensor mean and variance
- in YAML
preprocessing: - id: zero_mean_unit_variance
- in Python
>>> preprocessing = [ZeroMeanUnitVarianceDescr()]
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
1352class ScaleRangeKwargs(ProcessingKwargs): 1353 """key word arguments for `ScaleRangeDescr` 1354 1355 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default) 1356 this processing step normalizes data to the [0, 1] intervall. 1357 For other percentiles the normalized values will partially be outside the [0, 1] 1358 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the 1359 normalized values to a range. 1360 """ 1361 1362 axes: Annotated[ 1363 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 1364 ] = None 1365 """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value. 1366 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 1367 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 1368 To normalize samples independently, leave out the "batch" axis. 1369 Default: Scale all axes jointly.""" 1370 1371 min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0 1372 """The lower percentile used to determine the value to align with zero.""" 1373 1374 max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0 1375 """The upper percentile used to determine the value to align with one. 1376 Has to be bigger than `min_percentile`. 1377 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly 1378 accepting percentiles specified in the range 0.0 to 1.0.""" 1379 1380 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 1381 """Epsilon for numeric stability. 1382 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`; 1383 with `v_lower,v_upper` values at the respective percentiles.""" 1384 1385 reference_tensor: Optional[TensorId] = None 1386 """Tensor ID to compute the percentiles from. Default: The tensor itself. 1387 For any tensor in `inputs` only input tensor references are allowed.""" 1388 1389 @field_validator("max_percentile", mode="after") 1390 @classmethod 1391 def min_smaller_max(cls, value: float, info: ValidationInfo) -> float: 1392 if (min_p := info.data["min_percentile"]) >= value: 1393 raise ValueError(f"min_percentile {min_p} >= max_percentile {value}") 1394 1395 return value
key word arguments for ScaleRangeDescr
For min_percentile
=0.0 (the default) and max_percentile
=100 (the default)
this processing step normalizes data to the [0, 1] intervall.
For other percentiles the normalized values will partially be outside the [0, 1]
intervall. Use ScaleRange
followed by ClipDescr
if you want to limit the
normalized values to a range.
The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value.
For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
resulting in a tensor of equal shape normalized per channel, specify axes=('batch', 'x', 'y')
.
To normalize samples independently, leave out the "batch" axis.
Default: Scale all axes jointly.
The lower percentile used to determine the value to align with zero.
The upper percentile used to determine the value to align with one.
Has to be bigger than min_percentile
.
The range is 1 to 100 instead of 0 to 100 to avoid mistakenly
accepting percentiles specified in the range 0.0 to 1.0.
Epsilon for numeric stability.
out = (tensor - v_lower) / (v_upper - v_lower + eps)
;
with v_lower,v_upper
values at the respective percentiles.
Tensor ID to compute the percentiles from. Default: The tensor itself.
For any tensor in inputs
only input tensor references are allowed.
1389 @field_validator("max_percentile", mode="after") 1390 @classmethod 1391 def min_smaller_max(cls, value: float, info: ValidationInfo) -> float: 1392 if (min_p := info.data["min_percentile"]) >= value: 1393 raise ValueError(f"min_percentile {min_p} >= max_percentile {value}") 1394 1395 return value
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
1398class ScaleRangeDescr(ProcessingDescrBase): 1399 """Scale with percentiles. 1400 1401 Examples: 1402 1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0 1403 - in YAML 1404 ```yaml 1405 preprocessing: 1406 - id: scale_range 1407 kwargs: 1408 axes: ['y', 'x'] 1409 max_percentile: 99.8 1410 min_percentile: 5.0 1411 ``` 1412 - in Python 1413 >>> preprocessing = [ 1414 ... ScaleRangeDescr( 1415 ... kwargs=ScaleRangeKwargs( 1416 ... axes= (AxisId('y'), AxisId('x')), 1417 ... max_percentile= 99.8, 1418 ... min_percentile= 5.0, 1419 ... ) 1420 ... ), 1421 ... ClipDescr( 1422 ... kwargs=ClipKwargs( 1423 ... min=0.0, 1424 ... max=1.0, 1425 ... ) 1426 ... ), 1427 ... ] 1428 1429 2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles. 1430 - in YAML 1431 ```yaml 1432 preprocessing: 1433 - id: scale_range 1434 kwargs: 1435 axes: ['y', 'x'] 1436 max_percentile: 99.8 1437 min_percentile: 5.0 1438 - id: scale_range 1439 - id: clip 1440 kwargs: 1441 min: 0.0 1442 max: 1.0 1443 ``` 1444 - in Python 1445 >>> preprocessing = [ScaleRangeDescr( 1446 ... kwargs=ScaleRangeKwargs( 1447 ... axes= (AxisId('y'), AxisId('x')), 1448 ... max_percentile= 99.8, 1449 ... min_percentile= 5.0, 1450 ... ) 1451 ... )] 1452 1453 """ 1454 1455 implemented_id: ClassVar[Literal["scale_range"]] = "scale_range" 1456 if TYPE_CHECKING: 1457 id: Literal["scale_range"] = "scale_range" 1458 else: 1459 id: Literal["scale_range"] 1460 kwargs: ScaleRangeKwargs = Field(default_factory=ScaleRangeKwargs.model_construct)
Scale with percentiles.
Examples:
- Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0
- in YAML
preprocessing:
- id: scale_range
kwargs:
axes: ['y', 'x']
max_percentile: 99.8
min_percentile: 5.0
- in Python
>>> preprocessing = [
... ScaleRangeDescr(
... kwargs=ScaleRangeKwargs(
... axes= (AxisId('y'), AxisId('x')),
... max_percentile= 99.8,
... min_percentile= 5.0,
... )
... ),
... ClipDescr(
... kwargs=ClipKwargs(
... min=0.0,
... max=1.0,
... )
... ),
... ]
Combine the above scaling with additional clipping to clip values outside the range given by the percentiles.
- in YAML
preprocessing: - id: scale_range kwargs: axes: ['y', 'x'] max_percentile: 99.8 min_percentile: 5.0 - id: scale_range - id: clip kwargs: min: 0.0 max: 1.0
- in Python
>>> preprocessing = [ScaleRangeDescr( ... kwargs=ScaleRangeKwargs( ... axes= (AxisId('y'), AxisId('x')), ... max_percentile= 99.8, ... min_percentile= 5.0, ... ) ... )]
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
1463class ScaleMeanVarianceKwargs(ProcessingKwargs): 1464 """key word arguments for `ScaleMeanVarianceKwargs`""" 1465 1466 reference_tensor: TensorId 1467 """Name of tensor to match.""" 1468 1469 axes: Annotated[ 1470 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 1471 ] = None 1472 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 1473 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 1474 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 1475 To normalize samples independently, leave out the 'batch' axis. 1476 Default: Scale all axes jointly.""" 1477 1478 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 1479 """Epsilon for numeric stability: 1480 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`"""
key word arguments for ScaleMeanVarianceKwargs
The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
resulting in a tensor of equal shape normalized per channel, specify axes=('batch', 'x', 'y')
.
To normalize samples independently, leave out the 'batch' axis.
Default: Scale all axes jointly.
Epsilon for numeric stability:
out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
1483class ScaleMeanVarianceDescr(ProcessingDescrBase): 1484 """Scale a tensor's data distribution to match another tensor's mean/std. 1485 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.` 1486 """ 1487 1488 implemented_id: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance" 1489 if TYPE_CHECKING: 1490 id: Literal["scale_mean_variance"] = "scale_mean_variance" 1491 else: 1492 id: Literal["scale_mean_variance"] 1493 kwargs: ScaleMeanVarianceKwargs
Scale a tensor's data distribution to match another tensor's mean/std.
out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
1529class TensorDescrBase(Node, Generic[IO_AxisT]): 1530 id: TensorId 1531 """Tensor id. No duplicates are allowed.""" 1532 1533 description: Annotated[str, MaxLen(128)] = "" 1534 """free text description""" 1535 1536 axes: NotEmpty[Sequence[IO_AxisT]] 1537 """tensor axes""" 1538 1539 @property 1540 def shape(self): 1541 return tuple(a.size for a in self.axes) 1542 1543 @field_validator("axes", mode="after", check_fields=False) 1544 @classmethod 1545 def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]: 1546 batch_axes = [a for a in axes if a.type == "batch"] 1547 if len(batch_axes) > 1: 1548 raise ValueError( 1549 f"Only one batch axis (per tensor) allowed, but got {batch_axes}" 1550 ) 1551 1552 seen_ids: Set[AxisId] = set() 1553 duplicate_axes_ids: Set[AxisId] = set() 1554 for a in axes: 1555 (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id) 1556 1557 if duplicate_axes_ids: 1558 raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}") 1559 1560 return axes 1561 1562 test_tensor: FAIR[Optional[FileDescr_]] = None 1563 """An example tensor to use for testing. 1564 Using the model with the test input tensors is expected to yield the test output tensors. 1565 Each test tensor has be a an ndarray in the 1566 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format). 1567 The file extension must be '.npy'.""" 1568 1569 sample_tensor: FAIR[Optional[FileDescr_]] = None 1570 """A sample tensor to illustrate a possible input/output for the model, 1571 The sample image primarily serves to inform a human user about an example use case 1572 and is typically stored as .hdf5, .png or .tiff. 1573 It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats) 1574 (numpy's `.npy` format is not supported). 1575 The image dimensionality has to match the number of axes specified in this tensor description. 1576 """ 1577 1578 @model_validator(mode="after") 1579 def _validate_sample_tensor(self) -> Self: 1580 if self.sample_tensor is None or not get_validation_context().perform_io_checks: 1581 return self 1582 1583 reader = get_reader(self.sample_tensor.source, sha256=self.sample_tensor.sha256) 1584 tensor: NDArray[Any] = imread( 1585 reader.read(), 1586 extension=PurePosixPath(reader.original_file_name).suffix, 1587 ) 1588 n_dims = len(tensor.squeeze().shape) 1589 n_dims_min = n_dims_max = len(self.axes) 1590 1591 for a in self.axes: 1592 if isinstance(a, BatchAxis): 1593 n_dims_min -= 1 1594 elif isinstance(a.size, int): 1595 if a.size == 1: 1596 n_dims_min -= 1 1597 elif isinstance(a.size, (ParameterizedSize, DataDependentSize)): 1598 if a.size.min == 1: 1599 n_dims_min -= 1 1600 elif isinstance(a.size, SizeReference): 1601 if a.size.offset < 2: 1602 # size reference may result in singleton axis 1603 n_dims_min -= 1 1604 else: 1605 assert_never(a.size) 1606 1607 n_dims_min = max(0, n_dims_min) 1608 if n_dims < n_dims_min or n_dims > n_dims_max: 1609 raise ValueError( 1610 f"Expected sample tensor to have {n_dims_min} to" 1611 + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})." 1612 ) 1613 1614 return self 1615 1616 data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = ( 1617 IntervalOrRatioDataDescr() 1618 ) 1619 """Description of the tensor's data values, optionally per channel. 1620 If specified per channel, the data `type` needs to match across channels.""" 1621 1622 @property 1623 def dtype( 1624 self, 1625 ) -> Literal[ 1626 "float32", 1627 "float64", 1628 "uint8", 1629 "int8", 1630 "uint16", 1631 "int16", 1632 "uint32", 1633 "int32", 1634 "uint64", 1635 "int64", 1636 "bool", 1637 ]: 1638 """dtype as specified under `data.type` or `data[i].type`""" 1639 if isinstance(self.data, collections.abc.Sequence): 1640 return self.data[0].type 1641 else: 1642 return self.data.type 1643 1644 @field_validator("data", mode="after") 1645 @classmethod 1646 def _check_data_type_across_channels( 1647 cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] 1648 ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]: 1649 if not isinstance(value, list): 1650 return value 1651 1652 dtypes = {t.type for t in value} 1653 if len(dtypes) > 1: 1654 raise ValueError( 1655 "Tensor data descriptions per channel need to agree in their data" 1656 + f" `type`, but found {dtypes}." 1657 ) 1658 1659 return value 1660 1661 @model_validator(mode="after") 1662 def _check_data_matches_channelaxis(self) -> Self: 1663 if not isinstance(self.data, (list, tuple)): 1664 return self 1665 1666 for a in self.axes: 1667 if isinstance(a, ChannelAxis): 1668 size = a.size 1669 assert isinstance(size, int) 1670 break 1671 else: 1672 return self 1673 1674 if len(self.data) != size: 1675 raise ValueError( 1676 f"Got tensor data descriptions for {len(self.data)} channels, but" 1677 + f" '{a.id}' axis has size {size}." 1678 ) 1679 1680 return self 1681 1682 def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]: 1683 if len(array.shape) != len(self.axes): 1684 raise ValueError( 1685 f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})" 1686 + f" incompatible with {len(self.axes)} axes." 1687 ) 1688 return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
An example tensor to use for testing. Using the model with the test input tensors is expected to yield the test output tensors. Each test tensor has be a an ndarray in the numpy.lib file format. The file extension must be '.npy'.
A sample tensor to illustrate a possible input/output for the model,
The sample image primarily serves to inform a human user about an example use case
and is typically stored as .hdf5, .png or .tiff.
It has to be readable by the imageio library
(numpy's .npy
format is not supported).
The image dimensionality has to match the number of axes specified in this tensor description.
Description of the tensor's data values, optionally per channel.
If specified per channel, the data type
needs to match across channels.
1622 @property 1623 def dtype( 1624 self, 1625 ) -> Literal[ 1626 "float32", 1627 "float64", 1628 "uint8", 1629 "int8", 1630 "uint16", 1631 "int16", 1632 "uint32", 1633 "int32", 1634 "uint64", 1635 "int64", 1636 "bool", 1637 ]: 1638 """dtype as specified under `data.type` or `data[i].type`""" 1639 if isinstance(self.data, collections.abc.Sequence): 1640 return self.data[0].type 1641 else: 1642 return self.data.type
dtype as specified under data.type
or data[i].type
1682 def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]: 1683 if len(array.shape) != len(self.axes): 1684 raise ValueError( 1685 f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})" 1686 + f" incompatible with {len(self.axes)} axes." 1687 ) 1688 return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
Inherited Members
1691class InputTensorDescr(TensorDescrBase[InputAxis]): 1692 id: TensorId = TensorId("input") 1693 """Input tensor id. 1694 No duplicates are allowed across all inputs and outputs.""" 1695 1696 optional: bool = False 1697 """indicates that this tensor may be `None`""" 1698 1699 preprocessing: List[PreprocessingDescr] = Field( 1700 default_factory=cast(Callable[[], List[PreprocessingDescr]], list) 1701 ) 1702 1703 """Description of how this input should be preprocessed. 1704 1705 notes: 1706 - If preprocessing does not start with an 'ensure_dtype' entry, it is added 1707 to ensure an input tensor's data type matches the input tensor's data description. 1708 - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an 1709 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally 1710 changing the data type. 1711 """ 1712 1713 @model_validator(mode="after") 1714 def _validate_preprocessing_kwargs(self) -> Self: 1715 axes_ids = [a.id for a in self.axes] 1716 for p in self.preprocessing: 1717 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 1718 if kwargs_axes is None: 1719 continue 1720 1721 if not isinstance(kwargs_axes, collections.abc.Sequence): 1722 raise ValueError( 1723 f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}" 1724 ) 1725 1726 if any(a not in axes_ids for a in kwargs_axes): 1727 raise ValueError( 1728 "`preprocessing.i.kwargs.axes` needs to be subset of axes ids" 1729 ) 1730 1731 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 1732 dtype = self.data.type 1733 else: 1734 dtype = self.data[0].type 1735 1736 # ensure `preprocessing` begins with `EnsureDtypeDescr` 1737 if not self.preprocessing or not isinstance( 1738 self.preprocessing[0], EnsureDtypeDescr 1739 ): 1740 self.preprocessing.insert( 1741 0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 1742 ) 1743 1744 # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 1745 if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)): 1746 self.preprocessing.append( 1747 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 1748 ) 1749 1750 return self
Description of how this input should be preprocessed.
notes:
- If preprocessing does not start with an 'ensure_dtype' entry, it is added to ensure an input tensor's data type matches the input tensor's data description.
- If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally changing the data type.
1753def convert_axes( 1754 axes: str, 1755 *, 1756 shape: Union[ 1757 Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4 1758 ], 1759 tensor_type: Literal["input", "output"], 1760 halo: Optional[Sequence[int]], 1761 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 1762): 1763 ret: List[AnyAxis] = [] 1764 for i, a in enumerate(axes): 1765 axis_type = _AXIS_TYPE_MAP.get(a, a) 1766 if axis_type == "batch": 1767 ret.append(BatchAxis()) 1768 continue 1769 1770 scale = 1.0 1771 if isinstance(shape, _ParameterizedInputShape_v0_4): 1772 if shape.step[i] == 0: 1773 size = shape.min[i] 1774 else: 1775 size = ParameterizedSize(min=shape.min[i], step=shape.step[i]) 1776 elif isinstance(shape, _ImplicitOutputShape_v0_4): 1777 ref_t = str(shape.reference_tensor) 1778 if ref_t.count(".") == 1: 1779 t_id, orig_a_id = ref_t.split(".") 1780 else: 1781 t_id = ref_t 1782 orig_a_id = a 1783 1784 a_id = _AXIS_ID_MAP.get(orig_a_id, a) 1785 if not (orig_scale := shape.scale[i]): 1786 # old way to insert a new axis dimension 1787 size = int(2 * shape.offset[i]) 1788 else: 1789 scale = 1 / orig_scale 1790 if axis_type in ("channel", "index"): 1791 # these axes no longer have a scale 1792 offset_from_scale = orig_scale * size_refs.get( 1793 _TensorName_v0_4(t_id), {} 1794 ).get(orig_a_id, 0) 1795 else: 1796 offset_from_scale = 0 1797 size = SizeReference( 1798 tensor_id=TensorId(t_id), 1799 axis_id=AxisId(a_id), 1800 offset=int(offset_from_scale + 2 * shape.offset[i]), 1801 ) 1802 else: 1803 size = shape[i] 1804 1805 if axis_type == "time": 1806 if tensor_type == "input": 1807 ret.append(TimeInputAxis(size=size, scale=scale)) 1808 else: 1809 assert not isinstance(size, ParameterizedSize) 1810 if halo is None: 1811 ret.append(TimeOutputAxis(size=size, scale=scale)) 1812 else: 1813 assert not isinstance(size, int) 1814 ret.append( 1815 TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i]) 1816 ) 1817 1818 elif axis_type == "index": 1819 if tensor_type == "input": 1820 ret.append(IndexInputAxis(size=size)) 1821 else: 1822 if isinstance(size, ParameterizedSize): 1823 size = DataDependentSize(min=size.min) 1824 1825 ret.append(IndexOutputAxis(size=size)) 1826 elif axis_type == "channel": 1827 assert not isinstance(size, ParameterizedSize) 1828 if isinstance(size, SizeReference): 1829 warnings.warn( 1830 "Conversion of channel size from an implicit output shape may be" 1831 + " wrong" 1832 ) 1833 ret.append( 1834 ChannelAxis( 1835 channel_names=[ 1836 Identifier(f"channel{i}") for i in range(size.offset) 1837 ] 1838 ) 1839 ) 1840 else: 1841 ret.append( 1842 ChannelAxis( 1843 channel_names=[Identifier(f"channel{i}") for i in range(size)] 1844 ) 1845 ) 1846 elif axis_type == "space": 1847 if tensor_type == "input": 1848 ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale)) 1849 else: 1850 assert not isinstance(size, ParameterizedSize) 1851 if halo is None or halo[i] == 0: 1852 ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale)) 1853 elif isinstance(size, int): 1854 raise NotImplementedError( 1855 f"output axis with halo and fixed size (here {size}) not allowed" 1856 ) 1857 else: 1858 ret.append( 1859 SpaceOutputAxisWithHalo( 1860 id=AxisId(a), size=size, scale=scale, halo=halo[i] 1861 ) 1862 ) 1863 1864 return ret
2029class OutputTensorDescr(TensorDescrBase[OutputAxis]): 2030 id: TensorId = TensorId("output") 2031 """Output tensor id. 2032 No duplicates are allowed across all inputs and outputs.""" 2033 2034 postprocessing: List[PostprocessingDescr] = Field( 2035 default_factory=cast(Callable[[], List[PostprocessingDescr]], list) 2036 ) 2037 """Description of how this output should be postprocessed. 2038 2039 note: `postprocessing` always ends with an 'ensure_dtype' operation. 2040 If not given this is added to cast to this tensor's `data.type`. 2041 """ 2042 2043 @model_validator(mode="after") 2044 def _validate_postprocessing_kwargs(self) -> Self: 2045 axes_ids = [a.id for a in self.axes] 2046 for p in self.postprocessing: 2047 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 2048 if kwargs_axes is None: 2049 continue 2050 2051 if not isinstance(kwargs_axes, collections.abc.Sequence): 2052 raise ValueError( 2053 f"expected `axes` sequence, but got {type(kwargs_axes)}" 2054 ) 2055 2056 if any(a not in axes_ids for a in kwargs_axes): 2057 raise ValueError("`kwargs.axes` needs to be subset of axes ids") 2058 2059 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 2060 dtype = self.data.type 2061 else: 2062 dtype = self.data[0].type 2063 2064 # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 2065 if not self.postprocessing or not isinstance( 2066 self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr) 2067 ): 2068 self.postprocessing.append( 2069 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 2070 ) 2071 return self
Description of how this output should be postprocessed.
note: postprocessing
always ends with an 'ensure_dtype' operation.
If not given this is added to cast to this tensor's data.type
.
2121def validate_tensors( 2122 tensors: Mapping[TensorId, Tuple[TensorDescr, Optional[NDArray[Any]]]], 2123 tensor_origin: Literal[ 2124 "test_tensor" 2125 ], # for more precise error messages, e.g. 'test_tensor' 2126): 2127 all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, Optional[int]]]] = {} 2128 2129 def e_msg(d: TensorDescr): 2130 return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]" 2131 2132 for descr, array in tensors.values(): 2133 if array is None: 2134 axis_sizes = {a.id: None for a in descr.axes} 2135 else: 2136 try: 2137 axis_sizes = descr.get_axis_sizes_for_array(array) 2138 except ValueError as e: 2139 raise ValueError(f"{e_msg(descr)} {e}") 2140 2141 all_tensor_axes[descr.id] = {a.id: (a, axis_sizes[a.id]) for a in descr.axes} 2142 2143 for descr, array in tensors.values(): 2144 if array is None: 2145 continue 2146 2147 if descr.dtype in ("float32", "float64"): 2148 invalid_test_tensor_dtype = array.dtype.name not in ( 2149 "float32", 2150 "float64", 2151 "uint8", 2152 "int8", 2153 "uint16", 2154 "int16", 2155 "uint32", 2156 "int32", 2157 "uint64", 2158 "int64", 2159 ) 2160 else: 2161 invalid_test_tensor_dtype = array.dtype.name != descr.dtype 2162 2163 if invalid_test_tensor_dtype: 2164 raise ValueError( 2165 f"{e_msg(descr)}.{tensor_origin}.dtype '{array.dtype.name}' does not" 2166 + f" match described dtype '{descr.dtype}'" 2167 ) 2168 2169 if array.min() > -1e-4 and array.max() < 1e-4: 2170 raise ValueError( 2171 "Output values are too small for reliable testing." 2172 + f" Values <-1e5 or >=1e5 must be present in {tensor_origin}" 2173 ) 2174 2175 for a in descr.axes: 2176 actual_size = all_tensor_axes[descr.id][a.id][1] 2177 if actual_size is None: 2178 continue 2179 2180 if a.size is None: 2181 continue 2182 2183 if isinstance(a.size, int): 2184 if actual_size != a.size: 2185 raise ValueError( 2186 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' " 2187 + f"has incompatible size {actual_size}, expected {a.size}" 2188 ) 2189 elif isinstance(a.size, ParameterizedSize): 2190 _ = a.size.validate_size(actual_size) 2191 elif isinstance(a.size, DataDependentSize): 2192 _ = a.size.validate_size(actual_size) 2193 elif isinstance(a.size, SizeReference): 2194 ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id) 2195 if ref_tensor_axes is None: 2196 raise ValueError( 2197 f"{e_msg(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor" 2198 + f" reference '{a.size.tensor_id}'" 2199 ) 2200 2201 ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None)) 2202 if ref_axis is None or ref_size is None: 2203 raise ValueError( 2204 f"{e_msg(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis" 2205 + f" reference '{a.size.tensor_id}.{a.size.axis_id}" 2206 ) 2207 2208 if a.unit != ref_axis.unit: 2209 raise ValueError( 2210 f"{e_msg(descr)}.axes[{a.id}].size: `SizeReference` requires" 2211 + " axis and reference axis to have the same `unit`, but" 2212 + f" {a.unit}!={ref_axis.unit}" 2213 ) 2214 2215 if actual_size != ( 2216 expected_size := ( 2217 ref_size * ref_axis.scale / a.scale + a.size.offset 2218 ) 2219 ): 2220 raise ValueError( 2221 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' of size" 2222 + f" {actual_size} invalid for referenced size {ref_size};" 2223 + f" expected {expected_size}" 2224 ) 2225 else: 2226 assert_never(a.size)
2246class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr): 2247 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 2248 """Architecture source file""" 2249 2250 @model_serializer(mode="wrap", when_used="unless-none") 2251 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo): 2252 return package_file_descr_serializer(self, nxt, info)
A file description
Architecture source file
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
2255class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr): 2256 import_from: str 2257 """Where to import the callable from, i.e. `from <import_from> import <callable>`"""
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
2317class WeightsEntryDescrBase(FileDescr): 2318 type: ClassVar[WeightsFormat] 2319 weights_format_name: ClassVar[str] # human readable 2320 2321 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 2322 """Source of the weights file.""" 2323 2324 authors: Optional[List[Author]] = None 2325 """Authors 2326 Either the person(s) that have trained this model resulting in the original weights file. 2327 (If this is the initial weights entry, i.e. it does not have a `parent`) 2328 Or the person(s) who have converted the weights to this weights format. 2329 (If this is a child weight, i.e. it has a `parent` field) 2330 """ 2331 2332 parent: Annotated[ 2333 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"]) 2334 ] = None 2335 """The source weights these weights were converted from. 2336 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`, 2337 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights. 2338 All weight entries except one (the initial set of weights resulting from training the model), 2339 need to have this field.""" 2340 2341 comment: str = "" 2342 """A comment about this weights entry, for example how these weights were created.""" 2343 2344 @model_validator(mode="after") 2345 def _validate(self) -> Self: 2346 if self.type == self.parent: 2347 raise ValueError("Weights entry can't be it's own parent.") 2348 2349 return self 2350 2351 @model_serializer(mode="wrap", when_used="unless-none") 2352 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo): 2353 return package_file_descr_serializer(self, nxt, info)
A file description
Source of the weights file.
The source weights these weights were converted from.
For example, if a model's weights were converted from the pytorch_state_dict
format to torchscript
,
The pytorch_state_dict
weights entry has no parent
and is the parent of the torchscript
weights.
All weight entries except one (the initial set of weights resulting from training the model),
need to have this field.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
2356class KerasHdf5WeightsDescr(WeightsEntryDescrBase): 2357 type = "keras_hdf5" 2358 weights_format_name: ClassVar[str] = "Keras HDF5" 2359 tensorflow_version: Version 2360 """TensorFlow version used to create these weights."""
A file description
TensorFlow version used to create these weights.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
2363class OnnxWeightsDescr(WeightsEntryDescrBase): 2364 type = "onnx" 2365 weights_format_name: ClassVar[str] = "ONNX" 2366 opset_version: Annotated[int, Ge(7)] 2367 """ONNX opset version"""
A file description
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
2370class PytorchStateDictWeightsDescr(WeightsEntryDescrBase): 2371 type = "pytorch_state_dict" 2372 weights_format_name: ClassVar[str] = "Pytorch State Dict" 2373 architecture: Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr] 2374 pytorch_version: Version 2375 """Version of the PyTorch library used. 2376 If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible. 2377 """ 2378 dependencies: Optional[FileDescr_dependencies] = None 2379 """Custom depencies beyond pytorch described in a Conda environment file. 2380 Allows to specify custom dependencies, see conda docs: 2381 - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms) 2382 - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually) 2383 2384 The conda environment file should include pytorch and any version pinning has to be compatible with 2385 **pytorch_version**. 2386 """
A file description
Version of the PyTorch library used.
If architecture.depencencies
is specified it has to include pytorch and any version pinning has to be compatible.
Custom depencies beyond pytorch described in a Conda environment file. Allows to specify custom dependencies, see conda docs:
The conda environment file should include pytorch and any version pinning has to be compatible with pytorch_version.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
2389class TensorflowJsWeightsDescr(WeightsEntryDescrBase): 2390 type = "tensorflow_js" 2391 weights_format_name: ClassVar[str] = "Tensorflow.js" 2392 tensorflow_version: Version 2393 """Version of the TensorFlow library used.""" 2394 2395 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 2396 """The multi-file weights. 2397 All required files/folders should be a zip archive."""
A file description
Version of the TensorFlow library used.
The multi-file weights. All required files/folders should be a zip archive.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
2400class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase): 2401 type = "tensorflow_saved_model_bundle" 2402 weights_format_name: ClassVar[str] = "Tensorflow Saved Model" 2403 tensorflow_version: Version 2404 """Version of the TensorFlow library used.""" 2405 2406 dependencies: Optional[FileDescr_dependencies] = None 2407 """Custom dependencies beyond tensorflow. 2408 Should include tensorflow and any version pinning has to be compatible with **tensorflow_version**.""" 2409 2410 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 2411 """The multi-file weights. 2412 All required files/folders should be a zip archive."""
A file description
Version of the TensorFlow library used.
Custom dependencies beyond tensorflow. Should include tensorflow and any version pinning has to be compatible with tensorflow_version.
The multi-file weights. All required files/folders should be a zip archive.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
2415class TorchscriptWeightsDescr(WeightsEntryDescrBase): 2416 type = "torchscript" 2417 weights_format_name: ClassVar[str] = "TorchScript" 2418 pytorch_version: Version 2419 """Version of the PyTorch library used."""
A file description
Version of the PyTorch library used.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
2422class WeightsDescr(Node): 2423 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None 2424 onnx: Optional[OnnxWeightsDescr] = None 2425 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None 2426 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None 2427 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = ( 2428 None 2429 ) 2430 torchscript: Optional[TorchscriptWeightsDescr] = None 2431 2432 @model_validator(mode="after") 2433 def check_entries(self) -> Self: 2434 entries = {wtype for wtype, entry in self if entry is not None} 2435 2436 if not entries: 2437 raise ValueError("Missing weights entry") 2438 2439 entries_wo_parent = { 2440 wtype 2441 for wtype, entry in self 2442 if entry is not None and hasattr(entry, "parent") and entry.parent is None 2443 } 2444 if len(entries_wo_parent) != 1: 2445 issue_warning( 2446 "Exactly one weights entry may not specify the `parent` field (got" 2447 + " {value}). That entry is considered the original set of model weights." 2448 + " Other weight formats are created through conversion of the orignal or" 2449 + " already converted weights. They have to reference the weights format" 2450 + " they were converted from as their `parent`.", 2451 value=len(entries_wo_parent), 2452 field="weights", 2453 ) 2454 2455 for wtype, entry in self: 2456 if entry is None: 2457 continue 2458 2459 assert hasattr(entry, "type") 2460 assert hasattr(entry, "parent") 2461 assert wtype == entry.type 2462 if ( 2463 entry.parent is not None and entry.parent not in entries 2464 ): # self reference checked for `parent` field 2465 raise ValueError( 2466 f"`weights.{wtype}.parent={entry.parent} not in specified weight" 2467 + f" formats: {entries}" 2468 ) 2469 2470 return self 2471 2472 def __getitem__( 2473 self, 2474 key: Literal[ 2475 "keras_hdf5", 2476 "onnx", 2477 "pytorch_state_dict", 2478 "tensorflow_js", 2479 "tensorflow_saved_model_bundle", 2480 "torchscript", 2481 ], 2482 ): 2483 if key == "keras_hdf5": 2484 ret = self.keras_hdf5 2485 elif key == "onnx": 2486 ret = self.onnx 2487 elif key == "pytorch_state_dict": 2488 ret = self.pytorch_state_dict 2489 elif key == "tensorflow_js": 2490 ret = self.tensorflow_js 2491 elif key == "tensorflow_saved_model_bundle": 2492 ret = self.tensorflow_saved_model_bundle 2493 elif key == "torchscript": 2494 ret = self.torchscript 2495 else: 2496 raise KeyError(key) 2497 2498 if ret is None: 2499 raise KeyError(key) 2500 2501 return ret 2502 2503 @property 2504 def available_formats(self): 2505 return { 2506 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}), 2507 **({} if self.onnx is None else {"onnx": self.onnx}), 2508 **( 2509 {} 2510 if self.pytorch_state_dict is None 2511 else {"pytorch_state_dict": self.pytorch_state_dict} 2512 ), 2513 **( 2514 {} 2515 if self.tensorflow_js is None 2516 else {"tensorflow_js": self.tensorflow_js} 2517 ), 2518 **( 2519 {} 2520 if self.tensorflow_saved_model_bundle is None 2521 else { 2522 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle 2523 } 2524 ), 2525 **({} if self.torchscript is None else {"torchscript": self.torchscript}), 2526 } 2527 2528 @property 2529 def missing_formats(self): 2530 return { 2531 wf for wf in get_args(WeightsFormat) if wf not in self.available_formats 2532 }
2432 @model_validator(mode="after") 2433 def check_entries(self) -> Self: 2434 entries = {wtype for wtype, entry in self if entry is not None} 2435 2436 if not entries: 2437 raise ValueError("Missing weights entry") 2438 2439 entries_wo_parent = { 2440 wtype 2441 for wtype, entry in self 2442 if entry is not None and hasattr(entry, "parent") and entry.parent is None 2443 } 2444 if len(entries_wo_parent) != 1: 2445 issue_warning( 2446 "Exactly one weights entry may not specify the `parent` field (got" 2447 + " {value}). That entry is considered the original set of model weights." 2448 + " Other weight formats are created through conversion of the orignal or" 2449 + " already converted weights. They have to reference the weights format" 2450 + " they were converted from as their `parent`.", 2451 value=len(entries_wo_parent), 2452 field="weights", 2453 ) 2454 2455 for wtype, entry in self: 2456 if entry is None: 2457 continue 2458 2459 assert hasattr(entry, "type") 2460 assert hasattr(entry, "parent") 2461 assert wtype == entry.type 2462 if ( 2463 entry.parent is not None and entry.parent not in entries 2464 ): # self reference checked for `parent` field 2465 raise ValueError( 2466 f"`weights.{wtype}.parent={entry.parent} not in specified weight" 2467 + f" formats: {entries}" 2468 ) 2469 2470 return self
2503 @property 2504 def available_formats(self): 2505 return { 2506 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}), 2507 **({} if self.onnx is None else {"onnx": self.onnx}), 2508 **( 2509 {} 2510 if self.pytorch_state_dict is None 2511 else {"pytorch_state_dict": self.pytorch_state_dict} 2512 ), 2513 **( 2514 {} 2515 if self.tensorflow_js is None 2516 else {"tensorflow_js": self.tensorflow_js} 2517 ), 2518 **( 2519 {} 2520 if self.tensorflow_saved_model_bundle is None 2521 else { 2522 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle 2523 } 2524 ), 2525 **({} if self.torchscript is None else {"torchscript": self.torchscript}), 2526 }
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.
Inherited Members
2539class LinkedModel(LinkedResourceBase): 2540 """Reference to a bioimage.io model.""" 2541 2542 id: ModelId 2543 """A valid model `id` from the bioimage.io collection."""
Reference to a bioimage.io model.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
2565class ReproducibilityTolerance(Node, extra="allow"): 2566 """Describes what small numerical differences -- if any -- may be tolerated 2567 in the generated output when executing in different environments. 2568 2569 A tensor element *output* is considered mismatched to the **test_tensor** if 2570 abs(*output* - **test_tensor**) > **absolute_tolerance** + **relative_tolerance** * abs(**test_tensor**). 2571 (Internally we call [numpy.testing.assert_allclose](https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html).) 2572 2573 Motivation: 2574 For testing we can request the respective deep learning frameworks to be as 2575 reproducible as possible by setting seeds and chosing deterministic algorithms, 2576 but differences in operating systems, available hardware and installed drivers 2577 may still lead to numerical differences. 2578 """ 2579 2580 relative_tolerance: RelativeTolerance = 1e-3 2581 """Maximum relative tolerance of reproduced test tensor.""" 2582 2583 absolute_tolerance: AbsoluteTolerance = 1e-4 2584 """Maximum absolute tolerance of reproduced test tensor.""" 2585 2586 mismatched_elements_per_million: MismatchedElementsPerMillion = 100 2587 """Maximum number of mismatched elements/pixels per million to tolerate.""" 2588 2589 output_ids: Sequence[TensorId] = () 2590 """Limits the output tensor IDs these reproducibility details apply to.""" 2591 2592 weights_formats: Sequence[WeightsFormat] = () 2593 """Limits the weights formats these details apply to."""
Describes what small numerical differences -- if any -- may be tolerated in the generated output when executing in different environments.
A tensor element output is considered mismatched to the test_tensor if abs(output - test_tensor) > absolute_tolerance + relative_tolerance * abs(test_tensor). (Internally we call numpy.testing.assert_allclose.)
Motivation:
For testing we can request the respective deep learning frameworks to be as reproducible as possible by setting seeds and chosing deterministic algorithms, but differences in operating systems, available hardware and installed drivers may still lead to numerical differences.
Maximum relative tolerance of reproduced test tensor.
Maximum absolute tolerance of reproduced test tensor.
Maximum number of mismatched elements/pixels per million to tolerate.
Limits the weights formats these details apply to.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
2596class BioimageioConfig(Node, extra="allow"): 2597 reproducibility_tolerance: Sequence[ReproducibilityTolerance] = () 2598 """Tolerances to allow when reproducing the model's test outputs 2599 from the model's test inputs. 2600 Only the first entry matching tensor id and weights format is considered. 2601 """
Tolerances to allow when reproducing the model's test outputs from the model's test inputs. Only the first entry matching tensor id and weights format is considered.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
2604class Config(Node, extra="allow"): 2605 bioimageio: BioimageioConfig = Field( 2606 default_factory=BioimageioConfig.model_construct 2607 )
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
2610class ModelDescr(GenericModelDescrBase): 2611 """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. 2612 These fields are typically stored in a YAML file which we call a model resource description file (model RDF). 2613 """ 2614 2615 implemented_format_version: ClassVar[Literal["0.5.5"]] = "0.5.5" 2616 if TYPE_CHECKING: 2617 format_version: Literal["0.5.5"] = "0.5.5" 2618 else: 2619 format_version: Literal["0.5.5"] 2620 """Version of the bioimage.io model description specification used. 2621 When creating a new model always use the latest micro/patch version described here. 2622 The `format_version` is important for any consumer software to understand how to parse the fields. 2623 """ 2624 2625 implemented_type: ClassVar[Literal["model"]] = "model" 2626 if TYPE_CHECKING: 2627 type: Literal["model"] = "model" 2628 else: 2629 type: Literal["model"] 2630 """Specialized resource type 'model'""" 2631 2632 id: Optional[ModelId] = None 2633 """bioimage.io-wide unique resource identifier 2634 assigned by bioimage.io; version **un**specific.""" 2635 2636 authors: FAIR[List[Author]] = Field( 2637 default_factory=cast(Callable[[], List[Author]], list) 2638 ) 2639 """The authors are the creators of the model RDF and the primary points of contact.""" 2640 2641 documentation: FAIR[Optional[FileSource_documentation]] = None 2642 """URL or relative path to a markdown file with additional documentation. 2643 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory. 2644 The documentation should include a '#[#] Validation' (sub)section 2645 with details on how to quantitatively validate the model on unseen data.""" 2646 2647 @field_validator("documentation", mode="after") 2648 @classmethod 2649 def _validate_documentation( 2650 cls, value: Optional[FileSource_documentation] 2651 ) -> Optional[FileSource_documentation]: 2652 if not get_validation_context().perform_io_checks or value is None: 2653 return value 2654 2655 doc_reader = get_reader(value) 2656 doc_content = doc_reader.read().decode(encoding="utf-8") 2657 if not re.search("#.*[vV]alidation", doc_content): 2658 issue_warning( 2659 "No '# Validation' (sub)section found in {value}.", 2660 value=value, 2661 field="documentation", 2662 ) 2663 2664 return value 2665 2666 inputs: NotEmpty[Sequence[InputTensorDescr]] 2667 """Describes the input tensors expected by this model.""" 2668 2669 @field_validator("inputs", mode="after") 2670 @classmethod 2671 def _validate_input_axes( 2672 cls, inputs: Sequence[InputTensorDescr] 2673 ) -> Sequence[InputTensorDescr]: 2674 input_size_refs = cls._get_axes_with_independent_size(inputs) 2675 2676 for i, ipt in enumerate(inputs): 2677 valid_independent_refs: Dict[ 2678 Tuple[TensorId, AxisId], 2679 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2680 ] = { 2681 **{ 2682 (ipt.id, a.id): (ipt, a, a.size) 2683 for a in ipt.axes 2684 if not isinstance(a, BatchAxis) 2685 and isinstance(a.size, (int, ParameterizedSize)) 2686 }, 2687 **input_size_refs, 2688 } 2689 for a, ax in enumerate(ipt.axes): 2690 cls._validate_axis( 2691 "inputs", 2692 i=i, 2693 tensor_id=ipt.id, 2694 a=a, 2695 axis=ax, 2696 valid_independent_refs=valid_independent_refs, 2697 ) 2698 return inputs 2699 2700 @staticmethod 2701 def _validate_axis( 2702 field_name: str, 2703 i: int, 2704 tensor_id: TensorId, 2705 a: int, 2706 axis: AnyAxis, 2707 valid_independent_refs: Dict[ 2708 Tuple[TensorId, AxisId], 2709 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2710 ], 2711 ): 2712 if isinstance(axis, BatchAxis) or isinstance( 2713 axis.size, (int, ParameterizedSize, DataDependentSize) 2714 ): 2715 return 2716 elif not isinstance(axis.size, SizeReference): 2717 assert_never(axis.size) 2718 2719 # validate axis.size SizeReference 2720 ref = (axis.size.tensor_id, axis.size.axis_id) 2721 if ref not in valid_independent_refs: 2722 raise ValueError( 2723 "Invalid tensor axis reference at" 2724 + f" {field_name}[{i}].axes[{a}].size: {axis.size}." 2725 ) 2726 if ref == (tensor_id, axis.id): 2727 raise ValueError( 2728 "Self-referencing not allowed for" 2729 + f" {field_name}[{i}].axes[{a}].size: {axis.size}" 2730 ) 2731 if axis.type == "channel": 2732 if valid_independent_refs[ref][1].type != "channel": 2733 raise ValueError( 2734 "A channel axis' size may only reference another fixed size" 2735 + " channel axis." 2736 ) 2737 if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names: 2738 ref_size = valid_independent_refs[ref][2] 2739 assert isinstance(ref_size, int), ( 2740 "channel axis ref (another channel axis) has to specify fixed" 2741 + " size" 2742 ) 2743 generated_channel_names = [ 2744 Identifier(axis.channel_names.format(i=i)) 2745 for i in range(1, ref_size + 1) 2746 ] 2747 axis.channel_names = generated_channel_names 2748 2749 if (ax_unit := getattr(axis, "unit", None)) != ( 2750 ref_unit := getattr(valid_independent_refs[ref][1], "unit", None) 2751 ): 2752 raise ValueError( 2753 "The units of an axis and its reference axis need to match, but" 2754 + f" '{ax_unit}' != '{ref_unit}'." 2755 ) 2756 ref_axis = valid_independent_refs[ref][1] 2757 if isinstance(ref_axis, BatchAxis): 2758 raise ValueError( 2759 f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}" 2760 + " (a batch axis is not allowed as reference)." 2761 ) 2762 2763 if isinstance(axis, WithHalo): 2764 min_size = axis.size.get_size(axis, ref_axis, n=0) 2765 if (min_size - 2 * axis.halo) < 1: 2766 raise ValueError( 2767 f"axis {axis.id} with minimum size {min_size} is too small for halo" 2768 + f" {axis.halo}." 2769 ) 2770 2771 input_halo = axis.halo * axis.scale / ref_axis.scale 2772 if input_halo != int(input_halo) or input_halo % 2 == 1: 2773 raise ValueError( 2774 f"input_halo {input_halo} (output_halo {axis.halo} *" 2775 + f" output_scale {axis.scale} / input_scale {ref_axis.scale})" 2776 + f" {tensor_id}.{axis.id}." 2777 ) 2778 2779 @model_validator(mode="after") 2780 def _validate_test_tensors(self) -> Self: 2781 if not get_validation_context().perform_io_checks: 2782 return self 2783 2784 test_output_arrays = [ 2785 None if descr.test_tensor is None else load_array(descr.test_tensor) 2786 for descr in self.outputs 2787 ] 2788 test_input_arrays = [ 2789 None if descr.test_tensor is None else load_array(descr.test_tensor) 2790 for descr in self.inputs 2791 ] 2792 2793 tensors = { 2794 descr.id: (descr, array) 2795 for descr, array in zip( 2796 chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays 2797 ) 2798 } 2799 validate_tensors(tensors, tensor_origin="test_tensor") 2800 2801 output_arrays = { 2802 descr.id: array for descr, array in zip(self.outputs, test_output_arrays) 2803 } 2804 for rep_tol in self.config.bioimageio.reproducibility_tolerance: 2805 if not rep_tol.absolute_tolerance: 2806 continue 2807 2808 if rep_tol.output_ids: 2809 out_arrays = { 2810 oid: a 2811 for oid, a in output_arrays.items() 2812 if oid in rep_tol.output_ids 2813 } 2814 else: 2815 out_arrays = output_arrays 2816 2817 for out_id, array in out_arrays.items(): 2818 if array is None: 2819 continue 2820 2821 if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01: 2822 raise ValueError( 2823 "config.bioimageio.reproducibility_tolerance.absolute_tolerance=" 2824 + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}" 2825 + f" (1% of the maximum value of the test tensor '{out_id}')" 2826 ) 2827 2828 return self 2829 2830 @model_validator(mode="after") 2831 def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self: 2832 ipt_refs = {t.id for t in self.inputs} 2833 out_refs = {t.id for t in self.outputs} 2834 for ipt in self.inputs: 2835 for p in ipt.preprocessing: 2836 ref = p.kwargs.get("reference_tensor") 2837 if ref is None: 2838 continue 2839 if ref not in ipt_refs: 2840 raise ValueError( 2841 f"`reference_tensor` '{ref}' not found. Valid input tensor" 2842 + f" references are: {ipt_refs}." 2843 ) 2844 2845 for out in self.outputs: 2846 for p in out.postprocessing: 2847 ref = p.kwargs.get("reference_tensor") 2848 if ref is None: 2849 continue 2850 2851 if ref not in ipt_refs and ref not in out_refs: 2852 raise ValueError( 2853 f"`reference_tensor` '{ref}' not found. Valid tensor references" 2854 + f" are: {ipt_refs | out_refs}." 2855 ) 2856 2857 return self 2858 2859 # TODO: use validate funcs in validate_test_tensors 2860 # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]: 2861 2862 name: Annotated[ 2863 str, 2864 RestrictCharacters(string.ascii_letters + string.digits + "_+- ()"), 2865 MinLen(5), 2866 MaxLen(128), 2867 warn(MaxLen(64), "Name longer than 64 characters.", INFO), 2868 ] 2869 """A human-readable name of this model. 2870 It should be no longer than 64 characters 2871 and may only contain letter, number, underscore, minus, parentheses and spaces. 2872 We recommend to chose a name that refers to the model's task and image modality. 2873 """ 2874 2875 outputs: NotEmpty[Sequence[OutputTensorDescr]] 2876 """Describes the output tensors.""" 2877 2878 @field_validator("outputs", mode="after") 2879 @classmethod 2880 def _validate_tensor_ids( 2881 cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo 2882 ) -> Sequence[OutputTensorDescr]: 2883 tensor_ids = [ 2884 t.id for t in info.data.get("inputs", []) + info.data.get("outputs", []) 2885 ] 2886 duplicate_tensor_ids: List[str] = [] 2887 seen: Set[str] = set() 2888 for t in tensor_ids: 2889 if t in seen: 2890 duplicate_tensor_ids.append(t) 2891 2892 seen.add(t) 2893 2894 if duplicate_tensor_ids: 2895 raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}") 2896 2897 return outputs 2898 2899 @staticmethod 2900 def _get_axes_with_parameterized_size( 2901 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 2902 ): 2903 return { 2904 f"{t.id}.{a.id}": (t, a, a.size) 2905 for t in io 2906 for a in t.axes 2907 if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize) 2908 } 2909 2910 @staticmethod 2911 def _get_axes_with_independent_size( 2912 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 2913 ): 2914 return { 2915 (t.id, a.id): (t, a, a.size) 2916 for t in io 2917 for a in t.axes 2918 if not isinstance(a, BatchAxis) 2919 and isinstance(a.size, (int, ParameterizedSize)) 2920 } 2921 2922 @field_validator("outputs", mode="after") 2923 @classmethod 2924 def _validate_output_axes( 2925 cls, outputs: List[OutputTensorDescr], info: ValidationInfo 2926 ) -> List[OutputTensorDescr]: 2927 input_size_refs = cls._get_axes_with_independent_size( 2928 info.data.get("inputs", []) 2929 ) 2930 output_size_refs = cls._get_axes_with_independent_size(outputs) 2931 2932 for i, out in enumerate(outputs): 2933 valid_independent_refs: Dict[ 2934 Tuple[TensorId, AxisId], 2935 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2936 ] = { 2937 **{ 2938 (out.id, a.id): (out, a, a.size) 2939 for a in out.axes 2940 if not isinstance(a, BatchAxis) 2941 and isinstance(a.size, (int, ParameterizedSize)) 2942 }, 2943 **input_size_refs, 2944 **output_size_refs, 2945 } 2946 for a, ax in enumerate(out.axes): 2947 cls._validate_axis( 2948 "outputs", 2949 i, 2950 out.id, 2951 a, 2952 ax, 2953 valid_independent_refs=valid_independent_refs, 2954 ) 2955 2956 return outputs 2957 2958 packaged_by: List[Author] = Field( 2959 default_factory=cast(Callable[[], List[Author]], list) 2960 ) 2961 """The persons that have packaged and uploaded this model. 2962 Only required if those persons differ from the `authors`.""" 2963 2964 parent: Optional[LinkedModel] = None 2965 """The model from which this model is derived, e.g. by fine-tuning the weights.""" 2966 2967 @model_validator(mode="after") 2968 def _validate_parent_is_not_self(self) -> Self: 2969 if self.parent is not None and self.parent.id == self.id: 2970 raise ValueError("A model description may not reference itself as parent.") 2971 2972 return self 2973 2974 run_mode: Annotated[ 2975 Optional[RunMode], 2976 warn(None, "Run mode '{value}' has limited support across consumer softwares."), 2977 ] = None 2978 """Custom run mode for this model: for more complex prediction procedures like test time 2979 data augmentation that currently cannot be expressed in the specification. 2980 No standard run modes are defined yet.""" 2981 2982 timestamp: Datetime = Field(default_factory=Datetime.now) 2983 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format 2984 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat). 2985 (In Python a datetime object is valid, too).""" 2986 2987 training_data: Annotated[ 2988 Union[None, LinkedDataset, DatasetDescr, DatasetDescr02], 2989 Field(union_mode="left_to_right"), 2990 ] = None 2991 """The dataset used to train this model""" 2992 2993 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)] 2994 """The weights for this model. 2995 Weights can be given for different formats, but should otherwise be equivalent. 2996 The available weight formats determine which consumers can use this model.""" 2997 2998 config: Config = Field(default_factory=Config.model_construct) 2999 3000 @model_validator(mode="after") 3001 def _add_default_cover(self) -> Self: 3002 if not get_validation_context().perform_io_checks or self.covers: 3003 return self 3004 3005 try: 3006 generated_covers = generate_covers( 3007 [ 3008 (t, load_array(t.test_tensor)) 3009 for t in self.inputs 3010 if t.test_tensor is not None 3011 ], 3012 [ 3013 (t, load_array(t.test_tensor)) 3014 for t in self.outputs 3015 if t.test_tensor is not None 3016 ], 3017 ) 3018 except Exception as e: 3019 issue_warning( 3020 "Failed to generate cover image(s): {e}", 3021 value=self.covers, 3022 msg_context=dict(e=e), 3023 field="covers", 3024 ) 3025 else: 3026 self.covers.extend(generated_covers) 3027 3028 return self 3029 3030 def get_input_test_arrays(self) -> List[NDArray[Any]]: 3031 return self._get_test_arrays(self.inputs) 3032 3033 def get_output_test_arrays(self) -> List[NDArray[Any]]: 3034 return self._get_test_arrays(self.outputs) 3035 3036 @staticmethod 3037 def _get_test_arrays( 3038 io_descr: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 3039 ): 3040 ts: List[FileDescr] = [] 3041 for d in io_descr: 3042 if d.test_tensor is None: 3043 raise ValueError( 3044 f"Failed to get test arrays: description of '{d.id}' is missing a `test_tensor`." 3045 ) 3046 ts.append(d.test_tensor) 3047 3048 data = [load_array(t) for t in ts] 3049 assert all(isinstance(d, np.ndarray) for d in data) 3050 return data 3051 3052 @staticmethod 3053 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int: 3054 batch_size = 1 3055 tensor_with_batchsize: Optional[TensorId] = None 3056 for tid in tensor_sizes: 3057 for aid, s in tensor_sizes[tid].items(): 3058 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size: 3059 continue 3060 3061 if batch_size != 1: 3062 assert tensor_with_batchsize is not None 3063 raise ValueError( 3064 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})" 3065 ) 3066 3067 batch_size = s 3068 tensor_with_batchsize = tid 3069 3070 return batch_size 3071 3072 def get_output_tensor_sizes( 3073 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]] 3074 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]: 3075 """Returns the tensor output sizes for given **input_sizes**. 3076 Only if **input_sizes** has a valid input shape, the tensor output size is exact. 3077 Otherwise it might be larger than the actual (valid) output""" 3078 batch_size = self.get_batch_size(input_sizes) 3079 ns = self.get_ns(input_sizes) 3080 3081 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size) 3082 return tensor_sizes.outputs 3083 3084 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]): 3085 """get parameter `n` for each parameterized axis 3086 such that the valid input size is >= the given input size""" 3087 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {} 3088 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs} 3089 for tid in input_sizes: 3090 for aid, s in input_sizes[tid].items(): 3091 size_descr = axes[tid][aid].size 3092 if isinstance(size_descr, ParameterizedSize): 3093 ret[(tid, aid)] = size_descr.get_n(s) 3094 elif size_descr is None or isinstance(size_descr, (int, SizeReference)): 3095 pass 3096 else: 3097 assert_never(size_descr) 3098 3099 return ret 3100 3101 def get_tensor_sizes( 3102 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int 3103 ) -> _TensorSizes: 3104 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size) 3105 return _TensorSizes( 3106 { 3107 t: { 3108 aa: axis_sizes.inputs[(tt, aa)] 3109 for tt, aa in axis_sizes.inputs 3110 if tt == t 3111 } 3112 for t in {tt for tt, _ in axis_sizes.inputs} 3113 }, 3114 { 3115 t: { 3116 aa: axis_sizes.outputs[(tt, aa)] 3117 for tt, aa in axis_sizes.outputs 3118 if tt == t 3119 } 3120 for t in {tt for tt, _ in axis_sizes.outputs} 3121 }, 3122 ) 3123 3124 def get_axis_sizes( 3125 self, 3126 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], 3127 batch_size: Optional[int] = None, 3128 *, 3129 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None, 3130 ) -> _AxisSizes: 3131 """Determine input and output block shape for scale factors **ns** 3132 of parameterized input sizes. 3133 3134 Args: 3135 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id)) 3136 that is parameterized as `size = min + n * step`. 3137 batch_size: The desired size of the batch dimension. 3138 If given **batch_size** overwrites any batch size present in 3139 **max_input_shape**. Default 1. 3140 max_input_shape: Limits the derived block shapes. 3141 Each axis for which the input size, parameterized by `n`, is larger 3142 than **max_input_shape** is set to the minimal value `n_min` for which 3143 this is still true. 3144 Use this for small input samples or large values of **ns**. 3145 Or simply whenever you know the full input shape. 3146 3147 Returns: 3148 Resolved axis sizes for model inputs and outputs. 3149 """ 3150 max_input_shape = max_input_shape or {} 3151 if batch_size is None: 3152 for (_t_id, a_id), s in max_input_shape.items(): 3153 if a_id == BATCH_AXIS_ID: 3154 batch_size = s 3155 break 3156 else: 3157 batch_size = 1 3158 3159 all_axes = { 3160 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs) 3161 } 3162 3163 inputs: Dict[Tuple[TensorId, AxisId], int] = {} 3164 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {} 3165 3166 def get_axis_size(a: Union[InputAxis, OutputAxis]): 3167 if isinstance(a, BatchAxis): 3168 if (t_descr.id, a.id) in ns: 3169 logger.warning( 3170 "Ignoring unexpected size increment factor (n) for batch axis" 3171 + " of tensor '{}'.", 3172 t_descr.id, 3173 ) 3174 return batch_size 3175 elif isinstance(a.size, int): 3176 if (t_descr.id, a.id) in ns: 3177 logger.warning( 3178 "Ignoring unexpected size increment factor (n) for fixed size" 3179 + " axis '{}' of tensor '{}'.", 3180 a.id, 3181 t_descr.id, 3182 ) 3183 return a.size 3184 elif isinstance(a.size, ParameterizedSize): 3185 if (t_descr.id, a.id) not in ns: 3186 raise ValueError( 3187 "Size increment factor (n) missing for parametrized axis" 3188 + f" '{a.id}' of tensor '{t_descr.id}'." 3189 ) 3190 n = ns[(t_descr.id, a.id)] 3191 s_max = max_input_shape.get((t_descr.id, a.id)) 3192 if s_max is not None: 3193 n = min(n, a.size.get_n(s_max)) 3194 3195 return a.size.get_size(n) 3196 3197 elif isinstance(a.size, SizeReference): 3198 if (t_descr.id, a.id) in ns: 3199 logger.warning( 3200 "Ignoring unexpected size increment factor (n) for axis '{}'" 3201 + " of tensor '{}' with size reference.", 3202 a.id, 3203 t_descr.id, 3204 ) 3205 assert not isinstance(a, BatchAxis) 3206 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id] 3207 assert not isinstance(ref_axis, BatchAxis) 3208 ref_key = (a.size.tensor_id, a.size.axis_id) 3209 ref_size = inputs.get(ref_key, outputs.get(ref_key)) 3210 assert ref_size is not None, ref_key 3211 assert not isinstance(ref_size, _DataDepSize), ref_key 3212 return a.size.get_size( 3213 axis=a, 3214 ref_axis=ref_axis, 3215 ref_size=ref_size, 3216 ) 3217 elif isinstance(a.size, DataDependentSize): 3218 if (t_descr.id, a.id) in ns: 3219 logger.warning( 3220 "Ignoring unexpected increment factor (n) for data dependent" 3221 + " size axis '{}' of tensor '{}'.", 3222 a.id, 3223 t_descr.id, 3224 ) 3225 return _DataDepSize(a.size.min, a.size.max) 3226 else: 3227 assert_never(a.size) 3228 3229 # first resolve all , but the `SizeReference` input sizes 3230 for t_descr in self.inputs: 3231 for a in t_descr.axes: 3232 if not isinstance(a.size, SizeReference): 3233 s = get_axis_size(a) 3234 assert not isinstance(s, _DataDepSize) 3235 inputs[t_descr.id, a.id] = s 3236 3237 # resolve all other input axis sizes 3238 for t_descr in self.inputs: 3239 for a in t_descr.axes: 3240 if isinstance(a.size, SizeReference): 3241 s = get_axis_size(a) 3242 assert not isinstance(s, _DataDepSize) 3243 inputs[t_descr.id, a.id] = s 3244 3245 # resolve all output axis sizes 3246 for t_descr in self.outputs: 3247 for a in t_descr.axes: 3248 assert not isinstance(a.size, ParameterizedSize) 3249 s = get_axis_size(a) 3250 outputs[t_descr.id, a.id] = s 3251 3252 return _AxisSizes(inputs=inputs, outputs=outputs) 3253 3254 @model_validator(mode="before") 3255 @classmethod 3256 def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]: 3257 cls.convert_from_old_format_wo_validation(data) 3258 return data 3259 3260 @classmethod 3261 def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None: 3262 """Convert metadata following an older format version to this classes' format 3263 without validating the result. 3264 """ 3265 if ( 3266 data.get("type") == "model" 3267 and isinstance(fv := data.get("format_version"), str) 3268 and fv.count(".") == 2 3269 ): 3270 fv_parts = fv.split(".") 3271 if any(not p.isdigit() for p in fv_parts): 3272 return 3273 3274 fv_tuple = tuple(map(int, fv_parts)) 3275 3276 assert cls.implemented_format_version_tuple[0:2] == (0, 5) 3277 if fv_tuple[:2] in ((0, 3), (0, 4)): 3278 m04 = _ModelDescr_v0_4.load(data) 3279 if isinstance(m04, InvalidDescr): 3280 try: 3281 updated = _model_conv.convert_as_dict( 3282 m04 # pyright: ignore[reportArgumentType] 3283 ) 3284 except Exception as e: 3285 logger.error( 3286 "Failed to convert from invalid model 0.4 description." 3287 + f"\nerror: {e}" 3288 + "\nProceeding with model 0.5 validation without conversion." 3289 ) 3290 updated = None 3291 else: 3292 updated = _model_conv.convert_as_dict(m04) 3293 3294 if updated is not None: 3295 data.clear() 3296 data.update(updated) 3297 3298 elif fv_tuple[:2] == (0, 5): 3299 # bump patch version 3300 data["format_version"] = cls.implemented_format_version
Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
bioimage.io-wide unique resource identifier assigned by bioimage.io; version unspecific.
URL or relative path to a markdown file with additional documentation.
The recommended documentation file name is README.md
. An .md
suffix is mandatory.
The documentation should include a '#[#] Validation' (sub)section
with details on how to quantitatively validate the model on unseen data.
Describes the input tensors expected by this model.
A human-readable name of this model. It should be no longer than 64 characters and may only contain letter, number, underscore, minus, parentheses and spaces. We recommend to chose a name that refers to the model's task and image modality.
The persons that have packaged and uploaded this model.
Only required if those persons differ from the authors
.
The model from which this model is derived, e.g. by fine-tuning the weights.
Custom run mode for this model: for more complex prediction procedures like test time data augmentation that currently cannot be expressed in the specification. No standard run modes are defined yet.
The dataset used to train this model
The weights for this model. Weights can be given for different formats, but should otherwise be equivalent. The available weight formats determine which consumers can use this model.
3052 @staticmethod 3053 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int: 3054 batch_size = 1 3055 tensor_with_batchsize: Optional[TensorId] = None 3056 for tid in tensor_sizes: 3057 for aid, s in tensor_sizes[tid].items(): 3058 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size: 3059 continue 3060 3061 if batch_size != 1: 3062 assert tensor_with_batchsize is not None 3063 raise ValueError( 3064 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})" 3065 ) 3066 3067 batch_size = s 3068 tensor_with_batchsize = tid 3069 3070 return batch_size
3072 def get_output_tensor_sizes( 3073 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]] 3074 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]: 3075 """Returns the tensor output sizes for given **input_sizes**. 3076 Only if **input_sizes** has a valid input shape, the tensor output size is exact. 3077 Otherwise it might be larger than the actual (valid) output""" 3078 batch_size = self.get_batch_size(input_sizes) 3079 ns = self.get_ns(input_sizes) 3080 3081 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size) 3082 return tensor_sizes.outputs
Returns the tensor output sizes for given input_sizes. Only if input_sizes has a valid input shape, the tensor output size is exact. Otherwise it might be larger than the actual (valid) output
3084 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]): 3085 """get parameter `n` for each parameterized axis 3086 such that the valid input size is >= the given input size""" 3087 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {} 3088 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs} 3089 for tid in input_sizes: 3090 for aid, s in input_sizes[tid].items(): 3091 size_descr = axes[tid][aid].size 3092 if isinstance(size_descr, ParameterizedSize): 3093 ret[(tid, aid)] = size_descr.get_n(s) 3094 elif size_descr is None or isinstance(size_descr, (int, SizeReference)): 3095 pass 3096 else: 3097 assert_never(size_descr) 3098 3099 return ret
get parameter n
for each parameterized axis
such that the valid input size is >= the given input size
3101 def get_tensor_sizes( 3102 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int 3103 ) -> _TensorSizes: 3104 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size) 3105 return _TensorSizes( 3106 { 3107 t: { 3108 aa: axis_sizes.inputs[(tt, aa)] 3109 for tt, aa in axis_sizes.inputs 3110 if tt == t 3111 } 3112 for t in {tt for tt, _ in axis_sizes.inputs} 3113 }, 3114 { 3115 t: { 3116 aa: axis_sizes.outputs[(tt, aa)] 3117 for tt, aa in axis_sizes.outputs 3118 if tt == t 3119 } 3120 for t in {tt for tt, _ in axis_sizes.outputs} 3121 }, 3122 )
3124 def get_axis_sizes( 3125 self, 3126 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], 3127 batch_size: Optional[int] = None, 3128 *, 3129 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None, 3130 ) -> _AxisSizes: 3131 """Determine input and output block shape for scale factors **ns** 3132 of parameterized input sizes. 3133 3134 Args: 3135 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id)) 3136 that is parameterized as `size = min + n * step`. 3137 batch_size: The desired size of the batch dimension. 3138 If given **batch_size** overwrites any batch size present in 3139 **max_input_shape**. Default 1. 3140 max_input_shape: Limits the derived block shapes. 3141 Each axis for which the input size, parameterized by `n`, is larger 3142 than **max_input_shape** is set to the minimal value `n_min` for which 3143 this is still true. 3144 Use this for small input samples or large values of **ns**. 3145 Or simply whenever you know the full input shape. 3146 3147 Returns: 3148 Resolved axis sizes for model inputs and outputs. 3149 """ 3150 max_input_shape = max_input_shape or {} 3151 if batch_size is None: 3152 for (_t_id, a_id), s in max_input_shape.items(): 3153 if a_id == BATCH_AXIS_ID: 3154 batch_size = s 3155 break 3156 else: 3157 batch_size = 1 3158 3159 all_axes = { 3160 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs) 3161 } 3162 3163 inputs: Dict[Tuple[TensorId, AxisId], int] = {} 3164 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {} 3165 3166 def get_axis_size(a: Union[InputAxis, OutputAxis]): 3167 if isinstance(a, BatchAxis): 3168 if (t_descr.id, a.id) in ns: 3169 logger.warning( 3170 "Ignoring unexpected size increment factor (n) for batch axis" 3171 + " of tensor '{}'.", 3172 t_descr.id, 3173 ) 3174 return batch_size 3175 elif isinstance(a.size, int): 3176 if (t_descr.id, a.id) in ns: 3177 logger.warning( 3178 "Ignoring unexpected size increment factor (n) for fixed size" 3179 + " axis '{}' of tensor '{}'.", 3180 a.id, 3181 t_descr.id, 3182 ) 3183 return a.size 3184 elif isinstance(a.size, ParameterizedSize): 3185 if (t_descr.id, a.id) not in ns: 3186 raise ValueError( 3187 "Size increment factor (n) missing for parametrized axis" 3188 + f" '{a.id}' of tensor '{t_descr.id}'." 3189 ) 3190 n = ns[(t_descr.id, a.id)] 3191 s_max = max_input_shape.get((t_descr.id, a.id)) 3192 if s_max is not None: 3193 n = min(n, a.size.get_n(s_max)) 3194 3195 return a.size.get_size(n) 3196 3197 elif isinstance(a.size, SizeReference): 3198 if (t_descr.id, a.id) in ns: 3199 logger.warning( 3200 "Ignoring unexpected size increment factor (n) for axis '{}'" 3201 + " of tensor '{}' with size reference.", 3202 a.id, 3203 t_descr.id, 3204 ) 3205 assert not isinstance(a, BatchAxis) 3206 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id] 3207 assert not isinstance(ref_axis, BatchAxis) 3208 ref_key = (a.size.tensor_id, a.size.axis_id) 3209 ref_size = inputs.get(ref_key, outputs.get(ref_key)) 3210 assert ref_size is not None, ref_key 3211 assert not isinstance(ref_size, _DataDepSize), ref_key 3212 return a.size.get_size( 3213 axis=a, 3214 ref_axis=ref_axis, 3215 ref_size=ref_size, 3216 ) 3217 elif isinstance(a.size, DataDependentSize): 3218 if (t_descr.id, a.id) in ns: 3219 logger.warning( 3220 "Ignoring unexpected increment factor (n) for data dependent" 3221 + " size axis '{}' of tensor '{}'.", 3222 a.id, 3223 t_descr.id, 3224 ) 3225 return _DataDepSize(a.size.min, a.size.max) 3226 else: 3227 assert_never(a.size) 3228 3229 # first resolve all , but the `SizeReference` input sizes 3230 for t_descr in self.inputs: 3231 for a in t_descr.axes: 3232 if not isinstance(a.size, SizeReference): 3233 s = get_axis_size(a) 3234 assert not isinstance(s, _DataDepSize) 3235 inputs[t_descr.id, a.id] = s 3236 3237 # resolve all other input axis sizes 3238 for t_descr in self.inputs: 3239 for a in t_descr.axes: 3240 if isinstance(a.size, SizeReference): 3241 s = get_axis_size(a) 3242 assert not isinstance(s, _DataDepSize) 3243 inputs[t_descr.id, a.id] = s 3244 3245 # resolve all output axis sizes 3246 for t_descr in self.outputs: 3247 for a in t_descr.axes: 3248 assert not isinstance(a.size, ParameterizedSize) 3249 s = get_axis_size(a) 3250 outputs[t_descr.id, a.id] = s 3251 3252 return _AxisSizes(inputs=inputs, outputs=outputs)
Determine input and output block shape for scale factors ns of parameterized input sizes.
Arguments:
- ns: Scale factor
n
for each axis (keyed by (tensor_id, axis_id)) that is parameterized assize = min + n * step
. - batch_size: The desired size of the batch dimension. If given batch_size overwrites any batch size present in max_input_shape. Default 1.
- max_input_shape: Limits the derived block shapes.
Each axis for which the input size, parameterized by
n
, is larger than max_input_shape is set to the minimal valuen_min
for which this is still true. Use this for small input samples or large values of ns. Or simply whenever you know the full input shape.
Returns:
Resolved axis sizes for model inputs and outputs.
3260 @classmethod 3261 def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None: 3262 """Convert metadata following an older format version to this classes' format 3263 without validating the result. 3264 """ 3265 if ( 3266 data.get("type") == "model" 3267 and isinstance(fv := data.get("format_version"), str) 3268 and fv.count(".") == 2 3269 ): 3270 fv_parts = fv.split(".") 3271 if any(not p.isdigit() for p in fv_parts): 3272 return 3273 3274 fv_tuple = tuple(map(int, fv_parts)) 3275 3276 assert cls.implemented_format_version_tuple[0:2] == (0, 5) 3277 if fv_tuple[:2] in ((0, 3), (0, 4)): 3278 m04 = _ModelDescr_v0_4.load(data) 3279 if isinstance(m04, InvalidDescr): 3280 try: 3281 updated = _model_conv.convert_as_dict( 3282 m04 # pyright: ignore[reportArgumentType] 3283 ) 3284 except Exception as e: 3285 logger.error( 3286 "Failed to convert from invalid model 0.4 description." 3287 + f"\nerror: {e}" 3288 + "\nProceeding with model 0.5 validation without conversion." 3289 ) 3290 updated = None 3291 else: 3292 updated = _model_conv.convert_as_dict(m04) 3293 3294 if updated is not None: 3295 data.clear() 3296 data.update(updated) 3297 3298 elif fv_tuple[:2] == (0, 5): 3299 # bump patch version 3300 data["format_version"] = cls.implemented_format_version
Convert metadata following an older format version to this classes' format without validating the result.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
337def init_private_attributes(self: BaseModel, context: Any, /) -> None: 338 """This function is meant to behave like a BaseModel method to initialise private attributes. 339 340 It takes context as an argument since that's what pydantic-core passes when calling it. 341 342 Args: 343 self: The BaseModel instance. 344 context: The context. 345 """ 346 if getattr(self, '__pydantic_private__', None) is None: 347 pydantic_private = {} 348 for name, private_attr in self.__private_attributes__.items(): 349 default = private_attr.get_default() 350 if default is not PydanticUndefined: 351 pydantic_private[name] = default 352 object_setattr(self, '__pydantic_private__', pydantic_private)
This function is meant to behave like a BaseModel method to initialise private attributes.
It takes context as an argument since that's what pydantic-core passes when calling it.
Arguments:
- self: The BaseModel instance.
- context: The context.
Inherited Members
3517def generate_covers( 3518 inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]], 3519 outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]], 3520) -> List[Path]: 3521 def squeeze( 3522 data: NDArray[Any], axes: Sequence[AnyAxis] 3523 ) -> Tuple[NDArray[Any], List[AnyAxis]]: 3524 """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining""" 3525 if data.ndim != len(axes): 3526 raise ValueError( 3527 f"tensor shape {data.shape} does not match described axes" 3528 + f" {[a.id for a in axes]}" 3529 ) 3530 3531 axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1] 3532 return data.squeeze(), axes 3533 3534 def normalize( 3535 data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7 3536 ) -> NDArray[np.float32]: 3537 data = data.astype("float32") 3538 data -= data.min(axis=axis, keepdims=True) 3539 data /= data.max(axis=axis, keepdims=True) + eps 3540 return data 3541 3542 def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]): 3543 original_shape = data.shape 3544 data, axes = squeeze(data, axes) 3545 3546 # take slice fom any batch or index axis if needed 3547 # and convert the first channel axis and take a slice from any additional channel axes 3548 slices: Tuple[slice, ...] = () 3549 ndim = data.ndim 3550 ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2 3551 has_c_axis = False 3552 for i, a in enumerate(axes): 3553 s = data.shape[i] 3554 assert s > 1 3555 if ( 3556 isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis)) 3557 and ndim > ndim_need 3558 ): 3559 data = data[slices + (slice(s // 2 - 1, s // 2),)] 3560 ndim -= 1 3561 elif isinstance(a, ChannelAxis): 3562 if has_c_axis: 3563 # second channel axis 3564 data = data[slices + (slice(0, 1),)] 3565 ndim -= 1 3566 else: 3567 has_c_axis = True 3568 if s == 2: 3569 # visualize two channels with cyan and magenta 3570 data = np.concatenate( 3571 [ 3572 data[slices + (slice(1, 2),)], 3573 data[slices + (slice(0, 1),)], 3574 ( 3575 data[slices + (slice(0, 1),)] 3576 + data[slices + (slice(1, 2),)] 3577 ) 3578 / 2, # TODO: take maximum instead? 3579 ], 3580 axis=i, 3581 ) 3582 elif data.shape[i] == 3: 3583 pass # visualize 3 channels as RGB 3584 else: 3585 # visualize first 3 channels as RGB 3586 data = data[slices + (slice(3),)] 3587 3588 assert data.shape[i] == 3 3589 3590 slices += (slice(None),) 3591 3592 data, axes = squeeze(data, axes) 3593 assert len(axes) == ndim 3594 # take slice from z axis if needed 3595 slices = () 3596 if ndim > ndim_need: 3597 for i, a in enumerate(axes): 3598 s = data.shape[i] 3599 if a.id == AxisId("z"): 3600 data = data[slices + (slice(s // 2 - 1, s // 2),)] 3601 data, axes = squeeze(data, axes) 3602 ndim -= 1 3603 break 3604 3605 slices += (slice(None),) 3606 3607 # take slice from any space or time axis 3608 slices = () 3609 3610 for i, a in enumerate(axes): 3611 if ndim <= ndim_need: 3612 break 3613 3614 s = data.shape[i] 3615 assert s > 1 3616 if isinstance( 3617 a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis) 3618 ): 3619 data = data[slices + (slice(s // 2 - 1, s // 2),)] 3620 ndim -= 1 3621 3622 slices += (slice(None),) 3623 3624 del slices 3625 data, axes = squeeze(data, axes) 3626 assert len(axes) == ndim 3627 3628 if (has_c_axis and ndim != 3) or ndim != 2: 3629 raise ValueError( 3630 f"Failed to construct cover image from shape {original_shape}" 3631 ) 3632 3633 if not has_c_axis: 3634 assert ndim == 2 3635 data = np.repeat(data[:, :, None], 3, axis=2) 3636 axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB")))) 3637 ndim += 1 3638 3639 assert ndim == 3 3640 3641 # transpose axis order such that longest axis comes first... 3642 axis_order: List[int] = list(np.argsort(list(data.shape))) 3643 axis_order.reverse() 3644 # ... and channel axis is last 3645 c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0] 3646 axis_order.append(axis_order.pop(c)) 3647 axes = [axes[ao] for ao in axis_order] 3648 data = data.transpose(axis_order) 3649 3650 # h, w = data.shape[:2] 3651 # if h / w in (1.0 or 2.0): 3652 # pass 3653 # elif h / w < 2: 3654 # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images 3655 3656 norm_along = ( 3657 tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None 3658 ) 3659 # normalize the data and map to 8 bit 3660 data = normalize(data, norm_along) 3661 data = (data * 255).astype("uint8") 3662 3663 return data 3664 3665 def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]): 3666 assert im0.dtype == im1.dtype == np.uint8 3667 assert im0.shape == im1.shape 3668 assert im0.ndim == 3 3669 N, M, C = im0.shape 3670 assert C == 3 3671 out = np.ones((N, M, C), dtype="uint8") 3672 for c in range(C): 3673 outc = np.tril(im0[..., c]) 3674 mask = outc == 0 3675 outc[mask] = np.triu(im1[..., c])[mask] 3676 out[..., c] = outc 3677 3678 return out 3679 3680 if not inputs: 3681 raise ValueError("Missing test input tensor for cover generation.") 3682 3683 if not outputs: 3684 raise ValueError("Missing test output tensor for cover generation.") 3685 3686 ipt_descr, ipt = inputs[0] 3687 out_descr, out = outputs[0] 3688 3689 ipt_img = to_2d_image(ipt, ipt_descr.axes) 3690 out_img = to_2d_image(out, out_descr.axes) 3691 3692 cover_folder = Path(mkdtemp()) 3693 if ipt_img.shape == out_img.shape: 3694 covers = [cover_folder / "cover.png"] 3695 imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img)) 3696 else: 3697 covers = [cover_folder / "input.png", cover_folder / "output.png"] 3698 imwrite(covers[0], ipt_img) 3699 imwrite(covers[1], out_img) 3700 3701 return covers