bioimageio.spec.model.v0_5
1from __future__ import annotations 2 3import collections.abc 4import re 5import string 6import warnings 7from abc import ABC 8from copy import deepcopy 9from itertools import chain 10from math import ceil 11from pathlib import Path, PurePosixPath 12from tempfile import mkdtemp 13from typing import ( 14 TYPE_CHECKING, 15 Any, 16 Callable, 17 ClassVar, 18 Dict, 19 Generic, 20 List, 21 Literal, 22 Mapping, 23 NamedTuple, 24 Optional, 25 Sequence, 26 Set, 27 Tuple, 28 Type, 29 TypeVar, 30 Union, 31 cast, 32) 33 34import numpy as np 35from annotated_types import Ge, Gt, Interval, MaxLen, MinLen, Predicate 36from imageio.v3 import imread, imwrite # pyright: ignore[reportUnknownVariableType] 37from loguru import logger 38from numpy.typing import NDArray 39from pydantic import ( 40 AfterValidator, 41 Discriminator, 42 Field, 43 RootModel, 44 SerializationInfo, 45 SerializerFunctionWrapHandler, 46 Tag, 47 ValidationInfo, 48 WrapSerializer, 49 field_validator, 50 model_serializer, 51 model_validator, 52) 53from typing_extensions import Annotated, Self, assert_never, get_args 54 55from .._internal.common_nodes import ( 56 InvalidDescr, 57 Node, 58 NodeWithExplicitlySetFields, 59) 60from .._internal.constants import DTYPE_LIMITS 61from .._internal.field_warning import issue_warning, warn 62from .._internal.io import BioimageioYamlContent as BioimageioYamlContent 63from .._internal.io import FileDescr as FileDescr 64from .._internal.io import ( 65 FileSource, 66 WithSuffix, 67 YamlValue, 68 get_reader, 69 wo_special_file_name, 70) 71from .._internal.io_basics import Sha256 as Sha256 72from .._internal.io_packaging import ( 73 FileDescr_, 74 FileSource_, 75 package_file_descr_serializer, 76) 77from .._internal.io_utils import load_array 78from .._internal.node_converter import Converter 79from .._internal.types import ( 80 AbsoluteTolerance, 81 LowerCaseIdentifier, 82 LowerCaseIdentifierAnno, 83 MismatchedElementsPerMillion, 84 RelativeTolerance, 85) 86from .._internal.types import Datetime as Datetime 87from .._internal.types import Identifier as Identifier 88from .._internal.types import NotEmpty as NotEmpty 89from .._internal.types import SiUnit as SiUnit 90from .._internal.url import HttpUrl as HttpUrl 91from .._internal.validation_context import get_validation_context 92from .._internal.validator_annotations import RestrictCharacters 93from .._internal.version_type import Version as Version 94from .._internal.warning_levels import INFO 95from ..dataset.v0_2 import DatasetDescr as DatasetDescr02 96from ..dataset.v0_2 import LinkedDataset as LinkedDataset02 97from ..dataset.v0_3 import DatasetDescr as DatasetDescr 98from ..dataset.v0_3 import DatasetId as DatasetId 99from ..dataset.v0_3 import LinkedDataset as LinkedDataset 100from ..dataset.v0_3 import Uploader as Uploader 101from ..generic.v0_3 import ( 102 VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS, 103) 104from ..generic.v0_3 import Author as Author 105from ..generic.v0_3 import BadgeDescr as BadgeDescr 106from ..generic.v0_3 import CiteEntry as CiteEntry 107from ..generic.v0_3 import DeprecatedLicenseId as DeprecatedLicenseId 108from ..generic.v0_3 import Doi as Doi 109from ..generic.v0_3 import ( 110 FileSource_documentation, 111 GenericModelDescrBase, 112 LinkedResourceBase, 113 _author_conv, # pyright: ignore[reportPrivateUsage] 114 _maintainer_conv, # pyright: ignore[reportPrivateUsage] 115) 116from ..generic.v0_3 import LicenseId as LicenseId 117from ..generic.v0_3 import LinkedResource as LinkedResource 118from ..generic.v0_3 import Maintainer as Maintainer 119from ..generic.v0_3 import OrcidId as OrcidId 120from ..generic.v0_3 import RelativeFilePath as RelativeFilePath 121from ..generic.v0_3 import ResourceId as ResourceId 122from .v0_4 import Author as _Author_v0_4 123from .v0_4 import BinarizeDescr as _BinarizeDescr_v0_4 124from .v0_4 import CallableFromDepencency as CallableFromDepencency 125from .v0_4 import CallableFromDepencency as _CallableFromDepencency_v0_4 126from .v0_4 import CallableFromFile as _CallableFromFile_v0_4 127from .v0_4 import ClipDescr as _ClipDescr_v0_4 128from .v0_4 import ClipKwargs as ClipKwargs 129from .v0_4 import ImplicitOutputShape as _ImplicitOutputShape_v0_4 130from .v0_4 import InputTensorDescr as _InputTensorDescr_v0_4 131from .v0_4 import KnownRunMode as KnownRunMode 132from .v0_4 import ModelDescr as _ModelDescr_v0_4 133from .v0_4 import OutputTensorDescr as _OutputTensorDescr_v0_4 134from .v0_4 import ParameterizedInputShape as _ParameterizedInputShape_v0_4 135from .v0_4 import PostprocessingDescr as _PostprocessingDescr_v0_4 136from .v0_4 import PreprocessingDescr as _PreprocessingDescr_v0_4 137from .v0_4 import ProcessingKwargs as ProcessingKwargs 138from .v0_4 import RunMode as RunMode 139from .v0_4 import ScaleLinearDescr as _ScaleLinearDescr_v0_4 140from .v0_4 import ScaleMeanVarianceDescr as _ScaleMeanVarianceDescr_v0_4 141from .v0_4 import ScaleRangeDescr as _ScaleRangeDescr_v0_4 142from .v0_4 import SigmoidDescr as _SigmoidDescr_v0_4 143from .v0_4 import TensorName as _TensorName_v0_4 144from .v0_4 import WeightsFormat as WeightsFormat 145from .v0_4 import ZeroMeanUnitVarianceDescr as _ZeroMeanUnitVarianceDescr_v0_4 146from .v0_4 import package_weights 147 148SpaceUnit = Literal[ 149 "attometer", 150 "angstrom", 151 "centimeter", 152 "decimeter", 153 "exameter", 154 "femtometer", 155 "foot", 156 "gigameter", 157 "hectometer", 158 "inch", 159 "kilometer", 160 "megameter", 161 "meter", 162 "micrometer", 163 "mile", 164 "millimeter", 165 "nanometer", 166 "parsec", 167 "petameter", 168 "picometer", 169 "terameter", 170 "yard", 171 "yoctometer", 172 "yottameter", 173 "zeptometer", 174 "zettameter", 175] 176"""Space unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)""" 177 178TimeUnit = Literal[ 179 "attosecond", 180 "centisecond", 181 "day", 182 "decisecond", 183 "exasecond", 184 "femtosecond", 185 "gigasecond", 186 "hectosecond", 187 "hour", 188 "kilosecond", 189 "megasecond", 190 "microsecond", 191 "millisecond", 192 "minute", 193 "nanosecond", 194 "petasecond", 195 "picosecond", 196 "second", 197 "terasecond", 198 "yoctosecond", 199 "yottasecond", 200 "zeptosecond", 201 "zettasecond", 202] 203"""Time unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)""" 204 205AxisType = Literal["batch", "channel", "index", "time", "space"] 206 207_AXIS_TYPE_MAP: Mapping[str, AxisType] = { 208 "b": "batch", 209 "t": "time", 210 "i": "index", 211 "c": "channel", 212 "x": "space", 213 "y": "space", 214 "z": "space", 215} 216 217_AXIS_ID_MAP = { 218 "b": "batch", 219 "t": "time", 220 "i": "index", 221 "c": "channel", 222} 223 224 225class TensorId(LowerCaseIdentifier): 226 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 227 Annotated[LowerCaseIdentifierAnno, MaxLen(32)] 228 ] 229 230 231def _normalize_axis_id(a: str): 232 a = str(a) 233 normalized = _AXIS_ID_MAP.get(a, a) 234 if a != normalized: 235 logger.opt(depth=3).warning( 236 "Normalized axis id from '{}' to '{}'.", a, normalized 237 ) 238 return normalized 239 240 241class AxisId(LowerCaseIdentifier): 242 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 243 Annotated[ 244 LowerCaseIdentifierAnno, 245 MaxLen(16), 246 AfterValidator(_normalize_axis_id), 247 ] 248 ] 249 250 251def _is_batch(a: str) -> bool: 252 return str(a) == "batch" 253 254 255def _is_not_batch(a: str) -> bool: 256 return not _is_batch(a) 257 258 259NonBatchAxisId = Annotated[AxisId, Predicate(_is_not_batch)] 260 261PostprocessingId = Literal[ 262 "binarize", 263 "clip", 264 "ensure_dtype", 265 "fixed_zero_mean_unit_variance", 266 "scale_linear", 267 "scale_mean_variance", 268 "scale_range", 269 "sigmoid", 270 "zero_mean_unit_variance", 271] 272PreprocessingId = Literal[ 273 "binarize", 274 "clip", 275 "ensure_dtype", 276 "scale_linear", 277 "sigmoid", 278 "zero_mean_unit_variance", 279 "scale_range", 280] 281 282 283SAME_AS_TYPE = "<same as type>" 284 285 286ParameterizedSize_N = int 287""" 288Annotates an integer to calculate a concrete axis size from a `ParameterizedSize`. 289""" 290 291 292class ParameterizedSize(Node): 293 """Describes a range of valid tensor axis sizes as `size = min + n*step`. 294 295 - **min** and **step** are given by the model description. 296 - All blocksize paramters n = 0,1,2,... yield a valid `size`. 297 - A greater blocksize paramter n = 0,1,2,... results in a greater **size**. 298 This allows to adjust the axis size more generically. 299 """ 300 301 N: ClassVar[Type[int]] = ParameterizedSize_N 302 """Positive integer to parameterize this axis""" 303 304 min: Annotated[int, Gt(0)] 305 step: Annotated[int, Gt(0)] 306 307 def validate_size(self, size: int) -> int: 308 if size < self.min: 309 raise ValueError(f"size {size} < {self.min}") 310 if (size - self.min) % self.step != 0: 311 raise ValueError( 312 f"axis of size {size} is not parameterized by `min + n*step` =" 313 + f" `{self.min} + n*{self.step}`" 314 ) 315 316 return size 317 318 def get_size(self, n: ParameterizedSize_N) -> int: 319 return self.min + self.step * n 320 321 def get_n(self, s: int) -> ParameterizedSize_N: 322 """return smallest n parameterizing a size greater or equal than `s`""" 323 return ceil((s - self.min) / self.step) 324 325 326class DataDependentSize(Node): 327 min: Annotated[int, Gt(0)] = 1 328 max: Annotated[Optional[int], Gt(1)] = None 329 330 @model_validator(mode="after") 331 def _validate_max_gt_min(self): 332 if self.max is not None and self.min >= self.max: 333 raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}") 334 335 return self 336 337 def validate_size(self, size: int) -> int: 338 if size < self.min: 339 raise ValueError(f"size {size} < {self.min}") 340 341 if self.max is not None and size > self.max: 342 raise ValueError(f"size {size} > {self.max}") 343 344 return size 345 346 347class SizeReference(Node): 348 """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis. 349 350 `axis.size = reference.size * reference.scale / axis.scale + offset` 351 352 Note: 353 1. The axis and the referenced axis need to have the same unit (or no unit). 354 2. Batch axes may not be referenced. 355 3. Fractions are rounded down. 356 4. If the reference axis is `concatenable` the referencing axis is assumed to be 357 `concatenable` as well with the same block order. 358 359 Example: 360 An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm². 361 Let's assume that we want to express the image height h in relation to its width w 362 instead of only accepting input images of exactly 100*49 pixels 363 (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`). 364 365 >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2) 366 >>> h = SpaceInputAxis( 367 ... id=AxisId("h"), 368 ... size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1), 369 ... unit="millimeter", 370 ... scale=4, 371 ... ) 372 >>> print(h.size.get_size(h, w)) 373 49 374 375 ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49 376 """ 377 378 tensor_id: TensorId 379 """tensor id of the reference axis""" 380 381 axis_id: AxisId 382 """axis id of the reference axis""" 383 384 offset: int = 0 385 386 def get_size( 387 self, 388 axis: Union[ 389 ChannelAxis, 390 IndexInputAxis, 391 IndexOutputAxis, 392 TimeInputAxis, 393 SpaceInputAxis, 394 TimeOutputAxis, 395 TimeOutputAxisWithHalo, 396 SpaceOutputAxis, 397 SpaceOutputAxisWithHalo, 398 ], 399 ref_axis: Union[ 400 ChannelAxis, 401 IndexInputAxis, 402 IndexOutputAxis, 403 TimeInputAxis, 404 SpaceInputAxis, 405 TimeOutputAxis, 406 TimeOutputAxisWithHalo, 407 SpaceOutputAxis, 408 SpaceOutputAxisWithHalo, 409 ], 410 n: ParameterizedSize_N = 0, 411 ref_size: Optional[int] = None, 412 ): 413 """Compute the concrete size for a given axis and its reference axis. 414 415 Args: 416 axis: The axis this `SizeReference` is the size of. 417 ref_axis: The reference axis to compute the size from. 418 n: If the **ref_axis** is parameterized (of type `ParameterizedSize`) 419 and no fixed **ref_size** is given, 420 **n** is used to compute the size of the parameterized **ref_axis**. 421 ref_size: Overwrite the reference size instead of deriving it from 422 **ref_axis** 423 (**ref_axis.scale** is still used; any given **n** is ignored). 424 """ 425 assert ( 426 axis.size == self 427 ), "Given `axis.size` is not defined by this `SizeReference`" 428 429 assert ( 430 ref_axis.id == self.axis_id 431 ), f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}." 432 433 assert axis.unit == ref_axis.unit, ( 434 "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`," 435 f" but {axis.unit}!={ref_axis.unit}" 436 ) 437 if ref_size is None: 438 if isinstance(ref_axis.size, (int, float)): 439 ref_size = ref_axis.size 440 elif isinstance(ref_axis.size, ParameterizedSize): 441 ref_size = ref_axis.size.get_size(n) 442 elif isinstance(ref_axis.size, DataDependentSize): 443 raise ValueError( 444 "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`." 445 ) 446 elif isinstance(ref_axis.size, SizeReference): 447 raise ValueError( 448 "Reference axis referenced in `SizeReference` may not be sized by a" 449 + " `SizeReference` itself." 450 ) 451 else: 452 assert_never(ref_axis.size) 453 454 return int(ref_size * ref_axis.scale / axis.scale + self.offset) 455 456 @staticmethod 457 def _get_unit( 458 axis: Union[ 459 ChannelAxis, 460 IndexInputAxis, 461 IndexOutputAxis, 462 TimeInputAxis, 463 SpaceInputAxis, 464 TimeOutputAxis, 465 TimeOutputAxisWithHalo, 466 SpaceOutputAxis, 467 SpaceOutputAxisWithHalo, 468 ], 469 ): 470 return axis.unit 471 472 473class AxisBase(NodeWithExplicitlySetFields): 474 id: AxisId 475 """An axis id unique across all axes of one tensor.""" 476 477 description: Annotated[str, MaxLen(128)] = "" 478 479 480class WithHalo(Node): 481 halo: Annotated[int, Ge(1)] 482 """The halo should be cropped from the output tensor to avoid boundary effects. 483 It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`. 484 To document a halo that is already cropped by the model use `size.offset` instead.""" 485 486 size: Annotated[ 487 SizeReference, 488 Field( 489 examples=[ 490 10, 491 SizeReference( 492 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 493 ).model_dump(mode="json"), 494 ] 495 ), 496 ] 497 """reference to another axis with an optional offset (see `SizeReference`)""" 498 499 500BATCH_AXIS_ID = AxisId("batch") 501 502 503class BatchAxis(AxisBase): 504 implemented_type: ClassVar[Literal["batch"]] = "batch" 505 if TYPE_CHECKING: 506 type: Literal["batch"] = "batch" 507 else: 508 type: Literal["batch"] 509 510 id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID 511 size: Optional[Literal[1]] = None 512 """The batch size may be fixed to 1, 513 otherwise (the default) it may be chosen arbitrarily depending on available memory""" 514 515 @property 516 def scale(self): 517 return 1.0 518 519 @property 520 def concatenable(self): 521 return True 522 523 @property 524 def unit(self): 525 return None 526 527 528class ChannelAxis(AxisBase): 529 implemented_type: ClassVar[Literal["channel"]] = "channel" 530 if TYPE_CHECKING: 531 type: Literal["channel"] = "channel" 532 else: 533 type: Literal["channel"] 534 535 id: NonBatchAxisId = AxisId("channel") 536 channel_names: NotEmpty[List[Identifier]] 537 538 @property 539 def size(self) -> int: 540 return len(self.channel_names) 541 542 @property 543 def concatenable(self): 544 return False 545 546 @property 547 def scale(self) -> float: 548 return 1.0 549 550 @property 551 def unit(self): 552 return None 553 554 555class IndexAxisBase(AxisBase): 556 implemented_type: ClassVar[Literal["index"]] = "index" 557 if TYPE_CHECKING: 558 type: Literal["index"] = "index" 559 else: 560 type: Literal["index"] 561 562 id: NonBatchAxisId = AxisId("index") 563 564 @property 565 def scale(self) -> float: 566 return 1.0 567 568 @property 569 def unit(self): 570 return None 571 572 573class _WithInputAxisSize(Node): 574 size: Annotated[ 575 Union[Annotated[int, Gt(0)], ParameterizedSize, SizeReference], 576 Field( 577 examples=[ 578 10, 579 ParameterizedSize(min=32, step=16).model_dump(mode="json"), 580 SizeReference( 581 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 582 ).model_dump(mode="json"), 583 ] 584 ), 585 ] 586 """The size/length of this axis can be specified as 587 - fixed integer 588 - parameterized series of valid sizes (`ParameterizedSize`) 589 - reference to another axis with an optional offset (`SizeReference`) 590 """ 591 592 593class IndexInputAxis(IndexAxisBase, _WithInputAxisSize): 594 concatenable: bool = False 595 """If a model has a `concatenable` input axis, it can be processed blockwise, 596 splitting a longer sample axis into blocks matching its input tensor description. 597 Output axes are concatenable if they have a `SizeReference` to a concatenable 598 input axis. 599 """ 600 601 602class IndexOutputAxis(IndexAxisBase): 603 size: Annotated[ 604 Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize], 605 Field( 606 examples=[ 607 10, 608 SizeReference( 609 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 610 ).model_dump(mode="json"), 611 ] 612 ), 613 ] 614 """The size/length of this axis can be specified as 615 - fixed integer 616 - reference to another axis with an optional offset (`SizeReference`) 617 - data dependent size using `DataDependentSize` (size is only known after model inference) 618 """ 619 620 621class TimeAxisBase(AxisBase): 622 implemented_type: ClassVar[Literal["time"]] = "time" 623 if TYPE_CHECKING: 624 type: Literal["time"] = "time" 625 else: 626 type: Literal["time"] 627 628 id: NonBatchAxisId = AxisId("time") 629 unit: Optional[TimeUnit] = None 630 scale: Annotated[float, Gt(0)] = 1.0 631 632 633class TimeInputAxis(TimeAxisBase, _WithInputAxisSize): 634 concatenable: bool = False 635 """If a model has a `concatenable` input axis, it can be processed blockwise, 636 splitting a longer sample axis into blocks matching its input tensor description. 637 Output axes are concatenable if they have a `SizeReference` to a concatenable 638 input axis. 639 """ 640 641 642class SpaceAxisBase(AxisBase): 643 implemented_type: ClassVar[Literal["space"]] = "space" 644 if TYPE_CHECKING: 645 type: Literal["space"] = "space" 646 else: 647 type: Literal["space"] 648 649 id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x") 650 unit: Optional[SpaceUnit] = None 651 scale: Annotated[float, Gt(0)] = 1.0 652 653 654class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize): 655 concatenable: bool = False 656 """If a model has a `concatenable` input axis, it can be processed blockwise, 657 splitting a longer sample axis into blocks matching its input tensor description. 658 Output axes are concatenable if they have a `SizeReference` to a concatenable 659 input axis. 660 """ 661 662 663INPUT_AXIS_TYPES = ( 664 BatchAxis, 665 ChannelAxis, 666 IndexInputAxis, 667 TimeInputAxis, 668 SpaceInputAxis, 669) 670"""intended for isinstance comparisons in py<3.10""" 671 672_InputAxisUnion = Union[ 673 BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis 674] 675InputAxis = Annotated[_InputAxisUnion, Discriminator("type")] 676 677 678class _WithOutputAxisSize(Node): 679 size: Annotated[ 680 Union[Annotated[int, Gt(0)], SizeReference], 681 Field( 682 examples=[ 683 10, 684 SizeReference( 685 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 686 ).model_dump(mode="json"), 687 ] 688 ), 689 ] 690 """The size/length of this axis can be specified as 691 - fixed integer 692 - reference to another axis with an optional offset (see `SizeReference`) 693 """ 694 695 696class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize): 697 pass 698 699 700class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo): 701 pass 702 703 704def _get_halo_axis_discriminator_value(v: Any) -> Literal["with_halo", "wo_halo"]: 705 if isinstance(v, dict): 706 return "with_halo" if "halo" in v else "wo_halo" 707 else: 708 return "with_halo" if hasattr(v, "halo") else "wo_halo" 709 710 711_TimeOutputAxisUnion = Annotated[ 712 Union[ 713 Annotated[TimeOutputAxis, Tag("wo_halo")], 714 Annotated[TimeOutputAxisWithHalo, Tag("with_halo")], 715 ], 716 Discriminator(_get_halo_axis_discriminator_value), 717] 718 719 720class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize): 721 pass 722 723 724class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo): 725 pass 726 727 728_SpaceOutputAxisUnion = Annotated[ 729 Union[ 730 Annotated[SpaceOutputAxis, Tag("wo_halo")], 731 Annotated[SpaceOutputAxisWithHalo, Tag("with_halo")], 732 ], 733 Discriminator(_get_halo_axis_discriminator_value), 734] 735 736 737_OutputAxisUnion = Union[ 738 BatchAxis, ChannelAxis, IndexOutputAxis, _TimeOutputAxisUnion, _SpaceOutputAxisUnion 739] 740OutputAxis = Annotated[_OutputAxisUnion, Discriminator("type")] 741 742OUTPUT_AXIS_TYPES = ( 743 BatchAxis, 744 ChannelAxis, 745 IndexOutputAxis, 746 TimeOutputAxis, 747 TimeOutputAxisWithHalo, 748 SpaceOutputAxis, 749 SpaceOutputAxisWithHalo, 750) 751"""intended for isinstance comparisons in py<3.10""" 752 753 754AnyAxis = Union[InputAxis, OutputAxis] 755 756ANY_AXIS_TYPES = INPUT_AXIS_TYPES + OUTPUT_AXIS_TYPES 757"""intended for isinstance comparisons in py<3.10""" 758 759TVs = Union[ 760 NotEmpty[List[int]], 761 NotEmpty[List[float]], 762 NotEmpty[List[bool]], 763 NotEmpty[List[str]], 764] 765 766 767NominalOrOrdinalDType = Literal[ 768 "float32", 769 "float64", 770 "uint8", 771 "int8", 772 "uint16", 773 "int16", 774 "uint32", 775 "int32", 776 "uint64", 777 "int64", 778 "bool", 779] 780 781 782class NominalOrOrdinalDataDescr(Node): 783 values: TVs 784 """A fixed set of nominal or an ascending sequence of ordinal values. 785 In this case `data.type` is required to be an unsigend integer type, e.g. 'uint8'. 786 String `values` are interpreted as labels for tensor values 0, ..., N. 787 Note: as YAML 1.2 does not natively support a "set" datatype, 788 nominal values should be given as a sequence (aka list/array) as well. 789 """ 790 791 type: Annotated[ 792 NominalOrOrdinalDType, 793 Field( 794 examples=[ 795 "float32", 796 "uint8", 797 "uint16", 798 "int64", 799 "bool", 800 ], 801 ), 802 ] = "uint8" 803 804 @model_validator(mode="after") 805 def _validate_values_match_type( 806 self, 807 ) -> Self: 808 incompatible: List[Any] = [] 809 for v in self.values: 810 if self.type == "bool": 811 if not isinstance(v, bool): 812 incompatible.append(v) 813 elif self.type in DTYPE_LIMITS: 814 if ( 815 isinstance(v, (int, float)) 816 and ( 817 v < DTYPE_LIMITS[self.type].min 818 or v > DTYPE_LIMITS[self.type].max 819 ) 820 or (isinstance(v, str) and "uint" not in self.type) 821 or (isinstance(v, float) and "int" in self.type) 822 ): 823 incompatible.append(v) 824 else: 825 incompatible.append(v) 826 827 if len(incompatible) == 5: 828 incompatible.append("...") 829 break 830 831 if incompatible: 832 raise ValueError( 833 f"data type '{self.type}' incompatible with values {incompatible}" 834 ) 835 836 return self 837 838 unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None 839 840 @property 841 def range(self): 842 if isinstance(self.values[0], str): 843 return 0, len(self.values) - 1 844 else: 845 return min(self.values), max(self.values) 846 847 848IntervalOrRatioDType = Literal[ 849 "float32", 850 "float64", 851 "uint8", 852 "int8", 853 "uint16", 854 "int16", 855 "uint32", 856 "int32", 857 "uint64", 858 "int64", 859] 860 861 862class IntervalOrRatioDataDescr(Node): 863 type: Annotated[ # todo: rename to dtype 864 IntervalOrRatioDType, 865 Field( 866 examples=["float32", "float64", "uint8", "uint16"], 867 ), 868 ] = "float32" 869 range: Tuple[Optional[float], Optional[float]] = ( 870 None, 871 None, 872 ) 873 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor. 874 `None` corresponds to min/max of what can be expressed by **type**.""" 875 unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit" 876 scale: float = 1.0 877 """Scale for data on an interval (or ratio) scale.""" 878 offset: Optional[float] = None 879 """Offset for data on a ratio scale.""" 880 881 882TensorDataDescr = Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr] 883 884 885class ProcessingDescrBase(NodeWithExplicitlySetFields, ABC): 886 """processing base class""" 887 888 889class BinarizeKwargs(ProcessingKwargs): 890 """key word arguments for `BinarizeDescr`""" 891 892 threshold: float 893 """The fixed threshold""" 894 895 896class BinarizeAlongAxisKwargs(ProcessingKwargs): 897 """key word arguments for `BinarizeDescr`""" 898 899 threshold: NotEmpty[List[float]] 900 """The fixed threshold values along `axis`""" 901 902 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 903 """The `threshold` axis""" 904 905 906class BinarizeDescr(ProcessingDescrBase): 907 """Binarize the tensor with a fixed threshold. 908 909 Values above `BinarizeKwargs.threshold`/`BinarizeAlongAxisKwargs.threshold` 910 will be set to one, values below the threshold to zero. 911 912 Examples: 913 - in YAML 914 ```yaml 915 postprocessing: 916 - id: binarize 917 kwargs: 918 axis: 'channel' 919 threshold: [0.25, 0.5, 0.75] 920 ``` 921 - in Python: 922 >>> postprocessing = [BinarizeDescr( 923 ... kwargs=BinarizeAlongAxisKwargs( 924 ... axis=AxisId('channel'), 925 ... threshold=[0.25, 0.5, 0.75], 926 ... ) 927 ... )] 928 """ 929 930 implemented_id: ClassVar[Literal["binarize"]] = "binarize" 931 if TYPE_CHECKING: 932 id: Literal["binarize"] = "binarize" 933 else: 934 id: Literal["binarize"] 935 kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs] 936 937 938class ClipDescr(ProcessingDescrBase): 939 """Set tensor values below min to min and above max to max. 940 941 See `ScaleRangeDescr` for examples. 942 """ 943 944 implemented_id: ClassVar[Literal["clip"]] = "clip" 945 if TYPE_CHECKING: 946 id: Literal["clip"] = "clip" 947 else: 948 id: Literal["clip"] 949 950 kwargs: ClipKwargs 951 952 953class EnsureDtypeKwargs(ProcessingKwargs): 954 """key word arguments for `EnsureDtypeDescr`""" 955 956 dtype: Literal[ 957 "float32", 958 "float64", 959 "uint8", 960 "int8", 961 "uint16", 962 "int16", 963 "uint32", 964 "int32", 965 "uint64", 966 "int64", 967 "bool", 968 ] 969 970 971class EnsureDtypeDescr(ProcessingDescrBase): 972 """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching). 973 974 This can for example be used to ensure the inner neural network model gets a 975 different input tensor data type than the fully described bioimage.io model does. 976 977 Examples: 978 The described bioimage.io model (incl. preprocessing) accepts any 979 float32-compatible tensor, normalizes it with percentiles and clipping and then 980 casts it to uint8, which is what the neural network in this example expects. 981 - in YAML 982 ```yaml 983 inputs: 984 - data: 985 type: float32 # described bioimage.io model is compatible with any float32 input tensor 986 preprocessing: 987 - id: scale_range 988 kwargs: 989 axes: ['y', 'x'] 990 max_percentile: 99.8 991 min_percentile: 5.0 992 - id: clip 993 kwargs: 994 min: 0.0 995 max: 1.0 996 - id: ensure_dtype # the neural network of the model requires uint8 997 kwargs: 998 dtype: uint8 999 ``` 1000 - in Python: 1001 >>> preprocessing = [ 1002 ... ScaleRangeDescr( 1003 ... kwargs=ScaleRangeKwargs( 1004 ... axes= (AxisId('y'), AxisId('x')), 1005 ... max_percentile= 99.8, 1006 ... min_percentile= 5.0, 1007 ... ) 1008 ... ), 1009 ... ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)), 1010 ... EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")), 1011 ... ] 1012 """ 1013 1014 implemented_id: ClassVar[Literal["ensure_dtype"]] = "ensure_dtype" 1015 if TYPE_CHECKING: 1016 id: Literal["ensure_dtype"] = "ensure_dtype" 1017 else: 1018 id: Literal["ensure_dtype"] 1019 1020 kwargs: EnsureDtypeKwargs 1021 1022 1023class ScaleLinearKwargs(ProcessingKwargs): 1024 """Key word arguments for `ScaleLinearDescr`""" 1025 1026 gain: float = 1.0 1027 """multiplicative factor""" 1028 1029 offset: float = 0.0 1030 """additive term""" 1031 1032 @model_validator(mode="after") 1033 def _validate(self) -> Self: 1034 if self.gain == 1.0 and self.offset == 0.0: 1035 raise ValueError( 1036 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 1037 + " != 0.0." 1038 ) 1039 1040 return self 1041 1042 1043class ScaleLinearAlongAxisKwargs(ProcessingKwargs): 1044 """Key word arguments for `ScaleLinearDescr`""" 1045 1046 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 1047 """The axis of gain and offset values.""" 1048 1049 gain: Union[float, NotEmpty[List[float]]] = 1.0 1050 """multiplicative factor""" 1051 1052 offset: Union[float, NotEmpty[List[float]]] = 0.0 1053 """additive term""" 1054 1055 @model_validator(mode="after") 1056 def _validate(self) -> Self: 1057 1058 if isinstance(self.gain, list): 1059 if isinstance(self.offset, list): 1060 if len(self.gain) != len(self.offset): 1061 raise ValueError( 1062 f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match." 1063 ) 1064 else: 1065 self.offset = [float(self.offset)] * len(self.gain) 1066 elif isinstance(self.offset, list): 1067 self.gain = [float(self.gain)] * len(self.offset) 1068 else: 1069 raise ValueError( 1070 "Do not specify an `axis` for scalar gain and offset values." 1071 ) 1072 1073 if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset): 1074 raise ValueError( 1075 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 1076 + " != 0.0." 1077 ) 1078 1079 return self 1080 1081 1082class ScaleLinearDescr(ProcessingDescrBase): 1083 """Fixed linear scaling. 1084 1085 Examples: 1086 1. Scale with scalar gain and offset 1087 - in YAML 1088 ```yaml 1089 preprocessing: 1090 - id: scale_linear 1091 kwargs: 1092 gain: 2.0 1093 offset: 3.0 1094 ``` 1095 - in Python: 1096 >>> preprocessing = [ 1097 ... ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0)) 1098 ... ] 1099 1100 2. Independent scaling along an axis 1101 - in YAML 1102 ```yaml 1103 preprocessing: 1104 - id: scale_linear 1105 kwargs: 1106 axis: 'channel' 1107 gain: [1.0, 2.0, 3.0] 1108 ``` 1109 - in Python: 1110 >>> preprocessing = [ 1111 ... ScaleLinearDescr( 1112 ... kwargs=ScaleLinearAlongAxisKwargs( 1113 ... axis=AxisId("channel"), 1114 ... gain=[1.0, 2.0, 3.0], 1115 ... ) 1116 ... ) 1117 ... ] 1118 1119 """ 1120 1121 implemented_id: ClassVar[Literal["scale_linear"]] = "scale_linear" 1122 if TYPE_CHECKING: 1123 id: Literal["scale_linear"] = "scale_linear" 1124 else: 1125 id: Literal["scale_linear"] 1126 kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs] 1127 1128 1129class SigmoidDescr(ProcessingDescrBase): 1130 """The logistic sigmoid funciton, a.k.a. expit function. 1131 1132 Examples: 1133 - in YAML 1134 ```yaml 1135 postprocessing: 1136 - id: sigmoid 1137 ``` 1138 - in Python: 1139 >>> postprocessing = [SigmoidDescr()] 1140 """ 1141 1142 implemented_id: ClassVar[Literal["sigmoid"]] = "sigmoid" 1143 if TYPE_CHECKING: 1144 id: Literal["sigmoid"] = "sigmoid" 1145 else: 1146 id: Literal["sigmoid"] 1147 1148 @property 1149 def kwargs(self) -> ProcessingKwargs: 1150 """empty kwargs""" 1151 return ProcessingKwargs() 1152 1153 1154class FixedZeroMeanUnitVarianceKwargs(ProcessingKwargs): 1155 """key word arguments for `FixedZeroMeanUnitVarianceDescr`""" 1156 1157 mean: float 1158 """The mean value to normalize with.""" 1159 1160 std: Annotated[float, Ge(1e-6)] 1161 """The standard deviation value to normalize with.""" 1162 1163 1164class FixedZeroMeanUnitVarianceAlongAxisKwargs(ProcessingKwargs): 1165 """key word arguments for `FixedZeroMeanUnitVarianceDescr`""" 1166 1167 mean: NotEmpty[List[float]] 1168 """The mean value(s) to normalize with.""" 1169 1170 std: NotEmpty[List[Annotated[float, Ge(1e-6)]]] 1171 """The standard deviation value(s) to normalize with. 1172 Size must match `mean` values.""" 1173 1174 axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])] 1175 """The axis of the mean/std values to normalize each entry along that dimension 1176 separately.""" 1177 1178 @model_validator(mode="after") 1179 def _mean_and_std_match(self) -> Self: 1180 if len(self.mean) != len(self.std): 1181 raise ValueError( 1182 f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})" 1183 + " must match." 1184 ) 1185 1186 return self 1187 1188 1189class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase): 1190 """Subtract a given mean and divide by the standard deviation. 1191 1192 Normalize with fixed, precomputed values for 1193 `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std` 1194 Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given 1195 axes. 1196 1197 Examples: 1198 1. scalar value for whole tensor 1199 - in YAML 1200 ```yaml 1201 preprocessing: 1202 - id: fixed_zero_mean_unit_variance 1203 kwargs: 1204 mean: 103.5 1205 std: 13.7 1206 ``` 1207 - in Python 1208 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 1209 ... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7) 1210 ... )] 1211 1212 2. independently along an axis 1213 - in YAML 1214 ```yaml 1215 preprocessing: 1216 - id: fixed_zero_mean_unit_variance 1217 kwargs: 1218 axis: channel 1219 mean: [101.5, 102.5, 103.5] 1220 std: [11.7, 12.7, 13.7] 1221 ``` 1222 - in Python 1223 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 1224 ... kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( 1225 ... axis=AxisId("channel"), 1226 ... mean=[101.5, 102.5, 103.5], 1227 ... std=[11.7, 12.7, 13.7], 1228 ... ) 1229 ... )] 1230 """ 1231 1232 implemented_id: ClassVar[Literal["fixed_zero_mean_unit_variance"]] = ( 1233 "fixed_zero_mean_unit_variance" 1234 ) 1235 if TYPE_CHECKING: 1236 id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance" 1237 else: 1238 id: Literal["fixed_zero_mean_unit_variance"] 1239 1240 kwargs: Union[ 1241 FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs 1242 ] 1243 1244 1245class ZeroMeanUnitVarianceKwargs(ProcessingKwargs): 1246 """key word arguments for `ZeroMeanUnitVarianceDescr`""" 1247 1248 axes: Annotated[ 1249 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 1250 ] = None 1251 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 1252 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 1253 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 1254 To normalize each sample independently leave out the 'batch' axis. 1255 Default: Scale all axes jointly.""" 1256 1257 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 1258 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`.""" 1259 1260 1261class ZeroMeanUnitVarianceDescr(ProcessingDescrBase): 1262 """Subtract mean and divide by variance. 1263 1264 Examples: 1265 Subtract tensor mean and variance 1266 - in YAML 1267 ```yaml 1268 preprocessing: 1269 - id: zero_mean_unit_variance 1270 ``` 1271 - in Python 1272 >>> preprocessing = [ZeroMeanUnitVarianceDescr()] 1273 """ 1274 1275 implemented_id: ClassVar[Literal["zero_mean_unit_variance"]] = ( 1276 "zero_mean_unit_variance" 1277 ) 1278 if TYPE_CHECKING: 1279 id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance" 1280 else: 1281 id: Literal["zero_mean_unit_variance"] 1282 1283 kwargs: ZeroMeanUnitVarianceKwargs = Field( 1284 default_factory=ZeroMeanUnitVarianceKwargs 1285 ) 1286 1287 1288class ScaleRangeKwargs(ProcessingKwargs): 1289 """key word arguments for `ScaleRangeDescr` 1290 1291 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default) 1292 this processing step normalizes data to the [0, 1] intervall. 1293 For other percentiles the normalized values will partially be outside the [0, 1] 1294 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the 1295 normalized values to a range. 1296 """ 1297 1298 axes: Annotated[ 1299 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 1300 ] = None 1301 """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value. 1302 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 1303 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 1304 To normalize samples independently, leave out the "batch" axis. 1305 Default: Scale all axes jointly.""" 1306 1307 min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0 1308 """The lower percentile used to determine the value to align with zero.""" 1309 1310 max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0 1311 """The upper percentile used to determine the value to align with one. 1312 Has to be bigger than `min_percentile`. 1313 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly 1314 accepting percentiles specified in the range 0.0 to 1.0.""" 1315 1316 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 1317 """Epsilon for numeric stability. 1318 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`; 1319 with `v_lower,v_upper` values at the respective percentiles.""" 1320 1321 reference_tensor: Optional[TensorId] = None 1322 """Tensor ID to compute the percentiles from. Default: The tensor itself. 1323 For any tensor in `inputs` only input tensor references are allowed.""" 1324 1325 @field_validator("max_percentile", mode="after") 1326 @classmethod 1327 def min_smaller_max(cls, value: float, info: ValidationInfo) -> float: 1328 if (min_p := info.data["min_percentile"]) >= value: 1329 raise ValueError(f"min_percentile {min_p} >= max_percentile {value}") 1330 1331 return value 1332 1333 1334class ScaleRangeDescr(ProcessingDescrBase): 1335 """Scale with percentiles. 1336 1337 Examples: 1338 1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0 1339 - in YAML 1340 ```yaml 1341 preprocessing: 1342 - id: scale_range 1343 kwargs: 1344 axes: ['y', 'x'] 1345 max_percentile: 99.8 1346 min_percentile: 5.0 1347 ``` 1348 - in Python 1349 >>> preprocessing = [ 1350 ... ScaleRangeDescr( 1351 ... kwargs=ScaleRangeKwargs( 1352 ... axes= (AxisId('y'), AxisId('x')), 1353 ... max_percentile= 99.8, 1354 ... min_percentile= 5.0, 1355 ... ) 1356 ... ), 1357 ... ClipDescr( 1358 ... kwargs=ClipKwargs( 1359 ... min=0.0, 1360 ... max=1.0, 1361 ... ) 1362 ... ), 1363 ... ] 1364 1365 2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles. 1366 - in YAML 1367 ```yaml 1368 preprocessing: 1369 - id: scale_range 1370 kwargs: 1371 axes: ['y', 'x'] 1372 max_percentile: 99.8 1373 min_percentile: 5.0 1374 - id: scale_range 1375 - id: clip 1376 kwargs: 1377 min: 0.0 1378 max: 1.0 1379 ``` 1380 - in Python 1381 >>> preprocessing = [ScaleRangeDescr( 1382 ... kwargs=ScaleRangeKwargs( 1383 ... axes= (AxisId('y'), AxisId('x')), 1384 ... max_percentile= 99.8, 1385 ... min_percentile= 5.0, 1386 ... ) 1387 ... )] 1388 1389 """ 1390 1391 implemented_id: ClassVar[Literal["scale_range"]] = "scale_range" 1392 if TYPE_CHECKING: 1393 id: Literal["scale_range"] = "scale_range" 1394 else: 1395 id: Literal["scale_range"] 1396 kwargs: ScaleRangeKwargs 1397 1398 1399class ScaleMeanVarianceKwargs(ProcessingKwargs): 1400 """key word arguments for `ScaleMeanVarianceKwargs`""" 1401 1402 reference_tensor: TensorId 1403 """Name of tensor to match.""" 1404 1405 axes: Annotated[ 1406 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 1407 ] = None 1408 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 1409 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 1410 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 1411 To normalize samples independently, leave out the 'batch' axis. 1412 Default: Scale all axes jointly.""" 1413 1414 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 1415 """Epsilon for numeric stability: 1416 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`""" 1417 1418 1419class ScaleMeanVarianceDescr(ProcessingDescrBase): 1420 """Scale a tensor's data distribution to match another tensor's mean/std. 1421 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.` 1422 """ 1423 1424 implemented_id: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance" 1425 if TYPE_CHECKING: 1426 id: Literal["scale_mean_variance"] = "scale_mean_variance" 1427 else: 1428 id: Literal["scale_mean_variance"] 1429 kwargs: ScaleMeanVarianceKwargs 1430 1431 1432PreprocessingDescr = Annotated[ 1433 Union[ 1434 BinarizeDescr, 1435 ClipDescr, 1436 EnsureDtypeDescr, 1437 ScaleLinearDescr, 1438 SigmoidDescr, 1439 FixedZeroMeanUnitVarianceDescr, 1440 ZeroMeanUnitVarianceDescr, 1441 ScaleRangeDescr, 1442 ], 1443 Discriminator("id"), 1444] 1445PostprocessingDescr = Annotated[ 1446 Union[ 1447 BinarizeDescr, 1448 ClipDescr, 1449 EnsureDtypeDescr, 1450 ScaleLinearDescr, 1451 SigmoidDescr, 1452 FixedZeroMeanUnitVarianceDescr, 1453 ZeroMeanUnitVarianceDescr, 1454 ScaleRangeDescr, 1455 ScaleMeanVarianceDescr, 1456 ], 1457 Discriminator("id"), 1458] 1459 1460IO_AxisT = TypeVar("IO_AxisT", InputAxis, OutputAxis) 1461 1462 1463class TensorDescrBase(Node, Generic[IO_AxisT]): 1464 id: TensorId 1465 """Tensor id. No duplicates are allowed.""" 1466 1467 description: Annotated[str, MaxLen(128)] = "" 1468 """free text description""" 1469 1470 axes: NotEmpty[Sequence[IO_AxisT]] 1471 """tensor axes""" 1472 1473 @property 1474 def shape(self): 1475 return tuple(a.size for a in self.axes) 1476 1477 @field_validator("axes", mode="after", check_fields=False) 1478 @classmethod 1479 def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]: 1480 batch_axes = [a for a in axes if a.type == "batch"] 1481 if len(batch_axes) > 1: 1482 raise ValueError( 1483 f"Only one batch axis (per tensor) allowed, but got {batch_axes}" 1484 ) 1485 1486 seen_ids: Set[AxisId] = set() 1487 duplicate_axes_ids: Set[AxisId] = set() 1488 for a in axes: 1489 (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id) 1490 1491 if duplicate_axes_ids: 1492 raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}") 1493 1494 return axes 1495 1496 test_tensor: FileDescr_ 1497 """An example tensor to use for testing. 1498 Using the model with the test input tensors is expected to yield the test output tensors. 1499 Each test tensor has be a an ndarray in the 1500 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format). 1501 The file extension must be '.npy'.""" 1502 1503 sample_tensor: Optional[FileDescr_] = None 1504 """A sample tensor to illustrate a possible input/output for the model, 1505 The sample image primarily serves to inform a human user about an example use case 1506 and is typically stored as .hdf5, .png or .tiff. 1507 It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats) 1508 (numpy's `.npy` format is not supported). 1509 The image dimensionality has to match the number of axes specified in this tensor description. 1510 """ 1511 1512 @model_validator(mode="after") 1513 def _validate_sample_tensor(self) -> Self: 1514 if self.sample_tensor is None or not get_validation_context().perform_io_checks: 1515 return self 1516 1517 reader = get_reader(self.sample_tensor.source, sha256=self.sample_tensor.sha256) 1518 tensor: NDArray[Any] = imread( 1519 reader.read(), 1520 extension=PurePosixPath(reader.original_file_name).suffix, 1521 ) 1522 n_dims = len(tensor.squeeze().shape) 1523 n_dims_min = n_dims_max = len(self.axes) 1524 1525 for a in self.axes: 1526 if isinstance(a, BatchAxis): 1527 n_dims_min -= 1 1528 elif isinstance(a.size, int): 1529 if a.size == 1: 1530 n_dims_min -= 1 1531 elif isinstance(a.size, (ParameterizedSize, DataDependentSize)): 1532 if a.size.min == 1: 1533 n_dims_min -= 1 1534 elif isinstance(a.size, SizeReference): 1535 if a.size.offset < 2: 1536 # size reference may result in singleton axis 1537 n_dims_min -= 1 1538 else: 1539 assert_never(a.size) 1540 1541 n_dims_min = max(0, n_dims_min) 1542 if n_dims < n_dims_min or n_dims > n_dims_max: 1543 raise ValueError( 1544 f"Expected sample tensor to have {n_dims_min} to" 1545 + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})." 1546 ) 1547 1548 return self 1549 1550 data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = ( 1551 IntervalOrRatioDataDescr() 1552 ) 1553 """Description of the tensor's data values, optionally per channel. 1554 If specified per channel, the data `type` needs to match across channels.""" 1555 1556 @property 1557 def dtype( 1558 self, 1559 ) -> Literal[ 1560 "float32", 1561 "float64", 1562 "uint8", 1563 "int8", 1564 "uint16", 1565 "int16", 1566 "uint32", 1567 "int32", 1568 "uint64", 1569 "int64", 1570 "bool", 1571 ]: 1572 """dtype as specified under `data.type` or `data[i].type`""" 1573 if isinstance(self.data, collections.abc.Sequence): 1574 return self.data[0].type 1575 else: 1576 return self.data.type 1577 1578 @field_validator("data", mode="after") 1579 @classmethod 1580 def _check_data_type_across_channels( 1581 cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] 1582 ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]: 1583 if not isinstance(value, list): 1584 return value 1585 1586 dtypes = {t.type for t in value} 1587 if len(dtypes) > 1: 1588 raise ValueError( 1589 "Tensor data descriptions per channel need to agree in their data" 1590 + f" `type`, but found {dtypes}." 1591 ) 1592 1593 return value 1594 1595 @model_validator(mode="after") 1596 def _check_data_matches_channelaxis(self) -> Self: 1597 if not isinstance(self.data, (list, tuple)): 1598 return self 1599 1600 for a in self.axes: 1601 if isinstance(a, ChannelAxis): 1602 size = a.size 1603 assert isinstance(size, int) 1604 break 1605 else: 1606 return self 1607 1608 if len(self.data) != size: 1609 raise ValueError( 1610 f"Got tensor data descriptions for {len(self.data)} channels, but" 1611 + f" '{a.id}' axis has size {size}." 1612 ) 1613 1614 return self 1615 1616 def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]: 1617 if len(array.shape) != len(self.axes): 1618 raise ValueError( 1619 f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})" 1620 + f" incompatible with {len(self.axes)} axes." 1621 ) 1622 return {a.id: array.shape[i] for i, a in enumerate(self.axes)} 1623 1624 1625class InputTensorDescr(TensorDescrBase[InputAxis]): 1626 id: TensorId = TensorId("input") 1627 """Input tensor id. 1628 No duplicates are allowed across all inputs and outputs.""" 1629 1630 optional: bool = False 1631 """indicates that this tensor may be `None`""" 1632 1633 preprocessing: List[PreprocessingDescr] = Field( 1634 default_factory=cast(Callable[[], List[PreprocessingDescr]], list) 1635 ) 1636 1637 """Description of how this input should be preprocessed. 1638 1639 notes: 1640 - If preprocessing does not start with an 'ensure_dtype' entry, it is added 1641 to ensure an input tensor's data type matches the input tensor's data description. 1642 - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an 1643 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally 1644 changing the data type. 1645 """ 1646 1647 @model_validator(mode="after") 1648 def _validate_preprocessing_kwargs(self) -> Self: 1649 axes_ids = [a.id for a in self.axes] 1650 for p in self.preprocessing: 1651 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 1652 if kwargs_axes is None: 1653 continue 1654 1655 if not isinstance(kwargs_axes, collections.abc.Sequence): 1656 raise ValueError( 1657 f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}" 1658 ) 1659 1660 if any(a not in axes_ids for a in kwargs_axes): 1661 raise ValueError( 1662 "`preprocessing.i.kwargs.axes` needs to be subset of axes ids" 1663 ) 1664 1665 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 1666 dtype = self.data.type 1667 else: 1668 dtype = self.data[0].type 1669 1670 # ensure `preprocessing` begins with `EnsureDtypeDescr` 1671 if not self.preprocessing or not isinstance( 1672 self.preprocessing[0], EnsureDtypeDescr 1673 ): 1674 self.preprocessing.insert( 1675 0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 1676 ) 1677 1678 # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 1679 if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)): 1680 self.preprocessing.append( 1681 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 1682 ) 1683 1684 return self 1685 1686 1687def convert_axes( 1688 axes: str, 1689 *, 1690 shape: Union[ 1691 Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4 1692 ], 1693 tensor_type: Literal["input", "output"], 1694 halo: Optional[Sequence[int]], 1695 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 1696): 1697 ret: List[AnyAxis] = [] 1698 for i, a in enumerate(axes): 1699 axis_type = _AXIS_TYPE_MAP.get(a, a) 1700 if axis_type == "batch": 1701 ret.append(BatchAxis()) 1702 continue 1703 1704 scale = 1.0 1705 if isinstance(shape, _ParameterizedInputShape_v0_4): 1706 if shape.step[i] == 0: 1707 size = shape.min[i] 1708 else: 1709 size = ParameterizedSize(min=shape.min[i], step=shape.step[i]) 1710 elif isinstance(shape, _ImplicitOutputShape_v0_4): 1711 ref_t = str(shape.reference_tensor) 1712 if ref_t.count(".") == 1: 1713 t_id, orig_a_id = ref_t.split(".") 1714 else: 1715 t_id = ref_t 1716 orig_a_id = a 1717 1718 a_id = _AXIS_ID_MAP.get(orig_a_id, a) 1719 if not (orig_scale := shape.scale[i]): 1720 # old way to insert a new axis dimension 1721 size = int(2 * shape.offset[i]) 1722 else: 1723 scale = 1 / orig_scale 1724 if axis_type in ("channel", "index"): 1725 # these axes no longer have a scale 1726 offset_from_scale = orig_scale * size_refs.get( 1727 _TensorName_v0_4(t_id), {} 1728 ).get(orig_a_id, 0) 1729 else: 1730 offset_from_scale = 0 1731 size = SizeReference( 1732 tensor_id=TensorId(t_id), 1733 axis_id=AxisId(a_id), 1734 offset=int(offset_from_scale + 2 * shape.offset[i]), 1735 ) 1736 else: 1737 size = shape[i] 1738 1739 if axis_type == "time": 1740 if tensor_type == "input": 1741 ret.append(TimeInputAxis(size=size, scale=scale)) 1742 else: 1743 assert not isinstance(size, ParameterizedSize) 1744 if halo is None: 1745 ret.append(TimeOutputAxis(size=size, scale=scale)) 1746 else: 1747 assert not isinstance(size, int) 1748 ret.append( 1749 TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i]) 1750 ) 1751 1752 elif axis_type == "index": 1753 if tensor_type == "input": 1754 ret.append(IndexInputAxis(size=size)) 1755 else: 1756 if isinstance(size, ParameterizedSize): 1757 size = DataDependentSize(min=size.min) 1758 1759 ret.append(IndexOutputAxis(size=size)) 1760 elif axis_type == "channel": 1761 assert not isinstance(size, ParameterizedSize) 1762 if isinstance(size, SizeReference): 1763 warnings.warn( 1764 "Conversion of channel size from an implicit output shape may be" 1765 + " wrong" 1766 ) 1767 ret.append( 1768 ChannelAxis( 1769 channel_names=[ 1770 Identifier(f"channel{i}") for i in range(size.offset) 1771 ] 1772 ) 1773 ) 1774 else: 1775 ret.append( 1776 ChannelAxis( 1777 channel_names=[Identifier(f"channel{i}") for i in range(size)] 1778 ) 1779 ) 1780 elif axis_type == "space": 1781 if tensor_type == "input": 1782 ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale)) 1783 else: 1784 assert not isinstance(size, ParameterizedSize) 1785 if halo is None or halo[i] == 0: 1786 ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale)) 1787 elif isinstance(size, int): 1788 raise NotImplementedError( 1789 f"output axis with halo and fixed size (here {size}) not allowed" 1790 ) 1791 else: 1792 ret.append( 1793 SpaceOutputAxisWithHalo( 1794 id=AxisId(a), size=size, scale=scale, halo=halo[i] 1795 ) 1796 ) 1797 1798 return ret 1799 1800 1801def _axes_letters_to_ids( 1802 axes: Optional[str], 1803) -> Optional[List[AxisId]]: 1804 if axes is None: 1805 return None 1806 1807 return [AxisId(a) for a in axes] 1808 1809 1810def _get_complement_v04_axis( 1811 tensor_axes: Sequence[str], axes: Optional[Sequence[str]] 1812) -> Optional[AxisId]: 1813 if axes is None: 1814 return None 1815 1816 non_complement_axes = set(axes) | {"b"} 1817 complement_axes = [a for a in tensor_axes if a not in non_complement_axes] 1818 if len(complement_axes) > 1: 1819 raise ValueError( 1820 f"Expected none or a single complement axis, but axes '{axes}' " 1821 + f"for tensor dims '{tensor_axes}' leave '{complement_axes}'." 1822 ) 1823 1824 return None if not complement_axes else AxisId(complement_axes[0]) 1825 1826 1827def _convert_proc( 1828 p: Union[_PreprocessingDescr_v0_4, _PostprocessingDescr_v0_4], 1829 tensor_axes: Sequence[str], 1830) -> Union[PreprocessingDescr, PostprocessingDescr]: 1831 if isinstance(p, _BinarizeDescr_v0_4): 1832 return BinarizeDescr(kwargs=BinarizeKwargs(threshold=p.kwargs.threshold)) 1833 elif isinstance(p, _ClipDescr_v0_4): 1834 return ClipDescr(kwargs=ClipKwargs(min=p.kwargs.min, max=p.kwargs.max)) 1835 elif isinstance(p, _SigmoidDescr_v0_4): 1836 return SigmoidDescr() 1837 elif isinstance(p, _ScaleLinearDescr_v0_4): 1838 axes = _axes_letters_to_ids(p.kwargs.axes) 1839 if p.kwargs.axes is None: 1840 axis = None 1841 else: 1842 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes) 1843 1844 if axis is None: 1845 assert not isinstance(p.kwargs.gain, list) 1846 assert not isinstance(p.kwargs.offset, list) 1847 kwargs = ScaleLinearKwargs(gain=p.kwargs.gain, offset=p.kwargs.offset) 1848 else: 1849 kwargs = ScaleLinearAlongAxisKwargs( 1850 axis=axis, gain=p.kwargs.gain, offset=p.kwargs.offset 1851 ) 1852 return ScaleLinearDescr(kwargs=kwargs) 1853 elif isinstance(p, _ScaleMeanVarianceDescr_v0_4): 1854 return ScaleMeanVarianceDescr( 1855 kwargs=ScaleMeanVarianceKwargs( 1856 axes=_axes_letters_to_ids(p.kwargs.axes), 1857 reference_tensor=TensorId(str(p.kwargs.reference_tensor)), 1858 eps=p.kwargs.eps, 1859 ) 1860 ) 1861 elif isinstance(p, _ZeroMeanUnitVarianceDescr_v0_4): 1862 if p.kwargs.mode == "fixed": 1863 mean = p.kwargs.mean 1864 std = p.kwargs.std 1865 assert mean is not None 1866 assert std is not None 1867 1868 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes) 1869 1870 if axis is None: 1871 return FixedZeroMeanUnitVarianceDescr( 1872 kwargs=FixedZeroMeanUnitVarianceKwargs( 1873 mean=mean, std=std # pyright: ignore[reportArgumentType] 1874 ) 1875 ) 1876 else: 1877 if not isinstance(mean, list): 1878 mean = [float(mean)] 1879 if not isinstance(std, list): 1880 std = [float(std)] 1881 1882 return FixedZeroMeanUnitVarianceDescr( 1883 kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( 1884 axis=axis, mean=mean, std=std 1885 ) 1886 ) 1887 1888 else: 1889 axes = _axes_letters_to_ids(p.kwargs.axes) or [] 1890 if p.kwargs.mode == "per_dataset": 1891 axes = [AxisId("batch")] + axes 1892 if not axes: 1893 axes = None 1894 return ZeroMeanUnitVarianceDescr( 1895 kwargs=ZeroMeanUnitVarianceKwargs(axes=axes, eps=p.kwargs.eps) 1896 ) 1897 1898 elif isinstance(p, _ScaleRangeDescr_v0_4): 1899 return ScaleRangeDescr( 1900 kwargs=ScaleRangeKwargs( 1901 axes=_axes_letters_to_ids(p.kwargs.axes), 1902 min_percentile=p.kwargs.min_percentile, 1903 max_percentile=p.kwargs.max_percentile, 1904 eps=p.kwargs.eps, 1905 ) 1906 ) 1907 else: 1908 assert_never(p) 1909 1910 1911class _InputTensorConv( 1912 Converter[ 1913 _InputTensorDescr_v0_4, 1914 InputTensorDescr, 1915 FileSource_, 1916 Optional[FileSource_], 1917 Mapping[_TensorName_v0_4, Mapping[str, int]], 1918 ] 1919): 1920 def _convert( 1921 self, 1922 src: _InputTensorDescr_v0_4, 1923 tgt: "type[InputTensorDescr] | type[dict[str, Any]]", 1924 test_tensor: FileSource_, 1925 sample_tensor: Optional[FileSource_], 1926 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 1927 ) -> "InputTensorDescr | dict[str, Any]": 1928 axes: List[InputAxis] = convert_axes( # pyright: ignore[reportAssignmentType] 1929 src.axes, 1930 shape=src.shape, 1931 tensor_type="input", 1932 halo=None, 1933 size_refs=size_refs, 1934 ) 1935 prep: List[PreprocessingDescr] = [] 1936 for p in src.preprocessing: 1937 cp = _convert_proc(p, src.axes) 1938 assert not isinstance(cp, ScaleMeanVarianceDescr) 1939 prep.append(cp) 1940 1941 prep.append(EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="float32"))) 1942 1943 return tgt( 1944 axes=axes, 1945 id=TensorId(str(src.name)), 1946 test_tensor=FileDescr(source=test_tensor), 1947 sample_tensor=( 1948 None if sample_tensor is None else FileDescr(source=sample_tensor) 1949 ), 1950 data=dict(type=src.data_type), # pyright: ignore[reportArgumentType] 1951 preprocessing=prep, 1952 ) 1953 1954 1955_input_tensor_conv = _InputTensorConv(_InputTensorDescr_v0_4, InputTensorDescr) 1956 1957 1958class OutputTensorDescr(TensorDescrBase[OutputAxis]): 1959 id: TensorId = TensorId("output") 1960 """Output tensor id. 1961 No duplicates are allowed across all inputs and outputs.""" 1962 1963 postprocessing: List[PostprocessingDescr] = Field( 1964 default_factory=cast(Callable[[], List[PostprocessingDescr]], list) 1965 ) 1966 """Description of how this output should be postprocessed. 1967 1968 note: `postprocessing` always ends with an 'ensure_dtype' operation. 1969 If not given this is added to cast to this tensor's `data.type`. 1970 """ 1971 1972 @model_validator(mode="after") 1973 def _validate_postprocessing_kwargs(self) -> Self: 1974 axes_ids = [a.id for a in self.axes] 1975 for p in self.postprocessing: 1976 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 1977 if kwargs_axes is None: 1978 continue 1979 1980 if not isinstance(kwargs_axes, collections.abc.Sequence): 1981 raise ValueError( 1982 f"expected `axes` sequence, but got {type(kwargs_axes)}" 1983 ) 1984 1985 if any(a not in axes_ids for a in kwargs_axes): 1986 raise ValueError("`kwargs.axes` needs to be subset of axes ids") 1987 1988 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 1989 dtype = self.data.type 1990 else: 1991 dtype = self.data[0].type 1992 1993 # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 1994 if not self.postprocessing or not isinstance( 1995 self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr) 1996 ): 1997 self.postprocessing.append( 1998 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 1999 ) 2000 return self 2001 2002 2003class _OutputTensorConv( 2004 Converter[ 2005 _OutputTensorDescr_v0_4, 2006 OutputTensorDescr, 2007 FileSource_, 2008 Optional[FileSource_], 2009 Mapping[_TensorName_v0_4, Mapping[str, int]], 2010 ] 2011): 2012 def _convert( 2013 self, 2014 src: _OutputTensorDescr_v0_4, 2015 tgt: "type[OutputTensorDescr] | type[dict[str, Any]]", 2016 test_tensor: FileSource_, 2017 sample_tensor: Optional[FileSource_], 2018 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 2019 ) -> "OutputTensorDescr | dict[str, Any]": 2020 # TODO: split convert_axes into convert_output_axes and convert_input_axes 2021 axes: List[OutputAxis] = convert_axes( # pyright: ignore[reportAssignmentType] 2022 src.axes, 2023 shape=src.shape, 2024 tensor_type="output", 2025 halo=src.halo, 2026 size_refs=size_refs, 2027 ) 2028 data_descr: Dict[str, Any] = dict(type=src.data_type) 2029 if data_descr["type"] == "bool": 2030 data_descr["values"] = [False, True] 2031 2032 return tgt( 2033 axes=axes, 2034 id=TensorId(str(src.name)), 2035 test_tensor=FileDescr(source=test_tensor), 2036 sample_tensor=( 2037 None if sample_tensor is None else FileDescr(source=sample_tensor) 2038 ), 2039 data=data_descr, # pyright: ignore[reportArgumentType] 2040 postprocessing=[_convert_proc(p, src.axes) for p in src.postprocessing], 2041 ) 2042 2043 2044_output_tensor_conv = _OutputTensorConv(_OutputTensorDescr_v0_4, OutputTensorDescr) 2045 2046 2047TensorDescr = Union[InputTensorDescr, OutputTensorDescr] 2048 2049 2050def validate_tensors( 2051 tensors: Mapping[TensorId, Tuple[TensorDescr, NDArray[Any]]], 2052 tensor_origin: Literal[ 2053 "test_tensor" 2054 ], # for more precise error messages, e.g. 'test_tensor' 2055): 2056 all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, int]]] = {} 2057 2058 def e_msg(d: TensorDescr): 2059 return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]" 2060 2061 for descr, array in tensors.values(): 2062 try: 2063 axis_sizes = descr.get_axis_sizes_for_array(array) 2064 except ValueError as e: 2065 raise ValueError(f"{e_msg(descr)} {e}") 2066 else: 2067 all_tensor_axes[descr.id] = { 2068 a.id: (a, axis_sizes[a.id]) for a in descr.axes 2069 } 2070 2071 for descr, array in tensors.values(): 2072 if descr.dtype in ("float32", "float64"): 2073 invalid_test_tensor_dtype = array.dtype.name not in ( 2074 "float32", 2075 "float64", 2076 "uint8", 2077 "int8", 2078 "uint16", 2079 "int16", 2080 "uint32", 2081 "int32", 2082 "uint64", 2083 "int64", 2084 ) 2085 else: 2086 invalid_test_tensor_dtype = array.dtype.name != descr.dtype 2087 2088 if invalid_test_tensor_dtype: 2089 raise ValueError( 2090 f"{e_msg(descr)}.{tensor_origin}.dtype '{array.dtype.name}' does not" 2091 + f" match described dtype '{descr.dtype}'" 2092 ) 2093 2094 if array.min() > -1e-4 and array.max() < 1e-4: 2095 raise ValueError( 2096 "Output values are too small for reliable testing." 2097 + f" Values <-1e5 or >=1e5 must be present in {tensor_origin}" 2098 ) 2099 2100 for a in descr.axes: 2101 actual_size = all_tensor_axes[descr.id][a.id][1] 2102 if a.size is None: 2103 continue 2104 2105 if isinstance(a.size, int): 2106 if actual_size != a.size: 2107 raise ValueError( 2108 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' " 2109 + f"has incompatible size {actual_size}, expected {a.size}" 2110 ) 2111 elif isinstance(a.size, ParameterizedSize): 2112 _ = a.size.validate_size(actual_size) 2113 elif isinstance(a.size, DataDependentSize): 2114 _ = a.size.validate_size(actual_size) 2115 elif isinstance(a.size, SizeReference): 2116 ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id) 2117 if ref_tensor_axes is None: 2118 raise ValueError( 2119 f"{e_msg(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor" 2120 + f" reference '{a.size.tensor_id}'" 2121 ) 2122 2123 ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None)) 2124 if ref_axis is None or ref_size is None: 2125 raise ValueError( 2126 f"{e_msg(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis" 2127 + f" reference '{a.size.tensor_id}.{a.size.axis_id}" 2128 ) 2129 2130 if a.unit != ref_axis.unit: 2131 raise ValueError( 2132 f"{e_msg(descr)}.axes[{a.id}].size: `SizeReference` requires" 2133 + " axis and reference axis to have the same `unit`, but" 2134 + f" {a.unit}!={ref_axis.unit}" 2135 ) 2136 2137 if actual_size != ( 2138 expected_size := ( 2139 ref_size * ref_axis.scale / a.scale + a.size.offset 2140 ) 2141 ): 2142 raise ValueError( 2143 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' of size" 2144 + f" {actual_size} invalid for referenced size {ref_size};" 2145 + f" expected {expected_size}" 2146 ) 2147 else: 2148 assert_never(a.size) 2149 2150 2151FileDescr_dependencies = Annotated[ 2152 FileDescr_, 2153 WithSuffix((".yaml", ".yml"), case_sensitive=True), 2154 Field(examples=[dict(source="environment.yaml")]), 2155] 2156 2157 2158class _ArchitectureCallableDescr(Node): 2159 callable: Annotated[Identifier, Field(examples=["MyNetworkClass", "get_my_model"])] 2160 """Identifier of the callable that returns a torch.nn.Module instance.""" 2161 2162 kwargs: Dict[str, YamlValue] = Field( 2163 default_factory=cast(Callable[[], Dict[str, YamlValue]], dict) 2164 ) 2165 """key word arguments for the `callable`""" 2166 2167 2168class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr): 2169 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 2170 """Architecture source file""" 2171 2172 @model_serializer(mode="wrap", when_used="unless-none") 2173 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo): 2174 return package_file_descr_serializer(self, nxt, info) 2175 2176 2177class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr): 2178 import_from: str 2179 """Where to import the callable from, i.e. `from <import_from> import <callable>`""" 2180 2181 2182class _ArchFileConv( 2183 Converter[ 2184 _CallableFromFile_v0_4, 2185 ArchitectureFromFileDescr, 2186 Optional[Sha256], 2187 Dict[str, Any], 2188 ] 2189): 2190 def _convert( 2191 self, 2192 src: _CallableFromFile_v0_4, 2193 tgt: "type[ArchitectureFromFileDescr | dict[str, Any]]", 2194 sha256: Optional[Sha256], 2195 kwargs: Dict[str, Any], 2196 ) -> "ArchitectureFromFileDescr | dict[str, Any]": 2197 if src.startswith("http") and src.count(":") == 2: 2198 http, source, callable_ = src.split(":") 2199 source = ":".join((http, source)) 2200 elif not src.startswith("http") and src.count(":") == 1: 2201 source, callable_ = src.split(":") 2202 else: 2203 source = str(src) 2204 callable_ = str(src) 2205 return tgt( 2206 callable=Identifier(callable_), 2207 source=cast(FileSource_, source), 2208 sha256=sha256, 2209 kwargs=kwargs, 2210 ) 2211 2212 2213_arch_file_conv = _ArchFileConv(_CallableFromFile_v0_4, ArchitectureFromFileDescr) 2214 2215 2216class _ArchLibConv( 2217 Converter[ 2218 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr, Dict[str, Any] 2219 ] 2220): 2221 def _convert( 2222 self, 2223 src: _CallableFromDepencency_v0_4, 2224 tgt: "type[ArchitectureFromLibraryDescr | dict[str, Any]]", 2225 kwargs: Dict[str, Any], 2226 ) -> "ArchitectureFromLibraryDescr | dict[str, Any]": 2227 *mods, callable_ = src.split(".") 2228 import_from = ".".join(mods) 2229 return tgt( 2230 import_from=import_from, callable=Identifier(callable_), kwargs=kwargs 2231 ) 2232 2233 2234_arch_lib_conv = _ArchLibConv( 2235 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr 2236) 2237 2238 2239class WeightsEntryDescrBase(FileDescr): 2240 type: ClassVar[WeightsFormat] 2241 weights_format_name: ClassVar[str] # human readable 2242 2243 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 2244 """Source of the weights file.""" 2245 2246 authors: Optional[List[Author]] = None 2247 """Authors 2248 Either the person(s) that have trained this model resulting in the original weights file. 2249 (If this is the initial weights entry, i.e. it does not have a `parent`) 2250 Or the person(s) who have converted the weights to this weights format. 2251 (If this is a child weight, i.e. it has a `parent` field) 2252 """ 2253 2254 parent: Annotated[ 2255 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"]) 2256 ] = None 2257 """The source weights these weights were converted from. 2258 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`, 2259 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights. 2260 All weight entries except one (the initial set of weights resulting from training the model), 2261 need to have this field.""" 2262 2263 comment: str = "" 2264 """A comment about this weights entry, for example how these weights were created.""" 2265 2266 @model_validator(mode="after") 2267 def _validate(self) -> Self: 2268 if self.type == self.parent: 2269 raise ValueError("Weights entry can't be it's own parent.") 2270 2271 return self 2272 2273 @model_serializer(mode="wrap", when_used="unless-none") 2274 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo): 2275 return package_file_descr_serializer(self, nxt, info) 2276 2277 2278class KerasHdf5WeightsDescr(WeightsEntryDescrBase): 2279 type = "keras_hdf5" 2280 weights_format_name: ClassVar[str] = "Keras HDF5" 2281 tensorflow_version: Version 2282 """TensorFlow version used to create these weights.""" 2283 2284 2285class OnnxWeightsDescr(WeightsEntryDescrBase): 2286 type = "onnx" 2287 weights_format_name: ClassVar[str] = "ONNX" 2288 opset_version: Annotated[int, Ge(7)] 2289 """ONNX opset version""" 2290 2291 2292class PytorchStateDictWeightsDescr(WeightsEntryDescrBase): 2293 type = "pytorch_state_dict" 2294 weights_format_name: ClassVar[str] = "Pytorch State Dict" 2295 architecture: Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr] 2296 pytorch_version: Version 2297 """Version of the PyTorch library used. 2298 If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible. 2299 """ 2300 dependencies: Optional[FileDescr_dependencies] = None 2301 """Custom depencies beyond pytorch described in a Conda environment file. 2302 Allows to specify custom dependencies, see conda docs: 2303 - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms) 2304 - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually) 2305 2306 The conda environment file should include pytorch and any version pinning has to be compatible with 2307 **pytorch_version**. 2308 """ 2309 2310 2311class TensorflowJsWeightsDescr(WeightsEntryDescrBase): 2312 type = "tensorflow_js" 2313 weights_format_name: ClassVar[str] = "Tensorflow.js" 2314 tensorflow_version: Version 2315 """Version of the TensorFlow library used.""" 2316 2317 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 2318 """The multi-file weights. 2319 All required files/folders should be a zip archive.""" 2320 2321 2322class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase): 2323 type = "tensorflow_saved_model_bundle" 2324 weights_format_name: ClassVar[str] = "Tensorflow Saved Model" 2325 tensorflow_version: Version 2326 """Version of the TensorFlow library used.""" 2327 2328 dependencies: Optional[FileDescr_dependencies] = None 2329 """Custom dependencies beyond tensorflow. 2330 Should include tensorflow and any version pinning has to be compatible with **tensorflow_version**.""" 2331 2332 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 2333 """The multi-file weights. 2334 All required files/folders should be a zip archive.""" 2335 2336 2337class TorchscriptWeightsDescr(WeightsEntryDescrBase): 2338 type = "torchscript" 2339 weights_format_name: ClassVar[str] = "TorchScript" 2340 pytorch_version: Version 2341 """Version of the PyTorch library used.""" 2342 2343 2344class WeightsDescr(Node): 2345 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None 2346 onnx: Optional[OnnxWeightsDescr] = None 2347 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None 2348 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None 2349 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = ( 2350 None 2351 ) 2352 torchscript: Optional[TorchscriptWeightsDescr] = None 2353 2354 @model_validator(mode="after") 2355 def check_entries(self) -> Self: 2356 entries = {wtype for wtype, entry in self if entry is not None} 2357 2358 if not entries: 2359 raise ValueError("Missing weights entry") 2360 2361 entries_wo_parent = { 2362 wtype 2363 for wtype, entry in self 2364 if entry is not None and hasattr(entry, "parent") and entry.parent is None 2365 } 2366 if len(entries_wo_parent) != 1: 2367 issue_warning( 2368 "Exactly one weights entry may not specify the `parent` field (got" 2369 + " {value}). That entry is considered the original set of model weights." 2370 + " Other weight formats are created through conversion of the orignal or" 2371 + " already converted weights. They have to reference the weights format" 2372 + " they were converted from as their `parent`.", 2373 value=len(entries_wo_parent), 2374 field="weights", 2375 ) 2376 2377 for wtype, entry in self: 2378 if entry is None: 2379 continue 2380 2381 assert hasattr(entry, "type") 2382 assert hasattr(entry, "parent") 2383 assert wtype == entry.type 2384 if ( 2385 entry.parent is not None and entry.parent not in entries 2386 ): # self reference checked for `parent` field 2387 raise ValueError( 2388 f"`weights.{wtype}.parent={entry.parent} not in specified weight" 2389 + f" formats: {entries}" 2390 ) 2391 2392 return self 2393 2394 def __getitem__( 2395 self, 2396 key: Literal[ 2397 "keras_hdf5", 2398 "onnx", 2399 "pytorch_state_dict", 2400 "tensorflow_js", 2401 "tensorflow_saved_model_bundle", 2402 "torchscript", 2403 ], 2404 ): 2405 if key == "keras_hdf5": 2406 ret = self.keras_hdf5 2407 elif key == "onnx": 2408 ret = self.onnx 2409 elif key == "pytorch_state_dict": 2410 ret = self.pytorch_state_dict 2411 elif key == "tensorflow_js": 2412 ret = self.tensorflow_js 2413 elif key == "tensorflow_saved_model_bundle": 2414 ret = self.tensorflow_saved_model_bundle 2415 elif key == "torchscript": 2416 ret = self.torchscript 2417 else: 2418 raise KeyError(key) 2419 2420 if ret is None: 2421 raise KeyError(key) 2422 2423 return ret 2424 2425 @property 2426 def available_formats(self): 2427 return { 2428 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}), 2429 **({} if self.onnx is None else {"onnx": self.onnx}), 2430 **( 2431 {} 2432 if self.pytorch_state_dict is None 2433 else {"pytorch_state_dict": self.pytorch_state_dict} 2434 ), 2435 **( 2436 {} 2437 if self.tensorflow_js is None 2438 else {"tensorflow_js": self.tensorflow_js} 2439 ), 2440 **( 2441 {} 2442 if self.tensorflow_saved_model_bundle is None 2443 else { 2444 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle 2445 } 2446 ), 2447 **({} if self.torchscript is None else {"torchscript": self.torchscript}), 2448 } 2449 2450 @property 2451 def missing_formats(self): 2452 return { 2453 wf for wf in get_args(WeightsFormat) if wf not in self.available_formats 2454 } 2455 2456 2457class ModelId(ResourceId): 2458 pass 2459 2460 2461class LinkedModel(LinkedResourceBase): 2462 """Reference to a bioimage.io model.""" 2463 2464 id: ModelId 2465 """A valid model `id` from the bioimage.io collection.""" 2466 2467 2468class _DataDepSize(NamedTuple): 2469 min: int 2470 max: Optional[int] 2471 2472 2473class _AxisSizes(NamedTuple): 2474 """the lenghts of all axes of model inputs and outputs""" 2475 2476 inputs: Dict[Tuple[TensorId, AxisId], int] 2477 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] 2478 2479 2480class _TensorSizes(NamedTuple): 2481 """_AxisSizes as nested dicts""" 2482 2483 inputs: Dict[TensorId, Dict[AxisId, int]] 2484 outputs: Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]] 2485 2486 2487class ReproducibilityTolerance(Node, extra="allow"): 2488 """Describes what small numerical differences -- if any -- may be tolerated 2489 in the generated output when executing in different environments. 2490 2491 A tensor element *output* is considered mismatched to the **test_tensor** if 2492 abs(*output* - **test_tensor**) > **absolute_tolerance** + **relative_tolerance** * abs(**test_tensor**). 2493 (Internally we call [numpy.testing.assert_allclose](https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html).) 2494 2495 Motivation: 2496 For testing we can request the respective deep learning frameworks to be as 2497 reproducible as possible by setting seeds and chosing deterministic algorithms, 2498 but differences in operating systems, available hardware and installed drivers 2499 may still lead to numerical differences. 2500 """ 2501 2502 relative_tolerance: RelativeTolerance = 1e-3 2503 """Maximum relative tolerance of reproduced test tensor.""" 2504 2505 absolute_tolerance: AbsoluteTolerance = 1e-4 2506 """Maximum absolute tolerance of reproduced test tensor.""" 2507 2508 mismatched_elements_per_million: MismatchedElementsPerMillion = 100 2509 """Maximum number of mismatched elements/pixels per million to tolerate.""" 2510 2511 output_ids: Sequence[TensorId] = () 2512 """Limits the output tensor IDs these reproducibility details apply to.""" 2513 2514 weights_formats: Sequence[WeightsFormat] = () 2515 """Limits the weights formats these details apply to.""" 2516 2517 2518class BioimageioConfig(Node, extra="allow"): 2519 reproducibility_tolerance: Sequence[ReproducibilityTolerance] = () 2520 """Tolerances to allow when reproducing the model's test outputs 2521 from the model's test inputs. 2522 Only the first entry matching tensor id and weights format is considered. 2523 """ 2524 2525 2526class Config(Node, extra="allow"): 2527 bioimageio: BioimageioConfig = Field(default_factory=BioimageioConfig) 2528 2529 2530class ModelDescr(GenericModelDescrBase): 2531 """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. 2532 These fields are typically stored in a YAML file which we call a model resource description file (model RDF). 2533 """ 2534 2535 implemented_format_version: ClassVar[Literal["0.5.4"]] = "0.5.4" 2536 if TYPE_CHECKING: 2537 format_version: Literal["0.5.4"] = "0.5.4" 2538 else: 2539 format_version: Literal["0.5.4"] 2540 """Version of the bioimage.io model description specification used. 2541 When creating a new model always use the latest micro/patch version described here. 2542 The `format_version` is important for any consumer software to understand how to parse the fields. 2543 """ 2544 2545 implemented_type: ClassVar[Literal["model"]] = "model" 2546 if TYPE_CHECKING: 2547 type: Literal["model"] = "model" 2548 else: 2549 type: Literal["model"] 2550 """Specialized resource type 'model'""" 2551 2552 id: Optional[ModelId] = None 2553 """bioimage.io-wide unique resource identifier 2554 assigned by bioimage.io; version **un**specific.""" 2555 2556 authors: NotEmpty[List[Author]] 2557 """The authors are the creators of the model RDF and the primary points of contact.""" 2558 2559 documentation: FileSource_documentation 2560 """URL or relative path to a markdown file with additional documentation. 2561 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory. 2562 The documentation should include a '#[#] Validation' (sub)section 2563 with details on how to quantitatively validate the model on unseen data.""" 2564 2565 @field_validator("documentation", mode="after") 2566 @classmethod 2567 def _validate_documentation( 2568 cls, value: FileSource_documentation 2569 ) -> FileSource_documentation: 2570 if not get_validation_context().perform_io_checks: 2571 return value 2572 2573 doc_reader = get_reader(value) 2574 doc_content = doc_reader.read().decode(encoding="utf-8") 2575 if not re.search("#.*[vV]alidation", doc_content): 2576 issue_warning( 2577 "No '# Validation' (sub)section found in {value}.", 2578 value=value, 2579 field="documentation", 2580 ) 2581 2582 return value 2583 2584 inputs: NotEmpty[Sequence[InputTensorDescr]] 2585 """Describes the input tensors expected by this model.""" 2586 2587 @field_validator("inputs", mode="after") 2588 @classmethod 2589 def _validate_input_axes( 2590 cls, inputs: Sequence[InputTensorDescr] 2591 ) -> Sequence[InputTensorDescr]: 2592 input_size_refs = cls._get_axes_with_independent_size(inputs) 2593 2594 for i, ipt in enumerate(inputs): 2595 valid_independent_refs: Dict[ 2596 Tuple[TensorId, AxisId], 2597 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2598 ] = { 2599 **{ 2600 (ipt.id, a.id): (ipt, a, a.size) 2601 for a in ipt.axes 2602 if not isinstance(a, BatchAxis) 2603 and isinstance(a.size, (int, ParameterizedSize)) 2604 }, 2605 **input_size_refs, 2606 } 2607 for a, ax in enumerate(ipt.axes): 2608 cls._validate_axis( 2609 "inputs", 2610 i=i, 2611 tensor_id=ipt.id, 2612 a=a, 2613 axis=ax, 2614 valid_independent_refs=valid_independent_refs, 2615 ) 2616 return inputs 2617 2618 @staticmethod 2619 def _validate_axis( 2620 field_name: str, 2621 i: int, 2622 tensor_id: TensorId, 2623 a: int, 2624 axis: AnyAxis, 2625 valid_independent_refs: Dict[ 2626 Tuple[TensorId, AxisId], 2627 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2628 ], 2629 ): 2630 if isinstance(axis, BatchAxis) or isinstance( 2631 axis.size, (int, ParameterizedSize, DataDependentSize) 2632 ): 2633 return 2634 elif not isinstance(axis.size, SizeReference): 2635 assert_never(axis.size) 2636 2637 # validate axis.size SizeReference 2638 ref = (axis.size.tensor_id, axis.size.axis_id) 2639 if ref not in valid_independent_refs: 2640 raise ValueError( 2641 "Invalid tensor axis reference at" 2642 + f" {field_name}[{i}].axes[{a}].size: {axis.size}." 2643 ) 2644 if ref == (tensor_id, axis.id): 2645 raise ValueError( 2646 "Self-referencing not allowed for" 2647 + f" {field_name}[{i}].axes[{a}].size: {axis.size}" 2648 ) 2649 if axis.type == "channel": 2650 if valid_independent_refs[ref][1].type != "channel": 2651 raise ValueError( 2652 "A channel axis' size may only reference another fixed size" 2653 + " channel axis." 2654 ) 2655 if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names: 2656 ref_size = valid_independent_refs[ref][2] 2657 assert isinstance(ref_size, int), ( 2658 "channel axis ref (another channel axis) has to specify fixed" 2659 + " size" 2660 ) 2661 generated_channel_names = [ 2662 Identifier(axis.channel_names.format(i=i)) 2663 for i in range(1, ref_size + 1) 2664 ] 2665 axis.channel_names = generated_channel_names 2666 2667 if (ax_unit := getattr(axis, "unit", None)) != ( 2668 ref_unit := getattr(valid_independent_refs[ref][1], "unit", None) 2669 ): 2670 raise ValueError( 2671 "The units of an axis and its reference axis need to match, but" 2672 + f" '{ax_unit}' != '{ref_unit}'." 2673 ) 2674 ref_axis = valid_independent_refs[ref][1] 2675 if isinstance(ref_axis, BatchAxis): 2676 raise ValueError( 2677 f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}" 2678 + " (a batch axis is not allowed as reference)." 2679 ) 2680 2681 if isinstance(axis, WithHalo): 2682 min_size = axis.size.get_size(axis, ref_axis, n=0) 2683 if (min_size - 2 * axis.halo) < 1: 2684 raise ValueError( 2685 f"axis {axis.id} with minimum size {min_size} is too small for halo" 2686 + f" {axis.halo}." 2687 ) 2688 2689 input_halo = axis.halo * axis.scale / ref_axis.scale 2690 if input_halo != int(input_halo) or input_halo % 2 == 1: 2691 raise ValueError( 2692 f"input_halo {input_halo} (output_halo {axis.halo} *" 2693 + f" output_scale {axis.scale} / input_scale {ref_axis.scale})" 2694 + f" {tensor_id}.{axis.id}." 2695 ) 2696 2697 @model_validator(mode="after") 2698 def _validate_test_tensors(self) -> Self: 2699 if not get_validation_context().perform_io_checks: 2700 return self 2701 2702 test_output_arrays = [load_array(descr.test_tensor) for descr in self.outputs] 2703 test_input_arrays = [load_array(descr.test_tensor) for descr in self.inputs] 2704 2705 tensors = { 2706 descr.id: (descr, array) 2707 for descr, array in zip( 2708 chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays 2709 ) 2710 } 2711 validate_tensors(tensors, tensor_origin="test_tensor") 2712 2713 output_arrays = { 2714 descr.id: array for descr, array in zip(self.outputs, test_output_arrays) 2715 } 2716 for rep_tol in self.config.bioimageio.reproducibility_tolerance: 2717 if not rep_tol.absolute_tolerance: 2718 continue 2719 2720 if rep_tol.output_ids: 2721 out_arrays = { 2722 oid: a 2723 for oid, a in output_arrays.items() 2724 if oid in rep_tol.output_ids 2725 } 2726 else: 2727 out_arrays = output_arrays 2728 2729 for out_id, array in out_arrays.items(): 2730 if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01: 2731 raise ValueError( 2732 "config.bioimageio.reproducibility_tolerance.absolute_tolerance=" 2733 + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}" 2734 + f" (1% of the maximum value of the test tensor '{out_id}')" 2735 ) 2736 2737 return self 2738 2739 @model_validator(mode="after") 2740 def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self: 2741 ipt_refs = {t.id for t in self.inputs} 2742 out_refs = {t.id for t in self.outputs} 2743 for ipt in self.inputs: 2744 for p in ipt.preprocessing: 2745 ref = p.kwargs.get("reference_tensor") 2746 if ref is None: 2747 continue 2748 if ref not in ipt_refs: 2749 raise ValueError( 2750 f"`reference_tensor` '{ref}' not found. Valid input tensor" 2751 + f" references are: {ipt_refs}." 2752 ) 2753 2754 for out in self.outputs: 2755 for p in out.postprocessing: 2756 ref = p.kwargs.get("reference_tensor") 2757 if ref is None: 2758 continue 2759 2760 if ref not in ipt_refs and ref not in out_refs: 2761 raise ValueError( 2762 f"`reference_tensor` '{ref}' not found. Valid tensor references" 2763 + f" are: {ipt_refs | out_refs}." 2764 ) 2765 2766 return self 2767 2768 # TODO: use validate funcs in validate_test_tensors 2769 # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]: 2770 2771 name: Annotated[ 2772 Annotated[ 2773 str, RestrictCharacters(string.ascii_letters + string.digits + "_+- ()") 2774 ], 2775 MinLen(5), 2776 MaxLen(128), 2777 warn(MaxLen(64), "Name longer than 64 characters.", INFO), 2778 ] 2779 """A human-readable name of this model. 2780 It should be no longer than 64 characters 2781 and may only contain letter, number, underscore, minus, parentheses and spaces. 2782 We recommend to chose a name that refers to the model's task and image modality. 2783 """ 2784 2785 outputs: NotEmpty[Sequence[OutputTensorDescr]] 2786 """Describes the output tensors.""" 2787 2788 @field_validator("outputs", mode="after") 2789 @classmethod 2790 def _validate_tensor_ids( 2791 cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo 2792 ) -> Sequence[OutputTensorDescr]: 2793 tensor_ids = [ 2794 t.id for t in info.data.get("inputs", []) + info.data.get("outputs", []) 2795 ] 2796 duplicate_tensor_ids: List[str] = [] 2797 seen: Set[str] = set() 2798 for t in tensor_ids: 2799 if t in seen: 2800 duplicate_tensor_ids.append(t) 2801 2802 seen.add(t) 2803 2804 if duplicate_tensor_ids: 2805 raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}") 2806 2807 return outputs 2808 2809 @staticmethod 2810 def _get_axes_with_parameterized_size( 2811 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 2812 ): 2813 return { 2814 f"{t.id}.{a.id}": (t, a, a.size) 2815 for t in io 2816 for a in t.axes 2817 if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize) 2818 } 2819 2820 @staticmethod 2821 def _get_axes_with_independent_size( 2822 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 2823 ): 2824 return { 2825 (t.id, a.id): (t, a, a.size) 2826 for t in io 2827 for a in t.axes 2828 if not isinstance(a, BatchAxis) 2829 and isinstance(a.size, (int, ParameterizedSize)) 2830 } 2831 2832 @field_validator("outputs", mode="after") 2833 @classmethod 2834 def _validate_output_axes( 2835 cls, outputs: List[OutputTensorDescr], info: ValidationInfo 2836 ) -> List[OutputTensorDescr]: 2837 input_size_refs = cls._get_axes_with_independent_size( 2838 info.data.get("inputs", []) 2839 ) 2840 output_size_refs = cls._get_axes_with_independent_size(outputs) 2841 2842 for i, out in enumerate(outputs): 2843 valid_independent_refs: Dict[ 2844 Tuple[TensorId, AxisId], 2845 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2846 ] = { 2847 **{ 2848 (out.id, a.id): (out, a, a.size) 2849 for a in out.axes 2850 if not isinstance(a, BatchAxis) 2851 and isinstance(a.size, (int, ParameterizedSize)) 2852 }, 2853 **input_size_refs, 2854 **output_size_refs, 2855 } 2856 for a, ax in enumerate(out.axes): 2857 cls._validate_axis( 2858 "outputs", 2859 i, 2860 out.id, 2861 a, 2862 ax, 2863 valid_independent_refs=valid_independent_refs, 2864 ) 2865 2866 return outputs 2867 2868 packaged_by: List[Author] = Field( 2869 default_factory=cast(Callable[[], List[Author]], list) 2870 ) 2871 """The persons that have packaged and uploaded this model. 2872 Only required if those persons differ from the `authors`.""" 2873 2874 parent: Optional[LinkedModel] = None 2875 """The model from which this model is derived, e.g. by fine-tuning the weights.""" 2876 2877 @model_validator(mode="after") 2878 def _validate_parent_is_not_self(self) -> Self: 2879 if self.parent is not None and self.parent.id == self.id: 2880 raise ValueError("A model description may not reference itself as parent.") 2881 2882 return self 2883 2884 run_mode: Annotated[ 2885 Optional[RunMode], 2886 warn(None, "Run mode '{value}' has limited support across consumer softwares."), 2887 ] = None 2888 """Custom run mode for this model: for more complex prediction procedures like test time 2889 data augmentation that currently cannot be expressed in the specification. 2890 No standard run modes are defined yet.""" 2891 2892 timestamp: Datetime = Field(default_factory=Datetime.now) 2893 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format 2894 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat). 2895 (In Python a datetime object is valid, too).""" 2896 2897 training_data: Annotated[ 2898 Union[None, LinkedDataset, DatasetDescr, DatasetDescr02], 2899 Field(union_mode="left_to_right"), 2900 ] = None 2901 """The dataset used to train this model""" 2902 2903 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)] 2904 """The weights for this model. 2905 Weights can be given for different formats, but should otherwise be equivalent. 2906 The available weight formats determine which consumers can use this model.""" 2907 2908 config: Config = Field(default_factory=Config) 2909 2910 @model_validator(mode="after") 2911 def _add_default_cover(self) -> Self: 2912 if not get_validation_context().perform_io_checks or self.covers: 2913 return self 2914 2915 try: 2916 generated_covers = generate_covers( 2917 [(t, load_array(t.test_tensor)) for t in self.inputs], 2918 [(t, load_array(t.test_tensor)) for t in self.outputs], 2919 ) 2920 except Exception as e: 2921 issue_warning( 2922 "Failed to generate cover image(s): {e}", 2923 value=self.covers, 2924 msg_context=dict(e=e), 2925 field="covers", 2926 ) 2927 else: 2928 self.covers.extend(generated_covers) 2929 2930 return self 2931 2932 def get_input_test_arrays(self) -> List[NDArray[Any]]: 2933 data = [load_array(ipt.test_tensor) for ipt in self.inputs] 2934 assert all(isinstance(d, np.ndarray) for d in data) 2935 return data 2936 2937 def get_output_test_arrays(self) -> List[NDArray[Any]]: 2938 data = [load_array(out.test_tensor) for out in self.outputs] 2939 assert all(isinstance(d, np.ndarray) for d in data) 2940 return data 2941 2942 @staticmethod 2943 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int: 2944 batch_size = 1 2945 tensor_with_batchsize: Optional[TensorId] = None 2946 for tid in tensor_sizes: 2947 for aid, s in tensor_sizes[tid].items(): 2948 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size: 2949 continue 2950 2951 if batch_size != 1: 2952 assert tensor_with_batchsize is not None 2953 raise ValueError( 2954 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})" 2955 ) 2956 2957 batch_size = s 2958 tensor_with_batchsize = tid 2959 2960 return batch_size 2961 2962 def get_output_tensor_sizes( 2963 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]] 2964 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]: 2965 """Returns the tensor output sizes for given **input_sizes**. 2966 Only if **input_sizes** has a valid input shape, the tensor output size is exact. 2967 Otherwise it might be larger than the actual (valid) output""" 2968 batch_size = self.get_batch_size(input_sizes) 2969 ns = self.get_ns(input_sizes) 2970 2971 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size) 2972 return tensor_sizes.outputs 2973 2974 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]): 2975 """get parameter `n` for each parameterized axis 2976 such that the valid input size is >= the given input size""" 2977 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {} 2978 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs} 2979 for tid in input_sizes: 2980 for aid, s in input_sizes[tid].items(): 2981 size_descr = axes[tid][aid].size 2982 if isinstance(size_descr, ParameterizedSize): 2983 ret[(tid, aid)] = size_descr.get_n(s) 2984 elif size_descr is None or isinstance(size_descr, (int, SizeReference)): 2985 pass 2986 else: 2987 assert_never(size_descr) 2988 2989 return ret 2990 2991 def get_tensor_sizes( 2992 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int 2993 ) -> _TensorSizes: 2994 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size) 2995 return _TensorSizes( 2996 { 2997 t: { 2998 aa: axis_sizes.inputs[(tt, aa)] 2999 for tt, aa in axis_sizes.inputs 3000 if tt == t 3001 } 3002 for t in {tt for tt, _ in axis_sizes.inputs} 3003 }, 3004 { 3005 t: { 3006 aa: axis_sizes.outputs[(tt, aa)] 3007 for tt, aa in axis_sizes.outputs 3008 if tt == t 3009 } 3010 for t in {tt for tt, _ in axis_sizes.outputs} 3011 }, 3012 ) 3013 3014 def get_axis_sizes( 3015 self, 3016 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], 3017 batch_size: Optional[int] = None, 3018 *, 3019 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None, 3020 ) -> _AxisSizes: 3021 """Determine input and output block shape for scale factors **ns** 3022 of parameterized input sizes. 3023 3024 Args: 3025 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id)) 3026 that is parameterized as `size = min + n * step`. 3027 batch_size: The desired size of the batch dimension. 3028 If given **batch_size** overwrites any batch size present in 3029 **max_input_shape**. Default 1. 3030 max_input_shape: Limits the derived block shapes. 3031 Each axis for which the input size, parameterized by `n`, is larger 3032 than **max_input_shape** is set to the minimal value `n_min` for which 3033 this is still true. 3034 Use this for small input samples or large values of **ns**. 3035 Or simply whenever you know the full input shape. 3036 3037 Returns: 3038 Resolved axis sizes for model inputs and outputs. 3039 """ 3040 max_input_shape = max_input_shape or {} 3041 if batch_size is None: 3042 for (_t_id, a_id), s in max_input_shape.items(): 3043 if a_id == BATCH_AXIS_ID: 3044 batch_size = s 3045 break 3046 else: 3047 batch_size = 1 3048 3049 all_axes = { 3050 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs) 3051 } 3052 3053 inputs: Dict[Tuple[TensorId, AxisId], int] = {} 3054 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {} 3055 3056 def get_axis_size(a: Union[InputAxis, OutputAxis]): 3057 if isinstance(a, BatchAxis): 3058 if (t_descr.id, a.id) in ns: 3059 logger.warning( 3060 "Ignoring unexpected size increment factor (n) for batch axis" 3061 + " of tensor '{}'.", 3062 t_descr.id, 3063 ) 3064 return batch_size 3065 elif isinstance(a.size, int): 3066 if (t_descr.id, a.id) in ns: 3067 logger.warning( 3068 "Ignoring unexpected size increment factor (n) for fixed size" 3069 + " axis '{}' of tensor '{}'.", 3070 a.id, 3071 t_descr.id, 3072 ) 3073 return a.size 3074 elif isinstance(a.size, ParameterizedSize): 3075 if (t_descr.id, a.id) not in ns: 3076 raise ValueError( 3077 "Size increment factor (n) missing for parametrized axis" 3078 + f" '{a.id}' of tensor '{t_descr.id}'." 3079 ) 3080 n = ns[(t_descr.id, a.id)] 3081 s_max = max_input_shape.get((t_descr.id, a.id)) 3082 if s_max is not None: 3083 n = min(n, a.size.get_n(s_max)) 3084 3085 return a.size.get_size(n) 3086 3087 elif isinstance(a.size, SizeReference): 3088 if (t_descr.id, a.id) in ns: 3089 logger.warning( 3090 "Ignoring unexpected size increment factor (n) for axis '{}'" 3091 + " of tensor '{}' with size reference.", 3092 a.id, 3093 t_descr.id, 3094 ) 3095 assert not isinstance(a, BatchAxis) 3096 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id] 3097 assert not isinstance(ref_axis, BatchAxis) 3098 ref_key = (a.size.tensor_id, a.size.axis_id) 3099 ref_size = inputs.get(ref_key, outputs.get(ref_key)) 3100 assert ref_size is not None, ref_key 3101 assert not isinstance(ref_size, _DataDepSize), ref_key 3102 return a.size.get_size( 3103 axis=a, 3104 ref_axis=ref_axis, 3105 ref_size=ref_size, 3106 ) 3107 elif isinstance(a.size, DataDependentSize): 3108 if (t_descr.id, a.id) in ns: 3109 logger.warning( 3110 "Ignoring unexpected increment factor (n) for data dependent" 3111 + " size axis '{}' of tensor '{}'.", 3112 a.id, 3113 t_descr.id, 3114 ) 3115 return _DataDepSize(a.size.min, a.size.max) 3116 else: 3117 assert_never(a.size) 3118 3119 # first resolve all , but the `SizeReference` input sizes 3120 for t_descr in self.inputs: 3121 for a in t_descr.axes: 3122 if not isinstance(a.size, SizeReference): 3123 s = get_axis_size(a) 3124 assert not isinstance(s, _DataDepSize) 3125 inputs[t_descr.id, a.id] = s 3126 3127 # resolve all other input axis sizes 3128 for t_descr in self.inputs: 3129 for a in t_descr.axes: 3130 if isinstance(a.size, SizeReference): 3131 s = get_axis_size(a) 3132 assert not isinstance(s, _DataDepSize) 3133 inputs[t_descr.id, a.id] = s 3134 3135 # resolve all output axis sizes 3136 for t_descr in self.outputs: 3137 for a in t_descr.axes: 3138 assert not isinstance(a.size, ParameterizedSize) 3139 s = get_axis_size(a) 3140 outputs[t_descr.id, a.id] = s 3141 3142 return _AxisSizes(inputs=inputs, outputs=outputs) 3143 3144 @model_validator(mode="before") 3145 @classmethod 3146 def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]: 3147 cls.convert_from_old_format_wo_validation(data) 3148 return data 3149 3150 @classmethod 3151 def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None: 3152 """Convert metadata following an older format version to this classes' format 3153 without validating the result. 3154 """ 3155 if ( 3156 data.get("type") == "model" 3157 and isinstance(fv := data.get("format_version"), str) 3158 and fv.count(".") == 2 3159 ): 3160 fv_parts = fv.split(".") 3161 if any(not p.isdigit() for p in fv_parts): 3162 return 3163 3164 fv_tuple = tuple(map(int, fv_parts)) 3165 3166 assert cls.implemented_format_version_tuple[0:2] == (0, 5) 3167 if fv_tuple[:2] in ((0, 3), (0, 4)): 3168 m04 = _ModelDescr_v0_4.load(data) 3169 if isinstance(m04, InvalidDescr): 3170 try: 3171 updated = _model_conv.convert_as_dict( 3172 m04 # pyright: ignore[reportArgumentType] 3173 ) 3174 except Exception as e: 3175 logger.error( 3176 "Failed to convert from invalid model 0.4 description." 3177 + f"\nerror: {e}" 3178 + "\nProceeding with model 0.5 validation without conversion." 3179 ) 3180 updated = None 3181 else: 3182 updated = _model_conv.convert_as_dict(m04) 3183 3184 if updated is not None: 3185 data.clear() 3186 data.update(updated) 3187 3188 elif fv_tuple[:2] == (0, 5): 3189 # bump patch version 3190 data["format_version"] = cls.implemented_format_version 3191 3192 3193class _ModelConv(Converter[_ModelDescr_v0_4, ModelDescr]): 3194 def _convert( 3195 self, src: _ModelDescr_v0_4, tgt: "type[ModelDescr] | type[dict[str, Any]]" 3196 ) -> "ModelDescr | dict[str, Any]": 3197 name = "".join( 3198 c if c in string.ascii_letters + string.digits + "_+- ()" else " " 3199 for c in src.name 3200 ) 3201 3202 def conv_authors(auths: Optional[Sequence[_Author_v0_4]]): 3203 conv = ( 3204 _author_conv.convert if TYPE_CHECKING else _author_conv.convert_as_dict 3205 ) 3206 return None if auths is None else [conv(a) for a in auths] 3207 3208 if TYPE_CHECKING: 3209 arch_file_conv = _arch_file_conv.convert 3210 arch_lib_conv = _arch_lib_conv.convert 3211 else: 3212 arch_file_conv = _arch_file_conv.convert_as_dict 3213 arch_lib_conv = _arch_lib_conv.convert_as_dict 3214 3215 input_size_refs = { 3216 ipt.name: { 3217 a: s 3218 for a, s in zip( 3219 ipt.axes, 3220 ( 3221 ipt.shape.min 3222 if isinstance(ipt.shape, _ParameterizedInputShape_v0_4) 3223 else ipt.shape 3224 ), 3225 ) 3226 } 3227 for ipt in src.inputs 3228 if ipt.shape 3229 } 3230 output_size_refs = { 3231 **{ 3232 out.name: {a: s for a, s in zip(out.axes, out.shape)} 3233 for out in src.outputs 3234 if not isinstance(out.shape, _ImplicitOutputShape_v0_4) 3235 }, 3236 **input_size_refs, 3237 } 3238 3239 return tgt( 3240 attachments=( 3241 [] 3242 if src.attachments is None 3243 else [FileDescr(source=f) for f in src.attachments.files] 3244 ), 3245 authors=[ 3246 _author_conv.convert_as_dict(a) for a in src.authors 3247 ], # pyright: ignore[reportArgumentType] 3248 cite=[ 3249 {"text": c.text, "doi": c.doi, "url": c.url} for c in src.cite 3250 ], # pyright: ignore[reportArgumentType] 3251 config=src.config, # pyright: ignore[reportArgumentType] 3252 covers=src.covers, 3253 description=src.description, 3254 documentation=src.documentation, 3255 format_version="0.5.4", 3256 git_repo=src.git_repo, # pyright: ignore[reportArgumentType] 3257 icon=src.icon, 3258 id=None if src.id is None else ModelId(src.id), 3259 id_emoji=src.id_emoji, 3260 license=src.license, # type: ignore 3261 links=src.links, 3262 maintainers=[ 3263 _maintainer_conv.convert_as_dict(m) for m in src.maintainers 3264 ], # pyright: ignore[reportArgumentType] 3265 name=name, 3266 tags=src.tags, 3267 type=src.type, 3268 uploader=src.uploader, 3269 version=src.version, 3270 inputs=[ # pyright: ignore[reportArgumentType] 3271 _input_tensor_conv.convert_as_dict(ipt, tt, st, input_size_refs) 3272 for ipt, tt, st, in zip( 3273 src.inputs, 3274 src.test_inputs, 3275 src.sample_inputs or [None] * len(src.test_inputs), 3276 ) 3277 ], 3278 outputs=[ # pyright: ignore[reportArgumentType] 3279 _output_tensor_conv.convert_as_dict(out, tt, st, output_size_refs) 3280 for out, tt, st, in zip( 3281 src.outputs, 3282 src.test_outputs, 3283 src.sample_outputs or [None] * len(src.test_outputs), 3284 ) 3285 ], 3286 parent=( 3287 None 3288 if src.parent is None 3289 else LinkedModel( 3290 id=ModelId( 3291 str(src.parent.id) 3292 + ( 3293 "" 3294 if src.parent.version_number is None 3295 else f"/{src.parent.version_number}" 3296 ) 3297 ) 3298 ) 3299 ), 3300 training_data=( 3301 None 3302 if src.training_data is None 3303 else ( 3304 LinkedDataset( 3305 id=DatasetId( 3306 str(src.training_data.id) 3307 + ( 3308 "" 3309 if src.training_data.version_number is None 3310 else f"/{src.training_data.version_number}" 3311 ) 3312 ) 3313 ) 3314 if isinstance(src.training_data, LinkedDataset02) 3315 else src.training_data 3316 ) 3317 ), 3318 packaged_by=[ 3319 _author_conv.convert_as_dict(a) for a in src.packaged_by 3320 ], # pyright: ignore[reportArgumentType] 3321 run_mode=src.run_mode, 3322 timestamp=src.timestamp, 3323 weights=(WeightsDescr if TYPE_CHECKING else dict)( 3324 keras_hdf5=(w := src.weights.keras_hdf5) 3325 and (KerasHdf5WeightsDescr if TYPE_CHECKING else dict)( 3326 authors=conv_authors(w.authors), 3327 source=w.source, 3328 tensorflow_version=w.tensorflow_version or Version("1.15"), 3329 parent=w.parent, 3330 ), 3331 onnx=(w := src.weights.onnx) 3332 and (OnnxWeightsDescr if TYPE_CHECKING else dict)( 3333 source=w.source, 3334 authors=conv_authors(w.authors), 3335 parent=w.parent, 3336 opset_version=w.opset_version or 15, 3337 ), 3338 pytorch_state_dict=(w := src.weights.pytorch_state_dict) 3339 and (PytorchStateDictWeightsDescr if TYPE_CHECKING else dict)( 3340 source=w.source, 3341 authors=conv_authors(w.authors), 3342 parent=w.parent, 3343 architecture=( 3344 arch_file_conv( 3345 w.architecture, 3346 w.architecture_sha256, 3347 w.kwargs, 3348 ) 3349 if isinstance(w.architecture, _CallableFromFile_v0_4) 3350 else arch_lib_conv(w.architecture, w.kwargs) 3351 ), 3352 pytorch_version=w.pytorch_version or Version("1.10"), 3353 dependencies=( 3354 None 3355 if w.dependencies is None 3356 else (FileDescr if TYPE_CHECKING else dict)( 3357 source=cast( 3358 FileSource, 3359 str(deps := w.dependencies)[ 3360 ( 3361 len("conda:") 3362 if str(deps).startswith("conda:") 3363 else 0 3364 ) : 3365 ], 3366 ) 3367 ) 3368 ), 3369 ), 3370 tensorflow_js=(w := src.weights.tensorflow_js) 3371 and (TensorflowJsWeightsDescr if TYPE_CHECKING else dict)( 3372 source=w.source, 3373 authors=conv_authors(w.authors), 3374 parent=w.parent, 3375 tensorflow_version=w.tensorflow_version or Version("1.15"), 3376 ), 3377 tensorflow_saved_model_bundle=( 3378 w := src.weights.tensorflow_saved_model_bundle 3379 ) 3380 and (TensorflowSavedModelBundleWeightsDescr if TYPE_CHECKING else dict)( 3381 authors=conv_authors(w.authors), 3382 parent=w.parent, 3383 source=w.source, 3384 tensorflow_version=w.tensorflow_version or Version("1.15"), 3385 dependencies=( 3386 None 3387 if w.dependencies is None 3388 else (FileDescr if TYPE_CHECKING else dict)( 3389 source=cast( 3390 FileSource, 3391 ( 3392 str(w.dependencies)[len("conda:") :] 3393 if str(w.dependencies).startswith("conda:") 3394 else str(w.dependencies) 3395 ), 3396 ) 3397 ) 3398 ), 3399 ), 3400 torchscript=(w := src.weights.torchscript) 3401 and (TorchscriptWeightsDescr if TYPE_CHECKING else dict)( 3402 source=w.source, 3403 authors=conv_authors(w.authors), 3404 parent=w.parent, 3405 pytorch_version=w.pytorch_version or Version("1.10"), 3406 ), 3407 ), 3408 ) 3409 3410 3411_model_conv = _ModelConv(_ModelDescr_v0_4, ModelDescr) 3412 3413 3414# create better cover images for 3d data and non-image outputs 3415def generate_covers( 3416 inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]], 3417 outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]], 3418) -> List[Path]: 3419 def squeeze( 3420 data: NDArray[Any], axes: Sequence[AnyAxis] 3421 ) -> Tuple[NDArray[Any], List[AnyAxis]]: 3422 """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining""" 3423 if data.ndim != len(axes): 3424 raise ValueError( 3425 f"tensor shape {data.shape} does not match described axes" 3426 + f" {[a.id for a in axes]}" 3427 ) 3428 3429 axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1] 3430 return data.squeeze(), axes 3431 3432 def normalize( 3433 data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7 3434 ) -> NDArray[np.float32]: 3435 data = data.astype("float32") 3436 data -= data.min(axis=axis, keepdims=True) 3437 data /= data.max(axis=axis, keepdims=True) + eps 3438 return data 3439 3440 def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]): 3441 original_shape = data.shape 3442 data, axes = squeeze(data, axes) 3443 3444 # take slice fom any batch or index axis if needed 3445 # and convert the first channel axis and take a slice from any additional channel axes 3446 slices: Tuple[slice, ...] = () 3447 ndim = data.ndim 3448 ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2 3449 has_c_axis = False 3450 for i, a in enumerate(axes): 3451 s = data.shape[i] 3452 assert s > 1 3453 if ( 3454 isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis)) 3455 and ndim > ndim_need 3456 ): 3457 data = data[slices + (slice(s // 2 - 1, s // 2),)] 3458 ndim -= 1 3459 elif isinstance(a, ChannelAxis): 3460 if has_c_axis: 3461 # second channel axis 3462 data = data[slices + (slice(0, 1),)] 3463 ndim -= 1 3464 else: 3465 has_c_axis = True 3466 if s == 2: 3467 # visualize two channels with cyan and magenta 3468 data = np.concatenate( 3469 [ 3470 data[slices + (slice(1, 2),)], 3471 data[slices + (slice(0, 1),)], 3472 ( 3473 data[slices + (slice(0, 1),)] 3474 + data[slices + (slice(1, 2),)] 3475 ) 3476 / 2, # TODO: take maximum instead? 3477 ], 3478 axis=i, 3479 ) 3480 elif data.shape[i] == 3: 3481 pass # visualize 3 channels as RGB 3482 else: 3483 # visualize first 3 channels as RGB 3484 data = data[slices + (slice(3),)] 3485 3486 assert data.shape[i] == 3 3487 3488 slices += (slice(None),) 3489 3490 data, axes = squeeze(data, axes) 3491 assert len(axes) == ndim 3492 # take slice from z axis if needed 3493 slices = () 3494 if ndim > ndim_need: 3495 for i, a in enumerate(axes): 3496 s = data.shape[i] 3497 if a.id == AxisId("z"): 3498 data = data[slices + (slice(s // 2 - 1, s // 2),)] 3499 data, axes = squeeze(data, axes) 3500 ndim -= 1 3501 break 3502 3503 slices += (slice(None),) 3504 3505 # take slice from any space or time axis 3506 slices = () 3507 3508 for i, a in enumerate(axes): 3509 if ndim <= ndim_need: 3510 break 3511 3512 s = data.shape[i] 3513 assert s > 1 3514 if isinstance( 3515 a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis) 3516 ): 3517 data = data[slices + (slice(s // 2 - 1, s // 2),)] 3518 ndim -= 1 3519 3520 slices += (slice(None),) 3521 3522 del slices 3523 data, axes = squeeze(data, axes) 3524 assert len(axes) == ndim 3525 3526 if (has_c_axis and ndim != 3) or ndim != 2: 3527 raise ValueError( 3528 f"Failed to construct cover image from shape {original_shape}" 3529 ) 3530 3531 if not has_c_axis: 3532 assert ndim == 2 3533 data = np.repeat(data[:, :, None], 3, axis=2) 3534 axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB")))) 3535 ndim += 1 3536 3537 assert ndim == 3 3538 3539 # transpose axis order such that longest axis comes first... 3540 axis_order: List[int] = list(np.argsort(list(data.shape))) 3541 axis_order.reverse() 3542 # ... and channel axis is last 3543 c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0] 3544 axis_order.append(axis_order.pop(c)) 3545 axes = [axes[ao] for ao in axis_order] 3546 data = data.transpose(axis_order) 3547 3548 # h, w = data.shape[:2] 3549 # if h / w in (1.0 or 2.0): 3550 # pass 3551 # elif h / w < 2: 3552 # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images 3553 3554 norm_along = ( 3555 tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None 3556 ) 3557 # normalize the data and map to 8 bit 3558 data = normalize(data, norm_along) 3559 data = (data * 255).astype("uint8") 3560 3561 return data 3562 3563 def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]): 3564 assert im0.dtype == im1.dtype == np.uint8 3565 assert im0.shape == im1.shape 3566 assert im0.ndim == 3 3567 N, M, C = im0.shape 3568 assert C == 3 3569 out = np.ones((N, M, C), dtype="uint8") 3570 for c in range(C): 3571 outc = np.tril(im0[..., c]) 3572 mask = outc == 0 3573 outc[mask] = np.triu(im1[..., c])[mask] 3574 out[..., c] = outc 3575 3576 return out 3577 3578 ipt_descr, ipt = inputs[0] 3579 out_descr, out = outputs[0] 3580 3581 ipt_img = to_2d_image(ipt, ipt_descr.axes) 3582 out_img = to_2d_image(out, out_descr.axes) 3583 3584 cover_folder = Path(mkdtemp()) 3585 if ipt_img.shape == out_img.shape: 3586 covers = [cover_folder / "cover.png"] 3587 imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img)) 3588 else: 3589 covers = [cover_folder / "input.png", cover_folder / "output.png"] 3590 imwrite(covers[0], ipt_img) 3591 imwrite(covers[1], out_img) 3592 3593 return covers
Space unit compatible to the OME-Zarr axes specification 0.5
Time unit compatible to the OME-Zarr axes specification 0.5
226class TensorId(LowerCaseIdentifier): 227 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 228 Annotated[LowerCaseIdentifierAnno, MaxLen(32)] 229 ]
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.
242class AxisId(LowerCaseIdentifier): 243 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 244 Annotated[ 245 LowerCaseIdentifierAnno, 246 MaxLen(16), 247 AfterValidator(_normalize_axis_id), 248 ] 249 ]
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.
Annotates an integer to calculate a concrete axis size from a ParameterizedSize
.
293class ParameterizedSize(Node): 294 """Describes a range of valid tensor axis sizes as `size = min + n*step`. 295 296 - **min** and **step** are given by the model description. 297 - All blocksize paramters n = 0,1,2,... yield a valid `size`. 298 - A greater blocksize paramter n = 0,1,2,... results in a greater **size**. 299 This allows to adjust the axis size more generically. 300 """ 301 302 N: ClassVar[Type[int]] = ParameterizedSize_N 303 """Positive integer to parameterize this axis""" 304 305 min: Annotated[int, Gt(0)] 306 step: Annotated[int, Gt(0)] 307 308 def validate_size(self, size: int) -> int: 309 if size < self.min: 310 raise ValueError(f"size {size} < {self.min}") 311 if (size - self.min) % self.step != 0: 312 raise ValueError( 313 f"axis of size {size} is not parameterized by `min + n*step` =" 314 + f" `{self.min} + n*{self.step}`" 315 ) 316 317 return size 318 319 def get_size(self, n: ParameterizedSize_N) -> int: 320 return self.min + self.step * n 321 322 def get_n(self, s: int) -> ParameterizedSize_N: 323 """return smallest n parameterizing a size greater or equal than `s`""" 324 return ceil((s - self.min) / self.step)
Describes a range of valid tensor axis sizes as size = min + n*step
.
- min and step are given by the model description.
- All blocksize paramters n = 0,1,2,... yield a valid
size
. - A greater blocksize paramter n = 0,1,2,... results in a greater size. This allows to adjust the axis size more generically.
308 def validate_size(self, size: int) -> int: 309 if size < self.min: 310 raise ValueError(f"size {size} < {self.min}") 311 if (size - self.min) % self.step != 0: 312 raise ValueError( 313 f"axis of size {size} is not parameterized by `min + n*step` =" 314 + f" `{self.min} + n*{self.step}`" 315 ) 316 317 return size
322 def get_n(self, s: int) -> ParameterizedSize_N: 323 """return smallest n parameterizing a size greater or equal than `s`""" 324 return ceil((s - self.min) / self.step)
return smallest n parameterizing a size greater or equal than s
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
327class DataDependentSize(Node): 328 min: Annotated[int, Gt(0)] = 1 329 max: Annotated[Optional[int], Gt(1)] = None 330 331 @model_validator(mode="after") 332 def _validate_max_gt_min(self): 333 if self.max is not None and self.min >= self.max: 334 raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}") 335 336 return self 337 338 def validate_size(self, size: int) -> int: 339 if size < self.min: 340 raise ValueError(f"size {size} < {self.min}") 341 342 if self.max is not None and size > self.max: 343 raise ValueError(f"size {size} > {self.max}") 344 345 return size
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
348class SizeReference(Node): 349 """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis. 350 351 `axis.size = reference.size * reference.scale / axis.scale + offset` 352 353 Note: 354 1. The axis and the referenced axis need to have the same unit (or no unit). 355 2. Batch axes may not be referenced. 356 3. Fractions are rounded down. 357 4. If the reference axis is `concatenable` the referencing axis is assumed to be 358 `concatenable` as well with the same block order. 359 360 Example: 361 An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm². 362 Let's assume that we want to express the image height h in relation to its width w 363 instead of only accepting input images of exactly 100*49 pixels 364 (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`). 365 366 >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2) 367 >>> h = SpaceInputAxis( 368 ... id=AxisId("h"), 369 ... size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1), 370 ... unit="millimeter", 371 ... scale=4, 372 ... ) 373 >>> print(h.size.get_size(h, w)) 374 49 375 376 ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49 377 """ 378 379 tensor_id: TensorId 380 """tensor id of the reference axis""" 381 382 axis_id: AxisId 383 """axis id of the reference axis""" 384 385 offset: int = 0 386 387 def get_size( 388 self, 389 axis: Union[ 390 ChannelAxis, 391 IndexInputAxis, 392 IndexOutputAxis, 393 TimeInputAxis, 394 SpaceInputAxis, 395 TimeOutputAxis, 396 TimeOutputAxisWithHalo, 397 SpaceOutputAxis, 398 SpaceOutputAxisWithHalo, 399 ], 400 ref_axis: Union[ 401 ChannelAxis, 402 IndexInputAxis, 403 IndexOutputAxis, 404 TimeInputAxis, 405 SpaceInputAxis, 406 TimeOutputAxis, 407 TimeOutputAxisWithHalo, 408 SpaceOutputAxis, 409 SpaceOutputAxisWithHalo, 410 ], 411 n: ParameterizedSize_N = 0, 412 ref_size: Optional[int] = None, 413 ): 414 """Compute the concrete size for a given axis and its reference axis. 415 416 Args: 417 axis: The axis this `SizeReference` is the size of. 418 ref_axis: The reference axis to compute the size from. 419 n: If the **ref_axis** is parameterized (of type `ParameterizedSize`) 420 and no fixed **ref_size** is given, 421 **n** is used to compute the size of the parameterized **ref_axis**. 422 ref_size: Overwrite the reference size instead of deriving it from 423 **ref_axis** 424 (**ref_axis.scale** is still used; any given **n** is ignored). 425 """ 426 assert ( 427 axis.size == self 428 ), "Given `axis.size` is not defined by this `SizeReference`" 429 430 assert ( 431 ref_axis.id == self.axis_id 432 ), f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}." 433 434 assert axis.unit == ref_axis.unit, ( 435 "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`," 436 f" but {axis.unit}!={ref_axis.unit}" 437 ) 438 if ref_size is None: 439 if isinstance(ref_axis.size, (int, float)): 440 ref_size = ref_axis.size 441 elif isinstance(ref_axis.size, ParameterizedSize): 442 ref_size = ref_axis.size.get_size(n) 443 elif isinstance(ref_axis.size, DataDependentSize): 444 raise ValueError( 445 "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`." 446 ) 447 elif isinstance(ref_axis.size, SizeReference): 448 raise ValueError( 449 "Reference axis referenced in `SizeReference` may not be sized by a" 450 + " `SizeReference` itself." 451 ) 452 else: 453 assert_never(ref_axis.size) 454 455 return int(ref_size * ref_axis.scale / axis.scale + self.offset) 456 457 @staticmethod 458 def _get_unit( 459 axis: Union[ 460 ChannelAxis, 461 IndexInputAxis, 462 IndexOutputAxis, 463 TimeInputAxis, 464 SpaceInputAxis, 465 TimeOutputAxis, 466 TimeOutputAxisWithHalo, 467 SpaceOutputAxis, 468 SpaceOutputAxisWithHalo, 469 ], 470 ): 471 return axis.unit
A tensor axis size (extent in pixels/frames) defined in relation to a reference axis.
axis.size = reference.size * reference.scale / axis.scale + offset
Note:
- The axis and the referenced axis need to have the same unit (or no unit).
- Batch axes may not be referenced.
- Fractions are rounded down.
- If the reference axis is
concatenable
the referencing axis is assumed to beconcatenable
as well with the same block order.
Example:
An unisotropic input image of wh=10049 pixels depicts a phsical space of 200196mm².
Let's assume that we want to express the image height h in relation to its width w
instead of only accepting input images of exactly 10049 pixels
(for example to express a range of valid image shapes by parametrizing w, see ParameterizedSize
).
>>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2)
>>> h = SpaceInputAxis(
... id=AxisId("h"),
... size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1),
... unit="millimeter",
... scale=4,
... )
>>> print(h.size.get_size(h, w))
49
⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49
387 def get_size( 388 self, 389 axis: Union[ 390 ChannelAxis, 391 IndexInputAxis, 392 IndexOutputAxis, 393 TimeInputAxis, 394 SpaceInputAxis, 395 TimeOutputAxis, 396 TimeOutputAxisWithHalo, 397 SpaceOutputAxis, 398 SpaceOutputAxisWithHalo, 399 ], 400 ref_axis: Union[ 401 ChannelAxis, 402 IndexInputAxis, 403 IndexOutputAxis, 404 TimeInputAxis, 405 SpaceInputAxis, 406 TimeOutputAxis, 407 TimeOutputAxisWithHalo, 408 SpaceOutputAxis, 409 SpaceOutputAxisWithHalo, 410 ], 411 n: ParameterizedSize_N = 0, 412 ref_size: Optional[int] = None, 413 ): 414 """Compute the concrete size for a given axis and its reference axis. 415 416 Args: 417 axis: The axis this `SizeReference` is the size of. 418 ref_axis: The reference axis to compute the size from. 419 n: If the **ref_axis** is parameterized (of type `ParameterizedSize`) 420 and no fixed **ref_size** is given, 421 **n** is used to compute the size of the parameterized **ref_axis**. 422 ref_size: Overwrite the reference size instead of deriving it from 423 **ref_axis** 424 (**ref_axis.scale** is still used; any given **n** is ignored). 425 """ 426 assert ( 427 axis.size == self 428 ), "Given `axis.size` is not defined by this `SizeReference`" 429 430 assert ( 431 ref_axis.id == self.axis_id 432 ), f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}." 433 434 assert axis.unit == ref_axis.unit, ( 435 "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`," 436 f" but {axis.unit}!={ref_axis.unit}" 437 ) 438 if ref_size is None: 439 if isinstance(ref_axis.size, (int, float)): 440 ref_size = ref_axis.size 441 elif isinstance(ref_axis.size, ParameterizedSize): 442 ref_size = ref_axis.size.get_size(n) 443 elif isinstance(ref_axis.size, DataDependentSize): 444 raise ValueError( 445 "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`." 446 ) 447 elif isinstance(ref_axis.size, SizeReference): 448 raise ValueError( 449 "Reference axis referenced in `SizeReference` may not be sized by a" 450 + " `SizeReference` itself." 451 ) 452 else: 453 assert_never(ref_axis.size) 454 455 return int(ref_size * ref_axis.scale / axis.scale + self.offset)
Compute the concrete size for a given axis and its reference axis.
Arguments:
- axis: The axis this
SizeReference
is the size of. - ref_axis: The reference axis to compute the size from.
- n: If the ref_axis is parameterized (of type
ParameterizedSize
) and no fixed ref_size is given, n is used to compute the size of the parameterized ref_axis. - ref_size: Overwrite the reference size instead of deriving it from ref_axis (ref_axis.scale is still used; any given n is ignored).
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
474class AxisBase(NodeWithExplicitlySetFields): 475 id: AxisId 476 """An axis id unique across all axes of one tensor.""" 477 478 description: Annotated[str, MaxLen(128)] = ""
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
481class WithHalo(Node): 482 halo: Annotated[int, Ge(1)] 483 """The halo should be cropped from the output tensor to avoid boundary effects. 484 It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`. 485 To document a halo that is already cropped by the model use `size.offset` instead.""" 486 487 size: Annotated[ 488 SizeReference, 489 Field( 490 examples=[ 491 10, 492 SizeReference( 493 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 494 ).model_dump(mode="json"), 495 ] 496 ), 497 ] 498 """reference to another axis with an optional offset (see `SizeReference`)"""
The halo should be cropped from the output tensor to avoid boundary effects.
It is to be cropped from both sides, i.e. size_after_crop = size - 2 * halo
.
To document a halo that is already cropped by the model use size.offset
instead.
reference to another axis with an optional offset (see SizeReference
)
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
504class BatchAxis(AxisBase): 505 implemented_type: ClassVar[Literal["batch"]] = "batch" 506 if TYPE_CHECKING: 507 type: Literal["batch"] = "batch" 508 else: 509 type: Literal["batch"] 510 511 id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID 512 size: Optional[Literal[1]] = None 513 """The batch size may be fixed to 1, 514 otherwise (the default) it may be chosen arbitrarily depending on available memory""" 515 516 @property 517 def scale(self): 518 return 1.0 519 520 @property 521 def concatenable(self): 522 return True 523 524 @property 525 def unit(self): 526 return None
The batch size may be fixed to 1, otherwise (the default) it may be chosen arbitrarily depending on available memory
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
529class ChannelAxis(AxisBase): 530 implemented_type: ClassVar[Literal["channel"]] = "channel" 531 if TYPE_CHECKING: 532 type: Literal["channel"] = "channel" 533 else: 534 type: Literal["channel"] 535 536 id: NonBatchAxisId = AxisId("channel") 537 channel_names: NotEmpty[List[Identifier]] 538 539 @property 540 def size(self) -> int: 541 return len(self.channel_names) 542 543 @property 544 def concatenable(self): 545 return False 546 547 @property 548 def scale(self) -> float: 549 return 1.0 550 551 @property 552 def unit(self): 553 return None
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
556class IndexAxisBase(AxisBase): 557 implemented_type: ClassVar[Literal["index"]] = "index" 558 if TYPE_CHECKING: 559 type: Literal["index"] = "index" 560 else: 561 type: Literal["index"] 562 563 id: NonBatchAxisId = AxisId("index") 564 565 @property 566 def scale(self) -> float: 567 return 1.0 568 569 @property 570 def unit(self): 571 return None
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
594class IndexInputAxis(IndexAxisBase, _WithInputAxisSize): 595 concatenable: bool = False 596 """If a model has a `concatenable` input axis, it can be processed blockwise, 597 splitting a longer sample axis into blocks matching its input tensor description. 598 Output axes are concatenable if they have a `SizeReference` to a concatenable 599 input axis. 600 """
If a model has a concatenable
input axis, it can be processed blockwise,
splitting a longer sample axis into blocks matching its input tensor description.
Output axes are concatenable if they have a SizeReference
to a concatenable
input axis.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
603class IndexOutputAxis(IndexAxisBase): 604 size: Annotated[ 605 Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize], 606 Field( 607 examples=[ 608 10, 609 SizeReference( 610 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 611 ).model_dump(mode="json"), 612 ] 613 ), 614 ] 615 """The size/length of this axis can be specified as 616 - fixed integer 617 - reference to another axis with an optional offset (`SizeReference`) 618 - data dependent size using `DataDependentSize` (size is only known after model inference) 619 """
The size/length of this axis can be specified as
- fixed integer
- reference to another axis with an optional offset (
SizeReference
) - data dependent size using
DataDependentSize
(size is only known after model inference)
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
622class TimeAxisBase(AxisBase): 623 implemented_type: ClassVar[Literal["time"]] = "time" 624 if TYPE_CHECKING: 625 type: Literal["time"] = "time" 626 else: 627 type: Literal["time"] 628 629 id: NonBatchAxisId = AxisId("time") 630 unit: Optional[TimeUnit] = None 631 scale: Annotated[float, Gt(0)] = 1.0
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
634class TimeInputAxis(TimeAxisBase, _WithInputAxisSize): 635 concatenable: bool = False 636 """If a model has a `concatenable` input axis, it can be processed blockwise, 637 splitting a longer sample axis into blocks matching its input tensor description. 638 Output axes are concatenable if they have a `SizeReference` to a concatenable 639 input axis. 640 """
If a model has a concatenable
input axis, it can be processed blockwise,
splitting a longer sample axis into blocks matching its input tensor description.
Output axes are concatenable if they have a SizeReference
to a concatenable
input axis.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
643class SpaceAxisBase(AxisBase): 644 implemented_type: ClassVar[Literal["space"]] = "space" 645 if TYPE_CHECKING: 646 type: Literal["space"] = "space" 647 else: 648 type: Literal["space"] 649 650 id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x") 651 unit: Optional[SpaceUnit] = None 652 scale: Annotated[float, Gt(0)] = 1.0
An axis id unique across all axes of one tensor.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
655class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize): 656 concatenable: bool = False 657 """If a model has a `concatenable` input axis, it can be processed blockwise, 658 splitting a longer sample axis into blocks matching its input tensor description. 659 Output axes are concatenable if they have a `SizeReference` to a concatenable 660 input axis. 661 """
If a model has a concatenable
input axis, it can be processed blockwise,
splitting a longer sample axis into blocks matching its input tensor description.
Output axes are concatenable if they have a SizeReference
to a concatenable
input axis.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
intended for isinstance comparisons in py<3.10
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
intended for isinstance comparisons in py<3.10
intended for isinstance comparisons in py<3.10
783class NominalOrOrdinalDataDescr(Node): 784 values: TVs 785 """A fixed set of nominal or an ascending sequence of ordinal values. 786 In this case `data.type` is required to be an unsigend integer type, e.g. 'uint8'. 787 String `values` are interpreted as labels for tensor values 0, ..., N. 788 Note: as YAML 1.2 does not natively support a "set" datatype, 789 nominal values should be given as a sequence (aka list/array) as well. 790 """ 791 792 type: Annotated[ 793 NominalOrOrdinalDType, 794 Field( 795 examples=[ 796 "float32", 797 "uint8", 798 "uint16", 799 "int64", 800 "bool", 801 ], 802 ), 803 ] = "uint8" 804 805 @model_validator(mode="after") 806 def _validate_values_match_type( 807 self, 808 ) -> Self: 809 incompatible: List[Any] = [] 810 for v in self.values: 811 if self.type == "bool": 812 if not isinstance(v, bool): 813 incompatible.append(v) 814 elif self.type in DTYPE_LIMITS: 815 if ( 816 isinstance(v, (int, float)) 817 and ( 818 v < DTYPE_LIMITS[self.type].min 819 or v > DTYPE_LIMITS[self.type].max 820 ) 821 or (isinstance(v, str) and "uint" not in self.type) 822 or (isinstance(v, float) and "int" in self.type) 823 ): 824 incompatible.append(v) 825 else: 826 incompatible.append(v) 827 828 if len(incompatible) == 5: 829 incompatible.append("...") 830 break 831 832 if incompatible: 833 raise ValueError( 834 f"data type '{self.type}' incompatible with values {incompatible}" 835 ) 836 837 return self 838 839 unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None 840 841 @property 842 def range(self): 843 if isinstance(self.values[0], str): 844 return 0, len(self.values) - 1 845 else: 846 return min(self.values), max(self.values)
A fixed set of nominal or an ascending sequence of ordinal values.
In this case data.type
is required to be an unsigend integer type, e.g. 'uint8'.
String values
are interpreted as labels for tensor values 0, ..., N.
Note: as YAML 1.2 does not natively support a "set" datatype,
nominal values should be given as a sequence (aka list/array) as well.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
863class IntervalOrRatioDataDescr(Node): 864 type: Annotated[ # todo: rename to dtype 865 IntervalOrRatioDType, 866 Field( 867 examples=["float32", "float64", "uint8", "uint16"], 868 ), 869 ] = "float32" 870 range: Tuple[Optional[float], Optional[float]] = ( 871 None, 872 None, 873 ) 874 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor. 875 `None` corresponds to min/max of what can be expressed by **type**.""" 876 unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit" 877 scale: float = 1.0 878 """Scale for data on an interval (or ratio) scale.""" 879 offset: Optional[float] = None 880 """Offset for data on a ratio scale."""
Tuple (minimum, maximum)
specifying the allowed range of the data in this tensor.
None
corresponds to min/max of what can be expressed by type.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
processing base class
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
890class BinarizeKwargs(ProcessingKwargs): 891 """key word arguments for `BinarizeDescr`""" 892 893 threshold: float 894 """The fixed threshold"""
key word arguments for BinarizeDescr
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
897class BinarizeAlongAxisKwargs(ProcessingKwargs): 898 """key word arguments for `BinarizeDescr`""" 899 900 threshold: NotEmpty[List[float]] 901 """The fixed threshold values along `axis`""" 902 903 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 904 """The `threshold` axis"""
key word arguments for BinarizeDescr
The threshold
axis
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
907class BinarizeDescr(ProcessingDescrBase): 908 """Binarize the tensor with a fixed threshold. 909 910 Values above `BinarizeKwargs.threshold`/`BinarizeAlongAxisKwargs.threshold` 911 will be set to one, values below the threshold to zero. 912 913 Examples: 914 - in YAML 915 ```yaml 916 postprocessing: 917 - id: binarize 918 kwargs: 919 axis: 'channel' 920 threshold: [0.25, 0.5, 0.75] 921 ``` 922 - in Python: 923 >>> postprocessing = [BinarizeDescr( 924 ... kwargs=BinarizeAlongAxisKwargs( 925 ... axis=AxisId('channel'), 926 ... threshold=[0.25, 0.5, 0.75], 927 ... ) 928 ... )] 929 """ 930 931 implemented_id: ClassVar[Literal["binarize"]] = "binarize" 932 if TYPE_CHECKING: 933 id: Literal["binarize"] = "binarize" 934 else: 935 id: Literal["binarize"] 936 kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs]
Binarize the tensor with a fixed threshold.
Values above BinarizeKwargs.threshold
/BinarizeAlongAxisKwargs.threshold
will be set to one, values below the threshold to zero.
Examples:
- in YAML
postprocessing:
- id: binarize
kwargs:
axis: 'channel'
threshold: [0.25, 0.5, 0.75]
- in Python:
>>> postprocessing = [BinarizeDescr( ... kwargs=BinarizeAlongAxisKwargs( ... axis=AxisId('channel'), ... threshold=[0.25, 0.5, 0.75], ... ) ... )]
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
939class ClipDescr(ProcessingDescrBase): 940 """Set tensor values below min to min and above max to max. 941 942 See `ScaleRangeDescr` for examples. 943 """ 944 945 implemented_id: ClassVar[Literal["clip"]] = "clip" 946 if TYPE_CHECKING: 947 id: Literal["clip"] = "clip" 948 else: 949 id: Literal["clip"] 950 951 kwargs: ClipKwargs
Set tensor values below min to min and above max to max.
See ScaleRangeDescr
for examples.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
954class EnsureDtypeKwargs(ProcessingKwargs): 955 """key word arguments for `EnsureDtypeDescr`""" 956 957 dtype: Literal[ 958 "float32", 959 "float64", 960 "uint8", 961 "int8", 962 "uint16", 963 "int16", 964 "uint32", 965 "int32", 966 "uint64", 967 "int64", 968 "bool", 969 ]
key word arguments for EnsureDtypeDescr
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
972class EnsureDtypeDescr(ProcessingDescrBase): 973 """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching). 974 975 This can for example be used to ensure the inner neural network model gets a 976 different input tensor data type than the fully described bioimage.io model does. 977 978 Examples: 979 The described bioimage.io model (incl. preprocessing) accepts any 980 float32-compatible tensor, normalizes it with percentiles and clipping and then 981 casts it to uint8, which is what the neural network in this example expects. 982 - in YAML 983 ```yaml 984 inputs: 985 - data: 986 type: float32 # described bioimage.io model is compatible with any float32 input tensor 987 preprocessing: 988 - id: scale_range 989 kwargs: 990 axes: ['y', 'x'] 991 max_percentile: 99.8 992 min_percentile: 5.0 993 - id: clip 994 kwargs: 995 min: 0.0 996 max: 1.0 997 - id: ensure_dtype # the neural network of the model requires uint8 998 kwargs: 999 dtype: uint8 1000 ``` 1001 - in Python: 1002 >>> preprocessing = [ 1003 ... ScaleRangeDescr( 1004 ... kwargs=ScaleRangeKwargs( 1005 ... axes= (AxisId('y'), AxisId('x')), 1006 ... max_percentile= 99.8, 1007 ... min_percentile= 5.0, 1008 ... ) 1009 ... ), 1010 ... ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)), 1011 ... EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")), 1012 ... ] 1013 """ 1014 1015 implemented_id: ClassVar[Literal["ensure_dtype"]] = "ensure_dtype" 1016 if TYPE_CHECKING: 1017 id: Literal["ensure_dtype"] = "ensure_dtype" 1018 else: 1019 id: Literal["ensure_dtype"] 1020 1021 kwargs: EnsureDtypeKwargs
Cast the tensor data type to EnsureDtypeKwargs.dtype
(if not matching).
This can for example be used to ensure the inner neural network model gets a different input tensor data type than the fully described bioimage.io model does.
Examples:
The described bioimage.io model (incl. preprocessing) accepts any float32-compatible tensor, normalizes it with percentiles and clipping and then casts it to uint8, which is what the neural network in this example expects.
- in YAML
inputs: - data: type: float32 # described bioimage.io model is compatible with any float32 input tensor preprocessing: - id: scale_range kwargs: axes: ['y', 'x'] max_percentile: 99.8 min_percentile: 5.0 - id: clip kwargs: min: 0.0 max: 1.0 - id: ensure_dtype # the neural network of the model requires uint8 kwargs: dtype: uint8
- in Python:
>>> preprocessing = [ ... ScaleRangeDescr( ... kwargs=ScaleRangeKwargs( ... axes= (AxisId('y'), AxisId('x')), ... max_percentile= 99.8, ... min_percentile= 5.0, ... ) ... ), ... ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)), ... EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")), ... ]
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
1024class ScaleLinearKwargs(ProcessingKwargs): 1025 """Key word arguments for `ScaleLinearDescr`""" 1026 1027 gain: float = 1.0 1028 """multiplicative factor""" 1029 1030 offset: float = 0.0 1031 """additive term""" 1032 1033 @model_validator(mode="after") 1034 def _validate(self) -> Self: 1035 if self.gain == 1.0 and self.offset == 0.0: 1036 raise ValueError( 1037 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 1038 + " != 0.0." 1039 ) 1040 1041 return self
Key word arguments for ScaleLinearDescr
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
1044class ScaleLinearAlongAxisKwargs(ProcessingKwargs): 1045 """Key word arguments for `ScaleLinearDescr`""" 1046 1047 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 1048 """The axis of gain and offset values.""" 1049 1050 gain: Union[float, NotEmpty[List[float]]] = 1.0 1051 """multiplicative factor""" 1052 1053 offset: Union[float, NotEmpty[List[float]]] = 0.0 1054 """additive term""" 1055 1056 @model_validator(mode="after") 1057 def _validate(self) -> Self: 1058 1059 if isinstance(self.gain, list): 1060 if isinstance(self.offset, list): 1061 if len(self.gain) != len(self.offset): 1062 raise ValueError( 1063 f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match." 1064 ) 1065 else: 1066 self.offset = [float(self.offset)] * len(self.gain) 1067 elif isinstance(self.offset, list): 1068 self.gain = [float(self.gain)] * len(self.offset) 1069 else: 1070 raise ValueError( 1071 "Do not specify an `axis` for scalar gain and offset values." 1072 ) 1073 1074 if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset): 1075 raise ValueError( 1076 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 1077 + " != 0.0." 1078 ) 1079 1080 return self
Key word arguments for ScaleLinearDescr
The axis of gain and offset values.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
1083class ScaleLinearDescr(ProcessingDescrBase): 1084 """Fixed linear scaling. 1085 1086 Examples: 1087 1. Scale with scalar gain and offset 1088 - in YAML 1089 ```yaml 1090 preprocessing: 1091 - id: scale_linear 1092 kwargs: 1093 gain: 2.0 1094 offset: 3.0 1095 ``` 1096 - in Python: 1097 >>> preprocessing = [ 1098 ... ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0)) 1099 ... ] 1100 1101 2. Independent scaling along an axis 1102 - in YAML 1103 ```yaml 1104 preprocessing: 1105 - id: scale_linear 1106 kwargs: 1107 axis: 'channel' 1108 gain: [1.0, 2.0, 3.0] 1109 ``` 1110 - in Python: 1111 >>> preprocessing = [ 1112 ... ScaleLinearDescr( 1113 ... kwargs=ScaleLinearAlongAxisKwargs( 1114 ... axis=AxisId("channel"), 1115 ... gain=[1.0, 2.0, 3.0], 1116 ... ) 1117 ... ) 1118 ... ] 1119 1120 """ 1121 1122 implemented_id: ClassVar[Literal["scale_linear"]] = "scale_linear" 1123 if TYPE_CHECKING: 1124 id: Literal["scale_linear"] = "scale_linear" 1125 else: 1126 id: Literal["scale_linear"] 1127 kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs]
Fixed linear scaling.
Examples:
Scale with scalar gain and offset
in YAML
preprocessing: - id: scale_linear kwargs: gain: 2.0 offset: 3.0
in Python:
>>> preprocessing = [ ... ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0)) ... ]
Independent scaling along an axis
in YAML
preprocessing: - id: scale_linear kwargs: axis: 'channel' gain: [1.0, 2.0, 3.0]
in Python:
>>> preprocessing = [ ... ScaleLinearDescr( ... kwargs=ScaleLinearAlongAxisKwargs( ... axis=AxisId("channel"), ... gain=[1.0, 2.0, 3.0], ... ) ... ) ... ]
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
1130class SigmoidDescr(ProcessingDescrBase): 1131 """The logistic sigmoid funciton, a.k.a. expit function. 1132 1133 Examples: 1134 - in YAML 1135 ```yaml 1136 postprocessing: 1137 - id: sigmoid 1138 ``` 1139 - in Python: 1140 >>> postprocessing = [SigmoidDescr()] 1141 """ 1142 1143 implemented_id: ClassVar[Literal["sigmoid"]] = "sigmoid" 1144 if TYPE_CHECKING: 1145 id: Literal["sigmoid"] = "sigmoid" 1146 else: 1147 id: Literal["sigmoid"] 1148 1149 @property 1150 def kwargs(self) -> ProcessingKwargs: 1151 """empty kwargs""" 1152 return ProcessingKwargs()
The logistic sigmoid funciton, a.k.a. expit function.
Examples:
- in YAML
postprocessing:
- id: sigmoid
- in Python:
>>> postprocessing = [SigmoidDescr()]
1149 @property 1150 def kwargs(self) -> ProcessingKwargs: 1151 """empty kwargs""" 1152 return ProcessingKwargs()
empty kwargs
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
1155class FixedZeroMeanUnitVarianceKwargs(ProcessingKwargs): 1156 """key word arguments for `FixedZeroMeanUnitVarianceDescr`""" 1157 1158 mean: float 1159 """The mean value to normalize with.""" 1160 1161 std: Annotated[float, Ge(1e-6)] 1162 """The standard deviation value to normalize with."""
key word arguments for FixedZeroMeanUnitVarianceDescr
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
1165class FixedZeroMeanUnitVarianceAlongAxisKwargs(ProcessingKwargs): 1166 """key word arguments for `FixedZeroMeanUnitVarianceDescr`""" 1167 1168 mean: NotEmpty[List[float]] 1169 """The mean value(s) to normalize with.""" 1170 1171 std: NotEmpty[List[Annotated[float, Ge(1e-6)]]] 1172 """The standard deviation value(s) to normalize with. 1173 Size must match `mean` values.""" 1174 1175 axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])] 1176 """The axis of the mean/std values to normalize each entry along that dimension 1177 separately.""" 1178 1179 @model_validator(mode="after") 1180 def _mean_and_std_match(self) -> Self: 1181 if len(self.mean) != len(self.std): 1182 raise ValueError( 1183 f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})" 1184 + " must match." 1185 ) 1186 1187 return self
key word arguments for FixedZeroMeanUnitVarianceDescr
The standard deviation value(s) to normalize with.
Size must match mean
values.
The axis of the mean/std values to normalize each entry along that dimension separately.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
1190class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase): 1191 """Subtract a given mean and divide by the standard deviation. 1192 1193 Normalize with fixed, precomputed values for 1194 `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std` 1195 Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given 1196 axes. 1197 1198 Examples: 1199 1. scalar value for whole tensor 1200 - in YAML 1201 ```yaml 1202 preprocessing: 1203 - id: fixed_zero_mean_unit_variance 1204 kwargs: 1205 mean: 103.5 1206 std: 13.7 1207 ``` 1208 - in Python 1209 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 1210 ... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7) 1211 ... )] 1212 1213 2. independently along an axis 1214 - in YAML 1215 ```yaml 1216 preprocessing: 1217 - id: fixed_zero_mean_unit_variance 1218 kwargs: 1219 axis: channel 1220 mean: [101.5, 102.5, 103.5] 1221 std: [11.7, 12.7, 13.7] 1222 ``` 1223 - in Python 1224 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 1225 ... kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( 1226 ... axis=AxisId("channel"), 1227 ... mean=[101.5, 102.5, 103.5], 1228 ... std=[11.7, 12.7, 13.7], 1229 ... ) 1230 ... )] 1231 """ 1232 1233 implemented_id: ClassVar[Literal["fixed_zero_mean_unit_variance"]] = ( 1234 "fixed_zero_mean_unit_variance" 1235 ) 1236 if TYPE_CHECKING: 1237 id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance" 1238 else: 1239 id: Literal["fixed_zero_mean_unit_variance"] 1240 1241 kwargs: Union[ 1242 FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs 1243 ]
Subtract a given mean and divide by the standard deviation.
Normalize with fixed, precomputed values for
FixedZeroMeanUnitVarianceKwargs.mean
and FixedZeroMeanUnitVarianceKwargs.std
Use FixedZeroMeanUnitVarianceAlongAxisKwargs
for independent scaling along given
axes.
Examples:
- scalar value for whole tensor
- in YAML
preprocessing:
- id: fixed_zero_mean_unit_variance
kwargs:
mean: 103.5
std: 13.7
- in Python
>>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7)
... )]
independently along an axis
- in YAML
preprocessing: - id: fixed_zero_mean_unit_variance kwargs: axis: channel mean: [101.5, 102.5, 103.5] std: [11.7, 12.7, 13.7]
- in Python
>>> preprocessing = [FixedZeroMeanUnitVarianceDescr( ... kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( ... axis=AxisId("channel"), ... mean=[101.5, 102.5, 103.5], ... std=[11.7, 12.7, 13.7], ... ) ... )]
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
1246class ZeroMeanUnitVarianceKwargs(ProcessingKwargs): 1247 """key word arguments for `ZeroMeanUnitVarianceDescr`""" 1248 1249 axes: Annotated[ 1250 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 1251 ] = None 1252 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 1253 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 1254 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 1255 To normalize each sample independently leave out the 'batch' axis. 1256 Default: Scale all axes jointly.""" 1257 1258 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 1259 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`."""
key word arguments for ZeroMeanUnitVarianceDescr
The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
resulting in a tensor of equal shape normalized per channel, specify axes=('batch', 'x', 'y')
.
To normalize each sample independently leave out the 'batch' axis.
Default: Scale all axes jointly.
epsilon for numeric stability: out = (tensor - mean) / (std + eps)
.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
1262class ZeroMeanUnitVarianceDescr(ProcessingDescrBase): 1263 """Subtract mean and divide by variance. 1264 1265 Examples: 1266 Subtract tensor mean and variance 1267 - in YAML 1268 ```yaml 1269 preprocessing: 1270 - id: zero_mean_unit_variance 1271 ``` 1272 - in Python 1273 >>> preprocessing = [ZeroMeanUnitVarianceDescr()] 1274 """ 1275 1276 implemented_id: ClassVar[Literal["zero_mean_unit_variance"]] = ( 1277 "zero_mean_unit_variance" 1278 ) 1279 if TYPE_CHECKING: 1280 id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance" 1281 else: 1282 id: Literal["zero_mean_unit_variance"] 1283 1284 kwargs: ZeroMeanUnitVarianceKwargs = Field( 1285 default_factory=ZeroMeanUnitVarianceKwargs 1286 )
Subtract mean and divide by variance.
Examples:
Subtract tensor mean and variance
- in YAML
preprocessing: - id: zero_mean_unit_variance
- in Python
>>> preprocessing = [ZeroMeanUnitVarianceDescr()]
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
1289class ScaleRangeKwargs(ProcessingKwargs): 1290 """key word arguments for `ScaleRangeDescr` 1291 1292 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default) 1293 this processing step normalizes data to the [0, 1] intervall. 1294 For other percentiles the normalized values will partially be outside the [0, 1] 1295 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the 1296 normalized values to a range. 1297 """ 1298 1299 axes: Annotated[ 1300 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 1301 ] = None 1302 """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value. 1303 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 1304 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 1305 To normalize samples independently, leave out the "batch" axis. 1306 Default: Scale all axes jointly.""" 1307 1308 min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0 1309 """The lower percentile used to determine the value to align with zero.""" 1310 1311 max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0 1312 """The upper percentile used to determine the value to align with one. 1313 Has to be bigger than `min_percentile`. 1314 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly 1315 accepting percentiles specified in the range 0.0 to 1.0.""" 1316 1317 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 1318 """Epsilon for numeric stability. 1319 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`; 1320 with `v_lower,v_upper` values at the respective percentiles.""" 1321 1322 reference_tensor: Optional[TensorId] = None 1323 """Tensor ID to compute the percentiles from. Default: The tensor itself. 1324 For any tensor in `inputs` only input tensor references are allowed.""" 1325 1326 @field_validator("max_percentile", mode="after") 1327 @classmethod 1328 def min_smaller_max(cls, value: float, info: ValidationInfo) -> float: 1329 if (min_p := info.data["min_percentile"]) >= value: 1330 raise ValueError(f"min_percentile {min_p} >= max_percentile {value}") 1331 1332 return value
key word arguments for ScaleRangeDescr
For min_percentile
=0.0 (the default) and max_percentile
=100 (the default)
this processing step normalizes data to the [0, 1] intervall.
For other percentiles the normalized values will partially be outside the [0, 1]
intervall. Use ScaleRange
followed by ClipDescr
if you want to limit the
normalized values to a range.
The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value.
For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
resulting in a tensor of equal shape normalized per channel, specify axes=('batch', 'x', 'y')
.
To normalize samples independently, leave out the "batch" axis.
Default: Scale all axes jointly.
The lower percentile used to determine the value to align with zero.
The upper percentile used to determine the value to align with one.
Has to be bigger than min_percentile
.
The range is 1 to 100 instead of 0 to 100 to avoid mistakenly
accepting percentiles specified in the range 0.0 to 1.0.
Epsilon for numeric stability.
out = (tensor - v_lower) / (v_upper - v_lower + eps)
;
with v_lower,v_upper
values at the respective percentiles.
Tensor ID to compute the percentiles from. Default: The tensor itself.
For any tensor in inputs
only input tensor references are allowed.
1326 @field_validator("max_percentile", mode="after") 1327 @classmethod 1328 def min_smaller_max(cls, value: float, info: ValidationInfo) -> float: 1329 if (min_p := info.data["min_percentile"]) >= value: 1330 raise ValueError(f"min_percentile {min_p} >= max_percentile {value}") 1331 1332 return value
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
1335class ScaleRangeDescr(ProcessingDescrBase): 1336 """Scale with percentiles. 1337 1338 Examples: 1339 1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0 1340 - in YAML 1341 ```yaml 1342 preprocessing: 1343 - id: scale_range 1344 kwargs: 1345 axes: ['y', 'x'] 1346 max_percentile: 99.8 1347 min_percentile: 5.0 1348 ``` 1349 - in Python 1350 >>> preprocessing = [ 1351 ... ScaleRangeDescr( 1352 ... kwargs=ScaleRangeKwargs( 1353 ... axes= (AxisId('y'), AxisId('x')), 1354 ... max_percentile= 99.8, 1355 ... min_percentile= 5.0, 1356 ... ) 1357 ... ), 1358 ... ClipDescr( 1359 ... kwargs=ClipKwargs( 1360 ... min=0.0, 1361 ... max=1.0, 1362 ... ) 1363 ... ), 1364 ... ] 1365 1366 2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles. 1367 - in YAML 1368 ```yaml 1369 preprocessing: 1370 - id: scale_range 1371 kwargs: 1372 axes: ['y', 'x'] 1373 max_percentile: 99.8 1374 min_percentile: 5.0 1375 - id: scale_range 1376 - id: clip 1377 kwargs: 1378 min: 0.0 1379 max: 1.0 1380 ``` 1381 - in Python 1382 >>> preprocessing = [ScaleRangeDescr( 1383 ... kwargs=ScaleRangeKwargs( 1384 ... axes= (AxisId('y'), AxisId('x')), 1385 ... max_percentile= 99.8, 1386 ... min_percentile= 5.0, 1387 ... ) 1388 ... )] 1389 1390 """ 1391 1392 implemented_id: ClassVar[Literal["scale_range"]] = "scale_range" 1393 if TYPE_CHECKING: 1394 id: Literal["scale_range"] = "scale_range" 1395 else: 1396 id: Literal["scale_range"] 1397 kwargs: ScaleRangeKwargs
Scale with percentiles.
Examples:
- Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0
- in YAML
preprocessing:
- id: scale_range
kwargs:
axes: ['y', 'x']
max_percentile: 99.8
min_percentile: 5.0
- in Python
>>> preprocessing = [
... ScaleRangeDescr(
... kwargs=ScaleRangeKwargs(
... axes= (AxisId('y'), AxisId('x')),
... max_percentile= 99.8,
... min_percentile= 5.0,
... )
... ),
... ClipDescr(
... kwargs=ClipKwargs(
... min=0.0,
... max=1.0,
... )
... ),
... ]
Combine the above scaling with additional clipping to clip values outside the range given by the percentiles.
- in YAML
preprocessing: - id: scale_range kwargs: axes: ['y', 'x'] max_percentile: 99.8 min_percentile: 5.0 - id: scale_range - id: clip kwargs: min: 0.0 max: 1.0
- in Python
>>> preprocessing = [ScaleRangeDescr( ... kwargs=ScaleRangeKwargs( ... axes= (AxisId('y'), AxisId('x')), ... max_percentile= 99.8, ... min_percentile= 5.0, ... ) ... )]
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
1400class ScaleMeanVarianceKwargs(ProcessingKwargs): 1401 """key word arguments for `ScaleMeanVarianceKwargs`""" 1402 1403 reference_tensor: TensorId 1404 """Name of tensor to match.""" 1405 1406 axes: Annotated[ 1407 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 1408 ] = None 1409 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 1410 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 1411 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 1412 To normalize samples independently, leave out the 'batch' axis. 1413 Default: Scale all axes jointly.""" 1414 1415 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 1416 """Epsilon for numeric stability: 1417 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`"""
key word arguments for ScaleMeanVarianceKwargs
The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
resulting in a tensor of equal shape normalized per channel, specify axes=('batch', 'x', 'y')
.
To normalize samples independently, leave out the 'batch' axis.
Default: Scale all axes jointly.
Epsilon for numeric stability:
out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
1420class ScaleMeanVarianceDescr(ProcessingDescrBase): 1421 """Scale a tensor's data distribution to match another tensor's mean/std. 1422 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.` 1423 """ 1424 1425 implemented_id: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance" 1426 if TYPE_CHECKING: 1427 id: Literal["scale_mean_variance"] = "scale_mean_variance" 1428 else: 1429 id: Literal["scale_mean_variance"] 1430 kwargs: ScaleMeanVarianceKwargs
Scale a tensor's data distribution to match another tensor's mean/std.
out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
1464class TensorDescrBase(Node, Generic[IO_AxisT]): 1465 id: TensorId 1466 """Tensor id. No duplicates are allowed.""" 1467 1468 description: Annotated[str, MaxLen(128)] = "" 1469 """free text description""" 1470 1471 axes: NotEmpty[Sequence[IO_AxisT]] 1472 """tensor axes""" 1473 1474 @property 1475 def shape(self): 1476 return tuple(a.size for a in self.axes) 1477 1478 @field_validator("axes", mode="after", check_fields=False) 1479 @classmethod 1480 def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]: 1481 batch_axes = [a for a in axes if a.type == "batch"] 1482 if len(batch_axes) > 1: 1483 raise ValueError( 1484 f"Only one batch axis (per tensor) allowed, but got {batch_axes}" 1485 ) 1486 1487 seen_ids: Set[AxisId] = set() 1488 duplicate_axes_ids: Set[AxisId] = set() 1489 for a in axes: 1490 (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id) 1491 1492 if duplicate_axes_ids: 1493 raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}") 1494 1495 return axes 1496 1497 test_tensor: FileDescr_ 1498 """An example tensor to use for testing. 1499 Using the model with the test input tensors is expected to yield the test output tensors. 1500 Each test tensor has be a an ndarray in the 1501 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format). 1502 The file extension must be '.npy'.""" 1503 1504 sample_tensor: Optional[FileDescr_] = None 1505 """A sample tensor to illustrate a possible input/output for the model, 1506 The sample image primarily serves to inform a human user about an example use case 1507 and is typically stored as .hdf5, .png or .tiff. 1508 It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats) 1509 (numpy's `.npy` format is not supported). 1510 The image dimensionality has to match the number of axes specified in this tensor description. 1511 """ 1512 1513 @model_validator(mode="after") 1514 def _validate_sample_tensor(self) -> Self: 1515 if self.sample_tensor is None or not get_validation_context().perform_io_checks: 1516 return self 1517 1518 reader = get_reader(self.sample_tensor.source, sha256=self.sample_tensor.sha256) 1519 tensor: NDArray[Any] = imread( 1520 reader.read(), 1521 extension=PurePosixPath(reader.original_file_name).suffix, 1522 ) 1523 n_dims = len(tensor.squeeze().shape) 1524 n_dims_min = n_dims_max = len(self.axes) 1525 1526 for a in self.axes: 1527 if isinstance(a, BatchAxis): 1528 n_dims_min -= 1 1529 elif isinstance(a.size, int): 1530 if a.size == 1: 1531 n_dims_min -= 1 1532 elif isinstance(a.size, (ParameterizedSize, DataDependentSize)): 1533 if a.size.min == 1: 1534 n_dims_min -= 1 1535 elif isinstance(a.size, SizeReference): 1536 if a.size.offset < 2: 1537 # size reference may result in singleton axis 1538 n_dims_min -= 1 1539 else: 1540 assert_never(a.size) 1541 1542 n_dims_min = max(0, n_dims_min) 1543 if n_dims < n_dims_min or n_dims > n_dims_max: 1544 raise ValueError( 1545 f"Expected sample tensor to have {n_dims_min} to" 1546 + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})." 1547 ) 1548 1549 return self 1550 1551 data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = ( 1552 IntervalOrRatioDataDescr() 1553 ) 1554 """Description of the tensor's data values, optionally per channel. 1555 If specified per channel, the data `type` needs to match across channels.""" 1556 1557 @property 1558 def dtype( 1559 self, 1560 ) -> Literal[ 1561 "float32", 1562 "float64", 1563 "uint8", 1564 "int8", 1565 "uint16", 1566 "int16", 1567 "uint32", 1568 "int32", 1569 "uint64", 1570 "int64", 1571 "bool", 1572 ]: 1573 """dtype as specified under `data.type` or `data[i].type`""" 1574 if isinstance(self.data, collections.abc.Sequence): 1575 return self.data[0].type 1576 else: 1577 return self.data.type 1578 1579 @field_validator("data", mode="after") 1580 @classmethod 1581 def _check_data_type_across_channels( 1582 cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] 1583 ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]: 1584 if not isinstance(value, list): 1585 return value 1586 1587 dtypes = {t.type for t in value} 1588 if len(dtypes) > 1: 1589 raise ValueError( 1590 "Tensor data descriptions per channel need to agree in their data" 1591 + f" `type`, but found {dtypes}." 1592 ) 1593 1594 return value 1595 1596 @model_validator(mode="after") 1597 def _check_data_matches_channelaxis(self) -> Self: 1598 if not isinstance(self.data, (list, tuple)): 1599 return self 1600 1601 for a in self.axes: 1602 if isinstance(a, ChannelAxis): 1603 size = a.size 1604 assert isinstance(size, int) 1605 break 1606 else: 1607 return self 1608 1609 if len(self.data) != size: 1610 raise ValueError( 1611 f"Got tensor data descriptions for {len(self.data)} channels, but" 1612 + f" '{a.id}' axis has size {size}." 1613 ) 1614 1615 return self 1616 1617 def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]: 1618 if len(array.shape) != len(self.axes): 1619 raise ValueError( 1620 f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})" 1621 + f" incompatible with {len(self.axes)} axes." 1622 ) 1623 return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
An example tensor to use for testing. Using the model with the test input tensors is expected to yield the test output tensors. Each test tensor has be a an ndarray in the numpy.lib file format. The file extension must be '.npy'.
A sample tensor to illustrate a possible input/output for the model,
The sample image primarily serves to inform a human user about an example use case
and is typically stored as .hdf5, .png or .tiff.
It has to be readable by the imageio library
(numpy's .npy
format is not supported).
The image dimensionality has to match the number of axes specified in this tensor description.
Description of the tensor's data values, optionally per channel.
If specified per channel, the data type
needs to match across channels.
1557 @property 1558 def dtype( 1559 self, 1560 ) -> Literal[ 1561 "float32", 1562 "float64", 1563 "uint8", 1564 "int8", 1565 "uint16", 1566 "int16", 1567 "uint32", 1568 "int32", 1569 "uint64", 1570 "int64", 1571 "bool", 1572 ]: 1573 """dtype as specified under `data.type` or `data[i].type`""" 1574 if isinstance(self.data, collections.abc.Sequence): 1575 return self.data[0].type 1576 else: 1577 return self.data.type
dtype as specified under data.type
or data[i].type
1617 def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]: 1618 if len(array.shape) != len(self.axes): 1619 raise ValueError( 1620 f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})" 1621 + f" incompatible with {len(self.axes)} axes." 1622 ) 1623 return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
Inherited Members
1626class InputTensorDescr(TensorDescrBase[InputAxis]): 1627 id: TensorId = TensorId("input") 1628 """Input tensor id. 1629 No duplicates are allowed across all inputs and outputs.""" 1630 1631 optional: bool = False 1632 """indicates that this tensor may be `None`""" 1633 1634 preprocessing: List[PreprocessingDescr] = Field( 1635 default_factory=cast(Callable[[], List[PreprocessingDescr]], list) 1636 ) 1637 1638 """Description of how this input should be preprocessed. 1639 1640 notes: 1641 - If preprocessing does not start with an 'ensure_dtype' entry, it is added 1642 to ensure an input tensor's data type matches the input tensor's data description. 1643 - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an 1644 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally 1645 changing the data type. 1646 """ 1647 1648 @model_validator(mode="after") 1649 def _validate_preprocessing_kwargs(self) -> Self: 1650 axes_ids = [a.id for a in self.axes] 1651 for p in self.preprocessing: 1652 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 1653 if kwargs_axes is None: 1654 continue 1655 1656 if not isinstance(kwargs_axes, collections.abc.Sequence): 1657 raise ValueError( 1658 f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}" 1659 ) 1660 1661 if any(a not in axes_ids for a in kwargs_axes): 1662 raise ValueError( 1663 "`preprocessing.i.kwargs.axes` needs to be subset of axes ids" 1664 ) 1665 1666 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 1667 dtype = self.data.type 1668 else: 1669 dtype = self.data[0].type 1670 1671 # ensure `preprocessing` begins with `EnsureDtypeDescr` 1672 if not self.preprocessing or not isinstance( 1673 self.preprocessing[0], EnsureDtypeDescr 1674 ): 1675 self.preprocessing.insert( 1676 0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 1677 ) 1678 1679 # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 1680 if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)): 1681 self.preprocessing.append( 1682 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 1683 ) 1684 1685 return self
Description of how this input should be preprocessed.
notes:
- If preprocessing does not start with an 'ensure_dtype' entry, it is added to ensure an input tensor's data type matches the input tensor's data description.
- If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally changing the data type.
1688def convert_axes( 1689 axes: str, 1690 *, 1691 shape: Union[ 1692 Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4 1693 ], 1694 tensor_type: Literal["input", "output"], 1695 halo: Optional[Sequence[int]], 1696 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 1697): 1698 ret: List[AnyAxis] = [] 1699 for i, a in enumerate(axes): 1700 axis_type = _AXIS_TYPE_MAP.get(a, a) 1701 if axis_type == "batch": 1702 ret.append(BatchAxis()) 1703 continue 1704 1705 scale = 1.0 1706 if isinstance(shape, _ParameterizedInputShape_v0_4): 1707 if shape.step[i] == 0: 1708 size = shape.min[i] 1709 else: 1710 size = ParameterizedSize(min=shape.min[i], step=shape.step[i]) 1711 elif isinstance(shape, _ImplicitOutputShape_v0_4): 1712 ref_t = str(shape.reference_tensor) 1713 if ref_t.count(".") == 1: 1714 t_id, orig_a_id = ref_t.split(".") 1715 else: 1716 t_id = ref_t 1717 orig_a_id = a 1718 1719 a_id = _AXIS_ID_MAP.get(orig_a_id, a) 1720 if not (orig_scale := shape.scale[i]): 1721 # old way to insert a new axis dimension 1722 size = int(2 * shape.offset[i]) 1723 else: 1724 scale = 1 / orig_scale 1725 if axis_type in ("channel", "index"): 1726 # these axes no longer have a scale 1727 offset_from_scale = orig_scale * size_refs.get( 1728 _TensorName_v0_4(t_id), {} 1729 ).get(orig_a_id, 0) 1730 else: 1731 offset_from_scale = 0 1732 size = SizeReference( 1733 tensor_id=TensorId(t_id), 1734 axis_id=AxisId(a_id), 1735 offset=int(offset_from_scale + 2 * shape.offset[i]), 1736 ) 1737 else: 1738 size = shape[i] 1739 1740 if axis_type == "time": 1741 if tensor_type == "input": 1742 ret.append(TimeInputAxis(size=size, scale=scale)) 1743 else: 1744 assert not isinstance(size, ParameterizedSize) 1745 if halo is None: 1746 ret.append(TimeOutputAxis(size=size, scale=scale)) 1747 else: 1748 assert not isinstance(size, int) 1749 ret.append( 1750 TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i]) 1751 ) 1752 1753 elif axis_type == "index": 1754 if tensor_type == "input": 1755 ret.append(IndexInputAxis(size=size)) 1756 else: 1757 if isinstance(size, ParameterizedSize): 1758 size = DataDependentSize(min=size.min) 1759 1760 ret.append(IndexOutputAxis(size=size)) 1761 elif axis_type == "channel": 1762 assert not isinstance(size, ParameterizedSize) 1763 if isinstance(size, SizeReference): 1764 warnings.warn( 1765 "Conversion of channel size from an implicit output shape may be" 1766 + " wrong" 1767 ) 1768 ret.append( 1769 ChannelAxis( 1770 channel_names=[ 1771 Identifier(f"channel{i}") for i in range(size.offset) 1772 ] 1773 ) 1774 ) 1775 else: 1776 ret.append( 1777 ChannelAxis( 1778 channel_names=[Identifier(f"channel{i}") for i in range(size)] 1779 ) 1780 ) 1781 elif axis_type == "space": 1782 if tensor_type == "input": 1783 ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale)) 1784 else: 1785 assert not isinstance(size, ParameterizedSize) 1786 if halo is None or halo[i] == 0: 1787 ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale)) 1788 elif isinstance(size, int): 1789 raise NotImplementedError( 1790 f"output axis with halo and fixed size (here {size}) not allowed" 1791 ) 1792 else: 1793 ret.append( 1794 SpaceOutputAxisWithHalo( 1795 id=AxisId(a), size=size, scale=scale, halo=halo[i] 1796 ) 1797 ) 1798 1799 return ret
1959class OutputTensorDescr(TensorDescrBase[OutputAxis]): 1960 id: TensorId = TensorId("output") 1961 """Output tensor id. 1962 No duplicates are allowed across all inputs and outputs.""" 1963 1964 postprocessing: List[PostprocessingDescr] = Field( 1965 default_factory=cast(Callable[[], List[PostprocessingDescr]], list) 1966 ) 1967 """Description of how this output should be postprocessed. 1968 1969 note: `postprocessing` always ends with an 'ensure_dtype' operation. 1970 If not given this is added to cast to this tensor's `data.type`. 1971 """ 1972 1973 @model_validator(mode="after") 1974 def _validate_postprocessing_kwargs(self) -> Self: 1975 axes_ids = [a.id for a in self.axes] 1976 for p in self.postprocessing: 1977 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 1978 if kwargs_axes is None: 1979 continue 1980 1981 if not isinstance(kwargs_axes, collections.abc.Sequence): 1982 raise ValueError( 1983 f"expected `axes` sequence, but got {type(kwargs_axes)}" 1984 ) 1985 1986 if any(a not in axes_ids for a in kwargs_axes): 1987 raise ValueError("`kwargs.axes` needs to be subset of axes ids") 1988 1989 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 1990 dtype = self.data.type 1991 else: 1992 dtype = self.data[0].type 1993 1994 # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 1995 if not self.postprocessing or not isinstance( 1996 self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr) 1997 ): 1998 self.postprocessing.append( 1999 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 2000 ) 2001 return self
Description of how this output should be postprocessed.
note: postprocessing
always ends with an 'ensure_dtype' operation.
If not given this is added to cast to this tensor's data.type
.
2051def validate_tensors( 2052 tensors: Mapping[TensorId, Tuple[TensorDescr, NDArray[Any]]], 2053 tensor_origin: Literal[ 2054 "test_tensor" 2055 ], # for more precise error messages, e.g. 'test_tensor' 2056): 2057 all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, int]]] = {} 2058 2059 def e_msg(d: TensorDescr): 2060 return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]" 2061 2062 for descr, array in tensors.values(): 2063 try: 2064 axis_sizes = descr.get_axis_sizes_for_array(array) 2065 except ValueError as e: 2066 raise ValueError(f"{e_msg(descr)} {e}") 2067 else: 2068 all_tensor_axes[descr.id] = { 2069 a.id: (a, axis_sizes[a.id]) for a in descr.axes 2070 } 2071 2072 for descr, array in tensors.values(): 2073 if descr.dtype in ("float32", "float64"): 2074 invalid_test_tensor_dtype = array.dtype.name not in ( 2075 "float32", 2076 "float64", 2077 "uint8", 2078 "int8", 2079 "uint16", 2080 "int16", 2081 "uint32", 2082 "int32", 2083 "uint64", 2084 "int64", 2085 ) 2086 else: 2087 invalid_test_tensor_dtype = array.dtype.name != descr.dtype 2088 2089 if invalid_test_tensor_dtype: 2090 raise ValueError( 2091 f"{e_msg(descr)}.{tensor_origin}.dtype '{array.dtype.name}' does not" 2092 + f" match described dtype '{descr.dtype}'" 2093 ) 2094 2095 if array.min() > -1e-4 and array.max() < 1e-4: 2096 raise ValueError( 2097 "Output values are too small for reliable testing." 2098 + f" Values <-1e5 or >=1e5 must be present in {tensor_origin}" 2099 ) 2100 2101 for a in descr.axes: 2102 actual_size = all_tensor_axes[descr.id][a.id][1] 2103 if a.size is None: 2104 continue 2105 2106 if isinstance(a.size, int): 2107 if actual_size != a.size: 2108 raise ValueError( 2109 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' " 2110 + f"has incompatible size {actual_size}, expected {a.size}" 2111 ) 2112 elif isinstance(a.size, ParameterizedSize): 2113 _ = a.size.validate_size(actual_size) 2114 elif isinstance(a.size, DataDependentSize): 2115 _ = a.size.validate_size(actual_size) 2116 elif isinstance(a.size, SizeReference): 2117 ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id) 2118 if ref_tensor_axes is None: 2119 raise ValueError( 2120 f"{e_msg(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor" 2121 + f" reference '{a.size.tensor_id}'" 2122 ) 2123 2124 ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None)) 2125 if ref_axis is None or ref_size is None: 2126 raise ValueError( 2127 f"{e_msg(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis" 2128 + f" reference '{a.size.tensor_id}.{a.size.axis_id}" 2129 ) 2130 2131 if a.unit != ref_axis.unit: 2132 raise ValueError( 2133 f"{e_msg(descr)}.axes[{a.id}].size: `SizeReference` requires" 2134 + " axis and reference axis to have the same `unit`, but" 2135 + f" {a.unit}!={ref_axis.unit}" 2136 ) 2137 2138 if actual_size != ( 2139 expected_size := ( 2140 ref_size * ref_axis.scale / a.scale + a.size.offset 2141 ) 2142 ): 2143 raise ValueError( 2144 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' of size" 2145 + f" {actual_size} invalid for referenced size {ref_size};" 2146 + f" expected {expected_size}" 2147 ) 2148 else: 2149 assert_never(a.size)
2169class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr): 2170 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 2171 """Architecture source file""" 2172 2173 @model_serializer(mode="wrap", when_used="unless-none") 2174 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo): 2175 return package_file_descr_serializer(self, nxt, info)
A file description
Architecture source file
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
2178class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr): 2179 import_from: str 2180 """Where to import the callable from, i.e. `from <import_from> import <callable>`"""
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
2240class WeightsEntryDescrBase(FileDescr): 2241 type: ClassVar[WeightsFormat] 2242 weights_format_name: ClassVar[str] # human readable 2243 2244 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 2245 """Source of the weights file.""" 2246 2247 authors: Optional[List[Author]] = None 2248 """Authors 2249 Either the person(s) that have trained this model resulting in the original weights file. 2250 (If this is the initial weights entry, i.e. it does not have a `parent`) 2251 Or the person(s) who have converted the weights to this weights format. 2252 (If this is a child weight, i.e. it has a `parent` field) 2253 """ 2254 2255 parent: Annotated[ 2256 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"]) 2257 ] = None 2258 """The source weights these weights were converted from. 2259 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`, 2260 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights. 2261 All weight entries except one (the initial set of weights resulting from training the model), 2262 need to have this field.""" 2263 2264 comment: str = "" 2265 """A comment about this weights entry, for example how these weights were created.""" 2266 2267 @model_validator(mode="after") 2268 def _validate(self) -> Self: 2269 if self.type == self.parent: 2270 raise ValueError("Weights entry can't be it's own parent.") 2271 2272 return self 2273 2274 @model_serializer(mode="wrap", when_used="unless-none") 2275 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo): 2276 return package_file_descr_serializer(self, nxt, info)
A file description
Source of the weights file.
The source weights these weights were converted from.
For example, if a model's weights were converted from the pytorch_state_dict
format to torchscript
,
The pytorch_state_dict
weights entry has no parent
and is the parent of the torchscript
weights.
All weight entries except one (the initial set of weights resulting from training the model),
need to have this field.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
2279class KerasHdf5WeightsDescr(WeightsEntryDescrBase): 2280 type = "keras_hdf5" 2281 weights_format_name: ClassVar[str] = "Keras HDF5" 2282 tensorflow_version: Version 2283 """TensorFlow version used to create these weights."""
A file description
TensorFlow version used to create these weights.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
2286class OnnxWeightsDescr(WeightsEntryDescrBase): 2287 type = "onnx" 2288 weights_format_name: ClassVar[str] = "ONNX" 2289 opset_version: Annotated[int, Ge(7)] 2290 """ONNX opset version"""
A file description
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
2293class PytorchStateDictWeightsDescr(WeightsEntryDescrBase): 2294 type = "pytorch_state_dict" 2295 weights_format_name: ClassVar[str] = "Pytorch State Dict" 2296 architecture: Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr] 2297 pytorch_version: Version 2298 """Version of the PyTorch library used. 2299 If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible. 2300 """ 2301 dependencies: Optional[FileDescr_dependencies] = None 2302 """Custom depencies beyond pytorch described in a Conda environment file. 2303 Allows to specify custom dependencies, see conda docs: 2304 - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms) 2305 - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually) 2306 2307 The conda environment file should include pytorch and any version pinning has to be compatible with 2308 **pytorch_version**. 2309 """
A file description
Version of the PyTorch library used.
If architecture.depencencies
is specified it has to include pytorch and any version pinning has to be compatible.
Custom depencies beyond pytorch described in a Conda environment file. Allows to specify custom dependencies, see conda docs:
The conda environment file should include pytorch and any version pinning has to be compatible with pytorch_version.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
2312class TensorflowJsWeightsDescr(WeightsEntryDescrBase): 2313 type = "tensorflow_js" 2314 weights_format_name: ClassVar[str] = "Tensorflow.js" 2315 tensorflow_version: Version 2316 """Version of the TensorFlow library used.""" 2317 2318 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 2319 """The multi-file weights. 2320 All required files/folders should be a zip archive."""
A file description
Version of the TensorFlow library used.
The multi-file weights. All required files/folders should be a zip archive.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
2323class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase): 2324 type = "tensorflow_saved_model_bundle" 2325 weights_format_name: ClassVar[str] = "Tensorflow Saved Model" 2326 tensorflow_version: Version 2327 """Version of the TensorFlow library used.""" 2328 2329 dependencies: Optional[FileDescr_dependencies] = None 2330 """Custom dependencies beyond tensorflow. 2331 Should include tensorflow and any version pinning has to be compatible with **tensorflow_version**.""" 2332 2333 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 2334 """The multi-file weights. 2335 All required files/folders should be a zip archive."""
A file description
Version of the TensorFlow library used.
Custom dependencies beyond tensorflow. Should include tensorflow and any version pinning has to be compatible with tensorflow_version.
The multi-file weights. All required files/folders should be a zip archive.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
2338class TorchscriptWeightsDescr(WeightsEntryDescrBase): 2339 type = "torchscript" 2340 weights_format_name: ClassVar[str] = "TorchScript" 2341 pytorch_version: Version 2342 """Version of the PyTorch library used."""
A file description
Version of the PyTorch library used.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
2345class WeightsDescr(Node): 2346 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None 2347 onnx: Optional[OnnxWeightsDescr] = None 2348 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None 2349 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None 2350 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = ( 2351 None 2352 ) 2353 torchscript: Optional[TorchscriptWeightsDescr] = None 2354 2355 @model_validator(mode="after") 2356 def check_entries(self) -> Self: 2357 entries = {wtype for wtype, entry in self if entry is not None} 2358 2359 if not entries: 2360 raise ValueError("Missing weights entry") 2361 2362 entries_wo_parent = { 2363 wtype 2364 for wtype, entry in self 2365 if entry is not None and hasattr(entry, "parent") and entry.parent is None 2366 } 2367 if len(entries_wo_parent) != 1: 2368 issue_warning( 2369 "Exactly one weights entry may not specify the `parent` field (got" 2370 + " {value}). That entry is considered the original set of model weights." 2371 + " Other weight formats are created through conversion of the orignal or" 2372 + " already converted weights. They have to reference the weights format" 2373 + " they were converted from as their `parent`.", 2374 value=len(entries_wo_parent), 2375 field="weights", 2376 ) 2377 2378 for wtype, entry in self: 2379 if entry is None: 2380 continue 2381 2382 assert hasattr(entry, "type") 2383 assert hasattr(entry, "parent") 2384 assert wtype == entry.type 2385 if ( 2386 entry.parent is not None and entry.parent not in entries 2387 ): # self reference checked for `parent` field 2388 raise ValueError( 2389 f"`weights.{wtype}.parent={entry.parent} not in specified weight" 2390 + f" formats: {entries}" 2391 ) 2392 2393 return self 2394 2395 def __getitem__( 2396 self, 2397 key: Literal[ 2398 "keras_hdf5", 2399 "onnx", 2400 "pytorch_state_dict", 2401 "tensorflow_js", 2402 "tensorflow_saved_model_bundle", 2403 "torchscript", 2404 ], 2405 ): 2406 if key == "keras_hdf5": 2407 ret = self.keras_hdf5 2408 elif key == "onnx": 2409 ret = self.onnx 2410 elif key == "pytorch_state_dict": 2411 ret = self.pytorch_state_dict 2412 elif key == "tensorflow_js": 2413 ret = self.tensorflow_js 2414 elif key == "tensorflow_saved_model_bundle": 2415 ret = self.tensorflow_saved_model_bundle 2416 elif key == "torchscript": 2417 ret = self.torchscript 2418 else: 2419 raise KeyError(key) 2420 2421 if ret is None: 2422 raise KeyError(key) 2423 2424 return ret 2425 2426 @property 2427 def available_formats(self): 2428 return { 2429 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}), 2430 **({} if self.onnx is None else {"onnx": self.onnx}), 2431 **( 2432 {} 2433 if self.pytorch_state_dict is None 2434 else {"pytorch_state_dict": self.pytorch_state_dict} 2435 ), 2436 **( 2437 {} 2438 if self.tensorflow_js is None 2439 else {"tensorflow_js": self.tensorflow_js} 2440 ), 2441 **( 2442 {} 2443 if self.tensorflow_saved_model_bundle is None 2444 else { 2445 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle 2446 } 2447 ), 2448 **({} if self.torchscript is None else {"torchscript": self.torchscript}), 2449 } 2450 2451 @property 2452 def missing_formats(self): 2453 return { 2454 wf for wf in get_args(WeightsFormat) if wf not in self.available_formats 2455 }
2355 @model_validator(mode="after") 2356 def check_entries(self) -> Self: 2357 entries = {wtype for wtype, entry in self if entry is not None} 2358 2359 if not entries: 2360 raise ValueError("Missing weights entry") 2361 2362 entries_wo_parent = { 2363 wtype 2364 for wtype, entry in self 2365 if entry is not None and hasattr(entry, "parent") and entry.parent is None 2366 } 2367 if len(entries_wo_parent) != 1: 2368 issue_warning( 2369 "Exactly one weights entry may not specify the `parent` field (got" 2370 + " {value}). That entry is considered the original set of model weights." 2371 + " Other weight formats are created through conversion of the orignal or" 2372 + " already converted weights. They have to reference the weights format" 2373 + " they were converted from as their `parent`.", 2374 value=len(entries_wo_parent), 2375 field="weights", 2376 ) 2377 2378 for wtype, entry in self: 2379 if entry is None: 2380 continue 2381 2382 assert hasattr(entry, "type") 2383 assert hasattr(entry, "parent") 2384 assert wtype == entry.type 2385 if ( 2386 entry.parent is not None and entry.parent not in entries 2387 ): # self reference checked for `parent` field 2388 raise ValueError( 2389 f"`weights.{wtype}.parent={entry.parent} not in specified weight" 2390 + f" formats: {entries}" 2391 ) 2392 2393 return self
2426 @property 2427 def available_formats(self): 2428 return { 2429 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}), 2430 **({} if self.onnx is None else {"onnx": self.onnx}), 2431 **( 2432 {} 2433 if self.pytorch_state_dict is None 2434 else {"pytorch_state_dict": self.pytorch_state_dict} 2435 ), 2436 **( 2437 {} 2438 if self.tensorflow_js is None 2439 else {"tensorflow_js": self.tensorflow_js} 2440 ), 2441 **( 2442 {} 2443 if self.tensorflow_saved_model_bundle is None 2444 else { 2445 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle 2446 } 2447 ), 2448 **({} if self.torchscript is None else {"torchscript": self.torchscript}), 2449 }
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.
Inherited Members
2462class LinkedModel(LinkedResourceBase): 2463 """Reference to a bioimage.io model.""" 2464 2465 id: ModelId 2466 """A valid model `id` from the bioimage.io collection."""
Reference to a bioimage.io model.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
2488class ReproducibilityTolerance(Node, extra="allow"): 2489 """Describes what small numerical differences -- if any -- may be tolerated 2490 in the generated output when executing in different environments. 2491 2492 A tensor element *output* is considered mismatched to the **test_tensor** if 2493 abs(*output* - **test_tensor**) > **absolute_tolerance** + **relative_tolerance** * abs(**test_tensor**). 2494 (Internally we call [numpy.testing.assert_allclose](https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html).) 2495 2496 Motivation: 2497 For testing we can request the respective deep learning frameworks to be as 2498 reproducible as possible by setting seeds and chosing deterministic algorithms, 2499 but differences in operating systems, available hardware and installed drivers 2500 may still lead to numerical differences. 2501 """ 2502 2503 relative_tolerance: RelativeTolerance = 1e-3 2504 """Maximum relative tolerance of reproduced test tensor.""" 2505 2506 absolute_tolerance: AbsoluteTolerance = 1e-4 2507 """Maximum absolute tolerance of reproduced test tensor.""" 2508 2509 mismatched_elements_per_million: MismatchedElementsPerMillion = 100 2510 """Maximum number of mismatched elements/pixels per million to tolerate.""" 2511 2512 output_ids: Sequence[TensorId] = () 2513 """Limits the output tensor IDs these reproducibility details apply to.""" 2514 2515 weights_formats: Sequence[WeightsFormat] = () 2516 """Limits the weights formats these details apply to."""
Describes what small numerical differences -- if any -- may be tolerated in the generated output when executing in different environments.
A tensor element output is considered mismatched to the test_tensor if abs(output - test_tensor) > absolute_tolerance + relative_tolerance * abs(test_tensor). (Internally we call numpy.testing.assert_allclose.)
Motivation:
For testing we can request the respective deep learning frameworks to be as reproducible as possible by setting seeds and chosing deterministic algorithms, but differences in operating systems, available hardware and installed drivers may still lead to numerical differences.
Maximum relative tolerance of reproduced test tensor.
Maximum absolute tolerance of reproduced test tensor.
Maximum number of mismatched elements/pixels per million to tolerate.
Limits the weights formats these details apply to.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
2519class BioimageioConfig(Node, extra="allow"): 2520 reproducibility_tolerance: Sequence[ReproducibilityTolerance] = () 2521 """Tolerances to allow when reproducing the model's test outputs 2522 from the model's test inputs. 2523 Only the first entry matching tensor id and weights format is considered. 2524 """
Tolerances to allow when reproducing the model's test outputs from the model's test inputs. Only the first entry matching tensor id and weights format is considered.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
2527class Config(Node, extra="allow"): 2528 bioimageio: BioimageioConfig = Field(default_factory=BioimageioConfig)
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
Inherited Members
2531class ModelDescr(GenericModelDescrBase): 2532 """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. 2533 These fields are typically stored in a YAML file which we call a model resource description file (model RDF). 2534 """ 2535 2536 implemented_format_version: ClassVar[Literal["0.5.4"]] = "0.5.4" 2537 if TYPE_CHECKING: 2538 format_version: Literal["0.5.4"] = "0.5.4" 2539 else: 2540 format_version: Literal["0.5.4"] 2541 """Version of the bioimage.io model description specification used. 2542 When creating a new model always use the latest micro/patch version described here. 2543 The `format_version` is important for any consumer software to understand how to parse the fields. 2544 """ 2545 2546 implemented_type: ClassVar[Literal["model"]] = "model" 2547 if TYPE_CHECKING: 2548 type: Literal["model"] = "model" 2549 else: 2550 type: Literal["model"] 2551 """Specialized resource type 'model'""" 2552 2553 id: Optional[ModelId] = None 2554 """bioimage.io-wide unique resource identifier 2555 assigned by bioimage.io; version **un**specific.""" 2556 2557 authors: NotEmpty[List[Author]] 2558 """The authors are the creators of the model RDF and the primary points of contact.""" 2559 2560 documentation: FileSource_documentation 2561 """URL or relative path to a markdown file with additional documentation. 2562 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory. 2563 The documentation should include a '#[#] Validation' (sub)section 2564 with details on how to quantitatively validate the model on unseen data.""" 2565 2566 @field_validator("documentation", mode="after") 2567 @classmethod 2568 def _validate_documentation( 2569 cls, value: FileSource_documentation 2570 ) -> FileSource_documentation: 2571 if not get_validation_context().perform_io_checks: 2572 return value 2573 2574 doc_reader = get_reader(value) 2575 doc_content = doc_reader.read().decode(encoding="utf-8") 2576 if not re.search("#.*[vV]alidation", doc_content): 2577 issue_warning( 2578 "No '# Validation' (sub)section found in {value}.", 2579 value=value, 2580 field="documentation", 2581 ) 2582 2583 return value 2584 2585 inputs: NotEmpty[Sequence[InputTensorDescr]] 2586 """Describes the input tensors expected by this model.""" 2587 2588 @field_validator("inputs", mode="after") 2589 @classmethod 2590 def _validate_input_axes( 2591 cls, inputs: Sequence[InputTensorDescr] 2592 ) -> Sequence[InputTensorDescr]: 2593 input_size_refs = cls._get_axes_with_independent_size(inputs) 2594 2595 for i, ipt in enumerate(inputs): 2596 valid_independent_refs: Dict[ 2597 Tuple[TensorId, AxisId], 2598 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2599 ] = { 2600 **{ 2601 (ipt.id, a.id): (ipt, a, a.size) 2602 for a in ipt.axes 2603 if not isinstance(a, BatchAxis) 2604 and isinstance(a.size, (int, ParameterizedSize)) 2605 }, 2606 **input_size_refs, 2607 } 2608 for a, ax in enumerate(ipt.axes): 2609 cls._validate_axis( 2610 "inputs", 2611 i=i, 2612 tensor_id=ipt.id, 2613 a=a, 2614 axis=ax, 2615 valid_independent_refs=valid_independent_refs, 2616 ) 2617 return inputs 2618 2619 @staticmethod 2620 def _validate_axis( 2621 field_name: str, 2622 i: int, 2623 tensor_id: TensorId, 2624 a: int, 2625 axis: AnyAxis, 2626 valid_independent_refs: Dict[ 2627 Tuple[TensorId, AxisId], 2628 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2629 ], 2630 ): 2631 if isinstance(axis, BatchAxis) or isinstance( 2632 axis.size, (int, ParameterizedSize, DataDependentSize) 2633 ): 2634 return 2635 elif not isinstance(axis.size, SizeReference): 2636 assert_never(axis.size) 2637 2638 # validate axis.size SizeReference 2639 ref = (axis.size.tensor_id, axis.size.axis_id) 2640 if ref not in valid_independent_refs: 2641 raise ValueError( 2642 "Invalid tensor axis reference at" 2643 + f" {field_name}[{i}].axes[{a}].size: {axis.size}." 2644 ) 2645 if ref == (tensor_id, axis.id): 2646 raise ValueError( 2647 "Self-referencing not allowed for" 2648 + f" {field_name}[{i}].axes[{a}].size: {axis.size}" 2649 ) 2650 if axis.type == "channel": 2651 if valid_independent_refs[ref][1].type != "channel": 2652 raise ValueError( 2653 "A channel axis' size may only reference another fixed size" 2654 + " channel axis." 2655 ) 2656 if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names: 2657 ref_size = valid_independent_refs[ref][2] 2658 assert isinstance(ref_size, int), ( 2659 "channel axis ref (another channel axis) has to specify fixed" 2660 + " size" 2661 ) 2662 generated_channel_names = [ 2663 Identifier(axis.channel_names.format(i=i)) 2664 for i in range(1, ref_size + 1) 2665 ] 2666 axis.channel_names = generated_channel_names 2667 2668 if (ax_unit := getattr(axis, "unit", None)) != ( 2669 ref_unit := getattr(valid_independent_refs[ref][1], "unit", None) 2670 ): 2671 raise ValueError( 2672 "The units of an axis and its reference axis need to match, but" 2673 + f" '{ax_unit}' != '{ref_unit}'." 2674 ) 2675 ref_axis = valid_independent_refs[ref][1] 2676 if isinstance(ref_axis, BatchAxis): 2677 raise ValueError( 2678 f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}" 2679 + " (a batch axis is not allowed as reference)." 2680 ) 2681 2682 if isinstance(axis, WithHalo): 2683 min_size = axis.size.get_size(axis, ref_axis, n=0) 2684 if (min_size - 2 * axis.halo) < 1: 2685 raise ValueError( 2686 f"axis {axis.id} with minimum size {min_size} is too small for halo" 2687 + f" {axis.halo}." 2688 ) 2689 2690 input_halo = axis.halo * axis.scale / ref_axis.scale 2691 if input_halo != int(input_halo) or input_halo % 2 == 1: 2692 raise ValueError( 2693 f"input_halo {input_halo} (output_halo {axis.halo} *" 2694 + f" output_scale {axis.scale} / input_scale {ref_axis.scale})" 2695 + f" {tensor_id}.{axis.id}." 2696 ) 2697 2698 @model_validator(mode="after") 2699 def _validate_test_tensors(self) -> Self: 2700 if not get_validation_context().perform_io_checks: 2701 return self 2702 2703 test_output_arrays = [load_array(descr.test_tensor) for descr in self.outputs] 2704 test_input_arrays = [load_array(descr.test_tensor) for descr in self.inputs] 2705 2706 tensors = { 2707 descr.id: (descr, array) 2708 for descr, array in zip( 2709 chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays 2710 ) 2711 } 2712 validate_tensors(tensors, tensor_origin="test_tensor") 2713 2714 output_arrays = { 2715 descr.id: array for descr, array in zip(self.outputs, test_output_arrays) 2716 } 2717 for rep_tol in self.config.bioimageio.reproducibility_tolerance: 2718 if not rep_tol.absolute_tolerance: 2719 continue 2720 2721 if rep_tol.output_ids: 2722 out_arrays = { 2723 oid: a 2724 for oid, a in output_arrays.items() 2725 if oid in rep_tol.output_ids 2726 } 2727 else: 2728 out_arrays = output_arrays 2729 2730 for out_id, array in out_arrays.items(): 2731 if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01: 2732 raise ValueError( 2733 "config.bioimageio.reproducibility_tolerance.absolute_tolerance=" 2734 + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}" 2735 + f" (1% of the maximum value of the test tensor '{out_id}')" 2736 ) 2737 2738 return self 2739 2740 @model_validator(mode="after") 2741 def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self: 2742 ipt_refs = {t.id for t in self.inputs} 2743 out_refs = {t.id for t in self.outputs} 2744 for ipt in self.inputs: 2745 for p in ipt.preprocessing: 2746 ref = p.kwargs.get("reference_tensor") 2747 if ref is None: 2748 continue 2749 if ref not in ipt_refs: 2750 raise ValueError( 2751 f"`reference_tensor` '{ref}' not found. Valid input tensor" 2752 + f" references are: {ipt_refs}." 2753 ) 2754 2755 for out in self.outputs: 2756 for p in out.postprocessing: 2757 ref = p.kwargs.get("reference_tensor") 2758 if ref is None: 2759 continue 2760 2761 if ref not in ipt_refs and ref not in out_refs: 2762 raise ValueError( 2763 f"`reference_tensor` '{ref}' not found. Valid tensor references" 2764 + f" are: {ipt_refs | out_refs}." 2765 ) 2766 2767 return self 2768 2769 # TODO: use validate funcs in validate_test_tensors 2770 # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]: 2771 2772 name: Annotated[ 2773 Annotated[ 2774 str, RestrictCharacters(string.ascii_letters + string.digits + "_+- ()") 2775 ], 2776 MinLen(5), 2777 MaxLen(128), 2778 warn(MaxLen(64), "Name longer than 64 characters.", INFO), 2779 ] 2780 """A human-readable name of this model. 2781 It should be no longer than 64 characters 2782 and may only contain letter, number, underscore, minus, parentheses and spaces. 2783 We recommend to chose a name that refers to the model's task and image modality. 2784 """ 2785 2786 outputs: NotEmpty[Sequence[OutputTensorDescr]] 2787 """Describes the output tensors.""" 2788 2789 @field_validator("outputs", mode="after") 2790 @classmethod 2791 def _validate_tensor_ids( 2792 cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo 2793 ) -> Sequence[OutputTensorDescr]: 2794 tensor_ids = [ 2795 t.id for t in info.data.get("inputs", []) + info.data.get("outputs", []) 2796 ] 2797 duplicate_tensor_ids: List[str] = [] 2798 seen: Set[str] = set() 2799 for t in tensor_ids: 2800 if t in seen: 2801 duplicate_tensor_ids.append(t) 2802 2803 seen.add(t) 2804 2805 if duplicate_tensor_ids: 2806 raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}") 2807 2808 return outputs 2809 2810 @staticmethod 2811 def _get_axes_with_parameterized_size( 2812 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 2813 ): 2814 return { 2815 f"{t.id}.{a.id}": (t, a, a.size) 2816 for t in io 2817 for a in t.axes 2818 if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize) 2819 } 2820 2821 @staticmethod 2822 def _get_axes_with_independent_size( 2823 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 2824 ): 2825 return { 2826 (t.id, a.id): (t, a, a.size) 2827 for t in io 2828 for a in t.axes 2829 if not isinstance(a, BatchAxis) 2830 and isinstance(a.size, (int, ParameterizedSize)) 2831 } 2832 2833 @field_validator("outputs", mode="after") 2834 @classmethod 2835 def _validate_output_axes( 2836 cls, outputs: List[OutputTensorDescr], info: ValidationInfo 2837 ) -> List[OutputTensorDescr]: 2838 input_size_refs = cls._get_axes_with_independent_size( 2839 info.data.get("inputs", []) 2840 ) 2841 output_size_refs = cls._get_axes_with_independent_size(outputs) 2842 2843 for i, out in enumerate(outputs): 2844 valid_independent_refs: Dict[ 2845 Tuple[TensorId, AxisId], 2846 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2847 ] = { 2848 **{ 2849 (out.id, a.id): (out, a, a.size) 2850 for a in out.axes 2851 if not isinstance(a, BatchAxis) 2852 and isinstance(a.size, (int, ParameterizedSize)) 2853 }, 2854 **input_size_refs, 2855 **output_size_refs, 2856 } 2857 for a, ax in enumerate(out.axes): 2858 cls._validate_axis( 2859 "outputs", 2860 i, 2861 out.id, 2862 a, 2863 ax, 2864 valid_independent_refs=valid_independent_refs, 2865 ) 2866 2867 return outputs 2868 2869 packaged_by: List[Author] = Field( 2870 default_factory=cast(Callable[[], List[Author]], list) 2871 ) 2872 """The persons that have packaged and uploaded this model. 2873 Only required if those persons differ from the `authors`.""" 2874 2875 parent: Optional[LinkedModel] = None 2876 """The model from which this model is derived, e.g. by fine-tuning the weights.""" 2877 2878 @model_validator(mode="after") 2879 def _validate_parent_is_not_self(self) -> Self: 2880 if self.parent is not None and self.parent.id == self.id: 2881 raise ValueError("A model description may not reference itself as parent.") 2882 2883 return self 2884 2885 run_mode: Annotated[ 2886 Optional[RunMode], 2887 warn(None, "Run mode '{value}' has limited support across consumer softwares."), 2888 ] = None 2889 """Custom run mode for this model: for more complex prediction procedures like test time 2890 data augmentation that currently cannot be expressed in the specification. 2891 No standard run modes are defined yet.""" 2892 2893 timestamp: Datetime = Field(default_factory=Datetime.now) 2894 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format 2895 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat). 2896 (In Python a datetime object is valid, too).""" 2897 2898 training_data: Annotated[ 2899 Union[None, LinkedDataset, DatasetDescr, DatasetDescr02], 2900 Field(union_mode="left_to_right"), 2901 ] = None 2902 """The dataset used to train this model""" 2903 2904 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)] 2905 """The weights for this model. 2906 Weights can be given for different formats, but should otherwise be equivalent. 2907 The available weight formats determine which consumers can use this model.""" 2908 2909 config: Config = Field(default_factory=Config) 2910 2911 @model_validator(mode="after") 2912 def _add_default_cover(self) -> Self: 2913 if not get_validation_context().perform_io_checks or self.covers: 2914 return self 2915 2916 try: 2917 generated_covers = generate_covers( 2918 [(t, load_array(t.test_tensor)) for t in self.inputs], 2919 [(t, load_array(t.test_tensor)) for t in self.outputs], 2920 ) 2921 except Exception as e: 2922 issue_warning( 2923 "Failed to generate cover image(s): {e}", 2924 value=self.covers, 2925 msg_context=dict(e=e), 2926 field="covers", 2927 ) 2928 else: 2929 self.covers.extend(generated_covers) 2930 2931 return self 2932 2933 def get_input_test_arrays(self) -> List[NDArray[Any]]: 2934 data = [load_array(ipt.test_tensor) for ipt in self.inputs] 2935 assert all(isinstance(d, np.ndarray) for d in data) 2936 return data 2937 2938 def get_output_test_arrays(self) -> List[NDArray[Any]]: 2939 data = [load_array(out.test_tensor) for out in self.outputs] 2940 assert all(isinstance(d, np.ndarray) for d in data) 2941 return data 2942 2943 @staticmethod 2944 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int: 2945 batch_size = 1 2946 tensor_with_batchsize: Optional[TensorId] = None 2947 for tid in tensor_sizes: 2948 for aid, s in tensor_sizes[tid].items(): 2949 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size: 2950 continue 2951 2952 if batch_size != 1: 2953 assert tensor_with_batchsize is not None 2954 raise ValueError( 2955 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})" 2956 ) 2957 2958 batch_size = s 2959 tensor_with_batchsize = tid 2960 2961 return batch_size 2962 2963 def get_output_tensor_sizes( 2964 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]] 2965 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]: 2966 """Returns the tensor output sizes for given **input_sizes**. 2967 Only if **input_sizes** has a valid input shape, the tensor output size is exact. 2968 Otherwise it might be larger than the actual (valid) output""" 2969 batch_size = self.get_batch_size(input_sizes) 2970 ns = self.get_ns(input_sizes) 2971 2972 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size) 2973 return tensor_sizes.outputs 2974 2975 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]): 2976 """get parameter `n` for each parameterized axis 2977 such that the valid input size is >= the given input size""" 2978 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {} 2979 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs} 2980 for tid in input_sizes: 2981 for aid, s in input_sizes[tid].items(): 2982 size_descr = axes[tid][aid].size 2983 if isinstance(size_descr, ParameterizedSize): 2984 ret[(tid, aid)] = size_descr.get_n(s) 2985 elif size_descr is None or isinstance(size_descr, (int, SizeReference)): 2986 pass 2987 else: 2988 assert_never(size_descr) 2989 2990 return ret 2991 2992 def get_tensor_sizes( 2993 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int 2994 ) -> _TensorSizes: 2995 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size) 2996 return _TensorSizes( 2997 { 2998 t: { 2999 aa: axis_sizes.inputs[(tt, aa)] 3000 for tt, aa in axis_sizes.inputs 3001 if tt == t 3002 } 3003 for t in {tt for tt, _ in axis_sizes.inputs} 3004 }, 3005 { 3006 t: { 3007 aa: axis_sizes.outputs[(tt, aa)] 3008 for tt, aa in axis_sizes.outputs 3009 if tt == t 3010 } 3011 for t in {tt for tt, _ in axis_sizes.outputs} 3012 }, 3013 ) 3014 3015 def get_axis_sizes( 3016 self, 3017 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], 3018 batch_size: Optional[int] = None, 3019 *, 3020 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None, 3021 ) -> _AxisSizes: 3022 """Determine input and output block shape for scale factors **ns** 3023 of parameterized input sizes. 3024 3025 Args: 3026 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id)) 3027 that is parameterized as `size = min + n * step`. 3028 batch_size: The desired size of the batch dimension. 3029 If given **batch_size** overwrites any batch size present in 3030 **max_input_shape**. Default 1. 3031 max_input_shape: Limits the derived block shapes. 3032 Each axis for which the input size, parameterized by `n`, is larger 3033 than **max_input_shape** is set to the minimal value `n_min` for which 3034 this is still true. 3035 Use this for small input samples or large values of **ns**. 3036 Or simply whenever you know the full input shape. 3037 3038 Returns: 3039 Resolved axis sizes for model inputs and outputs. 3040 """ 3041 max_input_shape = max_input_shape or {} 3042 if batch_size is None: 3043 for (_t_id, a_id), s in max_input_shape.items(): 3044 if a_id == BATCH_AXIS_ID: 3045 batch_size = s 3046 break 3047 else: 3048 batch_size = 1 3049 3050 all_axes = { 3051 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs) 3052 } 3053 3054 inputs: Dict[Tuple[TensorId, AxisId], int] = {} 3055 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {} 3056 3057 def get_axis_size(a: Union[InputAxis, OutputAxis]): 3058 if isinstance(a, BatchAxis): 3059 if (t_descr.id, a.id) in ns: 3060 logger.warning( 3061 "Ignoring unexpected size increment factor (n) for batch axis" 3062 + " of tensor '{}'.", 3063 t_descr.id, 3064 ) 3065 return batch_size 3066 elif isinstance(a.size, int): 3067 if (t_descr.id, a.id) in ns: 3068 logger.warning( 3069 "Ignoring unexpected size increment factor (n) for fixed size" 3070 + " axis '{}' of tensor '{}'.", 3071 a.id, 3072 t_descr.id, 3073 ) 3074 return a.size 3075 elif isinstance(a.size, ParameterizedSize): 3076 if (t_descr.id, a.id) not in ns: 3077 raise ValueError( 3078 "Size increment factor (n) missing for parametrized axis" 3079 + f" '{a.id}' of tensor '{t_descr.id}'." 3080 ) 3081 n = ns[(t_descr.id, a.id)] 3082 s_max = max_input_shape.get((t_descr.id, a.id)) 3083 if s_max is not None: 3084 n = min(n, a.size.get_n(s_max)) 3085 3086 return a.size.get_size(n) 3087 3088 elif isinstance(a.size, SizeReference): 3089 if (t_descr.id, a.id) in ns: 3090 logger.warning( 3091 "Ignoring unexpected size increment factor (n) for axis '{}'" 3092 + " of tensor '{}' with size reference.", 3093 a.id, 3094 t_descr.id, 3095 ) 3096 assert not isinstance(a, BatchAxis) 3097 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id] 3098 assert not isinstance(ref_axis, BatchAxis) 3099 ref_key = (a.size.tensor_id, a.size.axis_id) 3100 ref_size = inputs.get(ref_key, outputs.get(ref_key)) 3101 assert ref_size is not None, ref_key 3102 assert not isinstance(ref_size, _DataDepSize), ref_key 3103 return a.size.get_size( 3104 axis=a, 3105 ref_axis=ref_axis, 3106 ref_size=ref_size, 3107 ) 3108 elif isinstance(a.size, DataDependentSize): 3109 if (t_descr.id, a.id) in ns: 3110 logger.warning( 3111 "Ignoring unexpected increment factor (n) for data dependent" 3112 + " size axis '{}' of tensor '{}'.", 3113 a.id, 3114 t_descr.id, 3115 ) 3116 return _DataDepSize(a.size.min, a.size.max) 3117 else: 3118 assert_never(a.size) 3119 3120 # first resolve all , but the `SizeReference` input sizes 3121 for t_descr in self.inputs: 3122 for a in t_descr.axes: 3123 if not isinstance(a.size, SizeReference): 3124 s = get_axis_size(a) 3125 assert not isinstance(s, _DataDepSize) 3126 inputs[t_descr.id, a.id] = s 3127 3128 # resolve all other input axis sizes 3129 for t_descr in self.inputs: 3130 for a in t_descr.axes: 3131 if isinstance(a.size, SizeReference): 3132 s = get_axis_size(a) 3133 assert not isinstance(s, _DataDepSize) 3134 inputs[t_descr.id, a.id] = s 3135 3136 # resolve all output axis sizes 3137 for t_descr in self.outputs: 3138 for a in t_descr.axes: 3139 assert not isinstance(a.size, ParameterizedSize) 3140 s = get_axis_size(a) 3141 outputs[t_descr.id, a.id] = s 3142 3143 return _AxisSizes(inputs=inputs, outputs=outputs) 3144 3145 @model_validator(mode="before") 3146 @classmethod 3147 def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]: 3148 cls.convert_from_old_format_wo_validation(data) 3149 return data 3150 3151 @classmethod 3152 def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None: 3153 """Convert metadata following an older format version to this classes' format 3154 without validating the result. 3155 """ 3156 if ( 3157 data.get("type") == "model" 3158 and isinstance(fv := data.get("format_version"), str) 3159 and fv.count(".") == 2 3160 ): 3161 fv_parts = fv.split(".") 3162 if any(not p.isdigit() for p in fv_parts): 3163 return 3164 3165 fv_tuple = tuple(map(int, fv_parts)) 3166 3167 assert cls.implemented_format_version_tuple[0:2] == (0, 5) 3168 if fv_tuple[:2] in ((0, 3), (0, 4)): 3169 m04 = _ModelDescr_v0_4.load(data) 3170 if isinstance(m04, InvalidDescr): 3171 try: 3172 updated = _model_conv.convert_as_dict( 3173 m04 # pyright: ignore[reportArgumentType] 3174 ) 3175 except Exception as e: 3176 logger.error( 3177 "Failed to convert from invalid model 0.4 description." 3178 + f"\nerror: {e}" 3179 + "\nProceeding with model 0.5 validation without conversion." 3180 ) 3181 updated = None 3182 else: 3183 updated = _model_conv.convert_as_dict(m04) 3184 3185 if updated is not None: 3186 data.clear() 3187 data.update(updated) 3188 3189 elif fv_tuple[:2] == (0, 5): 3190 # bump patch version 3191 data["format_version"] = cls.implemented_format_version
Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
bioimage.io-wide unique resource identifier assigned by bioimage.io; version unspecific.
URL or relative path to a markdown file with additional documentation.
The recommended documentation file name is README.md
. An .md
suffix is mandatory.
The documentation should include a '#[#] Validation' (sub)section
with details on how to quantitatively validate the model on unseen data.
Describes the input tensors expected by this model.
A human-readable name of this model. It should be no longer than 64 characters and may only contain letter, number, underscore, minus, parentheses and spaces. We recommend to chose a name that refers to the model's task and image modality.
The persons that have packaged and uploaded this model.
Only required if those persons differ from the authors
.
The model from which this model is derived, e.g. by fine-tuning the weights.
Custom run mode for this model: for more complex prediction procedures like test time data augmentation that currently cannot be expressed in the specification. No standard run modes are defined yet.
The dataset used to train this model
The weights for this model. Weights can be given for different formats, but should otherwise be equivalent. The available weight formats determine which consumers can use this model.
2943 @staticmethod 2944 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int: 2945 batch_size = 1 2946 tensor_with_batchsize: Optional[TensorId] = None 2947 for tid in tensor_sizes: 2948 for aid, s in tensor_sizes[tid].items(): 2949 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size: 2950 continue 2951 2952 if batch_size != 1: 2953 assert tensor_with_batchsize is not None 2954 raise ValueError( 2955 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})" 2956 ) 2957 2958 batch_size = s 2959 tensor_with_batchsize = tid 2960 2961 return batch_size
2963 def get_output_tensor_sizes( 2964 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]] 2965 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]: 2966 """Returns the tensor output sizes for given **input_sizes**. 2967 Only if **input_sizes** has a valid input shape, the tensor output size is exact. 2968 Otherwise it might be larger than the actual (valid) output""" 2969 batch_size = self.get_batch_size(input_sizes) 2970 ns = self.get_ns(input_sizes) 2971 2972 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size) 2973 return tensor_sizes.outputs
Returns the tensor output sizes for given input_sizes. Only if input_sizes has a valid input shape, the tensor output size is exact. Otherwise it might be larger than the actual (valid) output
2975 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]): 2976 """get parameter `n` for each parameterized axis 2977 such that the valid input size is >= the given input size""" 2978 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {} 2979 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs} 2980 for tid in input_sizes: 2981 for aid, s in input_sizes[tid].items(): 2982 size_descr = axes[tid][aid].size 2983 if isinstance(size_descr, ParameterizedSize): 2984 ret[(tid, aid)] = size_descr.get_n(s) 2985 elif size_descr is None or isinstance(size_descr, (int, SizeReference)): 2986 pass 2987 else: 2988 assert_never(size_descr) 2989 2990 return ret
get parameter n
for each parameterized axis
such that the valid input size is >= the given input size
2992 def get_tensor_sizes( 2993 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int 2994 ) -> _TensorSizes: 2995 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size) 2996 return _TensorSizes( 2997 { 2998 t: { 2999 aa: axis_sizes.inputs[(tt, aa)] 3000 for tt, aa in axis_sizes.inputs 3001 if tt == t 3002 } 3003 for t in {tt for tt, _ in axis_sizes.inputs} 3004 }, 3005 { 3006 t: { 3007 aa: axis_sizes.outputs[(tt, aa)] 3008 for tt, aa in axis_sizes.outputs 3009 if tt == t 3010 } 3011 for t in {tt for tt, _ in axis_sizes.outputs} 3012 }, 3013 )
3015 def get_axis_sizes( 3016 self, 3017 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], 3018 batch_size: Optional[int] = None, 3019 *, 3020 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None, 3021 ) -> _AxisSizes: 3022 """Determine input and output block shape for scale factors **ns** 3023 of parameterized input sizes. 3024 3025 Args: 3026 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id)) 3027 that is parameterized as `size = min + n * step`. 3028 batch_size: The desired size of the batch dimension. 3029 If given **batch_size** overwrites any batch size present in 3030 **max_input_shape**. Default 1. 3031 max_input_shape: Limits the derived block shapes. 3032 Each axis for which the input size, parameterized by `n`, is larger 3033 than **max_input_shape** is set to the minimal value `n_min` for which 3034 this is still true. 3035 Use this for small input samples or large values of **ns**. 3036 Or simply whenever you know the full input shape. 3037 3038 Returns: 3039 Resolved axis sizes for model inputs and outputs. 3040 """ 3041 max_input_shape = max_input_shape or {} 3042 if batch_size is None: 3043 for (_t_id, a_id), s in max_input_shape.items(): 3044 if a_id == BATCH_AXIS_ID: 3045 batch_size = s 3046 break 3047 else: 3048 batch_size = 1 3049 3050 all_axes = { 3051 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs) 3052 } 3053 3054 inputs: Dict[Tuple[TensorId, AxisId], int] = {} 3055 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {} 3056 3057 def get_axis_size(a: Union[InputAxis, OutputAxis]): 3058 if isinstance(a, BatchAxis): 3059 if (t_descr.id, a.id) in ns: 3060 logger.warning( 3061 "Ignoring unexpected size increment factor (n) for batch axis" 3062 + " of tensor '{}'.", 3063 t_descr.id, 3064 ) 3065 return batch_size 3066 elif isinstance(a.size, int): 3067 if (t_descr.id, a.id) in ns: 3068 logger.warning( 3069 "Ignoring unexpected size increment factor (n) for fixed size" 3070 + " axis '{}' of tensor '{}'.", 3071 a.id, 3072 t_descr.id, 3073 ) 3074 return a.size 3075 elif isinstance(a.size, ParameterizedSize): 3076 if (t_descr.id, a.id) not in ns: 3077 raise ValueError( 3078 "Size increment factor (n) missing for parametrized axis" 3079 + f" '{a.id}' of tensor '{t_descr.id}'." 3080 ) 3081 n = ns[(t_descr.id, a.id)] 3082 s_max = max_input_shape.get((t_descr.id, a.id)) 3083 if s_max is not None: 3084 n = min(n, a.size.get_n(s_max)) 3085 3086 return a.size.get_size(n) 3087 3088 elif isinstance(a.size, SizeReference): 3089 if (t_descr.id, a.id) in ns: 3090 logger.warning( 3091 "Ignoring unexpected size increment factor (n) for axis '{}'" 3092 + " of tensor '{}' with size reference.", 3093 a.id, 3094 t_descr.id, 3095 ) 3096 assert not isinstance(a, BatchAxis) 3097 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id] 3098 assert not isinstance(ref_axis, BatchAxis) 3099 ref_key = (a.size.tensor_id, a.size.axis_id) 3100 ref_size = inputs.get(ref_key, outputs.get(ref_key)) 3101 assert ref_size is not None, ref_key 3102 assert not isinstance(ref_size, _DataDepSize), ref_key 3103 return a.size.get_size( 3104 axis=a, 3105 ref_axis=ref_axis, 3106 ref_size=ref_size, 3107 ) 3108 elif isinstance(a.size, DataDependentSize): 3109 if (t_descr.id, a.id) in ns: 3110 logger.warning( 3111 "Ignoring unexpected increment factor (n) for data dependent" 3112 + " size axis '{}' of tensor '{}'.", 3113 a.id, 3114 t_descr.id, 3115 ) 3116 return _DataDepSize(a.size.min, a.size.max) 3117 else: 3118 assert_never(a.size) 3119 3120 # first resolve all , but the `SizeReference` input sizes 3121 for t_descr in self.inputs: 3122 for a in t_descr.axes: 3123 if not isinstance(a.size, SizeReference): 3124 s = get_axis_size(a) 3125 assert not isinstance(s, _DataDepSize) 3126 inputs[t_descr.id, a.id] = s 3127 3128 # resolve all other input axis sizes 3129 for t_descr in self.inputs: 3130 for a in t_descr.axes: 3131 if isinstance(a.size, SizeReference): 3132 s = get_axis_size(a) 3133 assert not isinstance(s, _DataDepSize) 3134 inputs[t_descr.id, a.id] = s 3135 3136 # resolve all output axis sizes 3137 for t_descr in self.outputs: 3138 for a in t_descr.axes: 3139 assert not isinstance(a.size, ParameterizedSize) 3140 s = get_axis_size(a) 3141 outputs[t_descr.id, a.id] = s 3142 3143 return _AxisSizes(inputs=inputs, outputs=outputs)
Determine input and output block shape for scale factors ns of parameterized input sizes.
Arguments:
- ns: Scale factor
n
for each axis (keyed by (tensor_id, axis_id)) that is parameterized assize = min + n * step
. - batch_size: The desired size of the batch dimension. If given batch_size overwrites any batch size present in max_input_shape. Default 1.
- max_input_shape: Limits the derived block shapes.
Each axis for which the input size, parameterized by
n
, is larger than max_input_shape is set to the minimal valuen_min
for which this is still true. Use this for small input samples or large values of ns. Or simply whenever you know the full input shape.
Returns:
Resolved axis sizes for model inputs and outputs.
3151 @classmethod 3152 def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None: 3153 """Convert metadata following an older format version to this classes' format 3154 without validating the result. 3155 """ 3156 if ( 3157 data.get("type") == "model" 3158 and isinstance(fv := data.get("format_version"), str) 3159 and fv.count(".") == 2 3160 ): 3161 fv_parts = fv.split(".") 3162 if any(not p.isdigit() for p in fv_parts): 3163 return 3164 3165 fv_tuple = tuple(map(int, fv_parts)) 3166 3167 assert cls.implemented_format_version_tuple[0:2] == (0, 5) 3168 if fv_tuple[:2] in ((0, 3), (0, 4)): 3169 m04 = _ModelDescr_v0_4.load(data) 3170 if isinstance(m04, InvalidDescr): 3171 try: 3172 updated = _model_conv.convert_as_dict( 3173 m04 # pyright: ignore[reportArgumentType] 3174 ) 3175 except Exception as e: 3176 logger.error( 3177 "Failed to convert from invalid model 0.4 description." 3178 + f"\nerror: {e}" 3179 + "\nProceeding with model 0.5 validation without conversion." 3180 ) 3181 updated = None 3182 else: 3183 updated = _model_conv.convert_as_dict(m04) 3184 3185 if updated is not None: 3186 data.clear() 3187 data.update(updated) 3188 3189 elif fv_tuple[:2] == (0, 5): 3190 # bump patch version 3191 data["format_version"] = cls.implemented_format_version
Convert metadata following an older format version to this classes' format without validating the result.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
337def init_private_attributes(self: BaseModel, context: Any, /) -> None: 338 """This function is meant to behave like a BaseModel method to initialise private attributes. 339 340 It takes context as an argument since that's what pydantic-core passes when calling it. 341 342 Args: 343 self: The BaseModel instance. 344 context: The context. 345 """ 346 if getattr(self, '__pydantic_private__', None) is None: 347 pydantic_private = {} 348 for name, private_attr in self.__private_attributes__.items(): 349 default = private_attr.get_default() 350 if default is not PydanticUndefined: 351 pydantic_private[name] = default 352 object_setattr(self, '__pydantic_private__', pydantic_private)
This function is meant to behave like a BaseModel method to initialise private attributes.
It takes context as an argument since that's what pydantic-core passes when calling it.
Arguments:
- self: The BaseModel instance.
- context: The context.
Inherited Members
3416def generate_covers( 3417 inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]], 3418 outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]], 3419) -> List[Path]: 3420 def squeeze( 3421 data: NDArray[Any], axes: Sequence[AnyAxis] 3422 ) -> Tuple[NDArray[Any], List[AnyAxis]]: 3423 """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining""" 3424 if data.ndim != len(axes): 3425 raise ValueError( 3426 f"tensor shape {data.shape} does not match described axes" 3427 + f" {[a.id for a in axes]}" 3428 ) 3429 3430 axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1] 3431 return data.squeeze(), axes 3432 3433 def normalize( 3434 data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7 3435 ) -> NDArray[np.float32]: 3436 data = data.astype("float32") 3437 data -= data.min(axis=axis, keepdims=True) 3438 data /= data.max(axis=axis, keepdims=True) + eps 3439 return data 3440 3441 def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]): 3442 original_shape = data.shape 3443 data, axes = squeeze(data, axes) 3444 3445 # take slice fom any batch or index axis if needed 3446 # and convert the first channel axis and take a slice from any additional channel axes 3447 slices: Tuple[slice, ...] = () 3448 ndim = data.ndim 3449 ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2 3450 has_c_axis = False 3451 for i, a in enumerate(axes): 3452 s = data.shape[i] 3453 assert s > 1 3454 if ( 3455 isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis)) 3456 and ndim > ndim_need 3457 ): 3458 data = data[slices + (slice(s // 2 - 1, s // 2),)] 3459 ndim -= 1 3460 elif isinstance(a, ChannelAxis): 3461 if has_c_axis: 3462 # second channel axis 3463 data = data[slices + (slice(0, 1),)] 3464 ndim -= 1 3465 else: 3466 has_c_axis = True 3467 if s == 2: 3468 # visualize two channels with cyan and magenta 3469 data = np.concatenate( 3470 [ 3471 data[slices + (slice(1, 2),)], 3472 data[slices + (slice(0, 1),)], 3473 ( 3474 data[slices + (slice(0, 1),)] 3475 + data[slices + (slice(1, 2),)] 3476 ) 3477 / 2, # TODO: take maximum instead? 3478 ], 3479 axis=i, 3480 ) 3481 elif data.shape[i] == 3: 3482 pass # visualize 3 channels as RGB 3483 else: 3484 # visualize first 3 channels as RGB 3485 data = data[slices + (slice(3),)] 3486 3487 assert data.shape[i] == 3 3488 3489 slices += (slice(None),) 3490 3491 data, axes = squeeze(data, axes) 3492 assert len(axes) == ndim 3493 # take slice from z axis if needed 3494 slices = () 3495 if ndim > ndim_need: 3496 for i, a in enumerate(axes): 3497 s = data.shape[i] 3498 if a.id == AxisId("z"): 3499 data = data[slices + (slice(s // 2 - 1, s // 2),)] 3500 data, axes = squeeze(data, axes) 3501 ndim -= 1 3502 break 3503 3504 slices += (slice(None),) 3505 3506 # take slice from any space or time axis 3507 slices = () 3508 3509 for i, a in enumerate(axes): 3510 if ndim <= ndim_need: 3511 break 3512 3513 s = data.shape[i] 3514 assert s > 1 3515 if isinstance( 3516 a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis) 3517 ): 3518 data = data[slices + (slice(s // 2 - 1, s // 2),)] 3519 ndim -= 1 3520 3521 slices += (slice(None),) 3522 3523 del slices 3524 data, axes = squeeze(data, axes) 3525 assert len(axes) == ndim 3526 3527 if (has_c_axis and ndim != 3) or ndim != 2: 3528 raise ValueError( 3529 f"Failed to construct cover image from shape {original_shape}" 3530 ) 3531 3532 if not has_c_axis: 3533 assert ndim == 2 3534 data = np.repeat(data[:, :, None], 3, axis=2) 3535 axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB")))) 3536 ndim += 1 3537 3538 assert ndim == 3 3539 3540 # transpose axis order such that longest axis comes first... 3541 axis_order: List[int] = list(np.argsort(list(data.shape))) 3542 axis_order.reverse() 3543 # ... and channel axis is last 3544 c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0] 3545 axis_order.append(axis_order.pop(c)) 3546 axes = [axes[ao] for ao in axis_order] 3547 data = data.transpose(axis_order) 3548 3549 # h, w = data.shape[:2] 3550 # if h / w in (1.0 or 2.0): 3551 # pass 3552 # elif h / w < 2: 3553 # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images 3554 3555 norm_along = ( 3556 tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None 3557 ) 3558 # normalize the data and map to 8 bit 3559 data = normalize(data, norm_along) 3560 data = (data * 255).astype("uint8") 3561 3562 return data 3563 3564 def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]): 3565 assert im0.dtype == im1.dtype == np.uint8 3566 assert im0.shape == im1.shape 3567 assert im0.ndim == 3 3568 N, M, C = im0.shape 3569 assert C == 3 3570 out = np.ones((N, M, C), dtype="uint8") 3571 for c in range(C): 3572 outc = np.tril(im0[..., c]) 3573 mask = outc == 0 3574 outc[mask] = np.triu(im1[..., c])[mask] 3575 out[..., c] = outc 3576 3577 return out 3578 3579 ipt_descr, ipt = inputs[0] 3580 out_descr, out = outputs[0] 3581 3582 ipt_img = to_2d_image(ipt, ipt_descr.axes) 3583 out_img = to_2d_image(out, out_descr.axes) 3584 3585 cover_folder = Path(mkdtemp()) 3586 if ipt_img.shape == out_img.shape: 3587 covers = [cover_folder / "cover.png"] 3588 imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img)) 3589 else: 3590 covers = [cover_folder / "input.png", cover_folder / "output.png"] 3591 imwrite(covers[0], ipt_img) 3592 imwrite(covers[1], out_img) 3593 3594 return covers