bioimageio.spec.model.v0_5
1from __future__ import annotations 2 3import collections.abc 4import re 5import string 6import warnings 7from abc import ABC 8from copy import deepcopy 9from itertools import chain 10from math import ceil 11from pathlib import Path, PurePosixPath 12from tempfile import mkdtemp 13from typing import ( 14 TYPE_CHECKING, 15 Any, 16 Callable, 17 ClassVar, 18 Dict, 19 Generic, 20 List, 21 Literal, 22 Mapping, 23 NamedTuple, 24 Optional, 25 Sequence, 26 Set, 27 Tuple, 28 Type, 29 TypeVar, 30 Union, 31 cast, 32) 33 34import numpy as np 35from annotated_types import Ge, Gt, Interval, MaxLen, MinLen, Predicate 36from imageio.v3 import imread, imwrite # pyright: ignore[reportUnknownVariableType] 37from loguru import logger 38from numpy.typing import NDArray 39from pydantic import ( 40 AfterValidator, 41 Discriminator, 42 Field, 43 RootModel, 44 SerializationInfo, 45 SerializerFunctionWrapHandler, 46 StrictInt, 47 Tag, 48 ValidationInfo, 49 WrapSerializer, 50 field_validator, 51 model_serializer, 52 model_validator, 53) 54from typing_extensions import Annotated, Self, assert_never, get_args 55 56from .._internal.common_nodes import ( 57 InvalidDescr, 58 Node, 59 NodeWithExplicitlySetFields, 60) 61from .._internal.constants import DTYPE_LIMITS 62from .._internal.field_warning import issue_warning, warn 63from .._internal.io import BioimageioYamlContent as BioimageioYamlContent 64from .._internal.io import FileDescr as FileDescr 65from .._internal.io import ( 66 FileSource, 67 WithSuffix, 68 YamlValue, 69 get_reader, 70 wo_special_file_name, 71) 72from .._internal.io_basics import Sha256 as Sha256 73from .._internal.io_packaging import ( 74 FileDescr_, 75 FileSource_, 76 package_file_descr_serializer, 77) 78from .._internal.io_utils import load_array 79from .._internal.node_converter import Converter 80from .._internal.type_guards import is_dict, is_sequence 81from .._internal.types import ( 82 FAIR, 83 AbsoluteTolerance, 84 LowerCaseIdentifier, 85 LowerCaseIdentifierAnno, 86 MismatchedElementsPerMillion, 87 RelativeTolerance, 88) 89from .._internal.types import Datetime as Datetime 90from .._internal.types import Identifier as Identifier 91from .._internal.types import NotEmpty as NotEmpty 92from .._internal.types import SiUnit as SiUnit 93from .._internal.url import HttpUrl as HttpUrl 94from .._internal.validation_context import get_validation_context 95from .._internal.validator_annotations import RestrictCharacters 96from .._internal.version_type import Version as Version 97from .._internal.warning_levels import INFO 98from ..dataset.v0_2 import DatasetDescr as DatasetDescr02 99from ..dataset.v0_2 import LinkedDataset as LinkedDataset02 100from ..dataset.v0_3 import DatasetDescr as DatasetDescr 101from ..dataset.v0_3 import DatasetId as DatasetId 102from ..dataset.v0_3 import LinkedDataset as LinkedDataset 103from ..dataset.v0_3 import Uploader as Uploader 104from ..generic.v0_3 import ( 105 VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS, 106) 107from ..generic.v0_3 import Author as Author 108from ..generic.v0_3 import BadgeDescr as BadgeDescr 109from ..generic.v0_3 import CiteEntry as CiteEntry 110from ..generic.v0_3 import DeprecatedLicenseId as DeprecatedLicenseId 111from ..generic.v0_3 import Doi as Doi 112from ..generic.v0_3 import ( 113 FileSource_documentation, 114 GenericModelDescrBase, 115 LinkedResourceBase, 116 _author_conv, # pyright: ignore[reportPrivateUsage] 117 _maintainer_conv, # pyright: ignore[reportPrivateUsage] 118) 119from ..generic.v0_3 import LicenseId as LicenseId 120from ..generic.v0_3 import LinkedResource as LinkedResource 121from ..generic.v0_3 import Maintainer as Maintainer 122from ..generic.v0_3 import OrcidId as OrcidId 123from ..generic.v0_3 import RelativeFilePath as RelativeFilePath 124from ..generic.v0_3 import ResourceId as ResourceId 125from .v0_4 import Author as _Author_v0_4 126from .v0_4 import BinarizeDescr as _BinarizeDescr_v0_4 127from .v0_4 import CallableFromDepencency as CallableFromDepencency 128from .v0_4 import CallableFromDepencency as _CallableFromDepencency_v0_4 129from .v0_4 import CallableFromFile as _CallableFromFile_v0_4 130from .v0_4 import ClipDescr as _ClipDescr_v0_4 131from .v0_4 import ClipKwargs as ClipKwargs 132from .v0_4 import ImplicitOutputShape as _ImplicitOutputShape_v0_4 133from .v0_4 import InputTensorDescr as _InputTensorDescr_v0_4 134from .v0_4 import KnownRunMode as KnownRunMode 135from .v0_4 import ModelDescr as _ModelDescr_v0_4 136from .v0_4 import OutputTensorDescr as _OutputTensorDescr_v0_4 137from .v0_4 import ParameterizedInputShape as _ParameterizedInputShape_v0_4 138from .v0_4 import PostprocessingDescr as _PostprocessingDescr_v0_4 139from .v0_4 import PreprocessingDescr as _PreprocessingDescr_v0_4 140from .v0_4 import ProcessingKwargs as ProcessingKwargs 141from .v0_4 import RunMode as RunMode 142from .v0_4 import ScaleLinearDescr as _ScaleLinearDescr_v0_4 143from .v0_4 import ScaleMeanVarianceDescr as _ScaleMeanVarianceDescr_v0_4 144from .v0_4 import ScaleRangeDescr as _ScaleRangeDescr_v0_4 145from .v0_4 import SigmoidDescr as _SigmoidDescr_v0_4 146from .v0_4 import TensorName as _TensorName_v0_4 147from .v0_4 import WeightsFormat as WeightsFormat 148from .v0_4 import ZeroMeanUnitVarianceDescr as _ZeroMeanUnitVarianceDescr_v0_4 149from .v0_4 import package_weights 150 151SpaceUnit = Literal[ 152 "attometer", 153 "angstrom", 154 "centimeter", 155 "decimeter", 156 "exameter", 157 "femtometer", 158 "foot", 159 "gigameter", 160 "hectometer", 161 "inch", 162 "kilometer", 163 "megameter", 164 "meter", 165 "micrometer", 166 "mile", 167 "millimeter", 168 "nanometer", 169 "parsec", 170 "petameter", 171 "picometer", 172 "terameter", 173 "yard", 174 "yoctometer", 175 "yottameter", 176 "zeptometer", 177 "zettameter", 178] 179"""Space unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)""" 180 181TimeUnit = Literal[ 182 "attosecond", 183 "centisecond", 184 "day", 185 "decisecond", 186 "exasecond", 187 "femtosecond", 188 "gigasecond", 189 "hectosecond", 190 "hour", 191 "kilosecond", 192 "megasecond", 193 "microsecond", 194 "millisecond", 195 "minute", 196 "nanosecond", 197 "petasecond", 198 "picosecond", 199 "second", 200 "terasecond", 201 "yoctosecond", 202 "yottasecond", 203 "zeptosecond", 204 "zettasecond", 205] 206"""Time unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)""" 207 208AxisType = Literal["batch", "channel", "index", "time", "space"] 209 210_AXIS_TYPE_MAP: Mapping[str, AxisType] = { 211 "b": "batch", 212 "t": "time", 213 "i": "index", 214 "c": "channel", 215 "x": "space", 216 "y": "space", 217 "z": "space", 218} 219 220_AXIS_ID_MAP = { 221 "b": "batch", 222 "t": "time", 223 "i": "index", 224 "c": "channel", 225} 226 227 228class TensorId(LowerCaseIdentifier): 229 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 230 Annotated[LowerCaseIdentifierAnno, MaxLen(32)] 231 ] 232 233 234def _normalize_axis_id(a: str): 235 a = str(a) 236 normalized = _AXIS_ID_MAP.get(a, a) 237 if a != normalized: 238 logger.opt(depth=3).warning( 239 "Normalized axis id from '{}' to '{}'.", a, normalized 240 ) 241 return normalized 242 243 244class AxisId(LowerCaseIdentifier): 245 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 246 Annotated[ 247 LowerCaseIdentifierAnno, 248 MaxLen(16), 249 AfterValidator(_normalize_axis_id), 250 ] 251 ] 252 253 254def _is_batch(a: str) -> bool: 255 return str(a) == "batch" 256 257 258def _is_not_batch(a: str) -> bool: 259 return not _is_batch(a) 260 261 262NonBatchAxisId = Annotated[AxisId, Predicate(_is_not_batch)] 263 264PreprocessingId = Literal[ 265 "binarize", 266 "clip", 267 "ensure_dtype", 268 "fixed_zero_mean_unit_variance", 269 "scale_linear", 270 "scale_range", 271 "sigmoid", 272 "softmax", 273] 274PostprocessingId = Literal[ 275 "binarize", 276 "clip", 277 "ensure_dtype", 278 "fixed_zero_mean_unit_variance", 279 "scale_linear", 280 "scale_mean_variance", 281 "scale_range", 282 "sigmoid", 283 "softmax", 284 "zero_mean_unit_variance", 285] 286 287 288SAME_AS_TYPE = "<same as type>" 289 290 291ParameterizedSize_N = int 292""" 293Annotates an integer to calculate a concrete axis size from a `ParameterizedSize`. 294""" 295 296 297class ParameterizedSize(Node): 298 """Describes a range of valid tensor axis sizes as `size = min + n*step`. 299 300 - **min** and **step** are given by the model description. 301 - All blocksize paramters n = 0,1,2,... yield a valid `size`. 302 - A greater blocksize paramter n = 0,1,2,... results in a greater **size**. 303 This allows to adjust the axis size more generically. 304 """ 305 306 N: ClassVar[Type[int]] = ParameterizedSize_N 307 """Positive integer to parameterize this axis""" 308 309 min: Annotated[int, Gt(0)] 310 step: Annotated[int, Gt(0)] 311 312 def validate_size(self, size: int) -> int: 313 if size < self.min: 314 raise ValueError(f"size {size} < {self.min}") 315 if (size - self.min) % self.step != 0: 316 raise ValueError( 317 f"axis of size {size} is not parameterized by `min + n*step` =" 318 + f" `{self.min} + n*{self.step}`" 319 ) 320 321 return size 322 323 def get_size(self, n: ParameterizedSize_N) -> int: 324 return self.min + self.step * n 325 326 def get_n(self, s: int) -> ParameterizedSize_N: 327 """return smallest n parameterizing a size greater or equal than `s`""" 328 return ceil((s - self.min) / self.step) 329 330 331class DataDependentSize(Node): 332 min: Annotated[int, Gt(0)] = 1 333 max: Annotated[Optional[int], Gt(1)] = None 334 335 @model_validator(mode="after") 336 def _validate_max_gt_min(self): 337 if self.max is not None and self.min >= self.max: 338 raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}") 339 340 return self 341 342 def validate_size(self, size: int) -> int: 343 if size < self.min: 344 raise ValueError(f"size {size} < {self.min}") 345 346 if self.max is not None and size > self.max: 347 raise ValueError(f"size {size} > {self.max}") 348 349 return size 350 351 352class SizeReference(Node): 353 """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis. 354 355 `axis.size = reference.size * reference.scale / axis.scale + offset` 356 357 Note: 358 1. The axis and the referenced axis need to have the same unit (or no unit). 359 2. Batch axes may not be referenced. 360 3. Fractions are rounded down. 361 4. If the reference axis is `concatenable` the referencing axis is assumed to be 362 `concatenable` as well with the same block order. 363 364 Example: 365 An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm². 366 Let's assume that we want to express the image height h in relation to its width w 367 instead of only accepting input images of exactly 100*49 pixels 368 (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`). 369 370 >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2) 371 >>> h = SpaceInputAxis( 372 ... id=AxisId("h"), 373 ... size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1), 374 ... unit="millimeter", 375 ... scale=4, 376 ... ) 377 >>> print(h.size.get_size(h, w)) 378 49 379 380 ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49 381 """ 382 383 tensor_id: TensorId 384 """tensor id of the reference axis""" 385 386 axis_id: AxisId 387 """axis id of the reference axis""" 388 389 offset: StrictInt = 0 390 391 def get_size( 392 self, 393 axis: Union[ 394 ChannelAxis, 395 IndexInputAxis, 396 IndexOutputAxis, 397 TimeInputAxis, 398 SpaceInputAxis, 399 TimeOutputAxis, 400 TimeOutputAxisWithHalo, 401 SpaceOutputAxis, 402 SpaceOutputAxisWithHalo, 403 ], 404 ref_axis: Union[ 405 ChannelAxis, 406 IndexInputAxis, 407 IndexOutputAxis, 408 TimeInputAxis, 409 SpaceInputAxis, 410 TimeOutputAxis, 411 TimeOutputAxisWithHalo, 412 SpaceOutputAxis, 413 SpaceOutputAxisWithHalo, 414 ], 415 n: ParameterizedSize_N = 0, 416 ref_size: Optional[int] = None, 417 ): 418 """Compute the concrete size for a given axis and its reference axis. 419 420 Args: 421 axis: The axis this `SizeReference` is the size of. 422 ref_axis: The reference axis to compute the size from. 423 n: If the **ref_axis** is parameterized (of type `ParameterizedSize`) 424 and no fixed **ref_size** is given, 425 **n** is used to compute the size of the parameterized **ref_axis**. 426 ref_size: Overwrite the reference size instead of deriving it from 427 **ref_axis** 428 (**ref_axis.scale** is still used; any given **n** is ignored). 429 """ 430 assert ( 431 axis.size == self 432 ), "Given `axis.size` is not defined by this `SizeReference`" 433 434 assert ( 435 ref_axis.id == self.axis_id 436 ), f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}." 437 438 assert axis.unit == ref_axis.unit, ( 439 "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`," 440 f" but {axis.unit}!={ref_axis.unit}" 441 ) 442 if ref_size is None: 443 if isinstance(ref_axis.size, (int, float)): 444 ref_size = ref_axis.size 445 elif isinstance(ref_axis.size, ParameterizedSize): 446 ref_size = ref_axis.size.get_size(n) 447 elif isinstance(ref_axis.size, DataDependentSize): 448 raise ValueError( 449 "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`." 450 ) 451 elif isinstance(ref_axis.size, SizeReference): 452 raise ValueError( 453 "Reference axis referenced in `SizeReference` may not be sized by a" 454 + " `SizeReference` itself." 455 ) 456 else: 457 assert_never(ref_axis.size) 458 459 return int(ref_size * ref_axis.scale / axis.scale + self.offset) 460 461 @staticmethod 462 def _get_unit( 463 axis: Union[ 464 ChannelAxis, 465 IndexInputAxis, 466 IndexOutputAxis, 467 TimeInputAxis, 468 SpaceInputAxis, 469 TimeOutputAxis, 470 TimeOutputAxisWithHalo, 471 SpaceOutputAxis, 472 SpaceOutputAxisWithHalo, 473 ], 474 ): 475 return axis.unit 476 477 478class AxisBase(NodeWithExplicitlySetFields): 479 id: AxisId 480 """An axis id unique across all axes of one tensor.""" 481 482 description: Annotated[str, MaxLen(128)] = "" 483 """A short description of this axis beyond its type and id.""" 484 485 486class WithHalo(Node): 487 halo: Annotated[int, Ge(1)] 488 """The halo should be cropped from the output tensor to avoid boundary effects. 489 It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`. 490 To document a halo that is already cropped by the model use `size.offset` instead.""" 491 492 size: Annotated[ 493 SizeReference, 494 Field( 495 examples=[ 496 10, 497 SizeReference( 498 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 499 ).model_dump(mode="json"), 500 ] 501 ), 502 ] 503 """reference to another axis with an optional offset (see `SizeReference`)""" 504 505 506BATCH_AXIS_ID = AxisId("batch") 507 508 509class BatchAxis(AxisBase): 510 implemented_type: ClassVar[Literal["batch"]] = "batch" 511 if TYPE_CHECKING: 512 type: Literal["batch"] = "batch" 513 else: 514 type: Literal["batch"] 515 516 id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID 517 size: Optional[Literal[1]] = None 518 """The batch size may be fixed to 1, 519 otherwise (the default) it may be chosen arbitrarily depending on available memory""" 520 521 @property 522 def scale(self): 523 return 1.0 524 525 @property 526 def concatenable(self): 527 return True 528 529 @property 530 def unit(self): 531 return None 532 533 534class ChannelAxis(AxisBase): 535 implemented_type: ClassVar[Literal["channel"]] = "channel" 536 if TYPE_CHECKING: 537 type: Literal["channel"] = "channel" 538 else: 539 type: Literal["channel"] 540 541 id: NonBatchAxisId = AxisId("channel") 542 543 channel_names: NotEmpty[List[Identifier]] 544 545 @property 546 def size(self) -> int: 547 return len(self.channel_names) 548 549 @property 550 def concatenable(self): 551 return False 552 553 @property 554 def scale(self) -> float: 555 return 1.0 556 557 @property 558 def unit(self): 559 return None 560 561 562class IndexAxisBase(AxisBase): 563 implemented_type: ClassVar[Literal["index"]] = "index" 564 if TYPE_CHECKING: 565 type: Literal["index"] = "index" 566 else: 567 type: Literal["index"] 568 569 id: NonBatchAxisId = AxisId("index") 570 571 @property 572 def scale(self) -> float: 573 return 1.0 574 575 @property 576 def unit(self): 577 return None 578 579 580class _WithInputAxisSize(Node): 581 size: Annotated[ 582 Union[Annotated[int, Gt(0)], ParameterizedSize, SizeReference], 583 Field( 584 examples=[ 585 10, 586 ParameterizedSize(min=32, step=16).model_dump(mode="json"), 587 SizeReference( 588 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 589 ).model_dump(mode="json"), 590 ] 591 ), 592 ] 593 """The size/length of this axis can be specified as 594 - fixed integer 595 - parameterized series of valid sizes (`ParameterizedSize`) 596 - reference to another axis with an optional offset (`SizeReference`) 597 """ 598 599 600class IndexInputAxis(IndexAxisBase, _WithInputAxisSize): 601 concatenable: bool = False 602 """If a model has a `concatenable` input axis, it can be processed blockwise, 603 splitting a longer sample axis into blocks matching its input tensor description. 604 Output axes are concatenable if they have a `SizeReference` to a concatenable 605 input axis. 606 """ 607 608 609class IndexOutputAxis(IndexAxisBase): 610 size: Annotated[ 611 Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize], 612 Field( 613 examples=[ 614 10, 615 SizeReference( 616 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 617 ).model_dump(mode="json"), 618 ] 619 ), 620 ] 621 """The size/length of this axis can be specified as 622 - fixed integer 623 - reference to another axis with an optional offset (`SizeReference`) 624 - data dependent size using `DataDependentSize` (size is only known after model inference) 625 """ 626 627 628class TimeAxisBase(AxisBase): 629 implemented_type: ClassVar[Literal["time"]] = "time" 630 if TYPE_CHECKING: 631 type: Literal["time"] = "time" 632 else: 633 type: Literal["time"] 634 635 id: NonBatchAxisId = AxisId("time") 636 unit: Optional[TimeUnit] = None 637 scale: Annotated[float, Gt(0)] = 1.0 638 639 640class TimeInputAxis(TimeAxisBase, _WithInputAxisSize): 641 concatenable: bool = False 642 """If a model has a `concatenable` input axis, it can be processed blockwise, 643 splitting a longer sample axis into blocks matching its input tensor description. 644 Output axes are concatenable if they have a `SizeReference` to a concatenable 645 input axis. 646 """ 647 648 649class SpaceAxisBase(AxisBase): 650 implemented_type: ClassVar[Literal["space"]] = "space" 651 if TYPE_CHECKING: 652 type: Literal["space"] = "space" 653 else: 654 type: Literal["space"] 655 656 id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x") 657 unit: Optional[SpaceUnit] = None 658 scale: Annotated[float, Gt(0)] = 1.0 659 660 661class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize): 662 concatenable: bool = False 663 """If a model has a `concatenable` input axis, it can be processed blockwise, 664 splitting a longer sample axis into blocks matching its input tensor description. 665 Output axes are concatenable if they have a `SizeReference` to a concatenable 666 input axis. 667 """ 668 669 670INPUT_AXIS_TYPES = ( 671 BatchAxis, 672 ChannelAxis, 673 IndexInputAxis, 674 TimeInputAxis, 675 SpaceInputAxis, 676) 677"""intended for isinstance comparisons in py<3.10""" 678 679_InputAxisUnion = Union[ 680 BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis 681] 682InputAxis = Annotated[_InputAxisUnion, Discriminator("type")] 683 684 685class _WithOutputAxisSize(Node): 686 size: Annotated[ 687 Union[Annotated[int, Gt(0)], SizeReference], 688 Field( 689 examples=[ 690 10, 691 SizeReference( 692 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 693 ).model_dump(mode="json"), 694 ] 695 ), 696 ] 697 """The size/length of this axis can be specified as 698 - fixed integer 699 - reference to another axis with an optional offset (see `SizeReference`) 700 """ 701 702 703class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize): 704 pass 705 706 707class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo): 708 pass 709 710 711def _get_halo_axis_discriminator_value(v: Any) -> Literal["with_halo", "wo_halo"]: 712 if isinstance(v, dict): 713 return "with_halo" if "halo" in v else "wo_halo" 714 else: 715 return "with_halo" if hasattr(v, "halo") else "wo_halo" 716 717 718_TimeOutputAxisUnion = Annotated[ 719 Union[ 720 Annotated[TimeOutputAxis, Tag("wo_halo")], 721 Annotated[TimeOutputAxisWithHalo, Tag("with_halo")], 722 ], 723 Discriminator(_get_halo_axis_discriminator_value), 724] 725 726 727class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize): 728 pass 729 730 731class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo): 732 pass 733 734 735_SpaceOutputAxisUnion = Annotated[ 736 Union[ 737 Annotated[SpaceOutputAxis, Tag("wo_halo")], 738 Annotated[SpaceOutputAxisWithHalo, Tag("with_halo")], 739 ], 740 Discriminator(_get_halo_axis_discriminator_value), 741] 742 743 744_OutputAxisUnion = Union[ 745 BatchAxis, ChannelAxis, IndexOutputAxis, _TimeOutputAxisUnion, _SpaceOutputAxisUnion 746] 747OutputAxis = Annotated[_OutputAxisUnion, Discriminator("type")] 748 749OUTPUT_AXIS_TYPES = ( 750 BatchAxis, 751 ChannelAxis, 752 IndexOutputAxis, 753 TimeOutputAxis, 754 TimeOutputAxisWithHalo, 755 SpaceOutputAxis, 756 SpaceOutputAxisWithHalo, 757) 758"""intended for isinstance comparisons in py<3.10""" 759 760 761AnyAxis = Union[InputAxis, OutputAxis] 762 763ANY_AXIS_TYPES = INPUT_AXIS_TYPES + OUTPUT_AXIS_TYPES 764"""intended for isinstance comparisons in py<3.10""" 765 766TVs = Union[ 767 NotEmpty[List[int]], 768 NotEmpty[List[float]], 769 NotEmpty[List[bool]], 770 NotEmpty[List[str]], 771] 772 773 774NominalOrOrdinalDType = Literal[ 775 "float32", 776 "float64", 777 "uint8", 778 "int8", 779 "uint16", 780 "int16", 781 "uint32", 782 "int32", 783 "uint64", 784 "int64", 785 "bool", 786] 787 788 789class NominalOrOrdinalDataDescr(Node): 790 values: TVs 791 """A fixed set of nominal or an ascending sequence of ordinal values. 792 In this case `data.type` is required to be an unsigend integer type, e.g. 'uint8'. 793 String `values` are interpreted as labels for tensor values 0, ..., N. 794 Note: as YAML 1.2 does not natively support a "set" datatype, 795 nominal values should be given as a sequence (aka list/array) as well. 796 """ 797 798 type: Annotated[ 799 NominalOrOrdinalDType, 800 Field( 801 examples=[ 802 "float32", 803 "uint8", 804 "uint16", 805 "int64", 806 "bool", 807 ], 808 ), 809 ] = "uint8" 810 811 @model_validator(mode="after") 812 def _validate_values_match_type( 813 self, 814 ) -> Self: 815 incompatible: List[Any] = [] 816 for v in self.values: 817 if self.type == "bool": 818 if not isinstance(v, bool): 819 incompatible.append(v) 820 elif self.type in DTYPE_LIMITS: 821 if ( 822 isinstance(v, (int, float)) 823 and ( 824 v < DTYPE_LIMITS[self.type].min 825 or v > DTYPE_LIMITS[self.type].max 826 ) 827 or (isinstance(v, str) and "uint" not in self.type) 828 or (isinstance(v, float) and "int" in self.type) 829 ): 830 incompatible.append(v) 831 else: 832 incompatible.append(v) 833 834 if len(incompatible) == 5: 835 incompatible.append("...") 836 break 837 838 if incompatible: 839 raise ValueError( 840 f"data type '{self.type}' incompatible with values {incompatible}" 841 ) 842 843 return self 844 845 unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None 846 847 @property 848 def range(self): 849 if isinstance(self.values[0], str): 850 return 0, len(self.values) - 1 851 else: 852 return min(self.values), max(self.values) 853 854 855IntervalOrRatioDType = Literal[ 856 "float32", 857 "float64", 858 "uint8", 859 "int8", 860 "uint16", 861 "int16", 862 "uint32", 863 "int32", 864 "uint64", 865 "int64", 866] 867 868 869class IntervalOrRatioDataDescr(Node): 870 type: Annotated[ # TODO: rename to dtype 871 IntervalOrRatioDType, 872 Field( 873 examples=["float32", "float64", "uint8", "uint16"], 874 ), 875 ] = "float32" 876 range: Tuple[Optional[float], Optional[float]] = ( 877 None, 878 None, 879 ) 880 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor. 881 `None` corresponds to min/max of what can be expressed by **type**.""" 882 unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit" 883 scale: float = 1.0 884 """Scale for data on an interval (or ratio) scale.""" 885 offset: Optional[float] = None 886 """Offset for data on a ratio scale.""" 887 888 @model_validator(mode="before") 889 def _replace_inf(cls, data: Any): 890 if is_dict(data): 891 if "range" in data and is_sequence(data["range"]): 892 forbidden = ( 893 "inf", 894 "-inf", 895 ".inf", 896 "-.inf", 897 float("inf"), 898 float("-inf"), 899 ) 900 if any(v in forbidden for v in data["range"]): 901 issue_warning("replaced 'inf' value", value=data["range"]) 902 903 data["range"] = tuple( 904 (None if v in forbidden else v) for v in data["range"] 905 ) 906 907 return data 908 909 910TensorDataDescr = Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr] 911 912 913class ProcessingDescrBase(NodeWithExplicitlySetFields, ABC): 914 """processing base class""" 915 916 917class BinarizeKwargs(ProcessingKwargs): 918 """key word arguments for `BinarizeDescr`""" 919 920 threshold: float 921 """The fixed threshold""" 922 923 924class BinarizeAlongAxisKwargs(ProcessingKwargs): 925 """key word arguments for `BinarizeDescr`""" 926 927 threshold: NotEmpty[List[float]] 928 """The fixed threshold values along `axis`""" 929 930 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 931 """The `threshold` axis""" 932 933 934class BinarizeDescr(ProcessingDescrBase): 935 """Binarize the tensor with a fixed threshold. 936 937 Values above `BinarizeKwargs.threshold`/`BinarizeAlongAxisKwargs.threshold` 938 will be set to one, values below the threshold to zero. 939 940 Examples: 941 - in YAML 942 ```yaml 943 postprocessing: 944 - id: binarize 945 kwargs: 946 axis: 'channel' 947 threshold: [0.25, 0.5, 0.75] 948 ``` 949 - in Python: 950 >>> postprocessing = [BinarizeDescr( 951 ... kwargs=BinarizeAlongAxisKwargs( 952 ... axis=AxisId('channel'), 953 ... threshold=[0.25, 0.5, 0.75], 954 ... ) 955 ... )] 956 """ 957 958 implemented_id: ClassVar[Literal["binarize"]] = "binarize" 959 if TYPE_CHECKING: 960 id: Literal["binarize"] = "binarize" 961 else: 962 id: Literal["binarize"] 963 kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs] 964 965 966class ClipDescr(ProcessingDescrBase): 967 """Set tensor values below min to min and above max to max. 968 969 See `ScaleRangeDescr` for examples. 970 """ 971 972 implemented_id: ClassVar[Literal["clip"]] = "clip" 973 if TYPE_CHECKING: 974 id: Literal["clip"] = "clip" 975 else: 976 id: Literal["clip"] 977 978 kwargs: ClipKwargs 979 980 981class EnsureDtypeKwargs(ProcessingKwargs): 982 """key word arguments for `EnsureDtypeDescr`""" 983 984 dtype: Literal[ 985 "float32", 986 "float64", 987 "uint8", 988 "int8", 989 "uint16", 990 "int16", 991 "uint32", 992 "int32", 993 "uint64", 994 "int64", 995 "bool", 996 ] 997 998 999class EnsureDtypeDescr(ProcessingDescrBase): 1000 """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching). 1001 1002 This can for example be used to ensure the inner neural network model gets a 1003 different input tensor data type than the fully described bioimage.io model does. 1004 1005 Examples: 1006 The described bioimage.io model (incl. preprocessing) accepts any 1007 float32-compatible tensor, normalizes it with percentiles and clipping and then 1008 casts it to uint8, which is what the neural network in this example expects. 1009 - in YAML 1010 ```yaml 1011 inputs: 1012 - data: 1013 type: float32 # described bioimage.io model is compatible with any float32 input tensor 1014 preprocessing: 1015 - id: scale_range 1016 kwargs: 1017 axes: ['y', 'x'] 1018 max_percentile: 99.8 1019 min_percentile: 5.0 1020 - id: clip 1021 kwargs: 1022 min: 0.0 1023 max: 1.0 1024 - id: ensure_dtype # the neural network of the model requires uint8 1025 kwargs: 1026 dtype: uint8 1027 ``` 1028 - in Python: 1029 >>> preprocessing = [ 1030 ... ScaleRangeDescr( 1031 ... kwargs=ScaleRangeKwargs( 1032 ... axes= (AxisId('y'), AxisId('x')), 1033 ... max_percentile= 99.8, 1034 ... min_percentile= 5.0, 1035 ... ) 1036 ... ), 1037 ... ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)), 1038 ... EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")), 1039 ... ] 1040 """ 1041 1042 implemented_id: ClassVar[Literal["ensure_dtype"]] = "ensure_dtype" 1043 if TYPE_CHECKING: 1044 id: Literal["ensure_dtype"] = "ensure_dtype" 1045 else: 1046 id: Literal["ensure_dtype"] 1047 1048 kwargs: EnsureDtypeKwargs 1049 1050 1051class ScaleLinearKwargs(ProcessingKwargs): 1052 """Key word arguments for `ScaleLinearDescr`""" 1053 1054 gain: float = 1.0 1055 """multiplicative factor""" 1056 1057 offset: float = 0.0 1058 """additive term""" 1059 1060 @model_validator(mode="after") 1061 def _validate(self) -> Self: 1062 if self.gain == 1.0 and self.offset == 0.0: 1063 raise ValueError( 1064 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 1065 + " != 0.0." 1066 ) 1067 1068 return self 1069 1070 1071class ScaleLinearAlongAxisKwargs(ProcessingKwargs): 1072 """Key word arguments for `ScaleLinearDescr`""" 1073 1074 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 1075 """The axis of gain and offset values.""" 1076 1077 gain: Union[float, NotEmpty[List[float]]] = 1.0 1078 """multiplicative factor""" 1079 1080 offset: Union[float, NotEmpty[List[float]]] = 0.0 1081 """additive term""" 1082 1083 @model_validator(mode="after") 1084 def _validate(self) -> Self: 1085 1086 if isinstance(self.gain, list): 1087 if isinstance(self.offset, list): 1088 if len(self.gain) != len(self.offset): 1089 raise ValueError( 1090 f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match." 1091 ) 1092 else: 1093 self.offset = [float(self.offset)] * len(self.gain) 1094 elif isinstance(self.offset, list): 1095 self.gain = [float(self.gain)] * len(self.offset) 1096 else: 1097 raise ValueError( 1098 "Do not specify an `axis` for scalar gain and offset values." 1099 ) 1100 1101 if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset): 1102 raise ValueError( 1103 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 1104 + " != 0.0." 1105 ) 1106 1107 return self 1108 1109 1110class ScaleLinearDescr(ProcessingDescrBase): 1111 """Fixed linear scaling. 1112 1113 Examples: 1114 1. Scale with scalar gain and offset 1115 - in YAML 1116 ```yaml 1117 preprocessing: 1118 - id: scale_linear 1119 kwargs: 1120 gain: 2.0 1121 offset: 3.0 1122 ``` 1123 - in Python: 1124 >>> preprocessing = [ 1125 ... ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0)) 1126 ... ] 1127 1128 2. Independent scaling along an axis 1129 - in YAML 1130 ```yaml 1131 preprocessing: 1132 - id: scale_linear 1133 kwargs: 1134 axis: 'channel' 1135 gain: [1.0, 2.0, 3.0] 1136 ``` 1137 - in Python: 1138 >>> preprocessing = [ 1139 ... ScaleLinearDescr( 1140 ... kwargs=ScaleLinearAlongAxisKwargs( 1141 ... axis=AxisId("channel"), 1142 ... gain=[1.0, 2.0, 3.0], 1143 ... ) 1144 ... ) 1145 ... ] 1146 1147 """ 1148 1149 implemented_id: ClassVar[Literal["scale_linear"]] = "scale_linear" 1150 if TYPE_CHECKING: 1151 id: Literal["scale_linear"] = "scale_linear" 1152 else: 1153 id: Literal["scale_linear"] 1154 kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs] 1155 1156 1157class SigmoidDescr(ProcessingDescrBase): 1158 """The logistic sigmoid function, a.k.a. expit function. 1159 1160 Examples: 1161 - in YAML 1162 ```yaml 1163 postprocessing: 1164 - id: sigmoid 1165 ``` 1166 - in Python: 1167 >>> postprocessing = [SigmoidDescr()] 1168 """ 1169 1170 implemented_id: ClassVar[Literal["sigmoid"]] = "sigmoid" 1171 if TYPE_CHECKING: 1172 id: Literal["sigmoid"] = "sigmoid" 1173 else: 1174 id: Literal["sigmoid"] 1175 1176 @property 1177 def kwargs(self) -> ProcessingKwargs: 1178 """empty kwargs""" 1179 return ProcessingKwargs() 1180 1181 1182class SoftmaxKwargs(ProcessingKwargs): 1183 """key word arguments for `SoftmaxDescr`""" 1184 1185 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] = AxisId("channel") 1186 """The axis to apply the softmax function along. 1187 Note: 1188 Defaults to 'channel' axis 1189 (which may not exist, in which case 1190 a different axis id has to be specified). 1191 """ 1192 1193 1194class SoftmaxDescr(ProcessingDescrBase): 1195 """The softmax function. 1196 1197 Examples: 1198 - in YAML 1199 ```yaml 1200 postprocessing: 1201 - id: softmax 1202 kwargs: 1203 axis: channel 1204 ``` 1205 - in Python: 1206 >>> postprocessing = [SoftmaxDescr(kwargs=SoftmaxKwargs(axis=AxisId("channel")))] 1207 """ 1208 1209 implemented_id: ClassVar[Literal["softmax"]] = "softmax" 1210 if TYPE_CHECKING: 1211 id: Literal["softmax"] = "softmax" 1212 else: 1213 id: Literal["softmax"] 1214 1215 kwargs: SoftmaxKwargs = Field(default_factory=SoftmaxKwargs.model_construct) 1216 1217 1218class FixedZeroMeanUnitVarianceKwargs(ProcessingKwargs): 1219 """key word arguments for `FixedZeroMeanUnitVarianceDescr`""" 1220 1221 mean: float 1222 """The mean value to normalize with.""" 1223 1224 std: Annotated[float, Ge(1e-6)] 1225 """The standard deviation value to normalize with.""" 1226 1227 1228class FixedZeroMeanUnitVarianceAlongAxisKwargs(ProcessingKwargs): 1229 """key word arguments for `FixedZeroMeanUnitVarianceDescr`""" 1230 1231 mean: NotEmpty[List[float]] 1232 """The mean value(s) to normalize with.""" 1233 1234 std: NotEmpty[List[Annotated[float, Ge(1e-6)]]] 1235 """The standard deviation value(s) to normalize with. 1236 Size must match `mean` values.""" 1237 1238 axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])] 1239 """The axis of the mean/std values to normalize each entry along that dimension 1240 separately.""" 1241 1242 @model_validator(mode="after") 1243 def _mean_and_std_match(self) -> Self: 1244 if len(self.mean) != len(self.std): 1245 raise ValueError( 1246 f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})" 1247 + " must match." 1248 ) 1249 1250 return self 1251 1252 1253class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase): 1254 """Subtract a given mean and divide by the standard deviation. 1255 1256 Normalize with fixed, precomputed values for 1257 `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std` 1258 Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given 1259 axes. 1260 1261 Examples: 1262 1. scalar value for whole tensor 1263 - in YAML 1264 ```yaml 1265 preprocessing: 1266 - id: fixed_zero_mean_unit_variance 1267 kwargs: 1268 mean: 103.5 1269 std: 13.7 1270 ``` 1271 - in Python 1272 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 1273 ... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7) 1274 ... )] 1275 1276 2. independently along an axis 1277 - in YAML 1278 ```yaml 1279 preprocessing: 1280 - id: fixed_zero_mean_unit_variance 1281 kwargs: 1282 axis: channel 1283 mean: [101.5, 102.5, 103.5] 1284 std: [11.7, 12.7, 13.7] 1285 ``` 1286 - in Python 1287 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 1288 ... kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( 1289 ... axis=AxisId("channel"), 1290 ... mean=[101.5, 102.5, 103.5], 1291 ... std=[11.7, 12.7, 13.7], 1292 ... ) 1293 ... )] 1294 """ 1295 1296 implemented_id: ClassVar[Literal["fixed_zero_mean_unit_variance"]] = ( 1297 "fixed_zero_mean_unit_variance" 1298 ) 1299 if TYPE_CHECKING: 1300 id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance" 1301 else: 1302 id: Literal["fixed_zero_mean_unit_variance"] 1303 1304 kwargs: Union[ 1305 FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs 1306 ] 1307 1308 1309class ZeroMeanUnitVarianceKwargs(ProcessingKwargs): 1310 """key word arguments for `ZeroMeanUnitVarianceDescr`""" 1311 1312 axes: Annotated[ 1313 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 1314 ] = None 1315 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 1316 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 1317 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 1318 To normalize each sample independently leave out the 'batch' axis. 1319 Default: Scale all axes jointly.""" 1320 1321 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 1322 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`.""" 1323 1324 1325class ZeroMeanUnitVarianceDescr(ProcessingDescrBase): 1326 """Subtract mean and divide by variance. 1327 1328 Examples: 1329 Subtract tensor mean and variance 1330 - in YAML 1331 ```yaml 1332 preprocessing: 1333 - id: zero_mean_unit_variance 1334 ``` 1335 - in Python 1336 >>> preprocessing = [ZeroMeanUnitVarianceDescr()] 1337 """ 1338 1339 implemented_id: ClassVar[Literal["zero_mean_unit_variance"]] = ( 1340 "zero_mean_unit_variance" 1341 ) 1342 if TYPE_CHECKING: 1343 id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance" 1344 else: 1345 id: Literal["zero_mean_unit_variance"] 1346 1347 kwargs: ZeroMeanUnitVarianceKwargs = Field( 1348 default_factory=ZeroMeanUnitVarianceKwargs.model_construct 1349 ) 1350 1351 1352class ScaleRangeKwargs(ProcessingKwargs): 1353 """key word arguments for `ScaleRangeDescr` 1354 1355 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default) 1356 this processing step normalizes data to the [0, 1] intervall. 1357 For other percentiles the normalized values will partially be outside the [0, 1] 1358 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the 1359 normalized values to a range. 1360 """ 1361 1362 axes: Annotated[ 1363 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 1364 ] = None 1365 """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value. 1366 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 1367 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 1368 To normalize samples independently, leave out the "batch" axis. 1369 Default: Scale all axes jointly.""" 1370 1371 min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0 1372 """The lower percentile used to determine the value to align with zero.""" 1373 1374 max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0 1375 """The upper percentile used to determine the value to align with one. 1376 Has to be bigger than `min_percentile`. 1377 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly 1378 accepting percentiles specified in the range 0.0 to 1.0.""" 1379 1380 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 1381 """Epsilon for numeric stability. 1382 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`; 1383 with `v_lower,v_upper` values at the respective percentiles.""" 1384 1385 reference_tensor: Optional[TensorId] = None 1386 """Tensor ID to compute the percentiles from. Default: The tensor itself. 1387 For any tensor in `inputs` only input tensor references are allowed.""" 1388 1389 @field_validator("max_percentile", mode="after") 1390 @classmethod 1391 def min_smaller_max(cls, value: float, info: ValidationInfo) -> float: 1392 if (min_p := info.data["min_percentile"]) >= value: 1393 raise ValueError(f"min_percentile {min_p} >= max_percentile {value}") 1394 1395 return value 1396 1397 1398class ScaleRangeDescr(ProcessingDescrBase): 1399 """Scale with percentiles. 1400 1401 Examples: 1402 1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0 1403 - in YAML 1404 ```yaml 1405 preprocessing: 1406 - id: scale_range 1407 kwargs: 1408 axes: ['y', 'x'] 1409 max_percentile: 99.8 1410 min_percentile: 5.0 1411 ``` 1412 - in Python 1413 >>> preprocessing = [ 1414 ... ScaleRangeDescr( 1415 ... kwargs=ScaleRangeKwargs( 1416 ... axes= (AxisId('y'), AxisId('x')), 1417 ... max_percentile= 99.8, 1418 ... min_percentile= 5.0, 1419 ... ) 1420 ... ), 1421 ... ClipDescr( 1422 ... kwargs=ClipKwargs( 1423 ... min=0.0, 1424 ... max=1.0, 1425 ... ) 1426 ... ), 1427 ... ] 1428 1429 2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles. 1430 - in YAML 1431 ```yaml 1432 preprocessing: 1433 - id: scale_range 1434 kwargs: 1435 axes: ['y', 'x'] 1436 max_percentile: 99.8 1437 min_percentile: 5.0 1438 - id: scale_range 1439 - id: clip 1440 kwargs: 1441 min: 0.0 1442 max: 1.0 1443 ``` 1444 - in Python 1445 >>> preprocessing = [ScaleRangeDescr( 1446 ... kwargs=ScaleRangeKwargs( 1447 ... axes= (AxisId('y'), AxisId('x')), 1448 ... max_percentile= 99.8, 1449 ... min_percentile= 5.0, 1450 ... ) 1451 ... )] 1452 1453 """ 1454 1455 implemented_id: ClassVar[Literal["scale_range"]] = "scale_range" 1456 if TYPE_CHECKING: 1457 id: Literal["scale_range"] = "scale_range" 1458 else: 1459 id: Literal["scale_range"] 1460 kwargs: ScaleRangeKwargs = Field(default_factory=ScaleRangeKwargs.model_construct) 1461 1462 1463class ScaleMeanVarianceKwargs(ProcessingKwargs): 1464 """key word arguments for `ScaleMeanVarianceKwargs`""" 1465 1466 reference_tensor: TensorId 1467 """Name of tensor to match.""" 1468 1469 axes: Annotated[ 1470 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 1471 ] = None 1472 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 1473 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 1474 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 1475 To normalize samples independently, leave out the 'batch' axis. 1476 Default: Scale all axes jointly.""" 1477 1478 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 1479 """Epsilon for numeric stability: 1480 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`""" 1481 1482 1483class ScaleMeanVarianceDescr(ProcessingDescrBase): 1484 """Scale a tensor's data distribution to match another tensor's mean/std. 1485 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.` 1486 """ 1487 1488 implemented_id: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance" 1489 if TYPE_CHECKING: 1490 id: Literal["scale_mean_variance"] = "scale_mean_variance" 1491 else: 1492 id: Literal["scale_mean_variance"] 1493 kwargs: ScaleMeanVarianceKwargs 1494 1495 1496PreprocessingDescr = Annotated[ 1497 Union[ 1498 BinarizeDescr, 1499 ClipDescr, 1500 EnsureDtypeDescr, 1501 FixedZeroMeanUnitVarianceDescr, 1502 ScaleLinearDescr, 1503 ScaleRangeDescr, 1504 SigmoidDescr, 1505 SoftmaxDescr, 1506 ZeroMeanUnitVarianceDescr, 1507 ], 1508 Discriminator("id"), 1509] 1510PostprocessingDescr = Annotated[ 1511 Union[ 1512 BinarizeDescr, 1513 ClipDescr, 1514 EnsureDtypeDescr, 1515 FixedZeroMeanUnitVarianceDescr, 1516 ScaleLinearDescr, 1517 ScaleMeanVarianceDescr, 1518 ScaleRangeDescr, 1519 SigmoidDescr, 1520 SoftmaxDescr, 1521 ZeroMeanUnitVarianceDescr, 1522 ], 1523 Discriminator("id"), 1524] 1525 1526IO_AxisT = TypeVar("IO_AxisT", InputAxis, OutputAxis) 1527 1528 1529class TensorDescrBase(Node, Generic[IO_AxisT]): 1530 id: TensorId 1531 """Tensor id. No duplicates are allowed.""" 1532 1533 description: Annotated[str, MaxLen(128)] = "" 1534 """free text description""" 1535 1536 axes: NotEmpty[Sequence[IO_AxisT]] 1537 """tensor axes""" 1538 1539 @property 1540 def shape(self): 1541 return tuple(a.size for a in self.axes) 1542 1543 @field_validator("axes", mode="after", check_fields=False) 1544 @classmethod 1545 def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]: 1546 batch_axes = [a for a in axes if a.type == "batch"] 1547 if len(batch_axes) > 1: 1548 raise ValueError( 1549 f"Only one batch axis (per tensor) allowed, but got {batch_axes}" 1550 ) 1551 1552 seen_ids: Set[AxisId] = set() 1553 duplicate_axes_ids: Set[AxisId] = set() 1554 for a in axes: 1555 (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id) 1556 1557 if duplicate_axes_ids: 1558 raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}") 1559 1560 return axes 1561 1562 test_tensor: FAIR[Optional[FileDescr_]] = None 1563 """An example tensor to use for testing. 1564 Using the model with the test input tensors is expected to yield the test output tensors. 1565 Each test tensor has be a an ndarray in the 1566 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format). 1567 The file extension must be '.npy'.""" 1568 1569 sample_tensor: FAIR[Optional[FileDescr_]] = None 1570 """A sample tensor to illustrate a possible input/output for the model, 1571 The sample image primarily serves to inform a human user about an example use case 1572 and is typically stored as .hdf5, .png or .tiff. 1573 It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats) 1574 (numpy's `.npy` format is not supported). 1575 The image dimensionality has to match the number of axes specified in this tensor description. 1576 """ 1577 1578 @model_validator(mode="after") 1579 def _validate_sample_tensor(self) -> Self: 1580 if self.sample_tensor is None or not get_validation_context().perform_io_checks: 1581 return self 1582 1583 reader = get_reader(self.sample_tensor.source, sha256=self.sample_tensor.sha256) 1584 tensor: NDArray[Any] = imread( 1585 reader.read(), 1586 extension=PurePosixPath(reader.original_file_name).suffix, 1587 ) 1588 n_dims = len(tensor.squeeze().shape) 1589 n_dims_min = n_dims_max = len(self.axes) 1590 1591 for a in self.axes: 1592 if isinstance(a, BatchAxis): 1593 n_dims_min -= 1 1594 elif isinstance(a.size, int): 1595 if a.size == 1: 1596 n_dims_min -= 1 1597 elif isinstance(a.size, (ParameterizedSize, DataDependentSize)): 1598 if a.size.min == 1: 1599 n_dims_min -= 1 1600 elif isinstance(a.size, SizeReference): 1601 if a.size.offset < 2: 1602 # size reference may result in singleton axis 1603 n_dims_min -= 1 1604 else: 1605 assert_never(a.size) 1606 1607 n_dims_min = max(0, n_dims_min) 1608 if n_dims < n_dims_min or n_dims > n_dims_max: 1609 raise ValueError( 1610 f"Expected sample tensor to have {n_dims_min} to" 1611 + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})." 1612 ) 1613 1614 return self 1615 1616 data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = ( 1617 IntervalOrRatioDataDescr() 1618 ) 1619 """Description of the tensor's data values, optionally per channel. 1620 If specified per channel, the data `type` needs to match across channels.""" 1621 1622 @property 1623 def dtype( 1624 self, 1625 ) -> Literal[ 1626 "float32", 1627 "float64", 1628 "uint8", 1629 "int8", 1630 "uint16", 1631 "int16", 1632 "uint32", 1633 "int32", 1634 "uint64", 1635 "int64", 1636 "bool", 1637 ]: 1638 """dtype as specified under `data.type` or `data[i].type`""" 1639 if isinstance(self.data, collections.abc.Sequence): 1640 return self.data[0].type 1641 else: 1642 return self.data.type 1643 1644 @field_validator("data", mode="after") 1645 @classmethod 1646 def _check_data_type_across_channels( 1647 cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] 1648 ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]: 1649 if not isinstance(value, list): 1650 return value 1651 1652 dtypes = {t.type for t in value} 1653 if len(dtypes) > 1: 1654 raise ValueError( 1655 "Tensor data descriptions per channel need to agree in their data" 1656 + f" `type`, but found {dtypes}." 1657 ) 1658 1659 return value 1660 1661 @model_validator(mode="after") 1662 def _check_data_matches_channelaxis(self) -> Self: 1663 if not isinstance(self.data, (list, tuple)): 1664 return self 1665 1666 for a in self.axes: 1667 if isinstance(a, ChannelAxis): 1668 size = a.size 1669 assert isinstance(size, int) 1670 break 1671 else: 1672 return self 1673 1674 if len(self.data) != size: 1675 raise ValueError( 1676 f"Got tensor data descriptions for {len(self.data)} channels, but" 1677 + f" '{a.id}' axis has size {size}." 1678 ) 1679 1680 return self 1681 1682 def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]: 1683 if len(array.shape) != len(self.axes): 1684 raise ValueError( 1685 f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})" 1686 + f" incompatible with {len(self.axes)} axes." 1687 ) 1688 return {a.id: array.shape[i] for i, a in enumerate(self.axes)} 1689 1690 1691class InputTensorDescr(TensorDescrBase[InputAxis]): 1692 id: TensorId = TensorId("input") 1693 """Input tensor id. 1694 No duplicates are allowed across all inputs and outputs.""" 1695 1696 optional: bool = False 1697 """indicates that this tensor may be `None`""" 1698 1699 preprocessing: List[PreprocessingDescr] = Field( 1700 default_factory=cast(Callable[[], List[PreprocessingDescr]], list) 1701 ) 1702 1703 """Description of how this input should be preprocessed. 1704 1705 notes: 1706 - If preprocessing does not start with an 'ensure_dtype' entry, it is added 1707 to ensure an input tensor's data type matches the input tensor's data description. 1708 - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an 1709 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally 1710 changing the data type. 1711 """ 1712 1713 @model_validator(mode="after") 1714 def _validate_preprocessing_kwargs(self) -> Self: 1715 axes_ids = [a.id for a in self.axes] 1716 for p in self.preprocessing: 1717 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 1718 if kwargs_axes is None: 1719 continue 1720 1721 if not isinstance(kwargs_axes, collections.abc.Sequence): 1722 raise ValueError( 1723 f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}" 1724 ) 1725 1726 if any(a not in axes_ids for a in kwargs_axes): 1727 raise ValueError( 1728 "`preprocessing.i.kwargs.axes` needs to be subset of axes ids" 1729 ) 1730 1731 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 1732 dtype = self.data.type 1733 else: 1734 dtype = self.data[0].type 1735 1736 # ensure `preprocessing` begins with `EnsureDtypeDescr` 1737 if not self.preprocessing or not isinstance( 1738 self.preprocessing[0], EnsureDtypeDescr 1739 ): 1740 self.preprocessing.insert( 1741 0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 1742 ) 1743 1744 # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 1745 if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)): 1746 self.preprocessing.append( 1747 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 1748 ) 1749 1750 return self 1751 1752 1753def convert_axes( 1754 axes: str, 1755 *, 1756 shape: Union[ 1757 Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4 1758 ], 1759 tensor_type: Literal["input", "output"], 1760 halo: Optional[Sequence[int]], 1761 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 1762): 1763 ret: List[AnyAxis] = [] 1764 for i, a in enumerate(axes): 1765 axis_type = _AXIS_TYPE_MAP.get(a, a) 1766 if axis_type == "batch": 1767 ret.append(BatchAxis()) 1768 continue 1769 1770 scale = 1.0 1771 if isinstance(shape, _ParameterizedInputShape_v0_4): 1772 if shape.step[i] == 0: 1773 size = shape.min[i] 1774 else: 1775 size = ParameterizedSize(min=shape.min[i], step=shape.step[i]) 1776 elif isinstance(shape, _ImplicitOutputShape_v0_4): 1777 ref_t = str(shape.reference_tensor) 1778 if ref_t.count(".") == 1: 1779 t_id, orig_a_id = ref_t.split(".") 1780 else: 1781 t_id = ref_t 1782 orig_a_id = a 1783 1784 a_id = _AXIS_ID_MAP.get(orig_a_id, a) 1785 if not (orig_scale := shape.scale[i]): 1786 # old way to insert a new axis dimension 1787 size = int(2 * shape.offset[i]) 1788 else: 1789 scale = 1 / orig_scale 1790 if axis_type in ("channel", "index"): 1791 # these axes no longer have a scale 1792 offset_from_scale = orig_scale * size_refs.get( 1793 _TensorName_v0_4(t_id), {} 1794 ).get(orig_a_id, 0) 1795 else: 1796 offset_from_scale = 0 1797 size = SizeReference( 1798 tensor_id=TensorId(t_id), 1799 axis_id=AxisId(a_id), 1800 offset=int(offset_from_scale + 2 * shape.offset[i]), 1801 ) 1802 else: 1803 size = shape[i] 1804 1805 if axis_type == "time": 1806 if tensor_type == "input": 1807 ret.append(TimeInputAxis(size=size, scale=scale)) 1808 else: 1809 assert not isinstance(size, ParameterizedSize) 1810 if halo is None: 1811 ret.append(TimeOutputAxis(size=size, scale=scale)) 1812 else: 1813 assert not isinstance(size, int) 1814 ret.append( 1815 TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i]) 1816 ) 1817 1818 elif axis_type == "index": 1819 if tensor_type == "input": 1820 ret.append(IndexInputAxis(size=size)) 1821 else: 1822 if isinstance(size, ParameterizedSize): 1823 size = DataDependentSize(min=size.min) 1824 1825 ret.append(IndexOutputAxis(size=size)) 1826 elif axis_type == "channel": 1827 assert not isinstance(size, ParameterizedSize) 1828 if isinstance(size, SizeReference): 1829 warnings.warn( 1830 "Conversion of channel size from an implicit output shape may be" 1831 + " wrong" 1832 ) 1833 ret.append( 1834 ChannelAxis( 1835 channel_names=[ 1836 Identifier(f"channel{i}") for i in range(size.offset) 1837 ] 1838 ) 1839 ) 1840 else: 1841 ret.append( 1842 ChannelAxis( 1843 channel_names=[Identifier(f"channel{i}") for i in range(size)] 1844 ) 1845 ) 1846 elif axis_type == "space": 1847 if tensor_type == "input": 1848 ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale)) 1849 else: 1850 assert not isinstance(size, ParameterizedSize) 1851 if halo is None or halo[i] == 0: 1852 ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale)) 1853 elif isinstance(size, int): 1854 raise NotImplementedError( 1855 f"output axis with halo and fixed size (here {size}) not allowed" 1856 ) 1857 else: 1858 ret.append( 1859 SpaceOutputAxisWithHalo( 1860 id=AxisId(a), size=size, scale=scale, halo=halo[i] 1861 ) 1862 ) 1863 1864 return ret 1865 1866 1867def _axes_letters_to_ids( 1868 axes: Optional[str], 1869) -> Optional[List[AxisId]]: 1870 if axes is None: 1871 return None 1872 1873 return [AxisId(a) for a in axes] 1874 1875 1876def _get_complement_v04_axis( 1877 tensor_axes: Sequence[str], axes: Optional[Sequence[str]] 1878) -> Optional[AxisId]: 1879 if axes is None: 1880 return None 1881 1882 non_complement_axes = set(axes) | {"b"} 1883 complement_axes = [a for a in tensor_axes if a not in non_complement_axes] 1884 if len(complement_axes) > 1: 1885 raise ValueError( 1886 f"Expected none or a single complement axis, but axes '{axes}' " 1887 + f"for tensor dims '{tensor_axes}' leave '{complement_axes}'." 1888 ) 1889 1890 return None if not complement_axes else AxisId(complement_axes[0]) 1891 1892 1893def _convert_proc( 1894 p: Union[_PreprocessingDescr_v0_4, _PostprocessingDescr_v0_4], 1895 tensor_axes: Sequence[str], 1896) -> Union[PreprocessingDescr, PostprocessingDescr]: 1897 if isinstance(p, _BinarizeDescr_v0_4): 1898 return BinarizeDescr(kwargs=BinarizeKwargs(threshold=p.kwargs.threshold)) 1899 elif isinstance(p, _ClipDescr_v0_4): 1900 return ClipDescr(kwargs=ClipKwargs(min=p.kwargs.min, max=p.kwargs.max)) 1901 elif isinstance(p, _SigmoidDescr_v0_4): 1902 return SigmoidDescr() 1903 elif isinstance(p, _ScaleLinearDescr_v0_4): 1904 axes = _axes_letters_to_ids(p.kwargs.axes) 1905 if p.kwargs.axes is None: 1906 axis = None 1907 else: 1908 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes) 1909 1910 if axis is None: 1911 assert not isinstance(p.kwargs.gain, list) 1912 assert not isinstance(p.kwargs.offset, list) 1913 kwargs = ScaleLinearKwargs(gain=p.kwargs.gain, offset=p.kwargs.offset) 1914 else: 1915 kwargs = ScaleLinearAlongAxisKwargs( 1916 axis=axis, gain=p.kwargs.gain, offset=p.kwargs.offset 1917 ) 1918 return ScaleLinearDescr(kwargs=kwargs) 1919 elif isinstance(p, _ScaleMeanVarianceDescr_v0_4): 1920 return ScaleMeanVarianceDescr( 1921 kwargs=ScaleMeanVarianceKwargs( 1922 axes=_axes_letters_to_ids(p.kwargs.axes), 1923 reference_tensor=TensorId(str(p.kwargs.reference_tensor)), 1924 eps=p.kwargs.eps, 1925 ) 1926 ) 1927 elif isinstance(p, _ZeroMeanUnitVarianceDescr_v0_4): 1928 if p.kwargs.mode == "fixed": 1929 mean = p.kwargs.mean 1930 std = p.kwargs.std 1931 assert mean is not None 1932 assert std is not None 1933 1934 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes) 1935 1936 if axis is None: 1937 return FixedZeroMeanUnitVarianceDescr( 1938 kwargs=FixedZeroMeanUnitVarianceKwargs( 1939 mean=mean, std=std # pyright: ignore[reportArgumentType] 1940 ) 1941 ) 1942 else: 1943 if not isinstance(mean, list): 1944 mean = [float(mean)] 1945 if not isinstance(std, list): 1946 std = [float(std)] 1947 1948 return FixedZeroMeanUnitVarianceDescr( 1949 kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( 1950 axis=axis, mean=mean, std=std 1951 ) 1952 ) 1953 1954 else: 1955 axes = _axes_letters_to_ids(p.kwargs.axes) or [] 1956 if p.kwargs.mode == "per_dataset": 1957 axes = [AxisId("batch")] + axes 1958 if not axes: 1959 axes = None 1960 return ZeroMeanUnitVarianceDescr( 1961 kwargs=ZeroMeanUnitVarianceKwargs(axes=axes, eps=p.kwargs.eps) 1962 ) 1963 1964 elif isinstance(p, _ScaleRangeDescr_v0_4): 1965 return ScaleRangeDescr( 1966 kwargs=ScaleRangeKwargs( 1967 axes=_axes_letters_to_ids(p.kwargs.axes), 1968 min_percentile=p.kwargs.min_percentile, 1969 max_percentile=p.kwargs.max_percentile, 1970 eps=p.kwargs.eps, 1971 ) 1972 ) 1973 else: 1974 assert_never(p) 1975 1976 1977class _InputTensorConv( 1978 Converter[ 1979 _InputTensorDescr_v0_4, 1980 InputTensorDescr, 1981 FileSource_, 1982 Optional[FileSource_], 1983 Mapping[_TensorName_v0_4, Mapping[str, int]], 1984 ] 1985): 1986 def _convert( 1987 self, 1988 src: _InputTensorDescr_v0_4, 1989 tgt: "type[InputTensorDescr] | type[dict[str, Any]]", 1990 test_tensor: FileSource_, 1991 sample_tensor: Optional[FileSource_], 1992 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 1993 ) -> "InputTensorDescr | dict[str, Any]": 1994 axes: List[InputAxis] = convert_axes( # pyright: ignore[reportAssignmentType] 1995 src.axes, 1996 shape=src.shape, 1997 tensor_type="input", 1998 halo=None, 1999 size_refs=size_refs, 2000 ) 2001 prep: List[PreprocessingDescr] = [] 2002 for p in src.preprocessing: 2003 cp = _convert_proc(p, src.axes) 2004 assert not isinstance(cp, ScaleMeanVarianceDescr) 2005 prep.append(cp) 2006 2007 prep.append(EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="float32"))) 2008 2009 return tgt( 2010 axes=axes, 2011 id=TensorId(str(src.name)), 2012 test_tensor=FileDescr(source=test_tensor), 2013 sample_tensor=( 2014 None if sample_tensor is None else FileDescr(source=sample_tensor) 2015 ), 2016 data=dict(type=src.data_type), # pyright: ignore[reportArgumentType] 2017 preprocessing=prep, 2018 ) 2019 2020 2021_input_tensor_conv = _InputTensorConv(_InputTensorDescr_v0_4, InputTensorDescr) 2022 2023 2024class OutputTensorDescr(TensorDescrBase[OutputAxis]): 2025 id: TensorId = TensorId("output") 2026 """Output tensor id. 2027 No duplicates are allowed across all inputs and outputs.""" 2028 2029 postprocessing: List[PostprocessingDescr] = Field( 2030 default_factory=cast(Callable[[], List[PostprocessingDescr]], list) 2031 ) 2032 """Description of how this output should be postprocessed. 2033 2034 note: `postprocessing` always ends with an 'ensure_dtype' operation. 2035 If not given this is added to cast to this tensor's `data.type`. 2036 """ 2037 2038 @model_validator(mode="after") 2039 def _validate_postprocessing_kwargs(self) -> Self: 2040 axes_ids = [a.id for a in self.axes] 2041 for p in self.postprocessing: 2042 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 2043 if kwargs_axes is None: 2044 continue 2045 2046 if not isinstance(kwargs_axes, collections.abc.Sequence): 2047 raise ValueError( 2048 f"expected `axes` sequence, but got {type(kwargs_axes)}" 2049 ) 2050 2051 if any(a not in axes_ids for a in kwargs_axes): 2052 raise ValueError("`kwargs.axes` needs to be subset of axes ids") 2053 2054 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 2055 dtype = self.data.type 2056 else: 2057 dtype = self.data[0].type 2058 2059 # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 2060 if not self.postprocessing or not isinstance( 2061 self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr) 2062 ): 2063 self.postprocessing.append( 2064 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 2065 ) 2066 return self 2067 2068 2069class _OutputTensorConv( 2070 Converter[ 2071 _OutputTensorDescr_v0_4, 2072 OutputTensorDescr, 2073 FileSource_, 2074 Optional[FileSource_], 2075 Mapping[_TensorName_v0_4, Mapping[str, int]], 2076 ] 2077): 2078 def _convert( 2079 self, 2080 src: _OutputTensorDescr_v0_4, 2081 tgt: "type[OutputTensorDescr] | type[dict[str, Any]]", 2082 test_tensor: FileSource_, 2083 sample_tensor: Optional[FileSource_], 2084 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 2085 ) -> "OutputTensorDescr | dict[str, Any]": 2086 # TODO: split convert_axes into convert_output_axes and convert_input_axes 2087 axes: List[OutputAxis] = convert_axes( # pyright: ignore[reportAssignmentType] 2088 src.axes, 2089 shape=src.shape, 2090 tensor_type="output", 2091 halo=src.halo, 2092 size_refs=size_refs, 2093 ) 2094 data_descr: Dict[str, Any] = dict(type=src.data_type) 2095 if data_descr["type"] == "bool": 2096 data_descr["values"] = [False, True] 2097 2098 return tgt( 2099 axes=axes, 2100 id=TensorId(str(src.name)), 2101 test_tensor=FileDescr(source=test_tensor), 2102 sample_tensor=( 2103 None if sample_tensor is None else FileDescr(source=sample_tensor) 2104 ), 2105 data=data_descr, # pyright: ignore[reportArgumentType] 2106 postprocessing=[_convert_proc(p, src.axes) for p in src.postprocessing], 2107 ) 2108 2109 2110_output_tensor_conv = _OutputTensorConv(_OutputTensorDescr_v0_4, OutputTensorDescr) 2111 2112 2113TensorDescr = Union[InputTensorDescr, OutputTensorDescr] 2114 2115 2116def validate_tensors( 2117 tensors: Mapping[TensorId, Tuple[TensorDescr, Optional[NDArray[Any]]]], 2118 tensor_origin: Literal[ 2119 "test_tensor" 2120 ], # for more precise error messages, e.g. 'test_tensor' 2121): 2122 all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, Optional[int]]]] = {} 2123 2124 def e_msg(d: TensorDescr): 2125 return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]" 2126 2127 for descr, array in tensors.values(): 2128 if array is None: 2129 axis_sizes = {a.id: None for a in descr.axes} 2130 else: 2131 try: 2132 axis_sizes = descr.get_axis_sizes_for_array(array) 2133 except ValueError as e: 2134 raise ValueError(f"{e_msg(descr)} {e}") 2135 2136 all_tensor_axes[descr.id] = {a.id: (a, axis_sizes[a.id]) for a in descr.axes} 2137 2138 for descr, array in tensors.values(): 2139 if array is None: 2140 continue 2141 2142 if descr.dtype in ("float32", "float64"): 2143 invalid_test_tensor_dtype = array.dtype.name not in ( 2144 "float32", 2145 "float64", 2146 "uint8", 2147 "int8", 2148 "uint16", 2149 "int16", 2150 "uint32", 2151 "int32", 2152 "uint64", 2153 "int64", 2154 ) 2155 else: 2156 invalid_test_tensor_dtype = array.dtype.name != descr.dtype 2157 2158 if invalid_test_tensor_dtype: 2159 raise ValueError( 2160 f"{e_msg(descr)}.{tensor_origin}.dtype '{array.dtype.name}' does not" 2161 + f" match described dtype '{descr.dtype}'" 2162 ) 2163 2164 if array.min() > -1e-4 and array.max() < 1e-4: 2165 raise ValueError( 2166 "Output values are too small for reliable testing." 2167 + f" Values <-1e5 or >=1e5 must be present in {tensor_origin}" 2168 ) 2169 2170 for a in descr.axes: 2171 actual_size = all_tensor_axes[descr.id][a.id][1] 2172 if actual_size is None: 2173 continue 2174 2175 if a.size is None: 2176 continue 2177 2178 if isinstance(a.size, int): 2179 if actual_size != a.size: 2180 raise ValueError( 2181 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' " 2182 + f"has incompatible size {actual_size}, expected {a.size}" 2183 ) 2184 elif isinstance(a.size, ParameterizedSize): 2185 _ = a.size.validate_size(actual_size) 2186 elif isinstance(a.size, DataDependentSize): 2187 _ = a.size.validate_size(actual_size) 2188 elif isinstance(a.size, SizeReference): 2189 ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id) 2190 if ref_tensor_axes is None: 2191 raise ValueError( 2192 f"{e_msg(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor" 2193 + f" reference '{a.size.tensor_id}'" 2194 ) 2195 2196 ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None)) 2197 if ref_axis is None or ref_size is None: 2198 raise ValueError( 2199 f"{e_msg(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis" 2200 + f" reference '{a.size.tensor_id}.{a.size.axis_id}" 2201 ) 2202 2203 if a.unit != ref_axis.unit: 2204 raise ValueError( 2205 f"{e_msg(descr)}.axes[{a.id}].size: `SizeReference` requires" 2206 + " axis and reference axis to have the same `unit`, but" 2207 + f" {a.unit}!={ref_axis.unit}" 2208 ) 2209 2210 if actual_size != ( 2211 expected_size := ( 2212 ref_size * ref_axis.scale / a.scale + a.size.offset 2213 ) 2214 ): 2215 raise ValueError( 2216 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' of size" 2217 + f" {actual_size} invalid for referenced size {ref_size};" 2218 + f" expected {expected_size}" 2219 ) 2220 else: 2221 assert_never(a.size) 2222 2223 2224FileDescr_dependencies = Annotated[ 2225 FileDescr_, 2226 WithSuffix((".yaml", ".yml"), case_sensitive=True), 2227 Field(examples=[dict(source="environment.yaml")]), 2228] 2229 2230 2231class _ArchitectureCallableDescr(Node): 2232 callable: Annotated[Identifier, Field(examples=["MyNetworkClass", "get_my_model"])] 2233 """Identifier of the callable that returns a torch.nn.Module instance.""" 2234 2235 kwargs: Dict[str, YamlValue] = Field( 2236 default_factory=cast(Callable[[], Dict[str, YamlValue]], dict) 2237 ) 2238 """key word arguments for the `callable`""" 2239 2240 2241class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr): 2242 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 2243 """Architecture source file""" 2244 2245 @model_serializer(mode="wrap", when_used="unless-none") 2246 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo): 2247 return package_file_descr_serializer(self, nxt, info) 2248 2249 2250class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr): 2251 import_from: str 2252 """Where to import the callable from, i.e. `from <import_from> import <callable>`""" 2253 2254 2255class _ArchFileConv( 2256 Converter[ 2257 _CallableFromFile_v0_4, 2258 ArchitectureFromFileDescr, 2259 Optional[Sha256], 2260 Dict[str, Any], 2261 ] 2262): 2263 def _convert( 2264 self, 2265 src: _CallableFromFile_v0_4, 2266 tgt: "type[ArchitectureFromFileDescr | dict[str, Any]]", 2267 sha256: Optional[Sha256], 2268 kwargs: Dict[str, Any], 2269 ) -> "ArchitectureFromFileDescr | dict[str, Any]": 2270 if src.startswith("http") and src.count(":") == 2: 2271 http, source, callable_ = src.split(":") 2272 source = ":".join((http, source)) 2273 elif not src.startswith("http") and src.count(":") == 1: 2274 source, callable_ = src.split(":") 2275 else: 2276 source = str(src) 2277 callable_ = str(src) 2278 return tgt( 2279 callable=Identifier(callable_), 2280 source=cast(FileSource_, source), 2281 sha256=sha256, 2282 kwargs=kwargs, 2283 ) 2284 2285 2286_arch_file_conv = _ArchFileConv(_CallableFromFile_v0_4, ArchitectureFromFileDescr) 2287 2288 2289class _ArchLibConv( 2290 Converter[ 2291 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr, Dict[str, Any] 2292 ] 2293): 2294 def _convert( 2295 self, 2296 src: _CallableFromDepencency_v0_4, 2297 tgt: "type[ArchitectureFromLibraryDescr | dict[str, Any]]", 2298 kwargs: Dict[str, Any], 2299 ) -> "ArchitectureFromLibraryDescr | dict[str, Any]": 2300 *mods, callable_ = src.split(".") 2301 import_from = ".".join(mods) 2302 return tgt( 2303 import_from=import_from, callable=Identifier(callable_), kwargs=kwargs 2304 ) 2305 2306 2307_arch_lib_conv = _ArchLibConv( 2308 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr 2309) 2310 2311 2312class WeightsEntryDescrBase(FileDescr): 2313 type: ClassVar[WeightsFormat] 2314 weights_format_name: ClassVar[str] # human readable 2315 2316 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 2317 """Source of the weights file.""" 2318 2319 authors: Optional[List[Author]] = None 2320 """Authors 2321 Either the person(s) that have trained this model resulting in the original weights file. 2322 (If this is the initial weights entry, i.e. it does not have a `parent`) 2323 Or the person(s) who have converted the weights to this weights format. 2324 (If this is a child weight, i.e. it has a `parent` field) 2325 """ 2326 2327 parent: Annotated[ 2328 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"]) 2329 ] = None 2330 """The source weights these weights were converted from. 2331 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`, 2332 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights. 2333 All weight entries except one (the initial set of weights resulting from training the model), 2334 need to have this field.""" 2335 2336 comment: str = "" 2337 """A comment about this weights entry, for example how these weights were created.""" 2338 2339 @model_validator(mode="after") 2340 def _validate(self) -> Self: 2341 if self.type == self.parent: 2342 raise ValueError("Weights entry can't be it's own parent.") 2343 2344 return self 2345 2346 @model_serializer(mode="wrap", when_used="unless-none") 2347 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo): 2348 return package_file_descr_serializer(self, nxt, info) 2349 2350 2351class KerasHdf5WeightsDescr(WeightsEntryDescrBase): 2352 type = "keras_hdf5" 2353 weights_format_name: ClassVar[str] = "Keras HDF5" 2354 tensorflow_version: Version 2355 """TensorFlow version used to create these weights.""" 2356 2357 2358class OnnxWeightsDescr(WeightsEntryDescrBase): 2359 type = "onnx" 2360 weights_format_name: ClassVar[str] = "ONNX" 2361 opset_version: Annotated[int, Ge(7)] 2362 """ONNX opset version""" 2363 2364 2365class PytorchStateDictWeightsDescr(WeightsEntryDescrBase): 2366 type = "pytorch_state_dict" 2367 weights_format_name: ClassVar[str] = "Pytorch State Dict" 2368 architecture: Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr] 2369 pytorch_version: Version 2370 """Version of the PyTorch library used. 2371 If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible. 2372 """ 2373 dependencies: Optional[FileDescr_dependencies] = None 2374 """Custom depencies beyond pytorch described in a Conda environment file. 2375 Allows to specify custom dependencies, see conda docs: 2376 - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms) 2377 - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually) 2378 2379 The conda environment file should include pytorch and any version pinning has to be compatible with 2380 **pytorch_version**. 2381 """ 2382 2383 2384class TensorflowJsWeightsDescr(WeightsEntryDescrBase): 2385 type = "tensorflow_js" 2386 weights_format_name: ClassVar[str] = "Tensorflow.js" 2387 tensorflow_version: Version 2388 """Version of the TensorFlow library used.""" 2389 2390 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 2391 """The multi-file weights. 2392 All required files/folders should be a zip archive.""" 2393 2394 2395class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase): 2396 type = "tensorflow_saved_model_bundle" 2397 weights_format_name: ClassVar[str] = "Tensorflow Saved Model" 2398 tensorflow_version: Version 2399 """Version of the TensorFlow library used.""" 2400 2401 dependencies: Optional[FileDescr_dependencies] = None 2402 """Custom dependencies beyond tensorflow. 2403 Should include tensorflow and any version pinning has to be compatible with **tensorflow_version**.""" 2404 2405 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 2406 """The multi-file weights. 2407 All required files/folders should be a zip archive.""" 2408 2409 2410class TorchscriptWeightsDescr(WeightsEntryDescrBase): 2411 type = "torchscript" 2412 weights_format_name: ClassVar[str] = "TorchScript" 2413 pytorch_version: Version 2414 """Version of the PyTorch library used.""" 2415 2416 2417class WeightsDescr(Node): 2418 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None 2419 onnx: Optional[OnnxWeightsDescr] = None 2420 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None 2421 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None 2422 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = ( 2423 None 2424 ) 2425 torchscript: Optional[TorchscriptWeightsDescr] = None 2426 2427 @model_validator(mode="after") 2428 def check_entries(self) -> Self: 2429 entries = {wtype for wtype, entry in self if entry is not None} 2430 2431 if not entries: 2432 raise ValueError("Missing weights entry") 2433 2434 entries_wo_parent = { 2435 wtype 2436 for wtype, entry in self 2437 if entry is not None and hasattr(entry, "parent") and entry.parent is None 2438 } 2439 if len(entries_wo_parent) != 1: 2440 issue_warning( 2441 "Exactly one weights entry may not specify the `parent` field (got" 2442 + " {value}). That entry is considered the original set of model weights." 2443 + " Other weight formats are created through conversion of the orignal or" 2444 + " already converted weights. They have to reference the weights format" 2445 + " they were converted from as their `parent`.", 2446 value=len(entries_wo_parent), 2447 field="weights", 2448 ) 2449 2450 for wtype, entry in self: 2451 if entry is None: 2452 continue 2453 2454 assert hasattr(entry, "type") 2455 assert hasattr(entry, "parent") 2456 assert wtype == entry.type 2457 if ( 2458 entry.parent is not None and entry.parent not in entries 2459 ): # self reference checked for `parent` field 2460 raise ValueError( 2461 f"`weights.{wtype}.parent={entry.parent} not in specified weight" 2462 + f" formats: {entries}" 2463 ) 2464 2465 return self 2466 2467 def __getitem__( 2468 self, 2469 key: Literal[ 2470 "keras_hdf5", 2471 "onnx", 2472 "pytorch_state_dict", 2473 "tensorflow_js", 2474 "tensorflow_saved_model_bundle", 2475 "torchscript", 2476 ], 2477 ): 2478 if key == "keras_hdf5": 2479 ret = self.keras_hdf5 2480 elif key == "onnx": 2481 ret = self.onnx 2482 elif key == "pytorch_state_dict": 2483 ret = self.pytorch_state_dict 2484 elif key == "tensorflow_js": 2485 ret = self.tensorflow_js 2486 elif key == "tensorflow_saved_model_bundle": 2487 ret = self.tensorflow_saved_model_bundle 2488 elif key == "torchscript": 2489 ret = self.torchscript 2490 else: 2491 raise KeyError(key) 2492 2493 if ret is None: 2494 raise KeyError(key) 2495 2496 return ret 2497 2498 @property 2499 def available_formats(self): 2500 return { 2501 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}), 2502 **({} if self.onnx is None else {"onnx": self.onnx}), 2503 **( 2504 {} 2505 if self.pytorch_state_dict is None 2506 else {"pytorch_state_dict": self.pytorch_state_dict} 2507 ), 2508 **( 2509 {} 2510 if self.tensorflow_js is None 2511 else {"tensorflow_js": self.tensorflow_js} 2512 ), 2513 **( 2514 {} 2515 if self.tensorflow_saved_model_bundle is None 2516 else { 2517 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle 2518 } 2519 ), 2520 **({} if self.torchscript is None else {"torchscript": self.torchscript}), 2521 } 2522 2523 @property 2524 def missing_formats(self): 2525 return { 2526 wf for wf in get_args(WeightsFormat) if wf not in self.available_formats 2527 } 2528 2529 2530class ModelId(ResourceId): 2531 pass 2532 2533 2534class LinkedModel(LinkedResourceBase): 2535 """Reference to a bioimage.io model.""" 2536 2537 id: ModelId 2538 """A valid model `id` from the bioimage.io collection.""" 2539 2540 2541class _DataDepSize(NamedTuple): 2542 min: StrictInt 2543 max: Optional[StrictInt] 2544 2545 2546class _AxisSizes(NamedTuple): 2547 """the lenghts of all axes of model inputs and outputs""" 2548 2549 inputs: Dict[Tuple[TensorId, AxisId], int] 2550 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] 2551 2552 2553class _TensorSizes(NamedTuple): 2554 """_AxisSizes as nested dicts""" 2555 2556 inputs: Dict[TensorId, Dict[AxisId, int]] 2557 outputs: Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]] 2558 2559 2560class ReproducibilityTolerance(Node, extra="allow"): 2561 """Describes what small numerical differences -- if any -- may be tolerated 2562 in the generated output when executing in different environments. 2563 2564 A tensor element *output* is considered mismatched to the **test_tensor** if 2565 abs(*output* - **test_tensor**) > **absolute_tolerance** + **relative_tolerance** * abs(**test_tensor**). 2566 (Internally we call [numpy.testing.assert_allclose](https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html).) 2567 2568 Motivation: 2569 For testing we can request the respective deep learning frameworks to be as 2570 reproducible as possible by setting seeds and chosing deterministic algorithms, 2571 but differences in operating systems, available hardware and installed drivers 2572 may still lead to numerical differences. 2573 """ 2574 2575 relative_tolerance: RelativeTolerance = 1e-3 2576 """Maximum relative tolerance of reproduced test tensor.""" 2577 2578 absolute_tolerance: AbsoluteTolerance = 1e-4 2579 """Maximum absolute tolerance of reproduced test tensor.""" 2580 2581 mismatched_elements_per_million: MismatchedElementsPerMillion = 100 2582 """Maximum number of mismatched elements/pixels per million to tolerate.""" 2583 2584 output_ids: Sequence[TensorId] = () 2585 """Limits the output tensor IDs these reproducibility details apply to.""" 2586 2587 weights_formats: Sequence[WeightsFormat] = () 2588 """Limits the weights formats these details apply to.""" 2589 2590 2591class BioimageioConfig(Node, extra="allow"): 2592 reproducibility_tolerance: Sequence[ReproducibilityTolerance] = () 2593 """Tolerances to allow when reproducing the model's test outputs 2594 from the model's test inputs. 2595 Only the first entry matching tensor id and weights format is considered. 2596 """ 2597 2598 2599class Config(Node, extra="allow"): 2600 bioimageio: BioimageioConfig = Field( 2601 default_factory=BioimageioConfig.model_construct 2602 ) 2603 2604 2605class ModelDescr(GenericModelDescrBase): 2606 """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. 2607 These fields are typically stored in a YAML file which we call a model resource description file (model RDF). 2608 """ 2609 2610 implemented_format_version: ClassVar[Literal["0.5.5"]] = "0.5.5" 2611 if TYPE_CHECKING: 2612 format_version: Literal["0.5.5"] = "0.5.5" 2613 else: 2614 format_version: Literal["0.5.5"] 2615 """Version of the bioimage.io model description specification used. 2616 When creating a new model always use the latest micro/patch version described here. 2617 The `format_version` is important for any consumer software to understand how to parse the fields. 2618 """ 2619 2620 implemented_type: ClassVar[Literal["model"]] = "model" 2621 if TYPE_CHECKING: 2622 type: Literal["model"] = "model" 2623 else: 2624 type: Literal["model"] 2625 """Specialized resource type 'model'""" 2626 2627 id: Optional[ModelId] = None 2628 """bioimage.io-wide unique resource identifier 2629 assigned by bioimage.io; version **un**specific.""" 2630 2631 authors: FAIR[List[Author]] = Field( 2632 default_factory=cast(Callable[[], List[Author]], list) 2633 ) 2634 """The authors are the creators of the model RDF and the primary points of contact.""" 2635 2636 documentation: FAIR[Optional[FileSource_documentation]] = None 2637 """URL or relative path to a markdown file with additional documentation. 2638 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory. 2639 The documentation should include a '#[#] Validation' (sub)section 2640 with details on how to quantitatively validate the model on unseen data.""" 2641 2642 @field_validator("documentation", mode="after") 2643 @classmethod 2644 def _validate_documentation( 2645 cls, value: Optional[FileSource_documentation] 2646 ) -> Optional[FileSource_documentation]: 2647 if not get_validation_context().perform_io_checks or value is None: 2648 return value 2649 2650 doc_reader = get_reader(value) 2651 doc_content = doc_reader.read().decode(encoding="utf-8") 2652 if not re.search("#.*[vV]alidation", doc_content): 2653 issue_warning( 2654 "No '# Validation' (sub)section found in {value}.", 2655 value=value, 2656 field="documentation", 2657 ) 2658 2659 return value 2660 2661 inputs: NotEmpty[Sequence[InputTensorDescr]] 2662 """Describes the input tensors expected by this model.""" 2663 2664 @field_validator("inputs", mode="after") 2665 @classmethod 2666 def _validate_input_axes( 2667 cls, inputs: Sequence[InputTensorDescr] 2668 ) -> Sequence[InputTensorDescr]: 2669 input_size_refs = cls._get_axes_with_independent_size(inputs) 2670 2671 for i, ipt in enumerate(inputs): 2672 valid_independent_refs: Dict[ 2673 Tuple[TensorId, AxisId], 2674 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2675 ] = { 2676 **{ 2677 (ipt.id, a.id): (ipt, a, a.size) 2678 for a in ipt.axes 2679 if not isinstance(a, BatchAxis) 2680 and isinstance(a.size, (int, ParameterizedSize)) 2681 }, 2682 **input_size_refs, 2683 } 2684 for a, ax in enumerate(ipt.axes): 2685 cls._validate_axis( 2686 "inputs", 2687 i=i, 2688 tensor_id=ipt.id, 2689 a=a, 2690 axis=ax, 2691 valid_independent_refs=valid_independent_refs, 2692 ) 2693 return inputs 2694 2695 @staticmethod 2696 def _validate_axis( 2697 field_name: str, 2698 i: int, 2699 tensor_id: TensorId, 2700 a: int, 2701 axis: AnyAxis, 2702 valid_independent_refs: Dict[ 2703 Tuple[TensorId, AxisId], 2704 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2705 ], 2706 ): 2707 if isinstance(axis, BatchAxis) or isinstance( 2708 axis.size, (int, ParameterizedSize, DataDependentSize) 2709 ): 2710 return 2711 elif not isinstance(axis.size, SizeReference): 2712 assert_never(axis.size) 2713 2714 # validate axis.size SizeReference 2715 ref = (axis.size.tensor_id, axis.size.axis_id) 2716 if ref not in valid_independent_refs: 2717 raise ValueError( 2718 "Invalid tensor axis reference at" 2719 + f" {field_name}[{i}].axes[{a}].size: {axis.size}." 2720 ) 2721 if ref == (tensor_id, axis.id): 2722 raise ValueError( 2723 "Self-referencing not allowed for" 2724 + f" {field_name}[{i}].axes[{a}].size: {axis.size}" 2725 ) 2726 if axis.type == "channel": 2727 if valid_independent_refs[ref][1].type != "channel": 2728 raise ValueError( 2729 "A channel axis' size may only reference another fixed size" 2730 + " channel axis." 2731 ) 2732 if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names: 2733 ref_size = valid_independent_refs[ref][2] 2734 assert isinstance(ref_size, int), ( 2735 "channel axis ref (another channel axis) has to specify fixed" 2736 + " size" 2737 ) 2738 generated_channel_names = [ 2739 Identifier(axis.channel_names.format(i=i)) 2740 for i in range(1, ref_size + 1) 2741 ] 2742 axis.channel_names = generated_channel_names 2743 2744 if (ax_unit := getattr(axis, "unit", None)) != ( 2745 ref_unit := getattr(valid_independent_refs[ref][1], "unit", None) 2746 ): 2747 raise ValueError( 2748 "The units of an axis and its reference axis need to match, but" 2749 + f" '{ax_unit}' != '{ref_unit}'." 2750 ) 2751 ref_axis = valid_independent_refs[ref][1] 2752 if isinstance(ref_axis, BatchAxis): 2753 raise ValueError( 2754 f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}" 2755 + " (a batch axis is not allowed as reference)." 2756 ) 2757 2758 if isinstance(axis, WithHalo): 2759 min_size = axis.size.get_size(axis, ref_axis, n=0) 2760 if (min_size - 2 * axis.halo) < 1: 2761 raise ValueError( 2762 f"axis {axis.id} with minimum size {min_size} is too small for halo" 2763 + f" {axis.halo}." 2764 ) 2765 2766 input_halo = axis.halo * axis.scale / ref_axis.scale 2767 if input_halo != int(input_halo) or input_halo % 2 == 1: 2768 raise ValueError( 2769 f"input_halo {input_halo} (output_halo {axis.halo} *" 2770 + f" output_scale {axis.scale} / input_scale {ref_axis.scale})" 2771 + f" {tensor_id}.{axis.id}." 2772 ) 2773 2774 @model_validator(mode="after") 2775 def _validate_test_tensors(self) -> Self: 2776 if not get_validation_context().perform_io_checks: 2777 return self 2778 2779 test_output_arrays = [ 2780 None if descr.test_tensor is None else load_array(descr.test_tensor) 2781 for descr in self.outputs 2782 ] 2783 test_input_arrays = [ 2784 None if descr.test_tensor is None else load_array(descr.test_tensor) 2785 for descr in self.inputs 2786 ] 2787 2788 tensors = { 2789 descr.id: (descr, array) 2790 for descr, array in zip( 2791 chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays 2792 ) 2793 } 2794 validate_tensors(tensors, tensor_origin="test_tensor") 2795 2796 output_arrays = { 2797 descr.id: array for descr, array in zip(self.outputs, test_output_arrays) 2798 } 2799 for rep_tol in self.config.bioimageio.reproducibility_tolerance: 2800 if not rep_tol.absolute_tolerance: 2801 continue 2802 2803 if rep_tol.output_ids: 2804 out_arrays = { 2805 oid: a 2806 for oid, a in output_arrays.items() 2807 if oid in rep_tol.output_ids 2808 } 2809 else: 2810 out_arrays = output_arrays 2811 2812 for out_id, array in out_arrays.items(): 2813 if array is None: 2814 continue 2815 2816 if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01: 2817 raise ValueError( 2818 "config.bioimageio.reproducibility_tolerance.absolute_tolerance=" 2819 + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}" 2820 + f" (1% of the maximum value of the test tensor '{out_id}')" 2821 ) 2822 2823 return self 2824 2825 @model_validator(mode="after") 2826 def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self: 2827 ipt_refs = {t.id for t in self.inputs} 2828 out_refs = {t.id for t in self.outputs} 2829 for ipt in self.inputs: 2830 for p in ipt.preprocessing: 2831 ref = p.kwargs.get("reference_tensor") 2832 if ref is None: 2833 continue 2834 if ref not in ipt_refs: 2835 raise ValueError( 2836 f"`reference_tensor` '{ref}' not found. Valid input tensor" 2837 + f" references are: {ipt_refs}." 2838 ) 2839 2840 for out in self.outputs: 2841 for p in out.postprocessing: 2842 ref = p.kwargs.get("reference_tensor") 2843 if ref is None: 2844 continue 2845 2846 if ref not in ipt_refs and ref not in out_refs: 2847 raise ValueError( 2848 f"`reference_tensor` '{ref}' not found. Valid tensor references" 2849 + f" are: {ipt_refs | out_refs}." 2850 ) 2851 2852 return self 2853 2854 # TODO: use validate funcs in validate_test_tensors 2855 # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]: 2856 2857 name: Annotated[ 2858 str, 2859 RestrictCharacters(string.ascii_letters + string.digits + "_+- ()"), 2860 MinLen(5), 2861 MaxLen(128), 2862 warn(MaxLen(64), "Name longer than 64 characters.", INFO), 2863 ] 2864 """A human-readable name of this model. 2865 It should be no longer than 64 characters 2866 and may only contain letter, number, underscore, minus, parentheses and spaces. 2867 We recommend to chose a name that refers to the model's task and image modality. 2868 """ 2869 2870 outputs: NotEmpty[Sequence[OutputTensorDescr]] 2871 """Describes the output tensors.""" 2872 2873 @field_validator("outputs", mode="after") 2874 @classmethod 2875 def _validate_tensor_ids( 2876 cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo 2877 ) -> Sequence[OutputTensorDescr]: 2878 tensor_ids = [ 2879 t.id for t in info.data.get("inputs", []) + info.data.get("outputs", []) 2880 ] 2881 duplicate_tensor_ids: List[str] = [] 2882 seen: Set[str] = set() 2883 for t in tensor_ids: 2884 if t in seen: 2885 duplicate_tensor_ids.append(t) 2886 2887 seen.add(t) 2888 2889 if duplicate_tensor_ids: 2890 raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}") 2891 2892 return outputs 2893 2894 @staticmethod 2895 def _get_axes_with_parameterized_size( 2896 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 2897 ): 2898 return { 2899 f"{t.id}.{a.id}": (t, a, a.size) 2900 for t in io 2901 for a in t.axes 2902 if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize) 2903 } 2904 2905 @staticmethod 2906 def _get_axes_with_independent_size( 2907 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 2908 ): 2909 return { 2910 (t.id, a.id): (t, a, a.size) 2911 for t in io 2912 for a in t.axes 2913 if not isinstance(a, BatchAxis) 2914 and isinstance(a.size, (int, ParameterizedSize)) 2915 } 2916 2917 @field_validator("outputs", mode="after") 2918 @classmethod 2919 def _validate_output_axes( 2920 cls, outputs: List[OutputTensorDescr], info: ValidationInfo 2921 ) -> List[OutputTensorDescr]: 2922 input_size_refs = cls._get_axes_with_independent_size( 2923 info.data.get("inputs", []) 2924 ) 2925 output_size_refs = cls._get_axes_with_independent_size(outputs) 2926 2927 for i, out in enumerate(outputs): 2928 valid_independent_refs: Dict[ 2929 Tuple[TensorId, AxisId], 2930 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2931 ] = { 2932 **{ 2933 (out.id, a.id): (out, a, a.size) 2934 for a in out.axes 2935 if not isinstance(a, BatchAxis) 2936 and isinstance(a.size, (int, ParameterizedSize)) 2937 }, 2938 **input_size_refs, 2939 **output_size_refs, 2940 } 2941 for a, ax in enumerate(out.axes): 2942 cls._validate_axis( 2943 "outputs", 2944 i, 2945 out.id, 2946 a, 2947 ax, 2948 valid_independent_refs=valid_independent_refs, 2949 ) 2950 2951 return outputs 2952 2953 packaged_by: List[Author] = Field( 2954 default_factory=cast(Callable[[], List[Author]], list) 2955 ) 2956 """The persons that have packaged and uploaded this model. 2957 Only required if those persons differ from the `authors`.""" 2958 2959 parent: Optional[LinkedModel] = None 2960 """The model from which this model is derived, e.g. by fine-tuning the weights.""" 2961 2962 @model_validator(mode="after") 2963 def _validate_parent_is_not_self(self) -> Self: 2964 if self.parent is not None and self.parent.id == self.id: 2965 raise ValueError("A model description may not reference itself as parent.") 2966 2967 return self 2968 2969 run_mode: Annotated[ 2970 Optional[RunMode], 2971 warn(None, "Run mode '{value}' has limited support across consumer softwares."), 2972 ] = None 2973 """Custom run mode for this model: for more complex prediction procedures like test time 2974 data augmentation that currently cannot be expressed in the specification. 2975 No standard run modes are defined yet.""" 2976 2977 timestamp: Datetime = Field(default_factory=Datetime.now) 2978 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format 2979 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat). 2980 (In Python a datetime object is valid, too).""" 2981 2982 training_data: Annotated[ 2983 Union[None, LinkedDataset, DatasetDescr, DatasetDescr02], 2984 Field(union_mode="left_to_right"), 2985 ] = None 2986 """The dataset used to train this model""" 2987 2988 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)] 2989 """The weights for this model. 2990 Weights can be given for different formats, but should otherwise be equivalent. 2991 The available weight formats determine which consumers can use this model.""" 2992 2993 config: Config = Field(default_factory=Config.model_construct) 2994 2995 @model_validator(mode="after") 2996 def _add_default_cover(self) -> Self: 2997 if not get_validation_context().perform_io_checks or self.covers: 2998 return self 2999 3000 try: 3001 generated_covers = generate_covers( 3002 [ 3003 (t, load_array(t.test_tensor)) 3004 for t in self.inputs 3005 if t.test_tensor is not None 3006 ], 3007 [ 3008 (t, load_array(t.test_tensor)) 3009 for t in self.outputs 3010 if t.test_tensor is not None 3011 ], 3012 ) 3013 except Exception as e: 3014 issue_warning( 3015 "Failed to generate cover image(s): {e}", 3016 value=self.covers, 3017 msg_context=dict(e=e), 3018 field="covers", 3019 ) 3020 else: 3021 self.covers.extend(generated_covers) 3022 3023 return self 3024 3025 def get_input_test_arrays(self) -> List[NDArray[Any]]: 3026 return self._get_test_arrays(self.inputs) 3027 3028 def get_output_test_arrays(self) -> List[NDArray[Any]]: 3029 return self._get_test_arrays(self.outputs) 3030 3031 @staticmethod 3032 def _get_test_arrays( 3033 io_descr: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 3034 ): 3035 ts: List[FileDescr] = [] 3036 for d in io_descr: 3037 if d.test_tensor is None: 3038 raise ValueError( 3039 f"Failed to get test arrays: description of '{d.id}' is missing a `test_tensor`." 3040 ) 3041 ts.append(d.test_tensor) 3042 3043 data = [load_array(t) for t in ts] 3044 assert all(isinstance(d, np.ndarray) for d in data) 3045 return data 3046 3047 @staticmethod 3048 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int: 3049 batch_size = 1 3050 tensor_with_batchsize: Optional[TensorId] = None 3051 for tid in tensor_sizes: 3052 for aid, s in tensor_sizes[tid].items(): 3053 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size: 3054 continue 3055 3056 if batch_size != 1: 3057 assert tensor_with_batchsize is not None 3058 raise ValueError( 3059 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})" 3060 ) 3061 3062 batch_size = s 3063 tensor_with_batchsize = tid 3064 3065 return batch_size 3066 3067 def get_output_tensor_sizes( 3068 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]] 3069 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]: 3070 """Returns the tensor output sizes for given **input_sizes**. 3071 Only if **input_sizes** has a valid input shape, the tensor output size is exact. 3072 Otherwise it might be larger than the actual (valid) output""" 3073 batch_size = self.get_batch_size(input_sizes) 3074 ns = self.get_ns(input_sizes) 3075 3076 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size) 3077 return tensor_sizes.outputs 3078 3079 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]): 3080 """get parameter `n` for each parameterized axis 3081 such that the valid input size is >= the given input size""" 3082 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {} 3083 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs} 3084 for tid in input_sizes: 3085 for aid, s in input_sizes[tid].items(): 3086 size_descr = axes[tid][aid].size 3087 if isinstance(size_descr, ParameterizedSize): 3088 ret[(tid, aid)] = size_descr.get_n(s) 3089 elif size_descr is None or isinstance(size_descr, (int, SizeReference)): 3090 pass 3091 else: 3092 assert_never(size_descr) 3093 3094 return ret 3095 3096 def get_tensor_sizes( 3097 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int 3098 ) -> _TensorSizes: 3099 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size) 3100 return _TensorSizes( 3101 { 3102 t: { 3103 aa: axis_sizes.inputs[(tt, aa)] 3104 for tt, aa in axis_sizes.inputs 3105 if tt == t 3106 } 3107 for t in {tt for tt, _ in axis_sizes.inputs} 3108 }, 3109 { 3110 t: { 3111 aa: axis_sizes.outputs[(tt, aa)] 3112 for tt, aa in axis_sizes.outputs 3113 if tt == t 3114 } 3115 for t in {tt for tt, _ in axis_sizes.outputs} 3116 }, 3117 ) 3118 3119 def get_axis_sizes( 3120 self, 3121 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], 3122 batch_size: Optional[int] = None, 3123 *, 3124 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None, 3125 ) -> _AxisSizes: 3126 """Determine input and output block shape for scale factors **ns** 3127 of parameterized input sizes. 3128 3129 Args: 3130 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id)) 3131 that is parameterized as `size = min + n * step`. 3132 batch_size: The desired size of the batch dimension. 3133 If given **batch_size** overwrites any batch size present in 3134 **max_input_shape**. Default 1. 3135 max_input_shape: Limits the derived block shapes. 3136 Each axis for which the input size, parameterized by `n`, is larger 3137 than **max_input_shape** is set to the minimal value `n_min` for which 3138 this is still true. 3139 Use this for small input samples or large values of **ns**. 3140 Or simply whenever you know the full input shape. 3141 3142 Returns: 3143 Resolved axis sizes for model inputs and outputs. 3144 """ 3145 max_input_shape = max_input_shape or {} 3146 if batch_size is None: 3147 for (_t_id, a_id), s in max_input_shape.items(): 3148 if a_id == BATCH_AXIS_ID: 3149 batch_size = s 3150 break 3151 else: 3152 batch_size = 1 3153 3154 all_axes = { 3155 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs) 3156 } 3157 3158 inputs: Dict[Tuple[TensorId, AxisId], int] = {} 3159 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {} 3160 3161 def get_axis_size(a: Union[InputAxis, OutputAxis]): 3162 if isinstance(a, BatchAxis): 3163 if (t_descr.id, a.id) in ns: 3164 logger.warning( 3165 "Ignoring unexpected size increment factor (n) for batch axis" 3166 + " of tensor '{}'.", 3167 t_descr.id, 3168 ) 3169 return batch_size 3170 elif isinstance(a.size, int): 3171 if (t_descr.id, a.id) in ns: 3172 logger.warning( 3173 "Ignoring unexpected size increment factor (n) for fixed size" 3174 + " axis '{}' of tensor '{}'.", 3175 a.id, 3176 t_descr.id, 3177 ) 3178 return a.size 3179 elif isinstance(a.size, ParameterizedSize): 3180 if (t_descr.id, a.id) not in ns: 3181 raise ValueError( 3182 "Size increment factor (n) missing for parametrized axis" 3183 + f" '{a.id}' of tensor '{t_descr.id}'." 3184 ) 3185 n = ns[(t_descr.id, a.id)] 3186 s_max = max_input_shape.get((t_descr.id, a.id)) 3187 if s_max is not None: 3188 n = min(n, a.size.get_n(s_max)) 3189 3190 return a.size.get_size(n) 3191 3192 elif isinstance(a.size, SizeReference): 3193 if (t_descr.id, a.id) in ns: 3194 logger.warning( 3195 "Ignoring unexpected size increment factor (n) for axis '{}'" 3196 + " of tensor '{}' with size reference.", 3197 a.id, 3198 t_descr.id, 3199 ) 3200 assert not isinstance(a, BatchAxis) 3201 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id] 3202 assert not isinstance(ref_axis, BatchAxis) 3203 ref_key = (a.size.tensor_id, a.size.axis_id) 3204 ref_size = inputs.get(ref_key, outputs.get(ref_key)) 3205 assert ref_size is not None, ref_key 3206 assert not isinstance(ref_size, _DataDepSize), ref_key 3207 return a.size.get_size( 3208 axis=a, 3209 ref_axis=ref_axis, 3210 ref_size=ref_size, 3211 ) 3212 elif isinstance(a.size, DataDependentSize): 3213 if (t_descr.id, a.id) in ns: 3214 logger.warning( 3215 "Ignoring unexpected increment factor (n) for data dependent" 3216 + " size axis '{}' of tensor '{}'.", 3217 a.id, 3218 t_descr.id, 3219 ) 3220 return _DataDepSize(a.size.min, a.size.max) 3221 else: 3222 assert_never(a.size) 3223 3224 # first resolve all , but the `SizeReference` input sizes 3225 for t_descr in self.inputs: 3226 for a in t_descr.axes: 3227 if not isinstance(a.size, SizeReference): 3228 s = get_axis_size(a) 3229 assert not isinstance(s, _DataDepSize) 3230 inputs[t_descr.id, a.id] = s 3231 3232 # resolve all other input axis sizes 3233 for t_descr in self.inputs: 3234 for a in t_descr.axes: 3235 if isinstance(a.size, SizeReference): 3236 s = get_axis_size(a) 3237 assert not isinstance(s, _DataDepSize) 3238 inputs[t_descr.id, a.id] = s 3239 3240 # resolve all output axis sizes 3241 for t_descr in self.outputs: 3242 for a in t_descr.axes: 3243 assert not isinstance(a.size, ParameterizedSize) 3244 s = get_axis_size(a) 3245 outputs[t_descr.id, a.id] = s 3246 3247 return _AxisSizes(inputs=inputs, outputs=outputs) 3248 3249 @model_validator(mode="before") 3250 @classmethod 3251 def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]: 3252 cls.convert_from_old_format_wo_validation(data) 3253 return data 3254 3255 @classmethod 3256 def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None: 3257 """Convert metadata following an older format version to this classes' format 3258 without validating the result. 3259 """ 3260 if ( 3261 data.get("type") == "model" 3262 and isinstance(fv := data.get("format_version"), str) 3263 and fv.count(".") == 2 3264 ): 3265 fv_parts = fv.split(".") 3266 if any(not p.isdigit() for p in fv_parts): 3267 return 3268 3269 fv_tuple = tuple(map(int, fv_parts)) 3270 3271 assert cls.implemented_format_version_tuple[0:2] == (0, 5) 3272 if fv_tuple[:2] in ((0, 3), (0, 4)): 3273 m04 = _ModelDescr_v0_4.load(data) 3274 if isinstance(m04, InvalidDescr): 3275 try: 3276 updated = _model_conv.convert_as_dict( 3277 m04 # pyright: ignore[reportArgumentType] 3278 ) 3279 except Exception as e: 3280 logger.error( 3281 "Failed to convert from invalid model 0.4 description." 3282 + f"\nerror: {e}" 3283 + "\nProceeding with model 0.5 validation without conversion." 3284 ) 3285 updated = None 3286 else: 3287 updated = _model_conv.convert_as_dict(m04) 3288 3289 if updated is not None: 3290 data.clear() 3291 data.update(updated) 3292 3293 elif fv_tuple[:2] == (0, 5): 3294 # bump patch version 3295 data["format_version"] = cls.implemented_format_version 3296 3297 3298class _ModelConv(Converter[_ModelDescr_v0_4, ModelDescr]): 3299 def _convert( 3300 self, src: _ModelDescr_v0_4, tgt: "type[ModelDescr] | type[dict[str, Any]]" 3301 ) -> "ModelDescr | dict[str, Any]": 3302 name = "".join( 3303 c if c in string.ascii_letters + string.digits + "_+- ()" else " " 3304 for c in src.name 3305 ) 3306 3307 def conv_authors(auths: Optional[Sequence[_Author_v0_4]]): 3308 conv = ( 3309 _author_conv.convert if TYPE_CHECKING else _author_conv.convert_as_dict 3310 ) 3311 return None if auths is None else [conv(a) for a in auths] 3312 3313 if TYPE_CHECKING: 3314 arch_file_conv = _arch_file_conv.convert 3315 arch_lib_conv = _arch_lib_conv.convert 3316 else: 3317 arch_file_conv = _arch_file_conv.convert_as_dict 3318 arch_lib_conv = _arch_lib_conv.convert_as_dict 3319 3320 input_size_refs = { 3321 ipt.name: { 3322 a: s 3323 for a, s in zip( 3324 ipt.axes, 3325 ( 3326 ipt.shape.min 3327 if isinstance(ipt.shape, _ParameterizedInputShape_v0_4) 3328 else ipt.shape 3329 ), 3330 ) 3331 } 3332 for ipt in src.inputs 3333 if ipt.shape 3334 } 3335 output_size_refs = { 3336 **{ 3337 out.name: {a: s for a, s in zip(out.axes, out.shape)} 3338 for out in src.outputs 3339 if not isinstance(out.shape, _ImplicitOutputShape_v0_4) 3340 }, 3341 **input_size_refs, 3342 } 3343 3344 return tgt( 3345 attachments=( 3346 [] 3347 if src.attachments is None 3348 else [FileDescr(source=f) for f in src.attachments.files] 3349 ), 3350 authors=[ 3351 _author_conv.convert_as_dict(a) for a in src.authors 3352 ], # pyright: ignore[reportArgumentType] 3353 cite=[ 3354 {"text": c.text, "doi": c.doi, "url": c.url} for c in src.cite 3355 ], # pyright: ignore[reportArgumentType] 3356 config=src.config, # pyright: ignore[reportArgumentType] 3357 covers=src.covers, 3358 description=src.description, 3359 documentation=src.documentation, 3360 format_version="0.5.5", 3361 git_repo=src.git_repo, # pyright: ignore[reportArgumentType] 3362 icon=src.icon, 3363 id=None if src.id is None else ModelId(src.id), 3364 id_emoji=src.id_emoji, 3365 license=src.license, # type: ignore 3366 links=src.links, 3367 maintainers=[ 3368 _maintainer_conv.convert_as_dict(m) for m in src.maintainers 3369 ], # pyright: ignore[reportArgumentType] 3370 name=name, 3371 tags=src.tags, 3372 type=src.type, 3373 uploader=src.uploader, 3374 version=src.version, 3375 inputs=[ # pyright: ignore[reportArgumentType] 3376 _input_tensor_conv.convert_as_dict(ipt, tt, st, input_size_refs) 3377 for ipt, tt, st, in zip( 3378 src.inputs, 3379 src.test_inputs, 3380 src.sample_inputs or [None] * len(src.test_inputs), 3381 ) 3382 ], 3383 outputs=[ # pyright: ignore[reportArgumentType] 3384 _output_tensor_conv.convert_as_dict(out, tt, st, output_size_refs) 3385 for out, tt, st, in zip( 3386 src.outputs, 3387 src.test_outputs, 3388 src.sample_outputs or [None] * len(src.test_outputs), 3389 ) 3390 ], 3391 parent=( 3392 None 3393 if src.parent is None 3394 else LinkedModel( 3395 id=ModelId( 3396 str(src.parent.id) 3397 + ( 3398 "" 3399 if src.parent.version_number is None 3400 else f"/{src.parent.version_number}" 3401 ) 3402 ) 3403 ) 3404 ), 3405 training_data=( 3406 None 3407 if src.training_data is None 3408 else ( 3409 LinkedDataset( 3410 id=DatasetId( 3411 str(src.training_data.id) 3412 + ( 3413 "" 3414 if src.training_data.version_number is None 3415 else f"/{src.training_data.version_number}" 3416 ) 3417 ) 3418 ) 3419 if isinstance(src.training_data, LinkedDataset02) 3420 else src.training_data 3421 ) 3422 ), 3423 packaged_by=[ 3424 _author_conv.convert_as_dict(a) for a in src.packaged_by 3425 ], # pyright: ignore[reportArgumentType] 3426 run_mode=src.run_mode, 3427 timestamp=src.timestamp, 3428 weights=(WeightsDescr if TYPE_CHECKING else dict)( 3429 keras_hdf5=(w := src.weights.keras_hdf5) 3430 and (KerasHdf5WeightsDescr if TYPE_CHECKING else dict)( 3431 authors=conv_authors(w.authors), 3432 source=w.source, 3433 tensorflow_version=w.tensorflow_version or Version("1.15"), 3434 parent=w.parent, 3435 ), 3436 onnx=(w := src.weights.onnx) 3437 and (OnnxWeightsDescr if TYPE_CHECKING else dict)( 3438 source=w.source, 3439 authors=conv_authors(w.authors), 3440 parent=w.parent, 3441 opset_version=w.opset_version or 15, 3442 ), 3443 pytorch_state_dict=(w := src.weights.pytorch_state_dict) 3444 and (PytorchStateDictWeightsDescr if TYPE_CHECKING else dict)( 3445 source=w.source, 3446 authors=conv_authors(w.authors), 3447 parent=w.parent, 3448 architecture=( 3449 arch_file_conv( 3450 w.architecture, 3451 w.architecture_sha256, 3452 w.kwargs, 3453 ) 3454 if isinstance(w.architecture, _CallableFromFile_v0_4) 3455 else arch_lib_conv(w.architecture, w.kwargs) 3456 ), 3457 pytorch_version=w.pytorch_version or Version("1.10"), 3458 dependencies=( 3459 None 3460 if w.dependencies is None 3461 else (FileDescr if TYPE_CHECKING else dict)( 3462 source=cast( 3463 FileSource, 3464 str(deps := w.dependencies)[ 3465 ( 3466 len("conda:") 3467 if str(deps).startswith("conda:") 3468 else 0 3469 ) : 3470 ], 3471 ) 3472 ) 3473 ), 3474 ), 3475 tensorflow_js=(w := src.weights.tensorflow_js) 3476 and (TensorflowJsWeightsDescr if TYPE_CHECKING else dict)( 3477 source=w.source, 3478 authors=conv_authors(w.authors), 3479 parent=w.parent, 3480 tensorflow_version=w.tensorflow_version or Version("1.15"), 3481 ), 3482 tensorflow_saved_model_bundle=( 3483 w := src.weights.tensorflow_saved_model_bundle 3484 ) 3485 and (TensorflowSavedModelBundleWeightsDescr if TYPE_CHECKING else dict)( 3486 authors=conv_authors(w.authors), 3487 parent=w.parent, 3488 source=w.source, 3489 tensorflow_version=w.tensorflow_version or Version("1.15"), 3490 dependencies=( 3491 None 3492 if w.dependencies is None 3493 else (FileDescr if TYPE_CHECKING else dict)( 3494 source=cast( 3495 FileSource, 3496 ( 3497 str(w.dependencies)[len("conda:") :] 3498 if str(w.dependencies).startswith("conda:") 3499 else str(w.dependencies) 3500 ), 3501 ) 3502 ) 3503 ), 3504 ), 3505 torchscript=(w := src.weights.torchscript) 3506 and (TorchscriptWeightsDescr if TYPE_CHECKING else dict)( 3507 source=w.source, 3508 authors=conv_authors(w.authors), 3509 parent=w.parent, 3510 pytorch_version=w.pytorch_version or Version("1.10"), 3511 ), 3512 ), 3513 ) 3514 3515 3516_model_conv = _ModelConv(_ModelDescr_v0_4, ModelDescr) 3517 3518 3519# create better cover images for 3d data and non-image outputs 3520def generate_covers( 3521 inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]], 3522 outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]], 3523) -> List[Path]: 3524 def squeeze( 3525 data: NDArray[Any], axes: Sequence[AnyAxis] 3526 ) -> Tuple[NDArray[Any], List[AnyAxis]]: 3527 """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining""" 3528 if data.ndim != len(axes): 3529 raise ValueError( 3530 f"tensor shape {data.shape} does not match described axes" 3531 + f" {[a.id for a in axes]}" 3532 ) 3533 3534 axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1] 3535 return data.squeeze(), axes 3536 3537 def normalize( 3538 data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7 3539 ) -> NDArray[np.float32]: 3540 data = data.astype("float32") 3541 data -= data.min(axis=axis, keepdims=True) 3542 data /= data.max(axis=axis, keepdims=True) + eps 3543 return data 3544 3545 def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]): 3546 original_shape = data.shape 3547 data, axes = squeeze(data, axes) 3548 3549 # take slice fom any batch or index axis if needed 3550 # and convert the first channel axis and take a slice from any additional channel axes 3551 slices: Tuple[slice, ...] = () 3552 ndim = data.ndim 3553 ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2 3554 has_c_axis = False 3555 for i, a in enumerate(axes): 3556 s = data.shape[i] 3557 assert s > 1 3558 if ( 3559 isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis)) 3560 and ndim > ndim_need 3561 ): 3562 data = data[slices + (slice(s // 2 - 1, s // 2),)] 3563 ndim -= 1 3564 elif isinstance(a, ChannelAxis): 3565 if has_c_axis: 3566 # second channel axis 3567 data = data[slices + (slice(0, 1),)] 3568 ndim -= 1 3569 else: 3570 has_c_axis = True 3571 if s == 2: 3572 # visualize two channels with cyan and magenta 3573 data = np.concatenate( 3574 [ 3575 data[slices + (slice(1, 2),)], 3576 data[slices + (slice(0, 1),)], 3577 ( 3578 data[slices + (slice(0, 1),)] 3579 + data[slices + (slice(1, 2),)] 3580 ) 3581 / 2, # TODO: take maximum instead? 3582 ], 3583 axis=i, 3584 ) 3585 elif data.shape[i] == 3: 3586 pass # visualize 3 channels as RGB 3587 else: 3588 # visualize first 3 channels as RGB 3589 data = data[slices + (slice(3),)] 3590 3591 assert data.shape[i] == 3 3592 3593 slices += (slice(None),) 3594 3595 data, axes = squeeze(data, axes) 3596 assert len(axes) == ndim 3597 # take slice from z axis if needed 3598 slices = () 3599 if ndim > ndim_need: 3600 for i, a in enumerate(axes): 3601 s = data.shape[i] 3602 if a.id == AxisId("z"): 3603 data = data[slices + (slice(s // 2 - 1, s // 2),)] 3604 data, axes = squeeze(data, axes) 3605 ndim -= 1 3606 break 3607 3608 slices += (slice(None),) 3609 3610 # take slice from any space or time axis 3611 slices = () 3612 3613 for i, a in enumerate(axes): 3614 if ndim <= ndim_need: 3615 break 3616 3617 s = data.shape[i] 3618 assert s > 1 3619 if isinstance( 3620 a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis) 3621 ): 3622 data = data[slices + (slice(s // 2 - 1, s // 2),)] 3623 ndim -= 1 3624 3625 slices += (slice(None),) 3626 3627 del slices 3628 data, axes = squeeze(data, axes) 3629 assert len(axes) == ndim 3630 3631 if (has_c_axis and ndim != 3) or ndim != 2: 3632 raise ValueError( 3633 f"Failed to construct cover image from shape {original_shape}" 3634 ) 3635 3636 if not has_c_axis: 3637 assert ndim == 2 3638 data = np.repeat(data[:, :, None], 3, axis=2) 3639 axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB")))) 3640 ndim += 1 3641 3642 assert ndim == 3 3643 3644 # transpose axis order such that longest axis comes first... 3645 axis_order: List[int] = list(np.argsort(list(data.shape))) 3646 axis_order.reverse() 3647 # ... and channel axis is last 3648 c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0] 3649 axis_order.append(axis_order.pop(c)) 3650 axes = [axes[ao] for ao in axis_order] 3651 data = data.transpose(axis_order) 3652 3653 # h, w = data.shape[:2] 3654 # if h / w in (1.0 or 2.0): 3655 # pass 3656 # elif h / w < 2: 3657 # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images 3658 3659 norm_along = ( 3660 tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None 3661 ) 3662 # normalize the data and map to 8 bit 3663 data = normalize(data, norm_along) 3664 data = (data * 255).astype("uint8") 3665 3666 return data 3667 3668 def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]): 3669 assert im0.dtype == im1.dtype == np.uint8 3670 assert im0.shape == im1.shape 3671 assert im0.ndim == 3 3672 N, M, C = im0.shape 3673 assert C == 3 3674 out = np.ones((N, M, C), dtype="uint8") 3675 for c in range(C): 3676 outc = np.tril(im0[..., c]) 3677 mask = outc == 0 3678 outc[mask] = np.triu(im1[..., c])[mask] 3679 out[..., c] = outc 3680 3681 return out 3682 3683 if not inputs: 3684 raise ValueError("Missing test input tensor for cover generation.") 3685 3686 if not outputs: 3687 raise ValueError("Missing test output tensor for cover generation.") 3688 3689 ipt_descr, ipt = inputs[0] 3690 out_descr, out = outputs[0] 3691 3692 ipt_img = to_2d_image(ipt, ipt_descr.axes) 3693 out_img = to_2d_image(out, out_descr.axes) 3694 3695 cover_folder = Path(mkdtemp()) 3696 if ipt_img.shape == out_img.shape: 3697 covers = [cover_folder / "cover.png"] 3698 imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img)) 3699 else: 3700 covers = [cover_folder / "input.png", cover_folder / "output.png"] 3701 imwrite(covers[0], ipt_img) 3702 imwrite(covers[1], out_img) 3703 3704 return covers
Space unit compatible to the OME-Zarr axes specification 0.5
Time unit compatible to the OME-Zarr axes specification 0.5
229class TensorId(LowerCaseIdentifier): 230 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 231 Annotated[LowerCaseIdentifierAnno, MaxLen(32)] 232 ]
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.
245class AxisId(LowerCaseIdentifier): 246 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 247 Annotated[ 248 LowerCaseIdentifierAnno, 249 MaxLen(16), 250 AfterValidator(_normalize_axis_id), 251 ] 252 ]
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.
Annotates an integer to calculate a concrete axis size from a ParameterizedSize
.
298class ParameterizedSize(Node): 299 """Describes a range of valid tensor axis sizes as `size = min + n*step`. 300 301 - **min** and **step** are given by the model description. 302 - All blocksize paramters n = 0,1,2,... yield a valid `size`. 303 - A greater blocksize paramter n = 0,1,2,... results in a greater **size**. 304 This allows to adjust the axis size more generically. 305 """ 306 307 N: ClassVar[Type[int]] = ParameterizedSize_N 308 """Positive integer to parameterize this axis""" 309 310 min: Annotated[int, Gt(0)] 311 step: Annotated[int, Gt(0)] 312 313 def validate_size(self, size: int) -> int: 314 if size < self.min: 315 raise ValueError(f"size {size} < {self.min}") 316 if (size - self.min) % self.step != 0: 317 raise ValueError( 318 f"axis of size {size} is not parameterized by `min + n*step` =" 319 + f" `{self.min} + n*{self.step}`" 320 ) 321 322 return size 323 324 def get_size(self, n: ParameterizedSize_N) -> int: 325 return self.min + self.step * n 326 327 def get_n(self, s: int) -> ParameterizedSize_N: 328 """return smallest n parameterizing a size greater or equal than `s`""" 329 return ceil((s - self.min) / self.step)
Describes a range of valid tensor axis sizes as size = min + n*step
.
- min and step are given by the model description.
- All blocksize paramters n = 0,1,2,... yield a valid
size
. - A greater blocksize paramter n = 0,1,2,... results in a greater size. This allows to adjust the axis size more generically.
313 def validate_size(self, size: int) -> int: 314 if size < self.min: 315 raise ValueError(f"size {size} < {self.min}") 316 if (size - self.min) % self.step != 0: 317 raise ValueError( 318 f"axis of size {size} is not parameterized by `min + n*step` =" 319 + f" `{self.min} + n*{self.step}`" 320 ) 321 322 return size
327 def get_n(self, s: int) -> ParameterizedSize_N: 328 """return smallest n parameterizing a size greater or equal than `s`""" 329 return ceil((s - self.min) / self.step)
return smallest n parameterizing a size greater or equal than s
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
332class DataDependentSize(Node): 333 min: Annotated[int, Gt(0)] = 1 334 max: Annotated[Optional[int], Gt(1)] = None 335 336 @model_validator(mode="after") 337 def _validate_max_gt_min(self): 338 if self.max is not None and self.min >= self.max: 339 raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}") 340 341 return self 342 343 def validate_size(self, size: int) -> int: 344 if size < self.min: 345 raise ValueError(f"size {size} < {self.min}") 346 347 if self.max is not None and size > self.max: 348 raise ValueError(f"size {size} > {self.max}") 349 350 return size
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
353class SizeReference(Node): 354 """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis. 355 356 `axis.size = reference.size * reference.scale / axis.scale + offset` 357 358 Note: 359 1. The axis and the referenced axis need to have the same unit (or no unit). 360 2. Batch axes may not be referenced. 361 3. Fractions are rounded down. 362 4. If the reference axis is `concatenable` the referencing axis is assumed to be 363 `concatenable` as well with the same block order. 364 365 Example: 366 An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm². 367 Let's assume that we want to express the image height h in relation to its width w 368 instead of only accepting input images of exactly 100*49 pixels 369 (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`). 370 371 >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2) 372 >>> h = SpaceInputAxis( 373 ... id=AxisId("h"), 374 ... size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1), 375 ... unit="millimeter", 376 ... scale=4, 377 ... ) 378 >>> print(h.size.get_size(h, w)) 379 49 380 381 ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49 382 """ 383 384 tensor_id: TensorId 385 """tensor id of the reference axis""" 386 387 axis_id: AxisId 388 """axis id of the reference axis""" 389 390 offset: StrictInt = 0 391 392 def get_size( 393 self, 394 axis: Union[ 395 ChannelAxis, 396 IndexInputAxis, 397 IndexOutputAxis, 398 TimeInputAxis, 399 SpaceInputAxis, 400 TimeOutputAxis, 401 TimeOutputAxisWithHalo, 402 SpaceOutputAxis, 403 SpaceOutputAxisWithHalo, 404 ], 405 ref_axis: Union[ 406 ChannelAxis, 407 IndexInputAxis, 408 IndexOutputAxis, 409 TimeInputAxis, 410 SpaceInputAxis, 411 TimeOutputAxis, 412 TimeOutputAxisWithHalo, 413 SpaceOutputAxis, 414 SpaceOutputAxisWithHalo, 415 ], 416 n: ParameterizedSize_N = 0, 417 ref_size: Optional[int] = None, 418 ): 419 """Compute the concrete size for a given axis and its reference axis. 420 421 Args: 422 axis: The axis this `SizeReference` is the size of. 423 ref_axis: The reference axis to compute the size from. 424 n: If the **ref_axis** is parameterized (of type `ParameterizedSize`) 425 and no fixed **ref_size** is given, 426 **n** is used to compute the size of the parameterized **ref_axis**. 427 ref_size: Overwrite the reference size instead of deriving it from 428 **ref_axis** 429 (**ref_axis.scale** is still used; any given **n** is ignored). 430 """ 431 assert ( 432 axis.size == self 433 ), "Given `axis.size` is not defined by this `SizeReference`" 434 435 assert ( 436 ref_axis.id == self.axis_id 437 ), f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}." 438 439 assert axis.unit == ref_axis.unit, ( 440 "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`," 441 f" but {axis.unit}!={ref_axis.unit}" 442 ) 443 if ref_size is None: 444 if isinstance(ref_axis.size, (int, float)): 445 ref_size = ref_axis.size 446 elif isinstance(ref_axis.size, ParameterizedSize): 447 ref_size = ref_axis.size.get_size(n) 448 elif isinstance(ref_axis.size, DataDependentSize): 449 raise ValueError( 450 "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`." 451 ) 452 elif isinstance(ref_axis.size, SizeReference): 453 raise ValueError( 454 "Reference axis referenced in `SizeReference` may not be sized by a" 455 + " `SizeReference` itself." 456 ) 457 else: 458 assert_never(ref_axis.size) 459 460 return int(ref_size * ref_axis.scale / axis.scale + self.offset) 461 462 @staticmethod 463 def _get_unit( 464 axis: Union[ 465 ChannelAxis, 466 IndexInputAxis, 467 IndexOutputAxis, 468 TimeInputAxis, 469 SpaceInputAxis, 470 TimeOutputAxis, 471 TimeOutputAxisWithHalo, 472 SpaceOutputAxis, 473 SpaceOutputAxisWithHalo, 474 ], 475 ): 476 return axis.unit
A tensor axis size (extent in pixels/frames) defined in relation to a reference axis.
axis.size = reference.size * reference.scale / axis.scale + offset
Note:
- The axis and the referenced axis need to have the same unit (or no unit).
- Batch axes may not be referenced.
- Fractions are rounded down.
- If the reference axis is
concatenable
the referencing axis is assumed to beconcatenable
as well with the same block order.
Example:
An unisotropic input image of wh=10049 pixels depicts a phsical space of 200196mm².
Let's assume that we want to express the image height h in relation to its width w
instead of only accepting input images of exactly 10049 pixels
(for example to express a range of valid image shapes by parametrizing w, see ParameterizedSize
).
>>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2)
>>> h = SpaceInputAxis(
... id=AxisId("h"),
... size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1),
... unit="millimeter",
... scale=4,
... )
>>> print(h.size.get_size(h, w))
49
⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49
392 def get_size( 393 self, 394 axis: Union[ 395 ChannelAxis, 396 IndexInputAxis, 397 IndexOutputAxis, 398 TimeInputAxis, 399 SpaceInputAxis, 400 TimeOutputAxis, 401 TimeOutputAxisWithHalo, 402 SpaceOutputAxis, 403 SpaceOutputAxisWithHalo, 404 ], 405 ref_axis: Union[ 406 ChannelAxis, 407 IndexInputAxis, 408 IndexOutputAxis, 409 TimeInputAxis, 410 SpaceInputAxis, 411 TimeOutputAxis, 412 TimeOutputAxisWithHalo, 413 SpaceOutputAxis, 414 SpaceOutputAxisWithHalo, 415 ], 416 n: ParameterizedSize_N = 0, 417 ref_size: Optional[int] = None, 418 ): 419 """Compute the concrete size for a given axis and its reference axis. 420 421 Args: 422 axis: The axis this `SizeReference` is the size of. 423 ref_axis: The reference axis to compute the size from. 424 n: If the **ref_axis** is parameterized (of type `ParameterizedSize`) 425 and no fixed **ref_size** is given, 426 **n** is used to compute the size of the parameterized **ref_axis**. 427 ref_size: Overwrite the reference size instead of deriving it from 428 **ref_axis** 429 (**ref_axis.scale** is still used; any given **n** is ignored). 430 """ 431 assert ( 432 axis.size == self 433 ), "Given `axis.size` is not defined by this `SizeReference`" 434 435 assert ( 436 ref_axis.id == self.axis_id 437 ), f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}." 438 439 assert axis.unit == ref_axis.unit, ( 440 "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`," 441 f" but {axis.unit}!={ref_axis.unit}" 442 ) 443 if ref_size is None: 444 if isinstance(ref_axis.size, (int, float)): 445 ref_size = ref_axis.size 446 elif isinstance(ref_axis.size, ParameterizedSize): 447 ref_size = ref_axis.size.get_size(n) 448 elif isinstance(ref_axis.size, DataDependentSize): 449 raise ValueError( 450 "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`." 451 ) 452 elif isinstance(ref_axis.size, SizeReference): 453 raise ValueError( 454 "Reference axis referenced in `SizeReference` may not be sized by a" 455 + " `SizeReference` itself." 456 ) 457 else: 458 assert_never(ref_axis.size) 459 460 return int(ref_size * ref_axis.scale / axis.scale + self.offset)
Compute the concrete size for a given axis and its reference axis.
Arguments:
- axis: The axis this
SizeReference
is the size of. - ref_axis: The reference axis to compute the size from.
- n: If the ref_axis is parameterized (of type
ParameterizedSize
) and no fixed ref_size is given, n is used to compute the size of the parameterized ref_axis. - ref_size: Overwrite the reference size instead of deriving it from ref_axis (ref_axis.scale is still used; any given n is ignored).
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
479class AxisBase(NodeWithExplicitlySetFields): 480 id: AxisId 481 """An axis id unique across all axes of one tensor.""" 482 483 description: Annotated[str, MaxLen(128)] = "" 484 """A short description of this axis beyond its type and id."""
A short description of this axis beyond its type and id.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
487class WithHalo(Node): 488 halo: Annotated[int, Ge(1)] 489 """The halo should be cropped from the output tensor to avoid boundary effects. 490 It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`. 491 To document a halo that is already cropped by the model use `size.offset` instead.""" 492 493 size: Annotated[ 494 SizeReference, 495 Field( 496 examples=[ 497 10, 498 SizeReference( 499 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 500 ).model_dump(mode="json"), 501 ] 502 ), 503 ] 504 """reference to another axis with an optional offset (see `SizeReference`)"""
The halo should be cropped from the output tensor to avoid boundary effects.
It is to be cropped from both sides, i.e. size_after_crop = size - 2 * halo
.
To document a halo that is already cropped by the model use size.offset
instead.
reference to another axis with an optional offset (see SizeReference
)
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
510class BatchAxis(AxisBase): 511 implemented_type: ClassVar[Literal["batch"]] = "batch" 512 if TYPE_CHECKING: 513 type: Literal["batch"] = "batch" 514 else: 515 type: Literal["batch"] 516 517 id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID 518 size: Optional[Literal[1]] = None 519 """The batch size may be fixed to 1, 520 otherwise (the default) it may be chosen arbitrarily depending on available memory""" 521 522 @property 523 def scale(self): 524 return 1.0 525 526 @property 527 def concatenable(self): 528 return True 529 530 @property 531 def unit(self): 532 return None
The batch size may be fixed to 1, otherwise (the default) it may be chosen arbitrarily depending on available memory
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
535class ChannelAxis(AxisBase): 536 implemented_type: ClassVar[Literal["channel"]] = "channel" 537 if TYPE_CHECKING: 538 type: Literal["channel"] = "channel" 539 else: 540 type: Literal["channel"] 541 542 id: NonBatchAxisId = AxisId("channel") 543 544 channel_names: NotEmpty[List[Identifier]] 545 546 @property 547 def size(self) -> int: 548 return len(self.channel_names) 549 550 @property 551 def concatenable(self): 552 return False 553 554 @property 555 def scale(self) -> float: 556 return 1.0 557 558 @property 559 def unit(self): 560 return None
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
563class IndexAxisBase(AxisBase): 564 implemented_type: ClassVar[Literal["index"]] = "index" 565 if TYPE_CHECKING: 566 type: Literal["index"] = "index" 567 else: 568 type: Literal["index"] 569 570 id: NonBatchAxisId = AxisId("index") 571 572 @property 573 def scale(self) -> float: 574 return 1.0 575 576 @property 577 def unit(self): 578 return None
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
601class IndexInputAxis(IndexAxisBase, _WithInputAxisSize): 602 concatenable: bool = False 603 """If a model has a `concatenable` input axis, it can be processed blockwise, 604 splitting a longer sample axis into blocks matching its input tensor description. 605 Output axes are concatenable if they have a `SizeReference` to a concatenable 606 input axis. 607 """
If a model has a concatenable
input axis, it can be processed blockwise,
splitting a longer sample axis into blocks matching its input tensor description.
Output axes are concatenable if they have a SizeReference
to a concatenable
input axis.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
610class IndexOutputAxis(IndexAxisBase): 611 size: Annotated[ 612 Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize], 613 Field( 614 examples=[ 615 10, 616 SizeReference( 617 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 618 ).model_dump(mode="json"), 619 ] 620 ), 621 ] 622 """The size/length of this axis can be specified as 623 - fixed integer 624 - reference to another axis with an optional offset (`SizeReference`) 625 - data dependent size using `DataDependentSize` (size is only known after model inference) 626 """
The size/length of this axis can be specified as
- fixed integer
- reference to another axis with an optional offset (
SizeReference
) - data dependent size using
DataDependentSize
(size is only known after model inference)
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
629class TimeAxisBase(AxisBase): 630 implemented_type: ClassVar[Literal["time"]] = "time" 631 if TYPE_CHECKING: 632 type: Literal["time"] = "time" 633 else: 634 type: Literal["time"] 635 636 id: NonBatchAxisId = AxisId("time") 637 unit: Optional[TimeUnit] = None 638 scale: Annotated[float, Gt(0)] = 1.0
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
641class TimeInputAxis(TimeAxisBase, _WithInputAxisSize): 642 concatenable: bool = False 643 """If a model has a `concatenable` input axis, it can be processed blockwise, 644 splitting a longer sample axis into blocks matching its input tensor description. 645 Output axes are concatenable if they have a `SizeReference` to a concatenable 646 input axis. 647 """
If a model has a concatenable
input axis, it can be processed blockwise,
splitting a longer sample axis into blocks matching its input tensor description.
Output axes are concatenable if they have a SizeReference
to a concatenable
input axis.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
650class SpaceAxisBase(AxisBase): 651 implemented_type: ClassVar[Literal["space"]] = "space" 652 if TYPE_CHECKING: 653 type: Literal["space"] = "space" 654 else: 655 type: Literal["space"] 656 657 id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x") 658 unit: Optional[SpaceUnit] = None 659 scale: Annotated[float, Gt(0)] = 1.0
An axis id unique across all axes of one tensor.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
662class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize): 663 concatenable: bool = False 664 """If a model has a `concatenable` input axis, it can be processed blockwise, 665 splitting a longer sample axis into blocks matching its input tensor description. 666 Output axes are concatenable if they have a `SizeReference` to a concatenable 667 input axis. 668 """
If a model has a concatenable
input axis, it can be processed blockwise,
splitting a longer sample axis into blocks matching its input tensor description.
Output axes are concatenable if they have a SizeReference
to a concatenable
input axis.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
intended for isinstance comparisons in py<3.10
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
intended for isinstance comparisons in py<3.10
intended for isinstance comparisons in py<3.10
790class NominalOrOrdinalDataDescr(Node): 791 values: TVs 792 """A fixed set of nominal or an ascending sequence of ordinal values. 793 In this case `data.type` is required to be an unsigend integer type, e.g. 'uint8'. 794 String `values` are interpreted as labels for tensor values 0, ..., N. 795 Note: as YAML 1.2 does not natively support a "set" datatype, 796 nominal values should be given as a sequence (aka list/array) as well. 797 """ 798 799 type: Annotated[ 800 NominalOrOrdinalDType, 801 Field( 802 examples=[ 803 "float32", 804 "uint8", 805 "uint16", 806 "int64", 807 "bool", 808 ], 809 ), 810 ] = "uint8" 811 812 @model_validator(mode="after") 813 def _validate_values_match_type( 814 self, 815 ) -> Self: 816 incompatible: List[Any] = [] 817 for v in self.values: 818 if self.type == "bool": 819 if not isinstance(v, bool): 820 incompatible.append(v) 821 elif self.type in DTYPE_LIMITS: 822 if ( 823 isinstance(v, (int, float)) 824 and ( 825 v < DTYPE_LIMITS[self.type].min 826 or v > DTYPE_LIMITS[self.type].max 827 ) 828 or (isinstance(v, str) and "uint" not in self.type) 829 or (isinstance(v, float) and "int" in self.type) 830 ): 831 incompatible.append(v) 832 else: 833 incompatible.append(v) 834 835 if len(incompatible) == 5: 836 incompatible.append("...") 837 break 838 839 if incompatible: 840 raise ValueError( 841 f"data type '{self.type}' incompatible with values {incompatible}" 842 ) 843 844 return self 845 846 unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None 847 848 @property 849 def range(self): 850 if isinstance(self.values[0], str): 851 return 0, len(self.values) - 1 852 else: 853 return min(self.values), max(self.values)
A fixed set of nominal or an ascending sequence of ordinal values.
In this case data.type
is required to be an unsigend integer type, e.g. 'uint8'.
String values
are interpreted as labels for tensor values 0, ..., N.
Note: as YAML 1.2 does not natively support a "set" datatype,
nominal values should be given as a sequence (aka list/array) as well.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
870class IntervalOrRatioDataDescr(Node): 871 type: Annotated[ # TODO: rename to dtype 872 IntervalOrRatioDType, 873 Field( 874 examples=["float32", "float64", "uint8", "uint16"], 875 ), 876 ] = "float32" 877 range: Tuple[Optional[float], Optional[float]] = ( 878 None, 879 None, 880 ) 881 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor. 882 `None` corresponds to min/max of what can be expressed by **type**.""" 883 unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit" 884 scale: float = 1.0 885 """Scale for data on an interval (or ratio) scale.""" 886 offset: Optional[float] = None 887 """Offset for data on a ratio scale.""" 888 889 @model_validator(mode="before") 890 def _replace_inf(cls, data: Any): 891 if is_dict(data): 892 if "range" in data and is_sequence(data["range"]): 893 forbidden = ( 894 "inf", 895 "-inf", 896 ".inf", 897 "-.inf", 898 float("inf"), 899 float("-inf"), 900 ) 901 if any(v in forbidden for v in data["range"]): 902 issue_warning("replaced 'inf' value", value=data["range"]) 903 904 data["range"] = tuple( 905 (None if v in forbidden else v) for v in data["range"] 906 ) 907 908 return data
Tuple (minimum, maximum)
specifying the allowed range of the data in this tensor.
None
corresponds to min/max of what can be expressed by type.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
processing base class
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
918class BinarizeKwargs(ProcessingKwargs): 919 """key word arguments for `BinarizeDescr`""" 920 921 threshold: float 922 """The fixed threshold"""
key word arguments for BinarizeDescr
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
925class BinarizeAlongAxisKwargs(ProcessingKwargs): 926 """key word arguments for `BinarizeDescr`""" 927 928 threshold: NotEmpty[List[float]] 929 """The fixed threshold values along `axis`""" 930 931 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 932 """The `threshold` axis"""
key word arguments for BinarizeDescr
The threshold
axis
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
935class BinarizeDescr(ProcessingDescrBase): 936 """Binarize the tensor with a fixed threshold. 937 938 Values above `BinarizeKwargs.threshold`/`BinarizeAlongAxisKwargs.threshold` 939 will be set to one, values below the threshold to zero. 940 941 Examples: 942 - in YAML 943 ```yaml 944 postprocessing: 945 - id: binarize 946 kwargs: 947 axis: 'channel' 948 threshold: [0.25, 0.5, 0.75] 949 ``` 950 - in Python: 951 >>> postprocessing = [BinarizeDescr( 952 ... kwargs=BinarizeAlongAxisKwargs( 953 ... axis=AxisId('channel'), 954 ... threshold=[0.25, 0.5, 0.75], 955 ... ) 956 ... )] 957 """ 958 959 implemented_id: ClassVar[Literal["binarize"]] = "binarize" 960 if TYPE_CHECKING: 961 id: Literal["binarize"] = "binarize" 962 else: 963 id: Literal["binarize"] 964 kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs]
Binarize the tensor with a fixed threshold.
Values above BinarizeKwargs.threshold
/BinarizeAlongAxisKwargs.threshold
will be set to one, values below the threshold to zero.
Examples:
- in YAML
postprocessing:
- id: binarize
kwargs:
axis: 'channel'
threshold: [0.25, 0.5, 0.75]
- in Python:
>>> postprocessing = [BinarizeDescr( ... kwargs=BinarizeAlongAxisKwargs( ... axis=AxisId('channel'), ... threshold=[0.25, 0.5, 0.75], ... ) ... )]
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
967class ClipDescr(ProcessingDescrBase): 968 """Set tensor values below min to min and above max to max. 969 970 See `ScaleRangeDescr` for examples. 971 """ 972 973 implemented_id: ClassVar[Literal["clip"]] = "clip" 974 if TYPE_CHECKING: 975 id: Literal["clip"] = "clip" 976 else: 977 id: Literal["clip"] 978 979 kwargs: ClipKwargs
Set tensor values below min to min and above max to max.
See ScaleRangeDescr
for examples.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
982class EnsureDtypeKwargs(ProcessingKwargs): 983 """key word arguments for `EnsureDtypeDescr`""" 984 985 dtype: Literal[ 986 "float32", 987 "float64", 988 "uint8", 989 "int8", 990 "uint16", 991 "int16", 992 "uint32", 993 "int32", 994 "uint64", 995 "int64", 996 "bool", 997 ]
key word arguments for EnsureDtypeDescr
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
1000class EnsureDtypeDescr(ProcessingDescrBase): 1001 """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching). 1002 1003 This can for example be used to ensure the inner neural network model gets a 1004 different input tensor data type than the fully described bioimage.io model does. 1005 1006 Examples: 1007 The described bioimage.io model (incl. preprocessing) accepts any 1008 float32-compatible tensor, normalizes it with percentiles and clipping and then 1009 casts it to uint8, which is what the neural network in this example expects. 1010 - in YAML 1011 ```yaml 1012 inputs: 1013 - data: 1014 type: float32 # described bioimage.io model is compatible with any float32 input tensor 1015 preprocessing: 1016 - id: scale_range 1017 kwargs: 1018 axes: ['y', 'x'] 1019 max_percentile: 99.8 1020 min_percentile: 5.0 1021 - id: clip 1022 kwargs: 1023 min: 0.0 1024 max: 1.0 1025 - id: ensure_dtype # the neural network of the model requires uint8 1026 kwargs: 1027 dtype: uint8 1028 ``` 1029 - in Python: 1030 >>> preprocessing = [ 1031 ... ScaleRangeDescr( 1032 ... kwargs=ScaleRangeKwargs( 1033 ... axes= (AxisId('y'), AxisId('x')), 1034 ... max_percentile= 99.8, 1035 ... min_percentile= 5.0, 1036 ... ) 1037 ... ), 1038 ... ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)), 1039 ... EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")), 1040 ... ] 1041 """ 1042 1043 implemented_id: ClassVar[Literal["ensure_dtype"]] = "ensure_dtype" 1044 if TYPE_CHECKING: 1045 id: Literal["ensure_dtype"] = "ensure_dtype" 1046 else: 1047 id: Literal["ensure_dtype"] 1048 1049 kwargs: EnsureDtypeKwargs
Cast the tensor data type to EnsureDtypeKwargs.dtype
(if not matching).
This can for example be used to ensure the inner neural network model gets a different input tensor data type than the fully described bioimage.io model does.
Examples:
The described bioimage.io model (incl. preprocessing) accepts any float32-compatible tensor, normalizes it with percentiles and clipping and then casts it to uint8, which is what the neural network in this example expects.
- in YAML
inputs: - data: type: float32 # described bioimage.io model is compatible with any float32 input tensor preprocessing: - id: scale_range kwargs: axes: ['y', 'x'] max_percentile: 99.8 min_percentile: 5.0 - id: clip kwargs: min: 0.0 max: 1.0 - id: ensure_dtype # the neural network of the model requires uint8 kwargs: dtype: uint8
- in Python:
>>> preprocessing = [ ... ScaleRangeDescr( ... kwargs=ScaleRangeKwargs( ... axes= (AxisId('y'), AxisId('x')), ... max_percentile= 99.8, ... min_percentile= 5.0, ... ) ... ), ... ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)), ... EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")), ... ]
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
1052class ScaleLinearKwargs(ProcessingKwargs): 1053 """Key word arguments for `ScaleLinearDescr`""" 1054 1055 gain: float = 1.0 1056 """multiplicative factor""" 1057 1058 offset: float = 0.0 1059 """additive term""" 1060 1061 @model_validator(mode="after") 1062 def _validate(self) -> Self: 1063 if self.gain == 1.0 and self.offset == 0.0: 1064 raise ValueError( 1065 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 1066 + " != 0.0." 1067 ) 1068 1069 return self
Key word arguments for ScaleLinearDescr
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
1072class ScaleLinearAlongAxisKwargs(ProcessingKwargs): 1073 """Key word arguments for `ScaleLinearDescr`""" 1074 1075 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 1076 """The axis of gain and offset values.""" 1077 1078 gain: Union[float, NotEmpty[List[float]]] = 1.0 1079 """multiplicative factor""" 1080 1081 offset: Union[float, NotEmpty[List[float]]] = 0.0 1082 """additive term""" 1083 1084 @model_validator(mode="after") 1085 def _validate(self) -> Self: 1086 1087 if isinstance(self.gain, list): 1088 if isinstance(self.offset, list): 1089 if len(self.gain) != len(self.offset): 1090 raise ValueError( 1091 f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match." 1092 ) 1093 else: 1094 self.offset = [float(self.offset)] * len(self.gain) 1095 elif isinstance(self.offset, list): 1096 self.gain = [float(self.gain)] * len(self.offset) 1097 else: 1098 raise ValueError( 1099 "Do not specify an `axis` for scalar gain and offset values." 1100 ) 1101 1102 if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset): 1103 raise ValueError( 1104 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 1105 + " != 0.0." 1106 ) 1107 1108 return self
Key word arguments for ScaleLinearDescr
The axis of gain and offset values.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
1111class ScaleLinearDescr(ProcessingDescrBase): 1112 """Fixed linear scaling. 1113 1114 Examples: 1115 1. Scale with scalar gain and offset 1116 - in YAML 1117 ```yaml 1118 preprocessing: 1119 - id: scale_linear 1120 kwargs: 1121 gain: 2.0 1122 offset: 3.0 1123 ``` 1124 - in Python: 1125 >>> preprocessing = [ 1126 ... ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0)) 1127 ... ] 1128 1129 2. Independent scaling along an axis 1130 - in YAML 1131 ```yaml 1132 preprocessing: 1133 - id: scale_linear 1134 kwargs: 1135 axis: 'channel' 1136 gain: [1.0, 2.0, 3.0] 1137 ``` 1138 - in Python: 1139 >>> preprocessing = [ 1140 ... ScaleLinearDescr( 1141 ... kwargs=ScaleLinearAlongAxisKwargs( 1142 ... axis=AxisId("channel"), 1143 ... gain=[1.0, 2.0, 3.0], 1144 ... ) 1145 ... ) 1146 ... ] 1147 1148 """ 1149 1150 implemented_id: ClassVar[Literal["scale_linear"]] = "scale_linear" 1151 if TYPE_CHECKING: 1152 id: Literal["scale_linear"] = "scale_linear" 1153 else: 1154 id: Literal["scale_linear"] 1155 kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs]
Fixed linear scaling.
Examples:
Scale with scalar gain and offset
in YAML
preprocessing: - id: scale_linear kwargs: gain: 2.0 offset: 3.0
in Python:
>>> preprocessing = [ ... ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0)) ... ]
Independent scaling along an axis
in YAML
preprocessing: - id: scale_linear kwargs: axis: 'channel' gain: [1.0, 2.0, 3.0]
in Python:
>>> preprocessing = [ ... ScaleLinearDescr( ... kwargs=ScaleLinearAlongAxisKwargs( ... axis=AxisId("channel"), ... gain=[1.0, 2.0, 3.0], ... ) ... ) ... ]
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
1158class SigmoidDescr(ProcessingDescrBase): 1159 """The logistic sigmoid function, a.k.a. expit function. 1160 1161 Examples: 1162 - in YAML 1163 ```yaml 1164 postprocessing: 1165 - id: sigmoid 1166 ``` 1167 - in Python: 1168 >>> postprocessing = [SigmoidDescr()] 1169 """ 1170 1171 implemented_id: ClassVar[Literal["sigmoid"]] = "sigmoid" 1172 if TYPE_CHECKING: 1173 id: Literal["sigmoid"] = "sigmoid" 1174 else: 1175 id: Literal["sigmoid"] 1176 1177 @property 1178 def kwargs(self) -> ProcessingKwargs: 1179 """empty kwargs""" 1180 return ProcessingKwargs()
The logistic sigmoid function, a.k.a. expit function.
Examples:
- in YAML
postprocessing:
- id: sigmoid
- in Python:
>>> postprocessing = [SigmoidDescr()]
1177 @property 1178 def kwargs(self) -> ProcessingKwargs: 1179 """empty kwargs""" 1180 return ProcessingKwargs()
empty kwargs
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
1183class SoftmaxKwargs(ProcessingKwargs): 1184 """key word arguments for `SoftmaxDescr`""" 1185 1186 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] = AxisId("channel") 1187 """The axis to apply the softmax function along. 1188 Note: 1189 Defaults to 'channel' axis 1190 (which may not exist, in which case 1191 a different axis id has to be specified). 1192 """
key word arguments for SoftmaxDescr
The axis to apply the softmax function along.
Note:
Defaults to 'channel' axis (which may not exist, in which case a different axis id has to be specified).
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
1195class SoftmaxDescr(ProcessingDescrBase): 1196 """The softmax function. 1197 1198 Examples: 1199 - in YAML 1200 ```yaml 1201 postprocessing: 1202 - id: softmax 1203 kwargs: 1204 axis: channel 1205 ``` 1206 - in Python: 1207 >>> postprocessing = [SoftmaxDescr(kwargs=SoftmaxKwargs(axis=AxisId("channel")))] 1208 """ 1209 1210 implemented_id: ClassVar[Literal["softmax"]] = "softmax" 1211 if TYPE_CHECKING: 1212 id: Literal["softmax"] = "softmax" 1213 else: 1214 id: Literal["softmax"] 1215 1216 kwargs: SoftmaxKwargs = Field(default_factory=SoftmaxKwargs.model_construct)
The softmax function.
Examples:
- in YAML
postprocessing:
- id: softmax
kwargs:
axis: channel
- in Python:
>>> postprocessing = [SoftmaxDescr(kwargs=SoftmaxKwargs(axis=AxisId("channel")))]
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
1219class FixedZeroMeanUnitVarianceKwargs(ProcessingKwargs): 1220 """key word arguments for `FixedZeroMeanUnitVarianceDescr`""" 1221 1222 mean: float 1223 """The mean value to normalize with.""" 1224 1225 std: Annotated[float, Ge(1e-6)] 1226 """The standard deviation value to normalize with."""
key word arguments for FixedZeroMeanUnitVarianceDescr
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
1229class FixedZeroMeanUnitVarianceAlongAxisKwargs(ProcessingKwargs): 1230 """key word arguments for `FixedZeroMeanUnitVarianceDescr`""" 1231 1232 mean: NotEmpty[List[float]] 1233 """The mean value(s) to normalize with.""" 1234 1235 std: NotEmpty[List[Annotated[float, Ge(1e-6)]]] 1236 """The standard deviation value(s) to normalize with. 1237 Size must match `mean` values.""" 1238 1239 axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])] 1240 """The axis of the mean/std values to normalize each entry along that dimension 1241 separately.""" 1242 1243 @model_validator(mode="after") 1244 def _mean_and_std_match(self) -> Self: 1245 if len(self.mean) != len(self.std): 1246 raise ValueError( 1247 f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})" 1248 + " must match." 1249 ) 1250 1251 return self
key word arguments for FixedZeroMeanUnitVarianceDescr
The standard deviation value(s) to normalize with.
Size must match mean
values.
The axis of the mean/std values to normalize each entry along that dimension separately.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
1254class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase): 1255 """Subtract a given mean and divide by the standard deviation. 1256 1257 Normalize with fixed, precomputed values for 1258 `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std` 1259 Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given 1260 axes. 1261 1262 Examples: 1263 1. scalar value for whole tensor 1264 - in YAML 1265 ```yaml 1266 preprocessing: 1267 - id: fixed_zero_mean_unit_variance 1268 kwargs: 1269 mean: 103.5 1270 std: 13.7 1271 ``` 1272 - in Python 1273 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 1274 ... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7) 1275 ... )] 1276 1277 2. independently along an axis 1278 - in YAML 1279 ```yaml 1280 preprocessing: 1281 - id: fixed_zero_mean_unit_variance 1282 kwargs: 1283 axis: channel 1284 mean: [101.5, 102.5, 103.5] 1285 std: [11.7, 12.7, 13.7] 1286 ``` 1287 - in Python 1288 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 1289 ... kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( 1290 ... axis=AxisId("channel"), 1291 ... mean=[101.5, 102.5, 103.5], 1292 ... std=[11.7, 12.7, 13.7], 1293 ... ) 1294 ... )] 1295 """ 1296 1297 implemented_id: ClassVar[Literal["fixed_zero_mean_unit_variance"]] = ( 1298 "fixed_zero_mean_unit_variance" 1299 ) 1300 if TYPE_CHECKING: 1301 id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance" 1302 else: 1303 id: Literal["fixed_zero_mean_unit_variance"] 1304 1305 kwargs: Union[ 1306 FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs 1307 ]
Subtract a given mean and divide by the standard deviation.
Normalize with fixed, precomputed values for
FixedZeroMeanUnitVarianceKwargs.mean
and FixedZeroMeanUnitVarianceKwargs.std
Use FixedZeroMeanUnitVarianceAlongAxisKwargs
for independent scaling along given
axes.
Examples:
- scalar value for whole tensor
- in YAML
preprocessing:
- id: fixed_zero_mean_unit_variance
kwargs:
mean: 103.5
std: 13.7
- in Python
>>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7)
... )]
independently along an axis
- in YAML
preprocessing: - id: fixed_zero_mean_unit_variance kwargs: axis: channel mean: [101.5, 102.5, 103.5] std: [11.7, 12.7, 13.7]
- in Python
>>> preprocessing = [FixedZeroMeanUnitVarianceDescr( ... kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( ... axis=AxisId("channel"), ... mean=[101.5, 102.5, 103.5], ... std=[11.7, 12.7, 13.7], ... ) ... )]
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
1310class ZeroMeanUnitVarianceKwargs(ProcessingKwargs): 1311 """key word arguments for `ZeroMeanUnitVarianceDescr`""" 1312 1313 axes: Annotated[ 1314 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 1315 ] = None 1316 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 1317 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 1318 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 1319 To normalize each sample independently leave out the 'batch' axis. 1320 Default: Scale all axes jointly.""" 1321 1322 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 1323 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`."""
key word arguments for ZeroMeanUnitVarianceDescr
The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
resulting in a tensor of equal shape normalized per channel, specify axes=('batch', 'x', 'y')
.
To normalize each sample independently leave out the 'batch' axis.
Default: Scale all axes jointly.
epsilon for numeric stability: out = (tensor - mean) / (std + eps)
.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
1326class ZeroMeanUnitVarianceDescr(ProcessingDescrBase): 1327 """Subtract mean and divide by variance. 1328 1329 Examples: 1330 Subtract tensor mean and variance 1331 - in YAML 1332 ```yaml 1333 preprocessing: 1334 - id: zero_mean_unit_variance 1335 ``` 1336 - in Python 1337 >>> preprocessing = [ZeroMeanUnitVarianceDescr()] 1338 """ 1339 1340 implemented_id: ClassVar[Literal["zero_mean_unit_variance"]] = ( 1341 "zero_mean_unit_variance" 1342 ) 1343 if TYPE_CHECKING: 1344 id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance" 1345 else: 1346 id: Literal["zero_mean_unit_variance"] 1347 1348 kwargs: ZeroMeanUnitVarianceKwargs = Field( 1349 default_factory=ZeroMeanUnitVarianceKwargs.model_construct 1350 )
Subtract mean and divide by variance.
Examples:
Subtract tensor mean and variance
- in YAML
preprocessing: - id: zero_mean_unit_variance
- in Python
>>> preprocessing = [ZeroMeanUnitVarianceDescr()]
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
1353class ScaleRangeKwargs(ProcessingKwargs): 1354 """key word arguments for `ScaleRangeDescr` 1355 1356 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default) 1357 this processing step normalizes data to the [0, 1] intervall. 1358 For other percentiles the normalized values will partially be outside the [0, 1] 1359 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the 1360 normalized values to a range. 1361 """ 1362 1363 axes: Annotated[ 1364 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 1365 ] = None 1366 """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value. 1367 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 1368 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 1369 To normalize samples independently, leave out the "batch" axis. 1370 Default: Scale all axes jointly.""" 1371 1372 min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0 1373 """The lower percentile used to determine the value to align with zero.""" 1374 1375 max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0 1376 """The upper percentile used to determine the value to align with one. 1377 Has to be bigger than `min_percentile`. 1378 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly 1379 accepting percentiles specified in the range 0.0 to 1.0.""" 1380 1381 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 1382 """Epsilon for numeric stability. 1383 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`; 1384 with `v_lower,v_upper` values at the respective percentiles.""" 1385 1386 reference_tensor: Optional[TensorId] = None 1387 """Tensor ID to compute the percentiles from. Default: The tensor itself. 1388 For any tensor in `inputs` only input tensor references are allowed.""" 1389 1390 @field_validator("max_percentile", mode="after") 1391 @classmethod 1392 def min_smaller_max(cls, value: float, info: ValidationInfo) -> float: 1393 if (min_p := info.data["min_percentile"]) >= value: 1394 raise ValueError(f"min_percentile {min_p} >= max_percentile {value}") 1395 1396 return value
key word arguments for ScaleRangeDescr
For min_percentile
=0.0 (the default) and max_percentile
=100 (the default)
this processing step normalizes data to the [0, 1] intervall.
For other percentiles the normalized values will partially be outside the [0, 1]
intervall. Use ScaleRange
followed by ClipDescr
if you want to limit the
normalized values to a range.
The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value.
For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
resulting in a tensor of equal shape normalized per channel, specify axes=('batch', 'x', 'y')
.
To normalize samples independently, leave out the "batch" axis.
Default: Scale all axes jointly.
The lower percentile used to determine the value to align with zero.
The upper percentile used to determine the value to align with one.
Has to be bigger than min_percentile
.
The range is 1 to 100 instead of 0 to 100 to avoid mistakenly
accepting percentiles specified in the range 0.0 to 1.0.
Epsilon for numeric stability.
out = (tensor - v_lower) / (v_upper - v_lower + eps)
;
with v_lower,v_upper
values at the respective percentiles.
Tensor ID to compute the percentiles from. Default: The tensor itself.
For any tensor in inputs
only input tensor references are allowed.
1390 @field_validator("max_percentile", mode="after") 1391 @classmethod 1392 def min_smaller_max(cls, value: float, info: ValidationInfo) -> float: 1393 if (min_p := info.data["min_percentile"]) >= value: 1394 raise ValueError(f"min_percentile {min_p} >= max_percentile {value}") 1395 1396 return value
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
1399class ScaleRangeDescr(ProcessingDescrBase): 1400 """Scale with percentiles. 1401 1402 Examples: 1403 1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0 1404 - in YAML 1405 ```yaml 1406 preprocessing: 1407 - id: scale_range 1408 kwargs: 1409 axes: ['y', 'x'] 1410 max_percentile: 99.8 1411 min_percentile: 5.0 1412 ``` 1413 - in Python 1414 >>> preprocessing = [ 1415 ... ScaleRangeDescr( 1416 ... kwargs=ScaleRangeKwargs( 1417 ... axes= (AxisId('y'), AxisId('x')), 1418 ... max_percentile= 99.8, 1419 ... min_percentile= 5.0, 1420 ... ) 1421 ... ), 1422 ... ClipDescr( 1423 ... kwargs=ClipKwargs( 1424 ... min=0.0, 1425 ... max=1.0, 1426 ... ) 1427 ... ), 1428 ... ] 1429 1430 2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles. 1431 - in YAML 1432 ```yaml 1433 preprocessing: 1434 - id: scale_range 1435 kwargs: 1436 axes: ['y', 'x'] 1437 max_percentile: 99.8 1438 min_percentile: 5.0 1439 - id: scale_range 1440 - id: clip 1441 kwargs: 1442 min: 0.0 1443 max: 1.0 1444 ``` 1445 - in Python 1446 >>> preprocessing = [ScaleRangeDescr( 1447 ... kwargs=ScaleRangeKwargs( 1448 ... axes= (AxisId('y'), AxisId('x')), 1449 ... max_percentile= 99.8, 1450 ... min_percentile= 5.0, 1451 ... ) 1452 ... )] 1453 1454 """ 1455 1456 implemented_id: ClassVar[Literal["scale_range"]] = "scale_range" 1457 if TYPE_CHECKING: 1458 id: Literal["scale_range"] = "scale_range" 1459 else: 1460 id: Literal["scale_range"] 1461 kwargs: ScaleRangeKwargs = Field(default_factory=ScaleRangeKwargs.model_construct)
Scale with percentiles.
Examples:
- Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0
- in YAML
preprocessing:
- id: scale_range
kwargs:
axes: ['y', 'x']
max_percentile: 99.8
min_percentile: 5.0
- in Python
>>> preprocessing = [
... ScaleRangeDescr(
... kwargs=ScaleRangeKwargs(
... axes= (AxisId('y'), AxisId('x')),
... max_percentile= 99.8,
... min_percentile= 5.0,
... )
... ),
... ClipDescr(
... kwargs=ClipKwargs(
... min=0.0,
... max=1.0,
... )
... ),
... ]
Combine the above scaling with additional clipping to clip values outside the range given by the percentiles.
- in YAML
preprocessing: - id: scale_range kwargs: axes: ['y', 'x'] max_percentile: 99.8 min_percentile: 5.0 - id: scale_range - id: clip kwargs: min: 0.0 max: 1.0
- in Python
>>> preprocessing = [ScaleRangeDescr( ... kwargs=ScaleRangeKwargs( ... axes= (AxisId('y'), AxisId('x')), ... max_percentile= 99.8, ... min_percentile= 5.0, ... ) ... )]
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
1464class ScaleMeanVarianceKwargs(ProcessingKwargs): 1465 """key word arguments for `ScaleMeanVarianceKwargs`""" 1466 1467 reference_tensor: TensorId 1468 """Name of tensor to match.""" 1469 1470 axes: Annotated[ 1471 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 1472 ] = None 1473 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 1474 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 1475 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 1476 To normalize samples independently, leave out the 'batch' axis. 1477 Default: Scale all axes jointly.""" 1478 1479 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 1480 """Epsilon for numeric stability: 1481 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`"""
key word arguments for ScaleMeanVarianceKwargs
The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
resulting in a tensor of equal shape normalized per channel, specify axes=('batch', 'x', 'y')
.
To normalize samples independently, leave out the 'batch' axis.
Default: Scale all axes jointly.
Epsilon for numeric stability:
out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
1484class ScaleMeanVarianceDescr(ProcessingDescrBase): 1485 """Scale a tensor's data distribution to match another tensor's mean/std. 1486 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.` 1487 """ 1488 1489 implemented_id: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance" 1490 if TYPE_CHECKING: 1491 id: Literal["scale_mean_variance"] = "scale_mean_variance" 1492 else: 1493 id: Literal["scale_mean_variance"] 1494 kwargs: ScaleMeanVarianceKwargs
Scale a tensor's data distribution to match another tensor's mean/std.
out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
1530class TensorDescrBase(Node, Generic[IO_AxisT]): 1531 id: TensorId 1532 """Tensor id. No duplicates are allowed.""" 1533 1534 description: Annotated[str, MaxLen(128)] = "" 1535 """free text description""" 1536 1537 axes: NotEmpty[Sequence[IO_AxisT]] 1538 """tensor axes""" 1539 1540 @property 1541 def shape(self): 1542 return tuple(a.size for a in self.axes) 1543 1544 @field_validator("axes", mode="after", check_fields=False) 1545 @classmethod 1546 def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]: 1547 batch_axes = [a for a in axes if a.type == "batch"] 1548 if len(batch_axes) > 1: 1549 raise ValueError( 1550 f"Only one batch axis (per tensor) allowed, but got {batch_axes}" 1551 ) 1552 1553 seen_ids: Set[AxisId] = set() 1554 duplicate_axes_ids: Set[AxisId] = set() 1555 for a in axes: 1556 (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id) 1557 1558 if duplicate_axes_ids: 1559 raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}") 1560 1561 return axes 1562 1563 test_tensor: FAIR[Optional[FileDescr_]] = None 1564 """An example tensor to use for testing. 1565 Using the model with the test input tensors is expected to yield the test output tensors. 1566 Each test tensor has be a an ndarray in the 1567 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format). 1568 The file extension must be '.npy'.""" 1569 1570 sample_tensor: FAIR[Optional[FileDescr_]] = None 1571 """A sample tensor to illustrate a possible input/output for the model, 1572 The sample image primarily serves to inform a human user about an example use case 1573 and is typically stored as .hdf5, .png or .tiff. 1574 It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats) 1575 (numpy's `.npy` format is not supported). 1576 The image dimensionality has to match the number of axes specified in this tensor description. 1577 """ 1578 1579 @model_validator(mode="after") 1580 def _validate_sample_tensor(self) -> Self: 1581 if self.sample_tensor is None or not get_validation_context().perform_io_checks: 1582 return self 1583 1584 reader = get_reader(self.sample_tensor.source, sha256=self.sample_tensor.sha256) 1585 tensor: NDArray[Any] = imread( 1586 reader.read(), 1587 extension=PurePosixPath(reader.original_file_name).suffix, 1588 ) 1589 n_dims = len(tensor.squeeze().shape) 1590 n_dims_min = n_dims_max = len(self.axes) 1591 1592 for a in self.axes: 1593 if isinstance(a, BatchAxis): 1594 n_dims_min -= 1 1595 elif isinstance(a.size, int): 1596 if a.size == 1: 1597 n_dims_min -= 1 1598 elif isinstance(a.size, (ParameterizedSize, DataDependentSize)): 1599 if a.size.min == 1: 1600 n_dims_min -= 1 1601 elif isinstance(a.size, SizeReference): 1602 if a.size.offset < 2: 1603 # size reference may result in singleton axis 1604 n_dims_min -= 1 1605 else: 1606 assert_never(a.size) 1607 1608 n_dims_min = max(0, n_dims_min) 1609 if n_dims < n_dims_min or n_dims > n_dims_max: 1610 raise ValueError( 1611 f"Expected sample tensor to have {n_dims_min} to" 1612 + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})." 1613 ) 1614 1615 return self 1616 1617 data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = ( 1618 IntervalOrRatioDataDescr() 1619 ) 1620 """Description of the tensor's data values, optionally per channel. 1621 If specified per channel, the data `type` needs to match across channels.""" 1622 1623 @property 1624 def dtype( 1625 self, 1626 ) -> Literal[ 1627 "float32", 1628 "float64", 1629 "uint8", 1630 "int8", 1631 "uint16", 1632 "int16", 1633 "uint32", 1634 "int32", 1635 "uint64", 1636 "int64", 1637 "bool", 1638 ]: 1639 """dtype as specified under `data.type` or `data[i].type`""" 1640 if isinstance(self.data, collections.abc.Sequence): 1641 return self.data[0].type 1642 else: 1643 return self.data.type 1644 1645 @field_validator("data", mode="after") 1646 @classmethod 1647 def _check_data_type_across_channels( 1648 cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] 1649 ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]: 1650 if not isinstance(value, list): 1651 return value 1652 1653 dtypes = {t.type for t in value} 1654 if len(dtypes) > 1: 1655 raise ValueError( 1656 "Tensor data descriptions per channel need to agree in their data" 1657 + f" `type`, but found {dtypes}." 1658 ) 1659 1660 return value 1661 1662 @model_validator(mode="after") 1663 def _check_data_matches_channelaxis(self) -> Self: 1664 if not isinstance(self.data, (list, tuple)): 1665 return self 1666 1667 for a in self.axes: 1668 if isinstance(a, ChannelAxis): 1669 size = a.size 1670 assert isinstance(size, int) 1671 break 1672 else: 1673 return self 1674 1675 if len(self.data) != size: 1676 raise ValueError( 1677 f"Got tensor data descriptions for {len(self.data)} channels, but" 1678 + f" '{a.id}' axis has size {size}." 1679 ) 1680 1681 return self 1682 1683 def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]: 1684 if len(array.shape) != len(self.axes): 1685 raise ValueError( 1686 f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})" 1687 + f" incompatible with {len(self.axes)} axes." 1688 ) 1689 return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
An example tensor to use for testing. Using the model with the test input tensors is expected to yield the test output tensors. Each test tensor has be a an ndarray in the numpy.lib file format. The file extension must be '.npy'.
A sample tensor to illustrate a possible input/output for the model,
The sample image primarily serves to inform a human user about an example use case
and is typically stored as .hdf5, .png or .tiff.
It has to be readable by the imageio library
(numpy's .npy
format is not supported).
The image dimensionality has to match the number of axes specified in this tensor description.
Description of the tensor's data values, optionally per channel.
If specified per channel, the data type
needs to match across channels.
1623 @property 1624 def dtype( 1625 self, 1626 ) -> Literal[ 1627 "float32", 1628 "float64", 1629 "uint8", 1630 "int8", 1631 "uint16", 1632 "int16", 1633 "uint32", 1634 "int32", 1635 "uint64", 1636 "int64", 1637 "bool", 1638 ]: 1639 """dtype as specified under `data.type` or `data[i].type`""" 1640 if isinstance(self.data, collections.abc.Sequence): 1641 return self.data[0].type 1642 else: 1643 return self.data.type
dtype as specified under data.type
or data[i].type
1683 def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]: 1684 if len(array.shape) != len(self.axes): 1685 raise ValueError( 1686 f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})" 1687 + f" incompatible with {len(self.axes)} axes." 1688 ) 1689 return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
Inherited Members
1692class InputTensorDescr(TensorDescrBase[InputAxis]): 1693 id: TensorId = TensorId("input") 1694 """Input tensor id. 1695 No duplicates are allowed across all inputs and outputs.""" 1696 1697 optional: bool = False 1698 """indicates that this tensor may be `None`""" 1699 1700 preprocessing: List[PreprocessingDescr] = Field( 1701 default_factory=cast(Callable[[], List[PreprocessingDescr]], list) 1702 ) 1703 1704 """Description of how this input should be preprocessed. 1705 1706 notes: 1707 - If preprocessing does not start with an 'ensure_dtype' entry, it is added 1708 to ensure an input tensor's data type matches the input tensor's data description. 1709 - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an 1710 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally 1711 changing the data type. 1712 """ 1713 1714 @model_validator(mode="after") 1715 def _validate_preprocessing_kwargs(self) -> Self: 1716 axes_ids = [a.id for a in self.axes] 1717 for p in self.preprocessing: 1718 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 1719 if kwargs_axes is None: 1720 continue 1721 1722 if not isinstance(kwargs_axes, collections.abc.Sequence): 1723 raise ValueError( 1724 f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}" 1725 ) 1726 1727 if any(a not in axes_ids for a in kwargs_axes): 1728 raise ValueError( 1729 "`preprocessing.i.kwargs.axes` needs to be subset of axes ids" 1730 ) 1731 1732 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 1733 dtype = self.data.type 1734 else: 1735 dtype = self.data[0].type 1736 1737 # ensure `preprocessing` begins with `EnsureDtypeDescr` 1738 if not self.preprocessing or not isinstance( 1739 self.preprocessing[0], EnsureDtypeDescr 1740 ): 1741 self.preprocessing.insert( 1742 0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 1743 ) 1744 1745 # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 1746 if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)): 1747 self.preprocessing.append( 1748 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 1749 ) 1750 1751 return self
Description of how this input should be preprocessed.
notes:
- If preprocessing does not start with an 'ensure_dtype' entry, it is added to ensure an input tensor's data type matches the input tensor's data description.
- If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally changing the data type.
1754def convert_axes( 1755 axes: str, 1756 *, 1757 shape: Union[ 1758 Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4 1759 ], 1760 tensor_type: Literal["input", "output"], 1761 halo: Optional[Sequence[int]], 1762 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 1763): 1764 ret: List[AnyAxis] = [] 1765 for i, a in enumerate(axes): 1766 axis_type = _AXIS_TYPE_MAP.get(a, a) 1767 if axis_type == "batch": 1768 ret.append(BatchAxis()) 1769 continue 1770 1771 scale = 1.0 1772 if isinstance(shape, _ParameterizedInputShape_v0_4): 1773 if shape.step[i] == 0: 1774 size = shape.min[i] 1775 else: 1776 size = ParameterizedSize(min=shape.min[i], step=shape.step[i]) 1777 elif isinstance(shape, _ImplicitOutputShape_v0_4): 1778 ref_t = str(shape.reference_tensor) 1779 if ref_t.count(".") == 1: 1780 t_id, orig_a_id = ref_t.split(".") 1781 else: 1782 t_id = ref_t 1783 orig_a_id = a 1784 1785 a_id = _AXIS_ID_MAP.get(orig_a_id, a) 1786 if not (orig_scale := shape.scale[i]): 1787 # old way to insert a new axis dimension 1788 size = int(2 * shape.offset[i]) 1789 else: 1790 scale = 1 / orig_scale 1791 if axis_type in ("channel", "index"): 1792 # these axes no longer have a scale 1793 offset_from_scale = orig_scale * size_refs.get( 1794 _TensorName_v0_4(t_id), {} 1795 ).get(orig_a_id, 0) 1796 else: 1797 offset_from_scale = 0 1798 size = SizeReference( 1799 tensor_id=TensorId(t_id), 1800 axis_id=AxisId(a_id), 1801 offset=int(offset_from_scale + 2 * shape.offset[i]), 1802 ) 1803 else: 1804 size = shape[i] 1805 1806 if axis_type == "time": 1807 if tensor_type == "input": 1808 ret.append(TimeInputAxis(size=size, scale=scale)) 1809 else: 1810 assert not isinstance(size, ParameterizedSize) 1811 if halo is None: 1812 ret.append(TimeOutputAxis(size=size, scale=scale)) 1813 else: 1814 assert not isinstance(size, int) 1815 ret.append( 1816 TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i]) 1817 ) 1818 1819 elif axis_type == "index": 1820 if tensor_type == "input": 1821 ret.append(IndexInputAxis(size=size)) 1822 else: 1823 if isinstance(size, ParameterizedSize): 1824 size = DataDependentSize(min=size.min) 1825 1826 ret.append(IndexOutputAxis(size=size)) 1827 elif axis_type == "channel": 1828 assert not isinstance(size, ParameterizedSize) 1829 if isinstance(size, SizeReference): 1830 warnings.warn( 1831 "Conversion of channel size from an implicit output shape may be" 1832 + " wrong" 1833 ) 1834 ret.append( 1835 ChannelAxis( 1836 channel_names=[ 1837 Identifier(f"channel{i}") for i in range(size.offset) 1838 ] 1839 ) 1840 ) 1841 else: 1842 ret.append( 1843 ChannelAxis( 1844 channel_names=[Identifier(f"channel{i}") for i in range(size)] 1845 ) 1846 ) 1847 elif axis_type == "space": 1848 if tensor_type == "input": 1849 ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale)) 1850 else: 1851 assert not isinstance(size, ParameterizedSize) 1852 if halo is None or halo[i] == 0: 1853 ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale)) 1854 elif isinstance(size, int): 1855 raise NotImplementedError( 1856 f"output axis with halo and fixed size (here {size}) not allowed" 1857 ) 1858 else: 1859 ret.append( 1860 SpaceOutputAxisWithHalo( 1861 id=AxisId(a), size=size, scale=scale, halo=halo[i] 1862 ) 1863 ) 1864 1865 return ret
2025class OutputTensorDescr(TensorDescrBase[OutputAxis]): 2026 id: TensorId = TensorId("output") 2027 """Output tensor id. 2028 No duplicates are allowed across all inputs and outputs.""" 2029 2030 postprocessing: List[PostprocessingDescr] = Field( 2031 default_factory=cast(Callable[[], List[PostprocessingDescr]], list) 2032 ) 2033 """Description of how this output should be postprocessed. 2034 2035 note: `postprocessing` always ends with an 'ensure_dtype' operation. 2036 If not given this is added to cast to this tensor's `data.type`. 2037 """ 2038 2039 @model_validator(mode="after") 2040 def _validate_postprocessing_kwargs(self) -> Self: 2041 axes_ids = [a.id for a in self.axes] 2042 for p in self.postprocessing: 2043 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 2044 if kwargs_axes is None: 2045 continue 2046 2047 if not isinstance(kwargs_axes, collections.abc.Sequence): 2048 raise ValueError( 2049 f"expected `axes` sequence, but got {type(kwargs_axes)}" 2050 ) 2051 2052 if any(a not in axes_ids for a in kwargs_axes): 2053 raise ValueError("`kwargs.axes` needs to be subset of axes ids") 2054 2055 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 2056 dtype = self.data.type 2057 else: 2058 dtype = self.data[0].type 2059 2060 # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 2061 if not self.postprocessing or not isinstance( 2062 self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr) 2063 ): 2064 self.postprocessing.append( 2065 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 2066 ) 2067 return self
Description of how this output should be postprocessed.
note: postprocessing
always ends with an 'ensure_dtype' operation.
If not given this is added to cast to this tensor's data.type
.
2117def validate_tensors( 2118 tensors: Mapping[TensorId, Tuple[TensorDescr, Optional[NDArray[Any]]]], 2119 tensor_origin: Literal[ 2120 "test_tensor" 2121 ], # for more precise error messages, e.g. 'test_tensor' 2122): 2123 all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, Optional[int]]]] = {} 2124 2125 def e_msg(d: TensorDescr): 2126 return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]" 2127 2128 for descr, array in tensors.values(): 2129 if array is None: 2130 axis_sizes = {a.id: None for a in descr.axes} 2131 else: 2132 try: 2133 axis_sizes = descr.get_axis_sizes_for_array(array) 2134 except ValueError as e: 2135 raise ValueError(f"{e_msg(descr)} {e}") 2136 2137 all_tensor_axes[descr.id] = {a.id: (a, axis_sizes[a.id]) for a in descr.axes} 2138 2139 for descr, array in tensors.values(): 2140 if array is None: 2141 continue 2142 2143 if descr.dtype in ("float32", "float64"): 2144 invalid_test_tensor_dtype = array.dtype.name not in ( 2145 "float32", 2146 "float64", 2147 "uint8", 2148 "int8", 2149 "uint16", 2150 "int16", 2151 "uint32", 2152 "int32", 2153 "uint64", 2154 "int64", 2155 ) 2156 else: 2157 invalid_test_tensor_dtype = array.dtype.name != descr.dtype 2158 2159 if invalid_test_tensor_dtype: 2160 raise ValueError( 2161 f"{e_msg(descr)}.{tensor_origin}.dtype '{array.dtype.name}' does not" 2162 + f" match described dtype '{descr.dtype}'" 2163 ) 2164 2165 if array.min() > -1e-4 and array.max() < 1e-4: 2166 raise ValueError( 2167 "Output values are too small for reliable testing." 2168 + f" Values <-1e5 or >=1e5 must be present in {tensor_origin}" 2169 ) 2170 2171 for a in descr.axes: 2172 actual_size = all_tensor_axes[descr.id][a.id][1] 2173 if actual_size is None: 2174 continue 2175 2176 if a.size is None: 2177 continue 2178 2179 if isinstance(a.size, int): 2180 if actual_size != a.size: 2181 raise ValueError( 2182 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' " 2183 + f"has incompatible size {actual_size}, expected {a.size}" 2184 ) 2185 elif isinstance(a.size, ParameterizedSize): 2186 _ = a.size.validate_size(actual_size) 2187 elif isinstance(a.size, DataDependentSize): 2188 _ = a.size.validate_size(actual_size) 2189 elif isinstance(a.size, SizeReference): 2190 ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id) 2191 if ref_tensor_axes is None: 2192 raise ValueError( 2193 f"{e_msg(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor" 2194 + f" reference '{a.size.tensor_id}'" 2195 ) 2196 2197 ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None)) 2198 if ref_axis is None or ref_size is None: 2199 raise ValueError( 2200 f"{e_msg(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis" 2201 + f" reference '{a.size.tensor_id}.{a.size.axis_id}" 2202 ) 2203 2204 if a.unit != ref_axis.unit: 2205 raise ValueError( 2206 f"{e_msg(descr)}.axes[{a.id}].size: `SizeReference` requires" 2207 + " axis and reference axis to have the same `unit`, but" 2208 + f" {a.unit}!={ref_axis.unit}" 2209 ) 2210 2211 if actual_size != ( 2212 expected_size := ( 2213 ref_size * ref_axis.scale / a.scale + a.size.offset 2214 ) 2215 ): 2216 raise ValueError( 2217 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' of size" 2218 + f" {actual_size} invalid for referenced size {ref_size};" 2219 + f" expected {expected_size}" 2220 ) 2221 else: 2222 assert_never(a.size)
2242class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr): 2243 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 2244 """Architecture source file""" 2245 2246 @model_serializer(mode="wrap", when_used="unless-none") 2247 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo): 2248 return package_file_descr_serializer(self, nxt, info)
A file description
Architecture source file
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
2251class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr): 2252 import_from: str 2253 """Where to import the callable from, i.e. `from <import_from> import <callable>`"""
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
2313class WeightsEntryDescrBase(FileDescr): 2314 type: ClassVar[WeightsFormat] 2315 weights_format_name: ClassVar[str] # human readable 2316 2317 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 2318 """Source of the weights file.""" 2319 2320 authors: Optional[List[Author]] = None 2321 """Authors 2322 Either the person(s) that have trained this model resulting in the original weights file. 2323 (If this is the initial weights entry, i.e. it does not have a `parent`) 2324 Or the person(s) who have converted the weights to this weights format. 2325 (If this is a child weight, i.e. it has a `parent` field) 2326 """ 2327 2328 parent: Annotated[ 2329 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"]) 2330 ] = None 2331 """The source weights these weights were converted from. 2332 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`, 2333 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights. 2334 All weight entries except one (the initial set of weights resulting from training the model), 2335 need to have this field.""" 2336 2337 comment: str = "" 2338 """A comment about this weights entry, for example how these weights were created.""" 2339 2340 @model_validator(mode="after") 2341 def _validate(self) -> Self: 2342 if self.type == self.parent: 2343 raise ValueError("Weights entry can't be it's own parent.") 2344 2345 return self 2346 2347 @model_serializer(mode="wrap", when_used="unless-none") 2348 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo): 2349 return package_file_descr_serializer(self, nxt, info)
A file description
Source of the weights file.
The source weights these weights were converted from.
For example, if a model's weights were converted from the pytorch_state_dict
format to torchscript
,
The pytorch_state_dict
weights entry has no parent
and is the parent of the torchscript
weights.
All weight entries except one (the initial set of weights resulting from training the model),
need to have this field.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
2352class KerasHdf5WeightsDescr(WeightsEntryDescrBase): 2353 type = "keras_hdf5" 2354 weights_format_name: ClassVar[str] = "Keras HDF5" 2355 tensorflow_version: Version 2356 """TensorFlow version used to create these weights."""
A file description
TensorFlow version used to create these weights.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
2359class OnnxWeightsDescr(WeightsEntryDescrBase): 2360 type = "onnx" 2361 weights_format_name: ClassVar[str] = "ONNX" 2362 opset_version: Annotated[int, Ge(7)] 2363 """ONNX opset version"""
A file description
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
2366class PytorchStateDictWeightsDescr(WeightsEntryDescrBase): 2367 type = "pytorch_state_dict" 2368 weights_format_name: ClassVar[str] = "Pytorch State Dict" 2369 architecture: Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr] 2370 pytorch_version: Version 2371 """Version of the PyTorch library used. 2372 If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible. 2373 """ 2374 dependencies: Optional[FileDescr_dependencies] = None 2375 """Custom depencies beyond pytorch described in a Conda environment file. 2376 Allows to specify custom dependencies, see conda docs: 2377 - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms) 2378 - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually) 2379 2380 The conda environment file should include pytorch and any version pinning has to be compatible with 2381 **pytorch_version**. 2382 """
A file description
Version of the PyTorch library used.
If architecture.depencencies
is specified it has to include pytorch and any version pinning has to be compatible.
Custom depencies beyond pytorch described in a Conda environment file. Allows to specify custom dependencies, see conda docs:
The conda environment file should include pytorch and any version pinning has to be compatible with pytorch_version.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
2385class TensorflowJsWeightsDescr(WeightsEntryDescrBase): 2386 type = "tensorflow_js" 2387 weights_format_name: ClassVar[str] = "Tensorflow.js" 2388 tensorflow_version: Version 2389 """Version of the TensorFlow library used.""" 2390 2391 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 2392 """The multi-file weights. 2393 All required files/folders should be a zip archive."""
A file description
Version of the TensorFlow library used.
The multi-file weights. All required files/folders should be a zip archive.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
2396class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase): 2397 type = "tensorflow_saved_model_bundle" 2398 weights_format_name: ClassVar[str] = "Tensorflow Saved Model" 2399 tensorflow_version: Version 2400 """Version of the TensorFlow library used.""" 2401 2402 dependencies: Optional[FileDescr_dependencies] = None 2403 """Custom dependencies beyond tensorflow. 2404 Should include tensorflow and any version pinning has to be compatible with **tensorflow_version**.""" 2405 2406 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 2407 """The multi-file weights. 2408 All required files/folders should be a zip archive."""
A file description
Version of the TensorFlow library used.
Custom dependencies beyond tensorflow. Should include tensorflow and any version pinning has to be compatible with tensorflow_version.
The multi-file weights. All required files/folders should be a zip archive.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
2411class TorchscriptWeightsDescr(WeightsEntryDescrBase): 2412 type = "torchscript" 2413 weights_format_name: ClassVar[str] = "TorchScript" 2414 pytorch_version: Version 2415 """Version of the PyTorch library used."""
A file description
Version of the PyTorch library used.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
2418class WeightsDescr(Node): 2419 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None 2420 onnx: Optional[OnnxWeightsDescr] = None 2421 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None 2422 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None 2423 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = ( 2424 None 2425 ) 2426 torchscript: Optional[TorchscriptWeightsDescr] = None 2427 2428 @model_validator(mode="after") 2429 def check_entries(self) -> Self: 2430 entries = {wtype for wtype, entry in self if entry is not None} 2431 2432 if not entries: 2433 raise ValueError("Missing weights entry") 2434 2435 entries_wo_parent = { 2436 wtype 2437 for wtype, entry in self 2438 if entry is not None and hasattr(entry, "parent") and entry.parent is None 2439 } 2440 if len(entries_wo_parent) != 1: 2441 issue_warning( 2442 "Exactly one weights entry may not specify the `parent` field (got" 2443 + " {value}). That entry is considered the original set of model weights." 2444 + " Other weight formats are created through conversion of the orignal or" 2445 + " already converted weights. They have to reference the weights format" 2446 + " they were converted from as their `parent`.", 2447 value=len(entries_wo_parent), 2448 field="weights", 2449 ) 2450 2451 for wtype, entry in self: 2452 if entry is None: 2453 continue 2454 2455 assert hasattr(entry, "type") 2456 assert hasattr(entry, "parent") 2457 assert wtype == entry.type 2458 if ( 2459 entry.parent is not None and entry.parent not in entries 2460 ): # self reference checked for `parent` field 2461 raise ValueError( 2462 f"`weights.{wtype}.parent={entry.parent} not in specified weight" 2463 + f" formats: {entries}" 2464 ) 2465 2466 return self 2467 2468 def __getitem__( 2469 self, 2470 key: Literal[ 2471 "keras_hdf5", 2472 "onnx", 2473 "pytorch_state_dict", 2474 "tensorflow_js", 2475 "tensorflow_saved_model_bundle", 2476 "torchscript", 2477 ], 2478 ): 2479 if key == "keras_hdf5": 2480 ret = self.keras_hdf5 2481 elif key == "onnx": 2482 ret = self.onnx 2483 elif key == "pytorch_state_dict": 2484 ret = self.pytorch_state_dict 2485 elif key == "tensorflow_js": 2486 ret = self.tensorflow_js 2487 elif key == "tensorflow_saved_model_bundle": 2488 ret = self.tensorflow_saved_model_bundle 2489 elif key == "torchscript": 2490 ret = self.torchscript 2491 else: 2492 raise KeyError(key) 2493 2494 if ret is None: 2495 raise KeyError(key) 2496 2497 return ret 2498 2499 @property 2500 def available_formats(self): 2501 return { 2502 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}), 2503 **({} if self.onnx is None else {"onnx": self.onnx}), 2504 **( 2505 {} 2506 if self.pytorch_state_dict is None 2507 else {"pytorch_state_dict": self.pytorch_state_dict} 2508 ), 2509 **( 2510 {} 2511 if self.tensorflow_js is None 2512 else {"tensorflow_js": self.tensorflow_js} 2513 ), 2514 **( 2515 {} 2516 if self.tensorflow_saved_model_bundle is None 2517 else { 2518 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle 2519 } 2520 ), 2521 **({} if self.torchscript is None else {"torchscript": self.torchscript}), 2522 } 2523 2524 @property 2525 def missing_formats(self): 2526 return { 2527 wf for wf in get_args(WeightsFormat) if wf not in self.available_formats 2528 }
2428 @model_validator(mode="after") 2429 def check_entries(self) -> Self: 2430 entries = {wtype for wtype, entry in self if entry is not None} 2431 2432 if not entries: 2433 raise ValueError("Missing weights entry") 2434 2435 entries_wo_parent = { 2436 wtype 2437 for wtype, entry in self 2438 if entry is not None and hasattr(entry, "parent") and entry.parent is None 2439 } 2440 if len(entries_wo_parent) != 1: 2441 issue_warning( 2442 "Exactly one weights entry may not specify the `parent` field (got" 2443 + " {value}). That entry is considered the original set of model weights." 2444 + " Other weight formats are created through conversion of the orignal or" 2445 + " already converted weights. They have to reference the weights format" 2446 + " they were converted from as their `parent`.", 2447 value=len(entries_wo_parent), 2448 field="weights", 2449 ) 2450 2451 for wtype, entry in self: 2452 if entry is None: 2453 continue 2454 2455 assert hasattr(entry, "type") 2456 assert hasattr(entry, "parent") 2457 assert wtype == entry.type 2458 if ( 2459 entry.parent is not None and entry.parent not in entries 2460 ): # self reference checked for `parent` field 2461 raise ValueError( 2462 f"`weights.{wtype}.parent={entry.parent} not in specified weight" 2463 + f" formats: {entries}" 2464 ) 2465 2466 return self
2499 @property 2500 def available_formats(self): 2501 return { 2502 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}), 2503 **({} if self.onnx is None else {"onnx": self.onnx}), 2504 **( 2505 {} 2506 if self.pytorch_state_dict is None 2507 else {"pytorch_state_dict": self.pytorch_state_dict} 2508 ), 2509 **( 2510 {} 2511 if self.tensorflow_js is None 2512 else {"tensorflow_js": self.tensorflow_js} 2513 ), 2514 **( 2515 {} 2516 if self.tensorflow_saved_model_bundle is None 2517 else { 2518 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle 2519 } 2520 ), 2521 **({} if self.torchscript is None else {"torchscript": self.torchscript}), 2522 }
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.
Inherited Members
2535class LinkedModel(LinkedResourceBase): 2536 """Reference to a bioimage.io model.""" 2537 2538 id: ModelId 2539 """A valid model `id` from the bioimage.io collection."""
Reference to a bioimage.io model.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
2561class ReproducibilityTolerance(Node, extra="allow"): 2562 """Describes what small numerical differences -- if any -- may be tolerated 2563 in the generated output when executing in different environments. 2564 2565 A tensor element *output* is considered mismatched to the **test_tensor** if 2566 abs(*output* - **test_tensor**) > **absolute_tolerance** + **relative_tolerance** * abs(**test_tensor**). 2567 (Internally we call [numpy.testing.assert_allclose](https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html).) 2568 2569 Motivation: 2570 For testing we can request the respective deep learning frameworks to be as 2571 reproducible as possible by setting seeds and chosing deterministic algorithms, 2572 but differences in operating systems, available hardware and installed drivers 2573 may still lead to numerical differences. 2574 """ 2575 2576 relative_tolerance: RelativeTolerance = 1e-3 2577 """Maximum relative tolerance of reproduced test tensor.""" 2578 2579 absolute_tolerance: AbsoluteTolerance = 1e-4 2580 """Maximum absolute tolerance of reproduced test tensor.""" 2581 2582 mismatched_elements_per_million: MismatchedElementsPerMillion = 100 2583 """Maximum number of mismatched elements/pixels per million to tolerate.""" 2584 2585 output_ids: Sequence[TensorId] = () 2586 """Limits the output tensor IDs these reproducibility details apply to.""" 2587 2588 weights_formats: Sequence[WeightsFormat] = () 2589 """Limits the weights formats these details apply to."""
Describes what small numerical differences -- if any -- may be tolerated in the generated output when executing in different environments.
A tensor element output is considered mismatched to the test_tensor if abs(output - test_tensor) > absolute_tolerance + relative_tolerance * abs(test_tensor). (Internally we call numpy.testing.assert_allclose.)
Motivation:
For testing we can request the respective deep learning frameworks to be as reproducible as possible by setting seeds and chosing deterministic algorithms, but differences in operating systems, available hardware and installed drivers may still lead to numerical differences.
Maximum relative tolerance of reproduced test tensor.
Maximum absolute tolerance of reproduced test tensor.
Maximum number of mismatched elements/pixels per million to tolerate.
Limits the weights formats these details apply to.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
2592class BioimageioConfig(Node, extra="allow"): 2593 reproducibility_tolerance: Sequence[ReproducibilityTolerance] = () 2594 """Tolerances to allow when reproducing the model's test outputs 2595 from the model's test inputs. 2596 Only the first entry matching tensor id and weights format is considered. 2597 """
Tolerances to allow when reproducing the model's test outputs from the model's test inputs. Only the first entry matching tensor id and weights format is considered.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
2600class Config(Node, extra="allow"): 2601 bioimageio: BioimageioConfig = Field( 2602 default_factory=BioimageioConfig.model_construct 2603 )
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
2606class ModelDescr(GenericModelDescrBase): 2607 """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. 2608 These fields are typically stored in a YAML file which we call a model resource description file (model RDF). 2609 """ 2610 2611 implemented_format_version: ClassVar[Literal["0.5.5"]] = "0.5.5" 2612 if TYPE_CHECKING: 2613 format_version: Literal["0.5.5"] = "0.5.5" 2614 else: 2615 format_version: Literal["0.5.5"] 2616 """Version of the bioimage.io model description specification used. 2617 When creating a new model always use the latest micro/patch version described here. 2618 The `format_version` is important for any consumer software to understand how to parse the fields. 2619 """ 2620 2621 implemented_type: ClassVar[Literal["model"]] = "model" 2622 if TYPE_CHECKING: 2623 type: Literal["model"] = "model" 2624 else: 2625 type: Literal["model"] 2626 """Specialized resource type 'model'""" 2627 2628 id: Optional[ModelId] = None 2629 """bioimage.io-wide unique resource identifier 2630 assigned by bioimage.io; version **un**specific.""" 2631 2632 authors: FAIR[List[Author]] = Field( 2633 default_factory=cast(Callable[[], List[Author]], list) 2634 ) 2635 """The authors are the creators of the model RDF and the primary points of contact.""" 2636 2637 documentation: FAIR[Optional[FileSource_documentation]] = None 2638 """URL or relative path to a markdown file with additional documentation. 2639 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory. 2640 The documentation should include a '#[#] Validation' (sub)section 2641 with details on how to quantitatively validate the model on unseen data.""" 2642 2643 @field_validator("documentation", mode="after") 2644 @classmethod 2645 def _validate_documentation( 2646 cls, value: Optional[FileSource_documentation] 2647 ) -> Optional[FileSource_documentation]: 2648 if not get_validation_context().perform_io_checks or value is None: 2649 return value 2650 2651 doc_reader = get_reader(value) 2652 doc_content = doc_reader.read().decode(encoding="utf-8") 2653 if not re.search("#.*[vV]alidation", doc_content): 2654 issue_warning( 2655 "No '# Validation' (sub)section found in {value}.", 2656 value=value, 2657 field="documentation", 2658 ) 2659 2660 return value 2661 2662 inputs: NotEmpty[Sequence[InputTensorDescr]] 2663 """Describes the input tensors expected by this model.""" 2664 2665 @field_validator("inputs", mode="after") 2666 @classmethod 2667 def _validate_input_axes( 2668 cls, inputs: Sequence[InputTensorDescr] 2669 ) -> Sequence[InputTensorDescr]: 2670 input_size_refs = cls._get_axes_with_independent_size(inputs) 2671 2672 for i, ipt in enumerate(inputs): 2673 valid_independent_refs: Dict[ 2674 Tuple[TensorId, AxisId], 2675 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2676 ] = { 2677 **{ 2678 (ipt.id, a.id): (ipt, a, a.size) 2679 for a in ipt.axes 2680 if not isinstance(a, BatchAxis) 2681 and isinstance(a.size, (int, ParameterizedSize)) 2682 }, 2683 **input_size_refs, 2684 } 2685 for a, ax in enumerate(ipt.axes): 2686 cls._validate_axis( 2687 "inputs", 2688 i=i, 2689 tensor_id=ipt.id, 2690 a=a, 2691 axis=ax, 2692 valid_independent_refs=valid_independent_refs, 2693 ) 2694 return inputs 2695 2696 @staticmethod 2697 def _validate_axis( 2698 field_name: str, 2699 i: int, 2700 tensor_id: TensorId, 2701 a: int, 2702 axis: AnyAxis, 2703 valid_independent_refs: Dict[ 2704 Tuple[TensorId, AxisId], 2705 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2706 ], 2707 ): 2708 if isinstance(axis, BatchAxis) or isinstance( 2709 axis.size, (int, ParameterizedSize, DataDependentSize) 2710 ): 2711 return 2712 elif not isinstance(axis.size, SizeReference): 2713 assert_never(axis.size) 2714 2715 # validate axis.size SizeReference 2716 ref = (axis.size.tensor_id, axis.size.axis_id) 2717 if ref not in valid_independent_refs: 2718 raise ValueError( 2719 "Invalid tensor axis reference at" 2720 + f" {field_name}[{i}].axes[{a}].size: {axis.size}." 2721 ) 2722 if ref == (tensor_id, axis.id): 2723 raise ValueError( 2724 "Self-referencing not allowed for" 2725 + f" {field_name}[{i}].axes[{a}].size: {axis.size}" 2726 ) 2727 if axis.type == "channel": 2728 if valid_independent_refs[ref][1].type != "channel": 2729 raise ValueError( 2730 "A channel axis' size may only reference another fixed size" 2731 + " channel axis." 2732 ) 2733 if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names: 2734 ref_size = valid_independent_refs[ref][2] 2735 assert isinstance(ref_size, int), ( 2736 "channel axis ref (another channel axis) has to specify fixed" 2737 + " size" 2738 ) 2739 generated_channel_names = [ 2740 Identifier(axis.channel_names.format(i=i)) 2741 for i in range(1, ref_size + 1) 2742 ] 2743 axis.channel_names = generated_channel_names 2744 2745 if (ax_unit := getattr(axis, "unit", None)) != ( 2746 ref_unit := getattr(valid_independent_refs[ref][1], "unit", None) 2747 ): 2748 raise ValueError( 2749 "The units of an axis and its reference axis need to match, but" 2750 + f" '{ax_unit}' != '{ref_unit}'." 2751 ) 2752 ref_axis = valid_independent_refs[ref][1] 2753 if isinstance(ref_axis, BatchAxis): 2754 raise ValueError( 2755 f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}" 2756 + " (a batch axis is not allowed as reference)." 2757 ) 2758 2759 if isinstance(axis, WithHalo): 2760 min_size = axis.size.get_size(axis, ref_axis, n=0) 2761 if (min_size - 2 * axis.halo) < 1: 2762 raise ValueError( 2763 f"axis {axis.id} with minimum size {min_size} is too small for halo" 2764 + f" {axis.halo}." 2765 ) 2766 2767 input_halo = axis.halo * axis.scale / ref_axis.scale 2768 if input_halo != int(input_halo) or input_halo % 2 == 1: 2769 raise ValueError( 2770 f"input_halo {input_halo} (output_halo {axis.halo} *" 2771 + f" output_scale {axis.scale} / input_scale {ref_axis.scale})" 2772 + f" {tensor_id}.{axis.id}." 2773 ) 2774 2775 @model_validator(mode="after") 2776 def _validate_test_tensors(self) -> Self: 2777 if not get_validation_context().perform_io_checks: 2778 return self 2779 2780 test_output_arrays = [ 2781 None if descr.test_tensor is None else load_array(descr.test_tensor) 2782 for descr in self.outputs 2783 ] 2784 test_input_arrays = [ 2785 None if descr.test_tensor is None else load_array(descr.test_tensor) 2786 for descr in self.inputs 2787 ] 2788 2789 tensors = { 2790 descr.id: (descr, array) 2791 for descr, array in zip( 2792 chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays 2793 ) 2794 } 2795 validate_tensors(tensors, tensor_origin="test_tensor") 2796 2797 output_arrays = { 2798 descr.id: array for descr, array in zip(self.outputs, test_output_arrays) 2799 } 2800 for rep_tol in self.config.bioimageio.reproducibility_tolerance: 2801 if not rep_tol.absolute_tolerance: 2802 continue 2803 2804 if rep_tol.output_ids: 2805 out_arrays = { 2806 oid: a 2807 for oid, a in output_arrays.items() 2808 if oid in rep_tol.output_ids 2809 } 2810 else: 2811 out_arrays = output_arrays 2812 2813 for out_id, array in out_arrays.items(): 2814 if array is None: 2815 continue 2816 2817 if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01: 2818 raise ValueError( 2819 "config.bioimageio.reproducibility_tolerance.absolute_tolerance=" 2820 + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}" 2821 + f" (1% of the maximum value of the test tensor '{out_id}')" 2822 ) 2823 2824 return self 2825 2826 @model_validator(mode="after") 2827 def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self: 2828 ipt_refs = {t.id for t in self.inputs} 2829 out_refs = {t.id for t in self.outputs} 2830 for ipt in self.inputs: 2831 for p in ipt.preprocessing: 2832 ref = p.kwargs.get("reference_tensor") 2833 if ref is None: 2834 continue 2835 if ref not in ipt_refs: 2836 raise ValueError( 2837 f"`reference_tensor` '{ref}' not found. Valid input tensor" 2838 + f" references are: {ipt_refs}." 2839 ) 2840 2841 for out in self.outputs: 2842 for p in out.postprocessing: 2843 ref = p.kwargs.get("reference_tensor") 2844 if ref is None: 2845 continue 2846 2847 if ref not in ipt_refs and ref not in out_refs: 2848 raise ValueError( 2849 f"`reference_tensor` '{ref}' not found. Valid tensor references" 2850 + f" are: {ipt_refs | out_refs}." 2851 ) 2852 2853 return self 2854 2855 # TODO: use validate funcs in validate_test_tensors 2856 # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]: 2857 2858 name: Annotated[ 2859 str, 2860 RestrictCharacters(string.ascii_letters + string.digits + "_+- ()"), 2861 MinLen(5), 2862 MaxLen(128), 2863 warn(MaxLen(64), "Name longer than 64 characters.", INFO), 2864 ] 2865 """A human-readable name of this model. 2866 It should be no longer than 64 characters 2867 and may only contain letter, number, underscore, minus, parentheses and spaces. 2868 We recommend to chose a name that refers to the model's task and image modality. 2869 """ 2870 2871 outputs: NotEmpty[Sequence[OutputTensorDescr]] 2872 """Describes the output tensors.""" 2873 2874 @field_validator("outputs", mode="after") 2875 @classmethod 2876 def _validate_tensor_ids( 2877 cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo 2878 ) -> Sequence[OutputTensorDescr]: 2879 tensor_ids = [ 2880 t.id for t in info.data.get("inputs", []) + info.data.get("outputs", []) 2881 ] 2882 duplicate_tensor_ids: List[str] = [] 2883 seen: Set[str] = set() 2884 for t in tensor_ids: 2885 if t in seen: 2886 duplicate_tensor_ids.append(t) 2887 2888 seen.add(t) 2889 2890 if duplicate_tensor_ids: 2891 raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}") 2892 2893 return outputs 2894 2895 @staticmethod 2896 def _get_axes_with_parameterized_size( 2897 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 2898 ): 2899 return { 2900 f"{t.id}.{a.id}": (t, a, a.size) 2901 for t in io 2902 for a in t.axes 2903 if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize) 2904 } 2905 2906 @staticmethod 2907 def _get_axes_with_independent_size( 2908 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 2909 ): 2910 return { 2911 (t.id, a.id): (t, a, a.size) 2912 for t in io 2913 for a in t.axes 2914 if not isinstance(a, BatchAxis) 2915 and isinstance(a.size, (int, ParameterizedSize)) 2916 } 2917 2918 @field_validator("outputs", mode="after") 2919 @classmethod 2920 def _validate_output_axes( 2921 cls, outputs: List[OutputTensorDescr], info: ValidationInfo 2922 ) -> List[OutputTensorDescr]: 2923 input_size_refs = cls._get_axes_with_independent_size( 2924 info.data.get("inputs", []) 2925 ) 2926 output_size_refs = cls._get_axes_with_independent_size(outputs) 2927 2928 for i, out in enumerate(outputs): 2929 valid_independent_refs: Dict[ 2930 Tuple[TensorId, AxisId], 2931 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2932 ] = { 2933 **{ 2934 (out.id, a.id): (out, a, a.size) 2935 for a in out.axes 2936 if not isinstance(a, BatchAxis) 2937 and isinstance(a.size, (int, ParameterizedSize)) 2938 }, 2939 **input_size_refs, 2940 **output_size_refs, 2941 } 2942 for a, ax in enumerate(out.axes): 2943 cls._validate_axis( 2944 "outputs", 2945 i, 2946 out.id, 2947 a, 2948 ax, 2949 valid_independent_refs=valid_independent_refs, 2950 ) 2951 2952 return outputs 2953 2954 packaged_by: List[Author] = Field( 2955 default_factory=cast(Callable[[], List[Author]], list) 2956 ) 2957 """The persons that have packaged and uploaded this model. 2958 Only required if those persons differ from the `authors`.""" 2959 2960 parent: Optional[LinkedModel] = None 2961 """The model from which this model is derived, e.g. by fine-tuning the weights.""" 2962 2963 @model_validator(mode="after") 2964 def _validate_parent_is_not_self(self) -> Self: 2965 if self.parent is not None and self.parent.id == self.id: 2966 raise ValueError("A model description may not reference itself as parent.") 2967 2968 return self 2969 2970 run_mode: Annotated[ 2971 Optional[RunMode], 2972 warn(None, "Run mode '{value}' has limited support across consumer softwares."), 2973 ] = None 2974 """Custom run mode for this model: for more complex prediction procedures like test time 2975 data augmentation that currently cannot be expressed in the specification. 2976 No standard run modes are defined yet.""" 2977 2978 timestamp: Datetime = Field(default_factory=Datetime.now) 2979 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format 2980 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat). 2981 (In Python a datetime object is valid, too).""" 2982 2983 training_data: Annotated[ 2984 Union[None, LinkedDataset, DatasetDescr, DatasetDescr02], 2985 Field(union_mode="left_to_right"), 2986 ] = None 2987 """The dataset used to train this model""" 2988 2989 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)] 2990 """The weights for this model. 2991 Weights can be given for different formats, but should otherwise be equivalent. 2992 The available weight formats determine which consumers can use this model.""" 2993 2994 config: Config = Field(default_factory=Config.model_construct) 2995 2996 @model_validator(mode="after") 2997 def _add_default_cover(self) -> Self: 2998 if not get_validation_context().perform_io_checks or self.covers: 2999 return self 3000 3001 try: 3002 generated_covers = generate_covers( 3003 [ 3004 (t, load_array(t.test_tensor)) 3005 for t in self.inputs 3006 if t.test_tensor is not None 3007 ], 3008 [ 3009 (t, load_array(t.test_tensor)) 3010 for t in self.outputs 3011 if t.test_tensor is not None 3012 ], 3013 ) 3014 except Exception as e: 3015 issue_warning( 3016 "Failed to generate cover image(s): {e}", 3017 value=self.covers, 3018 msg_context=dict(e=e), 3019 field="covers", 3020 ) 3021 else: 3022 self.covers.extend(generated_covers) 3023 3024 return self 3025 3026 def get_input_test_arrays(self) -> List[NDArray[Any]]: 3027 return self._get_test_arrays(self.inputs) 3028 3029 def get_output_test_arrays(self) -> List[NDArray[Any]]: 3030 return self._get_test_arrays(self.outputs) 3031 3032 @staticmethod 3033 def _get_test_arrays( 3034 io_descr: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 3035 ): 3036 ts: List[FileDescr] = [] 3037 for d in io_descr: 3038 if d.test_tensor is None: 3039 raise ValueError( 3040 f"Failed to get test arrays: description of '{d.id}' is missing a `test_tensor`." 3041 ) 3042 ts.append(d.test_tensor) 3043 3044 data = [load_array(t) for t in ts] 3045 assert all(isinstance(d, np.ndarray) for d in data) 3046 return data 3047 3048 @staticmethod 3049 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int: 3050 batch_size = 1 3051 tensor_with_batchsize: Optional[TensorId] = None 3052 for tid in tensor_sizes: 3053 for aid, s in tensor_sizes[tid].items(): 3054 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size: 3055 continue 3056 3057 if batch_size != 1: 3058 assert tensor_with_batchsize is not None 3059 raise ValueError( 3060 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})" 3061 ) 3062 3063 batch_size = s 3064 tensor_with_batchsize = tid 3065 3066 return batch_size 3067 3068 def get_output_tensor_sizes( 3069 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]] 3070 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]: 3071 """Returns the tensor output sizes for given **input_sizes**. 3072 Only if **input_sizes** has a valid input shape, the tensor output size is exact. 3073 Otherwise it might be larger than the actual (valid) output""" 3074 batch_size = self.get_batch_size(input_sizes) 3075 ns = self.get_ns(input_sizes) 3076 3077 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size) 3078 return tensor_sizes.outputs 3079 3080 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]): 3081 """get parameter `n` for each parameterized axis 3082 such that the valid input size is >= the given input size""" 3083 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {} 3084 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs} 3085 for tid in input_sizes: 3086 for aid, s in input_sizes[tid].items(): 3087 size_descr = axes[tid][aid].size 3088 if isinstance(size_descr, ParameterizedSize): 3089 ret[(tid, aid)] = size_descr.get_n(s) 3090 elif size_descr is None or isinstance(size_descr, (int, SizeReference)): 3091 pass 3092 else: 3093 assert_never(size_descr) 3094 3095 return ret 3096 3097 def get_tensor_sizes( 3098 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int 3099 ) -> _TensorSizes: 3100 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size) 3101 return _TensorSizes( 3102 { 3103 t: { 3104 aa: axis_sizes.inputs[(tt, aa)] 3105 for tt, aa in axis_sizes.inputs 3106 if tt == t 3107 } 3108 for t in {tt for tt, _ in axis_sizes.inputs} 3109 }, 3110 { 3111 t: { 3112 aa: axis_sizes.outputs[(tt, aa)] 3113 for tt, aa in axis_sizes.outputs 3114 if tt == t 3115 } 3116 for t in {tt for tt, _ in axis_sizes.outputs} 3117 }, 3118 ) 3119 3120 def get_axis_sizes( 3121 self, 3122 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], 3123 batch_size: Optional[int] = None, 3124 *, 3125 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None, 3126 ) -> _AxisSizes: 3127 """Determine input and output block shape for scale factors **ns** 3128 of parameterized input sizes. 3129 3130 Args: 3131 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id)) 3132 that is parameterized as `size = min + n * step`. 3133 batch_size: The desired size of the batch dimension. 3134 If given **batch_size** overwrites any batch size present in 3135 **max_input_shape**. Default 1. 3136 max_input_shape: Limits the derived block shapes. 3137 Each axis for which the input size, parameterized by `n`, is larger 3138 than **max_input_shape** is set to the minimal value `n_min` for which 3139 this is still true. 3140 Use this for small input samples or large values of **ns**. 3141 Or simply whenever you know the full input shape. 3142 3143 Returns: 3144 Resolved axis sizes for model inputs and outputs. 3145 """ 3146 max_input_shape = max_input_shape or {} 3147 if batch_size is None: 3148 for (_t_id, a_id), s in max_input_shape.items(): 3149 if a_id == BATCH_AXIS_ID: 3150 batch_size = s 3151 break 3152 else: 3153 batch_size = 1 3154 3155 all_axes = { 3156 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs) 3157 } 3158 3159 inputs: Dict[Tuple[TensorId, AxisId], int] = {} 3160 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {} 3161 3162 def get_axis_size(a: Union[InputAxis, OutputAxis]): 3163 if isinstance(a, BatchAxis): 3164 if (t_descr.id, a.id) in ns: 3165 logger.warning( 3166 "Ignoring unexpected size increment factor (n) for batch axis" 3167 + " of tensor '{}'.", 3168 t_descr.id, 3169 ) 3170 return batch_size 3171 elif isinstance(a.size, int): 3172 if (t_descr.id, a.id) in ns: 3173 logger.warning( 3174 "Ignoring unexpected size increment factor (n) for fixed size" 3175 + " axis '{}' of tensor '{}'.", 3176 a.id, 3177 t_descr.id, 3178 ) 3179 return a.size 3180 elif isinstance(a.size, ParameterizedSize): 3181 if (t_descr.id, a.id) not in ns: 3182 raise ValueError( 3183 "Size increment factor (n) missing for parametrized axis" 3184 + f" '{a.id}' of tensor '{t_descr.id}'." 3185 ) 3186 n = ns[(t_descr.id, a.id)] 3187 s_max = max_input_shape.get((t_descr.id, a.id)) 3188 if s_max is not None: 3189 n = min(n, a.size.get_n(s_max)) 3190 3191 return a.size.get_size(n) 3192 3193 elif isinstance(a.size, SizeReference): 3194 if (t_descr.id, a.id) in ns: 3195 logger.warning( 3196 "Ignoring unexpected size increment factor (n) for axis '{}'" 3197 + " of tensor '{}' with size reference.", 3198 a.id, 3199 t_descr.id, 3200 ) 3201 assert not isinstance(a, BatchAxis) 3202 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id] 3203 assert not isinstance(ref_axis, BatchAxis) 3204 ref_key = (a.size.tensor_id, a.size.axis_id) 3205 ref_size = inputs.get(ref_key, outputs.get(ref_key)) 3206 assert ref_size is not None, ref_key 3207 assert not isinstance(ref_size, _DataDepSize), ref_key 3208 return a.size.get_size( 3209 axis=a, 3210 ref_axis=ref_axis, 3211 ref_size=ref_size, 3212 ) 3213 elif isinstance(a.size, DataDependentSize): 3214 if (t_descr.id, a.id) in ns: 3215 logger.warning( 3216 "Ignoring unexpected increment factor (n) for data dependent" 3217 + " size axis '{}' of tensor '{}'.", 3218 a.id, 3219 t_descr.id, 3220 ) 3221 return _DataDepSize(a.size.min, a.size.max) 3222 else: 3223 assert_never(a.size) 3224 3225 # first resolve all , but the `SizeReference` input sizes 3226 for t_descr in self.inputs: 3227 for a in t_descr.axes: 3228 if not isinstance(a.size, SizeReference): 3229 s = get_axis_size(a) 3230 assert not isinstance(s, _DataDepSize) 3231 inputs[t_descr.id, a.id] = s 3232 3233 # resolve all other input axis sizes 3234 for t_descr in self.inputs: 3235 for a in t_descr.axes: 3236 if isinstance(a.size, SizeReference): 3237 s = get_axis_size(a) 3238 assert not isinstance(s, _DataDepSize) 3239 inputs[t_descr.id, a.id] = s 3240 3241 # resolve all output axis sizes 3242 for t_descr in self.outputs: 3243 for a in t_descr.axes: 3244 assert not isinstance(a.size, ParameterizedSize) 3245 s = get_axis_size(a) 3246 outputs[t_descr.id, a.id] = s 3247 3248 return _AxisSizes(inputs=inputs, outputs=outputs) 3249 3250 @model_validator(mode="before") 3251 @classmethod 3252 def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]: 3253 cls.convert_from_old_format_wo_validation(data) 3254 return data 3255 3256 @classmethod 3257 def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None: 3258 """Convert metadata following an older format version to this classes' format 3259 without validating the result. 3260 """ 3261 if ( 3262 data.get("type") == "model" 3263 and isinstance(fv := data.get("format_version"), str) 3264 and fv.count(".") == 2 3265 ): 3266 fv_parts = fv.split(".") 3267 if any(not p.isdigit() for p in fv_parts): 3268 return 3269 3270 fv_tuple = tuple(map(int, fv_parts)) 3271 3272 assert cls.implemented_format_version_tuple[0:2] == (0, 5) 3273 if fv_tuple[:2] in ((0, 3), (0, 4)): 3274 m04 = _ModelDescr_v0_4.load(data) 3275 if isinstance(m04, InvalidDescr): 3276 try: 3277 updated = _model_conv.convert_as_dict( 3278 m04 # pyright: ignore[reportArgumentType] 3279 ) 3280 except Exception as e: 3281 logger.error( 3282 "Failed to convert from invalid model 0.4 description." 3283 + f"\nerror: {e}" 3284 + "\nProceeding with model 0.5 validation without conversion." 3285 ) 3286 updated = None 3287 else: 3288 updated = _model_conv.convert_as_dict(m04) 3289 3290 if updated is not None: 3291 data.clear() 3292 data.update(updated) 3293 3294 elif fv_tuple[:2] == (0, 5): 3295 # bump patch version 3296 data["format_version"] = cls.implemented_format_version
Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
bioimage.io-wide unique resource identifier assigned by bioimage.io; version unspecific.
URL or relative path to a markdown file with additional documentation.
The recommended documentation file name is README.md
. An .md
suffix is mandatory.
The documentation should include a '#[#] Validation' (sub)section
with details on how to quantitatively validate the model on unseen data.
Describes the input tensors expected by this model.
A human-readable name of this model. It should be no longer than 64 characters and may only contain letter, number, underscore, minus, parentheses and spaces. We recommend to chose a name that refers to the model's task and image modality.
The persons that have packaged and uploaded this model.
Only required if those persons differ from the authors
.
The model from which this model is derived, e.g. by fine-tuning the weights.
Custom run mode for this model: for more complex prediction procedures like test time data augmentation that currently cannot be expressed in the specification. No standard run modes are defined yet.
The dataset used to train this model
The weights for this model. Weights can be given for different formats, but should otherwise be equivalent. The available weight formats determine which consumers can use this model.
3048 @staticmethod 3049 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int: 3050 batch_size = 1 3051 tensor_with_batchsize: Optional[TensorId] = None 3052 for tid in tensor_sizes: 3053 for aid, s in tensor_sizes[tid].items(): 3054 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size: 3055 continue 3056 3057 if batch_size != 1: 3058 assert tensor_with_batchsize is not None 3059 raise ValueError( 3060 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})" 3061 ) 3062 3063 batch_size = s 3064 tensor_with_batchsize = tid 3065 3066 return batch_size
3068 def get_output_tensor_sizes( 3069 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]] 3070 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]: 3071 """Returns the tensor output sizes for given **input_sizes**. 3072 Only if **input_sizes** has a valid input shape, the tensor output size is exact. 3073 Otherwise it might be larger than the actual (valid) output""" 3074 batch_size = self.get_batch_size(input_sizes) 3075 ns = self.get_ns(input_sizes) 3076 3077 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size) 3078 return tensor_sizes.outputs
Returns the tensor output sizes for given input_sizes. Only if input_sizes has a valid input shape, the tensor output size is exact. Otherwise it might be larger than the actual (valid) output
3080 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]): 3081 """get parameter `n` for each parameterized axis 3082 such that the valid input size is >= the given input size""" 3083 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {} 3084 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs} 3085 for tid in input_sizes: 3086 for aid, s in input_sizes[tid].items(): 3087 size_descr = axes[tid][aid].size 3088 if isinstance(size_descr, ParameterizedSize): 3089 ret[(tid, aid)] = size_descr.get_n(s) 3090 elif size_descr is None or isinstance(size_descr, (int, SizeReference)): 3091 pass 3092 else: 3093 assert_never(size_descr) 3094 3095 return ret
get parameter n
for each parameterized axis
such that the valid input size is >= the given input size
3097 def get_tensor_sizes( 3098 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int 3099 ) -> _TensorSizes: 3100 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size) 3101 return _TensorSizes( 3102 { 3103 t: { 3104 aa: axis_sizes.inputs[(tt, aa)] 3105 for tt, aa in axis_sizes.inputs 3106 if tt == t 3107 } 3108 for t in {tt for tt, _ in axis_sizes.inputs} 3109 }, 3110 { 3111 t: { 3112 aa: axis_sizes.outputs[(tt, aa)] 3113 for tt, aa in axis_sizes.outputs 3114 if tt == t 3115 } 3116 for t in {tt for tt, _ in axis_sizes.outputs} 3117 }, 3118 )
3120 def get_axis_sizes( 3121 self, 3122 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], 3123 batch_size: Optional[int] = None, 3124 *, 3125 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None, 3126 ) -> _AxisSizes: 3127 """Determine input and output block shape for scale factors **ns** 3128 of parameterized input sizes. 3129 3130 Args: 3131 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id)) 3132 that is parameterized as `size = min + n * step`. 3133 batch_size: The desired size of the batch dimension. 3134 If given **batch_size** overwrites any batch size present in 3135 **max_input_shape**. Default 1. 3136 max_input_shape: Limits the derived block shapes. 3137 Each axis for which the input size, parameterized by `n`, is larger 3138 than **max_input_shape** is set to the minimal value `n_min` for which 3139 this is still true. 3140 Use this for small input samples or large values of **ns**. 3141 Or simply whenever you know the full input shape. 3142 3143 Returns: 3144 Resolved axis sizes for model inputs and outputs. 3145 """ 3146 max_input_shape = max_input_shape or {} 3147 if batch_size is None: 3148 for (_t_id, a_id), s in max_input_shape.items(): 3149 if a_id == BATCH_AXIS_ID: 3150 batch_size = s 3151 break 3152 else: 3153 batch_size = 1 3154 3155 all_axes = { 3156 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs) 3157 } 3158 3159 inputs: Dict[Tuple[TensorId, AxisId], int] = {} 3160 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {} 3161 3162 def get_axis_size(a: Union[InputAxis, OutputAxis]): 3163 if isinstance(a, BatchAxis): 3164 if (t_descr.id, a.id) in ns: 3165 logger.warning( 3166 "Ignoring unexpected size increment factor (n) for batch axis" 3167 + " of tensor '{}'.", 3168 t_descr.id, 3169 ) 3170 return batch_size 3171 elif isinstance(a.size, int): 3172 if (t_descr.id, a.id) in ns: 3173 logger.warning( 3174 "Ignoring unexpected size increment factor (n) for fixed size" 3175 + " axis '{}' of tensor '{}'.", 3176 a.id, 3177 t_descr.id, 3178 ) 3179 return a.size 3180 elif isinstance(a.size, ParameterizedSize): 3181 if (t_descr.id, a.id) not in ns: 3182 raise ValueError( 3183 "Size increment factor (n) missing for parametrized axis" 3184 + f" '{a.id}' of tensor '{t_descr.id}'." 3185 ) 3186 n = ns[(t_descr.id, a.id)] 3187 s_max = max_input_shape.get((t_descr.id, a.id)) 3188 if s_max is not None: 3189 n = min(n, a.size.get_n(s_max)) 3190 3191 return a.size.get_size(n) 3192 3193 elif isinstance(a.size, SizeReference): 3194 if (t_descr.id, a.id) in ns: 3195 logger.warning( 3196 "Ignoring unexpected size increment factor (n) for axis '{}'" 3197 + " of tensor '{}' with size reference.", 3198 a.id, 3199 t_descr.id, 3200 ) 3201 assert not isinstance(a, BatchAxis) 3202 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id] 3203 assert not isinstance(ref_axis, BatchAxis) 3204 ref_key = (a.size.tensor_id, a.size.axis_id) 3205 ref_size = inputs.get(ref_key, outputs.get(ref_key)) 3206 assert ref_size is not None, ref_key 3207 assert not isinstance(ref_size, _DataDepSize), ref_key 3208 return a.size.get_size( 3209 axis=a, 3210 ref_axis=ref_axis, 3211 ref_size=ref_size, 3212 ) 3213 elif isinstance(a.size, DataDependentSize): 3214 if (t_descr.id, a.id) in ns: 3215 logger.warning( 3216 "Ignoring unexpected increment factor (n) for data dependent" 3217 + " size axis '{}' of tensor '{}'.", 3218 a.id, 3219 t_descr.id, 3220 ) 3221 return _DataDepSize(a.size.min, a.size.max) 3222 else: 3223 assert_never(a.size) 3224 3225 # first resolve all , but the `SizeReference` input sizes 3226 for t_descr in self.inputs: 3227 for a in t_descr.axes: 3228 if not isinstance(a.size, SizeReference): 3229 s = get_axis_size(a) 3230 assert not isinstance(s, _DataDepSize) 3231 inputs[t_descr.id, a.id] = s 3232 3233 # resolve all other input axis sizes 3234 for t_descr in self.inputs: 3235 for a in t_descr.axes: 3236 if isinstance(a.size, SizeReference): 3237 s = get_axis_size(a) 3238 assert not isinstance(s, _DataDepSize) 3239 inputs[t_descr.id, a.id] = s 3240 3241 # resolve all output axis sizes 3242 for t_descr in self.outputs: 3243 for a in t_descr.axes: 3244 assert not isinstance(a.size, ParameterizedSize) 3245 s = get_axis_size(a) 3246 outputs[t_descr.id, a.id] = s 3247 3248 return _AxisSizes(inputs=inputs, outputs=outputs)
Determine input and output block shape for scale factors ns of parameterized input sizes.
Arguments:
- ns: Scale factor
n
for each axis (keyed by (tensor_id, axis_id)) that is parameterized assize = min + n * step
. - batch_size: The desired size of the batch dimension. If given batch_size overwrites any batch size present in max_input_shape. Default 1.
- max_input_shape: Limits the derived block shapes.
Each axis for which the input size, parameterized by
n
, is larger than max_input_shape is set to the minimal valuen_min
for which this is still true. Use this for small input samples or large values of ns. Or simply whenever you know the full input shape.
Returns:
Resolved axis sizes for model inputs and outputs.
3256 @classmethod 3257 def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None: 3258 """Convert metadata following an older format version to this classes' format 3259 without validating the result. 3260 """ 3261 if ( 3262 data.get("type") == "model" 3263 and isinstance(fv := data.get("format_version"), str) 3264 and fv.count(".") == 2 3265 ): 3266 fv_parts = fv.split(".") 3267 if any(not p.isdigit() for p in fv_parts): 3268 return 3269 3270 fv_tuple = tuple(map(int, fv_parts)) 3271 3272 assert cls.implemented_format_version_tuple[0:2] == (0, 5) 3273 if fv_tuple[:2] in ((0, 3), (0, 4)): 3274 m04 = _ModelDescr_v0_4.load(data) 3275 if isinstance(m04, InvalidDescr): 3276 try: 3277 updated = _model_conv.convert_as_dict( 3278 m04 # pyright: ignore[reportArgumentType] 3279 ) 3280 except Exception as e: 3281 logger.error( 3282 "Failed to convert from invalid model 0.4 description." 3283 + f"\nerror: {e}" 3284 + "\nProceeding with model 0.5 validation without conversion." 3285 ) 3286 updated = None 3287 else: 3288 updated = _model_conv.convert_as_dict(m04) 3289 3290 if updated is not None: 3291 data.clear() 3292 data.update(updated) 3293 3294 elif fv_tuple[:2] == (0, 5): 3295 # bump patch version 3296 data["format_version"] = cls.implemented_format_version
Convert metadata following an older format version to this classes' format without validating the result.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
337def init_private_attributes(self: BaseModel, context: Any, /) -> None: 338 """This function is meant to behave like a BaseModel method to initialise private attributes. 339 340 It takes context as an argument since that's what pydantic-core passes when calling it. 341 342 Args: 343 self: The BaseModel instance. 344 context: The context. 345 """ 346 if getattr(self, '__pydantic_private__', None) is None: 347 pydantic_private = {} 348 for name, private_attr in self.__private_attributes__.items(): 349 default = private_attr.get_default() 350 if default is not PydanticUndefined: 351 pydantic_private[name] = default 352 object_setattr(self, '__pydantic_private__', pydantic_private)
This function is meant to behave like a BaseModel method to initialise private attributes.
It takes context as an argument since that's what pydantic-core passes when calling it.
Arguments:
- self: The BaseModel instance.
- context: The context.
Inherited Members
3521def generate_covers( 3522 inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]], 3523 outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]], 3524) -> List[Path]: 3525 def squeeze( 3526 data: NDArray[Any], axes: Sequence[AnyAxis] 3527 ) -> Tuple[NDArray[Any], List[AnyAxis]]: 3528 """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining""" 3529 if data.ndim != len(axes): 3530 raise ValueError( 3531 f"tensor shape {data.shape} does not match described axes" 3532 + f" {[a.id for a in axes]}" 3533 ) 3534 3535 axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1] 3536 return data.squeeze(), axes 3537 3538 def normalize( 3539 data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7 3540 ) -> NDArray[np.float32]: 3541 data = data.astype("float32") 3542 data -= data.min(axis=axis, keepdims=True) 3543 data /= data.max(axis=axis, keepdims=True) + eps 3544 return data 3545 3546 def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]): 3547 original_shape = data.shape 3548 data, axes = squeeze(data, axes) 3549 3550 # take slice fom any batch or index axis if needed 3551 # and convert the first channel axis and take a slice from any additional channel axes 3552 slices: Tuple[slice, ...] = () 3553 ndim = data.ndim 3554 ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2 3555 has_c_axis = False 3556 for i, a in enumerate(axes): 3557 s = data.shape[i] 3558 assert s > 1 3559 if ( 3560 isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis)) 3561 and ndim > ndim_need 3562 ): 3563 data = data[slices + (slice(s // 2 - 1, s // 2),)] 3564 ndim -= 1 3565 elif isinstance(a, ChannelAxis): 3566 if has_c_axis: 3567 # second channel axis 3568 data = data[slices + (slice(0, 1),)] 3569 ndim -= 1 3570 else: 3571 has_c_axis = True 3572 if s == 2: 3573 # visualize two channels with cyan and magenta 3574 data = np.concatenate( 3575 [ 3576 data[slices + (slice(1, 2),)], 3577 data[slices + (slice(0, 1),)], 3578 ( 3579 data[slices + (slice(0, 1),)] 3580 + data[slices + (slice(1, 2),)] 3581 ) 3582 / 2, # TODO: take maximum instead? 3583 ], 3584 axis=i, 3585 ) 3586 elif data.shape[i] == 3: 3587 pass # visualize 3 channels as RGB 3588 else: 3589 # visualize first 3 channels as RGB 3590 data = data[slices + (slice(3),)] 3591 3592 assert data.shape[i] == 3 3593 3594 slices += (slice(None),) 3595 3596 data, axes = squeeze(data, axes) 3597 assert len(axes) == ndim 3598 # take slice from z axis if needed 3599 slices = () 3600 if ndim > ndim_need: 3601 for i, a in enumerate(axes): 3602 s = data.shape[i] 3603 if a.id == AxisId("z"): 3604 data = data[slices + (slice(s // 2 - 1, s // 2),)] 3605 data, axes = squeeze(data, axes) 3606 ndim -= 1 3607 break 3608 3609 slices += (slice(None),) 3610 3611 # take slice from any space or time axis 3612 slices = () 3613 3614 for i, a in enumerate(axes): 3615 if ndim <= ndim_need: 3616 break 3617 3618 s = data.shape[i] 3619 assert s > 1 3620 if isinstance( 3621 a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis) 3622 ): 3623 data = data[slices + (slice(s // 2 - 1, s // 2),)] 3624 ndim -= 1 3625 3626 slices += (slice(None),) 3627 3628 del slices 3629 data, axes = squeeze(data, axes) 3630 assert len(axes) == ndim 3631 3632 if (has_c_axis and ndim != 3) or ndim != 2: 3633 raise ValueError( 3634 f"Failed to construct cover image from shape {original_shape}" 3635 ) 3636 3637 if not has_c_axis: 3638 assert ndim == 2 3639 data = np.repeat(data[:, :, None], 3, axis=2) 3640 axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB")))) 3641 ndim += 1 3642 3643 assert ndim == 3 3644 3645 # transpose axis order such that longest axis comes first... 3646 axis_order: List[int] = list(np.argsort(list(data.shape))) 3647 axis_order.reverse() 3648 # ... and channel axis is last 3649 c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0] 3650 axis_order.append(axis_order.pop(c)) 3651 axes = [axes[ao] for ao in axis_order] 3652 data = data.transpose(axis_order) 3653 3654 # h, w = data.shape[:2] 3655 # if h / w in (1.0 or 2.0): 3656 # pass 3657 # elif h / w < 2: 3658 # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images 3659 3660 norm_along = ( 3661 tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None 3662 ) 3663 # normalize the data and map to 8 bit 3664 data = normalize(data, norm_along) 3665 data = (data * 255).astype("uint8") 3666 3667 return data 3668 3669 def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]): 3670 assert im0.dtype == im1.dtype == np.uint8 3671 assert im0.shape == im1.shape 3672 assert im0.ndim == 3 3673 N, M, C = im0.shape 3674 assert C == 3 3675 out = np.ones((N, M, C), dtype="uint8") 3676 for c in range(C): 3677 outc = np.tril(im0[..., c]) 3678 mask = outc == 0 3679 outc[mask] = np.triu(im1[..., c])[mask] 3680 out[..., c] = outc 3681 3682 return out 3683 3684 if not inputs: 3685 raise ValueError("Missing test input tensor for cover generation.") 3686 3687 if not outputs: 3688 raise ValueError("Missing test output tensor for cover generation.") 3689 3690 ipt_descr, ipt = inputs[0] 3691 out_descr, out = outputs[0] 3692 3693 ipt_img = to_2d_image(ipt, ipt_descr.axes) 3694 out_img = to_2d_image(out, out_descr.axes) 3695 3696 cover_folder = Path(mkdtemp()) 3697 if ipt_img.shape == out_img.shape: 3698 covers = [cover_folder / "cover.png"] 3699 imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img)) 3700 else: 3701 covers = [cover_folder / "input.png", cover_folder / "output.png"] 3702 imwrite(covers[0], ipt_img) 3703 imwrite(covers[1], out_img) 3704 3705 return covers